├── .gitattributes ├── .gitignore ├── README.md ├── asset ├── dev_tuples.csv └── train_tuples.csv ├── img └── voice-filter.png ├── nnet ├── conf.py ├── data_simulate.py ├── libs │ ├── __init__.py │ ├── audio.py │ ├── dataset.py │ ├── evaluator.py │ └── trainer.py ├── nnet.py ├── separate.py ├── stft.py └── train.py ├── requirements.txt └── train.sh /.gitattributes: -------------------------------------------------------------------------------- 1 | * text=auto 2 | 3 | *.py text eol=lf 4 | *.sh text eol=lf -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /.vscode/ 2 | 3 | __pycache__ 4 | .py[cod] 5 | 6 | /nnet/profile_nnet.py 7 | /dilated_cnn.py -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## VoiceFilter 2 | 3 | A Pytorch implementation of Google's [VoiceFilter](https://www.isca-speech.org/archive/Interspeech_2019/pdfs/1101.pdf) System 4 | 5 | ![](img/voice-filter.png) 6 | 7 | ### Usage 8 | 9 | 1. Data simulation 10 | ```shell 11 | ./nnet/data_simulate.py --dump-dir simu/train /path/to/librispeech/train.scp asset/train_tuples.csv 12 | ./nnet/data_simulate.py --dump-dir simu/dev /path/to/librispeech/dev.scp asset/dev_tuples.csv 13 | ``` 14 | 2. Speaker embedding (I used public xvector from [here](http://kaldi-asr.org/models/m7)) 15 | 3. Data prepare 16 | 17 | Prepare data as `{mix,ref,emb}.scp` and the format of scp file follows Kaldi's recipe, egs ` ` pair in each line. 18 | 4. Confugure `nnet/conf.py` and train the model (see `train.sh` for details). 19 | 5. Use `nnet/separate.py` for inference. 20 | 21 | ### Note 22 | 23 | 1. I used Si-SNR loss instead of MSE of spectrogram, which could achieve better perfermance. 24 | -------------------------------------------------------------------------------- /img/voice-filter.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/funcwj/voice-filter/9daa53b74c4aa821ce3e4376b298b119b168f263/img/voice-filter.png -------------------------------------------------------------------------------- /nnet/conf.py: -------------------------------------------------------------------------------- 1 | # model config 2 | 3 | fs = 16000 4 | frame_len = 512 5 | frame_hop = 256 6 | 7 | nnet_conf = { 8 | "frame_len": frame_len, 9 | "frame_hop": frame_hop, 10 | "log_mag": False, 11 | "mvn_mag": False, 12 | "lstm_dim": 400, 13 | "linear_dim": 600, 14 | "l2_norm": True, 15 | "round_pow_of_two": True, 16 | "embedding_dim": 512, 17 | "non_linear": "relu" 18 | } 19 | 20 | # trainer config 21 | adam_kwargs = { 22 | "lr": 1e-3, 23 | "weight_decay": 1e-5, 24 | } 25 | 26 | trainer_conf = { 27 | "optimizer": "adam", 28 | "optimizer_kwargs": adam_kwargs, 29 | "logging_period": 200, 30 | "gradient_clip": 10, 31 | "min_lr": 1e-8, 32 | "patience": 1, 33 | "factor": 0.5 34 | } 35 | 36 | train_dir = "data/train/" 37 | dev_dir = "data/dev/" 38 | 39 | train_data = { 40 | "sr": fs, 41 | "mix_scp": train_dir + "mix.scp", 42 | "ref_scp": train_dir + "ref.scp", 43 | "emb_scp": train_dir + "emb.scp" 44 | } 45 | 46 | dev_data = { 47 | "sr": fs, 48 | "mix_scp": dev_dir + "mix.scp", 49 | "ref_scp": dev_dir + "ref.scp", 50 | "emb_scp": dev_dir + "emb.scp" 51 | } 52 | -------------------------------------------------------------------------------- /nnet/data_simulate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # wujian@2019 4 | 5 | import os 6 | import csv 7 | import random 8 | import argparse 9 | 10 | import tqdm 11 | import numpy as np 12 | from libs.audio import WaveReader, write_wav 13 | 14 | 15 | def mix_audio(src, itf): 16 | """ 17 | According to the paper, seems they do not scale speakers via SNRs 18 | """ 19 | min_len = min(src.size, itf.size) 20 | src_beg = random.randint(0, src.size - min_len) 21 | itf_beg = random.randint(0, itf.size - min_len) 22 | src_seg = src[src_beg:src_beg + min_len] 23 | itf_seg = itf[itf_beg:itf_beg + min_len] 24 | mix_seg = src_seg + itf_seg 25 | scale = random.uniform(0.5, 0.9) / np.max(np.abs(mix_seg)) 26 | return src_seg * scale, mix_seg * scale 27 | 28 | 29 | def run(args): 30 | if args.dump_dir: 31 | os.makedirs(args.dump_dir, exist_ok=True) 32 | wave_reader = WaveReader(args.wav_scp) 33 | with open(os.path.join(args.dump_dir, "emb.key"), "w") as emb: 34 | with open(args.csv, "r") as f: 35 | reader = csv.reader(f) 36 | for ids in tqdm.tqdm(reader): 37 | src_id, ref_id, itf_id = ids 38 | emb.write("{}\t{}\n".format("_".join(ids), ref_id)) 39 | src = wave_reader[src_id] 40 | itf = wave_reader[itf_id] 41 | src, mix = mix_audio(src, itf) 42 | write_wav( 43 | os.path.join(args.dump_dir, 44 | "src/{}.wav".format("_".join(ids))), src) 45 | write_wav( 46 | os.path.join(args.dump_dir, 47 | "mix/{}.wav".format("_".join(ids))), mix) 48 | 49 | 50 | if __name__ == "__main__": 51 | parser = argparse.ArgumentParser( 52 | description="Command to simulate data for VoiceFilter training", 53 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 54 | parser.add_argument("wav_scp", 55 | type=str, 56 | help="Rspecifier of wave files for " 57 | "Librispeech dataset") 58 | parser.add_argument("csv", 59 | type=str, 60 | help="CSV files obtained from " 61 | "https://github.com/google/speaker-id") 62 | parser.add_argument("--dump-dir", 63 | type=str, 64 | default="voice_data", 65 | help="Directory of output data triplet") 66 | args = parser.parse_args() 67 | run(args) 68 | -------------------------------------------------------------------------------- /nnet/libs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/funcwj/voice-filter/9daa53b74c4aa821ce3e4376b298b119b168f263/nnet/libs/__init__.py -------------------------------------------------------------------------------- /nnet/libs/audio.py: -------------------------------------------------------------------------------- 1 | # wujian@2018 2 | 3 | import os 4 | import numpy as np 5 | 6 | import scipy.io.wavfile as wf 7 | 8 | MAX_INT16 = np.iinfo(np.int16).max 9 | 10 | 11 | def _parse_script(scp_path, value_processor=lambda x: x, num_tokens=2): 12 | """ 13 | Parse kaldi's script(.scp) file 14 | If num_tokens >= 2, function will check token number 15 | """ 16 | scp_dict = dict() 17 | line = 0 18 | with open(scp_path, "r") as f: 19 | for raw_line in f: 20 | scp_tokens = raw_line.strip().split() 21 | line += 1 22 | if num_tokens >= 2 and len(scp_tokens) != num_tokens or len( 23 | scp_tokens) < 2: 24 | raise RuntimeError( 25 | "For {}, format error in line[{:d}]: {}".format( 26 | scp_path, line, raw_line)) 27 | if num_tokens == 2: 28 | key, value = scp_tokens 29 | else: 30 | key, value = scp_tokens[0], scp_tokens[1:] 31 | if key in scp_dict: 32 | raise ValueError("Duplicated key \'{0}\' exists in {1}".format( 33 | key, scp_path)) 34 | scp_dict[key] = value_processor(value) 35 | return scp_dict 36 | 37 | 38 | def read_wav(fname, normalize=True, return_rate=False): 39 | """ 40 | Read wave files using scipy.io.wavfile(support multi-channel) 41 | """ 42 | # samps_int16: N x C or N 43 | # N: number of samples 44 | # C: number of channels 45 | samp_rate, samps_int16 = wf.read(fname) 46 | # N x C => C x N 47 | samps = samps_int16.astype(np.float32) 48 | # tranpose because I used to put channel axis first 49 | if samps.ndim != 1: 50 | samps = np.transpose(samps) 51 | # normalize like MATLAB and librosa 52 | if normalize: 53 | samps = samps / MAX_INT16 54 | if return_rate: 55 | return samp_rate, samps 56 | return samps 57 | 58 | 59 | def write_wav(fname, samps, sr=16000, normalize=True): 60 | """ 61 | Write wav files in int16, support single/multi-channel 62 | """ 63 | if normalize: 64 | samps = samps * MAX_INT16 65 | # scipy.io.wavfile.write could write single/multi-channel files 66 | # for multi-channel, accept ndarray [Nsamples, Nchannels] 67 | if samps.ndim != 1 and samps.shape[0] < samps.shape[1]: 68 | samps = np.transpose(samps) 69 | samps = np.squeeze(samps) 70 | # same as MATLAB and kaldi 71 | samps_int16 = samps.astype(np.int16) 72 | fdir = os.path.dirname(fname) 73 | if fdir: 74 | os.makedirs(fdir, exist_ok=True) 75 | # NOTE: librosa 0.6.0 seems could not write non-float narray 76 | # so use scipy.io.wavfile instead 77 | wf.write(fname, sr, samps_int16) 78 | 79 | 80 | class Reader(object): 81 | """ 82 | Basic Reader Class 83 | """ 84 | 85 | def __init__(self, scp_path, value_processor=lambda x: x): 86 | self.index_dict = _parse_script(scp_path, 87 | value_processor=value_processor, 88 | num_tokens=2) 89 | self.index_keys = list(self.index_dict.keys()) 90 | 91 | def _load(self, key): 92 | # return path 93 | return self.index_dict[key] 94 | 95 | # number of utterance 96 | def __len__(self): 97 | return len(self.index_dict) 98 | 99 | # avoid key error 100 | def __contains__(self, key): 101 | return key in self.index_dict 102 | 103 | # sequential index 104 | def __iter__(self): 105 | for key in self.index_keys: 106 | yield key, self._load(key) 107 | 108 | # random index, support str/int as index 109 | def __getitem__(self, index): 110 | if type(index) not in [int, str]: 111 | raise IndexError("Unsupported index type: {}".format(type(index))) 112 | if type(index) == int: 113 | # from int index to key 114 | num_utts = len(self.index_keys) 115 | if index >= num_utts or index < 0: 116 | raise KeyError( 117 | "Interger index out of range, {:d} vs {:d}".format( 118 | index, num_utts)) 119 | index = self.index_keys[index] 120 | if index not in self.index_dict: 121 | raise KeyError("Missing utterance {}!".format(index)) 122 | return self._load(index) 123 | 124 | 125 | class WaveReader(Reader): 126 | """ 127 | Sequential/Random Reader for single channel wave 128 | Format of wav.scp follows Kaldi's definition: 129 | key1 /path/to/wav 130 | ... 131 | """ 132 | 133 | def __init__(self, wav_scp, sr=None, normalize=True): 134 | super(WaveReader, self).__init__(wav_scp) 135 | self.sr = sr 136 | self.normalize = normalize 137 | 138 | def _load(self, key): 139 | # return C x N or N 140 | sr, samps = read_wav(self.index_dict[key], 141 | normalize=self.normalize, 142 | return_rate=True) 143 | # if given samp_rate, check it 144 | if self.sr is not None and sr != self.sr: 145 | raise RuntimeError("SampleRate mismatch: {:d} vs {:d}".format( 146 | sr, self.sr)) 147 | return samps -------------------------------------------------------------------------------- /nnet/libs/dataset.py: -------------------------------------------------------------------------------- 1 | # wujian@2018 2 | 3 | import random 4 | import torch as th 5 | import numpy as np 6 | 7 | from torch.utils.data.dataloader import default_collate 8 | from kaldi_python_io import ScriptReader 9 | 10 | from .audio import Reader, WaveReader 11 | 12 | 13 | def make_dataloader(train=True, 14 | data_kwargs=None, 15 | chunk_size=32000, 16 | batch_size=16, 17 | cache_size=32): 18 | perutt_loader = PeruttLoader(shuffle=train, **data_kwargs) 19 | return DataLoader(perutt_loader, 20 | train=train, 21 | chunk_size=chunk_size, 22 | batch_size=batch_size, 23 | cache_size=cache_size) 24 | 25 | 26 | class NumpyReader(Reader): 27 | """ 28 | Sequential/Random Reader for numpy's ndarray(*.npy) file 29 | """ 30 | 31 | def __init__(self, npy_scp): 32 | super(NumpyReader, self).__init__(npy_scp) 33 | 34 | def _load(self, key): 35 | return np.load(self.index_dict[key]) 36 | 37 | 38 | class PeruttLoader(object): 39 | """ 40 | Per Utterance Loader 41 | """ 42 | 43 | def __init__(self, 44 | shuffle=True, 45 | mix_scp="", 46 | ref_scp="", 47 | emb_scp="", 48 | embed_format="kaldi", 49 | sr=16000): 50 | if embed_format not in ["kaldi", "numpy"]: 51 | raise RuntimeError( 52 | "Unknown embedding format {}".format(embed_format)) 53 | self.mix = WaveReader(mix_scp, sr=sr) 54 | self.ref = WaveReader(ref_scp, sr=sr) 55 | self.emb = NumpyReader( 56 | emb_scp) if embed_format == "numpy" else ScriptReader(emb_scp, 57 | matrix=False) 58 | self.shuffle = shuffle 59 | 60 | def __iter__(self): 61 | if self.shuffle: 62 | random.shuffle(self.mix.index_keys) 63 | for key, mix in self.mix: 64 | eg = dict() 65 | eg["mix"] = mix 66 | eg["ref"] = self.ref[key] 67 | emb = self.emb[key] 68 | eg["emb"] = emb / (np.linalg.norm(emb, 2) + 1e-8) 69 | yield eg 70 | 71 | 72 | class ChunkSplitter(object): 73 | """ 74 | Split utterance into small chunks 75 | """ 76 | 77 | def __init__(self, chunk_size, train=True, hop=16000): 78 | self.chunk_size = chunk_size 79 | self.hop = hop 80 | self.train = train 81 | 82 | def _make_chunk(self, eg, s): 83 | """ 84 | Make a chunk instance, which contains: 85 | "emb": ndarray, 86 | "mix": ndarray, 87 | "ref": ndarray 88 | """ 89 | chunk = dict() 90 | # support for multi-channel 91 | chunk["emb"] = eg["emb"] 92 | chunk["mix"] = eg["mix"][..., s:s + self.chunk_size] 93 | chunk["ref"] = eg["ref"][..., s:s + self.chunk_size] 94 | return chunk 95 | 96 | def split(self, eg): 97 | N = eg["mix"].shape[-1] 98 | # too short, throw away 99 | if N < self.hop: 100 | return [] 101 | chunks = [] 102 | # padding zeros 103 | if N < self.chunk_size: 104 | chunk = dict() 105 | P = self.chunk_size - N 106 | pad_width = ((0, 0), (0, P)) if eg["mix"].ndim == 2 else (0, P) 107 | chunk["mix"] = np.pad(eg["mix"], pad_width, "constant") 108 | chunk["emb"] = eg["emb"] 109 | chunk["ref"] = np.pad(eg["ref"], (0, P), "constant") 110 | chunks.append(chunk) 111 | else: 112 | # random select start point for training 113 | s = random.randint(0, N % self.hop) if self.train else 0 114 | while True: 115 | if s + self.chunk_size > N: 116 | break 117 | chunk = self._make_chunk(eg, s) 118 | chunks.append(chunk) 119 | s += self.hop 120 | return chunks 121 | 122 | 123 | class DataLoader(object): 124 | """ 125 | Online dataloader for chunk-level loss 126 | """ 127 | 128 | def __init__(self, 129 | perutt_loader, 130 | chunk_size=32000, 131 | batch_size=16, 132 | cache_size=16, 133 | train=True): 134 | self.loader = perutt_loader 135 | self.cache_size = cache_size * batch_size 136 | self.batch_size = batch_size 137 | self.train = train 138 | self.splitter = ChunkSplitter(chunk_size, 139 | train=train, 140 | hop=chunk_size // 2) 141 | 142 | def _fetch_batch(self): 143 | while True: 144 | if len(self.load_list) >= self.cache_size: 145 | break 146 | try: 147 | eg = next(self.load_iter) 148 | cs = self.splitter.split(eg) 149 | self.load_list.extend(cs) 150 | except StopIteration: 151 | self.stop_iter = True 152 | break 153 | if self.train: 154 | random.shuffle(self.load_list) 155 | N = len(self.load_list) 156 | blist = [] 157 | for s in range(0, N - self.batch_size + 1, self.batch_size): 158 | batch = default_collate(self.load_list[s:s + self.batch_size]) 159 | blist.append(batch) 160 | # update load_list 161 | rn = N % self.batch_size 162 | if rn: 163 | self.load_list = self.load_list[-rn:] 164 | else: 165 | self.load_list = [] 166 | return blist 167 | 168 | def __iter__(self): 169 | # reset flags 170 | self.load_iter = iter(self.loader) 171 | self.stop_iter = False 172 | self.load_list = [] 173 | 174 | while not self.stop_iter: 175 | bs = self._fetch_batch() 176 | for obj in bs: 177 | yield obj -------------------------------------------------------------------------------- /nnet/libs/evaluator.py: -------------------------------------------------------------------------------- 1 | # wujian@2019 2 | 3 | import json 4 | import os.path as op 5 | 6 | import torch as th 7 | 8 | from .trainer import get_logger 9 | 10 | logger = get_logger(__name__) 11 | 12 | 13 | class Evaluator(object): 14 | """ 15 | A simple wrapper for model evaluation 16 | """ 17 | 18 | def __init__(self, nnet_cls, cpt_dir, gpu_id=-1): 19 | # load nnet 20 | self.nnet = self._load_nnet(nnet_cls, cpt_dir) 21 | self.device = th.device( 22 | "cpu" if gpu_id < 0 else "cuda:{:d}".format(gpu_id)) 23 | if gpu_id >= 0: 24 | self.nnet.to(self.device) 25 | # set eval model 26 | self.nnet.eval() 27 | 28 | def compute(self, egs): 29 | raise NotImplementedError 30 | 31 | def _load_nnet(self, nnet_cls, cpt_dir): 32 | """ 33 | Load model from checkpoints 34 | """ 35 | with open(op.join(cpt_dir, "mdl.json"), "r") as f: 36 | # load nnet conf 37 | nnet_conf = json.load(f) 38 | nnet = nnet_cls(**nnet_conf) 39 | # load checkpoint 40 | cpt_fname = op.join(cpt_dir, "best.pt.tar") 41 | cpt = th.load(cpt_fname, map_location="cpu") 42 | nnet.load_state_dict(cpt["model_state_dict"]) 43 | # log state 44 | logger.info( 45 | "Load model from checkpoint at {}, on epoch {:d}".format( 46 | cpt_fname, cpt["epoch"])) 47 | return nnet -------------------------------------------------------------------------------- /nnet/libs/trainer.py: -------------------------------------------------------------------------------- 1 | # wujian@2018 2 | 3 | import os 4 | import sys 5 | import time 6 | import logging 7 | 8 | from itertools import permutations 9 | from collections import defaultdict 10 | 11 | import torch as th 12 | import torch.nn.functional as F 13 | from torch.optim.lr_scheduler import ReduceLROnPlateau 14 | from torch.nn.utils import clip_grad_norm_ 15 | from torch import autograd 16 | 17 | 18 | def load_obj(obj, device): 19 | """ 20 | Offload tensor object in obj to cuda device 21 | """ 22 | 23 | def cuda(obj): 24 | return obj.to(device) if isinstance(obj, th.Tensor) else obj 25 | 26 | if isinstance(obj, dict): 27 | return {key: load_obj(obj[key], device) for key in obj} 28 | elif isinstance(obj, list): 29 | return [load_obj(val, device) for val in obj] 30 | else: 31 | return cuda(obj) 32 | 33 | 34 | def get_logger( 35 | name, 36 | format_str="%(asctime)s [%(pathname)s:%(lineno)s - %(levelname)s ] %(message)s", 37 | date_format="%Y-%m-%d %H:%M:%S", 38 | file=False): 39 | """ 40 | Get logger instance 41 | """ 42 | 43 | def get_handler(handler): 44 | handler.setLevel(logging.INFO) 45 | formatter = logging.Formatter(fmt=format_str, datefmt=date_format) 46 | handler.setFormatter(formatter) 47 | return handler 48 | 49 | logger = logging.getLogger(name) 50 | logger.setLevel(logging.INFO) 51 | if file: 52 | # both stdout & file 53 | logger.addHandler(get_handler(logging.FileHandler(name))) 54 | # logger.addHandler(logging.StreamHandler()) 55 | else: 56 | logger.addHandler(logging.StreamHandler()) 57 | return logger 58 | 59 | 60 | class SimpleTimer(object): 61 | """ 62 | A simple timer 63 | """ 64 | 65 | def __init__(self): 66 | self.reset() 67 | 68 | def reset(self): 69 | self.start = time.time() 70 | 71 | def elapsed(self): 72 | return (time.time() - self.start) / 60 73 | 74 | 75 | class ProgressReporter(object): 76 | """ 77 | A simple progress reporter 78 | """ 79 | 80 | def __init__(self, logger, period=100): 81 | self.period = period 82 | self.logger = logger 83 | self.reset() 84 | 85 | def reset(self): 86 | self.stats = defaultdict(list) 87 | self.timer = SimpleTimer() 88 | 89 | def add(self, key, value): 90 | self.stats[key].append(value) 91 | N = len(self.stats[key]) 92 | if not N % self.period: 93 | avg = sum(self.stats[key][-self.period:]) / self.period 94 | self.logger.info("Processed {:.2e} batches " 95 | "({} = {:+.2f})...".format(N, key, avg)) 96 | 97 | def report(self, details=False): 98 | N = len(self.stats["loss"]) 99 | if details: 100 | sstr = ",".join( 101 | map(lambda f: "{:.2f}".format(f), self.stats["loss"])) 102 | self.logger.info("Loss on {:d} batches: {}".format(N, sstr)) 103 | return { 104 | "loss": sum(self.stats["loss"]) / N, 105 | "batches": N, 106 | "cost": self.timer.elapsed() 107 | } 108 | 109 | 110 | class Trainer(object): 111 | def __init__(self, 112 | nnet, 113 | checkpoint="checkpoint", 114 | optimizer="adam", 115 | gpuid=0, 116 | optimizer_kwargs=None, 117 | gradient_clip=None, 118 | min_lr=0, 119 | patience=0, 120 | factor=0.5, 121 | logger=None, 122 | logging_period=100, 123 | resume=None, 124 | no_impr=6): 125 | if not th.cuda.is_available(): 126 | raise RuntimeError("CUDA device unavailable...exist") 127 | if not isinstance(gpuid, tuple): 128 | gpuid = (gpuid, ) 129 | self.device = th.device("cuda:{}".format(gpuid[0])) 130 | self.gpuid = gpuid 131 | if checkpoint: 132 | os.makedirs(checkpoint, exist_ok=True) 133 | self.checkpoint = checkpoint 134 | self.logger = logger if logger else get_logger( 135 | os.path.join(checkpoint, "trainer.log"), file=True) 136 | 137 | self.gradient_clip = gradient_clip 138 | self.logging_period = logging_period 139 | self.cur_epoch = 0 # zero based 140 | self.no_impr = no_impr 141 | 142 | if resume: 143 | if not os.path.exists(resume): 144 | raise FileNotFoundError( 145 | "Could not find resume checkpoint: {}".format(resume)) 146 | cpt = th.load(resume, map_location="cpu") 147 | self.cur_epoch = cpt["epoch"] 148 | self.logger.info("Resume from checkpoint {}: epoch {:d}".format( 149 | resume, self.cur_epoch)) 150 | # load nnet 151 | nnet.load_state_dict(cpt["model_state_dict"]) 152 | self.nnet = nnet.to(self.device) 153 | self.optimizer = self.create_optimizer( 154 | optimizer, optimizer_kwargs, state=cpt["optim_state_dict"]) 155 | else: 156 | self.nnet = nnet.to(self.device) 157 | self.optimizer = self.create_optimizer(optimizer, optimizer_kwargs) 158 | self.scheduler = ReduceLROnPlateau(self.optimizer, 159 | mode="min", 160 | factor=factor, 161 | patience=patience, 162 | min_lr=min_lr, 163 | verbose=True) 164 | self.num_params = sum( 165 | [param.nelement() for param in nnet.parameters()]) / 10.0**6 166 | 167 | # logging 168 | self.logger.info("Model summary:\n{}".format(nnet)) 169 | self.logger.info("Loading model to GPUs:{}, #param: {:.2f}M".format( 170 | gpuid, self.num_params)) 171 | if gradient_clip: 172 | self.logger.info( 173 | "Gradient clipping by {}, default L2".format(gradient_clip)) 174 | 175 | def save_checkpoint(self, best=True): 176 | cpt = { 177 | "epoch": self.cur_epoch, 178 | "model_state_dict": self.nnet.state_dict(), 179 | "optim_state_dict": self.optimizer.state_dict() 180 | } 181 | th.save( 182 | cpt, 183 | os.path.join(self.checkpoint, 184 | "{0}.pt.tar".format("best" if best else "last"))) 185 | 186 | def create_optimizer(self, optimizer, kwargs, state=None): 187 | supported_optimizer = { 188 | "sgd": th.optim.SGD, # momentum, weight_decay, lr 189 | "rmsprop": th.optim.RMSprop, # momentum, weight_decay, lr 190 | "adam": th.optim.Adam, # weight_decay, lr 191 | "adadelta": th.optim.Adadelta, # weight_decay, lr 192 | "adagrad": th.optim.Adagrad, # lr, lr_decay, weight_decay 193 | "adamax": th.optim.Adamax # lr, weight_decay 194 | # ... 195 | } 196 | if optimizer not in supported_optimizer: 197 | raise ValueError("Now only support optimizer {}".format(optimizer)) 198 | opt = supported_optimizer[optimizer](self.nnet.parameters(), **kwargs) 199 | self.logger.info("Create optimizer {0}: {1}".format(optimizer, kwargs)) 200 | if state is not None: 201 | opt.load_state_dict(state) 202 | self.logger.info("Load optimizer state dict from checkpoint") 203 | return opt 204 | 205 | def compute_loss(self, egs): 206 | raise NotImplementedError 207 | 208 | def train(self, data_loader): 209 | self.logger.info("Set train mode...") 210 | self.nnet.train() 211 | reporter = ProgressReporter(self.logger, period=self.logging_period) 212 | 213 | # with autograd.detect_anomaly(): 214 | for egs in data_loader: 215 | # load to gpu 216 | egs = load_obj(egs, self.device) 217 | 218 | self.optimizer.zero_grad() 219 | loss = self.compute_loss(egs) 220 | loss.backward() 221 | if self.gradient_clip: 222 | norm = clip_grad_norm_(self.nnet.parameters(), 223 | self.gradient_clip) 224 | reporter.add("norm", norm) 225 | self.optimizer.step() 226 | 227 | reporter.add("loss", loss.item()) 228 | return reporter.report() 229 | 230 | def eval(self, data_loader): 231 | self.logger.info("Set eval mode...") 232 | self.nnet.eval() 233 | reporter = ProgressReporter(self.logger, period=self.logging_period) 234 | 235 | with th.no_grad(): 236 | for egs in data_loader: 237 | egs = load_obj(egs, self.device) 238 | loss = self.compute_loss(egs) 239 | reporter.add("loss", loss.item()) 240 | return reporter.report(details=True) 241 | 242 | def run(self, 243 | train_loader, 244 | dev_loader, 245 | num_epoches=100, 246 | eval_interval=4000): 247 | stats = dict() 248 | # make dilated conv faster 249 | th.backends.cudnn.benchmark = True 250 | # avoid alloc memory from gpu0 251 | th.cuda.set_device(self.gpuid[0]) 252 | # check if save is OK 253 | self.save_checkpoint(best=False) 254 | cv = self.eval(dev_loader) 255 | best_loss = cv["loss"] 256 | self.logger.info("START FROM EPOCH {:d}, LOSS = {:.4f}".format( 257 | self.cur_epoch, best_loss)) 258 | no_impr = 0 259 | stop = False 260 | trained_batches = 0 261 | train_reporter = ProgressReporter(self.logger, 262 | period=self.logging_period) 263 | # make sure not inf 264 | self.scheduler.best = best_loss 265 | # set train mode 266 | self.nnet.train() 267 | while True: 268 | # trained on several batches 269 | for egs in train_loader: 270 | trained_batches = (trained_batches + 1) % eval_interval 271 | # update per-batch 272 | egs = load_obj(egs, self.device) 273 | self.optimizer.zero_grad() 274 | loss = self.compute_loss(egs) 275 | loss.backward() 276 | if self.gradient_clip: 277 | norm = clip_grad_norm_(self.nnet.parameters(), 278 | self.gradient_clip) 279 | train_reporter.add("norm", norm) 280 | self.optimizer.step() 281 | # record loss 282 | train_reporter.add("loss", loss.item()) 283 | # if trained on batches done, start evaluation 284 | if trained_batches == 0: 285 | self.cur_epoch += 1 286 | cur_lr = self.optimizer.param_groups[0]["lr"] 287 | stats[ 288 | "title"] = "Loss(time/N, lr={:.3e}) - Epoch {:2d}:".format( 289 | cur_lr, self.cur_epoch) 290 | tr = train_reporter.report() 291 | stats["tr"] = "train = {:+.4f}({:.2f}m/{:d})".format( 292 | tr["loss"], tr["cost"], tr["batches"]) 293 | cv = self.eval(dev_loader) 294 | stats["cv"] = "dev = {:+.4f}({:.2f}m/{:d})".format( 295 | cv["loss"], cv["cost"], cv["batches"]) 296 | stats["scheduler"] = "" 297 | if cv["loss"] > best_loss: 298 | no_impr += 1 299 | stats["scheduler"] = "| no impr, best = {:.4f}".format( 300 | self.scheduler.best) 301 | else: 302 | best_loss = cv["loss"] 303 | no_impr = 0 304 | self.save_checkpoint(best=True) 305 | self.logger.info( 306 | "{title} {tr} | {cv} {scheduler}".format(**stats)) 307 | # schedule here 308 | self.scheduler.step(cv["loss"]) 309 | # flush scheduler info 310 | sys.stdout.flush() 311 | # save last checkpoint 312 | self.save_checkpoint(best=False) 313 | # reset reporter 314 | train_reporter.reset() 315 | # early stop or not 316 | if no_impr == self.no_impr: 317 | self.logger.info( 318 | "Stop training cause no impr for {:d} epochs". 319 | format(no_impr)) 320 | stop = True 321 | break 322 | if self.cur_epoch == num_epoches: 323 | stop = True 324 | break 325 | # enable train mode 326 | self.nnet.train() 327 | if stop: 328 | break 329 | self.logger.info("Training for {:d}/{:d} epoches done!".format( 330 | self.cur_epoch, num_epoches)) 331 | 332 | 333 | class SiSnrTrainer(Trainer): 334 | def __init__(self, *args, **kwargs): 335 | super(SiSnrTrainer, self).__init__(*args, **kwargs) 336 | 337 | def sisnr(self, x, s, eps=1e-8): 338 | """ 339 | Arguments: 340 | x: separated signal, N x S tensor 341 | s: reference signal, N x S tensor 342 | Return: 343 | sisnr: N tensor 344 | """ 345 | 346 | def l2norm(mat, keepdim=False): 347 | return th.norm(mat, dim=-1, keepdim=keepdim) 348 | 349 | if x.shape != s.shape: 350 | raise RuntimeError( 351 | "Dimention mismatch when calculate si-snr, {} vs {}".format( 352 | x.shape, s.shape)) 353 | x_zm = x - th.mean(x, dim=-1, keepdim=True) 354 | s_zm = s - th.mean(s, dim=-1, keepdim=True) 355 | t = th.sum(x_zm * s_zm, dim=-1, 356 | keepdim=True) * s_zm / (l2norm(s_zm, keepdim=True)**2 + eps) 357 | return 20 * th.log10(eps + l2norm(t) / (l2norm(x_zm - t) + eps)) 358 | 359 | def compute_loss(self, egs): 360 | # flatten for parallel module 361 | self.nnet.flatten_parameters() 362 | # N x S 363 | est = th.nn.parallel.data_parallel(self.nnet, (egs["mix"], egs["emb"]), 364 | device_ids=self.gpuid) 365 | # N 366 | snr = self.sisnr(est, egs["ref"]) 367 | return -th.sum(snr) / est.size(0) -------------------------------------------------------------------------------- /nnet/nnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # wujian@2019 4 | 5 | import math 6 | import torch as th 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from stft import STFT, iSTFT 11 | 12 | EPSILON = th.finfo(th.float32).eps 13 | 14 | 15 | class Conv2dBlock(nn.Module): 16 | """ 17 | 2D convolutional blocks used in VoiceFilter 18 | """ 19 | 20 | def __init__(self, 21 | in_channels, 22 | out_channels, 23 | kernel_size=(5, 5), 24 | dilation=(1, 1)): 25 | super(Conv2dBlock, self).__init__() 26 | self.conv = nn.Conv2d(in_channels, 27 | out_channels, 28 | kernel_size, 29 | stride=1, 30 | dilation=dilation, 31 | padding=tuple( 32 | d * (k - 1) // 2 33 | for k, d in zip(kernel_size, dilation))) 34 | self.bn = nn.BatchNorm2d(out_channels) 35 | 36 | def forward(self, x): 37 | """ 38 | x: N x F x T 39 | """ 40 | x = self.bn(self.conv(x)) 41 | return F.relu(x) 42 | 43 | 44 | class VoiceFilter(nn.Module): 45 | """ 46 | Reference from 47 | VoiceFilter: Targeted Voice Separation by Speaker-Conditioned Spectrogram Masking 48 | """ 49 | 50 | def __init__(self, 51 | frame_len, 52 | frame_hop, 53 | round_pow_of_two=True, 54 | embedding_dim=512, 55 | log_mag=False, 56 | mvn_mag=False, 57 | lstm_dim=400, 58 | linear_dim=600, 59 | l2_norm=True, 60 | bidirectional=False, 61 | non_linear="relu"): 62 | super(VoiceFilter, self).__init__() 63 | supported_nonlinear = { 64 | "relu": F.relu, 65 | "sigmoid": th.sigmoid, 66 | "tanh": th.tanh 67 | } 68 | if non_linear not in supported_nonlinear: 69 | raise RuntimeError( 70 | "Unsupported non-linear function: {}".format(non_linear)) 71 | N = 2**math.ceil( 72 | math.log2(frame_len)) if round_pow_of_two else frame_len 73 | num_bins = N // 2 + 1 74 | 75 | self.stft = STFT(frame_len, 76 | frame_hop, 77 | round_pow_of_two=round_pow_of_two) 78 | self.istft = iSTFT(frame_len, 79 | frame_hop, 80 | round_pow_of_two=round_pow_of_two) 81 | self.cnn_f = Conv2dBlock(1, 64, kernel_size=(7, 1)) 82 | self.cnn_t = Conv2dBlock(64, 64, kernel_size=(1, 7)) 83 | blocks = [] 84 | for d in range(5): 85 | blocks.append( 86 | Conv2dBlock(64, 64, kernel_size=(5, 5), dilation=(1, 2**d))) 87 | self.cnn_tf = nn.Sequential(*blocks) 88 | self.proj = Conv2dBlock(64, 8, kernel_size=(1, 1)) 89 | self.lstm = nn.LSTM(8 * num_bins + embedding_dim, 90 | lstm_dim, 91 | batch_first=True, 92 | bidirectional=bidirectional) 93 | self.mask = nn.Sequential( 94 | nn.Linear(lstm_dim * 2 if bidirectional else lstm_dim, linear_dim), 95 | nn.ReLU(), nn.Linear(linear_dim, num_bins)) 96 | self.non_linear = supported_nonlinear[non_linear] 97 | self.embedding_dim = embedding_dim 98 | self.l2_norm = l2_norm 99 | self.log_mag = log_mag 100 | self.bn = nn.BatchNorm1d(num_bins) if mvn_mag else None 101 | 102 | def flatten_parameters(self): 103 | self.lstm.flatten_parameters() 104 | 105 | def check_args(self, x, e): 106 | if x.dim() != e.dim(): 107 | raise RuntimeError( 108 | "{} got invalid input dim: x/e = {:d}/{:d}".format( 109 | self.__name__, x.dim(), e.dim())) 110 | if e.size(-1) != self.embedding_dim: 111 | raise RuntimeError("input embedding dim do not match with " 112 | "network's, {:d} vs {:d}".format( 113 | e.size(-1), self.embedding_dim)) 114 | 115 | def forward(self, x, e, return_mask=False): 116 | """ 117 | x: N x S 118 | e: N x D 119 | """ 120 | if x.dim() == 1: 121 | x = th.unsqueeze(x, 0) 122 | e = th.unsqueeze(e, 0) 123 | if self.l2_norm: 124 | e = e / th.norm(e, 2, dim=1, keepdim=True) 125 | 126 | # N x S => N x F x T 127 | mag, ang = self.stft(x) 128 | 129 | # clip 130 | y = th.clamp(mag, min=EPSILON) 131 | # apply log 132 | if self.log_mag: 133 | y = th.log(y) 134 | # apply bn 135 | if self.bn: 136 | y = self.bn(y) 137 | 138 | N, _, T = mag.shape 139 | # N x 1 x F x T 140 | y = th.unsqueeze(y, 1) 141 | # N x D => N x D x T 142 | e = th.unsqueeze(e, 2).repeat(1, 1, T) 143 | 144 | y = self.cnn_f(y) 145 | y = self.cnn_t(y) 146 | y = self.cnn_tf(y) 147 | # N x C x F x T 148 | y = self.proj(y) 149 | # N x CF x T 150 | y = y.view(N, -1, T) 151 | # N x (CF+D) x T 152 | f = th.cat([y, e], 1) 153 | # N x T x (CF+D) 154 | f = th.transpose(f, 1, 2) 155 | f, _ = self.lstm(f) 156 | # N x T x F 157 | m = self.non_linear(self.mask(f)) 158 | if return_mask: 159 | return m 160 | # N x F x T 161 | m = th.transpose(m, 1, 2) 162 | # N x S 163 | s = self.istft(mag * m, ang, squeeze=True) 164 | return s 165 | 166 | 167 | def run(): 168 | x = th.rand(1, 2000) 169 | e = th.rand(1, 512) 170 | 171 | nnet = VoiceFilter(256, 128) 172 | print(nnet) 173 | s = nnet(x, e, return_mask=True) 174 | print(s.squeeze().shape) 175 | 176 | 177 | if __name__ == "__main__": 178 | run() 179 | -------------------------------------------------------------------------------- /nnet/separate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # wujian@2018 4 | 5 | import os 6 | import argparse 7 | 8 | import torch as th 9 | import numpy as np 10 | 11 | from nnet import VoiceFilter 12 | 13 | from libs.audio import WaveReader, write_wav 14 | from libs.dataset import NumpyReader 15 | from libs.evaluator import Evaluator 16 | from libs.trainer import get_logger 17 | 18 | from kaldi_python_io import ScriptReader 19 | 20 | logger = get_logger(__name__) 21 | 22 | 23 | class NnetComputer(Evaluator): 24 | """ 25 | Evaluator implementation 26 | """ 27 | 28 | def __init__(self, *args, **kwargs): 29 | super(NnetComputer, self).__init__(*args, **kwargs) 30 | 31 | def compute(self, mix, emb): 32 | with th.no_grad(): 33 | mix = th.from_numpy(mix).to(self.device) 34 | emb = th.from_numpy(emb).to(self.device) 35 | spk = self.nnet(mix, emb) 36 | return spk.detach().squeeze().cpu().numpy() 37 | 38 | 39 | def run(args): 40 | mix_reader = WaveReader(args.mix_scp, sr=args.fs) 41 | spk_embed = NumpyReader( 42 | args.emb_scp) if args.format == "numpy" else ScriptReader(args.emb_scp, 43 | matrix=False) 44 | os.makedirs(args.dump_dir, exist_ok=True) 45 | computer = NnetComputer(VoiceFilter, args.checkpoint, gpu_id=args.gpu) 46 | for key, mix in mix_reader: 47 | logger.info("Compute on utterance {}...".format(key)) 48 | emb = spk_embed[key] 49 | emb = emb / (np.linalg.norm(emb, 2) + 1e-8) 50 | spk = computer.compute(mix, emb) 51 | norm = np.linalg.norm(mix, np.inf) 52 | # norm 53 | spk = spk * norm / np.max(np.abs(spk)) 54 | write_wav(os.path.join(args.dump_dir, "{}.wav".format(key)), 55 | spk, 56 | sr=args.fs) 57 | logger.info("Compute over {:d} utterances".format(len(mix_reader))) 58 | 59 | 60 | if __name__ == "__main__": 61 | parser = argparse.ArgumentParser( 62 | description="Command to do speaker aware separation", 63 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 64 | parser.add_argument("checkpoint", type=str, help="Directory of checkpoint") 65 | parser.add_argument("--mix-scp", 66 | type=str, 67 | required=True, 68 | help="Rspecifier for input waveform") 69 | parser.add_argument("--emb-scp", 70 | type=str, 71 | required=True, 72 | help="Rspecifier for speaker embeddings") 73 | parser.add_argument("--emb-format", 74 | type=str, 75 | dest="format", 76 | choices=["kaldi", "numpy"], 77 | default="kaldi", 78 | help="Storage type for speaker embeddings") 79 | parser.add_argument("--gpu", 80 | type=int, 81 | default=-1, 82 | help="GPU-id to offload model to, -1 means " 83 | "running on CPU") 84 | parser.add_argument("--fs", 85 | type=int, 86 | default=16000, 87 | help="Sample rate for mixture input") 88 | parser.add_argument("--dump-dir", 89 | type=str, 90 | default="spk", 91 | help="Directory to dump separated speakers out") 92 | args = parser.parse_args() 93 | run(args) -------------------------------------------------------------------------------- /nnet/stft.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # wujian@2019 4 | 5 | import math 6 | import torch as th 7 | 8 | import torch.nn.functional as F 9 | import torch.nn as nn 10 | 11 | EPSILON = th.finfo(th.float32).eps 12 | 13 | 14 | def init_kernel(frame_len, 15 | frame_hop, 16 | round_pow_of_two=True, 17 | window="sqrt_hann"): 18 | if window != "sqrt_hann": 19 | raise RuntimeError("Now only support sqrt hanning window in order " 20 | "to make signal perfectly reconstructed") 21 | # FFT points 22 | N = 2**math.ceil(math.log2(frame_len)) if round_pow_of_two else frame_len 23 | # window 24 | W = th.hann_window(frame_len)**0.5 25 | S = 0.5 * (N * N / frame_hop)**0.5 26 | # F x N/2+1 x 2 27 | K = th.rfft(th.eye(N) / S, 1)[:frame_len] 28 | # 2 x N/2+1 x F 29 | K = th.transpose(K, 0, 2) * W 30 | # N+2 x 1 x F 31 | K = th.reshape(K, (N + 2, 1, frame_len)) 32 | return K 33 | 34 | 35 | class STFTBase(nn.Module): 36 | """ 37 | Base layer for (i)STFT 38 | NOTE: 39 | 1) Recommend sqrt_hann window with 2**N frame length, because it 40 | could achieve perfect reconstruction after overlap-add 41 | 2) Now haven't consider padding problems yet 42 | """ 43 | 44 | def __init__(self, 45 | frame_len, 46 | frame_hop, 47 | window="sqrt_hann", 48 | round_pow_of_two=True): 49 | super(STFTBase, self).__init__() 50 | K = init_kernel(frame_len, 51 | frame_hop, 52 | round_pow_of_two=round_pow_of_two, 53 | window=window) 54 | self.K = nn.Parameter(K, requires_grad=False) 55 | self.stride = frame_hop 56 | self.window = window 57 | 58 | def extra_repr(self): 59 | return "window={0}, stride={1}, kernel_size={2[0]}x{2[2]}".format( 60 | self.window, self.stride, self.K.shape) 61 | 62 | 63 | class STFT(STFTBase): 64 | """ 65 | Short-time Fourier Transform as a Layer 66 | """ 67 | 68 | def __init__(self, *args, **kwargs): 69 | super(STFT, self).__init__(*args, **kwargs) 70 | 71 | def forward(self, x): 72 | """ 73 | Accept raw waveform and output magnitude and phase 74 | x: input signal, N x 1 x S or N x S 75 | m: magnitude, N x F x T 76 | p: phase, N x F x T 77 | """ 78 | if x.dim() not in [2, 3]: 79 | raise RuntimeError("Expect 2D/3D tensor, but got {:d}D".format( 80 | x.dim())) 81 | # if N x S, reshape N x 1 x S 82 | if x.dim() == 2: 83 | x = th.unsqueeze(x, 1) 84 | # N x 2F x T 85 | c = F.conv1d(x, self.K, stride=self.stride, padding=0) 86 | # N x F x T 87 | r, i = th.chunk(c, 2, dim=1) 88 | m = (r**2 + i**2)**0.5 89 | p = th.atan2(i, r) 90 | return m, p 91 | 92 | 93 | class iSTFT(STFTBase): 94 | """ 95 | Inverse Short-time Fourier Transform as a Layer 96 | """ 97 | 98 | def __init__(self, *args, **kwargs): 99 | super(iSTFT, self).__init__(*args, **kwargs) 100 | 101 | def forward(self, m, p, squeeze=False): 102 | """ 103 | Accept phase & magnitude and output raw waveform 104 | m, p: N x F x T 105 | s: N x C x S 106 | """ 107 | if p.dim() != m.dim() or p.dim() not in [2, 3]: 108 | raise RuntimeError("Expect 2D/3D tensor, but got {:d}D".format( 109 | p.dim())) 110 | # if F x T, reshape 1 x F x T 111 | if p.dim() == 2: 112 | p = th.unsqueeze(p, 0) 113 | m = th.unsqueeze(m, 0) 114 | r = m * th.cos(p) 115 | i = m * th.sin(p) 116 | # N x 2F x T 117 | c = th.cat([r, i], dim=1) 118 | # N x 2F x T 119 | s = F.conv_transpose1d(c, self.K, stride=self.stride, padding=0) 120 | if squeeze: 121 | s = th.squeeze(s) 122 | return s -------------------------------------------------------------------------------- /nnet/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # wujian@2018 4 | 5 | import os 6 | import json 7 | import pprint 8 | import argparse 9 | import random 10 | 11 | from libs.trainer import SiSnrTrainer, get_logger 12 | from libs.dataset import make_dataloader 13 | from nnet import VoiceFilter 14 | from conf import trainer_conf, nnet_conf, train_data, dev_data 15 | 16 | 17 | def run(args): 18 | gpuids = tuple(map(int, args.gpus.split(","))) 19 | 20 | if args.checkpoint: 21 | os.makedirs(args.checkpoint, exist_ok=True) 22 | 23 | logger = get_logger(os.path.join(args.checkpoint, "trainer.log"), 24 | file=True) 25 | logger.info("Arguments in command:\n{}".format(pprint.pformat(vars(args)))) 26 | 27 | nnet = VoiceFilter(**nnet_conf) 28 | trainer = SiSnrTrainer(nnet, 29 | gpuid=gpuids, 30 | checkpoint=args.checkpoint, 31 | resume=args.resume, 32 | logger=logger, 33 | **trainer_conf) 34 | 35 | data_conf = { 36 | "train": train_data, 37 | "dev": dev_data, 38 | } 39 | # dump configs 40 | for conf, fname in zip([nnet_conf, trainer_conf, data_conf], 41 | ["mdl.json", "trainer.json", "data.json"]): 42 | with open(os.path.join(args.checkpoint, fname), "w") as f: 43 | json.dump(conf, f, indent=4, sort_keys=False) 44 | 45 | train_loader = make_dataloader(train=True, 46 | data_kwargs=train_data, 47 | batch_size=args.batch_size, 48 | cache_size=args.cache_size, 49 | chunk_size=args.chunk_size) 50 | dev_loader = make_dataloader(train=False, 51 | data_kwargs=dev_data, 52 | batch_size=args.batch_size, 53 | cache_size=args.cache_size, 54 | chunk_size=args.chunk_size) 55 | 56 | trainer.run(train_loader, 57 | dev_loader, 58 | eval_interval=args.eval_interval, 59 | num_epoches=args.epoches) 60 | 61 | 62 | if __name__ == "__main__": 63 | parser = argparse.ArgumentParser( 64 | description= 65 | "Command to start train voice-filter, configured from conf.py", 66 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 67 | parser.add_argument("--epoches", 68 | type=int, 69 | default=50, 70 | help="Number of training epoches") 71 | parser.add_argument("--gpus", 72 | type=str, 73 | default="0,1", 74 | help="Training on which GPUs (one or more, egs " 75 | "0, 0,1)") 76 | parser.add_argument("--eval-interval", 77 | type=int, 78 | default=3000, 79 | help="Number of batches trained per epoch (for larger " 80 | "training dataset)") 81 | parser.add_argument("--checkpoint", 82 | type=str, 83 | required=True, 84 | help="Directory to dump models") 85 | parser.add_argument("--resume", 86 | type=str, 87 | default="", 88 | help="Exist model to resume training from") 89 | parser.add_argument("--chunk-size", 90 | type=int, 91 | default=64256, 92 | help="Chunk size to feed networks") 93 | parser.add_argument("--batch-size", 94 | type=int, 95 | default=16, 96 | help="Number of utterances in each batch") 97 | parser.add_argument("--cache-size", 98 | type=int, 99 | default=16, 100 | help="Number of chunks cached in dataloader") 101 | args = parser.parse_args() 102 | run(args) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scipy==1.2.1 2 | tqdm==4.31.1 3 | torch==1.0.1.post2 4 | numpy==1.16.2 5 | kaldi-python-io==1.0.0 6 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # wujian@2019 3 | 4 | set -eu 5 | 6 | epoches=100 7 | batch_size=32 8 | cache_size=8 9 | chunk_size=64256 10 | eval_interval=3000 11 | 12 | echo "$0 $@" 13 | 14 | [ $# -ne 2 ] && echo "Script format error: $0 " && exit 1 15 | 16 | exp_id=$1 17 | gpu_id=$2 18 | 19 | ./nnet/train.py \ 20 | --gpu $gpu_id \ 21 | --checkpoint exp/nnet/$exp_id \ 22 | --batch-size $batch_size \ 23 | --cache-size $cache_size \ 24 | --chunk-size $chunk_size \ 25 | --epoches $epoches \ 26 | --eval-interval $eval_interval \ 27 | > $exp_id.train.log 2>&1 28 | --------------------------------------------------------------------------------