├── LICENSE ├── README.md └── DANet ├── data_utils.py ├── torch_utils.py ├── DANet_test.ipynb └── DANet_train.ipynb /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 NapLab (Columbia University) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Attractor Network (DANet) for single-channel speech separation 2 | 3 | This repository provides the implementation of the Deep Attractor Network (DANet) for single-channel speech separation in Jupyter Notebook (.ipynb) format. DANet was introduced in the following papers: 4 | 5 | Zhuo Chen, Yi Luo, and Nima Mesgarani, [Deep attractor network for single-microphone speaker separation](https://ieeexplore.ieee.org/abstract/document/7952155) 6 | 7 | Yi Luo, Zhuo Chen, and Nima Mesgarani, [Speaker-independent speech separation with deep attractor network](https://ieeexplore.ieee.org/abstract/document/8264702) 8 | 9 | Informations about the papers can also be found in [our lab website](http://naplab.ee.columbia.edu/danet.html). 10 | 11 | ## Citation 12 | 13 | If you find the scripts helpful in your research, please consider citing: 14 | 15 | @inproceedings{chen2017deep, 16 | title={Deep attractor network for single-microphone speaker separation}, 17 | author={Chen, Zhuo and Luo, Yi and Mesgarani, Nima}, 18 | booktitle={Acoustics, Speech and Signal Processing (ICASSP), 2017 IEEE International Conference on}, 19 | pages={246--250}, 20 | year={2017}, 21 | organization={IEEE} 22 | } 23 | 24 | @article{luo2018speaker, 25 | title={Speaker-independent speech separation with deep attractor network}, 26 | author={Luo, Yi and Chen, Zhuo and Mesgarani, Nima}, 27 | journal={IEEE/ACM Transactions on Audio, Speech, and Language Processing}, 28 | volume={26}, 29 | number={4}, 30 | pages={787--796}, 31 | year={2018}, 32 | publisher={IEEE} 33 | } 34 | 35 | ### Requirements 36 | - Python 3.6.4 37 | - Pytorch 0.4.1 38 | - h5py 2.7.1 39 | - sklearn 0.19.1 40 | - numpy 1.15.0 41 | - librosa 0.6.0 42 | - jupyter 1.0.0 or above 43 | - notebook 5.4.0 or above 44 | -------------------------------------------------------------------------------- /DANet/data_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import h5py 4 | 5 | class WSJDataset(Dataset): 6 | """ 7 | Wrapper for the WSJ Dataset. 8 | The dataset is saved in HDF5 binary data format, 9 | which contains the input feature, mixture magnitude 10 | spectrogram, wiener-filter like mask as training target, 11 | ideal binary mask as the oracle source assignment, 12 | and the weight threshold matrix for masking out low 13 | energy T-F bins. 14 | """ 15 | 16 | def __init__(self, path): 17 | super(WSJDataset, self).__init__() 18 | 19 | self.h5pyLoader = h5py.File(path, 'r') 20 | 21 | self.infeat = self.h5pyLoader['infeat'] # input feature, shape: (num_sample, time, freq) 22 | self.mixture = self.h5pyLoader['mix'] # mixture magnitude spectrogram, shape: (num_sample, time, freq) 23 | self.wf = self.h5pyLoader['wf'] # wiener-filter like mask, shape: (num_sample, time*freq, num_spk) 24 | self.ibm = self.h5pyLoader['ibm'] # ideal binary mask, shape: (num_sample, time*freq, num_spk) 25 | self.weight = self.h5pyLoader['weight'] # weight threshold matrix, shape: (num_sample, time*freq, 1) 26 | 27 | self._len = self.infeat.shape[0] 28 | 29 | def __getitem__(self, index): 30 | """ 31 | Wrap the data to Pytorch tensors. 32 | """ 33 | infeat_tensor = torch.from_numpy(self.infeat[index]) 34 | wf_tensor = torch.from_numpy(self.wf[index]) 35 | mixture_tensor = torch.from_numpy(self.mixture[index]) 36 | ibm_tensor = torch.from_numpy(self.ibm[index]) 37 | weight_tensor = torch.from_numpy(self.weight[index]) 38 | return infeat_tensor, wf_tensor, mixture_tensor, mask_tensor, weight_tensor 39 | 40 | def __len__(self): 41 | return self._len 42 | -------------------------------------------------------------------------------- /DANet/torch_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | 7 | class MultiRNN(nn.Module): 8 | """ 9 | Container module for multiple stacked RNN layers. 10 | 11 | args: 12 | rnn_type: string, select from 'RNN', 'LSTM' and 'GRU'. 13 | input_size: int, dimension of the input feature. The input should have shape 14 | (batch, seq_len, input_size). 15 | hidden_size: int, dimension of the hidden state. The corresponding output should 16 | have shape (batch, seq_len, hidden_size). 17 | num_layers: int, number of stacked RNN layers. Default is 1. 18 | bidirectional: bool, whether the RNN layers are bidirectional. Default is False. 19 | """ 20 | 21 | def __init__(self, rnn_type, input_size, hidden_size, dropout=0, num_layers=1, bidirectional=False): 22 | super(MultiRNN, self).__init__() 23 | 24 | self.rnn = getattr(nn, rnn_type)(input_size, hidden_size, num_layers, dropout=dropout, 25 | batch_first=True, bidirectional=bidirectional) 26 | 27 | 28 | 29 | self.rnn_type = rnn_type 30 | self.hidden_size = hidden_size 31 | self.num_layers = num_layers 32 | self.num_direction = int(bidirectional) + 1 33 | 34 | def forward(self, input, hidden): 35 | self.rnn.flatten_parameters() 36 | return self.rnn(input, hidden) 37 | 38 | def init_hidden(self, batch_size): 39 | weight = next(self.parameters()).data 40 | if self.rnn_type == 'LSTM': 41 | return (Variable(weight.new(self.num_layers*self.num_direction, batch_size, self.hidden_size).zero_()), 42 | Variable(weight.new(self.num_layers*self.num_direction, batch_size, self.hidden_size).zero_())) 43 | else: 44 | return Variable(weight.new(self.num_layers*self.num_direction, batch_size, self.hidden_size).zero_()) 45 | 46 | 47 | class FCLayer(nn.Module): 48 | """ 49 | Container module for a fully-connected layer. 50 | 51 | args: 52 | input_size: int, dimension of the input feature. The input should have shape 53 | (batch, input_size). 54 | hidden_size: int, dimension of the output. The corresponding output should 55 | have shape (batch, hidden_size). 56 | nonlinearity: string, the nonlinearity applied to the transformation. Default is None. 57 | """ 58 | 59 | def __init__(self, input_size, hidden_size, nonlinearity=None): 60 | super(FCLayer, self).__init__() 61 | 62 | self.input_size = input_size 63 | self.hidden_size = hidden_size 64 | self.FC = nn.Linear(self.input_size, self.hidden_size) 65 | if nonlinearity: 66 | self.nonlinearity = getattr(torch, nonlinearity) 67 | else: 68 | self.nonlinearity = None 69 | 70 | self.init_hidden() 71 | 72 | def forward(self, input): 73 | if self.nonlinearity is not None: 74 | return self.nonlinearity(self.FC(input)) 75 | else: 76 | return self.FC(input) 77 | 78 | def init_hidden(self): 79 | initrange = 1. / np.sqrt(self.input_size) 80 | self.FC.bias.data.fill_(0) 81 | self.FC.weight.data.uniform_(-initrange, initrange) 82 | -------------------------------------------------------------------------------- /DANet/DANet_test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "/home/cong/anaconda3/envs/DP/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n", 13 | " from ._conv import register_converters as _register_converters\n" 14 | ] 15 | }, 16 | { 17 | "name": "stdout", 18 | "output_type": "stream", 19 | "text": [ 20 | "env: CUDA_VISIBLE_DEVICES=1\n" 21 | ] 22 | } 23 | ], 24 | "source": [ 25 | "from __future__ import print_function\n", 26 | "import argparse\n", 27 | "\n", 28 | "import torch\n", 29 | "import torch.nn as nn\n", 30 | "import torch.nn.functional as F\n", 31 | "from torch.autograd import Variable\n", 32 | "\n", 33 | "import os\n", 34 | "import numpy as np\n", 35 | "import h5py\n", 36 | "import time\n", 37 | "\n", 38 | "import torch_utils\n", 39 | "import data_utils\n", 40 | "\n", 41 | "import librosa\n", 42 | "from sklearn.cluster import KMeans" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 2, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "# global params\n", 52 | "\n", 53 | "parser = argparse.ArgumentParser(description='DANet')\n", 54 | "parser.add_argument('--batch-size', type=int, default=128,\n", 55 | " help='input batch size for training (default: 128)')\n", 56 | "parser.add_argument('--epochs', type=int, default=100,\n", 57 | " help='number of epochs to train (default: 100)')\n", 58 | "parser.add_argument('--cuda', action='store_true', default=True,\n", 59 | " help='enables CUDA training (default: True)')\n", 60 | "parser.add_argument('--seed', type=int, default=20170220,\n", 61 | " help='random seed (default: 20170220)')\n", 62 | "parser.add_argument('--infeat-dim', type=int, default=129,\n", 63 | " help='dimension of the input feature (default: 129)')\n", 64 | "parser.add_argument('--outfeat-dim', type=int, default=20,\n", 65 | " help='dimension of the embedding (default: 20)')\n", 66 | "parser.add_argument('--threshold', type=float, default=0.9,\n", 67 | " help='the weight threshold (default: 0.9)')\n", 68 | "parser.add_argument('--seq-len', type=int, default=100,\n", 69 | " help='length of the sequence (default: 100)')\n", 70 | "parser.add_argument('--log-step', type=int, default=100,\n", 71 | " help='how many batches to wait before logging training status (default: 100)')\n", 72 | "parser.add_argument('--lr', type=float, default=1e-3,\n", 73 | " help='learning rate (default: 1e-3)')\n", 74 | "parser.add_argument('--num-layers', type=int, default=4,\n", 75 | " help='number of stacked RNN layers (default: 1)')\n", 76 | "parser.add_argument('--bidirectional', action='store_true', default=True,\n", 77 | " help='whether to use bidirectional RNN layers (default: True)')\n", 78 | "parser.add_argument('--val-save', type=str, default='model.pt',\n", 79 | " help='path to save the best model')\n", 80 | "\n", 81 | "args, _ = parser.parse_known_args()\n", 82 | "args.cuda = args.cuda and torch.cuda.is_available()\n", 83 | "args.num_direction = int(args.bidirectional)+1\n", 84 | "\n", 85 | "torch.manual_seed(args.seed)\n", 86 | "if args.cuda:\n", 87 | " torch.cuda.manual_seed(args.seed)\n", 88 | " kwargs = {'num_workers': 1, 'pin_memory': True} \n", 89 | "else:\n", 90 | " kwargs = {}\n", 91 | " \n", 92 | "# STFT parameters\n", 93 | "sr = 8000\n", 94 | "nfft = 256\n", 95 | "nhop = 64\n", 96 | "nspk = 2" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 4, 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "# define model\n", 106 | "\n", 107 | "class DANet(nn.Module):\n", 108 | " def __init__(self):\n", 109 | " super(DANet, self).__init__()\n", 110 | " \n", 111 | " self.rnn = torch_utils.MultiRNN('LSTM', args.infeat_dim, 300, \n", 112 | " num_layers=args.num_layers, \n", 113 | " bidirectional=args.bidirectional)\n", 114 | " self.FC = torch_utils.FCLayer(600, args.infeat_dim*args.outfeat_dim, nonlinearity='tanh')\n", 115 | " \n", 116 | " self.infeat_dim = args.infeat_dim\n", 117 | " self.outfeat_dim = args.outfeat_dim\n", 118 | " self.eps = 1e-8\n", 119 | " \n", 120 | " def forward(self, input, hidden):\n", 121 | " \"\"\"\n", 122 | " input: the input feature; \n", 123 | " shape: (B, T, F)\n", 124 | " \n", 125 | " hidden: the initial hidden state in the LSTM layers.\n", 126 | " \"\"\"\n", 127 | " \n", 128 | " seq_len = input.size(1)\n", 129 | " \n", 130 | " # generate the embeddings (V) by the LSTM layers\n", 131 | " LSTM_output, hidden = self.rnn(input, hidden)\n", 132 | " LSTM_output = LSTM_output.contiguous().view(-1, LSTM_output.size(2)) # B*T, H \n", 133 | " V = self.FC(LSTM_output) # B*T, F*K\n", 134 | " V = V.view(-1, seq_len*self.infeat_dim, self.outfeat_dim) # B, T*F, K\n", 135 | " \n", 136 | " return V\n", 137 | " \n", 138 | " def init_hidden(self, batch_size):\n", 139 | " return self.rnn.init_hidden(batch_size)" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": 5, 145 | "metadata": {}, 146 | "outputs": [ 147 | { 148 | "data": { 149 | "text/plain": [ 150 | "DANet(\n", 151 | " (rnn): MultiRNN(\n", 152 | " (rnn): LSTM(129, 300, num_layers=4, batch_first=True, bidirectional=True)\n", 153 | " )\n", 154 | " (FC): FCLayer(\n", 155 | " (FC): Linear(in_features=600, out_features=2580, bias=True)\n", 156 | " )\n", 157 | ")" 158 | ] 159 | }, 160 | "execution_count": 5, 161 | "metadata": {}, 162 | "output_type": "execute_result" 163 | } 164 | ], 165 | "source": [ 166 | "# load model\n", 167 | "model = DANet()\n", 168 | "model.load_state_dict(torch.load('model.pt'))\n", 169 | "\n", 170 | "if args.cuda:\n", 171 | " model.cuda()\n", 172 | "model.eval()" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [ 181 | "# load mixture data\n", 182 | "mix, _ = librosa.load('your_path_to_mixture_audio', sr=sr)\n", 183 | "\n", 184 | "# STFT\n", 185 | "mix_spec = librosa.stft(mix, nfft, nhop) # F, T\n", 186 | "mix_phase = np.angle(mix_spec) # F, T\n", 187 | "mix_spec = np.abs(mix_spec) # F, T\n", 188 | "\n", 189 | "# magnitude spectrogram in db scale\n", 190 | "infeat = 20*np.log10(mix_spec.T)\n", 191 | "infeat = np.asarray([infeat]*1)\n", 192 | "# optional: normalize the input feature with your pre-calculated\n", 193 | "# statistics of the training set\n", 194 | "\n", 195 | "batch_infeat = Variable(torch.from_numpy(infeat)).contiguous()\n", 196 | "if args.cuda:\n", 197 | " batch_infeat = batch_infeat.cuda()\n", 198 | "\n", 199 | "with torch.no_grad():\n", 200 | " hidden = model.init_hidden(batch_infeat.size(0))\n", 201 | " embeddings = model(batch_infeat, hidden)\n", 202 | " \n", 203 | "# estimate attractors via K-means\n", 204 | "embeddings = embeddings[0].data.cpu().numpy() # T*F, K\n", 205 | "kmeans_model = KMeans(n_clusters=nspk, random_state=0).fit(embeddings.astype('float64')) \n", 206 | "attractor = kmeans_model.cluster_centers_ # nspk, K\n", 207 | "\n", 208 | "# estimate masks\n", 209 | "embeddings = torch.from_numpy(embeddings).float() # T*F, K\n", 210 | "attractor = torch.from_numpy(attractor.T).float() # K, nspk\n", 211 | "if args.cuda:\n", 212 | " embeddings = embeddings.cuda()\n", 213 | " attractor = attractor.cuda()\n", 214 | "\n", 215 | "mask = F.softmax(torch.mm(embeddings, attractor), dim=1) # T*F, nspk\n", 216 | "mask = mask.data.cpu().numpy()\n", 217 | "\n", 218 | "mask_1 = mask[:,0].reshape(-1, args.infeat_dim).T\n", 219 | "mask_2 = mask[:,1].reshape(-1, args.infeat_dim).T\n", 220 | "\n", 221 | "# masking the mixture magnitude spectrogram\n", 222 | "s1_spec = (mix_spec * mask_1) * np.exp(1j*mix_phase)\n", 223 | "s2_spec = (mix_spec * mask_2) * np.exp(1j*mix_phase)\n", 224 | "\n", 225 | "# reconstruct waveforms\n", 226 | "res_1 = librosa.istft(s1_spec, hop_length=nhop, win_length=nfft)\n", 227 | "res_2 = librosa.istft(s2_spec, hop_length=nhop, win_length=nfft)\n", 228 | "\n", 229 | "if len(res_1) < len(mix):\n", 230 | " # pad zero at the end\n", 231 | " res_1 = np.concatenate([res_1, np.zeros(len(mix)-len(res_1))])\n", 232 | " res_2 = np.concatenate([res_2, np.zeros(len(mix)-len(res_2))])\n", 233 | "else:\n", 234 | " res_1 = res_1[:len(mix)]\n", 235 | " res_2 = res_2[:len(mix)]" 236 | ] 237 | } 238 | ], 239 | "metadata": { 240 | "kernelspec": { 241 | "display_name": "Python 3", 242 | "language": "python", 243 | "name": "python3" 244 | }, 245 | "language_info": { 246 | "codemirror_mode": { 247 | "name": "ipython", 248 | "version": 3 249 | }, 250 | "file_extension": ".py", 251 | "mimetype": "text/x-python", 252 | "name": "python", 253 | "nbconvert_exporter": "python", 254 | "pygments_lexer": "ipython3", 255 | "version": "3.6.4" 256 | } 257 | }, 258 | "nbformat": 4, 259 | "nbformat_minor": 2 260 | } 261 | -------------------------------------------------------------------------------- /DANet/DANet_train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from __future__ import print_function\n", 10 | "import argparse\n", 11 | "\n", 12 | "import torch\n", 13 | "import torch.nn as nn\n", 14 | "import torch.nn.functional as F\n", 15 | "from torch.autograd import Variable\n", 16 | "from torch.utils.data import DataLoader\n", 17 | "import torch.optim as optim\n", 18 | "\n", 19 | "import os\n", 20 | "import numpy as np\n", 21 | "import h5py\n", 22 | "import time\n", 23 | "\n", 24 | "import torch_utils\n", 25 | "import data_utils" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "# global params\n", 35 | "\n", 36 | "parser = argparse.ArgumentParser(description='DANet')\n", 37 | "parser.add_argument('--batch-size', type=int, default=128,\n", 38 | " help='input batch size for training (default: 128)')\n", 39 | "parser.add_argument('--epochs', type=int, default=100,\n", 40 | " help='number of epochs to train (default: 100)')\n", 41 | "parser.add_argument('--cuda', action='store_true', default=True,\n", 42 | " help='enables CUDA training (default: True)')\n", 43 | "parser.add_argument('--seed', type=int, default=20170220,\n", 44 | " help='random seed (default: 20170220)')\n", 45 | "parser.add_argument('--infeat-dim', type=int, default=129,\n", 46 | " help='dimension of the input feature (default: 129)')\n", 47 | "parser.add_argument('--outfeat-dim', type=int, default=20,\n", 48 | " help='dimension of the embedding (default: 20)')\n", 49 | "parser.add_argument('--threshold', type=float, default=0.9,\n", 50 | " help='the weight threshold (default: 0.9)')\n", 51 | "parser.add_argument('--seq-len', type=int, default=100,\n", 52 | " help='length of the sequence (default: 100)')\n", 53 | "parser.add_argument('--log-step', type=int, default=100,\n", 54 | " help='how many batches to wait before logging training status (default: 100)')\n", 55 | "parser.add_argument('--lr', type=float, default=1e-3,\n", 56 | " help='learning rate (default: 1e-3)')\n", 57 | "parser.add_argument('--num-layers', type=int, default=4,\n", 58 | " help='number of stacked RNN layers (default: 1)')\n", 59 | "parser.add_argument('--bidirectional', action='store_true', default=True,\n", 60 | " help='whether to use bidirectional RNN layers (default: True)')\n", 61 | "parser.add_argument('--val-save', type=str, default='model.pt',\n", 62 | " help='path to save the best model')\n", 63 | "\n", 64 | "args, _ = parser.parse_known_args()\n", 65 | "args.cuda = args.cuda and torch.cuda.is_available()\n", 66 | "args.num_direction = int(args.bidirectional)+1\n", 67 | "\n", 68 | "torch.manual_seed(args.seed)\n", 69 | "if args.cuda:\n", 70 | " torch.cuda.manual_seed(args.seed)\n", 71 | " kwargs = {'num_workers': 1, 'pin_memory': True} \n", 72 | "else:\n", 73 | " kwargs = {}\n", 74 | "\n", 75 | "# training and validation datast path\n", 76 | "training_data_path = 'your_path_to_training_set'\n", 77 | "validation_data_path = 'your_path_to_validation_set'" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": null, 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "# define data loaders\n", 87 | "\n", 88 | "train_loader = DataLoader(data_utils.WSJDataset(training_data_path), \n", 89 | " batch_size=args.batch_size, \n", 90 | " shuffle=True, \n", 91 | " **kwargs)\n", 92 | "\n", 93 | "validation_loader = DataLoader(data_utils.WSJDataset(validation_data_path), \n", 94 | " batch_size=args.batch_size, \n", 95 | " shuffle=False, \n", 96 | " **kwargs)" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "# define model\n", 106 | "\n", 107 | "class DANet(nn.Module):\n", 108 | " def __init__(self):\n", 109 | " super(DANet, self).__init__()\n", 110 | " \n", 111 | " self.rnn = torch_utils.MultiRNN('LSTM', args.infeat_dim, 300, \n", 112 | " num_layers=args.num_layers, \n", 113 | " bidirectional=args.bidirectional)\n", 114 | " self.FC = torch_utils.FCLayer(600, args.infeat_dim*args.outfeat_dim, nonlinearity='tanh')\n", 115 | " \n", 116 | " self.infeat_dim = args.infeat_dim\n", 117 | " self.outfeat_dim = args.outfeat_dim\n", 118 | " self.eps = 1e-8\n", 119 | " \n", 120 | " def forward(self, input, ibm, weight, hidden):\n", 121 | " \"\"\"\n", 122 | " input: the input feature; \n", 123 | " shape: (B, T, F)\n", 124 | " \n", 125 | " ibm: the ideal binary mask used for calculating the \n", 126 | " ideal attractors; \n", 127 | " shape: (B, T*F, nspk)\n", 128 | " \n", 129 | " weight: the binary energy threshold matrix for masking \n", 130 | " out T-F bins; \n", 131 | " shape: (B, T*F, 1)\n", 132 | " \n", 133 | " hidden: the initial hidden state in the LSTM layers.\n", 134 | " \"\"\"\n", 135 | " \n", 136 | " seq_len = input.size(1)\n", 137 | " \n", 138 | " # generate the embeddings (V) by the LSTM layers\n", 139 | " LSTM_output, hidden = self.rnn(input, hidden)\n", 140 | " LSTM_output = LSTM_output.contiguous().view(-1, LSTM_output.size(2)) # B*T, H \n", 141 | " V = self.FC(LSTM_output) # B*T, F*K\n", 142 | " V = V.view(-1, seq_len*self.infeat_dim, self.outfeat_dim) # B, T*F, K\n", 143 | " \n", 144 | " # calculate the ideal attractors\n", 145 | " # first calculate the source assignment matrix Y\n", 146 | " Y = ibm * weight.expand_as(ibm) # B, T*F, nspk\n", 147 | " \n", 148 | " # attractors are the weighted average of the embeddings\n", 149 | " # calculated by V and Y\n", 150 | " V_Y = torch.bmm(torch.transpose(V, 1,2), Y) # B, K, nspk\n", 151 | " sum_Y = torch.sum(Y, 1, keepdim=True).expand_as(V_Y) # B, K, nspk\n", 152 | " attractor = V_Y / (sum_Y + self.eps) # B, K, 2\n", 153 | " \n", 154 | " # calculate the distance bewteen embeddings and attractors\n", 155 | " # and generate the masks\n", 156 | " dist = V.bmm(attractor) # B, T*F, nspk\n", 157 | " mask = F.softmax(dist, dim=2) # B, T*F, nspk\n", 158 | " \n", 159 | " return mask, hidden\n", 160 | " \n", 161 | " def init_hidden(self, batch_size):\n", 162 | " return self.rnn.init_hidden(batch_size)\n", 163 | " \n", 164 | " \n", 165 | "def objective(mixture, wfm, estimated_mask):\n", 166 | " \"\"\"\n", 167 | " MSE as the training objective. The mask estimation loss is calculated.\n", 168 | " You can also change it into the spectrogram estimation loss, which is \n", 169 | " to calculate the MSE between the clean source spectrograms and the \n", 170 | " masked mixture spectrograms.\n", 171 | " \n", 172 | " mixture: the spectrogram of the mixture;\n", 173 | " shape: (B, T, F)\n", 174 | " \n", 175 | " wfm: the target masks, which are the wiener-filter like masks here;\n", 176 | " shape: (B, T*F, nspk)\n", 177 | " \n", 178 | " estimated_mask: the estimated masks generated by the network;\n", 179 | " shape: (B, T*F, nspk)\n", 180 | " \"\"\"\n", 181 | " \n", 182 | " loss = mixture.expand(mixture.size(0), mixture.size(1), wfm.size(2)) * (wfm - estimated_mask)\n", 183 | " loss = loss.view(-1, loss.size(1)*loss.size(2))\n", 184 | " \n", 185 | " return torch.mean(torch.sum(torch.pow(loss, 2), 1))\n", 186 | " \n", 187 | "# define the model and the optimizer\n", 188 | "model = DANet()\n", 189 | "if args.cuda:\n", 190 | " model.cuda()\n", 191 | "\n", 192 | "current_lr = args.lr\n", 193 | "optimizer = optim.Adam(model.parameters(), lr=args.lr)\n", 194 | "scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.5)\n", 195 | "scheduler.step()" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": null, 201 | "metadata": {}, 202 | "outputs": [], 203 | "source": [ 204 | "# function for training and validation\n", 205 | "\n", 206 | "def train(epoch):\n", 207 | " start_time = time.time()\n", 208 | " model.train()\n", 209 | " train_loss = 0.\n", 210 | " \n", 211 | " # data loading\n", 212 | " # see data_utils.py for dataloader details\n", 213 | " for batch_idx, data in enumerate(train_loader):\n", 214 | " # batch_infeat is the input feature\n", 215 | " batch_infeat = Variable(data[0]).contiguous()\n", 216 | " \n", 217 | " # wiener-filter like mask as the training target\n", 218 | " batch_wfm = Variable(data[1]).contiguous()\n", 219 | " \n", 220 | " # spectrogram of mixture, used in objective\n", 221 | " batch_mix = Variable(data[2]).contiguous()\n", 222 | " \n", 223 | " # ideal binary mask as the ideal source assignment\n", 224 | " # used during the calculation of attractors\n", 225 | " batch_ibm = Variable(data[3]).contiguous()\n", 226 | " \n", 227 | " # energy threshold matrix calculated from the mixture spectrogram\n", 228 | " batch_weight = Variable(data[4]).contiguous()\n", 229 | " \n", 230 | " if args.cuda:\n", 231 | " batch_infeat = batch_infeat.cuda() # B, T, F\n", 232 | " batch_wfm = batch_wfm.cuda() # B, T*F, nspk\n", 233 | " batch_mix = batch_mix.cuda() # B, T, F\n", 234 | " batch_ibm = batch_ibm.cuda() # B, T*F, nspk\n", 235 | " batch_weight = batch_weight.cuda() # B, T*F, 1\n", 236 | " \n", 237 | " # training\n", 238 | " hidden = model.init_hidden(batch_infeat.size(0))\n", 239 | " optimizer.zero_grad()\n", 240 | " estimated_mask, hidden = model(batch_infeat, batch_ibm, batch_weight, hidden)\n", 241 | " \n", 242 | " loss = objective(batch_mix, batch_wfm, estimated_mask)\n", 243 | " loss.backward()\n", 244 | " train_loss += loss.data.item()\n", 245 | " optimizer.step()\n", 246 | " \n", 247 | " # output logs\n", 248 | " if (batch_idx+1) % args.log_step == 0:\n", 249 | " elapsed = time.time() - start_time\n", 250 | " print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | loss {:5.2f} |'.format(\n", 251 | " epoch, batch_idx+1, len(train_loader),\n", 252 | " elapsed * 1000 / (batch_idx+1), train_loss / (batch_idx+1)))\n", 253 | " \n", 254 | " train_loss /= (batch_idx+1)\n", 255 | " print('-' * 99)\n", 256 | " print(' | end of training epoch {:3d} | time: {:5.2f}s | training loss {:5.2f} |'.format(\n", 257 | " epoch, (time.time() - start_time), train_loss))\n", 258 | " \n", 259 | " return train_loss\n", 260 | " \n", 261 | "def validate(epoch):\n", 262 | " start_time = time.time()\n", 263 | " model.eval()\n", 264 | " validation_loss = 0.\n", 265 | " \n", 266 | " # data loading\n", 267 | " for batch_idx, data in enumerate(validation_loader):\n", 268 | " batch_infeat = Variable(data[0]).contiguous()\n", 269 | " batch_wfm = Variable(data[1]).contiguous()\n", 270 | " batch_mix = Variable(data[2]).contiguous()\n", 271 | " batch_ibm = Variable(data[3]).contiguous()\n", 272 | " batch_weight = Variable(data[4]).contiguous()\n", 273 | " \n", 274 | " if args.cuda:\n", 275 | " batch_infeat = batch_infeat.cuda()\n", 276 | " batch_wfm = batch_wfm.cuda()\n", 277 | " batch_mix = batch_mix.cuda()\n", 278 | " batch_ibm = batch_ibm.cuda()\n", 279 | " batch_weight = batch_weight.cuda()\n", 280 | " \n", 281 | " # mask estimation\n", 282 | " with torch.no_grad():\n", 283 | " hidden = model.init_hidden(batch_infeat.size(0))\n", 284 | " estimated_mask, hidden = model(batch_infeat, batch_ibm, batch_weight, hidden)\n", 285 | " \n", 286 | " loss = objective(batch_mix, batch_wfm, estimated_mask)\n", 287 | " validation_loss += loss.data.item()\n", 288 | " \n", 289 | " validation_loss /= (batch_idx+1)\n", 290 | " print(' | end of validation epoch {:3d} | time: {:5.2f}s | validation loss {:5.2f} |'.format(\n", 291 | " epoch, (time.time() - start_time), validation_loss))\n", 292 | " print('-' * 99)\n", 293 | " \n", 294 | " return validation_loss" 295 | ] 296 | }, 297 | { 298 | "cell_type": "code", 299 | "execution_count": null, 300 | "metadata": { 301 | "scrolled": true 302 | }, 303 | "outputs": [], 304 | "source": [ 305 | "# main function\n", 306 | "\n", 307 | "training_loss = []\n", 308 | "validation_loss = []\n", 309 | "decay_cnt = 0\n", 310 | "for epoch in range(1, args.epochs + 1):\n", 311 | " model.cuda()\n", 312 | " training_loss.append(train(epoch))\n", 313 | " validation_loss.append(validate(epoch))\n", 314 | " if training_loss[-1] == np.min(training_loss):\n", 315 | " print(' Best training model found.')\n", 316 | " print('-' * 99)\n", 317 | " if validation_loss[-1] == np.min(validation_loss):\n", 318 | " # save current best model\n", 319 | " with open(args.val_save, 'wb') as f:\n", 320 | " torch.save(model.cpu().state_dict(), f)\n", 321 | " print(' Best validation model found and saved.')\n", 322 | " print('-' * 99)\n", 323 | " decay_cnt += 1\n", 324 | " # lr decay\n", 325 | " if np.min(training_loss) not in training_loss[-3:] and decay_cnt >= 3:\n", 326 | " scheduler.step()\n", 327 | " decay_cnt = 0\n", 328 | " print(' Learning rate decreased.')\n", 329 | " print('-' * 99)" 330 | ] 331 | } 332 | ], 333 | "metadata": { 334 | "kernelspec": { 335 | "display_name": "Python 3", 336 | "language": "python", 337 | "name": "python3" 338 | }, 339 | "language_info": { 340 | "codemirror_mode": { 341 | "name": "ipython", 342 | "version": 3 343 | }, 344 | "file_extension": ".py", 345 | "mimetype": "text/x-python", 346 | "name": "python", 347 | "nbconvert_exporter": "python", 348 | "pygments_lexer": "ipython3", 349 | "version": "3.6.4" 350 | } 351 | }, 352 | "nbformat": 4, 353 | "nbformat_minor": 1 354 | } 355 | --------------------------------------------------------------------------------