├── .gitignore ├── LICENSE ├── README.md ├── eend ├── bin │ ├── infer.py │ ├── make_mixture.py │ ├── make_mixture_nooverlap.py │ ├── make_rttm.py │ ├── model_averaging.py │ ├── random_mixture.py │ ├── random_mixture_nooverlap.py │ ├── rttm_stats.py │ ├── train.py │ ├── visualize_attention.py │ └── yaml2bash.py ├── chainer_backend │ ├── diarization_dataset.py │ ├── encoder_decoder_attractor.py │ ├── infer.py │ ├── models.py │ ├── train.py │ ├── transformer.py │ ├── updater.py │ └── utils.py ├── feature.py ├── kaldi_data.py └── system_info.py ├── egs ├── callhome │ └── v1 │ │ ├── cmd.sh │ │ ├── conf │ │ ├── adapt.yaml │ │ ├── blstm │ │ │ ├── adapt.yaml │ │ │ ├── infer.yaml │ │ │ └── train.yaml │ │ ├── debug │ │ │ ├── adapt.yaml │ │ │ └── train.yaml │ │ ├── eda │ │ │ ├── adapt.yaml │ │ │ ├── infer.yaml │ │ │ ├── train.yaml │ │ │ └── train_2spk.yaml │ │ ├── infer.yaml │ │ ├── mfcc.conf │ │ ├── mfcc_hires.conf │ │ └── train.yaml │ │ ├── local │ │ ├── make_callhome.sh │ │ ├── make_musan.py │ │ ├── make_musan.sh │ │ ├── make_sre.pl │ │ ├── make_sre.sh │ │ ├── make_swbd2_phase1.pl │ │ ├── make_swbd2_phase2.pl │ │ ├── make_swbd2_phase3.pl │ │ ├── make_swbd_cellular1.pl │ │ ├── make_swbd_cellular2.pl │ │ └── run_blstm.sh │ │ ├── path.sh │ │ ├── run.sh │ │ ├── run_eda.sh │ │ ├── run_prepare_shared.sh │ │ ├── run_prepare_shared_eda.sh │ │ ├── steps │ │ └── utils └── mini_librispeech │ └── v1 │ ├── RESULT.md │ ├── cmd.sh │ ├── conf │ ├── blstm │ │ ├── infer.yaml │ │ └── train.yaml │ ├── eda │ │ ├── infer.yaml │ │ └── train.yaml │ ├── infer.yaml │ ├── mfcc.conf │ ├── mfcc_hires.conf │ └── train.yaml │ ├── local │ ├── data_prep.sh │ ├── download_and_untar.sh │ └── run_blstm.sh │ ├── musan_bgnoise.tar.gz │ ├── path.sh │ ├── run.sh │ ├── run_prepare_shared.sh │ ├── steps │ └── utils ├── tools ├── Makefile ├── env.sh.in └── environment.yml └── utils └── best_score.sh /.gitignore: -------------------------------------------------------------------------------- 1 | egs/**/data/ 2 | egs/**/exp/ 3 | tools/kaldi 4 | tools/miniconda3.sh 5 | tools/miniconda3/ 6 | tools/sctk 7 | tools/sctk-2.4.10-20151007-1312Z.tar.bz2 8 | tools/sctk-2.4.10/ 9 | tools/env.sh 10 | .nfs* 11 | 12 | # Byte-compiled / optimized / DLL files 13 | __pycache__/ 14 | *.py[cod] 15 | *$py.class 16 | 17 | # C extensions 18 | *.so 19 | 20 | # Distribution / packaging 21 | .Python 22 | build/ 23 | develop-eggs/ 24 | dist/ 25 | downloads/ 26 | eggs/ 27 | .eggs/ 28 | lib/ 29 | lib64/ 30 | parts/ 31 | sdist/ 32 | var/ 33 | wheels/ 34 | *.egg-info/ 35 | .installed.cfg 36 | *.egg 37 | MANIFEST 38 | 39 | # PyInstaller 40 | # Usually these files are written by a python script from a template 41 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 42 | *.manifest 43 | *.spec 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | 49 | # Unit test / coverage reports 50 | htmlcov/ 51 | .tox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # celery beat schedule file 90 | celerybeat-schedule 91 | 92 | # SageMath parsed files 93 | *.sage.py 94 | 95 | # Environments 96 | .env 97 | .venv 98 | env/ 99 | venv/ 100 | ENV/ 101 | env.bak/ 102 | venv.bak/ 103 | 104 | # Spyder project settings 105 | .spyderproject 106 | .spyproject 107 | 108 | # Rope project settings 109 | .ropeproject 110 | 111 | # mkdocs documentation 112 | /site 113 | 114 | # mypy 115 | .mypy_cache/ 116 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Hitachi, Ltd. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EEND (End-to-End Neural Diarization) 2 | 3 | EEND (End-to-End Neural Diarization) is a neural-network-based speaker diarization method. 4 | - BLSTM EEND (INTERSPEECH 2019) 5 | - https://www.isca-speech.org/archive/Interspeech_2019/abstracts/2899.html 6 | - Self-attentive EEND (ASRU 2019) 7 | - https://ieeexplore.ieee.org/abstract/document/9003959/ 8 | 9 | The EEND extension for various number of speakers is also provided in this repository. 10 | - Self-attentive EEND with encoder-decoder based attractors 11 | - https://arxiv.org/abs/2005.09921 12 | 13 | ## Install tools 14 | ### Requirements 15 | - NVIDIA CUDA GPU 16 | - CUDA Toolkit (8.0 <= version <= 10.1) 17 | 18 | ### Install kaldi and python environment 19 | ```bash 20 | cd tools 21 | make 22 | ``` 23 | - This command builds kaldi at `tools/kaldi` 24 | - if you want to use pre-build kaldi 25 | ```bash 26 | cd tools 27 | make KALDI= 28 | ``` 29 | This option make a symlink at `tools/kaldi` 30 | - This command extracts miniconda3 at `tools/miniconda3`, and creates conda envirionment named 'eend' 31 | - Then, installs Chainer and cupy into 'eend' environment 32 | - use CUDA in `/usr/local/cuda/` 33 | - if you need to specify your CUDA path 34 | ```bash 35 | cd tools 36 | make CUDA_PATH=/your/path/to/cuda-8.0 37 | ``` 38 | This command installs cupy-cudaXX according to your CUDA version. 39 | See https://docs-cupy.chainer.org/en/stable/install.html#install-cupy 40 | 41 | ## Test recipe (mini_librispeech) 42 | ### Configuration 43 | - Modify `egs/mini_librispeech/v1/cmd.sh` according to your job schedular. 44 | If you use your local machine, use "run.pl". 45 | If you use Grid Engine, use "queue.pl" 46 | If you use SLURM, use "slurm.pl". 47 | For more information about cmd.sh see http://kaldi-asr.org/doc/queue.html. 48 | ### Data preparation 49 | ```bash 50 | cd egs/mini_librispeech/v1 51 | ./run_prepare_shared.sh 52 | ``` 53 | ### Run training, inference, and scoring 54 | ```bash 55 | ./run.sh 56 | ``` 57 | - If you use encoder-decoder based attractors [3], modify `run.sh` to use `config/eda/{train,infer}.yaml` 58 | - See `RESULT.md` and compare with your result. 59 | 60 | ## CALLHOME two-speaker experiment 61 | ### Configuraition 62 | - Modify `egs/callhome/v1/cmd.sh` according to your job schedular. 63 | If you use your local machine, use "run.pl". 64 | If you use Grid Engine, use "queue.pl" 65 | If you use SLURM, use "slurm.pl". 66 | For more information about cmd.sh see http://kaldi-asr.org/doc/queue.html. 67 | - Modify `egs/callhome/v1/run_prepare_shared.sh` according to storage paths of your corpora. 68 | 69 | ### Data preparation 70 | ```bash 71 | cd egs/callhome/v1 72 | ./run_prepare_shared.sh 73 | # If you want to conduct 1-4 speaker experiments, run below. 74 | # You also have to set paths to your corpora properly. 75 | ./run_prepare_shared_eda.sh 76 | ``` 77 | ### Self-attention-based model using 2-speaker mixtures 78 | ```bash 79 | ./run.sh 80 | ``` 81 | ### BLSTM-based model using 2-speaker mixtures 82 | ```bash 83 | local/run_blstm.sh 84 | ``` 85 | ### Self-attention-based model with EDA using 1-4-speaker mixtures 86 | ```bash 87 | ./run_eda.sh 88 | ``` 89 | 90 | ## References 91 | [1] Yusuke Fujita, Naoyuki Kanda, Shota Horiguchi, Kenji Nagamatsu, Shinji Watanabe, " 92 | End-to-End Neural Speaker Diarization with Permutation-free Objectives," Proc. Interspeech, pp. 4300-4304, 2019 93 | 94 | [2] Yusuke Fujita, Naoyuki Kanda, Shota Horiguchi, Yawen Xue, Kenji Nagamatsu, Shinji Watanabe, " 95 | End-to-End Neural Speaker Diarization with Self-attention," Proc. ASRU, pp. 296-303, 2019 96 | 97 | [3] Shota Horiguchi, Yusuke Fujita, Shinji Watanabe, Yawen Xue, Kenji Nagamatsu, " 98 | End-to-End Speaker Diarization for an Unknown Number of Speakers with Encoder-Decoder Based Attractors," Proc. INTERSPEECH, 2020 99 | 100 | 101 | 102 | ## Citation 103 | ``` 104 | @inproceedings{Fujita2019Interspeech, 105 | author={Yusuke Fujita and Naoyuki Kanda and Shota Horiguchi and Kenji Nagamatsu and Shinji Watanabe}, 106 | title={{End-to-End Neural Speaker Diarization with Permutation-free Objectives}}, 107 | booktitle={Interspeech}, 108 | pages={4300--4304} 109 | year=2019 110 | } 111 | ``` 112 | -------------------------------------------------------------------------------- /eend/bin/infer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita) 4 | # Licensed under the MIT license. 5 | # 6 | import yamlargparse 7 | from eend import system_info 8 | 9 | parser = yamlargparse.ArgumentParser(description='decoding') 10 | parser.add_argument('-c', '--config', help='config file path', 11 | action=yamlargparse.ActionConfigFile) 12 | parser.add_argument('data_dir', 13 | help='kaldi-style data dir') 14 | parser.add_argument('model_file', 15 | help='best.nnet') 16 | parser.add_argument('out_dir', 17 | help='output directory.') 18 | parser.add_argument('--backend', default='chainer', 19 | choices=['chainer', 'pytorch'], 20 | help='backend framework') 21 | parser.add_argument('--model_type', default='LSTM', type=str) 22 | parser.add_argument('--gpu', type=int, default=-1) 23 | parser.add_argument('--num-speakers', type=int, default=4) 24 | parser.add_argument('--hidden-size', default=256, type=int, 25 | help='number of lstm output nodes') 26 | parser.add_argument('--num-lstm-layers', default=1, type=int, 27 | help='number of lstm layers') 28 | parser.add_argument('--input-transform', default='', 29 | choices=['', 'log', 'logmel', 30 | 'logmel23', 'logmel23_swn', 'logmel23_mn'], 31 | help='input transform') 32 | parser.add_argument('--embedding-size', default=256, type=int) 33 | parser.add_argument('--embedding-layers', default=2, type=int) 34 | parser.add_argument('--chunk-size', default=2000, type=int, 35 | help='input is chunked with this size') 36 | parser.add_argument('--context-size', default=0, type=int, 37 | help='frame splicing') 38 | parser.add_argument('--subsampling', default=1, type=int) 39 | parser.add_argument('--sampling-rate', default=16000, type=int, 40 | help='sampling rate') 41 | parser.add_argument('--frame-size', default=1024, type=int, 42 | help='frame size') 43 | parser.add_argument('--frame-shift', default=256, type=int, 44 | help='frame shift') 45 | parser.add_argument('--transformer-encoder-n-heads', default=4, type=int) 46 | parser.add_argument('--transformer-encoder-n-layers', default=2, type=int) 47 | parser.add_argument('--save-attention-weight', default=0, type=int) 48 | 49 | attractor_args = parser.add_argument_group('attractor') 50 | attractor_args.add_argument('--use-attractor', action='store_true', 51 | help='Enable encoder-decoder attractor mode') 52 | attractor_args.add_argument('--shuffle', action='store_true', 53 | help='Shuffle the order in time-axis before input to the network') 54 | attractor_args.add_argument('--attractor-loss-ratio', default=1.0, type=float, 55 | help='weighting parameter') 56 | attractor_args.add_argument('--attractor-encoder-dropout', default=0.1, type=float) 57 | attractor_args.add_argument('--attractor-decoder-dropout', default=0.1, type=float) 58 | attractor_args.add_argument('--attractor-threshold', default=0.5, type=float) 59 | args = parser.parse_args() 60 | 61 | system_info.print_system_info() 62 | print(args) 63 | if args.backend == 'chainer': 64 | from eend.chainer_backend.infer import infer 65 | infer(args) 66 | elif args.backend == 'pytorch': 67 | # TODO 68 | # from eend.pytorch_backend.infer import infer 69 | # infer(args) 70 | raise NotImplementedError() 71 | else: 72 | raise ValueError() 73 | -------------------------------------------------------------------------------- /eend/bin/make_mixture.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita) 4 | # Licensed under the MIT license. 5 | # 6 | # This script generates simulated multi-talker mixtures for diarization 7 | # 8 | # common/make_mixture.py \ 9 | # mixture.scp \ 10 | # data/mixture \ 11 | # wav/mixture 12 | 13 | 14 | import argparse 15 | import os 16 | from eend import kaldi_data 17 | import numpy as np 18 | import math 19 | import soundfile as sf 20 | import json 21 | 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('script', 24 | help='list of json') 25 | parser.add_argument('out_data_dir', 26 | help='output data dir of mixture') 27 | parser.add_argument('out_wav_dir', 28 | help='output mixture wav files are stored here') 29 | parser.add_argument('--rate', type=int, default=16000, 30 | help='sampling rate') 31 | args = parser.parse_args() 32 | 33 | # open output data files 34 | segments_f = open(args.out_data_dir + '/segments', 'w') 35 | utt2spk_f = open(args.out_data_dir + '/utt2spk', 'w') 36 | wav_scp_f = open(args.out_data_dir + '/wav.scp', 'w') 37 | 38 | # "-R" forces the default random seed for reproducibility 39 | resample_cmd = "sox -R -t wav - -t wav - rate {}".format(args.rate) 40 | 41 | for line in open(args.script): 42 | recid, jsonstr = line.strip().split(None, 1) 43 | indata = json.loads(jsonstr) 44 | wavfn = indata['recid'] 45 | # recid now include out_wav_dir 46 | recid = os.path.join(args.out_wav_dir, wavfn).replace('/','_') 47 | noise = indata['noise'] 48 | noise_snr = indata['snr'] 49 | mixture = [] 50 | for speaker in indata['speakers']: 51 | spkid = speaker['spkid'] 52 | utts = speaker['utts'] 53 | intervals = speaker['intervals'] 54 | rir = speaker['rir'] 55 | data = [] 56 | pos = 0 57 | for interval, utt in zip(intervals, utts): 58 | # append silence interval data 59 | silence = np.zeros(int(interval * args.rate)) 60 | data.append(silence) 61 | # utterance is reverberated using room impulse response 62 | preprocess = "wav-reverberate --print-args=false " \ 63 | " --impulse-response={} - -".format(rir) 64 | if isinstance(utt, list): 65 | rec, st, et = utt 66 | st = np.rint(st * args.rate).astype(int) 67 | et = np.rint(et * args.rate).astype(int) 68 | else: 69 | rec = utt 70 | st = 0 71 | et = None 72 | if rir is not None: 73 | wav_rxfilename = kaldi_data.process_wav(rec, preprocess) 74 | else: 75 | wav_rxfilename = rec 76 | wav_rxfilename = kaldi_data.process_wav( 77 | wav_rxfilename, resample_cmd) 78 | speech, _ = kaldi_data.load_wav(wav_rxfilename, st, et) 79 | data.append(speech) 80 | # calculate start/end position in samples 81 | startpos = pos + len(silence) 82 | endpos = startpos + len(speech) 83 | # write segments and utt2spk 84 | uttid = '{}_{}_{:07d}_{:07d}'.format( 85 | spkid, recid, int(startpos / args.rate * 100), 86 | int(endpos / args.rate * 100)) 87 | print(uttid, recid, 88 | startpos / args.rate, endpos / args.rate, file=segments_f) 89 | print(uttid, spkid, file=utt2spk_f) 90 | # update position for next utterance 91 | pos = endpos 92 | data = np.concatenate(data) 93 | mixture.append(data) 94 | 95 | # fitting to the maximum-length speaker data, then mix all speakers 96 | maxlen = max(len(x) for x in mixture) 97 | mixture = [np.pad(x, (0, maxlen - len(x)), 'constant') for x in mixture] 98 | mixture = np.sum(mixture, axis=0) 99 | # noise is repeated or cutted for fitting to the mixture data length 100 | noise_resampled = kaldi_data.process_wav(noise, resample_cmd) 101 | noise_data, _ = kaldi_data.load_wav(noise_resampled) 102 | if maxlen > len(noise_data): 103 | noise_data = np.pad(noise_data, (0, maxlen - len(noise_data)), 'wrap') 104 | else: 105 | noise_data = noise_data[:maxlen] 106 | # noise power is scaled according to selected SNR, then mixed 107 | signal_power = np.sum(mixture**2) / len(mixture) 108 | noise_power = np.sum(noise_data**2) / len(noise_data) 109 | scale = math.sqrt( 110 | math.pow(10, - noise_snr / 10) * signal_power / noise_power) 111 | mixture += noise_data * scale 112 | # output the wav file and write wav.scp 113 | outfname = '{}.wav'.format(wavfn) 114 | outpath = os.path.join(args.out_wav_dir, outfname) 115 | sf.write(outpath, mixture, args.rate) 116 | print(recid, os.path.abspath(outpath), file=wav_scp_f) 117 | 118 | wav_scp_f.close() 119 | segments_f.close() 120 | utt2spk_f.close() 121 | -------------------------------------------------------------------------------- /eend/bin/make_mixture_nooverlap.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita) 4 | # Licensed under the MIT license. 5 | # 6 | # This script generates simulated multi-talker mixtures for diarization 7 | # (No speaker overlaps) 8 | # 9 | # common/make_mixture_nooverlap.py \ 10 | # mixture.scp \ 11 | # data/mixture \ 12 | # wav/mixture 13 | 14 | 15 | import argparse 16 | import os 17 | import kaldi_data 18 | import numpy as np 19 | import math 20 | import soundfile as sf 21 | import json 22 | 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('script', 25 | help='list of json') 26 | parser.add_argument('out_data_dir', 27 | help='output data dir of mixture') 28 | parser.add_argument('out_wav_dir', 29 | help='output mixture wav files are stored here') 30 | parser.add_argument('--rate', type=int, default=16000, 31 | help='sampling rate') 32 | args = parser.parse_args() 33 | 34 | # open output data files 35 | segments_f = open(args.out_data_dir + '/segments', 'w') 36 | utt2spk_f = open(args.out_data_dir + '/utt2spk', 'w') 37 | wav_scp_f = open(args.out_data_dir + '/wav.scp', 'w') 38 | 39 | # outputs are resampled at target sample rate 40 | resample_cmd = "sox -t wav - -t wav - rate {}".format(args.rate) 41 | 42 | for line in open(args.script): 43 | recid, jsonstr = line.strip().split(None, 1) 44 | indata = json.loads(jsonstr) 45 | recid = indata['recid'] 46 | noise = indata['noise'] 47 | noise_snr = indata['snr'] 48 | mixture = [] 49 | data = [] 50 | pos = 0 51 | for utt in indata['utts']: 52 | spkid = utt['spkid'] 53 | wav = utt['utt'] 54 | interval = utt['interval'] 55 | rir = utt['rir'] 56 | st = 0 57 | et = None 58 | if 'st' in utt: 59 | st = np.rint(utt['st'] * args.rate).astype(int) 60 | if 'et' in utt: 61 | et = np.rint(utt['et'] * args.rate).astype(int) 62 | silence = np.zeros(int(interval * args.rate)) 63 | data.append(silence) 64 | # utterance is reverberated using room impulse response 65 | if rir: 66 | preprocess = "wav-reverberate --print-args=false " \ 67 | " --impulse-response={} - -".format(rir) 68 | wav_rxfilename = kaldi_data.process_wav(wav, preprocess) 69 | else: 70 | wav_rxfilename = wav 71 | wav_rxfilename = kaldi_data.process_wav(wav_rxfilename, resample_cmd) 72 | speech, _ = kaldi_data.load_wav(wav_rxfilename, st, et) 73 | data.append(speech) 74 | # calculate start/end position in samples 75 | startpos = pos + len(silence) 76 | endpos = startpos + len(speech) 77 | # write segments and utt2spk 78 | uttid = '{}_{}_{:07d}_{:07d}'.format( 79 | spkid, recid, int(startpos / args.rate * 100), 80 | int(endpos / args.rate * 100)) 81 | print(uttid, recid, 82 | startpos / args.rate, endpos / args.rate, file=segments_f) 83 | print(uttid, spkid, file=utt2spk_f) 84 | pos = endpos 85 | mixture = np.concatenate(data) 86 | maxlen = len(mixture) 87 | # noise is repeated or cutted for fitting to the mixture data length 88 | noise_resampled = kaldi_data.process_wav(noise, resample_cmd) 89 | noise_data, _ = kaldi_data.load_wav(noise_resampled) 90 | if maxlen > len(noise_data): 91 | noise_data = np.pad(noise_data, (0, maxlen - len(noise_data)), 'wrap') 92 | else: 93 | noise_data = noise_data[:maxlen] 94 | # noise power is scaled according to selected SNR, then mixed 95 | signal_power = np.sum(mixture**2) / len(mixture) 96 | noise_power = np.sum(noise_data**2) / len(noise_data) 97 | scale = math.sqrt( 98 | math.pow(10, - noise_snr / 10) * signal_power / noise_power) 99 | mixture += noise_data * scale 100 | # output the wav file and write wav.scp 101 | outfname = '{}.wav'.format(recid) 102 | outpath = os.path.join(args.out_wav_dir, outfname) 103 | sf.write(outpath, mixture, args.rate) 104 | print(recid, os.path.abspath(outpath), file=wav_scp_f) 105 | 106 | wav_scp_f.close() 107 | segments_f.close() 108 | utt2spk_f.close() 109 | -------------------------------------------------------------------------------- /eend/bin/make_rttm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita) 4 | # Licensed under the MIT license. 5 | 6 | import argparse 7 | import h5py 8 | import numpy as np 9 | import os 10 | from scipy.signal import medfilt 11 | 12 | parser = argparse.ArgumentParser(description='make rttm from decoded result') 13 | parser.add_argument('file_list_hdf5') 14 | parser.add_argument('out_rttm_file') 15 | parser.add_argument('--threshold', default=0.5, type=float) 16 | parser.add_argument('--frame_shift', default=256, type=int) 17 | parser.add_argument('--subsampling', default=1, type=int) 18 | parser.add_argument('--median', default=1, type=int) 19 | parser.add_argument('--sampling_rate', default=16000, type=int) 20 | args = parser.parse_args() 21 | 22 | filepaths = [line.strip() for line in open(args.file_list_hdf5)] 23 | filepaths.sort() 24 | 25 | with open(args.out_rttm_file, 'w') as wf: 26 | for filepath in filepaths: 27 | session, _ = os.path.splitext(os.path.basename(filepath)) 28 | data = h5py.File(filepath, 'r') 29 | a = np.where(data['T_hat'][:] > args.threshold, 1, 0) 30 | if args.median > 1: 31 | a = medfilt(a, (args.median, 1)) 32 | for spkid, frames in enumerate(a.T): 33 | frames = np.pad(frames, (1, 1), 'constant') 34 | changes, = np.where(np.diff(frames, axis=0) != 0) 35 | fmt = "SPEAKER {:s} 1 {:7.2f} {:7.2f} {:s} " 36 | for s, e in zip(changes[::2], changes[1::2]): 37 | print(fmt.format( 38 | session, 39 | s * args.frame_shift * args.subsampling / args.sampling_rate, 40 | (e - s) * args.frame_shift * args.subsampling / args.sampling_rate, 41 | session + "_" + str(spkid)), file=wf) 42 | -------------------------------------------------------------------------------- /eend/bin/model_averaging.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita) 4 | # Licensed under the MIT license. 5 | # 6 | # averaging chainer serialized models 7 | 8 | import numpy as np 9 | import argparse 10 | 11 | 12 | def average_model_chainer(ifiles, ofile): 13 | omodel = {} 14 | # get keys from the first file 15 | model = np.load(ifiles[0]) 16 | for x in model: 17 | if 'model' in x: 18 | print(x) 19 | keys = [x.split('main/')[1] for x in model if 'model' in x] 20 | print(keys) 21 | for path in ifiles: 22 | model = np.load(path) 23 | for key in keys: 24 | val = model['updater/model:main/{}'.format(key)] 25 | if key not in omodel: 26 | omodel[key] = val 27 | else: 28 | omodel[key] += val 29 | for key in keys: 30 | omodel[key] /= len(ifiles) 31 | np.savez_compressed(ofile, **omodel) 32 | 33 | 34 | if __name__ == '__main__': 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument("ofile") 37 | parser.add_argument("ifiles", nargs='+') 38 | parser.add_argument("--backend", default='chainer', 39 | choices=['chainer', 'pytorch']) 40 | args = parser.parse_args() 41 | if args.backend == 'chainer': 42 | average_model_chainer(args.ifiles, args.ofile) 43 | elif args.backend == 'pytorch': 44 | # TODO 45 | # average_model_pytorch(args.ifiles, args.ofile) 46 | raise NotImplementedError() 47 | else: 48 | raise ValueError() 49 | -------------------------------------------------------------------------------- /eend/bin/random_mixture.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita) 4 | # Licensed under the MIT license. 5 | 6 | """ 7 | This script generates random multi-talker mixtures for diarization. 8 | It generates a scp-like outputs: lines of "[recid] [json]". 9 | recid: recording id of mixture 10 | serial numbers like mix_0000001, mix_0000002, ... 11 | json: mixture configuration formatted in "one-line" 12 | The json format is as following: 13 | { 14 | 'speakers':[ # list of speakers 15 | { 16 | 'spkid': 'Name', # speaker id 17 | 'rir': '/rirdir/rir.wav', # wav_rxfilename of room impulse response 18 | 'utts': [ # list of wav_rxfilenames of utterances 19 | '/wavdir/utt1.wav', 20 | '/wavdir/utt2.wav',...], 21 | 'intervals': [1.2, 3.4, ...] # list of silence durations before utterances 22 | }, ... ], 23 | 'noise': '/noisedir/noise.wav' # wav_rxfilename of background noise 24 | 'snr': 15.0, # SNR for mixing background noise 25 | 'recid': 'mix_000001' # recording id of the mixture 26 | } 27 | 28 | Usage: 29 | common/random_mixture.py \ 30 | --n_mixtures=10000 \ # number of mixtures 31 | data/voxceleb1_train \ # kaldi-style data dir of utterances 32 | data/musan_noise_bg \ # background noises 33 | data/simu_rirs \ # room impulse responses 34 | > mixture.scp # output scp-like file 35 | 36 | The actual data dir and wav files are generated using make_mixture.py: 37 | common/make_mixture.py \ 38 | mixture.scp \ # scp-like file for mixture 39 | data/mixture \ # output data dir 40 | wav/mixture # output wav dir 41 | """ 42 | 43 | import argparse 44 | import os 45 | from eend import kaldi_data 46 | import random 47 | import numpy as np 48 | import json 49 | import itertools 50 | 51 | parser = argparse.ArgumentParser() 52 | parser.add_argument('data_dir', 53 | help='data dir of single-speaker recordings') 54 | parser.add_argument('noise_dir', 55 | help='data dir of background noise recordings') 56 | parser.add_argument('rir_dir', 57 | help='data dir of room impulse responses') 58 | parser.add_argument('--n_mixtures', type=int, default=10, 59 | help='number of mixture recordings') 60 | parser.add_argument('--n_speakers', type=int, default=4, 61 | help='number of speakers in a mixture') 62 | parser.add_argument('--min_utts', type=int, default=10, 63 | help='minimum number of uttenraces per speaker') 64 | parser.add_argument('--max_utts', type=int, default=20, 65 | help='maximum number of utterances per speaker') 66 | parser.add_argument('--sil_scale', type=float, default=10.0, 67 | help='average silence time') 68 | parser.add_argument('--noise_snrs', default="5:10:15:20", 69 | help='colon-delimited SNRs for background noises') 70 | parser.add_argument('--random_seed', type=int, default=777, 71 | help='random seed') 72 | parser.add_argument('--speech_rvb_probability', type=float, default=1, 73 | help='reverb probability') 74 | args = parser.parse_args() 75 | 76 | random.seed(args.random_seed) 77 | np.random.seed(args.random_seed) 78 | 79 | # load list of wav files from kaldi-style data dirs 80 | wavs = kaldi_data.load_wav_scp( 81 | os.path.join(args.data_dir, 'wav.scp')) 82 | noises = kaldi_data.load_wav_scp( 83 | os.path.join(args.noise_dir, 'wav.scp')) 84 | rirs = kaldi_data.load_wav_scp( 85 | os.path.join(args.rir_dir, 'wav.scp')) 86 | 87 | # spk2utt is used for counting number of utterances per speaker 88 | spk2utt = kaldi_data.load_spk2utt( 89 | os.path.join(args.data_dir, 'spk2utt')) 90 | 91 | segments = kaldi_data.load_segments_hash( 92 | os.path.join(args.data_dir, 'segments')) 93 | 94 | # choice lists for random sampling 95 | all_speakers = list(spk2utt.keys()) 96 | all_noises = list(noises.keys()) 97 | all_rirs = list(rirs.keys()) 98 | noise_snrs = [float(x) for x in args.noise_snrs.split(':')] 99 | 100 | mixtures = [] 101 | for it in range(args.n_mixtures): 102 | # recording ids are mix_0000001, mix_0000002, ... 103 | recid = 'mix_{:07d}'.format(it + 1) 104 | # randomly select speakers, a background noise and a SNR 105 | speakers = random.sample(all_speakers, args.n_speakers) 106 | noise = random.choice(all_noises) 107 | noise_snr = random.choice(noise_snrs) 108 | mixture = {'speakers': []} 109 | for speaker in speakers: 110 | # randomly select the number of utterances 111 | n_utts = np.random.randint(args.min_utts, args.max_utts + 1) 112 | # utts = spk2utt[speaker][:n_utts] 113 | cycle_utts = itertools.cycle(spk2utt[speaker]) 114 | # random start utterance 115 | roll = np.random.randint(0, len(spk2utt[speaker])) 116 | for i in range(roll): 117 | next(cycle_utts) 118 | utts = [next(cycle_utts) for i in range(n_utts)] 119 | # randomly select wait time before appending utterance 120 | intervals = np.random.exponential(args.sil_scale, size=n_utts) 121 | # randomly select a room impulse response 122 | if random.random() < args.speech_rvb_probability: 123 | rir = rirs[random.choice(all_rirs)] 124 | else: 125 | rir = None 126 | if segments is not None: 127 | utts = [segments[utt] for utt in utts] 128 | utts = [(wavs[rec], st, et) for (rec, st, et) in utts] 129 | mixture['speakers'].append({ 130 | 'spkid': speaker, 131 | 'rir': rir, 132 | 'utts': utts, 133 | 'intervals': intervals.tolist() 134 | }) 135 | else: 136 | mixture['speakers'].append({ 137 | 'spkid': speaker, 138 | 'rir': rir, 139 | 'utts': [wavs[utt] for utt in utts], 140 | 'intervals': intervals.tolist() 141 | }) 142 | mixture['noise'] = noises[noise] 143 | mixture['snr'] = noise_snr 144 | mixture['recid'] = recid 145 | print(recid, json.dumps(mixture)) 146 | -------------------------------------------------------------------------------- /eend/bin/random_mixture_nooverlap.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita) 4 | # Licensed under the MIT license. 5 | 6 | """ 7 | This script generates random multi-talker mixtures for diarization. 8 | (No speaker overlaps) 9 | It generates a scp-like outputs: lines of "[recid] [json]". 10 | recid: recording id of mixture 11 | serial numbers like mix_0000001, mix_0000002, ... 12 | json: mixture configuration formatted in "one-line" 13 | The json format is as following: 14 | { 15 | 'speakers':[ # list of speakers 16 | { 17 | 'spkid': 'Name', # speaker id 18 | 'rir': '/rirdir/rir.wav', # wav_rxfilename of room impulse response 19 | 'utts': [ # list of wav_rxfilenames of utterances 20 | '/wavdir/utt1.wav', 21 | '/wavdir/utt2.wav',...], 22 | 'intervals': [1.2, 3.4, ...] # list of silence durations before utterances 23 | }, ... ], 24 | 'noise': '/noisedir/noise.wav' # wav_rxfilename of background noise 25 | 'snr': 15.0, # SNR for mixing background noise 26 | 'recid': 'mix_000001' # recording id of the mixture 27 | } 28 | 29 | Usage: 30 | common/random_mixture.py \ 31 | --n_mixtures=10000 \ # number of mixtures 32 | data/voxceleb1_train \ # kaldi-style data dir of utterances 33 | data/musan_noise_bg \ # background noises 34 | data/simu_rirs \ # room impulse responses 35 | > mixture.scp # output scp-like file 36 | 37 | The actual data dir and wav files are generated using make_mixture.py: 38 | common/make_mixture.py \ 39 | mixture.scp \ # scp-like file for mixture 40 | data/mixture \ # output data dir 41 | wav/mixture # output wav dir 42 | """ 43 | 44 | import argparse 45 | import os 46 | import kaldi_data 47 | import random 48 | import numpy as np 49 | import json 50 | import itertools 51 | 52 | parser = argparse.ArgumentParser() 53 | parser.add_argument('data_dir', 54 | help='data dir of single-speaker recordings') 55 | parser.add_argument('noise_dir', 56 | help='data dir of background noise recordings') 57 | parser.add_argument('rir_dir', 58 | help='data dir of room impulse responses') 59 | parser.add_argument('--n_mixtures', type=int, default=10, 60 | help='number of mixture recordings') 61 | parser.add_argument('--n_speakers', type=int, default=4, 62 | help='number of speakers in a mixture') 63 | parser.add_argument('--min_utts', type=int, default=20, 64 | help='minimum number of uttenraces per speaker') 65 | parser.add_argument('--max_utts', type=int, default=40, 66 | help='maximum number of utterances per speaker') 67 | parser.add_argument('--sil_scale', type=float, default=1.0, 68 | help='average silence time') 69 | parser.add_argument('--noise_snrs', default="10:15:20", 70 | help='colon-delimited SNRs for background noises') 71 | parser.add_argument('--random_seed', type=int, default=777, 72 | help='random seed') 73 | parser.add_argument('--speech_rvb_probability', type=float, default=1, 74 | help='reverb probability') 75 | args = parser.parse_args() 76 | 77 | random.seed(args.random_seed) 78 | np.random.seed(args.random_seed) 79 | 80 | # load list of wav files from kaldi-style data dirs 81 | wavs = kaldi_data.load_wav_scp( 82 | os.path.join(args.data_dir, 'wav.scp')) 83 | noises = kaldi_data.load_wav_scp( 84 | os.path.join(args.noise_dir, 'wav.scp')) 85 | rirs = kaldi_data.load_wav_scp( 86 | os.path.join(args.rir_dir, 'wav.scp')) 87 | 88 | # spk2utt is used for counting number of utterances per speaker 89 | spk2utt = kaldi_data.load_spk2utt( 90 | os.path.join(args.data_dir, 'spk2utt')) 91 | 92 | segments = kaldi_data.load_segments_hash( 93 | os.path.join(args.data_dir, 'segments')) 94 | 95 | # choice lists for random sampling 96 | all_speakers = list(spk2utt.keys()) 97 | all_noises = list(noises.keys()) 98 | all_rirs = list(rirs.keys()) 99 | noise_snrs = [float(x) for x in args.noise_snrs.split(':')] 100 | 101 | mixtures = [] 102 | for it in range(args.n_mixtures): 103 | # recording ids are mix_0000001, mix_0000002, ... 104 | recid = 'mix_{:07d}'.format(it + 1) 105 | # randomly select speakers, a background noise and a SNR 106 | speakers = random.sample(all_speakers, args.n_speakers) 107 | noise = random.choice(all_noises) 108 | noise_snr = random.choice(noise_snrs) 109 | mixture = {'utts': []} 110 | n_utts = np.random.randint(args.min_utts, args.max_utts + 1) 111 | # randomly select wait time before appending utterance 112 | intervals = np.random.exponential(args.sil_scale, size=n_utts) 113 | spk2rir = {} 114 | spk2cycleutts = {} 115 | for speaker in speakers: 116 | # select rvb for each speaker 117 | if random.random() < args.speech_rvb_probability: 118 | spk2rir[speaker] = random.choice(all_rirs) 119 | else: 120 | spk2rir[speaker] = None 121 | spk2cycleutts[speaker] = itertools.cycle(spk2utt[speaker]) 122 | # random start utterance 123 | roll = np.random.randint(0, len(spk2utt[speaker])) 124 | for i in range(roll): 125 | next(spk2cycleutts[speaker]) 126 | # randomly select speaker 127 | for interval in intervals: 128 | speaker = np.random.choice(speakers) 129 | utt = next(spk2cycleutts[speaker]) 130 | # rir = spk2rir[speaker] 131 | if spk2rir[speaker]: 132 | rir = rirs[spk2rir[speaker]] 133 | else: 134 | rir = None 135 | if segments is not None: 136 | rec, st, et = segments[utt] 137 | mixture['utts'].append({ 138 | 'spkid': speaker, 139 | 'rir': rir, 140 | 'utt': wavs[rec], 141 | 'st': st, 142 | 'et': et, 143 | 'interval': interval 144 | }) 145 | else: 146 | mixture['utts'].append({ 147 | 'spkid': speaker, 148 | 'rir': rir, 149 | 'utt': wavs[utt], 150 | 'interval': interval 151 | }) 152 | mixture['noise'] = noises[noise] 153 | mixture['snr'] = noise_snr 154 | mixture['recid'] = recid 155 | print(recid, json.dumps(mixture)) 156 | -------------------------------------------------------------------------------- /eend/bin/rttm_stats.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita) 4 | # Licensed under the MIT license. 5 | 6 | import numpy as np 7 | import argparse 8 | 9 | def load_rttm(rttm_file): 10 | """ load rttm file as numpy structured array """ 11 | segments = [] 12 | for line in open(rttm_file): 13 | toks = line.strip().split() 14 | # number of columns is 9 (RT-05S) or 10 (RT-09S) 15 | (stype, fileid, ch, start, duration, 16 | _, _, speaker, _) = toks[:9] 17 | if stype != "SPEAKER": 18 | continue 19 | start = float(start) 20 | end = start + float(duration) 21 | segments.append((fileid, speaker, start, end)) 22 | return np.array(segments, dtype=[ 23 | ('recid', 'object'), ('speaker', 'object'), ('st', 'f'), ('et', 'f')]) 24 | 25 | def time2frame(t, rate, shift): 26 | """ time in second (float) to frame index (int) """ 27 | return np.rint(t * rate / shift).astype(int) 28 | 29 | def get_frame_labels( 30 | rttm, start=0, end=None, rate=16000, shift=256): 31 | """ Get frame labels from RTTM file 32 | Args: 33 | start: start time in seconds 34 | end: end time in seconds 35 | rate: sampling rate 36 | shift: number of frame shift samples 37 | n_speakers: number of speakers 38 | if None, determined from rttm file 39 | Returns: 40 | labels.T: frame labels 41 | (n_frames, n_speaker)-shaped numpy.int32 array 42 | speakers: list of speaker ids 43 | """ 44 | # sorted uniq speaker ids 45 | speakers = np.unique(rttm['speaker']).tolist() 46 | # start and end frames 47 | rec_sf = time2frame(start, rate, shift) 48 | rec_ef = time2frame(end if end else rttm['et'].max(), rate, shift) 49 | labels = np.zeros((rec_ef - rec_sf, len(speakers)), dtype=np.int32) 50 | for seg in rttm: 51 | seg_sp = speakers.index(seg['speaker']) 52 | seg_sf = time2frame(seg['st'], rate, shift) 53 | seg_ef = time2frame(seg['et'], rate, shift) 54 | # relative frame index from 'rec_sf' 55 | sf = ef = None 56 | if rec_sf <= seg_sf and seg_sf < rec_ef: 57 | sf = seg_sf - rec_sf 58 | if rec_sf < seg_ef and seg_ef <= rec_ef: 59 | ef = seg_ef - rec_sf 60 | if sf is not None or ef is not None: 61 | labels[sf:ef, seg_sp] = 1 62 | return labels.T, speakers 63 | 64 | 65 | parser = argparse.ArgumentParser() 66 | parser.add_argument('rttm') 67 | args = parser.parse_args() 68 | rttm = load_rttm(args.rttm) 69 | 70 | def _min_max_ave(a): 71 | return [f(a) for f in [np.min, np.max, np.mean]] 72 | 73 | vafs = [] 74 | uds = [] 75 | ids = [] 76 | reclens = [] 77 | pres = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) 78 | den = 0 79 | recordings = np.unique(rttm['recid']) 80 | for recid in recordings: 81 | rec = rttm[rttm['recid'] == recid] 82 | speakers = np.unique(rec['speaker']) 83 | for speaker in speakers: 84 | spk = rec[rec['speaker'] == speaker] 85 | spk.sort() 86 | durs = spk['et'] - spk['st'] 87 | stats_dur = _min_max_ave(durs) 88 | uds.append(np.mean(durs)) 89 | if len(durs) > 1: 90 | intervals = spk['st'][1:] - spk['et'][:-1] 91 | stats_int = _min_max_ave(intervals) 92 | ids.append(np.mean(intervals)) 93 | vafs.append(np.sum(durs)/(np.sum(durs) + np.sum(intervals))) 94 | labels, _ = get_frame_labels(rec) 95 | n_presense = np.sum(labels, axis=0) 96 | for n in np.unique(n_presense): 97 | pres[n] += np.sum(n_presense == n) 98 | den += len(n_presense) 99 | #for s in speakers: print(s) 100 | reclens.append(rec['et'].max() - rec['st'].min()) 101 | 102 | print(list(range(2, len(pres)))) 103 | total_speaker = np.sum([n * pres[n] for n in range(len(pres))]) 104 | total_overlap = np.sum([n * pres[n] for n in range(2, len(pres))]) 105 | print(total_speaker, total_overlap, total_overlap/total_speaker) 106 | print("single-speaker overlap", pres[3]/np.sum(pres[2:])) 107 | print(len(recordings), np.mean(reclens), np.mean(vafs), np.mean(uds), np.mean(ids), "overlap ratio:", np.sum(pres[2:])/np.sum(pres[1:]), "overlaps", ' '.join(str(x) for x in pres/den)) 108 | -------------------------------------------------------------------------------- /eend/bin/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita) 4 | # Licensed under the MIT license. 5 | # 6 | import yamlargparse 7 | from eend import system_info 8 | 9 | parser = yamlargparse.ArgumentParser(description='EEND training') 10 | parser.add_argument('-c', '--config', help='config file path', 11 | action=yamlargparse.ActionConfigFile) 12 | parser.add_argument('train_data_dir', 13 | help='kaldi-style data dir used for training.') 14 | parser.add_argument('valid_data_dir', 15 | help='kaldi-style data dir used for validation.') 16 | parser.add_argument('model_save_dir', 17 | help='output directory which model file will be saved in.') 18 | parser.add_argument('--backend', default='chainer', 19 | choices=['chainer', 'pytorch'], 20 | help='backend framework') 21 | parser.add_argument('--model-type', default='Transformer', 22 | help='Type of model (Transformer or BLSTM)') 23 | parser.add_argument('--initmodel', '-m', default='', 24 | help='Initialize the model from given file') 25 | parser.add_argument('--resume', '-r', default='', 26 | help='Resume the optimization from snapshot') 27 | parser.add_argument('--gpu', '-g', default=-1, type=int, 28 | help='GPU ID (negative value indicates CPU)') 29 | parser.add_argument('--max-epochs', default=20, type=int, 30 | help='Max. number of epochs to train') 31 | parser.add_argument('--input-transform', default='', 32 | choices=['', 'log', 'logmel', 'logmel23', 'logmel23_mn', 33 | 'logmel23_mvn', 'logmel23_swn'], 34 | help='input transform') 35 | parser.add_argument('--lr', default=0.001, type=float) 36 | parser.add_argument('--optimizer', default='adam', type=str) 37 | parser.add_argument('--num-speakers', type=int) 38 | parser.add_argument('--gradclip', default=-1, type=int, 39 | help='gradient clipping. if < 0, no clipping') 40 | parser.add_argument('--num-frames', default=2000, type=int, 41 | help='number of frames in one utterance') 42 | parser.add_argument('--batchsize', default=1, type=int, 43 | help='number of utterances in one batch') 44 | parser.add_argument('--label-delay', default=0, type=int, 45 | help='number of frames delayed from original labels' 46 | ' for uni-directional rnn to see in the future') 47 | parser.add_argument('--hidden-size', default=256, type=int, 48 | help='number of lstm output nodes') 49 | parser.add_argument('--num-lstm-layers', default=1, type=int, 50 | help='number of lstm layers') 51 | parser.add_argument('--dc-loss-ratio', default=0.5, type=float) 52 | parser.add_argument('--embedding-layers', default=2, type=int) 53 | parser.add_argument('--embedding-size', default=256, type=int) 54 | parser.add_argument('--context-size', default=0, type=int) 55 | parser.add_argument('--subsampling', default=1, type=int) 56 | parser.add_argument('--frame-size', default=1024, type=int) 57 | parser.add_argument('--frame-shift', default=256, type=int) 58 | parser.add_argument('--sampling-rate', default=16000, type=int) 59 | parser.add_argument('--noam-scale', default=1.0, type=float) 60 | parser.add_argument('--noam-warmup-steps', default=25000, type=float) 61 | parser.add_argument('--transformer-encoder-n-heads', default=4, type=int) 62 | parser.add_argument('--transformer-encoder-n-layers', default=2, type=int) 63 | parser.add_argument('--transformer-encoder-dropout', default=0.1, type=float) 64 | parser.add_argument('--gradient-accumulation-steps', default=1, type=int) 65 | parser.add_argument('--seed', default=777, type=int) 66 | 67 | attractor_args = parser.add_argument_group('attractor') 68 | attractor_args.add_argument('--use-attractor', action='store_true', 69 | help='Enable encoder-decoder attractor mode') 70 | attractor_args.add_argument('--shuffle', action='store_true', 71 | help='Shuffle the order in time-axis before input to the network') 72 | attractor_args.add_argument('--attractor-loss-ratio', default=1.0, type=float, 73 | help='weighting parameter') 74 | attractor_args.add_argument('--attractor-encoder-dropout', default=0.1, type=float) 75 | attractor_args.add_argument('--attractor-decoder-dropout', default=0.1, type=float) 76 | args = parser.parse_args() 77 | 78 | system_info.print_system_info() 79 | print(args) 80 | if args.backend == 'chainer': 81 | from eend.chainer_backend.train import train 82 | train(args) 83 | elif args.backend == 'pytorch': 84 | # TODO 85 | # from eend.pytorch_backend.train import train 86 | # train(args) 87 | raise NotImplementedError() 88 | else: 89 | raise ValueError() 90 | -------------------------------------------------------------------------------- /eend/bin/visualize_attention.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita) 4 | # Licensed under the MIT license. 5 | import argparse 6 | import os 7 | import numpy as np 8 | import matplotlib 9 | import matplotlib.pyplot as plt 10 | from mpl_toolkits.axes_grid1 import make_axes_locatable 11 | 12 | 13 | def load_rttm(rttm_file): 14 | """ load rttm file as numpy structured array """ 15 | segments = [] 16 | for line in open(rttm_file): 17 | toks = line.strip().split() 18 | # number of columns is 9 (RT-05S) or 10 (RT-09S) 19 | (stype, fileid, ch, start, duration, 20 | _, _, speaker, _) = toks[:9] 21 | if stype != "SPEAKER": 22 | continue 23 | start = float(start) 24 | end = start + float(duration) 25 | segments.append((fileid, speaker, start, end)) 26 | return np.array(segments, dtype=[ 27 | ('recid', 'object'), ('speaker', 'object'), ('st', 'f'), ('et', 'f')]) 28 | 29 | 30 | def time2frame(t, rate, shift): 31 | """ time in second (float) to frame index (int) """ 32 | return np.rint(t * rate / shift).astype(int) 33 | 34 | 35 | def get_frame_labels( 36 | rttm_file, recid, start=0, end=None, rate=16000, shift=256): 37 | """ Get frame labels from RTTM file 38 | Args: 39 | rttm_file: RTTM file 40 | recid: Recording id is the 2nd column of RTTM file, 41 | must be identical to the rec id in wav.scp 42 | start: start time in seconds 43 | end: end time in seconds 44 | rate: sampling rate 45 | shift: number of frame shift samples 46 | Returns: 47 | labels.T: frame labels 48 | (n_frames, n_speaker)-shaped numpy.int32 array 49 | speakers: list of speaker ids 50 | """ 51 | rttm = load_rttm(rttm_file) 52 | # filter by recording id 53 | rttm = rttm[rttm['recid'] == recid] 54 | # sorted uniq speaker ids 55 | speakers = np.unique(rttm['speaker']).tolist() 56 | # start and end frames 57 | rec_sf = time2frame(start, rate, shift) 58 | rec_ef = time2frame(end if end else rttm['et'].max(), rate, shift) 59 | labels = np.zeros((rec_ef - rec_sf, len(speakers)), dtype=np.int32) 60 | for seg in rttm: 61 | seg_sp = speakers.index(seg['speaker']) 62 | seg_sf = time2frame(seg['st'], rate, shift) 63 | seg_ef = time2frame(seg['et'], rate, shift) 64 | # relative frame index from 'rec_sf' 65 | sf = ef = None 66 | if rec_sf <= seg_sf and seg_sf < rec_ef: 67 | sf = seg_sf - rec_sf 68 | if rec_sf < seg_ef and seg_ef <= rec_ef: 69 | ef = seg_ef - rec_sf 70 | if seg_sf < rec_sf and rec_ef < seg_ef: 71 | sf = 0 72 | if sf is not None or ef is not None: 73 | labels[sf:ef, seg_sp] = 1 74 | return labels.T, speakers 75 | 76 | 77 | def attention_plot(args): 78 | """ Plots attention weights with reference labels. 79 | """ 80 | # attention weights at specified layer 81 | att_w = np.load(args.att_file)[args.layer, ...] 82 | # extract recid from att_file name, "__.att.npy" 83 | recid = '_'.join(os.path.basename(args.att_file).split('_')[:-2]) 84 | 85 | start_frame = int(os.path.basename(args.att_file).split('_')[-2]) 86 | end_frame = start_frame + att_w.shape[1] 87 | 88 | ref, ref_spks = get_frame_labels( 89 | args.rttm_file, recid, rate=args.rate, 90 | start=start_frame * args.shift / args.rate, 91 | end=end_frame * args.shift / args.rate, 92 | shift=args.shift) 93 | 94 | if args.span: 95 | start, end = [int(x) for x in args.span.split(":")] 96 | else: 97 | start = 0 98 | end = att_w.shape[2] 99 | n_subplots = att_w.shape[0] 100 | 101 | ref_height = 1 102 | 103 | fig, axes = plt.subplots(1, n_subplots, 104 | figsize=(att_w.shape[0] * 4, 4 + ref_height)) 105 | 106 | for i, (ax, aw) in enumerate(zip(axes, att_w)): 107 | divider = make_axes_locatable(ax) 108 | if args.ref_type == 'line': 109 | colors = ['k', 'r'] 110 | # stack figures from bottom to top 111 | for spk, r in reversed(list(enumerate(ref))): 112 | ax_label = divider.append_axes('top', 0.3, pad=0.1, sharex=ax) 113 | ax_label.xaxis.set_tick_params(labelbottom=False) 114 | ax_label.xaxis.set_tick_params(bottom=False) 115 | ax_label.yaxis.set_tick_params(left=False) 116 | ax_label.set_ylim([-0.5, 1.5]) 117 | ax_label.set_yticks(np.arange(2)) 118 | ax_label.set_yticklabels(['silence', 'speech']) 119 | ax_label.set_ylabel('Spk {}'.format(spk+1), 120 | rotation=0, va='center', labelpad=15) 121 | if i > 0: 122 | ax_label.yaxis.set_tick_params(labelleft=False) 123 | ax_label.set_ylabel('') 124 | ax_label.plot(np.arange(r.size), r, lw=1, c=colors[spk == i]) 125 | elif args.ref_type == 'fill': 126 | for spk, r in reversed(list(enumerate(ref, 1))): 127 | ax_label = divider.append_axes('top', 0.2, pad=0.1, sharex=ax) 128 | ax_label.xaxis.set_tick_params(labelbottom=False) 129 | ax_label.yaxis.set_tick_params(labelleft=False) 130 | ax_label.xaxis.set_tick_params(bottom=False, left=False) 131 | ax_label.yaxis.set_tick_params(left=False) 132 | if args.span: 133 | ax_label.set_xlim([start, end]) 134 | ax_label.set_ylabel('Spk {}'.format(spk), 135 | rotation=0, va='center', labelpad=15) 136 | aspect = (end - start) / 40 137 | ax_label.imshow(r[np.newaxis, :], aspect=aspect, cmap='binary') 138 | 139 | ax.imshow(aw, aspect='equal', cmap=args.colormap) 140 | if args.span: 141 | ax.set_xlim([start, end]) 142 | ax.set_ylim([end, start]) 143 | ax.set_xlabel('Key') 144 | ax.set_ylabel('Query') 145 | ax.set_xticks(ax.get_yticks()[1:-1]) 146 | if args.invert_yaxis: 147 | ax.invert_yaxis() 148 | if args.add_title: 149 | ax.set_title('Head {}'.format(i + 1), y=-0.25, fontsize=16) 150 | 151 | if args.add_title: 152 | # manual spacing at bottom 153 | plt.tight_layout(rect=[0, 0.05, 1, 1]) 154 | else: 155 | plt.tight_layout() 156 | plt.savefig(args.pdf_file) 157 | 158 | 159 | if __name__ == '__main__': 160 | parser = argparse.ArgumentParser('Plot attention weight') 161 | parser.add_argument('att_file', 162 | help='Attention weight file.') 163 | parser.add_argument('rttm_file', 164 | help='RTTM file,') 165 | parser.add_argument('pdf_file', 166 | help='Output pdf file.') 167 | parser.add_argument('--colormap', default='binary', 168 | help='colormap for heatmaps: gray, jet, viridis, etc.') 169 | parser.add_argument('--invert_yaxis', action='store_true', 170 | help='invert y-axis in heatmap') 171 | parser.add_argument('--add_title', action='store_true', 172 | help='put captions "Head N" under heatmaps') 173 | parser.add_argument('--ref_type', choices=['line', 'fill'], default='fill', 174 | help='reference label appearance') 175 | parser.add_argument('--span', default='', 176 | help='colon-delimited start/end frame id') 177 | parser.add_argument('--layer', default=1, 178 | help='0-origin layer index') 179 | parser.add_argument('--rate', default=8000, 180 | help='sampleing rate') 181 | parser.add_argument('--shift', default=800, 182 | help='frame-shift * subsampling') 183 | args = parser.parse_args() 184 | 185 | matplotlib.pyplot.switch_backend('Agg') 186 | attention_plot(args) 187 | -------------------------------------------------------------------------------- /eend/bin/yaml2bash.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita) 4 | # Licensed under the MIT license. 5 | 6 | import argparse 7 | import yaml 8 | 9 | 10 | def print_var_assign_statements(obj, prefix=''): 11 | """ Print variable assignment statements from yaml object. 12 | - { key: value } -> key=value 13 | - { parent: { child: value } } -> parent_child=value 14 | - { key: [ val1, val2 ]} -> key_0=val1 key_1=val2 15 | """ 16 | if isinstance(obj, dict): 17 | for key in obj: 18 | child_prefix = prefix + '_' + key if prefix else key 19 | print_var_assign_statements(obj[key], child_prefix) 20 | elif isinstance(obj, list): 21 | for key, val in enumerate(obj): 22 | child_prefix = prefix + '_' + str(key) if prefix else str(key) 23 | print_var_assign_statements(val, child_prefix) 24 | else: 25 | if obj is None: 26 | obj = '' 27 | elif obj is False: 28 | obj = 'false' 29 | elif obj is True: 30 | obj = 'true' 31 | print(f"{prefix}={obj}") 32 | 33 | 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument('input_yaml') 36 | parser.add_argument('--prefix', default='') 37 | args = parser.parse_args() 38 | 39 | data = yaml.safe_load(open(args.input_yaml)) 40 | print_var_assign_statements(data, prefix=args.prefix) 41 | -------------------------------------------------------------------------------- /eend/chainer_backend/diarization_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita) 2 | # Licensed under the MIT license. 3 | 4 | import chainer 5 | import numpy as np 6 | from eend import kaldi_data 7 | from eend import feature 8 | 9 | 10 | def _count_frames(data_len, size, step): 11 | # no padding at edges, last remaining samples are ignored 12 | return int((data_len - size + step) / step) 13 | 14 | 15 | def _gen_frame_indices( 16 | data_length, size=2000, step=2000, 17 | use_last_samples=False, 18 | label_delay=0, 19 | subsampling=1): 20 | i = -1 21 | for i in range(_count_frames(data_length, size, step)): 22 | yield i * step, i * step + size 23 | if use_last_samples and i * step + size < data_length: 24 | if data_length - (i + 1) * step - subsampling * label_delay > 0: 25 | yield (i + 1) * step, data_length 26 | 27 | 28 | class KaldiDiarizationDataset(chainer.dataset.DatasetMixin): 29 | 30 | def __init__( 31 | self, 32 | data_dir, 33 | dtype=np.float32, 34 | chunk_size=2000, 35 | context_size=0, 36 | frame_size=1024, 37 | frame_shift=256, 38 | subsampling=1, 39 | rate=16000, 40 | input_transform=None, 41 | use_last_samples=False, 42 | label_delay=0, 43 | n_speakers=None, 44 | shuffle=False, 45 | ): 46 | self.data_dir = data_dir 47 | self.dtype = dtype 48 | self.chunk_size = chunk_size 49 | self.context_size = context_size 50 | self.frame_size = frame_size 51 | self.frame_shift = frame_shift 52 | self.subsampling = subsampling 53 | self.input_transform = input_transform 54 | self.n_speakers = n_speakers 55 | self.chunk_indices = [] 56 | self.label_delay = label_delay 57 | 58 | self.data = kaldi_data.KaldiData(self.data_dir) 59 | 60 | # make chunk indices: filepath, start_frame, end_frame 61 | for rec in self.data.wavs: 62 | data_len = int(self.data.reco2dur[rec] * rate / frame_shift) 63 | data_len = int(data_len / self.subsampling) 64 | for st, ed in _gen_frame_indices( 65 | data_len, chunk_size, chunk_size, use_last_samples, 66 | label_delay=self.label_delay, 67 | subsampling=self.subsampling): 68 | self.chunk_indices.append( 69 | (rec, st * self.subsampling, ed * self.subsampling)) 70 | print(len(self.chunk_indices), " chunks") 71 | 72 | self.shuffle = shuffle 73 | 74 | def __len__(self): 75 | return len(self.chunk_indices) 76 | 77 | def get_example(self, i): 78 | rec, st, ed = self.chunk_indices[i] 79 | Y, T = feature.get_labeledSTFT( 80 | self.data, 81 | rec, 82 | st, 83 | ed, 84 | self.frame_size, 85 | self.frame_shift, 86 | self.n_speakers) 87 | Y = feature.transform(Y, self.input_transform) 88 | Y_spliced = feature.splice(Y, self.context_size) 89 | Y_ss, T_ss = feature.subsample(Y_spliced, T, self.subsampling) 90 | 91 | # If the sample contains more than "self.n_speakers" speakers, 92 | # extract top-(self.n_speakers) speakers 93 | if self.n_speakers and T_ss.shape[1] > self.n_speakers: 94 | selected_speakers = np.argsort(T_ss.sum(axis=0))[::-1][:self.n_speakers] 95 | T_ss = T_ss[:, selected_speakers] 96 | 97 | # If self.shuffle is True, shuffle the order in time-axis 98 | # This operation improves the performance of EEND-EDA 99 | if self.shuffle: 100 | order = np.arange(Y_ss.shape[0]) 101 | np.random.shuffle(order) 102 | Y_ss = Y_ss[order] 103 | T_ss = T_ss[order] 104 | 105 | return Y_ss, T_ss 106 | -------------------------------------------------------------------------------- /eend/chainer_backend/encoder_decoder_attractor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2020 Hitachi, Ltd. (author: Shota Horiguchi) 4 | # Licensed under the MIT license. 5 | 6 | from chainer import Chain, cuda 7 | import chainer.functions as F 8 | import chainer.links as L 9 | 10 | 11 | class EncoderDecoderAttractor(Chain): 12 | 13 | def __init__(self, n_units, encoder_dropout=0.1, decoder_dropout=0.1): 14 | super(EncoderDecoderAttractor, self).__init__() 15 | with self.init_scope(): 16 | self.encoder = L.NStepLSTM(1, n_units, n_units, encoder_dropout) 17 | self.decoder = L.NStepLSTM(1, n_units, n_units, decoder_dropout) 18 | self.counter = L.Linear(n_units, 1) 19 | self.n_units = n_units 20 | 21 | def forward(self, xs, zeros): 22 | hx, cx, _ = self.encoder(None, None, xs) 23 | _, _, attractors = self.decoder(hx, cx, zeros) 24 | return attractors 25 | 26 | def estimate(self, xs, max_n_speakers=15): 27 | """ 28 | Calculate attractors from embedding sequences 29 | without prior knowledge of the number of speakers 30 | 31 | Args: 32 | xs: List of (T,D)-shaped embeddings 33 | max_n_speakers (int) 34 | Returns: 35 | attractors: List of (N,D)-shaped attractors 36 | probs: List of attractor existence probabilities 37 | """ 38 | 39 | xp = cuda.get_array_module(xs[0]) 40 | zeros = [xp.zeros((max_n_speakers, self.n_units), dtype=xp.float32) for _ in xs] 41 | attractors = self.forward(xs, zeros) 42 | probs = [F.sigmoid(F.flatten(self.counter(att))) for att in attractors] 43 | return attractors, probs 44 | 45 | def __call__(self, xs, n_speakers): 46 | """ 47 | Calculate attractors from embedding sequences with given number of speakers 48 | 49 | Args: 50 | xs: List of (T,D)-shaped embeddings 51 | n_speakers: List of number of speakers, or None if the number of speakers is unknown (ex. test phase) 52 | Returns: 53 | loss: Attractor existence loss 54 | attractors: List of (N,D)-shaped attractors 55 | """ 56 | 57 | xp = cuda.get_array_module(xs[0]) 58 | zeros = [xp.zeros((n_spk + 1, self.n_units), dtype=xp.float32) for n_spk in n_speakers] 59 | attractors = self.forward(xs, zeros) 60 | labels = F.concat([xp.array([[1] * n_spk + [0]], xp.int32) for n_spk in n_speakers], axis=1) 61 | logit = F.concat([F.reshape(self.counter(att), (-1, n_spk + 1)) for att, n_spk in zip(attractors, n_speakers)], axis=1) 62 | loss = F.sigmoid_cross_entropy(logit, labels) 63 | 64 | # The final attractor does not correspond to a speaker so remove it 65 | # attractors = [att[:-1] for att in attractors] 66 | attractors = [att[slice(0, att.shape[0] - 1)] for att in attractors] 67 | return loss, attractors 68 | -------------------------------------------------------------------------------- /eend/chainer_backend/infer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita) 4 | # Licensed under the MIT license. 5 | # 6 | import os 7 | import h5py 8 | import numpy as np 9 | import chainer 10 | from chainer import Variable 11 | from chainer import serializers 12 | from scipy.ndimage import shift 13 | from eend.chainer_backend.models import BLSTMDiarization 14 | from eend.chainer_backend.models import TransformerDiarization, TransformerEDADiarization 15 | from eend.chainer_backend.utils import use_single_gpu 16 | from eend import feature 17 | from eend import kaldi_data 18 | from eend import system_info 19 | 20 | 21 | def _gen_chunk_indices(data_len, chunk_size): 22 | step = chunk_size 23 | start = 0 24 | while start < data_len: 25 | end = min(data_len, start + chunk_size) 26 | yield start, end 27 | start += step 28 | 29 | 30 | def infer(args): 31 | system_info.print_system_info() 32 | 33 | # Prepare model 34 | in_size = feature.get_input_dim( 35 | args.frame_size, 36 | args.context_size, 37 | args.input_transform) 38 | 39 | if args.model_type == "BLSTM": 40 | model = BLSTMDiarization( 41 | in_size=in_size, 42 | n_speakers=args.num_speakers, 43 | hidden_size=args.hidden_size, 44 | n_layers=args.num_lstm_layers, 45 | embedding_layers=args.embedding_layers, 46 | embedding_size=args.embedding_size 47 | ) 48 | elif args.model_type == 'Transformer': 49 | if args.use_attractor: 50 | model = TransformerEDADiarization( 51 | in_size, 52 | n_units=args.hidden_size, 53 | n_heads=args.transformer_encoder_n_heads, 54 | n_layers=args.transformer_encoder_n_layers, 55 | dropout=0, 56 | attractor_encoder_dropout=args.attractor_encoder_dropout, 57 | attractor_decoder_dropout=args.attractor_decoder_dropout, 58 | ) 59 | else: 60 | model = TransformerDiarization( 61 | args.num_speakers, 62 | in_size, 63 | n_units=args.hidden_size, 64 | n_heads=args.transformer_encoder_n_heads, 65 | n_layers=args.transformer_encoder_n_layers, 66 | dropout=0 67 | ) 68 | else: 69 | raise ValueError('Unknown model type.') 70 | 71 | serializers.load_npz(args.model_file, model) 72 | 73 | if args.gpu >= 0: 74 | gpuid = use_single_gpu() 75 | model.to_gpu() 76 | 77 | kaldi_obj = kaldi_data.KaldiData(args.data_dir) 78 | for recid in kaldi_obj.wavs: 79 | data, rate = kaldi_obj.load_wav(recid) 80 | Y = feature.stft(data, args.frame_size, args.frame_shift) 81 | Y = feature.transform(Y, transform_type=args.input_transform) 82 | Y = feature.splice(Y, context_size=args.context_size) 83 | Y = Y[::args.subsampling] 84 | out_chunks = [] 85 | with chainer.no_backprop_mode(), chainer.using_config('train', False): 86 | hs = None 87 | for start, end in _gen_chunk_indices(len(Y), args.chunk_size): 88 | Y_chunked = Variable(Y[start:end]) 89 | if args.gpu >= 0: 90 | Y_chunked.to_gpu(gpuid) 91 | hs, ys = model.estimate_sequential( 92 | hs, [Y_chunked], 93 | n_speakers=args.num_speakers, 94 | th=args.attractor_threshold, 95 | shuffle=args.shuffle 96 | ) 97 | if args.gpu >= 0: 98 | ys[0].to_cpu() 99 | out_chunks.append(ys[0].data) 100 | if args.save_attention_weight == 1: 101 | att_fname = f"{recid}_{start}_{end}.att.npy" 102 | att_path = os.path.join(args.out_dir, att_fname) 103 | model.save_attention_weight(att_path) 104 | outfname = recid + '.h5' 105 | outpath = os.path.join(args.out_dir, outfname) 106 | if hasattr(model, 'label_delay'): 107 | outdata = shift(np.vstack(out_chunks), (-model.label_delay, 0)) 108 | else: 109 | max_n_speakers = max([o.shape[1] for o in out_chunks]) 110 | out_chunks = [np.insert(o, o.shape[1], np.zeros((max_n_speakers - o.shape[1], o.shape[0])), axis=1) for o in out_chunks] 111 | outdata = np.vstack(out_chunks) 112 | with h5py.File(outpath, 'w') as wf: 113 | wf.create_dataset('T_hat', data=outdata) 114 | -------------------------------------------------------------------------------- /eend/chainer_backend/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita) 4 | # Licensed under the MIT license. 5 | # 6 | import os 7 | import numpy as np 8 | import chainer 9 | from chainer import optimizers 10 | from chainer import serializers 11 | from chainer import iterators 12 | from chainer import training 13 | from chainer.training import extensions 14 | from eend.chainer_backend.models import BLSTMDiarization 15 | from eend.chainer_backend.models import TransformerDiarization, TransformerEDADiarization 16 | from eend.chainer_backend.transformer import NoamScheduler 17 | from eend.chainer_backend.updater import GradientAccumulationUpdater 18 | from eend.chainer_backend.diarization_dataset import KaldiDiarizationDataset 19 | from eend.chainer_backend.utils import use_single_gpu 20 | 21 | 22 | @chainer.dataset.converter() 23 | def _convert(batch, device): 24 | def to_device_batch(batch): 25 | if device is None: 26 | return batch 27 | batch_dst = [device.send(x) for x in batch] 28 | return batch_dst 29 | return {'xs': to_device_batch([x for x, _ in batch]), 30 | 'ts': to_device_batch([t for _, t in batch])} 31 | 32 | 33 | def train(args): 34 | """ Training model with chainer backend. 35 | This function is called from eend/bin/train.py with 36 | parsed command-line arguments. 37 | """ 38 | np.random.seed(args.seed) 39 | os.environ['CHAINER_SEED'] = str(args.seed) 40 | chainer.global_config.cudnn_deterministic = True 41 | 42 | train_set = KaldiDiarizationDataset( 43 | args.train_data_dir, 44 | chunk_size=args.num_frames, 45 | context_size=args.context_size, 46 | input_transform=args.input_transform, 47 | frame_size=args.frame_size, 48 | frame_shift=args.frame_shift, 49 | subsampling=args.subsampling, 50 | rate=args.sampling_rate, 51 | use_last_samples=True, 52 | label_delay=args.label_delay, 53 | n_speakers=args.num_speakers, 54 | shuffle=args.shuffle, 55 | ) 56 | dev_set = KaldiDiarizationDataset( 57 | args.valid_data_dir, 58 | chunk_size=args.num_frames, 59 | context_size=args.context_size, 60 | input_transform=args.input_transform, 61 | frame_size=args.frame_size, 62 | frame_shift=args.frame_shift, 63 | subsampling=args.subsampling, 64 | rate=args.sampling_rate, 65 | use_last_samples=True, 66 | label_delay=args.label_delay, 67 | n_speakers=args.num_speakers, 68 | shuffle=args.shuffle, 69 | ) 70 | 71 | # Prepare model 72 | Y, T = train_set.get_example(0) 73 | 74 | if args.model_type == 'BLSTM': 75 | assert args.num_speakers is not None 76 | model = BLSTMDiarization( 77 | in_size=Y.shape[1], 78 | n_speakers=args.num_speakers, 79 | hidden_size=args.hidden_size, 80 | n_layers=args.num_lstm_layers, 81 | embedding_layers=args.embedding_layers, 82 | embedding_size=args.embedding_size, 83 | dc_loss_ratio=args.dc_loss_ratio, 84 | ) 85 | elif args.model_type == 'Transformer': 86 | if args.use_attractor: 87 | model = TransformerEDADiarization( 88 | Y.shape[1], 89 | n_units=args.hidden_size, 90 | n_heads=args.transformer_encoder_n_heads, 91 | n_layers=args.transformer_encoder_n_layers, 92 | dropout=args.transformer_encoder_dropout, 93 | attractor_loss_ratio=args.attractor_loss_ratio, 94 | attractor_encoder_dropout=args.attractor_encoder_dropout, 95 | attractor_decoder_dropout=args.attractor_decoder_dropout, 96 | ) 97 | else: 98 | assert args.num_speakers is not None 99 | model = TransformerDiarization( 100 | args.num_speakers, 101 | Y.shape[1], 102 | n_units=args.hidden_size, 103 | n_heads=args.transformer_encoder_n_heads, 104 | n_layers=args.transformer_encoder_n_layers, 105 | dropout=args.transformer_encoder_dropout 106 | ) 107 | else: 108 | raise ValueError('Possible model_type are "Transformer" and "BLSTM"') 109 | 110 | if args.gpu >= 0: 111 | gpuid = use_single_gpu() 112 | print('GPU device {} is used'.format(gpuid)) 113 | model.to_gpu() 114 | else: 115 | gpuid = -1 116 | print('Prepared model') 117 | 118 | # Setup optimizer 119 | if args.optimizer == 'adam': 120 | optimizer = optimizers.Adam(alpha=args.lr) 121 | elif args.optimizer == 'sgd': 122 | optimizer = optimizers.SGD(lr=args.lr) 123 | elif args.optimizer == 'noam': 124 | optimizer = optimizers.Adam(alpha=0, beta1=0.9, beta2=0.98, eps=1e-9) 125 | else: 126 | raise ValueError(args.optimizer) 127 | 128 | optimizer.setup(model) 129 | if args.gradclip > 0: 130 | optimizer.add_hook( 131 | chainer.optimizer_hooks.GradientClipping(args.gradclip)) 132 | 133 | # Init/Resume 134 | if args.initmodel: 135 | print('Load model from', args.initmodel) 136 | serializers.load_npz(args.initmodel, model) 137 | 138 | train_iter = iterators.MultiprocessIterator( 139 | train_set, 140 | batch_size=args.batchsize, 141 | repeat=True, shuffle=True, 142 | # shared_mem=64000000, 143 | shared_mem=None, 144 | n_processes=4, n_prefetch=2) 145 | 146 | dev_iter = iterators.MultiprocessIterator( 147 | dev_set, 148 | batch_size=args.batchsize, 149 | repeat=False, shuffle=False, 150 | # shared_mem=64000000, 151 | shared_mem=None, 152 | n_processes=4, n_prefetch=2) 153 | 154 | if args.gradient_accumulation_steps > 1: 155 | updater = GradientAccumulationUpdater( 156 | train_iter, optimizer, converter=_convert, device=gpuid) 157 | else: 158 | updater = training.StandardUpdater( 159 | train_iter, optimizer, converter=_convert, device=gpuid) 160 | 161 | trainer = training.Trainer( 162 | updater, 163 | (args.max_epochs, 'epoch'), 164 | out=os.path.join(args.model_save_dir)) 165 | 166 | evaluator = extensions.Evaluator( 167 | dev_iter, model, converter=_convert, device=gpuid) 168 | trainer.extend(evaluator) 169 | 170 | if args.optimizer == 'noam': 171 | trainer.extend( 172 | NoamScheduler(args.hidden_size, 173 | warmup_steps=args.noam_warmup_steps, 174 | scale=args.noam_scale), 175 | trigger=(1, 'iteration')) 176 | 177 | if args.resume: 178 | chainer.serializers.load_npz(args.resume, trainer) 179 | 180 | # MICRO AVERAGE 181 | metrics = [ 182 | ('diarization_error', 'speaker_scored', 'DER'), 183 | ('speech_miss', 'speech_scored', 'SAD_MR'), 184 | ('speech_falarm', 'speech_scored', 'SAD_FR'), 185 | ('speaker_miss', 'speaker_scored', 'MI'), 186 | ('speaker_falarm', 'speaker_scored', 'FA'), 187 | ('speaker_error', 'speaker_scored', 'CF'), 188 | ('correct', 'frames', 'accuracy')] 189 | for num, den, name in metrics: 190 | trainer.extend(extensions.MicroAverage( 191 | 'main/{}'.format(num), 192 | 'main/{}'.format(den), 193 | 'main/{}'.format(name))) 194 | trainer.extend(extensions.MicroAverage( 195 | 'validation/main/{}'.format(num), 196 | 'validation/main/{}'.format(den), 197 | 'validation/main/{}'.format(name))) 198 | 199 | trainer.extend(extensions.LogReport(log_name='log_iter', 200 | trigger=(1000, 'iteration'))) 201 | 202 | trainer.extend(extensions.LogReport()) 203 | trainer.extend(extensions.PrintReport( 204 | ['epoch', 'main/loss', 'validation/main/loss', 205 | 'main/diarization_error_rate', 206 | 'validation/main/diarization_error_rate', 207 | 'elapsed_time'])) 208 | trainer.extend(extensions.PlotReport( 209 | ['main/loss', 'validation/main/loss'], 210 | x_key='epoch', 211 | file_name='loss.png')) 212 | trainer.extend(extensions.PlotReport( 213 | ['main/diarization_error_rate', 214 | 'validation/main/diarization_error_rate'], 215 | x_key='epoch', 216 | file_name='DER.png')) 217 | trainer.extend(extensions.ProgressBar(update_interval=100)) 218 | trainer.extend(extensions.snapshot( 219 | filename='snapshot_epoch-{.updater.epoch}')) 220 | 221 | trainer.extend(extensions.dump_graph('main/loss', out_name="cg.dot")) 222 | 223 | trainer.run() 224 | print('Finished!') 225 | -------------------------------------------------------------------------------- /eend/chainer_backend/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita) 2 | # Licensed under the MIT license. 3 | 4 | import numpy as np 5 | from chainer.training import extension 6 | from chainer import Chain 7 | import chainer.functions as F 8 | import chainer.links as L 9 | 10 | 11 | class NoamScheduler(extension.Extension): 12 | """ learning rate scheduler used in the transformer 13 | See https://arxiv.org/pdf/1706.03762.pdf 14 | lrate = d_model**(-0.5) * \ 15 | min(step_num**(-0.5), step_num*warmup_steps**(-1.5)) 16 | Scaling factor is implemented as in 17 | http://nlp.seas.harvard.edu/2018/04/03/attention.html#optimizer 18 | """ 19 | 20 | def __init__(self, d_model, warmup_steps, scale=1.0): 21 | self.d_model = d_model 22 | self.warmup_steps = warmup_steps 23 | self.scale = scale 24 | self.last_value = None 25 | self.t = 0 26 | 27 | def initialize(self, trainer): 28 | optimizer = trainer.updater.get_optimizer('main') 29 | if self.last_value: 30 | # resume 31 | setattr(optimizer, 'alpha', self.last_value) 32 | else: 33 | # the initiallearning rate is set as step = 1, 34 | init_value = self.scale * self.d_model ** (-0.5) * \ 35 | self.warmup_steps ** (-1.5) 36 | setattr(optimizer, 'alpha', init_value) 37 | self.last_value = init_value 38 | 39 | def __call__(self, trainer): 40 | self.t += 1 41 | optimizer = trainer.updater.get_optimizer('main') 42 | value = self.scale * self.d_model ** (-0.5) * \ 43 | min(self.t ** (-0.5), self.t * self.warmup_steps ** (-1.5)) 44 | setattr(optimizer, 'alpha', value) 45 | self.last_value = value 46 | 47 | def serialize(self, serializer): 48 | self.t = serializer('t', self.t) 49 | self.last_value = serializer('last_value', self.last_value) 50 | 51 | 52 | class MultiHeadSelfAttention(Chain): 53 | 54 | """ Multi head "self" attention layer 55 | """ 56 | 57 | def __init__(self, n_units, h=8, dropout=0.1): 58 | super(MultiHeadSelfAttention, self).__init__() 59 | with self.init_scope(): 60 | self.linearQ = L.Linear(n_units, n_units) 61 | self.linearK = L.Linear(n_units, n_units) 62 | self.linearV = L.Linear(n_units, n_units) 63 | self.linearO = L.Linear(n_units, n_units) 64 | self.d_k = n_units // h 65 | self.h = h 66 | self.dropout = dropout 67 | # attention for plot 68 | self.att = None 69 | 70 | def __call__(self, x, batch_size): 71 | # x: (BT, F) 72 | # TODO: if chainer >= 5.0, use linear functions with 'n_batch_axes' 73 | # and x be (B, T, F), then remove batch_size. 74 | q = self.linearQ(x).reshape(batch_size, -1, self.h, self.d_k) 75 | k = self.linearK(x).reshape(batch_size, -1, self.h, self.d_k) 76 | v = self.linearV(x).reshape(batch_size, -1, self.h, self.d_k) 77 | scores = F.matmul( 78 | F.swapaxes(q, 1, 2), k.transpose(0, 2, 3, 1)) / np.sqrt(self.d_k) 79 | # scores: (B, h, T, T) 80 | self.att = F.softmax(scores, axis=3) 81 | p_att = F.dropout(self.att, self.dropout) 82 | x = F.matmul(p_att, F.swapaxes(v, 1, 2)) 83 | x = F.swapaxes(x, 1, 2).reshape(-1, self.h * self.d_k) 84 | return self.linearO(x) 85 | 86 | 87 | class PositionwiseFeedForward(Chain): 88 | 89 | """ Positionwise feed-forward layer 90 | """ 91 | 92 | def __init__(self, n_units, d_units, dropout): 93 | super(PositionwiseFeedForward, self).__init__() 94 | with self.init_scope(): 95 | self.linear1 = L.Linear(n_units, d_units) 96 | self.linear2 = L.Linear(d_units, n_units) 97 | self.dropout = dropout 98 | 99 | def __call__(self, x): 100 | return self.linear2(F.dropout(F.relu(self.linear1(x)), self.dropout)) 101 | 102 | 103 | class PositionalEncoding(Chain): 104 | 105 | """ Positional encoding function 106 | """ 107 | 108 | def __init__(self, n_units, dropout, max_len): 109 | super(PositionalEncoding, self).__init__() 110 | self.dropout = dropout 111 | positions = np.arange(0, max_len, dtype='f')[:, None] 112 | dens = np.exp( 113 | np.arange(0, n_units, 2, dtype='f') * -(np.log(10000.) / n_units)) 114 | self.enc = np.zeros((max_len, n_units), dtype='f') 115 | self.enc[:, ::2] = np.sin(positions * dens) 116 | self.enc[:, 1::2] = np.cos(positions * dens) 117 | self.scale = np.sqrt(n_units) 118 | 119 | def __call__(self, x): 120 | x = x * self.scale + self.xp.array(self.enc[:, :x.shape[1]]) 121 | return F.dropout(x, self.dropout) 122 | 123 | 124 | class TransformerEncoder(Chain): 125 | def __init__(self, idim, n_layers, n_units, 126 | e_units=2048, h=8, dropout=0.1): 127 | super(TransformerEncoder, self).__init__() 128 | with self.init_scope(): 129 | self.linear_in = L.Linear(idim, n_units) 130 | self.lnorm_in = L.LayerNormalization(n_units) 131 | self.pos_enc = PositionalEncoding(n_units, dropout, 5000) 132 | self.n_layers = n_layers 133 | self.dropout = dropout 134 | for i in range(n_layers): 135 | setattr(self, '{}{:d}'.format("lnorm1_", i), 136 | L.LayerNormalization(n_units)) 137 | setattr(self, '{}{:d}'.format("self_att_", i), 138 | MultiHeadSelfAttention(n_units, h)) 139 | setattr(self, '{}{:d}'.format("lnorm2_", i), 140 | L.LayerNormalization(n_units)) 141 | setattr(self, '{}{:d}'.format("ff_", i), 142 | PositionwiseFeedForward(n_units, e_units, dropout)) 143 | self.lnorm_out = L.LayerNormalization(n_units) 144 | 145 | def __call__(self, x): 146 | # x: (B, T, F) ... batch, time, (mel)freq 147 | BT_size = x.shape[0] * x.shape[1] 148 | # e: (BT, F) 149 | e = self.linear_in(x.reshape(BT_size, -1)) 150 | # Encoder stack 151 | for i in range(self.n_layers): 152 | # layer normalization 153 | e = getattr(self, '{}{:d}'.format("lnorm1_", i))(e) 154 | # self-attention 155 | s = getattr(self, '{}{:d}'.format("self_att_", i))(e, x.shape[0]) 156 | # residual 157 | e = e + F.dropout(s, self.dropout) 158 | # layer normalization 159 | e = getattr(self, '{}{:d}'.format("lnorm2_", i))(e) 160 | # positionwise feed-forward 161 | s = getattr(self, '{}{:d}'.format("ff_", i))(e) 162 | # residual 163 | e = e + F.dropout(s, self.dropout) 164 | # final layer normalization 165 | # output: (BT, F) 166 | return self.lnorm_out(e) 167 | -------------------------------------------------------------------------------- /eend/chainer_backend/updater.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita) 2 | # Licensed under the MIT license. 3 | 4 | from chainer import training 5 | 6 | 7 | class GradientAccumulationUpdater(training.StandardUpdater): 8 | """ The optimizer is run once every `n_accs` minibatches. 9 | The gradients over `n_accs` minibatches are accumulated. 10 | It virtually enlarges minibatch size. 11 | """ 12 | def __init__(self, iterator, optimizer, converter, device, n_accs=1): 13 | super(GradientAccumulationUpdater, self).__init__( 14 | iterator, optimizer, converter=converter, device=device) 15 | self.step = 0 16 | self.n_accs = n_accs 17 | 18 | def update_core(self): 19 | self.step += 1 20 | iterator = self.get_iterator('main') 21 | optimizer = self.get_optimizer('main') 22 | batch = iterator.next() 23 | # converter outputs 'dict', 'tuple', or array 24 | x = self.converter(batch, self.device) 25 | if self.step == 1: 26 | optimizer.target.cleargrads() 27 | # Compute the loss at this time step and accumulate it 28 | if isinstance(x, tuple): 29 | loss = optimizer.target(*x) / self.n_accs 30 | elif isinstance(x, dict): 31 | loss = optimizer.target(**x) / self.n_accs 32 | else: 33 | loss = optimizer.target(x) / self.n_accs 34 | loss.backward() 35 | loss.unchain_backward() 36 | # Update parameters once every n_accs 37 | if self.step % self.n_accs == 0: 38 | optimizer.update() 39 | optimizer.target.cleargrads() 40 | -------------------------------------------------------------------------------- /eend/chainer_backend/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita) 4 | # Licensed under the MIT license. 5 | 6 | import os 7 | import chainer 8 | import subprocess 9 | import cupy 10 | 11 | 12 | def get_free_gpus(): 13 | """ Get IDs of free GPUs using `nvidia-smi`. 14 | 15 | Returns: 16 | sorted list of GPUs which have no running process. 17 | """ 18 | p = subprocess.Popen( 19 | ["nvidia-smi", 20 | "--query-gpu=index,gpu_bus_id", 21 | "--format=csv,noheader"], 22 | stdout=subprocess.PIPE) 23 | stdout, stderr = p.communicate() 24 | gpus = {} 25 | for line in stdout.decode('utf-8').strip().split(os.linesep): 26 | if not line: 27 | continue 28 | idx, busid = line.strip().split(',') 29 | gpus[busid] = int(idx) 30 | p = subprocess.Popen( 31 | ["nvidia-smi", 32 | "--query-compute-apps=pid,gpu_bus_id", 33 | "--format=csv,noheader"], 34 | stdout=subprocess.PIPE) 35 | stdout, stderr = p.communicate() 36 | for line in stdout.decode('utf-8').strip().split(os.linesep): 37 | if not line: 38 | continue 39 | pid, busid = line.strip().split(',') 40 | del gpus[busid] 41 | return sorted([gpus[busid] for busid in gpus]) 42 | 43 | 44 | def use_single_gpu(): 45 | """ Use single GPU device. 46 | 47 | If CUDA_VISIBLE_DEVICES is set, select a device from the variable. 48 | Otherwise, get a free GPU device and use it. 49 | 50 | Returns: 51 | assigned GPU id. 52 | """ 53 | cvd = os.environ.get('CUDA_VISIBLE_DEVICES') 54 | if cvd is None: 55 | # no GPUs are researved 56 | cvd = get_free_gpus()[0] 57 | elif ',' in cvd: 58 | # multiple GPUs are researved 59 | cvd = int(cvd.split(',')[0]) 60 | else: 61 | # single GPU is reserved 62 | cvd = int(cvd) 63 | # Use the GPU immediately 64 | chainer.cuda.get_device_from_id(cvd).use() 65 | cupy.empty((1,), dtype=cupy.float32) 66 | return cvd 67 | -------------------------------------------------------------------------------- /eend/feature.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita) 2 | # Licensed under the MIT license. 3 | # 4 | # This module is for computing audio features 5 | 6 | import numpy as np 7 | import librosa 8 | 9 | 10 | def get_input_dim( 11 | frame_size, 12 | context_size, 13 | transform_type, 14 | ): 15 | if transform_type.startswith('logmel23'): 16 | frame_size = 23 17 | elif transform_type.startswith('logmel'): 18 | frame_size = 40 19 | else: 20 | fft_size = 1 << (frame_size - 1).bit_length() 21 | frame_size = int(fft_size / 2) + 1 22 | input_dim = (2 * context_size + 1) * frame_size 23 | return input_dim 24 | 25 | 26 | def transform( 27 | Y, 28 | transform_type=None, 29 | dtype=np.float32): 30 | """ Transform STFT feature 31 | 32 | Args: 33 | Y: STFT 34 | (n_frames, n_bins)-shaped np.complex array 35 | transform_type: 36 | None, "log" 37 | dtype: output data type 38 | np.float32 is expected 39 | Returns: 40 | Y (numpy.array): transformed feature 41 | """ 42 | Y = np.abs(Y) 43 | if not transform_type: 44 | pass 45 | elif transform_type == 'log': 46 | Y = np.log(np.maximum(Y, 1e-10)) 47 | elif transform_type == 'logmel': 48 | n_fft = 2 * (Y.shape[1] - 1) 49 | sr = 16000 50 | n_mels = 40 51 | mel_basis = librosa.filters.mel(sr, n_fft, n_mels) 52 | Y = np.dot(Y ** 2, mel_basis.T) 53 | Y = np.log10(np.maximum(Y, 1e-10)) 54 | elif transform_type == 'logmel23': 55 | n_fft = 2 * (Y.shape[1] - 1) 56 | sr = 8000 57 | n_mels = 23 58 | mel_basis = librosa.filters.mel(sr, n_fft, n_mels) 59 | Y = np.dot(Y ** 2, mel_basis.T) 60 | Y = np.log10(np.maximum(Y, 1e-10)) 61 | elif transform_type == 'logmel23_mn': 62 | n_fft = 2 * (Y.shape[1] - 1) 63 | sr = 8000 64 | n_mels = 23 65 | mel_basis = librosa.filters.mel(sr, n_fft, n_mels) 66 | Y = np.dot(Y ** 2, mel_basis.T) 67 | Y = np.log10(np.maximum(Y, 1e-10)) 68 | mean = np.mean(Y, axis=0) 69 | Y = Y - mean 70 | elif transform_type == 'logmel23_swn': 71 | n_fft = 2 * (Y.shape[1] - 1) 72 | sr = 8000 73 | n_mels = 23 74 | mel_basis = librosa.filters.mel(sr, n_fft, n_mels) 75 | Y = np.dot(Y ** 2, mel_basis.T) 76 | Y = np.log10(np.maximum(Y, 1e-10)) 77 | # b = np.ones(300)/300 78 | # mean = scipy.signal.convolve2d(Y, b[:, None], mode='same') 79 | 80 | # simple 2-means based threshoding for mean calculation 81 | powers = np.sum(Y, axis=1) 82 | th = (np.max(powers) + np.min(powers)) / 2.0 83 | for i in range(10): 84 | th = (np.mean(powers[powers >= th]) + np.mean(powers[powers < th])) / 2 85 | mean = np.mean(Y[powers > th, :], axis=0) 86 | Y = Y - mean 87 | elif transform_type == 'logmel23_mvn': 88 | n_fft = 2 * (Y.shape[1] - 1) 89 | sr = 8000 90 | n_mels = 23 91 | mel_basis = librosa.filters.mel(sr, n_fft, n_mels) 92 | Y = np.dot(Y ** 2, mel_basis.T) 93 | Y = np.log10(np.maximum(Y, 1e-10)) 94 | mean = np.mean(Y, axis=0) 95 | Y = Y - mean 96 | std = np.maximum(np.std(Y, axis=0), 1e-10) 97 | Y = Y / std 98 | else: 99 | raise ValueError('Unknown transform_type: %s' % transform_type) 100 | return Y.astype(dtype) 101 | 102 | 103 | def subsample(Y, T, subsampling=1): 104 | """ Frame subsampling 105 | """ 106 | Y_ss = Y[::subsampling] 107 | T_ss = T[::subsampling] 108 | return Y_ss, T_ss 109 | 110 | 111 | def splice(Y, context_size=0): 112 | """ Frame splicing 113 | 114 | Args: 115 | Y: feature 116 | (n_frames, n_featdim)-shaped numpy array 117 | context_size: 118 | number of frames concatenated on left-side 119 | if context_size = 5, 11 frames are concatenated. 120 | 121 | Returns: 122 | Y_spliced: spliced feature 123 | (n_frames, n_featdim * (2 * context_size + 1))-shaped 124 | """ 125 | Y_pad = np.pad( 126 | Y, 127 | [(context_size, context_size), (0, 0)], 128 | 'constant') 129 | Y_spliced = np.lib.stride_tricks.as_strided( 130 | np.ascontiguousarray(Y_pad), 131 | (Y.shape[0], Y.shape[1] * (2 * context_size + 1)), 132 | (Y.itemsize * Y.shape[1], Y.itemsize), writeable=False) 133 | return Y_spliced 134 | 135 | 136 | def stft( 137 | data, 138 | frame_size=1024, 139 | frame_shift=256): 140 | """ Compute STFT features 141 | 142 | Args: 143 | data: audio signal 144 | (n_samples,)-shaped np.float32 array 145 | frame_size: number of samples in a frame (must be a power of two) 146 | frame_shift: number of samples between frames 147 | 148 | Returns: 149 | stft: STFT frames 150 | (n_frames, n_bins)-shaped np.complex64 array 151 | """ 152 | # round up to nearest power of 2 153 | fft_size = 1 << (frame_size - 1).bit_length() 154 | # HACK: The last frame is ommited 155 | # as librosa.stft produces such an excessive frame 156 | if len(data) % frame_shift == 0: 157 | return librosa.stft(data, n_fft=fft_size, win_length=frame_size, 158 | hop_length=frame_shift).T[:-1] 159 | else: 160 | return librosa.stft(data, n_fft=fft_size, win_length=frame_size, 161 | hop_length=frame_shift).T 162 | 163 | 164 | def _count_frames(data_len, size, shift): 165 | # HACK: Assuming librosa.stft(..., center=True) 166 | n_frames = 1 + int(data_len / shift) 167 | if data_len % shift == 0: 168 | n_frames = n_frames - 1 169 | return n_frames 170 | 171 | 172 | def get_frame_labels( 173 | kaldi_obj, 174 | rec, 175 | start=0, 176 | end=None, 177 | frame_size=1024, 178 | frame_shift=256, 179 | n_speakers=None): 180 | """ Get frame-aligned labels of given recording 181 | Args: 182 | kaldi_obj (KaldiData) 183 | rec (str): recording id 184 | start (int): start frame index 185 | end (int): end frame index 186 | None means the last frame of recording 187 | frame_size (int): number of frames in a frame 188 | frame_shift (int): number of shift samples 189 | n_speakers (int): number of speakers 190 | if None, the value is given from data 191 | Returns: 192 | T: label 193 | (n_frames, n_speakers)-shaped np.int32 array 194 | """ 195 | filtered_segments = kaldi_obj.segments[kaldi_obj.segments['rec'] == rec] 196 | speakers = np.unique( 197 | [kaldi_obj.utt2spk[seg['utt']] for seg 198 | in filtered_segments]).tolist() 199 | if n_speakers is None: 200 | n_speakers = len(speakers) 201 | es = end * frame_shift if end is not None else None 202 | data, rate = kaldi_obj.load_wav(rec, start * frame_shift, es) 203 | n_frames = _count_frames(len(data), frame_size, frame_shift) 204 | T = np.zeros((n_frames, n_speakers), dtype=np.int32) 205 | if end is None: 206 | end = n_frames 207 | 208 | for seg in filtered_segments: 209 | speaker_index = speakers.index(kaldi_obj.utt2spk[seg['utt']]) 210 | start_frame = np.rint( 211 | seg['st'] * rate / frame_shift).astype(int) 212 | end_frame = np.rint( 213 | seg['et'] * rate / frame_shift).astype(int) 214 | rel_start = rel_end = None 215 | if start <= start_frame and start_frame < end: 216 | rel_start = start_frame - start 217 | if start < end_frame and end_frame <= end: 218 | rel_end = end_frame - start 219 | if rel_start is not None or rel_end is not None: 220 | T[rel_start:rel_end, speaker_index] = 1 221 | return T 222 | 223 | 224 | def get_labeledSTFT( 225 | kaldi_obj, 226 | rec, start, end, frame_size, frame_shift, 227 | n_speakers=None, 228 | use_speaker_id=False): 229 | """ Extracts STFT and corresponding labels 230 | 231 | Extracts STFT and corresponding diarization labels for 232 | given recording id and start/end times 233 | 234 | Args: 235 | kaldi_obj (KaldiData) 236 | rec (str): recording id 237 | start (int): start frame index 238 | end (int): end frame index 239 | frame_size (int): number of samples in a frame 240 | frame_shift (int): number of shift samples 241 | n_speakers (int): number of speakers 242 | if None, the value is given from data 243 | Returns: 244 | Y: STFT 245 | (n_frames, n_bins)-shaped np.complex64 array, 246 | T: label 247 | (n_frmaes, n_speakers)-shaped np.int32 array. 248 | """ 249 | data, rate = kaldi_obj.load_wav( 250 | rec, start * frame_shift, end * frame_shift) 251 | Y = stft(data, frame_size, frame_shift) 252 | filtered_segments = kaldi_obj.segments[rec] 253 | # filtered_segments = kaldi_obj.segments[kaldi_obj.segments['rec'] == rec] 254 | speakers = np.unique( 255 | [kaldi_obj.utt2spk[seg['utt']] for seg 256 | in filtered_segments]).tolist() 257 | if n_speakers is None: 258 | n_speakers = len(speakers) 259 | T = np.zeros((Y.shape[0], n_speakers), dtype=np.int32) 260 | 261 | if use_speaker_id: 262 | all_speakers = sorted(kaldi_obj.spk2utt.keys()) 263 | S = np.zeros((Y.shape[0], len(all_speakers)), dtype=np.int32) 264 | 265 | for seg in filtered_segments: 266 | speaker_index = speakers.index(kaldi_obj.utt2spk[seg['utt']]) 267 | if use_speaker_id: 268 | all_speaker_index = all_speakers.index(kaldi_obj.utt2spk[seg['utt']]) 269 | start_frame = np.rint( 270 | seg['st'] * rate / frame_shift).astype(int) 271 | end_frame = np.rint( 272 | seg['et'] * rate / frame_shift).astype(int) 273 | rel_start = rel_end = None 274 | if start <= start_frame and start_frame < end: 275 | rel_start = start_frame - start 276 | if start < end_frame and end_frame <= end: 277 | rel_end = end_frame - start 278 | if rel_start is not None or rel_end is not None: 279 | T[rel_start:rel_end, speaker_index] = 1 280 | if use_speaker_id: 281 | S[rel_start:rel_end, all_speaker_index] = 1 282 | 283 | if use_speaker_id: 284 | return Y, T, S 285 | else: 286 | return Y, T 287 | -------------------------------------------------------------------------------- /eend/kaldi_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita) 2 | # Licensed under the MIT license. 3 | # 4 | # This library provides utilities for kaldi-style data directory. 5 | 6 | 7 | from __future__ import print_function 8 | import os 9 | import sys 10 | import numpy as np 11 | import subprocess 12 | import soundfile as sf 13 | import io 14 | from functools import lru_cache 15 | 16 | 17 | def load_segments(segments_file): 18 | """ load segments file as array """ 19 | if not os.path.exists(segments_file): 20 | return None 21 | return np.loadtxt( 22 | segments_file, 23 | dtype=[('utt', 'object'), 24 | ('rec', 'object'), 25 | ('st', 'f'), 26 | ('et', 'f')], 27 | ndmin=1) 28 | 29 | 30 | def load_segments_hash(segments_file): 31 | ret = {} 32 | if not os.path.exists(segments_file): 33 | return None 34 | for line in open(segments_file): 35 | utt, rec, st, et = line.strip().split() 36 | ret[utt] = (rec, float(st), float(et)) 37 | return ret 38 | 39 | 40 | def load_segments_rechash(segments_file): 41 | ret = {} 42 | if not os.path.exists(segments_file): 43 | return None 44 | for line in open(segments_file): 45 | utt, rec, st, et = line.strip().split() 46 | if rec not in ret: 47 | ret[rec] = [] 48 | ret[rec].append({'utt':utt, 'st':float(st), 'et':float(et)}) 49 | return ret 50 | 51 | 52 | def load_wav_scp(wav_scp_file): 53 | """ return dictionary { rec: wav_rxfilename } """ 54 | lines = [line.strip().split(None, 1) for line in open(wav_scp_file)] 55 | return {x[0]: x[1] for x in lines} 56 | 57 | 58 | @lru_cache(maxsize=1) 59 | def load_wav(wav_rxfilename, start=0, end=None): 60 | """ This function reads audio file and return data in numpy.float32 array. 61 | "lru_cache" holds recently loaded audio so that can be called 62 | many times on the same audio file. 63 | OPTIMIZE: controls lru_cache size for random access, 64 | considering memory size 65 | """ 66 | if wav_rxfilename.endswith('|'): 67 | # input piped command 68 | p = subprocess.Popen(wav_rxfilename[:-1], shell=True, 69 | stdout=subprocess.PIPE) 70 | data, samplerate = sf.read(io.BytesIO(p.stdout.read()), 71 | dtype='float32') 72 | # cannot seek 73 | data = data[start:end] 74 | elif wav_rxfilename == '-': 75 | # stdin 76 | data, samplerate = sf.read(sys.stdin, dtype='float32') 77 | # cannot seek 78 | data = data[start:end] 79 | else: 80 | # normal wav file 81 | data, samplerate = sf.read(wav_rxfilename, start=start, stop=end) 82 | return data, samplerate 83 | 84 | 85 | def load_utt2spk(utt2spk_file): 86 | """ returns dictionary { uttid: spkid } """ 87 | lines = [line.strip().split(None, 1) for line in open(utt2spk_file)] 88 | return {x[0]: x[1] for x in lines} 89 | 90 | 91 | def load_spk2utt(spk2utt_file): 92 | """ returns dictionary { spkid: list of uttids } """ 93 | if not os.path.exists(spk2utt_file): 94 | return None 95 | lines = [line.strip().split() for line in open(spk2utt_file)] 96 | return {x[0]: x[1:] for x in lines} 97 | 98 | 99 | def load_reco2dur(reco2dur_file): 100 | """ returns dictionary { recid: duration } """ 101 | if not os.path.exists(reco2dur_file): 102 | return None 103 | lines = [line.strip().split(None, 1) for line in open(reco2dur_file)] 104 | return {x[0]: float(x[1]) for x in lines} 105 | 106 | 107 | def process_wav(wav_rxfilename, process): 108 | """ This function returns preprocessed wav_rxfilename 109 | Args: 110 | wav_rxfilename: input 111 | process: command which can be connected via pipe, 112 | use stdin and stdout 113 | Returns: 114 | wav_rxfilename: output piped command 115 | """ 116 | if wav_rxfilename.endswith('|'): 117 | # input piped command 118 | return wav_rxfilename + process + "|" 119 | else: 120 | # stdin "-" or normal file 121 | return "cat {} | {} |".format(wav_rxfilename, process) 122 | 123 | 124 | def extract_segments(wavs, segments=None): 125 | """ This function returns generator of segmented audio as 126 | (utterance id, numpy.float32 array) 127 | TODO?: sampling rate is not converted. 128 | """ 129 | if segments is not None: 130 | # segments should be sorted by rec-id 131 | for seg in segments: 132 | wav = wavs[seg['rec']] 133 | data, samplerate = load_wav(wav) 134 | st_sample = np.rint(seg['st'] * samplerate).astype(int) 135 | et_sample = np.rint(seg['et'] * samplerate).astype(int) 136 | yield seg['utt'], data[st_sample:et_sample] 137 | else: 138 | # segments file not found, 139 | # wav.scp is used as segmented audio list 140 | for rec in wavs: 141 | data, samplerate = load_wav(wavs[rec]) 142 | yield rec, data 143 | 144 | 145 | class KaldiData: 146 | def __init__(self, data_dir): 147 | self.data_dir = data_dir 148 | self.segments = load_segments_rechash( 149 | os.path.join(self.data_dir, 'segments')) 150 | self.utt2spk = load_utt2spk( 151 | os.path.join(self.data_dir, 'utt2spk')) 152 | self.wavs = load_wav_scp( 153 | os.path.join(self.data_dir, 'wav.scp')) 154 | self.reco2dur = load_reco2dur( 155 | os.path.join(self.data_dir, 'reco2dur')) 156 | self.spk2utt = load_spk2utt( 157 | os.path.join(self.data_dir, 'spk2utt')) 158 | 159 | def load_wav(self, recid, start=0, end=None): 160 | data, rate = load_wav( 161 | self.wavs[recid], start, end) 162 | return data, rate 163 | -------------------------------------------------------------------------------- /eend/system_info.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita) 2 | # Licensed under the MIT license. 3 | 4 | import sys 5 | import chainer 6 | import cupy 7 | import cupy.cuda 8 | from cupy.cuda import cudnn 9 | 10 | 11 | def print_system_info(): 12 | pyver = sys.version.replace('\n', ' ') 13 | print(f"python version: {pyver}") 14 | print(f"chainer version: {chainer.__version__}") 15 | print(f"cupy version: {cupy.__version__}") 16 | print(f"cuda version: {cupy.cuda.runtime.runtimeGetVersion()}") 17 | print(f"cudnn version: {cudnn.getVersion()}") 18 | -------------------------------------------------------------------------------- /egs/callhome/v1/cmd.sh: -------------------------------------------------------------------------------- 1 | # Modify this file according to a job scheduling system in your cluster. 2 | # For more information about cmd.sh see http://kaldi-asr.org/doc/queue.html. 3 | # 4 | # If you use your local machine, use "run.pl". 5 | # export train_cmd="run.pl" 6 | # export infer_cmd="run.pl" 7 | # export simu_cmd="run.pl" 8 | 9 | # If you use Grid Engine, use "queue.pl" 10 | export train_cmd="queue.pl --mem 32G -l 'hostname=c*'" 11 | export infer_cmd="queue.pl --mem 32G -l 'hostname=c*'" 12 | export simu_cmd="queue.pl" 13 | 14 | # If you use SLURM, use "slurm.pl". 15 | # export train_cmd="slurm.pl" 16 | # export infer_cmd="slurm.pl" 17 | # export simu_cmd="slurm.pl" 18 | -------------------------------------------------------------------------------- /egs/callhome/v1/conf/adapt.yaml: -------------------------------------------------------------------------------- 1 | # adapt options 2 | sampling_rate: 8000 3 | frame_size: 200 4 | frame_shift: 80 5 | model_type: Transformer 6 | max_epochs: 100 7 | gradclip: 5 8 | batchsize: 64 9 | hidden_size: 256 10 | num_frames: 500 11 | num_speakers: 2 12 | input_transform: logmel23_mn 13 | optimizer: adam 14 | lr: 1e-5 15 | context_size: 7 16 | subsampling: 10 17 | noam_scale: 1.0 18 | gradient_accumulation_steps: 1 19 | transformer_encoder_n_heads: 4 20 | transformer_encoder_n_layers: 2 21 | transformer_encoder_dropout: 0.1 22 | noam_warmup_steps: 25000 23 | seed: 777 24 | gpu: 0 25 | -------------------------------------------------------------------------------- /egs/callhome/v1/conf/blstm/adapt.yaml: -------------------------------------------------------------------------------- 1 | # adapt options 2 | sampling_rate: 8000 3 | frame_size: 200 4 | frame_shift: 80 5 | model_type: BLSTM 6 | max_epochs: 10 7 | gradclip: 5 8 | batchsize: 10 9 | hidden_size: 256 10 | num_lstm_layers: 5 11 | num_frames: 4000 12 | num_speakers: 2 13 | input_transform: logmel23_mn 14 | optimizer: adam 15 | lr: 1e-6 16 | context_size: 7 17 | subsampling: 10 18 | seed: 777 19 | gpu: 0 20 | -------------------------------------------------------------------------------- /egs/callhome/v1/conf/blstm/infer.yaml: -------------------------------------------------------------------------------- 1 | # inference options 2 | sampling_rate: 8000 3 | frame_size: 200 4 | frame_shift: 80 5 | model_type: BLSTM 6 | num_lstm_layers: 5 7 | hidden_size: 256 8 | num_speakers: 2 9 | input_transform: logmel23_mn 10 | context_size: 7 11 | subsampling: 10 12 | chunk_size: 4000 13 | gpu: 0 14 | -------------------------------------------------------------------------------- /egs/callhome/v1/conf/blstm/train.yaml: -------------------------------------------------------------------------------- 1 | # training options 2 | sampling_rate: 8000 3 | frame_size: 200 4 | frame_shift: 80 5 | model_type: BLSTM 6 | max_epochs: 20 7 | gradclip: 5 8 | batchsize: 10 9 | hidden_size: 256 10 | num_lstm_layers: 5 11 | lr: 0.001 12 | num_frames: 4000 13 | num_speakers: 2 14 | input_transform: logmel23_mn 15 | optimizer: adam 16 | context_size: 7 17 | subsampling: 10 18 | seed: 777 19 | gpu: 0 20 | -------------------------------------------------------------------------------- /egs/callhome/v1/conf/debug/adapt.yaml: -------------------------------------------------------------------------------- 1 | # adapt options 2 | sampling_rate: 8000 3 | frame_size: 200 4 | frame_shift: 80 5 | model_type: Transformer 6 | max_epochs: 10 7 | gradclip: 5 8 | batchsize: 64 9 | hidden_size: 256 10 | num_frames: 500 11 | num_speakers: 2 12 | input_transform: logmel23_mn 13 | optimizer: adam 14 | lr: 1e-5 15 | context_size: 7 16 | subsampling: 10 17 | noam_scale: 1.0 18 | gradient_accumulation_steps: 1 19 | transformer_encoder_n_heads: 4 20 | transformer_encoder_n_layers: 2 21 | transformer_encoder_dropout: 0.1 22 | noam_warmup_steps: 25000 23 | seed: 777 24 | gpu: 0 25 | -------------------------------------------------------------------------------- /egs/callhome/v1/conf/debug/train.yaml: -------------------------------------------------------------------------------- 1 | # training options 2 | sampling_rate: 8000 3 | frame_size: 200 4 | frame_shift: 80 5 | model_type: Transformer 6 | max_epochs: 5 7 | gradclip: 5 8 | batchsize: 64 9 | hidden_size: 256 10 | num_frames: 500 11 | num_speakers: 2 12 | input_transform: logmel23_mn 13 | optimizer: noam 14 | context_size: 7 15 | subsampling: 10 16 | noam_scale: 1.0 17 | gradient_accumulation_steps: 1 18 | transformer_encoder_n_heads: 4 19 | transformer_encoder_n_layers: 2 20 | transformer_encoder_dropout: 0.1 21 | noam_warmup_steps: 25000 22 | seed: 777 23 | gpu: 0 24 | -------------------------------------------------------------------------------- /egs/callhome/v1/conf/eda/adapt.yaml: -------------------------------------------------------------------------------- 1 | # adapt options 2 | sampling_rate: 8000 3 | frame_size: 200 4 | frame_shift: 80 5 | model_type: Transformer 6 | max_epochs: 100 7 | gradclip: 5 8 | batchsize: 64 9 | hidden_size: 256 10 | num_frames: 500 11 | input_transform: logmel23_mn 12 | optimizer: adam 13 | lr: 1e-5 14 | context_size: 7 15 | subsampling: 10 16 | gradient_accumulation_steps: 1 17 | transformer_encoder_n_heads: 4 18 | transformer_encoder_n_layers: 4 19 | transformer_encoder_dropout: 0.1 20 | use_attractor: True 21 | shuffle: True 22 | attractor_loss_ratio: 0.01 23 | attractor_encoder_dropout: 0.1 24 | attractor_decoder_dropout: 0.1 25 | seed: 777 26 | gpu: 0 27 | -------------------------------------------------------------------------------- /egs/callhome/v1/conf/eda/infer.yaml: -------------------------------------------------------------------------------- 1 | # inference options 2 | sampling_rate: 8000 3 | frame_size: 200 4 | frame_shift: 80 5 | model_type: Transformer 6 | hidden_size: 256 7 | input_transform: logmel23_mn 8 | context_size: 7 9 | subsampling: 10 10 | chunk_size: 200000000 11 | transformer_encoder_n_heads: 4 12 | transformer_encoder_n_layers: 4 13 | use_attractor: True 14 | shuffle: True 15 | attractor_encoder_dropout: 0.1 16 | attractor_decoder_dropout: 0.1 17 | gpu: 0 18 | -------------------------------------------------------------------------------- /egs/callhome/v1/conf/eda/train.yaml: -------------------------------------------------------------------------------- 1 | # training options 2 | sampling_rate: 8000 3 | frame_size: 200 4 | frame_shift: 80 5 | model_type: Transformer 6 | max_epochs: 25 7 | gradclip: 5 8 | batchsize: 64 9 | hidden_size: 256 10 | num_frames: 500 11 | input_transform: logmel23_mn 12 | optimizer: noam 13 | context_size: 7 14 | subsampling: 10 15 | noam_scale: 1.0 16 | gradient_accumulation_steps: 1 17 | transformer_encoder_n_heads: 4 18 | transformer_encoder_n_layers: 4 19 | transformer_encoder_dropout: 0.1 20 | noam_warmup_steps: 100000 21 | use_attractor: True 22 | shuffle: True 23 | attractor_loss_ratio: 1.0 24 | attractor_encoder_dropout: 0.1 25 | attractor_decoder_dropout: 0.1 26 | seed: 777 27 | gpu: 0 28 | -------------------------------------------------------------------------------- /egs/callhome/v1/conf/eda/train_2spk.yaml: -------------------------------------------------------------------------------- 1 | # training options 2 | sampling_rate: 8000 3 | frame_size: 200 4 | frame_shift: 80 5 | model_type: Transformer 6 | max_epochs: 100 7 | gradclip: 5 8 | batchsize: 64 9 | hidden_size: 256 10 | num_frames: 500 11 | num_speakers: 2 12 | input_transform: logmel23_mn 13 | optimizer: noam 14 | context_size: 7 15 | subsampling: 10 16 | noam_scale: 1.0 17 | gradient_accumulation_steps: 1 18 | transformer_encoder_n_heads: 4 19 | transformer_encoder_n_layers: 4 20 | transformer_encoder_dropout: 0.1 21 | noam_warmup_steps: 100000 22 | use_attractor: True 23 | shuffle: True 24 | attractor_loss_ratio: 1.0 25 | attractor_encoder_dropout: 0.1 26 | attractor_decoder_dropout: 0.1 27 | seed: 777 28 | gpu: 0 29 | -------------------------------------------------------------------------------- /egs/callhome/v1/conf/infer.yaml: -------------------------------------------------------------------------------- 1 | # inference options 2 | sampling_rate: 8000 3 | frame_size: 200 4 | frame_shift: 80 5 | model_type: Transformer 6 | hidden_size: 256 7 | num_speakers: 2 8 | input_transform: logmel23_mn 9 | context_size: 7 10 | subsampling: 10 11 | chunk_size: 2000 12 | transformer_encoder_n_heads: 4 13 | transformer_encoder_n_layers: 2 14 | gpu: 0 15 | -------------------------------------------------------------------------------- /egs/callhome/v1/conf/mfcc.conf: -------------------------------------------------------------------------------- 1 | --sample-frequency=8000 2 | --frame-length=25 # the default is 25 3 | --low-freq=20 # the default. 4 | --high-freq=3700 # the default is zero meaning use the Nyquist (4k in this case). 5 | --num-ceps=23 # higher than the default which is 12. 6 | --snip-edges=false 7 | -------------------------------------------------------------------------------- /egs/callhome/v1/conf/mfcc_hires.conf: -------------------------------------------------------------------------------- 1 | # config for high-resolution MFCC features, intended for neural network training. 2 | # Note: we keep all cepstra, so it has the same info as filterbank features, 3 | # but MFCC is more easily compressible (because less correlated) which is why 4 | # we prefer this method. 5 | --use-energy=false # use average of log energy, not energy. 6 | --sample-frequency=8000 # Switchboard is sampled at 8kHz 7 | --num-mel-bins=40 # similar to Google's setup. 8 | --num-ceps=40 # there is no dimensionality reduction. 9 | --low-freq=40 # low cutoff frequency for mel bins 10 | --high-freq=-200 # high cutoff frequently, relative to Nyquist of 4000 (=3800) 11 | --allow-downsample=true 12 | -------------------------------------------------------------------------------- /egs/callhome/v1/conf/train.yaml: -------------------------------------------------------------------------------- 1 | # training options 2 | sampling_rate: 8000 3 | frame_size: 200 4 | frame_shift: 80 5 | model_type: Transformer 6 | max_epochs: 100 7 | gradclip: 5 8 | batchsize: 64 9 | hidden_size: 256 10 | num_frames: 500 11 | num_speakers: 2 12 | input_transform: logmel23_mn 13 | optimizer: noam 14 | context_size: 7 15 | subsampling: 10 16 | noam_scale: 1.0 17 | gradient_accumulation_steps: 1 18 | transformer_encoder_n_heads: 4 19 | transformer_encoder_n_layers: 2 20 | transformer_encoder_dropout: 0.1 21 | noam_warmup_steps: 25000 22 | seed: 777 23 | gpu: 0 24 | -------------------------------------------------------------------------------- /egs/callhome/v1/local/make_callhome.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2017 David Snyder 3 | # Apache 2.0. 4 | # 5 | # This script prepares the Callhome portion of the NIST SRE 2000 6 | # corpus (LDC2001S97). It is the evaluation dataset used in the 7 | # callhome_diarization recipe. 8 | 9 | if [ $# -ne 2 ]; then 10 | echo "Usage: $0 " 11 | echo "e.g.: $0 /mnt/data/LDC2001S97 data/" 12 | exit 1; 13 | fi 14 | 15 | src_dir=$1 16 | data_dir=$2 17 | 18 | tmp_dir=$data_dir/callhome/.tmp/ 19 | mkdir -p $tmp_dir 20 | 21 | # Download some metadata that wasn't provided in the LDC release 22 | if [ ! -d "$tmp_dir/sre2000-key" ]; then 23 | wget --no-check-certificate -P $tmp_dir/ \ 24 | http://www.openslr.org/resources/10/sre2000-key.tar.gz 25 | tar -xvf $tmp_dir/sre2000-key.tar.gz -C $tmp_dir/ 26 | fi 27 | 28 | # The list of 500 recordings 29 | awk '{print $1}' $tmp_dir/sre2000-key/reco2num > $tmp_dir/reco.list 30 | 31 | # Create wav.scp file 32 | count=0 33 | missing=0 34 | while read reco; do 35 | path=$(find $src_dir -name "$reco.sph") 36 | if [ -z "${path// }" ]; then 37 | >&2 echo "$0: Missing Sphere file for $reco" 38 | missing=$((missing+1)) 39 | else 40 | echo "$reco sph2pipe -f wav -p $path |" 41 | fi 42 | count=$((count+1)) 43 | done < $tmp_dir/reco.list > $data_dir/callhome/wav.scp 44 | 45 | if [ $missing -gt 0 ]; then 46 | echo "$0: Missing $missing out of $count recordings" 47 | fi 48 | 49 | cp $tmp_dir/sre2000-key/segments $data_dir/callhome/ 50 | awk '{print $1, $2}' $data_dir/callhome/segments > $data_dir/callhome/utt2spk 51 | utils/utt2spk_to_spk2utt.pl $data_dir/callhome/utt2spk > $data_dir/callhome/spk2utt 52 | cp $tmp_dir/sre2000-key/reco2num $data_dir/callhome/reco2num_spk 53 | cp $tmp_dir/sre2000-key/fullref.rttm $data_dir/callhome/ 54 | 55 | utils/validate_data_dir.sh --no-text --no-feats $data_dir/callhome 56 | utils/fix_data_dir.sh $data_dir/callhome 57 | 58 | utils/copy_data_dir.sh $data_dir/callhome $data_dir/callhome1 59 | utils/copy_data_dir.sh $data_dir/callhome $data_dir/callhome2 60 | 61 | utils/shuffle_list.pl $data_dir/callhome/wav.scp | head -n 250 \ 62 | | utils/filter_scp.pl - $data_dir/callhome/wav.scp \ 63 | > $data_dir/callhome1/wav.scp 64 | utils/fix_data_dir.sh $data_dir/callhome1 65 | utils/filter_scp.pl --exclude $data_dir/callhome1/wav.scp \ 66 | $data_dir/callhome/wav.scp > $data_dir/callhome2/wav.scp 67 | utils/fix_data_dir.sh $data_dir/callhome2 68 | utils/filter_scp.pl $data_dir/callhome1/wav.scp $data_dir/callhome/reco2num_spk \ 69 | > $data_dir/callhome1/reco2num_spk 70 | utils/filter_scp.pl $data_dir/callhome2/wav.scp $data_dir/callhome/reco2num_spk \ 71 | > $data_dir/callhome2/reco2num_spk 72 | 73 | rm -rf $tmp_dir 2> /dev/null 74 | -------------------------------------------------------------------------------- /egs/callhome/v1/local/make_musan.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2015 David Snyder 3 | # 2018 Ewald Enzinger 4 | # Apache 2.0. 5 | # 6 | # Modified version of egs/sre16/v1/local/make_musan.py (commit e3fb7c4a0da4167f8c94b80f4d3cc5ab4d0e22e8). 7 | # This version uses the raw MUSAN audio files (16 kHz) and does not use sox to resample at 8 kHz. 8 | # 9 | # This file is meant to be invoked by make_musan.sh. 10 | 11 | import os, sys 12 | 13 | def process_music_annotations(path): 14 | utt2spk = {} 15 | utt2vocals = {} 16 | lines = open(path, 'r').readlines() 17 | for line in lines: 18 | utt, genres, vocals, musician = line.rstrip().split()[:4] 19 | # For this application, the musican ID isn't important 20 | utt2spk[utt] = utt 21 | utt2vocals[utt] = vocals == "Y" 22 | return utt2spk, utt2vocals 23 | 24 | def prepare_music(root_dir, use_vocals): 25 | utt2vocals = {} 26 | utt2spk = {} 27 | utt2wav = {} 28 | num_good_files = 0 29 | num_bad_files = 0 30 | music_dir = os.path.join(root_dir, "music") 31 | for root, dirs, files in os.walk(music_dir): 32 | for file in files: 33 | file_path = os.path.join(root, file) 34 | if file.endswith(".wav"): 35 | utt = str(file).replace(".wav", "") 36 | utt2wav[utt] = file_path 37 | elif str(file) == "ANNOTATIONS": 38 | utt2spk_part, utt2vocals_part = process_music_annotations(file_path) 39 | utt2spk.update(utt2spk_part) 40 | utt2vocals.update(utt2vocals_part) 41 | utt2spk_str = "" 42 | utt2wav_str = "" 43 | for utt in utt2vocals: 44 | if utt in utt2wav: 45 | if use_vocals or not utt2vocals[utt]: 46 | utt2spk_str = utt2spk_str + utt + " " + utt2spk[utt] + "\n" 47 | utt2wav_str = utt2wav_str + utt + " " + utt2wav[utt] + "\n" 48 | num_good_files += 1 49 | else: 50 | print("Missing file {}".format(utt)) 51 | num_bad_files += 1 52 | print("In music directory, processed {} files: {} had missing wav data".format(num_good_files, num_bad_files)) 53 | return utt2spk_str, utt2wav_str 54 | 55 | def prepare_speech(root_dir): 56 | utt2spk = {} 57 | utt2wav = {} 58 | num_good_files = 0 59 | num_bad_files = 0 60 | speech_dir = os.path.join(root_dir, "speech") 61 | for root, dirs, files in os.walk(speech_dir): 62 | for file in files: 63 | file_path = os.path.join(root, file) 64 | if file.endswith(".wav"): 65 | utt = str(file).replace(".wav", "") 66 | utt2wav[utt] = file_path 67 | utt2spk[utt] = utt 68 | utt2spk_str = "" 69 | utt2wav_str = "" 70 | for utt in utt2spk: 71 | if utt in utt2wav: 72 | utt2spk_str = utt2spk_str + utt + " " + utt2spk[utt] + "\n" 73 | utt2wav_str = utt2wav_str + utt + " " + utt2wav[utt] + "\n" 74 | num_good_files += 1 75 | else: 76 | print("Missing file {}".format(utt)) 77 | num_bad_files += 1 78 | print("In speech directory, processed {} files: {} had missing wav data".format(num_good_files, num_bad_files)) 79 | return utt2spk_str, utt2wav_str 80 | 81 | def prepare_noise(root_dir): 82 | utt2spk = {} 83 | utt2wav = {} 84 | num_good_files = 0 85 | num_bad_files = 0 86 | noise_dir = os.path.join(root_dir, "noise") 87 | for root, dirs, files in os.walk(noise_dir): 88 | for file in files: 89 | file_path = os.path.join(root, file) 90 | if file.endswith(".wav"): 91 | utt = str(file).replace(".wav", "") 92 | utt2wav[utt] = file_path 93 | utt2spk[utt] = utt 94 | utt2spk_str = "" 95 | utt2wav_str = "" 96 | for utt in utt2spk: 97 | if utt in utt2wav: 98 | utt2spk_str = utt2spk_str + utt + " " + utt2spk[utt] + "\n" 99 | utt2wav_str = utt2wav_str + utt + " " + utt2wav[utt] + "\n" 100 | num_good_files += 1 101 | else: 102 | print("Missing file {}".format(utt)) 103 | num_bad_files += 1 104 | print("In noise directory, processed {} files: {} had missing wav data".format(num_good_files, num_bad_files)) 105 | return utt2spk_str, utt2wav_str 106 | 107 | def main(): 108 | in_dir = sys.argv[1] 109 | out_dir = sys.argv[2] 110 | use_vocals = sys.argv[3] == "Y" 111 | utt2spk_music, utt2wav_music = prepare_music(in_dir, use_vocals) 112 | utt2spk_speech, utt2wav_speech = prepare_speech(in_dir) 113 | utt2spk_noise, utt2wav_noise = prepare_noise(in_dir) 114 | utt2spk = utt2spk_speech + utt2spk_music + utt2spk_noise 115 | utt2wav = utt2wav_speech + utt2wav_music + utt2wav_noise 116 | wav_fi = open(os.path.join(out_dir, "wav.scp"), 'w') 117 | wav_fi.write(utt2wav) 118 | utt2spk_fi = open(os.path.join(out_dir, "utt2spk"), 'w') 119 | utt2spk_fi.write(utt2spk) 120 | 121 | 122 | if __name__=="__main__": 123 | main() 124 | -------------------------------------------------------------------------------- /egs/callhome/v1/local/make_musan.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2015 David Snyder 3 | # Apache 2.0. 4 | # 5 | # This script, called by ../run.sh, creates the MUSAN 6 | # data directory. The required dataset is freely available at 7 | # http://www.openslr.org/17/ 8 | 9 | set -e 10 | in_dir=$1 11 | data_dir=$2 12 | use_vocals='Y' 13 | 14 | mkdir -p local/musan.tmp 15 | 16 | echo "Preparing ${data_dir}/musan..." 17 | mkdir -p ${data_dir}/musan 18 | local/make_musan.py ${in_dir} ${data_dir}/musan ${use_vocals} 19 | 20 | utils/fix_data_dir.sh ${data_dir}/musan 21 | 22 | grep "music" ${data_dir}/musan/utt2spk > local/musan.tmp/utt2spk_music 23 | grep "speech" ${data_dir}/musan/utt2spk > local/musan.tmp/utt2spk_speech 24 | grep "noise" ${data_dir}/musan/utt2spk > local/musan.tmp/utt2spk_noise 25 | utils/subset_data_dir.sh --utt-list local/musan.tmp/utt2spk_music \ 26 | ${data_dir}/musan ${data_dir}/musan_music 27 | utils/subset_data_dir.sh --utt-list local/musan.tmp/utt2spk_speech \ 28 | ${data_dir}/musan ${data_dir}/musan_speech 29 | utils/subset_data_dir.sh --utt-list local/musan.tmp/utt2spk_noise \ 30 | ${data_dir}/musan ${data_dir}/musan_noise 31 | 32 | utils/fix_data_dir.sh ${data_dir}/musan_music 33 | utils/fix_data_dir.sh ${data_dir}/musan_speech 34 | utils/fix_data_dir.sh ${data_dir}/musan_noise 35 | 36 | rm -rf local/musan.tmp 37 | 38 | -------------------------------------------------------------------------------- /egs/callhome/v1/local/make_sre.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/perl 2 | # 3 | # Copyright 2015 David Snyder 4 | # Apache 2.0. 5 | # Usage: make_sre.pl 6 | 7 | if (@ARGV != 4) { 8 | print STDERR "Usage: $0 \n"; 9 | print STDERR "e.g. $0 /export/corpora5/LDC/LDC2006S44 sre2004 sre_ref data/sre2004\n"; 10 | exit(1); 11 | } 12 | 13 | ($db_base, $sre_name, $sre_ref_filename, $out_dir) = @ARGV; 14 | %utt2sph = (); 15 | %spk2gender = (); 16 | 17 | $tmp_dir = "$out_dir/tmp"; 18 | if (system("mkdir -p $tmp_dir") != 0) { 19 | die "Error making directory $tmp_dir"; 20 | } 21 | 22 | if (system("find $db_base -name '*.sph' > $tmp_dir/sph.list") != 0) { 23 | die "Error getting list of sph files"; 24 | } 25 | open(WAVLIST, "<", "$tmp_dir/sph.list") or die "cannot open wav list"; 26 | 27 | while() { 28 | chomp; 29 | $sph = $_; 30 | @A1 = split("/",$sph); 31 | @A2 = split("[./]",$A1[$#A1]); 32 | $uttId=$A2[0]; 33 | $utt2sph{$uttId} = $sph; 34 | } 35 | 36 | open(GNDR,">", "$out_dir/spk2gender") or die "Could not open the output file $out_dir/spk2gender"; 37 | open(SPKR,">", "$out_dir/utt2spk") or die "Could not open the output file $out_dir/utt2spk"; 38 | open(WAV,">", "$out_dir/wav.scp") or die "Could not open the output file $out_dir/wav.scp"; 39 | open(SRE_REF, "<", $sre_ref_filename) or die "Cannot open SRE reference."; 40 | while () { 41 | chomp; 42 | ($speaker, $gender, $other_sre_name, $utt_id, $channel) = split(" ", $_); 43 | $channel_num = "1"; 44 | if ($channel eq "A") { 45 | $channel_num = "1"; 46 | } else { 47 | $channel_num = "2"; 48 | } 49 | if (($other_sre_name eq $sre_name) and (exists $utt2sph{$utt_id})) { 50 | $full_utt_id = "$speaker-$gender-$sre_name-$utt_id-$channel"; 51 | $spk2gender{"$speaker-$gender"} = $gender; 52 | print WAV "$full_utt_id"," sph2pipe -f wav -p -c $channel_num $utt2sph{$utt_id} |\n"; 53 | print SPKR "$full_utt_id $speaker-$gender","\n"; 54 | } 55 | } 56 | foreach $speaker (keys %spk2gender) { 57 | print GNDR "$speaker $spk2gender{$speaker}\n"; 58 | } 59 | 60 | close(GNDR) || die; 61 | close(SPKR) || die; 62 | close(WAV) || die; 63 | close(SRE_REF) || die; 64 | -------------------------------------------------------------------------------- /egs/callhome/v1/local/make_sre.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2015 David Snyder 3 | # Apache 2.0. 4 | # 5 | # See README.txt for more info on data required. 6 | 7 | set -e 8 | 9 | data_root=$1 10 | data_dir=$2 11 | 12 | wget -P data/local/ http://www.openslr.org/resources/15/speaker_list.tgz 13 | tar -C data/local/ -xvf data/local/speaker_list.tgz 14 | sre_ref=data/local/speaker_list 15 | 16 | local/make_sre.pl $data_root/LDC2006S44/ \ 17 | sre2004 $sre_ref $data_dir/sre2004 18 | 19 | local/make_sre.pl $data_root/LDC2011S01 \ 20 | sre2005 $sre_ref $data_dir/sre2005_train 21 | 22 | local/make_sre.pl $data_root/LDC2011S04 \ 23 | sre2005 $sre_ref $data_dir/sre2005_test 24 | 25 | local/make_sre.pl $data_root/LDC2011S09 \ 26 | sre2006 $sre_ref $data_dir/sre2006_train 27 | 28 | local/make_sre.pl $data_root/LDC2011S10 \ 29 | sre2006 $sre_ref $data_dir/sre2006_test_1 30 | 31 | local/make_sre.pl $data_root/LDC2012S01 \ 32 | sre2006 $sre_ref $data_dir/sre2006_test_2 33 | 34 | local/make_sre.pl $data_root/LDC2011S05 \ 35 | sre2008 $sre_ref $data_dir/sre2008_train 36 | 37 | local/make_sre.pl $data_root/LDC2011S08 \ 38 | sre2008 $sre_ref $data_dir/sre2008_test 39 | 40 | utils/combine_data.sh $data_dir/sre \ 41 | $data_dir/sre2004 $data_dir/sre2005_train \ 42 | $data_dir/sre2005_test $data_dir/sre2006_train \ 43 | $data_dir/sre2006_test_1 $data_dir/sre2006_test_2 \ 44 | $data_dir/sre2008_train $data_dir/sre2008_test 45 | 46 | utils/validate_data_dir.sh --no-text --no-feats $data_dir/sre 47 | utils/fix_data_dir.sh $data_dir/sre 48 | rm data/local/speaker_list.* 49 | -------------------------------------------------------------------------------- /egs/callhome/v1/local/make_swbd2_phase1.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/perl 2 | use warnings; #sed replacement for -w perl parameter 3 | # 4 | # Copyright 2017 David Snyder 5 | # Apache 2.0 6 | 7 | if (@ARGV != 2) { 8 | print STDERR "Usage: $0 \n"; 9 | print STDERR "e.g. $0 /export/corpora3/LDC/LDC98S75 data/swbd2_phase1_train\n"; 10 | exit(1); 11 | } 12 | ($db_base, $out_dir) = @ARGV; 13 | 14 | if (system("mkdir -p $out_dir")) { 15 | die "Error making directory $out_dir"; 16 | } 17 | 18 | open(CS, "<$db_base/doc/callstat.tbl") || die "Could not open $db_base/doc/callstat.tbl"; 19 | open(GNDR, ">$out_dir/spk2gender") || die "Could not open the output file $out_dir/spk2gender"; 20 | open(SPKR, ">$out_dir/utt2spk") || die "Could not open the output file $out_dir/utt2spk"; 21 | open(WAV, ">$out_dir/wav.scp") || die "Could not open the output file $out_dir/wav.scp"; 22 | 23 | @badAudio = ("3", "4"); 24 | 25 | $tmp_dir = "$out_dir/tmp"; 26 | if (system("mkdir -p $tmp_dir") != 0) { 27 | die "Error making directory $tmp_dir"; 28 | } 29 | 30 | if (system("find $db_base -name '*.sph' > $tmp_dir/sph.list") != 0) { 31 | die "Error getting list of sph files"; 32 | } 33 | 34 | open(WAVLIST, "<$tmp_dir/sph.list") or die "cannot open wav list"; 35 | 36 | %wavs = (); 37 | while() { 38 | chomp; 39 | $sph = $_; 40 | @t = split("/",$sph); 41 | @t1 = split("[./]",$t[$#t]); 42 | $uttId = $t1[0]; 43 | $wavs{$uttId} = $sph; 44 | } 45 | 46 | while () { 47 | $line = $_ ; 48 | @A = split(",", $line); 49 | @A1 = split("[./]",$A[0]); 50 | $wav = $A1[0]; 51 | if (/$wav/i ~~ @badAudio) { 52 | # do nothing 53 | print "Bad Audio = $wav"; 54 | } else { 55 | $spkr1= "sw_" . $A[2]; 56 | $spkr2= "sw_" . $A[3]; 57 | $gender1 = $A[5]; 58 | $gender2 = $A[6]; 59 | if ($gender1 eq "M") { 60 | $gender1 = "m"; 61 | } elsif ($gender1 eq "F") { 62 | $gender1 = "f"; 63 | } else { 64 | die "Unknown Gender in $line"; 65 | } 66 | if ($gender2 eq "M") { 67 | $gender2 = "m"; 68 | } elsif ($gender2 eq "F") { 69 | $gender2 = "f"; 70 | } else { 71 | die "Unknown Gender in $line"; 72 | } 73 | if (-e "$wavs{$wav}") { 74 | $uttId = $spkr1 ."_" . $wav ."_1"; 75 | if (!$spk2gender{$spkr1}) { 76 | $spk2gender{$spkr1} = $gender1; 77 | print GNDR "$spkr1"," $gender1\n"; 78 | } 79 | print WAV "$uttId"," sph2pipe -f wav -p -c 1 $wavs{$wav} |\n"; 80 | print SPKR "$uttId"," $spkr1","\n"; 81 | 82 | $uttId = $spkr2 . "_" . $wav ."_2"; 83 | if (!$spk2gender{$spkr2}) { 84 | $spk2gender{$spkr2} = $gender2; 85 | print GNDR "$spkr2"," $gender2\n"; 86 | } 87 | print WAV "$uttId"," sph2pipe -f wav -p -c 2 $wavs{$wav} |\n"; 88 | print SPKR "$uttId"," $spkr2","\n"; 89 | } else { 90 | print STDERR "Missing $wavs{$wav} for $wav\n"; 91 | } 92 | } 93 | } 94 | 95 | close(WAV) || die; 96 | close(SPKR) || die; 97 | close(GNDR) || die; 98 | if (system("utils/utt2spk_to_spk2utt.pl $out_dir/utt2spk >$out_dir/spk2utt") != 0) { 99 | die "Error creating spk2utt file in directory $out_dir"; 100 | } 101 | if (system("utils/fix_data_dir.sh $out_dir") != 0) { 102 | die "Error fixing data dir $out_dir"; 103 | } 104 | if (system("utils/validate_data_dir.sh --no-text --no-feats $out_dir") != 0) { 105 | die "Error validating directory $out_dir"; 106 | } 107 | -------------------------------------------------------------------------------- /egs/callhome/v1/local/make_swbd2_phase2.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/perl 2 | use warnings; #sed replacement for -w perl parameter 3 | # 4 | # Copyright 2013 Daniel Povey 5 | # Apache 2.0 6 | 7 | if (@ARGV != 2) { 8 | print STDERR "Usage: $0 \n"; 9 | print STDERR "e.g. $0 /export/corpora5/LDC/LDC99S79 data/swbd2_phase2_train\n"; 10 | exit(1); 11 | } 12 | ($db_base, $out_dir) = @ARGV; 13 | 14 | if (system("mkdir -p $out_dir")) { 15 | die "Error making directory $out_dir"; 16 | } 17 | 18 | open(CS, "<$db_base/DISC1/doc/callstat.tbl") || die "Could not open $db_base/DISC1/doc/callstat.tbl"; 19 | open(CI, "<$db_base/DISC1/doc/callinfo.tbl") || die "Could not open $db_base/DISC1/doc/callinfo.tbl"; 20 | open(GNDR, ">$out_dir/spk2gender") || die "Could not open the output file $out_dir/spk2gender"; 21 | open(SPKR, ">$out_dir/utt2spk") || die "Could not open the output file $out_dir/utt2spk"; 22 | open(WAV, ">$out_dir/wav.scp") || die "Could not open the output file $out_dir/wav.scp"; 23 | 24 | @badAudio = ("3", "4"); 25 | 26 | $tmp_dir = "$out_dir/tmp"; 27 | if (system("mkdir -p $tmp_dir") != 0) { 28 | die "Error making directory $tmp_dir"; 29 | } 30 | 31 | if (system("find $db_base -name '*.sph' > $tmp_dir/sph.list") != 0) { 32 | die "Error getting list of sph files"; 33 | } 34 | 35 | open(WAVLIST, "<$tmp_dir/sph.list") or die "cannot open wav list"; 36 | 37 | while() { 38 | chomp; 39 | $sph = $_; 40 | @t = split("/",$sph); 41 | @t1 = split("[./]",$t[$#t]); 42 | $uttId=$t1[0]; 43 | $wav{$uttId} = $sph; 44 | } 45 | 46 | while () { 47 | $line = $_ ; 48 | $ci = ; 49 | $ci = ; 50 | @ci = split(",",$ci); 51 | $wav = $ci[0]; 52 | @A = split(",", $line); 53 | if (/$wav/i ~~ @badAudio) { 54 | # do nothing 55 | } else { 56 | $spkr1= "sw_" . $A[2]; 57 | $spkr2= "sw_" . $A[3]; 58 | $gender1 = $A[4]; 59 | $gender2 = $A[5]; 60 | if ($gender1 eq "M") { 61 | $gender1 = "m"; 62 | } elsif ($gender1 eq "F") { 63 | $gender1 = "f"; 64 | } else { 65 | die "Unknown Gender in $line"; 66 | } 67 | if ($gender2 eq "M") { 68 | $gender2 = "m"; 69 | } elsif ($gender2 eq "F") { 70 | $gender2 = "f"; 71 | } else { 72 | die "Unknown Gender in $line"; 73 | } 74 | if (-e "$wav{$wav}") { 75 | $uttId = $spkr1 ."_" . $wav ."_1"; 76 | if (!$spk2gender{$spkr1}) { 77 | $spk2gender{$spkr1} = $gender1; 78 | print GNDR "$spkr1"," $gender1\n"; 79 | } 80 | print WAV "$uttId"," sph2pipe -f wav -p -c 1 $wav{$wav} |\n"; 81 | print SPKR "$uttId"," $spkr1","\n"; 82 | 83 | $uttId = $spkr2 . "_" . $wav ."_2"; 84 | if (!$spk2gender{$spkr2}) { 85 | $spk2gender{$spkr2} = $gender2; 86 | print GNDR "$spkr2"," $gender2\n"; 87 | } 88 | print WAV "$uttId"," sph2pipe -f wav -p -c 2 $wav{$wav} |\n"; 89 | print SPKR "$uttId"," $spkr2","\n"; 90 | } else { 91 | print STDERR "Missing $wav{$wav} for $wav\n"; 92 | } 93 | } 94 | } 95 | 96 | close(WAV) || die; 97 | close(SPKR) || die; 98 | close(GNDR) || die; 99 | if (system("utils/utt2spk_to_spk2utt.pl $out_dir/utt2spk >$out_dir/spk2utt") != 0) { 100 | die "Error creating spk2utt file in directory $out_dir"; 101 | } 102 | if (system("utils/fix_data_dir.sh $out_dir") != 0) { 103 | die "Error fixing data dir $out_dir"; 104 | } 105 | if (system("utils/validate_data_dir.sh --no-text --no-feats $out_dir") != 0) { 106 | die "Error validating directory $out_dir"; 107 | } 108 | -------------------------------------------------------------------------------- /egs/callhome/v1/local/make_swbd2_phase3.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/perl 2 | use warnings; #sed replacement for -w perl parameter 3 | # 4 | # Copyright 2013 Daniel Povey 5 | # Apache 2.0 6 | 7 | if (@ARGV != 2) { 8 | print STDERR "Usage: $0 \n"; 9 | print STDERR "e.g. $0 /export/corpora5/LDC/LDC2002S06 data/swbd2_phase3_train\n"; 10 | exit(1); 11 | } 12 | ($db_base, $out_dir) = @ARGV; 13 | 14 | if (system("mkdir -p $out_dir")) { 15 | die "Error making directory $out_dir"; 16 | } 17 | 18 | open(CS, "<$db_base/DISC1/docs/callstat.tbl") || die "Could not open $db_base/DISC1/docs/callstat.tbl"; 19 | open(GNDR, ">$out_dir/spk2gender") || die "Could not open the output file $out_dir/spk2gender"; 20 | open(SPKR, ">$out_dir/utt2spk") || die "Could not open the output file $out_dir/utt2spk"; 21 | open(WAV, ">$out_dir/wav.scp") || die "Could not open the output file $out_dir/wav.scp"; 22 | 23 | @badAudio = ("3", "4"); 24 | 25 | $tmp_dir = "$out_dir/tmp"; 26 | if (system("mkdir -p $tmp_dir") != 0) { 27 | die "Error making directory $tmp_dir"; 28 | } 29 | 30 | if (system("find $db_base -name '*.sph' > $tmp_dir/sph.list") != 0) { 31 | die "Error getting list of sph files"; 32 | } 33 | 34 | open(WAVLIST, "<$tmp_dir/sph.list") or die "cannot open wav list"; 35 | while() { 36 | chomp; 37 | $sph = $_; 38 | @t = split("/",$sph); 39 | @t1 = split("[./]",$t[$#t]); 40 | $uttId=$t1[0]; 41 | $wav{$uttId} = $sph; 42 | } 43 | 44 | while () { 45 | $line = $_ ; 46 | @A = split(",", $line); 47 | $wav = "sw_" . $A[0] ; 48 | if (/$wav/i ~~ @badAudio) { 49 | # do nothing 50 | } else { 51 | $spkr1= "sw_" . $A[3]; 52 | $spkr2= "sw_" . $A[4]; 53 | $gender1 = $A[5]; 54 | $gender2 = $A[6]; 55 | if ($gender1 eq "M") { 56 | $gender1 = "m"; 57 | } elsif ($gender1 eq "F") { 58 | $gender1 = "f"; 59 | } else { 60 | die "Unknown Gender in $line"; 61 | } 62 | if ($gender2 eq "M") { 63 | $gender2 = "m"; 64 | } elsif ($gender2 eq "F") { 65 | $gender2 = "f"; 66 | } else { 67 | die "Unknown Gender in $line"; 68 | } 69 | if (-e "$wav{$wav}") { 70 | $uttId = $spkr1 ."_" . $wav ."_1"; 71 | if (!$spk2gender{$spkr1}) { 72 | $spk2gender{$spkr1} = $gender1; 73 | print GNDR "$spkr1"," $gender1\n"; 74 | } 75 | print WAV "$uttId"," sph2pipe -f wav -p -c 1 $wav{$wav} |\n"; 76 | print SPKR "$uttId"," $spkr1","\n"; 77 | 78 | $uttId = $spkr2 . "_" . $wav ."_2"; 79 | if (!$spk2gender{$spkr2}) { 80 | $spk2gender{$spkr2} = $gender2; 81 | print GNDR "$spkr2"," $gender2\n"; 82 | } 83 | print WAV "$uttId"," sph2pipe -f wav -p -c 2 $wav{$wav} |\n"; 84 | print SPKR "$uttId"," $spkr2","\n"; 85 | } else { 86 | print STDERR "Missing $wav{$wav} for $wav\n"; 87 | } 88 | } 89 | } 90 | 91 | close(WAV) || die; 92 | close(SPKR) || die; 93 | close(GNDR) || die; 94 | if (system("utils/utt2spk_to_spk2utt.pl $out_dir/utt2spk >$out_dir/spk2utt") != 0) { 95 | die "Error creating spk2utt file in directory $out_dir"; 96 | } 97 | if (system("utils/fix_data_dir.sh $out_dir") != 0) { 98 | die "Error fixing data dir $out_dir"; 99 | } 100 | if (system("utils/validate_data_dir.sh --no-text --no-feats $out_dir") != 0) { 101 | die "Error validating directory $out_dir"; 102 | } 103 | -------------------------------------------------------------------------------- /egs/callhome/v1/local/make_swbd_cellular1.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/perl 2 | use warnings; #sed replacement for -w perl parameter 3 | # 4 | # Copyright 2013 Daniel Povey 5 | # Apache 2.0 6 | 7 | if (@ARGV != 2) { 8 | print STDERR "Usage: $0 \n"; 9 | print STDERR "e.g. $0 /export/corpora5/LDC/LDC2001S13 data/swbd_cellular1_train\n"; 10 | exit(1); 11 | } 12 | ($db_base, $out_dir) = @ARGV; 13 | 14 | if (system("mkdir -p $out_dir")) { 15 | die "Error making directory $out_dir"; 16 | } 17 | 18 | open(CS, "<$db_base/doc/swb_callstats.tbl") || die "Could not open $db_base/doc/swb_callstats.tbl"; 19 | open(GNDR, ">$out_dir/spk2gender") || die "Could not open the output file $out_dir/spk2gender"; 20 | open(SPKR, ">$out_dir/utt2spk") || die "Could not open the output file $out_dir/utt2spk"; 21 | open(WAV, ">$out_dir/wav.scp") || die "Could not open the output file $out_dir/wav.scp"; 22 | 23 | @badAudio = ("40019", "45024", "40022"); 24 | 25 | while () { 26 | $line = $_ ; 27 | @A = split(",", $line); 28 | if (/$A[0]/i ~~ @badAudio) { 29 | # do nothing 30 | } else { 31 | $wav = "sw_" . $A[0]; 32 | $spkr1= "sw_" . $A[1]; 33 | $spkr2= "sw_" . $A[2]; 34 | $gender1 = $A[3]; 35 | $gender2 = $A[4]; 36 | if ($A[3] eq "M") { 37 | $gender1 = "m"; 38 | } elsif ($A[3] eq "F") { 39 | $gender1 = "f"; 40 | } else { 41 | die "Unknown Gender in $line"; 42 | } 43 | if ($A[4] eq "M") { 44 | $gender2 = "m"; 45 | } elsif ($A[4] eq "F") { 46 | $gender2 = "f"; 47 | } else { 48 | die "Unknown Gender in $line"; 49 | } 50 | if (-e "$db_base/$wav.sph") { 51 | $uttId = $spkr1 . "-swbdc_" . $wav ."_1"; 52 | if (!$spk2gender{$spkr1}) { 53 | $spk2gender{$spkr1} = $gender1; 54 | print GNDR "$spkr1"," $gender1\n"; 55 | } 56 | print WAV "$uttId"," sph2pipe -f wav -p -c 1 $db_base/$wav.sph |\n"; 57 | print SPKR "$uttId"," $spkr1","\n"; 58 | 59 | $uttId = $spkr2 . "-swbdc_" . $wav ."_2"; 60 | if (!$spk2gender{$spkr2}) { 61 | $spk2gender{$spkr2} = $gender2; 62 | print GNDR "$spkr2"," $gender2\n"; 63 | } 64 | print WAV "$uttId"," sph2pipe -f wav -p -c 2 $db_base/$wav.sph |\n"; 65 | print SPKR "$uttId"," $spkr2","\n"; 66 | } else { 67 | print STDERR "Missing $db_base/$wav.sph\n"; 68 | } 69 | } 70 | } 71 | 72 | close(WAV) || die; 73 | close(SPKR) || die; 74 | close(GNDR) || die; 75 | if (system("utils/utt2spk_to_spk2utt.pl $out_dir/utt2spk >$out_dir/spk2utt") != 0) { 76 | die "Error creating spk2utt file in directory $out_dir"; 77 | } 78 | if (system("utils/fix_data_dir.sh $out_dir") != 0) { 79 | die "Error fixing data dir $out_dir"; 80 | } 81 | if (system("utils/validate_data_dir.sh --no-text --no-feats $out_dir") != 0) { 82 | die "Error validating directory $out_dir"; 83 | } 84 | -------------------------------------------------------------------------------- /egs/callhome/v1/local/make_swbd_cellular2.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/perl 2 | use warnings; #sed replacement for -w perl parameter 3 | # 4 | # Copyright 2013 Daniel Povey 5 | # Apache 2.0 6 | 7 | if (@ARGV != 2) { 8 | print STDERR "Usage: $0 \n"; 9 | print STDERR "e.g. $0 /export/corpora5/LDC/LDC2004S07 data/swbd_cellular2_train\n"; 10 | exit(1); 11 | } 12 | ($db_base, $out_dir) = @ARGV; 13 | 14 | if (system("mkdir -p $out_dir")) { 15 | die "Error making directory $out_dir"; 16 | } 17 | 18 | open(CS, "<$db_base/docs/swb_callstats.tbl") || die "Could not open $db_base/docs/swb_callstats.tbl"; 19 | open(GNDR, ">$out_dir/spk2gender") || die "Could not open the output file $out_dir/spk2gender"; 20 | open(SPKR, ">$out_dir/utt2spk") || die "Could not open the output file $out_dir/utt2spk"; 21 | open(WAV, ">$out_dir/wav.scp") || die "Could not open the output file $out_dir/wav.scp"; 22 | 23 | @badAudio=("45024", "40022"); 24 | 25 | while () { 26 | $line = $_ ; 27 | @A = split(",", $line); 28 | if (/$A[0]/i ~~ @badAudio) { 29 | # do nothing 30 | } else { 31 | $wav = "sw_" . $A[0]; 32 | $spkr1= "sw_" . $A[1]; 33 | $spkr2= "sw_" . $A[2]; 34 | $gender1 = $A[3]; 35 | $gender2 = $A[4]; 36 | if ($A[3] eq "M") { 37 | $gender1 = "m"; 38 | } elsif ($A[3] eq "F") { 39 | $gender1 = "f"; 40 | } else { 41 | die "Unknown Gender in $line"; 42 | } 43 | if ($A[4] eq "M") { 44 | $gender2 = "m"; 45 | } elsif ($A[4] eq "F") { 46 | $gender2 = "f"; 47 | } else { 48 | die "Unknown Gender in $line"; 49 | } 50 | if (-e "$db_base/data/$wav.sph") { 51 | $uttId = $spkr1 . "-swbdc_" . $wav ."_1"; 52 | if (!$spk2gender{$spkr1}) { 53 | $spk2gender{$spkr1} = $gender1; 54 | print GNDR "$spkr1"," $gender1\n"; 55 | } 56 | print WAV "$uttId"," sph2pipe -f wav -p -c 1 $db_base/data/$wav.sph |\n"; 57 | print SPKR "$uttId"," $spkr1","\n"; 58 | 59 | $uttId = $spkr2 . "-swbdc_" . $wav ."_2"; 60 | if (!$spk2gender{$spkr2}) { 61 | $spk2gender{$spkr2} = $gender2; 62 | print GNDR "$spkr2"," $gender2\n"; 63 | } 64 | print WAV "$uttId"," sph2pipe -f wav -p -c 2 $db_base/data/$wav.sph |\n"; 65 | print SPKR "$uttId"," $spkr2","\n"; 66 | } else { 67 | print STDERR "Missing $db_base/data/$wav.sph\n"; 68 | } 69 | } 70 | } 71 | 72 | close(WAV) || die; 73 | close(SPKR) || die; 74 | close(GNDR) || die; 75 | if (system("utils/utt2spk_to_spk2utt.pl $out_dir/utt2spk >$out_dir/spk2utt") != 0) { 76 | die "Error creating spk2utt file in directory $out_dir"; 77 | } 78 | if (system("utils/fix_data_dir.sh $out_dir") != 0) { 79 | die "Error fixing data dir $out_dir"; 80 | } 81 | if (system("utils/validate_data_dir.sh --no-text --no-feats $out_dir") != 0) { 82 | die "Error validating directory $out_dir"; 83 | } 84 | -------------------------------------------------------------------------------- /egs/callhome/v1/local/run_blstm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita) 4 | # Licensed under the MIT license. 5 | # 6 | # BLSTM-based model experiment 7 | ./run.sh --train-config conf/blstm/train.yaml --average-start 20 --average-end 20 \ 8 | --adapt-config conf/blstm/adapt.yaml --adapt-average-start 10 --adapt-average-end 10 \ 9 | --infer-config conf/blstm/infer.yaml $* 10 | -------------------------------------------------------------------------------- /egs/callhome/v1/path.sh: -------------------------------------------------------------------------------- 1 | . ../../../tools/env.sh 2 | -------------------------------------------------------------------------------- /egs/callhome/v1/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita) 4 | # Licensed under the MIT license. 5 | # 6 | stage=0 7 | 8 | # The datasets for training must be formatted as kaldi data directory. 9 | # Also, make sure the audio files in wav.scp are 'regular' wav files. 10 | # Including piped commands in wav.scp makes training very slow 11 | train_set=data/simu/data/swb_sre_tr_ns2_beta2_100000 12 | valid_set=data/simu/data/swb_sre_cv_ns2_beta2_500 13 | adapt_set=data/eval/callhome1_spk2 14 | adapt_valid_set=data/eval/callhome2_spk2 15 | 16 | # Base config files for {train,infer}.py 17 | train_config=conf/train.yaml 18 | infer_config=conf/infer.yaml 19 | adapt_config=conf/adapt.yaml 20 | 21 | # Additional arguments passed to {train,infer}.py. 22 | # You need not edit the base config files above 23 | train_args= 24 | infer_args= 25 | adapt_args= 26 | 27 | # Model averaging options 28 | average_start=91 29 | average_end=100 30 | 31 | # Adapted model averaging options 32 | adapt_average_start=91 33 | adapt_average_end=100 34 | 35 | # Resume training from snapshot at this epoch 36 | # TODO: not tested 37 | resume=-1 38 | 39 | # Debug purpose 40 | debug= 41 | 42 | . path.sh 43 | . cmd.sh 44 | . parse_options.sh || exit 45 | 46 | set -eu 47 | 48 | if [ "$debug" != "" ]; then 49 | # debug mode 50 | train_set=data/simu/data/swb_sre_tr_ns2_beta2_1000 51 | train_config=conf/debug/train.yaml 52 | average_start=3 53 | average_end=5 54 | adapt_config=conf/debug/adapt.yaml 55 | adapt_average_start=6 56 | adapt_average_end=10 57 | fi 58 | 59 | # Parse the config file to set bash variables like: $train_frame_shift, $infer_gpu 60 | eval `yaml2bash.py --prefix train $train_config` 61 | eval `yaml2bash.py --prefix infer $infer_config` 62 | 63 | # Append gpu reservation flag to the queuing command 64 | if [ $train_gpu -le 0 ]; then 65 | train_cmd+=" --gpu 1" 66 | fi 67 | if [ $infer_gpu -le 0 ]; then 68 | infer_cmd+=" --gpu 1" 69 | fi 70 | 71 | # Build directry names for an experiment 72 | # - Training 73 | # exp/diarize/model/{train_id}.{valid_id}.{train_config_id} 74 | # - Decoding 75 | # exp/diarize/infer/{train_id}.{valid_id}.{train_config_id}.{infer_config_id} 76 | # - Scoring 77 | # exp/diarize/scoring/{train_id}.{valid_id}.{train_config_id}.{infer_config_id} 78 | # - Adapation from non-adapted averaged model 79 | # exp/diarize/model/{train_id}.{valid_id}.{train_config_id}.{avgid}.{adapt_config_id} 80 | train_id=$(basename $train_set) 81 | valid_id=$(basename $valid_set) 82 | train_config_id=$(echo $train_config | sed -e 's%conf/%%' -e 's%/%_%' -e 's%\.yaml$%%') 83 | infer_config_id=$(echo $infer_config | sed -e 's%conf/%%' -e 's%/%_%' -e 's%\.yaml$%%') 84 | adapt_config_id=$(echo $adapt_config | sed -e 's%conf/%%' -e 's%/%_%' -e 's%\.yaml$%%') 85 | 86 | # Additional arguments are added to config_id 87 | train_config_id+=$(echo $train_args | sed -e 's/\-\-/_/g' -e 's/=//g' -e 's/ \+//g') 88 | infer_config_id+=$(echo $infer_args | sed -e 's/\-\-/_/g' -e 's/=//g' -e 's/ \+//g') 89 | adapt_config_id+=$(echo $adapt_args | sed -e 's/\-\-/_/g' -e 's/=//g' -e 's/ \+//g') 90 | 91 | model_id=$train_id.$valid_id.$train_config_id 92 | model_dir=exp/diarize/model/$model_id 93 | if [ $stage -le 1 ]; then 94 | echo "training model at $model_dir." 95 | if [ -d $model_dir ]; then 96 | echo "$model_dir already exists. " 97 | echo " if you want to retry, please remove it." 98 | exit 1 99 | fi 100 | work=$model_dir/.work 101 | mkdir -p $work 102 | $train_cmd $work/train.log \ 103 | train.py \ 104 | -c $train_config \ 105 | $train_args \ 106 | $train_set $valid_set $model_dir \ 107 | || exit 1 108 | fi 109 | 110 | ave_id=avg${average_start}-${average_end} 111 | if [ $stage -le 2 ]; then 112 | echo "averaging model parameters into $model_dir/$ave_id.nnet.npz" 113 | if [ -s $model_dir/$ave_id.nnet.npz ]; then 114 | echo "$model_dir/$ave_id.nnet.npz already exists. " 115 | echo " if you want to retry, please remove it." 116 | exit 1 117 | fi 118 | models=`eval echo $model_dir/snapshot_epoch-{$average_start..$average_end}` 119 | model_averaging.py $model_dir/$ave_id.nnet.npz $models || exit 1 120 | fi 121 | 122 | infer_dir=exp/diarize/infer/$model_id.$ave_id.$infer_config_id 123 | if [ $stage -le 3 ]; then 124 | echo "inference at $infer_dir" 125 | if [ -d $infer_dir ]; then 126 | echo "$infer_dir already exists. " 127 | echo " if you want to retry, please remove it." 128 | exit 1 129 | fi 130 | for dset in callhome2_spk2; do 131 | work=$infer_dir/$dset/.work 132 | mkdir -p $work 133 | $infer_cmd $work/infer.log \ 134 | infer.py \ 135 | -c $infer_config \ 136 | $infer_args \ 137 | data/eval/$dset \ 138 | $model_dir/$ave_id.nnet.npz \ 139 | $infer_dir/$dset \ 140 | || exit 1 141 | done 142 | fi 143 | 144 | scoring_dir=exp/diarize/scoring/$model_id.$ave_id.$infer_config_id 145 | if [ $stage -le 4 ]; then 146 | echo "scoring at $scoring_dir" 147 | if [ -d $scoring_dir ]; then 148 | echo "$scoring_dir already exists. " 149 | echo " if you want to retry, please remove it." 150 | exit 1 151 | fi 152 | for dset in callhome2_spk2; do 153 | work=$scoring_dir/$dset/.work 154 | mkdir -p $work 155 | find $infer_dir/$dset -iname "*.h5" > $work/file_list_$dset 156 | for med in 1 11; do 157 | for th in 0.3 0.4 0.5 0.6 0.7; do 158 | make_rttm.py --median=$med --threshold=$th \ 159 | --frame_shift=$infer_frame_shift --subsampling=$infer_subsampling --sampling_rate=$infer_sampling_rate \ 160 | $work/file_list_$dset $scoring_dir/$dset/hyp_${th}_$med.rttm 161 | md-eval.pl -c 0.25 \ 162 | -r data/eval/$dset/rttm \ 163 | -s $scoring_dir/$dset/hyp_${th}_$med.rttm > $scoring_dir/$dset/result_th${th}_med${med}_collar0.25 2>/dev/null || exit 164 | done 165 | done 166 | done 167 | fi 168 | 169 | adapt_model_dir=exp/diarize/model/$model_id.$ave_id.$adapt_config_id 170 | if [ $stage -le 5 ]; then 171 | echo "adapting model at $adapt_model_dir" 172 | if [ -d $adapt_model_dir ]; then 173 | echo "$adapt_model_dir already exists. " 174 | echo " if you want to retry, please remove it." 175 | exit 1 176 | fi 177 | work=$adapt_model_dir/.work 178 | mkdir -p $work 179 | $train_cmd $work/train.log \ 180 | train.py \ 181 | -c $adapt_config \ 182 | $adapt_args \ 183 | --initmodel $model_dir/$ave_id.nnet.npz \ 184 | $adapt_set $adapt_valid_set $adapt_model_dir \ 185 | || exit 1 186 | fi 187 | 188 | adapt_ave_id=avg${adapt_average_start}-${adapt_average_end} 189 | if [ $stage -le 6 ]; then 190 | echo "averaging models into $adapt_model_dir/$adapt_ave_id.nnet.gz" 191 | if [ -s $adapt_model_dir/$adapt_ave_id.nnet.npz ]; then 192 | echo "$adapt_model_dir/$adapt_ave_id.nnet.npz already exists." 193 | echo " if you want to retry, please remove it." 194 | exit 1 195 | fi 196 | models=`eval echo $adapt_model_dir/snapshot_epoch-{$adapt_average_start..$adapt_average_end}` 197 | model_averaging.py $adapt_model_dir/$adapt_ave_id.nnet.npz $models || exit 1 198 | fi 199 | 200 | infer_dir=exp/diarize/infer/$model_id.$ave_id.$adapt_config_id.$adapt_ave_id.$infer_config_id 201 | if [ $stage -le 7 ]; then 202 | echo "inference at $infer_dir" 203 | if [ -d $infer_dir ]; then 204 | echo "$infer_dir already exists. " 205 | echo " if you want to retry, please remove it." 206 | exit 1 207 | fi 208 | for dset in callhome2_spk2; do 209 | work=$infer_dir/$dset/.work 210 | mkdir -p $work 211 | $train_cmd $work/infer.log \ 212 | infer.py -c $infer_config \ 213 | data/eval/${dset} \ 214 | $adapt_model_dir/$adapt_ave_id.nnet.npz \ 215 | $infer_dir/$dset \ 216 | || exit 1 217 | done 218 | fi 219 | 220 | scoring_dir=exp/diarize/scoring/$model_id.$ave_id.$adapt_config_id.$adapt_ave_id.$infer_config_id 221 | if [ $stage -le 8 ]; then 222 | echo "scoring at $scoring_dir" 223 | if [ -d $scoring_dir ]; then 224 | echo "$scoring_dir already exists. " 225 | echo " if you want to retry, please remove it." 226 | exit 1 227 | fi 228 | for dset in callhome2_spk2; do 229 | work=$scoring_dir/$dset/.work 230 | mkdir -p $work 231 | find $infer_dir/$dset -iname "*.h5" > $work/file_list_$dset 232 | for med in 1 11; do 233 | for th in 0.3 0.4 0.5 0.6 0.7; do 234 | make_rttm.py --median=$med --threshold=$th \ 235 | --frame_shift=$infer_frame_shift --subsampling=$infer_subsampling --sampling_rate=$infer_sampling_rate \ 236 | $work/file_list_$dset $scoring_dir/$dset/hyp_${th}_$med.rttm 237 | md-eval.pl -c 0.25 \ 238 | -r data/eval/$dset/rttm \ 239 | -s $scoring_dir/$dset/hyp_${th}_$med.rttm > $scoring_dir/$dset/result_th${th}_med${med}_collar0.25 2>/dev/null || exit 240 | done 241 | done 242 | done 243 | fi 244 | 245 | if [ $stage -le 9 ]; then 246 | for dset in callhome2_spk2; do 247 | best_score.sh $scoring_dir/$dset 248 | done 249 | fi 250 | echo "Finished !" 251 | 252 | -------------------------------------------------------------------------------- /egs/callhome/v1/run_eda.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright 2019-2020 Hitachi, Ltd. (author: Yusuke Fujita, Shota Horiguchi) 4 | # Licensed under the MIT license. 5 | # 6 | stage=0 7 | 8 | # The datasets for training must be formatted as kaldi data directory. 9 | # Also, make sure the audio files in wav.scp are 'regular' wav files. 10 | # Including piped commands in wav.scp makes training very slow 11 | train_2spk_set=data/simu/data/swb_sre_tr_ns2_beta2_100000 12 | valid_2spk_set=data/simu/data/swb_sre_cv_ns2_beta2_500 13 | train_set=data/simu/data/swb_sre_tr_ns1n2n3n4_beta2n2n5n9_100000 14 | valid_set=data/simu/data/swb_sre_cv_ns1n2n3n4_beta2n2n5n9_500 15 | adapt_set=data/eval/callhome1_spkall 16 | adapt_valid_set=data/eval/callhome2_spkall 17 | 18 | # Base config files for {train,infer}.py 19 | train_2spk_config=conf/eda/train_2spk.yaml 20 | train_config=conf/eda/train.yaml 21 | infer_config=conf/eda/infer.yaml 22 | adapt_config=conf/eda/adapt.yaml 23 | 24 | # Additional arguments passed to {train,infer}.py. 25 | # You need not edit the base config files above 26 | train_2spk_args= 27 | train_args= 28 | infer_args= 29 | adapt_args= 30 | 31 | # 2-speaker model averaging options 32 | average_2spk_start=91 33 | average_2spk_end=100 34 | 35 | # Model averaging options 36 | average_start=16 37 | average_end=25 38 | 39 | # Adapted model averaging options 40 | adapt_average_start=91 41 | adapt_average_end=100 42 | 43 | # Resume training from snapshot at this epoch 44 | # TODO: not tested 45 | resume=-1 46 | 47 | # Debug purpose 48 | debug= 49 | 50 | . path.sh 51 | . cmd.sh 52 | . parse_options.sh || exit 53 | 54 | set -eu 55 | 56 | if [ "$debug" != "" ]; then 57 | # debug mode 58 | train_set=data/simu/data/swb_sre_tr_ns2_beta2_1000 59 | train_config=conf/debug/train.yaml 60 | average_start=3 61 | average_end=5 62 | adapt_config=conf/debug/adapt.yaml 63 | adapt_average_start=6 64 | adapt_average_end=10 65 | fi 66 | 67 | # Parse the config file to set bash variables like: $train_frame_shift, $infer_gpu 68 | eval `yaml2bash.py --prefix train $train_config` 69 | eval `yaml2bash.py --prefix infer $infer_config` 70 | 71 | # Append gpu reservation flag to the queuing command 72 | if [ $train_gpu -le 0 ]; then 73 | train_cmd+=" --gpu 1" 74 | fi 75 | if [ $infer_gpu -le 0 ]; then 76 | infer_cmd+=" --gpu 1" 77 | fi 78 | 79 | # Build directry names for an experiment 80 | # - Training (2 speakers) 81 | # exp/diarize/model/{train_2spk_id}.{valid_2spk_id}.{train_2spk_config_id} 82 | # - Training (1-4 speakers, finetune from the 2-speaker model) 83 | # exp/diarize/model/{train_id}.{valid_id}.{train_config_id} 84 | # - Adapation from non-adapted averaged model 85 | # exp/diarize/model/{train_id}.{valid_id}.{train_config_id}.{avgid}.{adapt_config_id} 86 | # - Decoding 87 | # exp/diarize/infer/{train_id}.{valid_id}.{train_config_id}.{avgid}.{adapt_config_id}.{infer_config_id} 88 | # - Scoring 89 | # exp/diarize/scoring/{train_id}.{valid_id}.{train_config_id}.{avgid}.{adapt_config_id}.{infer_config_id} 90 | train_2spk_id=$(basename $train_2spk_set) 91 | valid_2spk_id=$(basename $valid_2spk_set) 92 | train_id=$(basename $train_set) 93 | valid_id=$(basename $valid_set) 94 | train_2spk_config_id=$(echo $train_2spk_config | sed -e 's%conf/%%' -e 's%/%_%' -e 's%\.yaml$%%') 95 | train_config_id=$(echo $train_config | sed -e 's%conf/%%' -e 's%/%_%' -e 's%\.yaml$%%') 96 | infer_config_id=$(echo $infer_config | sed -e 's%conf/%%' -e 's%/%_%' -e 's%\.yaml$%%') 97 | adapt_config_id=$(echo $adapt_config | sed -e 's%conf/%%' -e 's%/%_%' -e 's%\.yaml$%%') 98 | 99 | # Additional arguments are added to config_id 100 | train_2spk_config_id+=$(echo $train_2spk_args | sed -e 's/\-\-/_/g' -e 's/=//g' -e 's/ \+//g') 101 | train_config_id+=$(echo $train_args | sed -e 's/\-\-/_/g' -e 's/=//g' -e 's/ \+//g') 102 | infer_config_id+=$(echo $infer_args | sed -e 's/\-\-/_/g' -e 's/=//g' -e 's/ \+//g') 103 | adapt_config_id+=$(echo $adapt_args | sed -e 's/\-\-/_/g' -e 's/=//g' -e 's/ \+//g') 104 | 105 | model_2spk_id=$train_2spk_id.$valid_2spk_id.$train_2spk_config_id 106 | model_2spk_dir=exp/diarize/model/$model_2spk_id 107 | if [ $stage -le 1 ]; then 108 | echo "training 2-speaker model at $model_2spk_dir." 109 | if [ -d $model_2spk_dir ]; then 110 | echo "$model_2spk_dir already exists. " 111 | echo " if you want to retry, please remove it." 112 | exit 1 113 | fi 114 | work=$model_2spk_dir/.work 115 | mkdir -p $work 116 | $train_cmd $work/train.log \ 117 | train.py \ 118 | -c $train_2spk_config \ 119 | $train_2spk_args \ 120 | $train_2spk_set $valid_2spk_set $model_2spk_dir \ 121 | || exit 1 122 | fi 123 | 124 | ave_id=avg${average_2spk_start}-${average_2spk_end} 125 | if [ $stage -le 2 ]; then 126 | echo "averaging model parameters into $model_2spk_dir/$ave_id.nnet.npz" 127 | if [ -s $model_2spk_dir/$ave_id.nnet.npz ]; then 128 | echo "$model_2spk_dir/$ave_id.nnet.npz already exists. " 129 | echo " if you want to retry, please remove it." 130 | exit 1 131 | fi 132 | models=`eval echo $model_2spk_dir/snapshot_epoch-{$average_2spk_start..$average_2spk_end}` 133 | model_averaging.py $model_2spk_dir/$ave_id.nnet.npz $models || exit 1 134 | fi 135 | 136 | model_id=$train_id.$valid_id.$train_config_id 137 | model_dir=exp/diarize/model/$model_id 138 | if [ $stage -le 3 ]; then 139 | echo "training model at $model_dir." 140 | if [ -d $model_dir ]; then 141 | echo "$model_dir already exists. " 142 | echo " if you want to retry, please remove it." 143 | exit 1 144 | fi 145 | work=$model_dir/.work 146 | mkdir -p $work 147 | $train_cmd $work/train.log \ 148 | train.py \ 149 | -c $train_config \ 150 | $train_args \ 151 | --initmodel $model_2spk_dir/$ave_id.nnet.npz \ 152 | $train_set $valid_set $model_dir \ 153 | || exit 1 154 | fi 155 | 156 | ave_id=avg${average_start}-${average_end} 157 | if [ $stage -le 4 ]; then 158 | echo "averaging model parameters into $model_dir/$ave_id.nnet.npz" 159 | if [ -s $model_dir/$ave_id.nnet.npz ]; then 160 | echo "$model_dir/$ave_id.nnet.npz already exists. " 161 | echo " if you want to retry, please remove it." 162 | exit 1 163 | fi 164 | models=`eval echo $model_dir/snapshot_epoch-{$average_start..$average_end}` 165 | model_averaging.py $model_dir/$ave_id.nnet.npz $models || exit 1 166 | fi 167 | 168 | adapt_model_dir=exp/diarize/model/$model_id.$ave_id.$adapt_config_id 169 | if [ $stage -le 5 ]; then 170 | echo "adapting model at $adapt_model_dir" 171 | if [ -d $adapt_model_dir ]; then 172 | echo "$adapt_model_dir already exists. " 173 | echo " if you want to retry, please remove it." 174 | exit 1 175 | fi 176 | work=$adapt_model_dir/.work 177 | mkdir -p $work 178 | $train_cmd $work/train.log \ 179 | train.py \ 180 | -c $adapt_config \ 181 | $adapt_args \ 182 | --initmodel $model_dir/$ave_id.nnet.npz \ 183 | $adapt_set $adapt_valid_set $adapt_model_dir \ 184 | || exit 1 185 | fi 186 | 187 | adapt_ave_id=avg${adapt_average_start}-${adapt_average_end} 188 | if [ $stage -le 6 ]; then 189 | echo "averaging models into $adapt_model_dir/$adapt_ave_id.nnet.gz" 190 | if [ -s $adapt_model_dir/$adapt_ave_id.nnet.npz ]; then 191 | echo "$adapt_model_dir/$adapt_ave_id.nnet.npz already exists." 192 | echo " if you want to retry, please remove it." 193 | exit 1 194 | fi 195 | models=`eval echo $adapt_model_dir/snapshot_epoch-{$adapt_average_start..$adapt_average_end}` 196 | model_averaging.py $adapt_model_dir/$adapt_ave_id.nnet.npz $models || exit 1 197 | fi 198 | 199 | infer_dir=exp/diarize/infer/$model_id.$ave_id.$adapt_config_id.$adapt_ave_id.$infer_config_id 200 | if [ $stage -le 7 ]; then 201 | echo "inference at $infer_dir" 202 | if [ -d $infer_dir ]; then 203 | echo "$infer_dir already exists. " 204 | echo " if you want to retry, please remove it." 205 | exit 1 206 | fi 207 | for dset in callhome2_spkall; do 208 | work=$infer_dir/$dset/.work 209 | mkdir -p $work 210 | $train_cmd $work/infer.log \ 211 | infer.py -c $infer_config \ 212 | data/eval/${dset} \ 213 | $adapt_model_dir/$adapt_ave_id.nnet.npz \ 214 | $infer_dir/$dset \ 215 | || exit 1 216 | done 217 | fi 218 | 219 | scoring_dir=exp/diarize/scoring/$model_id.$ave_id.$adapt_config_id.$adapt_ave_id.$infer_config_id 220 | if [ $stage -le 8 ]; then 221 | echo "scoring at $scoring_dir" 222 | if [ -d $scoring_dir ]; then 223 | echo "$scoring_dir already exists. " 224 | echo " if you want to retry, please remove it." 225 | exit 1 226 | fi 227 | for dset in callhome2_spkall; do 228 | work=$scoring_dir/$dset/.work 229 | mkdir -p $work 230 | find $infer_dir/$dset -iname "*.h5" > $work/file_list_$dset 231 | for med in 1 11; do 232 | for th in 0.3 0.4 0.5 0.6 0.7; do 233 | make_rttm.py --median=$med --threshold=$th \ 234 | --frame_shift=$infer_frame_shift --subsampling=$infer_subsampling --sampling_rate=$infer_sampling_rate \ 235 | $work/file_list_$dset $scoring_dir/$dset/hyp_${th}_$med.rttm 236 | md-eval.pl -c 0.25 \ 237 | -r data/eval/$dset/rttm \ 238 | -s $scoring_dir/$dset/hyp_${th}_$med.rttm > $scoring_dir/$dset/result_th${th}_med${med}_collar0.25 2>/dev/null || exit 239 | done 240 | done 241 | done 242 | fi 243 | 244 | if [ $stage -le 9 ]; then 245 | for dset in callhome2_spkall; do 246 | best_score.sh $scoring_dir/$dset 247 | done 248 | fi 249 | echo "Finished !" 250 | -------------------------------------------------------------------------------- /egs/callhome/v1/run_prepare_shared.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita) 4 | # Licensed under the MIT license. 5 | # 6 | # This script prepares kaldi-style data sets shared with different experiments 7 | # - data/xxxx 8 | # callhome, sre, swb2, and swb_cellular datasets 9 | # - data/simu_${simu_outputs} 10 | # simulation mixtures generated with various options 11 | 12 | stage=0 13 | 14 | # Modify corpus directories 15 | # - callhome_dir 16 | # CALLHOME (LDC2001S97) 17 | # - swb2_phase1_train 18 | # Switchboard-2 Phase 1 (LDC98S75) 19 | # - data_root 20 | # LDC99S79, LDC2002S06, LDC2001S13, LDC2004S07, 21 | # LDC2006S44, LDC2011S01, LDC2011S04, LDC2011S09, 22 | # LDC2011S10, LDC2012S01, LDC2011S05, LDC2011S08 23 | # - musan_root 24 | # MUSAN corpus (https://www.openslr.org/17/) 25 | callhome_dir=/export/corpora/NIST/LDC2001S97 26 | swb2_phase1_train=/export/corpora/LDC/LDC98S75 27 | data_root=/export/corpora5/LDC 28 | musan_root=/export/corpora/JHU/musan 29 | # Modify simulated data storage area. 30 | # This script distributes simulated data under these directories 31 | simu_actual_dirs=( 32 | /export/c05/$USER/diarization-data 33 | /export/c08/$USER/diarization-data 34 | /export/c09/$USER/diarization-data 35 | ) 36 | 37 | # data preparation options 38 | max_jobs_run=4 39 | sad_num_jobs=30 40 | sad_opts="--extra-left-context 79 --extra-right-context 21 --frames-per-chunk 150 --extra-left-context-initial 0 --extra-right-context-final 0 --acwt 0.3" 41 | sad_graph_opts="--min-silence-duration=0.03 --min-speech-duration=0.3 --max-speech-duration=10.0" 42 | sad_priors_opts="--sil-scale=0.1" 43 | 44 | # simulation options 45 | simu_opts_overlap=yes 46 | simu_opts_num_speaker=2 47 | simu_opts_sil_scale=2 48 | simu_opts_rvb_prob=0.5 49 | simu_opts_num_train=100000 50 | simu_opts_min_utts=10 51 | simu_opts_max_utts=20 52 | 53 | . path.sh 54 | . cmd.sh 55 | . parse_options.sh || exit 56 | 57 | if [ $stage -le 0 ]; then 58 | echo "prepare kaldi-style datasets" 59 | # Prepare CALLHOME dataset. This will be used to evaluation. 60 | if ! validate_data_dir.sh --no-text --no-feats data/callhome1_spk2 \ 61 | || ! validate_data_dir.sh --no-text --no-feats data/callhome2_spk2; then 62 | # imported from https://github.com/kaldi-asr/kaldi/blob/master/egs/callhome_diarization/v1 63 | local/make_callhome.sh $callhome_dir data 64 | # Generate two-speaker subsets 65 | for dset in callhome1 callhome2; do 66 | # Extract two-speaker recordings in wav.scp 67 | copy_data_dir.sh data/${dset} data/${dset}_spk2 68 | utils/filter_scp.pl <(awk '{if($2==2) print;}' data/${dset}/reco2num_spk) \ 69 | data/${dset}/wav.scp > data/${dset}_spk2/wav.scp 70 | # Regenerate segments file from fullref.rttm 71 | # $2: recid, $4: start_time, $5: duration, $8: speakerid 72 | awk '{printf "%s_%s_%07d_%07d %s %.2f %.2f\n", \ 73 | $2, $8, $4*100, ($4+$5)*100, $2, $4, $4+$5}' \ 74 | data/callhome/fullref.rttm | sort > data/${dset}_spk2/segments 75 | utils/fix_data_dir.sh data/${dset}_spk2 76 | # Speaker ID is '[recid]_[speakerid] 77 | awk '{split($1,A,"_"); printf "%s %s_%s\n", $1, A[1], A[2]}' \ 78 | data/${dset}_spk2/segments > data/${dset}_spk2/utt2spk 79 | utils/fix_data_dir.sh data/${dset}_spk2 80 | # Generate rttm files for scoring 81 | steps/segmentation/convert_utt2spk_and_segments_to_rttm.py \ 82 | data/${dset}_spk2/utt2spk data/${dset}_spk2/segments \ 83 | data/${dset}_spk2/rttm 84 | utils/data/get_reco2dur.sh data/${dset}_spk2 85 | done 86 | fi 87 | # Prepare a collection of NIST SRE and SWB data. This will be used to train, 88 | if ! validate_data_dir.sh --no-text --no-feats data/swb_sre_comb; then 89 | local/make_sre.sh $data_root data 90 | # Prepare SWB for x-vector DNN training. 91 | local/make_swbd2_phase1.pl $swb2_phase1_train \ 92 | data/swbd2_phase1_train 93 | local/make_swbd2_phase2.pl $data_root/LDC99S79 \ 94 | data/swbd2_phase2_train 95 | local/make_swbd2_phase3.pl $data_root/LDC2002S06 \ 96 | data/swbd2_phase3_train 97 | local/make_swbd_cellular1.pl $data_root/LDC2001S13 \ 98 | data/swbd_cellular1_train 99 | local/make_swbd_cellular2.pl $data_root/LDC2004S07 \ 100 | data/swbd_cellular2_train 101 | # Combine swb and sre data 102 | utils/combine_data.sh data/swb_sre_comb \ 103 | data/swbd_cellular1_train data/swbd_cellular2_train \ 104 | data/swbd2_phase1_train \ 105 | data/swbd2_phase2_train data/swbd2_phase3_train data/sre 106 | fi 107 | # musan data. "back-ground 108 | if ! validate_data_dir.sh --no-text --no-feats data/musan_noise_bg; then 109 | local/make_musan.sh $musan_root data 110 | utils/copy_data_dir.sh data/musan_noise data/musan_noise_bg 111 | awk '{if(NR>1) print $1,$1}' $musan_root/noise/free-sound/ANNOTATIONS > data/musan_noise_bg/utt2spk 112 | utils/fix_data_dir.sh data/musan_noise_bg 113 | fi 114 | # simu rirs 8k 115 | if ! validate_data_dir.sh --no-text --no-feats data/simu_rirs_8k; then 116 | mkdir -p data/simu_rirs_8k 117 | if [ ! -e sim_rir_8k.zip ]; then 118 | wget --no-check-certificate http://www.openslr.org/resources/26/sim_rir_8k.zip 119 | fi 120 | unzip sim_rir_8k.zip -d data/sim_rir_8k 121 | find $PWD/data/sim_rir_8k -iname "*.wav" \ 122 | | awk '{n=split($1,A,/[\/\.]/); print A[n-3]"_"A[n-1], $1}' \ 123 | | sort > data/simu_rirs_8k/wav.scp 124 | awk '{print $1, $1}' data/simu_rirs_8k/wav.scp > data/simu_rirs_8k/utt2spk 125 | utils/fix_data_dir.sh data/simu_rirs_8k 126 | fi 127 | # Automatic segmentation using pretrained SAD model 128 | # it will take one day using 30 CPU jobs: 129 | # make_mfcc: 1 hour, compute_output: 18 hours, decode: 0.5 hours 130 | sad_nnet_dir=exp/segmentation_1a/tdnn_stats_asr_sad_1a 131 | sad_work_dir=exp/segmentation_1a/tdnn_stats_asr_sad_1a 132 | if ! validate_data_dir.sh --no-text $sad_work_dir/swb_sre_comb_seg; then 133 | if [ ! -d exp/segmentation_1a ]; then 134 | wget http://kaldi-asr.org/models/4/0004_tdnn_stats_asr_sad_1a.tar.gz 135 | tar zxf 0004_tdnn_stats_asr_sad_1a.tar.gz 136 | fi 137 | steps/segmentation/detect_speech_activity.sh \ 138 | --nj $sad_num_jobs \ 139 | --graph-opts "$sad_graph_opts" \ 140 | --transform-probs-opts "$sad_priors_opts" $sad_opts \ 141 | data/swb_sre_comb $sad_nnet_dir mfcc_hires $sad_work_dir \ 142 | $sad_work_dir/swb_sre_comb || exit 1 143 | fi 144 | # Extract >1.5 sec segments and split into train/valid sets 145 | if ! validate_data_dir.sh --no-text --no-feats data/swb_sre_cv; then 146 | copy_data_dir.sh data/swb_sre_comb data/swb_sre_comb_seg 147 | awk '$4-$3>1.5{print;}' $sad_work_dir/swb_sre_comb_seg/segments > data/swb_sre_comb_seg/segments 148 | cp $sad_work_dir/swb_sre_comb_seg/{utt2spk,spk2utt} data/swb_sre_comb_seg 149 | fix_data_dir.sh data/swb_sre_comb_seg 150 | utils/subset_data_dir_tr_cv.sh data/swb_sre_comb_seg data/swb_sre_tr data/swb_sre_cv 151 | fi 152 | fi 153 | 154 | simudir=data/simu 155 | if [ $stage -le 1 ]; then 156 | echo "simulation of mixture" 157 | mkdir -p $simudir/.work 158 | random_mixture_cmd=random_mixture_nooverlap.py 159 | make_mixture_cmd=make_mixture_nooverlap.py 160 | if [ "$simu_opts_overlap" == "yes" ]; then 161 | random_mixture_cmd=random_mixture.py 162 | make_mixture_cmd=make_mixture.py 163 | fi 164 | 165 | for simu_opts_sil_scale in 2; do 166 | for dset in swb_sre_tr swb_sre_cv; do 167 | if [ "$dset" == "swb_sre_tr" ]; then 168 | n_mixtures=${simu_opts_num_train} 169 | else 170 | n_mixtures=500 171 | fi 172 | simuid=${dset}_ns${simu_opts_num_speaker}_beta${simu_opts_sil_scale}_${n_mixtures} 173 | # check if you have the simulation 174 | if ! validate_data_dir.sh --no-text --no-feats $simudir/data/$simuid; then 175 | # random mixture generation 176 | $train_cmd $simudir/.work/random_mixture_$simuid.log \ 177 | $random_mixture_cmd --n_speakers $simu_opts_num_speaker --n_mixtures $n_mixtures \ 178 | --speech_rvb_probability $simu_opts_rvb_prob \ 179 | --sil_scale $simu_opts_sil_scale \ 180 | data/$dset data/musan_noise_bg data/simu_rirs_8k \ 181 | \> $simudir/.work/mixture_$simuid.scp 182 | nj=100 183 | mkdir -p $simudir/wav/$simuid 184 | # distribute simulated data to $simu_actual_dir 185 | split_scps= 186 | for n in $(seq $nj); do 187 | split_scps="$split_scps $simudir/.work/mixture_$simuid.$n.scp" 188 | mkdir -p $simudir/.work/data_$simuid.$n 189 | actual=${simu_actual_dirs[($n-1)%${#simu_actual_dirs[@]}]}/$simudir/wav/$simuid/$n 190 | mkdir -p $actual 191 | ln -nfs $actual $simudir/wav/$simuid/$n 192 | done 193 | utils/split_scp.pl $simudir/.work/mixture_$simuid.scp $split_scps || exit 1 194 | 195 | $simu_cmd --max-jobs-run 32 JOB=1:$nj $simudir/.work/make_mixture_$simuid.JOB.log \ 196 | $make_mixture_cmd --rate=8000 \ 197 | $simudir/.work/mixture_$simuid.JOB.scp \ 198 | $simudir/.work/data_$simuid.JOB $simudir/wav/$simuid/JOB 199 | utils/combine_data.sh $simudir/data/$simuid $simudir/.work/data_$simuid.* 200 | steps/segmentation/convert_utt2spk_and_segments_to_rttm.py \ 201 | $simudir/data/$simuid/utt2spk $simudir/data/$simuid/segments \ 202 | $simudir/data/$simuid/rttm 203 | utils/data/get_reco2dur.sh $simudir/data/$simuid 204 | fi 205 | done 206 | done 207 | fi 208 | 209 | if [ $stage -le 3 ]; then 210 | # compose eval/callhome2_spk2 211 | eval_set=data/eval/callhome2_spk2 212 | if ! validate_data_dir.sh --no-text --no-feats $eval_set; then 213 | utils/copy_data_dir.sh data/callhome2_spk2 $eval_set 214 | cp data/callhome2_spk2/rttm $eval_set/rttm 215 | awk -v dstdir=wav/eval/callhome2_spk2 '{print $1, dstdir"/"$1".wav"}' data/callhome2_spk2/wav.scp > $eval_set/wav.scp 216 | mkdir -p wav/eval/callhome2_spk2 217 | wav-copy scp:data/callhome2_spk2/wav.scp scp:$eval_set/wav.scp 218 | utils/data/get_reco2dur.sh $eval_set 219 | fi 220 | 221 | # compose eval/callhome1_spk2 222 | adapt_set=data/eval/callhome1_spk2 223 | if ! validate_data_dir.sh --no-text --no-feats $adapt_set; then 224 | utils/copy_data_dir.sh data/callhome1_spk2 $adapt_set 225 | cp data/callhome1_spk2/rttm $adapt_set/rttm 226 | awk -v dstdir=wav/eval/callhome1_spk2 '{print $1, dstdir"/"$1".wav"}' data/callhome1_spk2/wav.scp > $adapt_set/wav.scp 227 | mkdir -p wav/eval/callhome1_spk2 228 | wav-copy scp:data/callhome1_spk2/wav.scp scp:$adapt_set/wav.scp 229 | utils/data/get_reco2dur.sh $adapt_set 230 | fi 231 | fi 232 | -------------------------------------------------------------------------------- /egs/callhome/v1/run_prepare_shared_eda.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita, Shota Horiguchi) 4 | # Licensed under the MIT license. 5 | # 6 | # This script prepares kaldi-style data sets shared with different experiments 7 | # - data/xxxx 8 | # callhome, sre, swb2, and swb_cellular datasets 9 | # - data/simu_${simu_outputs} 10 | # simulation mixtures generated with various options 11 | 12 | stage=0 13 | 14 | # Modify corpus directories 15 | # - callhome_dir 16 | # CALLHOME (LDC2001S97) 17 | # - swb2_phase1_train 18 | # Switchboard-2 Phase 1 (LDC98S75) 19 | # - data_root 20 | # LDC99S79, LDC2002S06, LDC2001S13, LDC2004S07, 21 | # LDC2006S44, LDC2011S01, LDC2011S04, LDC2011S09, 22 | # LDC2011S10, LDC2012S01, LDC2011S05, LDC2011S08 23 | # - musan_root 24 | # MUSAN corpus (https://www.openslr.org/17/) 25 | callhome_dir=/export/corpora/NIST/LDC2001S97 26 | swb2_phase1_train=/export/corpora/LDC/LDC98S75 27 | data_root=/export/corpora5/LDC 28 | musan_root=/export/corpora/JHU/musan 29 | # Modify simulated data storage area. 30 | # This script distributes simulated data under these directories 31 | simu_actual_dirs=( 32 | /export/c05/$USER/diarization-data 33 | /export/c08/$USER/diarization-data 34 | /export/c09/$USER/diarization-data 35 | ) 36 | 37 | # data preparation options 38 | max_jobs_run=4 39 | sad_num_jobs=30 40 | sad_opts="--extra-left-context 79 --extra-right-context 21 --frames-per-chunk 150 --extra-left-context-initial 0 --extra-right-context-final 0 --acwt 0.3" 41 | sad_graph_opts="--min-silence-duration=0.03 --min-speech-duration=0.3 --max-speech-duration=10.0" 42 | sad_priors_opts="--sil-scale=0.1" 43 | 44 | # simulation options 45 | simu_opts_overlap=yes 46 | simu_opts_num_speaker_array=(1 2 3 4) 47 | simu_opts_sil_scale_array=(2 2 5 9) 48 | simu_opts_rvb_prob=0.5 49 | simu_opts_num_train=100000 50 | simu_opts_min_utts=10 51 | simu_opts_max_utts=20 52 | 53 | . path.sh 54 | . cmd.sh 55 | . parse_options.sh || exit 56 | 57 | if [ $stage -le 0 ]; then 58 | echo "prepare kaldi-style datasets" 59 | # Prepare CALLHOME dataset. This will be used to evaluation. 60 | if ! validate_data_dir.sh --no-text --no-feats data/callhome1_spkall \ 61 | || ! validate_data_dir.sh --no-text --no-feats data/callhome2_spkall; then 62 | # imported from https://github.com/kaldi-asr/kaldi/blob/master/egs/callhome_diarization/v1 63 | local/make_callhome.sh $callhome_dir data 64 | # Generate two-speaker subsets 65 | for dset in callhome1 callhome2; do 66 | # Extract two-speaker recordings in wav.scp 67 | copy_data_dir.sh data/${dset} data/${dset}_spkall 68 | # Regenerate segments file from fullref.rttm 69 | # $2: recid, $4: start_time, $5: duration, $8: speakerid 70 | awk '{printf "%s_%s_%07d_%07d %s %.2f %.2f\n", \ 71 | $2, $8, $4*100, ($4+$5)*100, $2, $4, $4+$5}' \ 72 | data/callhome/fullref.rttm | sort > data/${dset}_spkall/segments 73 | utils/fix_data_dir.sh data/${dset}_spkall 74 | # Speaker ID is '[recid]_[speakerid] 75 | awk '{split($1,A,"_"); printf "%s %s_%s\n", $1, A[1], A[2]}' \ 76 | data/${dset}_spkall/segments > data/${dset}_spkall/utt2spk 77 | utils/fix_data_dir.sh data/${dset}_spkall 78 | # Generate rttm files for scoring 79 | steps/segmentation/convert_utt2spk_and_segments_to_rttm.py \ 80 | data/${dset}_spkall/utt2spk data/${dset}_spkall/segments \ 81 | data/${dset}_spkall/rttm 82 | utils/data/get_reco2dur.sh data/${dset}_spkall 83 | done 84 | fi 85 | # Prepare a collection of NIST SRE and SWB data. This will be used to train, 86 | if ! validate_data_dir.sh --no-text --no-feats data/swb_sre_comb; then 87 | local/make_sre.sh $data_root data 88 | # Prepare SWB for x-vector DNN training. 89 | local/make_swbd2_phase1.pl $swb2_phase1_train \ 90 | data/swbd2_phase1_train 91 | local/make_swbd2_phase2.pl $data_root/LDC99S79 \ 92 | data/swbd2_phase2_train 93 | local/make_swbd2_phase3.pl $data_root/LDC2002S06 \ 94 | data/swbd2_phase3_train 95 | local/make_swbd_cellular1.pl $data_root/LDC2001S13 \ 96 | data/swbd_cellular1_train 97 | local/make_swbd_cellular2.pl $data_root/LDC2004S07 \ 98 | data/swbd_cellular2_train 99 | # Combine swb and sre data 100 | utils/combine_data.sh data/swb_sre_comb \ 101 | data/swbd_cellular1_train data/swbd_cellular2_train \ 102 | data/swbd2_phase1_train \ 103 | data/swbd2_phase2_train data/swbd2_phase3_train data/sre 104 | fi 105 | # musan data. "back-ground 106 | if ! validate_data_dir.sh --no-text --no-feats data/musan_noise_bg; then 107 | local/make_musan.sh $musan_root data 108 | utils/copy_data_dir.sh data/musan_noise data/musan_noise_bg 109 | awk '{if(NR>1) print $1,$1}' $musan_root/noise/free-sound/ANNOTATIONS > data/musan_noise_bg/utt2spk 110 | utils/fix_data_dir.sh data/musan_noise_bg 111 | fi 112 | # simu rirs 8k 113 | if ! validate_data_dir.sh --no-text --no-feats data/simu_rirs_8k; then 114 | mkdir -p data/simu_rirs_8k 115 | if [ ! -e sim_rir_8k.zip ]; then 116 | wget --no-check-certificate http://www.openslr.org/resources/26/sim_rir_8k.zip 117 | fi 118 | unzip sim_rir_8k.zip -d data/sim_rir_8k 119 | find $PWD/data/sim_rir_8k -iname "*.wav" \ 120 | | awk '{n=split($1,A,/[\/\.]/); print A[n-3]"_"A[n-1], $1}' \ 121 | | sort > data/simu_rirs_8k/wav.scp 122 | awk '{print $1, $1}' data/simu_rirs_8k/wav.scp > data/simu_rirs_8k/utt2spk 123 | utils/fix_data_dir.sh data/simu_rirs_8k 124 | fi 125 | # Automatic segmentation using pretrained SAD model 126 | # it will take one day using 30 CPU jobs: 127 | # make_mfcc: 1 hour, compute_output: 18 hours, decode: 0.5 hours 128 | sad_nnet_dir=exp/segmentation_1a/tdnn_stats_asr_sad_1a 129 | sad_work_dir=exp/segmentation_1a/tdnn_stats_asr_sad_1a 130 | if ! validate_data_dir.sh --no-text $sad_work_dir/swb_sre_comb_seg; then 131 | if [ ! -d exp/segmentation_1a ]; then 132 | wget http://kaldi-asr.org/models/4/0004_tdnn_stats_asr_sad_1a.tar.gz 133 | tar zxf 0004_tdnn_stats_asr_sad_1a.tar.gz 134 | fi 135 | steps/segmentation/detect_speech_activity.sh \ 136 | --nj $sad_num_jobs \ 137 | --graph-opts "$sad_graph_opts" \ 138 | --transform-probs-opts "$sad_priors_opts" $sad_opts \ 139 | data/swb_sre_comb $sad_nnet_dir mfcc_hires $sad_work_dir \ 140 | $sad_work_dir/swb_sre_comb || exit 1 141 | fi 142 | # Extract >1.5 sec segments and split into train/valid sets 143 | if ! validate_data_dir.sh --no-text --no-feats data/swb_sre_cv; then 144 | copy_data_dir.sh data/swb_sre_comb data/swb_sre_comb_seg 145 | awk '$4-$3>1.5{print;}' $sad_work_dir/swb_sre_comb_seg/segments > data/swb_sre_comb_seg/segments 146 | cp $sad_work_dir/swb_sre_comb_seg/{utt2spk,spk2utt} data/swb_sre_comb_seg 147 | fix_data_dir.sh data/swb_sre_comb_seg 148 | utils/subset_data_dir_tr_cv.sh data/swb_sre_comb_seg data/swb_sre_tr data/swb_sre_cv 149 | fi 150 | fi 151 | 152 | simudir=data/simu 153 | if [ $stage -le 1 ]; then 154 | echo "simulation of mixture" 155 | mkdir -p $simudir/.work 156 | random_mixture_cmd=random_mixture_nooverlap.py 157 | make_mixture_cmd=make_mixture_nooverlap.py 158 | if [ "$simu_opts_overlap" == "yes" ]; then 159 | random_mixture_cmd=random_mixture.py 160 | make_mixture_cmd=make_mixture.py 161 | fi 162 | 163 | for ((i=0; i<${#simu_opts_sil_scale_array[@]}; ++i)); do 164 | simu_opts_num_speaker=${simu_opts_num_speaker_array[i]} 165 | simu_opts_sil_scale=${simu_opts_sil_scale_array[i]} 166 | for dset in swb_sre_tr swb_sre_cv; do 167 | if [ "$dset" == "swb_sre_tr" ]; then 168 | n_mixtures=${simu_opts_num_train} 169 | else 170 | n_mixtures=500 171 | fi 172 | simuid=${dset}_ns${simu_opts_num_speaker}_beta${simu_opts_sil_scale}_${n_mixtures} 173 | # check if you have the simulation 174 | if ! validate_data_dir.sh --no-text --no-feats $simudir/data/$simuid; then 175 | # random mixture generation 176 | $train_cmd $simudir/.work/random_mixture_$simuid.log \ 177 | $random_mixture_cmd --n_speakers $simu_opts_num_speaker --n_mixtures $n_mixtures \ 178 | --speech_rvb_probability $simu_opts_rvb_prob \ 179 | --sil_scale $simu_opts_sil_scale \ 180 | data/$dset data/musan_noise_bg data/simu_rirs_8k \ 181 | \> $simudir/.work/mixture_$simuid.scp 182 | nj=100 183 | mkdir -p $simudir/wav/$simuid 184 | # distribute simulated data to $simu_actual_dir 185 | split_scps= 186 | for n in $(seq $nj); do 187 | split_scps="$split_scps $simudir/.work/mixture_$simuid.$n.scp" 188 | mkdir -p $simudir/.work/data_$simuid.$n 189 | actual=${simu_actual_dirs[($n-1)%${#simu_actual_dirs[@]}]}/$simudir/wav/$simuid/$n 190 | mkdir -p $actual 191 | ln -nfs $actual $simudir/wav/$simuid/$n 192 | done 193 | utils/split_scp.pl $simudir/.work/mixture_$simuid.scp $split_scps || exit 1 194 | 195 | $simu_cmd --max-jobs-run 32 JOB=1:$nj $simudir/.work/make_mixture_$simuid.JOB.log \ 196 | $make_mixture_cmd --rate=8000 \ 197 | $simudir/.work/mixture_$simuid.JOB.scp \ 198 | $simudir/.work/data_$simuid.JOB $simudir/wav/$simuid/JOB 199 | utils/combine_data.sh $simudir/data/$simuid $simudir/.work/data_$simuid.* 200 | steps/segmentation/convert_utt2spk_and_segments_to_rttm.py \ 201 | $simudir/data/$simuid/utt2spk $simudir/data/$simuid/segments \ 202 | $simudir/data/$simuid/rttm 203 | utils/data/get_reco2dur.sh $simudir/data/$simuid 204 | fi 205 | simuid_concat=${dset}_ns"$(IFS="n"; echo "${simu_opts_num_speaker_array[*]}")"_beta"$(IFS="n"; echo "${simu_opts_sil_scale_array[*]}")"_${n_mixtures} 206 | mkdir -p $simudir/data/$simuid_concat 207 | for f in `ls -F $simudir/data/$simuid | grep -v "/"`; do 208 | cat $simudir/data/$simuid/$f >> $simudir/data/$simuid_concat/$f 209 | done 210 | done 211 | done 212 | fi 213 | 214 | if [ $stage -le 3 ]; then 215 | # compose eval/callhome2_spkall 216 | eval_set=data/eval/callhome2_spkall 217 | if ! validate_data_dir.sh --no-text --no-feats $eval_set; then 218 | utils/copy_data_dir.sh data/callhome2_spkall $eval_set 219 | cp data/callhome2_spkall/rttm $eval_set/rttm 220 | awk -v dstdir=wav/eval/callhome2_spkall '{print $1, dstdir"/"$1".wav"}' data/callhome2_spkall/wav.scp > $eval_set/wav.scp 221 | mkdir -p wav/eval/callhome2_spkall 222 | wav-copy scp:data/callhome2_spkall/wav.scp scp:$eval_set/wav.scp 223 | utils/data/get_reco2dur.sh $eval_set 224 | fi 225 | 226 | # compose eval/callhome1_spkall 227 | adapt_set=data/eval/callhome1_spkall 228 | if ! validate_data_dir.sh --no-text --no-feats $adapt_set; then 229 | utils/copy_data_dir.sh data/callhome1_spkall $adapt_set 230 | cp data/callhome1_spkall/rttm $adapt_set/rttm 231 | awk -v dstdir=wav/eval/callhome1_spkall '{print $1, dstdir"/"$1".wav"}' data/callhome1_spkall/wav.scp > $adapt_set/wav.scp 232 | mkdir -p wav/eval/callhome1_spkall 233 | wav-copy scp:data/callhome1_spkall/wav.scp scp:$adapt_set/wav.scp 234 | utils/data/get_reco2dur.sh $adapt_set 235 | fi 236 | fi 237 | -------------------------------------------------------------------------------- /egs/callhome/v1/steps: -------------------------------------------------------------------------------- 1 | ../../../tools/kaldi/egs/wsj/s5/steps -------------------------------------------------------------------------------- /egs/callhome/v1/utils: -------------------------------------------------------------------------------- 1 | ../../../tools/kaldi/egs/wsj/s5/utils -------------------------------------------------------------------------------- /egs/mini_librispeech/v1/RESULT.md: -------------------------------------------------------------------------------- 1 | 2 | # Training curve 3 | GPU: Tesla K80 4 | 5 | ``` 6 | grep loss exp/diarize/model/train_clean_5_ns2_beta2_500.dev_clean_2_ns2_beta2_500.train/log 7 | "main/loss": 0.8094629645347595, 8 | "validation/main/loss": 0.7502496838569641, 9 | "main/loss": 0.6841621398925781, 10 | "validation/main/loss": 0.6442975997924805, 11 | "main/loss": 0.5633799433708191, 12 | "validation/main/loss": 0.5978456139564514, 13 | "main/loss": 0.5073038339614868, 14 | "validation/main/loss": 0.5673154592514038, 15 | "main/loss": 0.47916650772094727, 16 | "validation/main/loss": 0.5508813261985779, 17 | "main/loss": 0.46045243740081787, 18 | "validation/main/loss": 0.53536057472229, 19 | "main/loss": 0.44897904992103577, 20 | "validation/main/loss": 0.5264081358909607, 21 | "main/loss": 0.4393312335014343, 22 | "validation/main/loss": 0.520709216594696, 23 | "main/loss": 0.4310261905193329, 24 | "validation/main/loss": 0.510313093662262, 25 | "main/loss": 0.42343708872795105, 26 | "validation/main/loss": 0.5055857300758362, 27 | ``` 28 | 29 | # Final DER 30 | 31 | ``` 32 | exp/diarize/scoring/train_clean_5_ns2_beta2_500.dev_clean_2_ns2_beta2_500.train.avg8-10.infer/dev_clean_2_ns2_beta2_500/result_th0.7_med11_collar0.25: OVERALL SPEAKER DIARIZATION ERROR = 29.96 percent of scored speaker time `(ALL) 33 | ``` 34 | -------------------------------------------------------------------------------- /egs/mini_librispeech/v1/cmd.sh: -------------------------------------------------------------------------------- 1 | # Modify this file according to a job scheduling system in your cluster. 2 | # For more information about cmd.sh see http://kaldi-asr.org/doc/queue.html. 3 | # 4 | # If you use your local machine, use "run.pl". 5 | # export train_cmd="run.pl" 6 | # export infer_cmd="run.pl" 7 | # export simu_cmd="run.pl" 8 | 9 | # If you use Grid Engine, use "queue.pl" 10 | export train_cmd="queue.pl --mem 16G -l 'hostname=c*'" 11 | export infer_cmd="queue.pl --mem 16G -l 'hostname=c*'" 12 | export simu_cmd="queue.pl" 13 | 14 | # If you use SLURM, use "slurm.pl". 15 | # export train_cmd="slurm.pl" 16 | # export infer_cmd="slurm.pl" 17 | # export simu_cmd="slurm.pl" 18 | -------------------------------------------------------------------------------- /egs/mini_librispeech/v1/conf/blstm/infer.yaml: -------------------------------------------------------------------------------- 1 | # inference options 2 | sampling_rate: 8000 3 | frame_size: 200 4 | frame_shift: 80 5 | model_type: BLSTM 6 | num_lstm_layers: 5 7 | hidden_size: 256 8 | num_speakers: 2 9 | input_transform: logmel23_mn 10 | context_size: 7 11 | subsampling: 10 12 | chunk_size: 4000 13 | gpu: 0 14 | -------------------------------------------------------------------------------- /egs/mini_librispeech/v1/conf/blstm/train.yaml: -------------------------------------------------------------------------------- 1 | # training options 2 | sampling_rate: 8000 3 | frame_size: 200 4 | frame_shift: 80 5 | model_type: BLSTM 6 | max_epochs: 10 7 | gradclip: 5 8 | batchsize: 10 9 | hidden_size: 256 10 | num_lstm_layers: 5 11 | lr: 0.001 12 | num_frames: 4000 13 | num_speakers: 2 14 | input_transform: logmel23_mn 15 | optimizer: adam 16 | context_size: 7 17 | subsampling: 10 18 | seed: 777 19 | gpu: 0 20 | -------------------------------------------------------------------------------- /egs/mini_librispeech/v1/conf/eda/infer.yaml: -------------------------------------------------------------------------------- 1 | # inference options 2 | sampling_rate: 8000 3 | frame_size: 200 4 | frame_shift: 80 5 | model_type: Transformer 6 | hidden_size: 256 7 | num_speakers: 2 8 | input_transform: logmel23_mn 9 | context_size: 7 10 | subsampling: 10 11 | chunk_size: 2000 12 | transformer_encoder_n_heads: 4 13 | transformer_encoder_n_layers: 2 14 | use_attractor: True 15 | shuffle: True 16 | attractor_encoder_dropout: 0.1 17 | attractor_decoder_dropout: 0.1 18 | gpu: 0 19 | -------------------------------------------------------------------------------- /egs/mini_librispeech/v1/conf/eda/train.yaml: -------------------------------------------------------------------------------- 1 | # training options 2 | sampling_rate: 8000 3 | frame_size: 200 4 | frame_shift: 80 5 | model_type: Transformer 6 | max_epochs: 10 7 | gradclip: 5 8 | batchsize: 64 9 | hidden_size: 256 10 | num_frames: 500 11 | num_speakers: 2 12 | input_transform: logmel23_mn 13 | optimizer: noam 14 | context_size: 7 15 | subsampling: 10 16 | noam_scale: 1.0 17 | gradient_accumulation_steps: 1 18 | transformer_encoder_n_heads: 4 19 | transformer_encoder_n_layers: 2 20 | transformer_encoder_dropout: 0.1 21 | noam_warmup_steps: 25000 22 | use_attractor: True 23 | shuffle: True 24 | attractor_loss_ratio: 1.0 25 | attractor_encoder_dropout: 0.1 26 | attractor_decoder_dropout: 0.1 27 | seed: 777 28 | gpu: 0 29 | -------------------------------------------------------------------------------- /egs/mini_librispeech/v1/conf/infer.yaml: -------------------------------------------------------------------------------- 1 | # inference options 2 | sampling_rate: 8000 3 | frame_size: 200 4 | frame_shift: 80 5 | model_type: Transformer 6 | hidden_size: 256 7 | num_speakers: 2 8 | input_transform: logmel23_mn 9 | context_size: 7 10 | subsampling: 10 11 | chunk_size: 2000 12 | transformer_encoder_n_heads: 4 13 | transformer_encoder_n_layers: 2 14 | gpu: 0 15 | -------------------------------------------------------------------------------- /egs/mini_librispeech/v1/conf/mfcc.conf: -------------------------------------------------------------------------------- 1 | --sample-frequency=8000 2 | --frame-length=25 # the default is 25 3 | --low-freq=20 # the default. 4 | --high-freq=3700 # the default is zero meaning use the Nyquist (4k in this case). 5 | --num-ceps=23 # higher than the default which is 12. 6 | --snip-edges=false 7 | -------------------------------------------------------------------------------- /egs/mini_librispeech/v1/conf/mfcc_hires.conf: -------------------------------------------------------------------------------- 1 | # config for high-resolution MFCC features, intended for neural network training. 2 | # Note: we keep all cepstra, so it has the same info as filterbank features, 3 | # but MFCC is more easily compressible (because less correlated) which is why 4 | # we prefer this method. 5 | --use-energy=false # use average of log energy, not energy. 6 | --sample-frequency=8000 # Switchboard is sampled at 8kHz 7 | --num-mel-bins=40 # similar to Google's setup. 8 | --num-ceps=40 # there is no dimensionality reduction. 9 | --low-freq=40 # low cutoff frequency for mel bins 10 | --high-freq=-200 # high cutoff frequently, relative to Nyquist of 4000 (=3800) 11 | --allow-downsample=true 12 | -------------------------------------------------------------------------------- /egs/mini_librispeech/v1/conf/train.yaml: -------------------------------------------------------------------------------- 1 | # training options 2 | sampling_rate: 8000 3 | frame_size: 200 4 | frame_shift: 80 5 | model_type: Transformer 6 | max_epochs: 10 7 | gradclip: 5 8 | batchsize: 64 9 | hidden_size: 256 10 | num_frames: 500 11 | num_speakers: 2 12 | input_transform: logmel23_mn 13 | optimizer: noam 14 | context_size: 7 15 | subsampling: 10 16 | noam_scale: 1.0 17 | gradient_accumulation_steps: 1 18 | transformer_encoder_n_heads: 4 19 | transformer_encoder_n_layers: 2 20 | transformer_encoder_dropout: 0.1 21 | noam_warmup_steps: 25000 22 | seed: 777 23 | gpu: 0 24 | -------------------------------------------------------------------------------- /egs/mini_librispeech/v1/local/data_prep.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright 2014 Vassil Panayotov 4 | # 2014 Johns Hopkins University (author: Daniel Povey) 5 | # Apache 2.0 6 | 7 | if [ "$#" -ne 2 ]; then 8 | echo "Usage: $0 " 9 | echo "e.g.: $0 /export/a15/vpanayotov/data/LibriSpeech/dev-clean data/dev-clean" 10 | exit 1 11 | fi 12 | 13 | src=$1 14 | dst=$2 15 | 16 | # all utterances are FLAC compressed 17 | if ! which flac >&/dev/null; then 18 | echo "Please install 'flac' on ALL worker nodes!" 19 | exit 1 20 | fi 21 | 22 | spk_file=$src/../SPEAKERS.TXT 23 | 24 | mkdir -p $dst || exit 1; 25 | 26 | [ ! -d $src ] && echo "$0: no such directory $src" && exit 1; 27 | [ ! -f $spk_file ] && echo "$0: expected file $spk_file to exist" && exit 1; 28 | 29 | 30 | wav_scp=$dst/wav.scp; [[ -f "$wav_scp" ]] && rm $wav_scp 31 | trans=$dst/text; [[ -f "$trans" ]] && rm $trans 32 | utt2spk=$dst/utt2spk; [[ -f "$utt2spk" ]] && rm $utt2spk 33 | spk2gender=$dst/spk2gender; [[ -f $spk2gender ]] && rm $spk2gender 34 | utt2dur=$dst/utt2dur; [[ -f "$utt2dur" ]] && rm $utt2dur 35 | 36 | for reader_dir in $(find -L $src -mindepth 1 -maxdepth 1 -type d | sort); do 37 | reader=$(basename $reader_dir) 38 | if ! [ $reader -eq $reader ]; then # not integer. 39 | echo "$0: unexpected subdirectory name $reader" 40 | exit 1; 41 | fi 42 | 43 | reader_gender=$(egrep "^$reader[ ]+\|" $spk_file | awk -F'|' '{gsub(/[ ]+/, ""); print tolower($2)}') 44 | if [ "$reader_gender" != 'm' ] && [ "$reader_gender" != 'f' ]; then 45 | echo "Unexpected gender: '$reader_gender'" 46 | exit 1; 47 | fi 48 | 49 | for chapter_dir in $(find -L $reader_dir/ -mindepth 1 -maxdepth 1 -type d | sort); do 50 | chapter=$(basename $chapter_dir) 51 | if ! [ "$chapter" -eq "$chapter" ]; then 52 | echo "$0: unexpected chapter-subdirectory name $chapter" 53 | exit 1; 54 | fi 55 | 56 | find -L $chapter_dir/ -iname "*.flac" | sort | xargs -I% basename % .flac | \ 57 | awk -v "dir=$chapter_dir" '{printf "%s flac -c -d -s %s/%s.flac |\n", $0, dir, $0}' >>$wav_scp|| exit 1 58 | 59 | chapter_trans=$chapter_dir/${reader}-${chapter}.trans.txt 60 | [ ! -f $chapter_trans ] && echo "$0: expected file $chapter_trans to exist" && exit 1 61 | cat $chapter_trans >>$trans 62 | 63 | # NOTE: For now we are using per-chapter utt2spk. That is each chapter is considered 64 | # to be a different speaker. This is done for simplicity and because we want 65 | # e.g. the CMVN to be calculated per-chapter 66 | awk -v "reader=$reader" -v "chapter=$chapter" '{printf "%s %s-%s\n", $1, reader, chapter}' \ 67 | <$chapter_trans >>$utt2spk || exit 1 68 | 69 | # reader -> gender map (again using per-chapter granularity) 70 | echo "${reader}-${chapter} $reader_gender" >>$spk2gender 71 | done 72 | done 73 | 74 | spk2utt=$dst/spk2utt 75 | utils/utt2spk_to_spk2utt.pl <$utt2spk >$spk2utt || exit 1 76 | 77 | ntrans=$(wc -l <$trans) 78 | nutt2spk=$(wc -l <$utt2spk) 79 | ! [ "$ntrans" -eq "$nutt2spk" ] && \ 80 | echo "Inconsistent #transcripts($ntrans) and #utt2spk($nutt2spk)" && exit 1; 81 | 82 | utils/data/get_utt2dur.sh $dst 1>&2 || exit 1 83 | 84 | utils/validate_data_dir.sh --no-feats $dst || exit 1; 85 | 86 | echo "$0: successfully prepared data in $dst" 87 | 88 | exit 0 89 | -------------------------------------------------------------------------------- /egs/mini_librispeech/v1/local/download_and_untar.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright 2014 Johns Hopkins University (author: Daniel Povey) 4 | # 2017 Luminar Technologies, Inc. (author: Daniel Galvez) 5 | # Apache 2.0 6 | 7 | remove_archive=false 8 | 9 | if [ "$1" == --remove-archive ]; then 10 | remove_archive=true 11 | shift 12 | fi 13 | 14 | if [ $# -ne 3 ]; then 15 | echo "Usage: $0 [--remove-archive] " 16 | echo "e.g.: $0 /export/a05/dgalvez/ www.openslr.org/resources/31 dev-clean-2" 17 | echo "With --remove-archive it will remove the archive after successfully un-tarring it." 18 | echo " can be one of: dev-clean-2, test-clean-5, dev-other, test-other," 19 | echo " train-clean-100, train-clean-360, train-other-500." 20 | fi 21 | 22 | data=$1 23 | url=$2 24 | part=$3 25 | 26 | if [ ! -d "$data" ]; then 27 | echo "$0: no such directory $data" 28 | exit 1; 29 | fi 30 | 31 | data=$(readlink -f $data) 32 | 33 | part_ok=false 34 | list="dev-clean-2 train-clean-5" 35 | for x in $list; do 36 | if [ "$part" == $x ]; then part_ok=true; fi 37 | done 38 | if ! $part_ok; then 39 | echo "$0: expected to be one of $list, but got '$part'" 40 | exit 1; 41 | fi 42 | 43 | if [ -z "$url" ]; then 44 | echo "$0: empty URL base." 45 | exit 1; 46 | fi 47 | 48 | if [ -f $data/LibriSpeech/$part/.complete ]; then 49 | echo "$0: data part $part was already successfully extracted, nothing to do." 50 | exit 0; 51 | fi 52 | 53 | 54 | #sizes="126046265 332747356" 55 | sizes="126046265 332954390" 56 | 57 | if [ -f $data/$part.tar.gz ]; then 58 | size=$(/bin/ls -l $data/$part.tar.gz | awk '{print $5}') 59 | size_ok=false 60 | for s in $sizes; do if [ $s == $size ]; then size_ok=true; fi; done 61 | if ! $size_ok; then 62 | echo "$0: removing existing file $data/$part.tar.gz because its size in bytes $size" 63 | echo "does not equal the size of one of the archives." 64 | rm $data/$part.tar.gz 65 | else 66 | echo "$data/$part.tar.gz exists and appears to be complete." 67 | fi 68 | fi 69 | 70 | if [ ! -f $data/$part.tar.gz ]; then 71 | if ! which wget >/dev/null; then 72 | echo "$0: wget is not installed." 73 | exit 1; 74 | fi 75 | full_url=$url/$part.tar.gz 76 | echo "$0: downloading data from $full_url. This may take some time, please be patient." 77 | 78 | cd $data 79 | if ! wget --no-check-certificate $full_url; then 80 | echo "$0: error executing wget $full_url" 81 | exit 1; 82 | fi 83 | cd - 84 | fi 85 | 86 | cd $data 87 | 88 | if ! tar -xvzf $part.tar.gz; then 89 | echo "$0: error un-tarring archive $data/$part.tar.gz" 90 | exit 1; 91 | fi 92 | 93 | touch $data/LibriSpeech/$part/.complete 94 | 95 | echo "$0: Successfully downloaded and un-tarred $data/$part.tar.gz" 96 | 97 | if $remove_archive; then 98 | echo "$0: removing $data/$part.tar.gz file since --remove-archive option was supplied." 99 | rm $data/$part.tar.gz 100 | fi 101 | -------------------------------------------------------------------------------- /egs/mini_librispeech/v1/local/run_blstm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita) 4 | # Licensed under the MIT license. 5 | # 6 | # BLSTM-based model experiment 7 | ./run.sh --train-config conf/blstm/train.yaml --infer-config conf/blstm/infer.yaml $* 8 | -------------------------------------------------------------------------------- /egs/mini_librispeech/v1/musan_bgnoise.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hitachi-speech/EEND/b851eecd8d7a966487ed3e4ff934a1581a73cc9e/egs/mini_librispeech/v1/musan_bgnoise.tar.gz -------------------------------------------------------------------------------- /egs/mini_librispeech/v1/path.sh: -------------------------------------------------------------------------------- 1 | . ../../../tools/env.sh 2 | -------------------------------------------------------------------------------- /egs/mini_librispeech/v1/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita) 4 | # Licensed under the MIT license. 5 | # 6 | stage=0 7 | 8 | # The datasets for training must be formatted as kaldi data directory. 9 | # Also, make sure the audio files in wav.scp are 'regular' wav files. 10 | # Including piped commands in wav.scp makes training very slow 11 | train_set=data/simu/data/train_clean_5_ns2_beta2_500 12 | valid_set=data/simu/data/dev_clean_2_ns2_beta2_500 13 | 14 | # Base config files for {train,infer}.py 15 | train_config=conf/train.yaml 16 | infer_config=conf/infer.yaml 17 | # If you want to use EDA-EEND, uncommend two lines below. 18 | # train_config=conf/eda/train.yaml 19 | # infer_config=conf/eda/infer.yaml 20 | 21 | # Additional arguments passed to {train,infer}.py. 22 | # You need not edit the base config files above 23 | train_args= 24 | infer_args= 25 | 26 | # Model averaging options 27 | average_start=8 28 | average_end=10 29 | 30 | # Resume training from snapshot at this epoch 31 | # TODO: not tested 32 | resume=-1 33 | 34 | . path.sh 35 | . cmd.sh 36 | . parse_options.sh || exit 37 | 38 | set -eu 39 | 40 | # Parse the config file to set bash variables like: $train_frame_shift, $infer_gpu 41 | eval `yaml2bash.py --prefix train $train_config` 42 | eval `yaml2bash.py --prefix infer $infer_config` 43 | 44 | # Append gpu reservation flag to the queuing command 45 | if [ $train_gpu -le 0 ]; then 46 | train_cmd+=" --gpu 1" 47 | fi 48 | if [ $infer_gpu -le 0 ]; then 49 | infer_cmd+=" --gpu 1" 50 | fi 51 | 52 | # Build directry names for an experiment 53 | # - Training 54 | # exp/diarize/model/{train_id}.{valid_id}.{train_config_id} 55 | # - Decoding 56 | # exp/diarize/infer/{train_id}.{valid_id}.{train_config_id}.{infer_config_id} 57 | # - Scoring 58 | # exp/diarize/scoring/{train_id}.{valid_id}.{train_config_id}.{infer_config_id} 59 | # - Adapation from non-adapted averaged model 60 | # exp/diarize/model/{train_id}.{valid_id}.{train_config_id}.{avgid}.{adapt_config_id} 61 | train_id=$(basename $train_set) 62 | valid_id=$(basename $valid_set) 63 | train_config_id=$(echo $train_config | sed -e 's%conf/%%' -e 's%/%_%' -e 's%\.yaml$%%') 64 | infer_config_id=$(echo $infer_config | sed -e 's%conf/%%' -e 's%/%_%' -e 's%\.yaml$%%') 65 | 66 | # Additional arguments are added to config_id 67 | train_config_id+=$(echo $train_args | sed -e 's/\-\-/_/g' -e 's/=//g' -e 's/ \+//g') 68 | infer_config_id+=$(echo $infer_args | sed -e 's/\-\-/_/g' -e 's/=//g' -e 's/ \+//g') 69 | 70 | model_id=$train_id.$valid_id.$train_config_id 71 | model_dir=exp/diarize/model/$model_id 72 | if [ $stage -le 1 ]; then 73 | echo "training model at $model_dir." 74 | if [ -d $model_dir ]; then 75 | echo "$model_dir already exists. " 76 | echo " if you want to retry, please remove it." 77 | exit 1 78 | fi 79 | work=$model_dir/.work 80 | mkdir -p $work 81 | $train_cmd $work/train.log \ 82 | train.py \ 83 | -c $train_config \ 84 | $train_args \ 85 | $train_set $valid_set $model_dir \ 86 | || exit 1 87 | fi 88 | 89 | ave_id=avg${average_start}-${average_end} 90 | if [ $stage -le 2 ]; then 91 | echo "averaging model parameters into $model_dir/$ave_id.nnet.npz" 92 | if [ -s $model_dir/$ave_id.nnet.npz ]; then 93 | echo "$model_dir/$ave_id.nnet.npz already exists. " 94 | echo " if you want to retry, please remove it." 95 | exit 1 96 | fi 97 | models=`eval echo $model_dir/snapshot_epoch-{$average_start..$average_end}` 98 | model_averaging.py $model_dir/$ave_id.nnet.npz $models || exit 1 99 | fi 100 | 101 | infer_dir=exp/diarize/infer/$model_id.$ave_id.$infer_config_id 102 | if [ $stage -le 3 ]; then 103 | echo "inference at $infer_dir" 104 | if [ -d $infer_dir ]; then 105 | echo "$infer_dir already exists. " 106 | echo " if you want to retry, please remove it." 107 | exit 1 108 | fi 109 | for dset in dev_clean_2_ns2_beta2_500; do 110 | work=$infer_dir/$dset/.work 111 | mkdir -p $work 112 | $infer_cmd $work/infer.log \ 113 | infer.py \ 114 | -c $infer_config \ 115 | $infer_args \ 116 | data/simu/data/$dset \ 117 | $model_dir/$ave_id.nnet.npz \ 118 | $infer_dir/$dset \ 119 | || exit 1 120 | done 121 | fi 122 | 123 | scoring_dir=exp/diarize/scoring/$model_id.$ave_id.$infer_config_id 124 | if [ $stage -le 4 ]; then 125 | echo "scoring at $scoring_dir" 126 | if [ -d $scoring_dir ]; then 127 | echo "$scoring_dir already exists. " 128 | echo " if you want to retry, please remove it." 129 | exit 1 130 | fi 131 | for dset in dev_clean_2_ns2_beta2_500; do 132 | work=$scoring_dir/$dset/.work 133 | mkdir -p $work 134 | find $infer_dir/$dset -iname "*.h5" > $work/file_list_$dset 135 | for med in 1 11; do 136 | for th in 0.3 0.4 0.5 0.6 0.7; do 137 | make_rttm.py --median=$med --threshold=$th \ 138 | --frame_shift=$infer_frame_shift --subsampling=$infer_subsampling --sampling_rate=$infer_sampling_rate \ 139 | $work/file_list_$dset $scoring_dir/$dset/hyp_${th}_$med.rttm 140 | md-eval.pl -c 0.25 \ 141 | -r data/simu/data/$dset/rttm \ 142 | -s $scoring_dir/$dset/hyp_${th}_$med.rttm > $scoring_dir/$dset/result_th${th}_med${med}_collar0.25 2>/dev/null || exit 143 | done 144 | done 145 | done 146 | fi 147 | 148 | if [ $stage -le 5 ]; then 149 | for dset in dev_clean_2_ns2_beta2_500; do 150 | best_score.sh $scoring_dir/$dset 151 | done 152 | fi 153 | echo "Finished !" 154 | -------------------------------------------------------------------------------- /egs/mini_librispeech/v1/run_prepare_shared.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita) 4 | # Licensed under the MIT license. 5 | # 6 | # This script prepares kaldi-style data sets shared with different experiments 7 | # - data/xxxx 8 | # callhome, sre, swb2, and swb_cellular datasets 9 | # - data/simu_${simu_outputs} 10 | # simulation mixtures generated with various options 11 | # This script does NOT include the composition of train/valid/test sets. 12 | # The composition will be done at stage 1 of ./run.sh 13 | 14 | stage=0 15 | 16 | # This script distributes simulated data under these directories 17 | simu_actual_dirs=( 18 | $PWD/data/local/diarization-data 19 | ) 20 | 21 | # simulation options 22 | simu_opts_overlap=yes 23 | simu_opts_num_speaker=2 24 | simu_opts_sil_scale=2 25 | simu_opts_rvb_prob=0.5 26 | simu_opts_num_train=100 27 | simu_opts_min_utts=10 28 | simu_opts_max_utts=20 29 | 30 | . path.sh 31 | . cmd.sh 32 | . parse_options.sh || exit 33 | 34 | if [ $stage -le 0 ]; then 35 | echo "prepare kaldi-style datasets" 36 | mini_librispeech_url=http://www.openslr.org/resources/31 37 | mkdir -p data/local 38 | local/download_and_untar.sh data/local $mini_librispeech_url dev-clean-2 39 | local/download_and_untar.sh data/local $mini_librispeech_url train-clean-5 40 | if [ ! -f data/dev_clean_2/.done ]; then 41 | local/data_prep.sh data/local/LibriSpeech/dev-clean-2 data/dev_clean_2 || exit 42 | touch data/dev_clean_2/.done 43 | fi 44 | if [ ! -f data/train_clean_5/.done ]; then 45 | local/data_prep.sh data/local/LibriSpeech/train-clean-5 data/train_clean_5 46 | touch data/train_clean_5/.done 47 | fi 48 | if [ ! -d data/musan_bgnoise ]; then 49 | tar xzf musan_bgnoise.tar.gz 50 | fi 51 | if [ ! -f data/simu_rirs_8k/.done ]; then 52 | mkdir -p data/simu_rirs_8k 53 | if [ ! -e sim_rir_8k.zip ]; then 54 | wget --no-check-certificate http://www.openslr.org/resources/26/sim_rir_8k.zip 55 | fi 56 | unzip sim_rir_8k.zip -d data/sim_rir_8k 57 | find $PWD/data/sim_rir_8k -iname "*.wav" \ 58 | | awk '{n=split($1,A,/[\/\.]/); print A[n-3]"_"A[n-1], $1}' \ 59 | | sort > data/simu_rirs_8k/wav.scp 60 | awk '{print $1, $1}' data/simu_rirs_8k/wav.scp > data/simu_rirs_8k/utt2spk 61 | utils/fix_data_dir.sh data/simu_rirs_8k 62 | touch data/simu_rirs_8k/.done 63 | fi 64 | fi 65 | 66 | simudir=data/simu 67 | if [ $stage -le 1 ]; then 68 | echo "simulation of mixture" 69 | mkdir -p $simudir/.work 70 | random_mixture_cmd=random_mixture_nooverlap.py 71 | make_mixture_cmd=make_mixture_nooverlap.py 72 | if [ "$simu_opts_overlap" == "yes" ]; then 73 | random_mixture_cmd=random_mixture.py 74 | make_mixture_cmd=make_mixture.py 75 | fi 76 | 77 | for simu_opts_sil_scale in 2; do 78 | for dset in train_clean_5 dev_clean_2; do 79 | n_mixtures=500 80 | simuid=${dset}_ns${simu_opts_num_speaker}_beta${simu_opts_sil_scale}_${n_mixtures} 81 | # check if you have the simulation 82 | if ! validate_data_dir.sh --no-text --no-feats $simudir/data/$simuid; then 83 | # random mixture generation 84 | $simu_cmd $simudir/.work/random_mixture_$simuid.log \ 85 | $random_mixture_cmd --n_speakers $simu_opts_num_speaker --n_mixtures $n_mixtures \ 86 | --speech_rvb_probability $simu_opts_rvb_prob \ 87 | --sil_scale $simu_opts_sil_scale \ 88 | data/$dset data/musan_bgnoise data/simu_rirs_8k \ 89 | \> $simudir/.work/mixture_$simuid.scp 90 | nj=100 91 | mkdir -p $simudir/wav/$simuid 92 | # distribute simulated data to $simu_actual_dir 93 | split_scps= 94 | for n in $(seq $nj); do 95 | split_scps="$split_scps $simudir/.work/mixture_$simuid.$n.scp" 96 | mkdir -p $simudir/.work/data_$simuid.$n 97 | actual=${simu_actual_dirs[($n-1)%${#simu_actual_dirs[@]}]}/$simudir/wav/$simuid/$n 98 | mkdir -p $actual 99 | ln -nfs $actual $simudir/wav/$simuid/$n 100 | done 101 | utils/split_scp.pl $simudir/.work/mixture_$simuid.scp $split_scps || exit 1 102 | 103 | $simu_cmd --max-jobs-run 32 JOB=1:$nj $simudir/.work/make_mixture_$simuid.JOB.log \ 104 | $make_mixture_cmd --rate=8000 \ 105 | $simudir/.work/mixture_$simuid.JOB.scp \ 106 | $simudir/.work/data_$simuid.JOB $simudir/wav/$simuid/JOB 107 | utils/combine_data.sh $simudir/data/$simuid $simudir/.work/data_$simuid.* 108 | steps/segmentation/convert_utt2spk_and_segments_to_rttm.py \ 109 | $simudir/data/$simuid/utt2spk $simudir/data/$simuid/segments \ 110 | $simudir/data/$simuid/rttm 111 | utils/data/get_reco2dur.sh $simudir/data/$simuid 112 | fi 113 | done 114 | done 115 | fi 116 | -------------------------------------------------------------------------------- /egs/mini_librispeech/v1/steps: -------------------------------------------------------------------------------- 1 | ../../../tools/kaldi/egs/wsj/s5/steps -------------------------------------------------------------------------------- /egs/mini_librispeech/v1/utils: -------------------------------------------------------------------------------- 1 | ../../../tools/kaldi/egs/wsj/s5/utils -------------------------------------------------------------------------------- /tools/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile 2 | # Copyright 2018-2019 Hitachi, Ltd. (author: Yusuke Fujita) 3 | # 4 | # install tools 5 | 6 | # If you want to use prebuild kaldi, make KALDI= 7 | KALDI := 8 | # Specify cuda root path installed in your environment 9 | CUDA_PATH := /usr/local/cuda 10 | CUDA_VERSION := $(shell $(CUDA_PATH)/bin/nvcc --version | tail -n1 | awk '{print substr($$5,0,length($$5)-1)}') 11 | 12 | all: kaldi miniconda3/envs/eend/bin env.sh 13 | 14 | ifneq ($(strip $(KALDI)),) 15 | kaldi: 16 | ln -s $(abspath $(KALDI)) kaldi 17 | else 18 | kaldi: 19 | git clone https://github.com/kaldi-asr/kaldi.git 20 | cd kaldi/tools; $(MAKE) 21 | cd kaldi/src; ./configure --shared --use-cuda=no; $(MAKE) depend; $(MAKE) all 22 | endif 23 | 24 | miniconda3.sh: 25 | wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O $@ 26 | 27 | miniconda3: miniconda3.sh 28 | # -b: non-interactive install 29 | # -p: installed directory 30 | bash miniconda3.sh -b -p miniconda3 31 | 32 | # virtual environment of python 33 | miniconda3/envs/eend/bin: miniconda3 34 | miniconda3/bin/conda update -y conda 35 | miniconda3/bin/conda env create -f environment.yml 36 | miniconda3/envs/eend/bin/pip install cupy-cuda$(subst .,,$(CUDA_VERSION))==6.2.0 chainer==6.2.0 37 | update: 38 | miniconda3/bin/conda env update -f environment.yml 39 | 40 | env.sh: miniconda3/envs/eend/bin 41 | cp env.sh.in env.sh 42 | echo "export LD_LIBRARY_PATH=$(CUDA_PATH)/lib64:$$LD_LIBRARY_PATH" >> env.sh 43 | -------------------------------------------------------------------------------- /tools/env.sh.in: -------------------------------------------------------------------------------- 1 | # Task-independent environmental variables 2 | export KALDI_ROOT=`pwd`/../../../tools/kaldi 3 | [ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh 4 | export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$KALDI_ROOT/tools/sph2pipe_v2.5:$KALDI_ROOT/tools/sctk/bin:$PWD:$PATH 5 | [ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1 6 | . $KALDI_ROOT/tools/config/common_path.sh 7 | export PATH=../../../tools/miniconda3/envs/eend/bin:$PATH 8 | export PATH=../../../eend/bin:../../../utils:$PATH 9 | export PYTHONPATH=../../..:$PYTHONPATH 10 | # cuda runtime 11 | -------------------------------------------------------------------------------- /tools/environment.yml: -------------------------------------------------------------------------------- 1 | name: eend 2 | channels: 3 | - defaults 4 | dependencies: 5 | - python=3.7 6 | - pip 7 | - pip: 8 | - h5py 9 | - librosa 10 | - numpy 11 | - protobuf 12 | - scipy 13 | - tqdm 14 | - SoundFile 15 | - matplotlib 16 | - kaldiio 17 | - yamlargparse 18 | -------------------------------------------------------------------------------- /utils/best_score.sh: -------------------------------------------------------------------------------- 1 | grep OVER $1/result_th0.[^_]*_med[^_]*_collar0.25 | grep -v nooverlap | sort -nrk 7 | tail -n 1 2 | --------------------------------------------------------------------------------