├── LICENCE ├── README.md ├── eend ├── bin │ ├── infer.py │ ├── make_mixture.py │ ├── make_rttm.py │ ├── model_averaging.py │ ├── random_mixture.py │ ├── save_spkv_lab.py │ ├── train.py │ └── yaml2bash.py ├── feature.py ├── kaldi_data.py └── pytorch_backend │ ├── diarization_dataset.py │ ├── infer.py │ ├── models.py │ ├── train.py │ └── transformer.py ├── egs ├── callhome │ └── v1 │ │ ├── Log.tar.gz │ │ ├── RESULT.md │ │ ├── cmd.sh │ │ ├── conf │ │ ├── adapt.yaml │ │ ├── infer_est_nspk0.yaml │ │ ├── infer_est_nspk1.yaml │ │ ├── save_spkv_lab.yaml │ │ └── train.yaml │ │ ├── local │ │ ├── path.sh │ │ ├── rttm_results │ │ ├── README │ │ └── eend-vc_callhome2_hyp_rttm.tar.gz │ │ ├── run.sh │ │ ├── run_prepare_shared_eda.sh │ │ ├── steps │ │ └── utils └── mini_librispeech │ └── v1 │ ├── RESULT.md │ ├── cmd.sh │ ├── conf │ ├── infer_est_nspk0.yaml │ ├── infer_est_nspk1.yaml │ └── train.yaml │ ├── local │ ├── musan_bgnoise.tar.gz │ ├── path.sh │ ├── run.sh │ ├── run_prepare_shared.sh │ ├── steps │ └── utils ├── tools ├── Makefile └── environment.yml └── utils └── best_score.sh /LICENCE: -------------------------------------------------------------------------------- 1 | SOFTWARE LICENSE AGREEMENT FOR EVALUATION 2 | 3 | This SOFTWARE EVALUATION LICENSE AGREEMENT (this "Agreement") is a legal contract between a person who uses or otherwise accesses or installs the Software (“User(s)”), and Nippon Telegraph and Telephone corporation ("NTT"). 4 | READ THE TERMS AND CONDITIONS OF THIS AGREEMENT CAREFULLY BEFORE INSTALLING OR OTHERWISE ACCESSING OR USING NTT'S PROPRIETARY SOFTWARE ACCOMPANIED BY THIS AGREEMENT (the "SOFTWARE"). THE SOFTWARE IS COPYRIGHTED AND IT IS LICENSED TO USER UNDER THIS AGREEMENT, NOT SOLD TO USER. BY INSTALLING OR OTHERWISE ACCESSING OR USING THE SOFTWARE, USER ACKNOWLEDGES THAT USER HAS READ THIS AGREEMENT, THAT USER UNDERSTANDS IT, AND THAT USER ACCEPTS AND AGREES TO BE BOUND BY ITS TERMS. IF AT ANY TIME USER IS NOT WILLING TO BE BOUND BY THE TERMS OF THIS AGREEMENT, USER SHOULD TERMINATE THE INSTALLATION PROCESS, IMMEDIATELY CEASE AND REFRAIN FROM ACCESSING OR USING THE SOFTWARE AND DELETE ANY COPIES USER MAY HAVE. THIS AGREEMENT REPRESENTS THE ENTIRE AGREEMENT BETWEEN USER AND NTT CONCERNING THE SOFTWARE. 5 | 6 | 7 | BACKGROUND 8 | A. NTT is the owner of all rights, including all patent rights, copyrights and trade secret rights, in and to the Software and related documentation listed in Exhibit A to this Agreement. 9 | 10 | B. User wishes to obtain a royalty free license to use the Software to enable User to evaluate, and NTT wishes to grant such a license to User, pursuant and subject to the terms and conditions of this Agreement. 11 | 12 | C. As a condition to NTT's provision of the Software to User, NTT has required User to execute this Agreement. 13 | 14 | In consideration of these premises, and the mutual promises and conditions in this Agreement, the parties hereby agree as follows: 15 | 16 | 1. Grant of Evaluation License. NTT hereby grants to User, and User hereby accepts, under the terms and conditions of this Agreement, a royalty free, nontransferable and nonexclusive license to use the Software internally for the purposes of testing, analyzing, and evaluating the methods or mechanisms as shown in the research paper submitted by NTT to a certain academy. User may make a reasonable number of backup copies of the Software solely for User's internal use pursuant to the license granted in this Section 1. 17 | 18 | 2. Shipment and Installation. NTT will ship or deliver the Software by any method that NTT deems appropriate. User shall be solely responsible for proper installation of the Software. 19 | 20 | 3. Term. This Agreement is effective whichever is earlier (i) upon User’s acceptance of the Agreement, or (ii) upon User’s installing, accessing, and using the Software, even if User has not expressly accepted this Agreement. Without prejudice to any other rights, NTT may terminate this Agreement without notice to User (i) if User breaches or fails to comply with any of the limitations or other requirements described herein, and (ii) if NTT receives a notice from the academy stating that the research paper would not be published, and in any such case User agrees that NTT may, in addition to any other remedies it may have at law or in equity, remotely disable the Software. User may terminate this Agreement at any time by User’s decision to terminate the Agreement to NTT and ceasing use of the Software. Upon any termination or expiration of this Agreement for any reason, User agrees to uninstall the Software and either return to NTT the Software and all copies thereof, or to destroy all such materials and provide written verification of such destruction to NTT. 21 | 22 | 4. Proprietary Rights 23 | (a) The Software is the valuable, confidential, and proprietary property of NTT, and NTT shall retain exclusive title to this property both during the term and after the termination of this Agreement. Without limitation, User acknowledges that all patent rights, copyrights and trade secret rights in the Software shall remain the exclusive property of NTT at all times. User shall use not less than reasonable care in safeguarding the confidentiality of the Software. 24 | (b) USER SHALL NOT, IN WHOLE OR IN PART, AT ANY TIME DURING THE TERM OF OR AFTER THE TERMINATION OF THIS AGREEMENT: (i) SELL, ASSIGN, LEASE, DISTRIBUTE, OR OTHERWISE TRANSFER THE SOFTWARE TO ANY THIRD PARTY; (ii) EXCEPT AS OTHERWISE PROVIDED HEREIN, COPY OR REPRODUCE THE SOFTWARE IN ANY MANNER; (iii) DISCLOSE THE SOFTWARE TO ANY THIRD PARTY, EXCEPT TO USER'S EMPLOYEES WHO REQUIRE ACCESS TO THE SOFTWARE FOR THE PURPOSES OF THIS AGREEMENT; (iv) MODIFY, DISASSEMBLE, DECOMPILE, REVERSE ENGINEER OR TRANSLATE THE SOFTWARE; OR (v) ALLOW ANY PERSON OR ENTITY TO COMMIT ANY OF THE ACTIONS DESCRIBED IN (i) THROUGH (iv) ABOVE. 25 | (c) User shall take appropriate action, by instruction, agreement, or otherwise, with respect to its employees permitted under this Agreement to have access to the Software to ensure that all of User's obligations under this Section 4 shall be satisfied. 26 | 27 | 5.  Indemnity. User shall defend, indemnify and hold harmless NTT, its agents and employees, from any loss, damage, or liability arising in connection with User's improper or unauthorized use of the Software. NTT SHALL HAVE THE SOLE RIGHT TO CONDUCT DEFEND ANY ACTTION RELATING TO THE SOFTWARE. 28 | 29 | 6. Disclaimer. THE SOFTWARE IS LICENSED TO USER "AS IS," WITHOUT ANY TRAINING, MAINTENANCE, OR SERVICE OBLIGATIONS WHATSOEVER ON THE PART OF NTT. NTT MAKES NO EXPRESS OR IMPLIED WARRANTIES OF ANY TYPE WHATSOEVER, INCLUDING WITHOUT LIMITATION THE IMPLIED WARRANTIES OF MERCHANTABILITY, OF FITNESS FOR A PARTICULAR PURPOSE AND OF NON-INFRINGEMENT ON COPYRIGHT OR ANY OTHER RIGHT OF THIRD PARTIES. USER ASSUMES ALL RISKS ASSOCIATED WITH ITS USE OF THE SOFTWARE, INCLUDING WITHOUT LIMITATION RISKS RELATING TO QUALITY, PERFORMANCE, DATA LOSS, AND UTILITY IN A PRODUCTION ENVIRONMENT. 30 | 31 | 7. Limitation of Liability. IN NO EVENT SHALL NTT BE LIABLE TO USER OR TO ANY THIRD PARTY FOR ANY INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING BUT NOT LIMITED TO DAMAGES FOR PERSONAL INJURY, PROPERTY DAMAGE, LOST PROFITS, OR OTHER ECONOMIC LOSS, ARISING IN CONNECTION WITH USER'S USE OF OR INABILITY TO USE THE SOFTWARE, IN CONNECTION WITH NTT'S PROVISION OF OR FAILURE TO PROVIDE SERVICES PERTAINING TO THE SOFTWARE, OR AS A RESULT OF ANY DEFECT IN THE SOFTWARE. THIS DISCLAIMER OF LIABILITY SHALL APPLY REGARD-LESS OF THE FORM OF ACTION THAT MAY BE BROUGHT AGAINST NTT, WHETHER IN CONTRACT OR TORT, INCLUDING WITHOUT LIMITATION ANY ACTION FOR NEGLIGENCE. USER'S SOLE REMEDY IN THE EVENT OF ANY BREACH OF THIS AGREEMENT BY NTT SHALL BE TERMINATION PURSUANT TO SECTION 3. 32 | 33 | 8. No Assignment or Sublicense. Neither this Agreement nor any right or license under this Agreement, nor the Software, may be sublicensed, assigned, or otherwise transferred by User without NTT's prior written consent. 34 | 35 | 9. General 36 | (a) If any provision, or part of a provision, of this Agreement is or becomes illegal, unenforceable, or invalidated, by operation of law or otherwise, that provision or part shall to that extent be deemed omitted, and the remainder of this Agreement shall remain in full force and effect. 37 | (b) This Agreement is the complete and exclusive statement of the agreement between the parties with respect to the subject matter hereof, and supersedes all written and oral contracts, proposals, and other communications between the parties relating to that subject matter. 38 | (c) Subject to Section 8, this Agreement shall be binding on, and shall inure to the benefit of, the respective successors and assigns of NTT and User. 39 | (d) If either party to this Agreement initiates a legal action or proceeding to enforce or interpret any part of this Agreement, the prevailing party in such action shall be entitled to recover, as an element of the costs of such action and not as damages, its attorneys' fees and other costs associated with such action or proceeding. 40 | (e) This Agreement shall be governed by and interpreted under the laws of Japan, without reference to conflicts of law principles. All disputes arising out of or in connection with this Agreement shall be finally settled by arbitration in Tokyo in accordance with the Commercial Arbitration Rules of the Japan Commercial Arbitration Association. The arbitration shall be conducted by three (3) arbitrators and in Japanese. The award rendered by the arbitrators shall be final and binding upon the parties. Judgment upon the award may be entered in any court having jurisdiction thereof. 41 | (f)   NTT shall not be liable to the User or to any third party for any delay or failure to perform NTT’s obligation set forth under this Agreement due to any cause beyond NTT’s reasonable control. 42 |   43 | EXHIBIT A 44 | This package contains the python implementation for the paper “Advances in integration of end-to-end neural and clustering-based diarization for real conversational speech” presented at Interspeech’21. The package includes the following 45 | - Code to train the proposed EEND-vector-clustering model 46 | - Code to perform inference with the model 47 | - Code to perform evaluation based on CALLHOME data 48 | The package contains the following files. 49 | 50 | . 51 | ├── LICENCE 52 | ├── README.md 53 | ├── eend 54 | │   ├── bin 55 | │   │   ├── infer.py 56 | │   │   ├── make_mixture.py -> ../../tools/EEND/eend/bin/make_mixture.py 57 | │   │   ├── make_rttm.py -> ../../tools/EEND/eend/bin/make_rttm.py 58 | │   │   ├── model_averaging.py 59 | │   │   ├── random_mixture.py -> ../../tools/EEND/eend/bin/random_mixture.py 60 | │   │   ├── save_spkv_lab.py 61 | │   │   ├── train.py 62 | │   │   └── yaml2bash.py -> ../../tools/EEND/eend/bin/yaml2bash.py 63 | │   ├── feature.py -> ../tools/EEND/eend/feature.py 64 | │   ├── kaldi_data.py -> ../tools/EEND/eend/kaldi_data.py 65 | │   └── pytorch_backend 66 | │   ├── diarization_dataset.py 67 | │   ├── infer.py 68 | │   ├── models.py 69 | │   ├── train.py 70 | │   └── transformer.py 71 | ├── egs 72 | │   ├── callhome 73 | │   │   └── v1 74 | │   │   ├── RESULT.md 75 | │   │   ├── cmd.sh 76 | │   │   ├── conf 77 | │   │   │   ├── adapt.yaml 78 | │   │   │   ├── infer_est_nspk0.yaml 79 | │   │   │   ├── infer_est_nspk1.yaml 80 | │   │   │   ├── save_spkv_lab.yaml 81 | │   │   │   └── train.yaml 82 | │   │   ├── local -> ../../../tools/EEND/egs/callhome/v1/local 83 | │   │   ├── path.sh -> ../../../tools/EEND/egs/callhome/v1/path.sh 84 | │   │   ├── run.sh 85 | │   │   ├── run_prepare_shared_eda.sh -> ../../../tools/EEND/egs/callhome/v1/run_prepare_shared_eda.sh 86 | │   │   ├── steps -> ../../../tools/kaldi/egs/wsj/s5/steps 87 | │   │   └── utils -> ../../../tools/kaldi/egs/wsj/s5/utils 88 | │   └── mini_librispeech 89 | │   └── v1 90 | │   ├── RESULT.md 91 | │   ├── cmd.sh 92 | │   ├── conf 93 | │   │   ├── infer_est_nspk0.yaml 94 | │   │   ├── infer_est_nspk1.yaml 95 | │   │   └── train.yaml 96 | │   ├── local -> ../../../tools/EEND/egs/mini_librispeech/v1/local 97 | │   ├── musan_bgnoise.tar.gz -> ../../../tools/EEND/egs/mini_librispeech/v1/musan_bgnoise.tar.gz 98 | │   ├── path.sh -> ../../../tools/EEND/egs/mini_librispeech/v1/path.sh 99 | │   ├── run.sh 100 | │   ├── run_prepare_shared.sh -> ../../../tools/EEND/egs/mini_librispeech/v1/run_prepare_shared.sh 101 | │   ├── steps -> ../../../tools/kaldi/egs/wsj/s5/steps 102 | │   └── utils -> ../../../tools/kaldi/egs/wsj/s5/utils 103 | ├── tools 104 | │   ├── Makefile 105 | │   └── environment.yml 106 | └── utils 107 | └── best_score.sh -> ../tools/EEND/utils/best_score.sh 108 | 109 | 12 directories, 45 files 110 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EEND-vector clustering 2 | 3 | The EEND-vector clustering (End-to-End-Neural-Diarization-vector clustering) is a speaker diarization framework that integrates two complementary major diarization approaches, i.e., traditional clustering-based and emerging end-to-end neural network-based approaches, to make the best of both worlds. In [1] it is shown that the EEND-vector clustering outperforms EEND when the recording is long (e.g., more than 5 min), while in [2] it is shown based on CALLHOME data that it outperforms x-vector clustering and EEND-EDA especially when the number of speakers in recordings is large. 4 | 5 | This repository contains an example implementation of the EEND-vector clustering based on Pytorch to reproduce the results in [2], i.e., the CALLHOME experiments. For the trainer, we use [Padertorch](https://github.com/fgnt/padertorch). This repository is implemented based on [EEND](https://github.com/hitachi-speech/EEND) and relies on some useful functions provided therein. 6 | 7 | 8 | ## References 9 | [1] Keisuke Kinoshita, Marc Delcroix, and Naohiro Tawara, "Integrating end-to-end neural and clustering-based diarization: Getting the best of both worlds," Proc. ICASSP, pp. 7198–7202, 2021 10 | 11 | [2] Keisuke Kinoshita, Marc Delcroix, and Naohiro Tawara, "Advances in integration of end-to-end neural and clustering-based diarization for real conversational speech," Proc. Interspeech, 2021 (to appear) 12 | 13 | ## Citation 14 | ``` 15 | @inproceedings{eend-vector-clustering, 16 | author = {Keisuke Kinoshita and Marc Delcroix and Naohiro Tawara}, 17 | title = {Integrating End-to-End Neural and Clustering-Based Diarization: Getting the Best of Both Worlds}, 18 | booktitle = {{ICASSP 2021 - 2021 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}}, 19 | pages={7198-7202} 20 | year = {2021} 21 | } 22 | ``` 23 | 24 | ## Install tools 25 | ### Requirements 26 | - NVIDIA CUDA GPU 27 | - CUDA Toolkit (version == 9.2, 10.1 or 10.2) 28 | 29 | ### Install kaldi and python environment 30 | ```bash 31 | cd tools 32 | make 33 | ``` 34 | - This command builds kaldi at `tools/kaldi` 35 | - if you want to use pre-build kaldi 36 | ```bash 37 | cd tools 38 | make KALDI= 39 | ``` 40 | This option make a symlink at `tools/kaldi` 41 | - This command extracts miniconda3 at `tools/miniconda3`, and creates conda envirionment named 'eend' 42 | - Then, installs Pytorch and Padertorch into 'eend' environment 43 | - use CUDA in `/usr/local/cuda/` 44 | - if you need to specify your CUDA path 45 | ```bash 46 | cd tools 47 | make CUDA_PATH=/your/path/to/cuda-10.1 48 | ``` 49 | The pytorch install command to be executed is depended on your CUDA version. 50 | See https://pytorch.org/get-started/previous-versions/ 51 | - Then, clones [EEND](https://github.com/hitachi-speech/EEND) to reference symbolic links stored under `eend/`, `egs/` and `utils/` 52 | 53 | ## Test recipe (mini_librispeech) 54 | ### Configuration 55 | - Modify `egs/mini_librispeech/v1/cmd.sh` according to your job schedular. 56 | If you use your local machine, use "run.pl" (default). 57 | If you use Grid Engine, use "queue.pl" 58 | If you use SLURM, use "slurm.pl". 59 | For more information about cmd.sh see http://kaldi-asr.org/doc/queue.html. 60 | ### Run data preparation, training, inference, and scoring 61 | ```bash 62 | cd egs/mini_librispeech/v1 63 | CUDA_VISIBLE_DEVICES=0 ./run.sh 64 | ``` 65 | - See `RESULT.md` and compare with your result. 66 | 67 | ## CALLHOME experiment 68 | ### Configuraition 69 | - Modify `egs/callhome/v1/cmd.sh` according to your job schedular. 70 | If you use your local machine, use "run.pl" (default). 71 | If you use Grid Engine, use "queue.pl" 72 | If you use SLURM, use "slurm.pl". 73 | For more information about cmd.sh see http://kaldi-asr.org/doc/queue.html. 74 | ### Run data preparation, training, inference, and scoring 75 | ```bash 76 | cd egs/callhome/v1 77 | CUDA_VISIBLE_DEVICES=0 ./run.sh --db_path 78 | # means absolute path of the directory where the necessary LDC corpora are stored. 79 | ``` 80 | - See `RESULT.md` and compare with your result. 81 | - If you want to run multi-GPU training, simply set `CUDA_VISIBLE_DEVICES` appropriately. This environment variable may be automatically set by your job schedular such as SLURM. 82 | 83 | -------------------------------------------------------------------------------- /eend/bin/infer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) 2021 Nippon Telegraph and Telephone corporation (NTT). 4 | # All rights reserved 5 | 6 | import yamlargparse 7 | from eend.pytorch_backend.infer import infer 8 | 9 | parser = yamlargparse.ArgumentParser(description='decoding') 10 | parser.add_argument('-c', '--config', help='config file path', 11 | action=yamlargparse.ActionConfigFile) 12 | parser.add_argument('data_dir', 13 | help='kaldi-style data dir') 14 | parser.add_argument('model_file', 15 | help='best.nnet') 16 | parser.add_argument('out_dir', 17 | help='output directory.') 18 | 19 | # The following arguments are set in conf/infer_est_nspk{0,1}.yaml 20 | parser.add_argument('--est-nspk', default=1, type=int, choices=[0, 1], 21 | help='At clustering stage, --est-nspk 0 means that\ 22 | oracle number of speakers is used, --est-nspk 1 means\ 23 | estimating numboer of speakers') 24 | parser.add_argument('--num-speakers', default=3, type=int) 25 | parser.add_argument('--spkv-dim', default=256, type=int, 26 | help='dimension of speaker embedding vector') 27 | parser.add_argument('--hidden-size', default=256, type=int) 28 | parser.add_argument('--input-transform', default='logmel23_mn', 29 | choices=['', 'log', 'logmel', 30 | 'logmel23', 'logmel23_swn', 'logmel23_mn'], 31 | help='input transform') 32 | parser.add_argument('--chunk-size', default=300, type=int, 33 | help='input is chunked with this size') 34 | parser.add_argument('--context-size', default=7, type=int, 35 | help='frame splicing') 36 | parser.add_argument('--subsampling', default=10, type=int) 37 | parser.add_argument('--sampling-rate', default=8000, type=int, 38 | help='sampling rate') 39 | parser.add_argument('--frame-size', default=200, type=int, 40 | help='frame size') 41 | parser.add_argument('--frame-shift', default=80, type=int, 42 | help='frame shift') 43 | parser.add_argument('--transformer-encoder-n-heads', default=8, type=int) 44 | parser.add_argument('--transformer-encoder-n-layers', default=6, type=int) 45 | parser.add_argument('--sil-spk-th', default=0.05, type=float, 46 | help='activity threshold to detect the silent speaker') 47 | parser.add_argument('--ahc-dis-th', default=1.0, type=float, 48 | help='distance threshold above which clusters\ 49 | will not be merged') 50 | parser.add_argument('--clink-dis', default=1e+4, type=float, 51 | help='modified distance corresponding to cannot-link') 52 | 53 | args = parser.parse_args() 54 | print(args) 55 | 56 | infer(args) 57 | -------------------------------------------------------------------------------- /eend/bin/make_mixture.py: -------------------------------------------------------------------------------- 1 | ../../tools/EEND/eend/bin/make_mixture.py -------------------------------------------------------------------------------- /eend/bin/make_rttm.py: -------------------------------------------------------------------------------- 1 | ../../tools/EEND/eend/bin/make_rttm.py -------------------------------------------------------------------------------- /eend/bin/model_averaging.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) 2021 Nippon Telegraph and Telephone corporation (NTT). 4 | # All rights reserved 5 | 6 | import torch 7 | import argparse 8 | 9 | 10 | def average_model_pytorch(ifiles, ofile): 11 | omodel = {} 12 | for path in ifiles: 13 | state_dict = torch.load(path)['model'] 14 | for key in state_dict.keys(): 15 | val = state_dict[key] 16 | if key not in omodel: 17 | omodel[key] = val 18 | else: 19 | omodel[key] += val 20 | for key in omodel.keys(): 21 | omodel[key] /= len(ifiles) 22 | torch.save(dict(model=omodel), ofile) 23 | 24 | 25 | if __name__ == '__main__': 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument("ofile") 28 | parser.add_argument("ifiles", nargs='+') 29 | args = parser.parse_args() 30 | average_model_pytorch(args.ifiles, args.ofile) 31 | -------------------------------------------------------------------------------- /eend/bin/random_mixture.py: -------------------------------------------------------------------------------- 1 | ../../tools/EEND/eend/bin/random_mixture.py -------------------------------------------------------------------------------- /eend/bin/save_spkv_lab.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) 2021 Nippon Telegraph and Telephone corporation (NTT). 4 | # All rights reserved 5 | 6 | import yamlargparse 7 | from eend.pytorch_backend.infer import save_spkv_lab 8 | 9 | parser = yamlargparse.ArgumentParser(description='decoding') 10 | parser.add_argument('-c', '--config', help='config file path', 11 | action=yamlargparse.ActionConfigFile) 12 | parser.add_argument('data_dir', 13 | help='kaldi-style data dir') 14 | parser.add_argument('model_file', 15 | help='best.nnet') 16 | parser.add_argument('out_dir', 17 | help='output directory.') 18 | 19 | # The following arguments are set in conf/save_spkv_lab.yaml 20 | parser.add_argument('--num-speakers', default=3, type=int) 21 | parser.add_argument('--spkv-dim', default=256, type=int, 22 | help='dimension of speaker embedding vector') 23 | parser.add_argument('--hidden-size', default=256, type=int) 24 | parser.add_argument('--input-transform', default='logmel23_mn', 25 | choices=['', 'log', 'logmel', 26 | 'logmel23', 'logmel23_swn', 'logmel23_mn'], 27 | help='input transform') 28 | parser.add_argument('--chunk-size', default=300, type=int, 29 | help='input is chunked with this size') 30 | parser.add_argument('--context-size', default=7, type=int, 31 | help='frame splicing') 32 | parser.add_argument('--subsampling', default=10, type=int) 33 | parser.add_argument('--sampling-rate', default=8000, type=int, 34 | help='sampling rate') 35 | parser.add_argument('--frame-size', default=200, type=int, 36 | help='frame size') 37 | parser.add_argument('--frame-shift', default=80, type=int, 38 | help='frame shift') 39 | parser.add_argument('--transformer-encoder-n-heads', default=8, type=int) 40 | parser.add_argument('--transformer-encoder-n-layers', default=6, type=int) 41 | 42 | args = parser.parse_args() 43 | print(args) 44 | 45 | save_spkv_lab(args) 46 | -------------------------------------------------------------------------------- /eend/bin/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) 2021 Nippon Telegraph and Telephone corporation (NTT). 4 | # All rights reserved 5 | 6 | import os 7 | import yamlargparse 8 | from eend.pytorch_backend.train import train, save_feature 9 | 10 | parser = yamlargparse.ArgumentParser(description='training') 11 | parser.add_argument('-c', '--config', help='config file path', 12 | action=yamlargparse.ActionConfigFile) 13 | parser.add_argument('train_data_dir', 14 | help='kaldi-style data dir used for training.') 15 | parser.add_argument('valid_data_dir', 16 | help='kaldi-style data dir used for validation.') 17 | parser.add_argument('model_save_dir', 18 | help='output directory which model file will be saved in.') 19 | parser.add_argument('--initmodel', '-m', default='', 20 | help='Initialize the model from given file') 21 | parser.add_argument('--spkv-lab', default='', 22 | help='file path of speaker vector with label and\ 23 | speaker ID conversion table for adaptation') 24 | 25 | # The following arguments are set in conf/train.yaml or conf/adapt.yaml 26 | parser.add_argument('--spk-loss-ratio', default=0.03, type=float) 27 | parser.add_argument('--spkv-dim', default=256, type=int, 28 | help='dimension of speaker embedding vector') 29 | parser.add_argument('--max-epochs', default=100, type=int, 30 | help='Max. number of epochs to train') 31 | parser.add_argument('--input-transform', default='logmel23_mn', 32 | choices=['', 'log', 'logmel', 'logmel23', 'logmel23_mn', 33 | 'logmel23_mvn', 'logmel23_swn'], 34 | help='input transform') 35 | parser.add_argument('--lr', default=0.001, type=float) 36 | parser.add_argument('--optimizer', default='noam', type=str) 37 | parser.add_argument('--num-speakers', default=3, type=int) 38 | parser.add_argument('--gradclip', default=5, type=int, 39 | help='gradient clipping. if < 0, no clipping') 40 | parser.add_argument('--chunk-size', default=150, type=int, 41 | help='number of frames in one utterance') 42 | parser.add_argument('--batchsize', default=64, type=int, 43 | help='number of utterances in one batch.\ 44 | Note that real batchsize = number of gpu *\ 45 | batchsize-per-gpu * batchsize') 46 | parser.add_argument('--num-workers', default=8, type=int) 47 | parser.add_argument('--hidden-size', default=256, type=int) 48 | parser.add_argument('--context-size', default=7, type=int) 49 | parser.add_argument('--subsampling', default=10, type=int) 50 | parser.add_argument('--frame-size', default=200, type=int) 51 | parser.add_argument('--frame-shift', default=80, type=int) 52 | parser.add_argument('--sampling-rate', default=8000, type=int) 53 | parser.add_argument('--noam-scale', default=1.0, type=float) 54 | parser.add_argument('--noam-warmup-steps', default=25000, type=float) 55 | parser.add_argument('--transformer-encoder-n-heads', default=8, type=int) 56 | parser.add_argument('--transformer-encoder-n-layers', default=6, type=int) 57 | parser.add_argument('--transformer-encoder-dropout', default=0.1, type=float) 58 | parser.add_argument('--seed', default=777, type=int) 59 | parser.add_argument('--feature-nj', default=100, type=int, 60 | help='maximum number of subdirectories to store\ 61 | featlab_XXXXXXXX.npy') 62 | parser.add_argument('--batchsize-per-gpu', default=16, type=int, 63 | help='virtual_minibatch_size in padertorch') 64 | parser.add_argument('--test-run', default=0, type=int, choices=[0, 1], 65 | help='padertorch test run switch; 1 is on, 0 is off') 66 | 67 | args = parser.parse_args() 68 | print(args) 69 | 70 | # To speed up the training process, we first calculate input features 71 | # to NN and save shuffled feature data to the disc. During training, 72 | # we simply read the saved data from the disc. 73 | path = '{}/data/.done'.format(args.model_save_dir) 74 | is_file = os.path.isfile(path) 75 | if is_file: 76 | print("skip feature saving.") 77 | train(args) 78 | else: 79 | save_feature(args) 80 | train(args) 81 | -------------------------------------------------------------------------------- /eend/bin/yaml2bash.py: -------------------------------------------------------------------------------- 1 | ../../tools/EEND/eend/bin/yaml2bash.py -------------------------------------------------------------------------------- /eend/feature.py: -------------------------------------------------------------------------------- 1 | ../tools/EEND/eend/feature.py -------------------------------------------------------------------------------- /eend/kaldi_data.py: -------------------------------------------------------------------------------- 1 | ../tools/EEND/eend/kaldi_data.py -------------------------------------------------------------------------------- /eend/pytorch_backend/diarization_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Nippon Telegraph and Telephone corporation (NTT). 2 | # All rights reserved 3 | 4 | import torch 5 | import numpy as np 6 | from eend import kaldi_data 7 | from eend import feature 8 | 9 | 10 | def _count_frames(data_len, size, step): 11 | return int((data_len - size + step) / step) 12 | 13 | 14 | def _gen_frame_indices(data_length, size=2000, step=2000): 15 | i = -1 16 | for i in range(_count_frames(data_length, size, step)): 17 | yield i * step, i * step + size 18 | 19 | if i * step + size < data_length: 20 | if data_length - (i + 1) * step > 0: 21 | if i == -1: 22 | yield (i + 1) * step, data_length 23 | else: 24 | yield data_length - size, data_length 25 | 26 | 27 | class DiarizationDatasetFromWave(torch.utils.data.Dataset): 28 | def __init__( 29 | self, 30 | data_dir, 31 | dtype=np.float32, 32 | chunk_size=2000, 33 | context_size=0, 34 | frame_size=1024, 35 | frame_shift=256, 36 | subsampling=1, 37 | rate=16000, 38 | input_transform=None, 39 | n_speakers=None, 40 | ): 41 | self.data_dir = data_dir 42 | self.dtype = dtype 43 | self.chunk_size = chunk_size 44 | self.context_size = context_size 45 | self.frame_size = frame_size 46 | self.frame_shift = frame_shift 47 | self.subsampling = subsampling 48 | self.rate = rate 49 | self.input_transform = input_transform 50 | self.n_speakers = n_speakers 51 | 52 | self.chunk_indices = [] 53 | self.data = kaldi_data.KaldiData(self.data_dir) 54 | self.all_speakers = sorted(self.data.spk2utt.keys()) 55 | self.all_n_speakers = len(self.all_speakers) 56 | self.all_n_speakers_arr =\ 57 | np.arange(self.all_n_speakers, 58 | dtype=np.int64).reshape(self.all_n_speakers, 1) 59 | 60 | # Make chunk indices: filepath, start_frame, end_frame 61 | for rec in self.data.wavs: 62 | data_len = int(self.data.reco2dur[rec] * self.rate / frame_shift) 63 | data_len = int(data_len / self.subsampling) 64 | for st, ed in _gen_frame_indices(data_len, chunk_size, chunk_size): 65 | self.chunk_indices.append( 66 | (rec, st * self.subsampling, ed * self.subsampling)) 67 | print(len(self.chunk_indices), " chunks") 68 | 69 | def __len__(self): 70 | return len(self.chunk_indices) 71 | 72 | def __getitem__(self, i): 73 | rec, st, ed = self.chunk_indices[i] 74 | filtered_segments = self.data.segments[rec] 75 | # speakers: the value given from data 76 | speakers = np.unique( 77 | [self.data.utt2spk[seg['utt']] for seg in filtered_segments] 78 | ).tolist() 79 | n_speakers = self.n_speakers 80 | if self.n_speakers < len(speakers): 81 | n_speakers = len(speakers) 82 | 83 | Y, T = feature.get_labeledSTFT( 84 | self.data, 85 | rec, 86 | st, 87 | ed, 88 | self.frame_size, 89 | self.frame_shift, 90 | n_speakers, 91 | ) 92 | T = T.astype(np.float32) 93 | 94 | S_arr = -1 * np.ones(n_speakers).astype(np.int64) 95 | for seg in filtered_segments: 96 | speaker_index = speakers.index(self.data.utt2spk[seg['utt']]) 97 | all_speaker_index = self.all_speakers.index( 98 | self.data.utt2spk[seg['utt']]) 99 | S_arr[speaker_index] = all_speaker_index 100 | 101 | # If T[:, n_speakers - 1] == 0.0, then S_arr[n_speakers - 1] == -1, 102 | # so S_arr[n_speakers - 1] is not used for training, 103 | # e.g., in the case of training 3-spk model with 2-spk data 104 | 105 | Y = feature.transform(Y, self.input_transform) 106 | Y_spliced = feature.splice(Y, self.context_size) 107 | Y_ss, T_ss = feature.subsample(Y_spliced, T, self.subsampling) 108 | ilen = np.array(Y_ss.shape[0], dtype=np.int64) 109 | 110 | return Y_ss, T_ss, S_arr, self.all_n_speakers_arr, ilen 111 | 112 | def get_allnspk(self): 113 | return self.all_n_speakers 114 | 115 | 116 | class DiarizationDatasetFromFeat(torch.utils.data.Dataset): 117 | def __init__( 118 | self, 119 | featlab_chunk_indices_path, 120 | featdim, 121 | ): 122 | self.featlab_chunk_indices_path = featlab_chunk_indices_path 123 | self.featdim = featdim 124 | 125 | self.chunk_indices = [ 126 | (line.strip().split()[0], line.strip().split()[1]) 127 | for line in open(featlab_chunk_indices_path)] 128 | print(len(self.chunk_indices), " chunks") 129 | 130 | # define self.all_n_speakers 131 | featlab_path, chunk_idx = self.chunk_indices[0] 132 | chunks = np.load(featlab_path, mmap_mode='r') 133 | chunk = chunks[int(chunk_idx)] 134 | chunk = np.array(chunk) 135 | labs_data = chunk[:, self.featdim:] 136 | self.all_n_speakers = np.round(labs_data[0, -2]).astype(np.int64) 137 | 138 | def __len__(self): 139 | return len(self.chunk_indices) 140 | 141 | def __getitem__(self, i): 142 | featlab_path, chunk_idx = self.chunk_indices[i] 143 | chunks = np.load(featlab_path, mmap_mode='r') 144 | # mmap_mode='r' is for accessing small fragments specified by 145 | # chunk_index of the file without reading the entire file into memory 146 | # 147 | # data structure info of chunks (=featlab_XXXXXXXX.npy): 148 | # (assuming that batch_size == 1024, chunk_size == 150 149 | # featdim == 345, and num_speakers == 3) 150 | # 151 | # - chunks.shape: (1024, 150, 353) 152 | # - chunk.shape == chunks[int(chunk_idx)].shape: (150, 353) 153 | # - 1) chunk[:, :345] : feature data from audio file 154 | # - 2) chunk[:, 345:348] : reference speech activities of 3-speakers 155 | # - 3) chunk[:, 348:351] : reference speaker IDs of 3-speakers 156 | # (speaker order is same as 2)) 157 | # - 4) chunk[:, 351] : reference number of all speakers 158 | # - 5) chunk[:, 352] : real chunk size 159 | 160 | chunk = chunks[int(chunk_idx)] 161 | chunk = np.array(chunk) 162 | feat_data = chunk[:, :self.featdim] 163 | labs_data = chunk[:, self.featdim:] 164 | num_speakers = (labs_data.shape[1] - 2) // 2 165 | y = feat_data 166 | t = labs_data[:, :num_speakers] 167 | s = np.round(labs_data[0, num_speakers:-2]).astype(np.int64) 168 | n = np.round(labs_data[0, -2]).astype(np.int64) 169 | n = np.arange(n, dtype=np.int64).reshape(n, 1) 170 | ilen = np.round(labs_data[0, -1]).astype(np.int64) 171 | return y, t, s, n, ilen 172 | 173 | def get_allnspk(self): 174 | return self.all_n_speakers 175 | -------------------------------------------------------------------------------- /eend/pytorch_backend/infer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Nippon Telegraph and Telephone corporation (NTT). 2 | # All rights reserved 3 | 4 | import os 5 | import h5py 6 | import numpy as np 7 | from sklearn.cluster import AgglomerativeClustering 8 | from scipy.spatial import distance 9 | import torch 10 | from functools import partial 11 | from eend.pytorch_backend.train import collate_fn_ns 12 | from eend.pytorch_backend.models import fix_state_dict 13 | from eend.pytorch_backend.models import TransformerDiarization 14 | from eend.pytorch_backend.diarization_dataset import DiarizationDatasetFromWave 15 | from eend import feature 16 | from eend import kaldi_data 17 | 18 | 19 | def _gen_chunk_indices(data_len, chunk_size): 20 | step = chunk_size 21 | start = 0 22 | while start < data_len: 23 | end = min(data_len, start + chunk_size) 24 | yield start, end 25 | start += step 26 | 27 | 28 | def get_cl_sil(args, acti, cls_num): 29 | n_chunks = len(acti) 30 | mean_acti = np.array([np.mean(acti[i], axis=0) 31 | for i in range(n_chunks)]).flatten() 32 | n = args.num_speakers 33 | sil_spk_th = args.sil_spk_th 34 | 35 | cl_lst = [] 36 | sil_lst = [] 37 | for chunk_idx in range(n_chunks): 38 | if cls_num is not None: 39 | if args.num_speakers > cls_num: 40 | mean_acti_bi = np.array([mean_acti[n * chunk_idx + s_loc_idx] 41 | for s_loc_idx in range(n)]) 42 | min_idx = np.argmin(mean_acti_bi) 43 | mean_acti[n * chunk_idx + min_idx] = 0.0 44 | 45 | for s_loc_idx in range(n): 46 | a = n * chunk_idx + (s_loc_idx + 0) % n 47 | b = n * chunk_idx + (s_loc_idx + 1) % n 48 | if mean_acti[a] > sil_spk_th and mean_acti[b] > sil_spk_th: 49 | cl_lst.append((a, b)) 50 | else: 51 | if mean_acti[a] <= sil_spk_th: 52 | sil_lst.append(a) 53 | 54 | return cl_lst, sil_lst 55 | 56 | 57 | def clustering(args, svec, cls_num, ahc_dis_th, cl_lst, sil_lst): 58 | org_svec_len = len(svec) 59 | svec = np.delete(svec, sil_lst, 0) 60 | 61 | # update cl_lst idx 62 | _tbl = [i - sum(sil < i for sil in sil_lst) for i in range(org_svec_len)] 63 | cl_lst = [(_tbl[_cl[0]], _tbl[_cl[1]]) for _cl in cl_lst] 64 | 65 | distMat = distance.cdist(svec, svec, metric='euclidean') 66 | for cl in cl_lst: 67 | distMat[cl[0], cl[1]] = args.clink_dis 68 | distMat[cl[1], cl[0]] = args.clink_dis 69 | 70 | clusterer = AgglomerativeClustering( 71 | n_clusters=cls_num, 72 | affinity='precomputed', 73 | linkage='average', 74 | distance_threshold=ahc_dis_th) 75 | clusterer.fit(distMat) 76 | 77 | if cls_num is not None: 78 | print("oracle n_clusters is known") 79 | else: 80 | print("oracle n_clusters is unknown") 81 | print("estimated n_clusters by constraind AHC: {}" 82 | .format(len(np.unique(clusterer.labels_)))) 83 | cls_num = len(np.unique(clusterer.labels_)) 84 | 85 | sil_lab = cls_num 86 | insert_sil_lab = [sil_lab for i in range(len(sil_lst))] 87 | insert_sil_lab_idx = [sil_lst[i] - i for i in range(len(sil_lst))] 88 | print("insert_sil_lab : {}".format(insert_sil_lab)) 89 | print("insert_sil_lab_idx : {}".format(insert_sil_lab_idx)) 90 | clslab = np.insert(clusterer.labels_, 91 | insert_sil_lab_idx, 92 | insert_sil_lab).reshape(-1, args.num_speakers) 93 | print("clslab : {}".format(clslab)) 94 | 95 | return clslab, cls_num 96 | 97 | 98 | def merge_act_max(act, i, j): 99 | for k in range(len(act)): 100 | act[k, i] = max(act[k, i], act[k, j]) 101 | act[k, j] = 0.0 102 | return act 103 | 104 | 105 | def merge_acti_clslab(args, acti, clslab, cls_num): 106 | sil_lab = cls_num 107 | for i in range(len(clslab)): 108 | _lab = clslab[i].reshape(-1, 1) 109 | distM = distance.cdist(_lab, _lab, metric='euclidean').astype(np.int64) 110 | for j in range(len(distM)): 111 | distM[j][:j] = -1 112 | idx_lst = np.where(np.count_nonzero(distM == 0, axis=1) > 1) 113 | merge_done = [] 114 | for j in idx_lst[0]: 115 | for k in (np.where(distM[j] == 0))[0]: 116 | if j != k and clslab[i, j] != sil_lab and k not in merge_done: 117 | print("merge : (i, j, k) == ({}, {}, {})".format(i, j, k)) 118 | acti[i] = merge_act_max(acti[i], j, k) 119 | clslab[i, k] = sil_lab 120 | merge_done.append(j) 121 | 122 | return acti, clslab 123 | 124 | 125 | def stitching(args, acti, clslab, cls_num): 126 | n_chunks = len(acti) 127 | s_loc = args.num_speakers 128 | sil_lab = cls_num 129 | s_tot = max(cls_num, s_loc-1) 130 | 131 | # Extend the max value of s_loc_idx to s_tot+1 132 | add_acti = [] 133 | for chunk_idx in range(n_chunks): 134 | zeros = np.zeros((len(acti[chunk_idx]), s_tot+1)) 135 | if s_tot+1 > s_loc: 136 | zeros[:, :-(s_tot+1-s_loc)] = acti[chunk_idx] 137 | else: 138 | zeros = acti[chunk_idx] 139 | add_acti.append(zeros) 140 | acti = np.array(add_acti) 141 | 142 | out_chunks = [] 143 | for chunk_idx in range(n_chunks): 144 | # Make sloci2lab_dct. 145 | # key: s_loc_idx 146 | # value: estimated label by clustering or sil_lab 147 | cls_set = set() 148 | for s_loc_idx in range(s_tot+1): 149 | cls_set.add(s_loc_idx) 150 | 151 | sloci2lab_dct = {} 152 | for s_loc_idx in range(s_tot+1): 153 | if s_loc_idx < s_loc: 154 | sloci2lab_dct[s_loc_idx] = clslab[chunk_idx][s_loc_idx] 155 | if clslab[chunk_idx][s_loc_idx] in cls_set: 156 | cls_set.remove(clslab[chunk_idx][s_loc_idx]) 157 | else: 158 | if clslab[chunk_idx][s_loc_idx] != sil_lab: 159 | raise ValueError 160 | else: 161 | sloci2lab_dct[s_loc_idx] = list(cls_set)[s_loc_idx-s_loc] 162 | 163 | # Sort by label value 164 | sloci2lab_lst = sorted(sloci2lab_dct.items(), key=lambda x: x[1]) 165 | 166 | # Select sil_lab_idx 167 | sil_lab_idx = None 168 | for idx_lab in sloci2lab_lst: 169 | if idx_lab[1] == sil_lab: 170 | sil_lab_idx = idx_lab[0] 171 | break 172 | if sil_lab_idx is None: 173 | raise ValueError 174 | 175 | # Get swap_idx 176 | # [idx of label(0), idx of label(1), ..., idx of label(s_tot)] 177 | swap_idx = [sil_lab_idx for j in range(s_tot+1)] 178 | for lab in range(s_tot+1): 179 | for idx_lab in sloci2lab_lst: 180 | if lab == idx_lab[1]: 181 | swap_idx[lab] = idx_lab[0] 182 | 183 | print("swap_idx {}".format(swap_idx)) 184 | swap_acti = acti[chunk_idx][:, swap_idx] 185 | swap_acti = np.delete(swap_acti, sil_lab, 1) 186 | out_chunks.append(swap_acti) 187 | 188 | return out_chunks 189 | 190 | 191 | def prepare_model_for_eval(args): 192 | in_size = feature.get_input_dim( 193 | args.frame_size, 194 | args.context_size, 195 | args.input_transform) 196 | model_parameter_dict = torch.load(args.model_file)['model'] 197 | model_all_n_speakers =\ 198 | fix_state_dict(model_parameter_dict)["embed.weight"].shape[0] 199 | net = TransformerDiarization( 200 | n_speakers=args.num_speakers, 201 | in_size=in_size, 202 | n_units=args.hidden_size, 203 | n_heads=args.transformer_encoder_n_heads, 204 | n_layers=args.transformer_encoder_n_layers, 205 | dropout_rate=0, 206 | all_n_speakers=model_all_n_speakers, 207 | d=args.spkv_dim) 208 | 209 | device = [device_id for device_id in range(torch.cuda.device_count())] 210 | net.load_state_dict(fix_state_dict(model_parameter_dict)) 211 | net.eval() 212 | net = net.to("cuda") 213 | print('GPU device {} is used'.format(device)) 214 | print('Prepared model') 215 | 216 | return net 217 | 218 | 219 | def prediction(args, net, kaldi_obj, recid): 220 | acti_lst = [] 221 | svec_lst = [] 222 | # Prepare input features 223 | data, rate = kaldi_obj.load_wav(recid) 224 | Y = feature.stft(data, args.frame_size, args.frame_shift) 225 | Y = feature.transform(Y, transform_type=args.input_transform) 226 | Y = feature.splice(Y, context_size=args.context_size) 227 | Y = Y[::args.subsampling] 228 | 229 | with torch.no_grad(): 230 | for start, end in _gen_chunk_indices(len(Y), args.chunk_size): 231 | if start > 0 and start + args.chunk_size > end: 232 | # Ensure last chunk size 233 | Y_chunked = torch.from_numpy(Y[end-args.chunk_size:end]) 234 | else: 235 | Y_chunked = torch.from_numpy(Y[start:end]) 236 | Y_chunked = Y_chunked.to('cuda') 237 | 238 | outputs = net.batch_estimate(torch.unsqueeze(Y_chunked, 0)) 239 | ys = outputs[0] 240 | 241 | for i in range(args.num_speakers): 242 | spkivecs = outputs[i+1] 243 | svec_lst.append(spkivecs[0].cpu().detach().numpy()) 244 | 245 | if start > 0 and start + args.chunk_size > end: 246 | # Ensure last chunk size 247 | ys = list(ys) 248 | ys[0] = ys[0][args.chunk_size-(end-start):args.chunk_size] 249 | 250 | acti = ys[0].cpu().detach().numpy() 251 | acti_lst.append(acti) 252 | 253 | acti_arr = np.array(acti_lst) 254 | svec_arr = np.array(svec_lst) 255 | 256 | return acti_arr, svec_arr 257 | 258 | 259 | def infer(args): 260 | # Prepare model 261 | net = prepare_model_for_eval(args) 262 | 263 | kaldi_obj = kaldi_data.KaldiData(args.data_dir) 264 | for recid in kaldi_obj.wavs: 265 | print("recid : {}".format(recid)) 266 | # prediction 267 | acti, svec = prediction(args, net, kaldi_obj, recid) 268 | n_chunks = len(acti) 269 | # initialize clustering setting 270 | cls_num = None 271 | ahc_dis_th = args.ahc_dis_th 272 | if args.est_nspk == 0: 273 | filtered_segments = kaldi_obj.segments[recid] 274 | cls_num = len(np.unique( 275 | [kaldi_obj.utt2spk[seg['utt']] for seg in filtered_segments] 276 | ).tolist()) 277 | ahc_dis_th = None 278 | # Get cannot-link index list and silence index list 279 | cl_lst, sil_lst = get_cl_sil(args, acti, cls_num) 280 | 281 | n_samples = n_chunks * args.num_speakers - len(sil_lst) 282 | min_n_samples = 2 283 | if cls_num is not None: 284 | min_n_samples = cls_num 285 | 286 | if n_samples >= min_n_samples: 287 | # clustering (if cls_num is None, update cls_num) 288 | clslab, cls_num =\ 289 | clustering(args, svec, cls_num, ahc_dis_th, cl_lst, sil_lst) 290 | # merge 291 | acti, clslab = merge_acti_clslab(args, acti, clslab, cls_num) 292 | # stitching 293 | out_chunks = stitching(args, acti, clslab, cls_num) 294 | else: 295 | out_chunks = acti 296 | 297 | outdata = np.vstack(out_chunks) 298 | # Saving the resuts 299 | outfname = recid + '.h5' 300 | outpath = os.path.join(args.out_dir, outfname) 301 | with h5py.File(outpath, 'w') as wf: 302 | # 'T_hat': key 303 | wf.create_dataset('T_hat', data=outdata) 304 | 305 | 306 | def save_spkv_lab(args): 307 | # Prepare data 308 | data_set = DiarizationDatasetFromWave( 309 | args.data_dir, 310 | chunk_size=args.chunk_size, 311 | context_size=args.context_size, 312 | input_transform=args.input_transform, 313 | frame_size=args.frame_size, 314 | frame_shift=args.frame_shift, 315 | subsampling=args.subsampling, 316 | rate=args.sampling_rate, 317 | n_speakers=args.num_speakers, 318 | ) 319 | 320 | # Prepare model 321 | net = prepare_model_for_eval(args) 322 | 323 | # Inference and saving filtered data (spkvec_lab.npz) 324 | with torch.no_grad(): 325 | all_outputs = [] 326 | all_labels = [] 327 | 328 | # Exclude samples that exceed args.num_speakers speakers in a chunk 329 | data_loader = torch.utils.data.DataLoader( 330 | data_set, batch_size=8, shuffle=False, 331 | collate_fn=partial( 332 | collate_fn_ns, 333 | n_speakers=args.num_speakers, 334 | spkidx_tbl=None)) 335 | 336 | for batch_data in data_loader: 337 | # batch_data: (xs, ts, ss, ns, ilens) 338 | for chunk_data in list(zip(*batch_data)): 339 | # chunk_data: (x, t, s, n, ilen) 340 | Y_chunked = torch.from_numpy(chunk_data[0]).to('cuda') 341 | t_chunked = torch.from_numpy(chunk_data[1]).to('cuda') 342 | 343 | outputs = net.batch_estimate_with_perm( 344 | torch.unsqueeze(Y_chunked, 0), 345 | torch.unsqueeze(t_chunked, 0)) 346 | sigma = outputs[args.num_speakers+1][0] 347 | t_chunked_t = t_chunked.transpose(1, 0) 348 | 349 | for i in range(args.num_speakers): 350 | # Exclude samples corresponding to silent speaker 351 | if torch.sum(t_chunked_t[sigma[i]]) > 0: 352 | vec = outputs[i+1][0].cpu().detach().numpy() 353 | lab = chunk_data[2][sigma[i]] 354 | all_outputs.append(vec) 355 | all_labels.append(lab) 356 | 357 | orgdata_all_n_speakers = data_set.get_allnspk() 358 | # Generate spkidx_tbl to convert speaker ID 359 | spkidx_tbl = np.array([-1 for i in range(orgdata_all_n_speakers)]) 360 | for i, idx in enumerate(list(set(all_labels))): 361 | spkidx_tbl[idx] = i 362 | # In this line, if speaker_tbl[_idx] == -1, the speaker whose 363 | # original speaker ID is _idx is excluded for training 364 | 365 | print("number of speakers in the original data: {}" 366 | .format(orgdata_all_n_speakers)) 367 | print("number of speakers in the filtered data: {}" 368 | .format(len(set(all_labels)))) 369 | 370 | emb_npz_path = args.out_dir + '/spkvec_lab' 371 | np.savez(emb_npz_path, 372 | np.array(all_outputs), 373 | np.array(all_labels), 374 | spkidx_tbl) 375 | print("Saved {}".format(emb_npz_path + '.npz')) 376 | -------------------------------------------------------------------------------- /eend/pytorch_backend/models.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Nippon Telegraph and Telephone corporation (NTT). 2 | # All rights reserved 3 | 4 | import sys 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | import torch.nn as nn 9 | from collections import OrderedDict 10 | from itertools import permutations 11 | from eend.pytorch_backend.transformer import TransformerEncoder 12 | import padertorch as pt 13 | 14 | """ 15 | P: number of permutation 16 | T: number of frames 17 | C: number of speakers (classes) 18 | B: mini-batch size 19 | """ 20 | 21 | 22 | def pit_loss(pred, label): 23 | """ 24 | Permutation-invariant training (PIT) cross entropy loss function. 25 | 26 | Args: 27 | pred: (T,C)-shaped pre-activation values 28 | label: (T,C)-shaped labels in {0,1} 29 | 30 | Returns: 31 | min_loss: (1,)-shape mean cross entropy 32 | label_perms[min_index]: permutated labels 33 | sigma: permutation 34 | """ 35 | 36 | device = pred.device 37 | T = len(label) 38 | C = label.shape[-1] 39 | label_perms_indices = [ 40 | list(p) for p in permutations(range(C))] 41 | P = len(label_perms_indices) 42 | perm_mat = torch.zeros(P, T, C, C).to(device) 43 | 44 | for i, p in enumerate(label_perms_indices): 45 | perm_mat[i, :, torch.arange(label.shape[-1]), p] = 1 46 | 47 | x = torch.unsqueeze(torch.unsqueeze(label, 0), -1).to(device) 48 | y = torch.arange(P * T * C).view(P, T, C, 1).to(device) 49 | 50 | broadcast_label = torch.broadcast_tensors(x, y)[0] 51 | allperm_label = torch.matmul( 52 | perm_mat, broadcast_label 53 | ).squeeze(-1) 54 | 55 | x = torch.unsqueeze(pred, 0) 56 | y = torch.arange(P * T).view(P, T, 1) 57 | broadcast_pred = torch.broadcast_tensors(x, y)[0] 58 | 59 | # broadcast_pred: (P, T, C) 60 | # allperm_label: (P, T, C) 61 | losses = F.binary_cross_entropy_with_logits( 62 | broadcast_pred, 63 | allperm_label, 64 | reduction='none') 65 | mean_losses = torch.mean(torch.mean(losses, dim=1), dim=1) 66 | min_loss = torch.min(mean_losses) * len(label) 67 | min_index = torch.argmin(mean_losses) 68 | sigma = list(permutations(range(label.shape[-1])))[min_index] 69 | 70 | return min_loss, allperm_label[min_index], sigma 71 | 72 | 73 | def batch_pit_loss(ys, ts, ilens=None): 74 | """ 75 | PIT loss over mini-batch. 76 | 77 | Args: 78 | ys: B-length list of predictions 79 | ts: B-length list of labels 80 | 81 | Returns: 82 | loss: (1,)-shape mean cross entropy over mini-batch 83 | sigmas: B-length list of permutation 84 | """ 85 | if ilens is None: 86 | ilens = [t.shape[0] for t in ts] 87 | 88 | loss_w_labels_w_sigmas = [pit_loss(y[:ilen, :], t[:ilen, :]) 89 | for (y, t, ilen) in zip(ys, ts, ilens)] 90 | losses, _, sigmas = zip(*loss_w_labels_w_sigmas) 91 | loss = torch.sum(torch.stack(losses)) 92 | n_frames = np.sum([ilen for ilen in ilens]) 93 | loss = loss / n_frames 94 | 95 | return loss, sigmas 96 | 97 | 98 | def fix_state_dict(state_dict): 99 | new_state_dict = OrderedDict() 100 | for k, v in state_dict.items(): 101 | if k.startswith('module.'): 102 | # remove 'module.' of DataParallel 103 | k = k[7:] 104 | if k.startswith('net.'): 105 | # remove 'net.' of PadertorchModel 106 | k = k[4:] 107 | new_state_dict[k] = v 108 | return new_state_dict 109 | 110 | 111 | class PadertorchModel(pt.base.Model): 112 | def __init__(self, net): 113 | super().__init__() 114 | self.net = net 115 | 116 | def forward(self, inputs): 117 | xs = inputs[0] 118 | ys, spksvecs = self.net(xs) 119 | 120 | return dict(prediction=ys, spksvecs=spksvecs) 121 | 122 | def review(self, inputs, outputs): 123 | ys = outputs["prediction"] 124 | spksvecs = outputs["spksvecs"] 125 | spksvecs = list(zip(*spksvecs)) 126 | ts = inputs[1] 127 | ss = inputs[2] 128 | ns = inputs[3] 129 | ilens = inputs[4] 130 | ilens = [ilen.item() for ilen in ilens] 131 | 132 | pit_loss, sigmas = batch_pit_loss(ys, ts, ilens) 133 | ss = [[i.item() for i in s] for s in ss] 134 | if pit_loss.requires_grad: 135 | spk_loss = self.batch_spk_loss( 136 | spksvecs, ys, ts, ss, sigmas, ns, ilens) 137 | else: 138 | spk_loss = torch.tensor(0.0).to(pit_loss.device) 139 | 140 | alpha = torch.clamp(self.net.alpha, min=sys.float_info.epsilon) 141 | 142 | return pt.summary.review_dict( 143 | losses={'pit_loss': pit_loss, 'spk_loss': spk_loss}, 144 | scalars={'alpha': alpha}) 145 | 146 | def batch_spk_loss(self, spksvecs, ys, ts, ss, sigmas, ns, ilens): 147 | spksvecs = [[spkvec[:ilen] for spkvec in spksvec] 148 | for spksvec, ilen in zip(spksvecs, ilens)] 149 | loss = torch.stack( 150 | [self.spk_loss(spksvec, y[:ilen], t[:ilen], s, sigma, n) 151 | for(spksvec, y, t, s, sigma, n, ilen) 152 | in zip(spksvecs, ys, ts, ss, sigmas, ns, ilens)]) 153 | loss = torch.mean(loss) 154 | 155 | return loss 156 | 157 | def spk_loss(self, spksvec, y, t, s, sigma, n): 158 | embeds = self.net.embed(n).squeeze() 159 | z = torch.sigmoid(y.transpose(1, 0)) 160 | 161 | losses = [] 162 | for spkid, spkvec in enumerate(spksvec): 163 | norm_spkvec_inv = 1.0 / torch.norm(spkvec, dim=1) 164 | # Normalize speaker vectors before weighted average 165 | spkvec = torch.mul( 166 | spkvec.transpose(1, 0), norm_spkvec_inv).transpose(1, 0) 167 | wavg_spkvec = torch.mul( 168 | spkvec.transpose(1, 0), z[spkid]).transpose(1, 0) 169 | sum_wavg_spkvec = torch.sum(wavg_spkvec, dim=0) 170 | nmz_wavg_spkvec = sum_wavg_spkvec / torch.norm(sum_wavg_spkvec) 171 | nmz_wavg_spkvec = torch.unsqueeze(nmz_wavg_spkvec, 0) 172 | norm_embeds_inv = 1.0 / torch.norm(embeds, dim=1) 173 | embeds = torch.mul( 174 | embeds.transpose(1, 0), norm_embeds_inv).transpose(1, 0) 175 | dist = torch.cdist(nmz_wavg_spkvec, embeds)[0] 176 | d = torch.add( 177 | torch.clamp( 178 | self.net.alpha, 179 | min=sys.float_info.epsilon) * torch.pow(dist, 2), 180 | self.net.beta) 181 | 182 | round_t = torch.round(t.transpose(1, 0)[sigma[spkid]]) 183 | if torch.sum(round_t) > 0: 184 | loss = -F.log_softmax(-d, 0)[s[sigma[spkid]]] 185 | else: 186 | loss = torch.tensor(0.0).to(y.device) 187 | losses.append(loss) 188 | 189 | return torch.mean(torch.stack(losses)) 190 | 191 | 192 | class TransformerDiarization(nn.Module): 193 | def __init__(self, 194 | n_speakers, 195 | in_size, 196 | n_units, 197 | n_heads, 198 | n_layers, 199 | dropout_rate, 200 | all_n_speakers, 201 | d 202 | ): 203 | super(TransformerDiarization, self).__init__() 204 | self.enc = TransformerEncoder( 205 | in_size, n_layers, n_units, h=n_heads, dropout_rate=dropout_rate) 206 | self.linear = nn.Linear(n_units, n_speakers) 207 | 208 | for i in range(n_speakers): 209 | setattr(self, '{}{:d}'.format("linear", i), nn.Linear(n_units, d)) 210 | 211 | self.n_speakers = n_speakers 212 | self.embed = nn.Embedding(all_n_speakers, d) 213 | self.alpha = nn.Parameter(torch.rand(1)[0] + torch.Tensor([0.5])[0]) 214 | self.beta = nn.Parameter(torch.rand(1)[0] + torch.Tensor([0.5])[0]) 215 | 216 | def modfy_emb(self, weight): 217 | self.embed = nn.Embedding.from_pretrained(weight) 218 | 219 | def forward(self, xs): 220 | # Since xs is pre-padded, the following code is extra, 221 | # but necessary for reproducibility 222 | xs = nn.utils.rnn.pad_sequence(xs, padding_value=-1, batch_first=True) 223 | pad_shape = xs.shape 224 | emb = self.enc(xs) 225 | ys = self.linear(emb) 226 | ys = ys.reshape(pad_shape[0], pad_shape[1], -1) 227 | 228 | spksvecs = [] 229 | for i in range(self.n_speakers): 230 | spkivecs = getattr(self, '{}{:d}'.format("linear", i))(emb) 231 | spkivecs = spkivecs.reshape(pad_shape[0], pad_shape[1], -1) 232 | spksvecs.append(spkivecs) 233 | 234 | return ys, spksvecs 235 | 236 | def batch_estimate(self, xs): 237 | out = self(xs) 238 | ys = out[0] 239 | spksvecs = out[1] 240 | spksvecs = list(zip(*spksvecs)) 241 | outputs = [ 242 | self.estimate(spksvec, y) 243 | for (spksvec, y) in zip(spksvecs, ys)] 244 | outputs = list(zip(*outputs)) 245 | 246 | return outputs 247 | 248 | def batch_estimate_with_perm(self, xs, ts, ilens=None): 249 | out = self(xs) 250 | ys = out[0] 251 | if ts[0].shape[1] > ys[0].shape[1]: 252 | # e.g. the case of training 3-spk model with 4-spk data 253 | add_dim = ts[0].shape[1] - ys[0].shape[1] 254 | y_device = ys[0].device 255 | zeros = [torch.zeros(ts[0].shape).to(y_device) 256 | for i in range(len(ts))] 257 | _ys = [] 258 | for zero, y in zip(zeros, ys): 259 | _zero = zero 260 | _zero[:, :-add_dim] = y 261 | _ys.append(_zero) 262 | _, sigmas = batch_pit_loss(_ys, ts, ilens) 263 | else: 264 | _, sigmas = batch_pit_loss(ys, ts, ilens) 265 | spksvecs = out[1] 266 | spksvecs = list(zip(*spksvecs)) 267 | outputs = [self.estimate(spksvec, y) 268 | for (spksvec, y) in zip(spksvecs, ys)] 269 | outputs = list(zip(*outputs)) 270 | zs = outputs[0] 271 | 272 | if ts[0].shape[1] > ys[0].shape[1]: 273 | # e.g. the case of training 3-spk model with 4-spk data 274 | add_dim = ts[0].shape[1] - ys[0].shape[1] 275 | z_device = zs[0].device 276 | zeros = [torch.zeros(ts[0].shape).to(z_device) 277 | for i in range(len(ts))] 278 | _zs = [] 279 | for zero, z in zip(zeros, zs): 280 | _zero = zero 281 | _zero[:, :-add_dim] = z 282 | _zs.append(_zero) 283 | zs = _zs 284 | outputs[0] = zs 285 | outputs.append(sigmas) 286 | 287 | # outputs: [zs, nmz_wavg_spk0vecs, nmz_wavg_spk1vecs, ..., sigmas] 288 | return outputs 289 | 290 | def estimate(self, spksvec, y): 291 | outputs = [] 292 | z = torch.sigmoid(y.transpose(1, 0)) 293 | 294 | outputs.append(z.transpose(1, 0)) 295 | for spkid, spkvec in enumerate(spksvec): 296 | norm_spkvec_inv = 1.0 / torch.norm(spkvec, dim=1) 297 | # Normalize speaker vectors before weighted average 298 | spkvec = torch.mul( 299 | spkvec.transpose(1, 0), norm_spkvec_inv 300 | ).transpose(1, 0) 301 | wavg_spkvec = torch.mul( 302 | spkvec.transpose(1, 0), z[spkid] 303 | ).transpose(1, 0) 304 | sum_wavg_spkvec = torch.sum(wavg_spkvec, dim=0) 305 | nmz_wavg_spkvec = sum_wavg_spkvec / torch.norm(sum_wavg_spkvec) 306 | outputs.append(nmz_wavg_spkvec) 307 | 308 | # outputs: [z, nmz_wavg_spk0vec, nmz_wavg_spk1vec, ...] 309 | return outputs 310 | -------------------------------------------------------------------------------- /eend/pytorch_backend/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Nippon Telegraph and Telephone corporation (NTT). 2 | # All rights reserved 3 | 4 | import os 5 | import numpy as np 6 | from functools import partial 7 | import torch 8 | from eend.feature import get_input_dim 9 | from eend.pytorch_backend.models import fix_state_dict 10 | from eend.pytorch_backend.models import PadertorchModel 11 | from eend.pytorch_backend.models import TransformerDiarization 12 | from eend.pytorch_backend.transformer import NoamScheduler 13 | from eend.pytorch_backend.diarization_dataset \ 14 | import DiarizationDatasetFromWave, DiarizationDatasetFromFeat 15 | import padertorch as pt 16 | import padertorch.train.optimizer as pt_opt 17 | 18 | 19 | def collate_fn_ns(batch, n_speakers, spkidx_tbl): 20 | xs, ts, ss, ns, ilens = list(zip(*batch)) 21 | valid_chunk_indices1 = [i for i in range(len(ts)) 22 | if ts[i].shape[1] == n_speakers] 23 | valid_chunk_indices2 = [] 24 | 25 | # n_speakers (rec-data) > n_speakers (model) 26 | invalid_chunk_indices1 = [i for i in range(len(ts)) 27 | if ts[i].shape[1] > n_speakers] 28 | 29 | ts = list(ts) 30 | ss = list(ss) 31 | for i in invalid_chunk_indices1: 32 | s = np.sum(ts[i], axis=0) 33 | cs = ts[i].shape[0] 34 | if len(s[s > 0.5]) <= n_speakers: 35 | # n_speakers (chunk-data) <= n_speakers (model) 36 | # update valid_chunk_indices2 37 | valid_chunk_indices2.append(i) 38 | idx_arr = np.where(s > 0.5)[0] 39 | ts[i] = ts[i][:, idx_arr] 40 | ss[i] = ss[i][idx_arr] 41 | if len(s[s > 0.5]) < n_speakers: 42 | # n_speakers (chunk-data) < n_speakers (model) 43 | # update ts[i] and ss[i] 44 | n_speakers_real = len(s[s > 0.5]) 45 | zeros_ts = np.zeros((cs, n_speakers), dtype=np.float32) 46 | zeros_ts[:, :-(n_speakers-n_speakers_real)] = ts[i] 47 | ts[i] = zeros_ts 48 | mones_ss = -1 * np.ones((n_speakers,), dtype=np.int64) 49 | mones_ss[:-(n_speakers-n_speakers_real)] = ss[i] 50 | ss[i] = mones_ss 51 | else: 52 | # n_speakers (chunk-data) == n_speakers (model) 53 | pass 54 | else: 55 | # n_speakers (chunk-data) > n_speakers (model) 56 | pass 57 | 58 | # valid_chunk_indices: chunk indices using for training 59 | valid_chunk_indices = sorted(valid_chunk_indices1 + valid_chunk_indices2) 60 | 61 | ilens = np.array(ilens) 62 | ilens = ilens[valid_chunk_indices] 63 | ns = np.array(ns)[valid_chunk_indices] 64 | ss = np.array([ss[i] for i in range(len(ss)) 65 | if ts[i].shape[1] == n_speakers]) 66 | xs = [xs[i] for i in range(len(xs)) if ts[i].shape[1] == n_speakers] 67 | ts = [ts[i] for i in range(len(ts)) if ts[i].shape[1] == n_speakers] 68 | xs = np.array([np.pad(x, [(0, np.max(ilens) - len(x)), (0, 0)], 69 | 'constant', constant_values=(-1,)) for x in xs]) 70 | ts = np.array([np.pad(t, [(0, np.max(ilens) - len(t)), (0, 0)], 71 | 'constant', constant_values=(+1,)) for t in ts]) 72 | 73 | if spkidx_tbl is not None: 74 | # Update global speaker ID 75 | all_n_speakers = np.max(spkidx_tbl) + 1 76 | bs = len(ns) 77 | ns = np.array([ 78 | np.arange( 79 | all_n_speakers, 80 | dtype=np.int64 81 | ).reshape(all_n_speakers, 1)] * bs) 82 | ss = np.array([spkidx_tbl[ss[i]] for i in range(len(ss))]) 83 | 84 | return (xs, ts, ss, ns, ilens) 85 | 86 | 87 | def collate_fn(batch): 88 | xs, ts, ss, ns, ilens = list(zip(*batch)) 89 | ilens = np.array(ilens) 90 | xs = np.array([np.pad( 91 | x, [(0, np.max(ilens) - len(x)), (0, 0)], 92 | 'constant', constant_values=(-1,) 93 | ) for x in xs]) 94 | ts = np.array([np.pad( 95 | t, [(0, np.max(ilens) - len(t)), (0, 0)], 96 | 'constant', constant_values=(+1,) 97 | ) for t in ts]) 98 | ss = np.array(ss) 99 | ns = np.array(ns) 100 | 101 | return (xs, ts, ss, ns, ilens) 102 | 103 | 104 | def train(args): 105 | # Set seed for reproducibility 106 | np.random.seed(args.seed) 107 | torch.manual_seed(args.seed) 108 | torch.backends.cudnn.deterministic = True 109 | os.environ['PYTORCH_SEED'] = str(args.seed) 110 | torch.backends.cudnn.benchmark = False 111 | 112 | # Prepare data 113 | featlab_chunk_indices_path =\ 114 | '{}/data/featlab_chunk_indices.txt'.format(args.model_save_dir) 115 | 116 | featdim = get_input_dim(args.frame_size, 117 | args.context_size, 118 | args.input_transform) 119 | 120 | train_set = DiarizationDatasetFromFeat( 121 | featlab_chunk_indices_path, 122 | featdim, 123 | ) 124 | dev_set = DiarizationDatasetFromWave( 125 | args.valid_data_dir, 126 | chunk_size=args.chunk_size, 127 | context_size=args.context_size, 128 | input_transform=args.input_transform, 129 | frame_size=args.frame_size, 130 | frame_shift=args.frame_shift, 131 | subsampling=args.subsampling, 132 | rate=args.sampling_rate, 133 | n_speakers=args.num_speakers, 134 | ) 135 | 136 | # Prepare model 137 | all_n_speakers = train_set.get_allnspk() 138 | net = TransformerDiarization( 139 | args.num_speakers, 140 | featdim, 141 | n_units=args.hidden_size, 142 | n_heads=args.transformer_encoder_n_heads, 143 | n_layers=args.transformer_encoder_n_layers, 144 | dropout_rate=args.transformer_encoder_dropout, 145 | all_n_speakers=all_n_speakers, 146 | d=args.spkv_dim) 147 | 148 | if args.initmodel: 149 | # adaptation 150 | model_parameter_dict = torch.load(args.initmodel)['model'] 151 | fix_model_parameter_dict = fix_state_dict(model_parameter_dict) 152 | all_n_speakers = fix_model_parameter_dict["embed.weight"].shape[0] 153 | 154 | print("old all_n_speakers : {}".format(all_n_speakers)) 155 | net = TransformerDiarization( 156 | args.num_speakers, 157 | featdim, 158 | n_units=args.hidden_size, 159 | n_heads=args.transformer_encoder_n_heads, 160 | n_layers=args.transformer_encoder_n_layers, 161 | dropout_rate=args.transformer_encoder_dropout, 162 | all_n_speakers=all_n_speakers, 163 | d=args.spkv_dim) 164 | net.load_state_dict(fix_model_parameter_dict) 165 | npz = np.load(args.spkv_lab) 166 | spkvecs = npz['arr_0'] 167 | spklabs = npz['arr_1'] 168 | spkidx_tbl = npz['arr_2'] 169 | 170 | # init 171 | spk_num = len(np.unique(spklabs)) 172 | fet_dim = spkvecs.shape[1] 173 | fet_arr = np.zeros([spk_num, fet_dim]) 174 | 175 | # sum 176 | bs = spklabs.shape[0] 177 | for i in range(bs): 178 | if spkidx_tbl[spklabs[i]] == -1: 179 | raise ValueError(spklabs[i]) 180 | fet_arr[spkidx_tbl[spklabs[i]]] += spkvecs[i] 181 | 182 | # normalize 183 | for spk in range(spk_num): 184 | org = fet_arr[spk] 185 | norm = np.linalg.norm(org, ord=2) 186 | fet_arr[spk] = org / norm 187 | 188 | weight = torch.from_numpy(fet_arr.astype(np.float32)).clone() 189 | print("new all_n_speakers : {}".format(weight.shape[0])) 190 | 191 | print(net) 192 | net.modfy_emb(weight) 193 | print(net) 194 | 195 | device = [device_id for device_id in range(torch.cuda.device_count())] 196 | model = PadertorchModel(net=net) 197 | print('GPU device {} is used'.format(device)) 198 | print('Prepared model.') 199 | 200 | # Setup optimizer 201 | if args.optimizer == 'adam': 202 | optimizer = pt_opt.Adam(lr=args.lr, gradient_clipping=args.gradclip) 203 | elif args.optimizer == 'sgd': 204 | optimizer = pt_opt.SGD(lr=args.lr, gradient_clipping=args.gradclip) 205 | elif args.optimizer == 'noam': 206 | optimizer = pt_opt.Adam(lr=args.lr, betas=(0.9, 0.98), eps=1e-9, 207 | gradient_clipping=args.gradclip) 208 | else: 209 | raise ValueError(args.optimizer) 210 | 211 | trainloader = torch.utils.data.DataLoader(train_set, 212 | batch_size=args.batchsize, 213 | shuffle=False, 214 | num_workers=args.num_workers) 215 | 216 | # pit_loss_ratio means diarization loss ratio 217 | pit_loss_ratio = abs(1 - args.spk_loss_ratio) 218 | spk_loss_ratio = args.spk_loss_ratio 219 | virtual_minibatch_size = len(device) * args.batchsize_per_gpu 220 | 221 | trainer = pt.trainer.Trainer( 222 | model, 223 | args.model_save_dir, 224 | optimizer, 225 | stop_trigger=(args.max_epochs, 'epoch'), 226 | summary_trigger=(1, 'iteration'), 227 | virtual_minibatch_size=virtual_minibatch_size, 228 | loss_weights={ 229 | "pit_loss": pit_loss_ratio, 230 | "spk_loss": spk_loss_ratio, 231 | } 232 | ) 233 | 234 | devloader = torch.utils.data.DataLoader(dev_set, batch_size=args.batchsize, 235 | shuffle=False, 236 | num_workers=args.num_workers, 237 | collate_fn=collate_fn) 238 | if args.test_run == 1: 239 | trainer.test_run(trainloader, devloader) 240 | 241 | trainer.register_validation_hook(validation_iterator=devloader, 242 | max_checkpoints=args.max_epochs+1) 243 | 244 | # learning rate scheduler 245 | if args.optimizer == 'noam': 246 | scheduler = NoamScheduler(trainer.optimizer.optimizer, 247 | args.hidden_size, 248 | warmup_steps=args.noam_warmup_steps, 249 | tot_step=len(trainloader), 250 | scale=1.0) 251 | trainer.register_hook( 252 | pt.train.hooks.LRSchedulerHook(scheduler, trigger=(1, 'iteration')) 253 | ) 254 | 255 | trainer.train(trainloader, resume=False, device=device) 256 | print('Finished!') 257 | 258 | 259 | def save_feature(args): 260 | # Set seed for reproducibility 261 | np.random.seed(args.seed) 262 | torch.manual_seed(args.seed) 263 | torch.backends.cudnn.deterministic = True 264 | os.environ['PYTORCH_SEED'] = str(args.seed) 265 | torch.backends.cudnn.benchmark = False 266 | 267 | device = [device_id for device_id in range(torch.cuda.device_count())] 268 | print('GPU device {} is used'.format(device)) 269 | 270 | train_set = DiarizationDatasetFromWave( 271 | args.train_data_dir, 272 | chunk_size=args.chunk_size, 273 | context_size=args.context_size, 274 | input_transform=args.input_transform, 275 | frame_size=args.frame_size, 276 | frame_shift=args.frame_shift, 277 | subsampling=args.subsampling, 278 | rate=args.sampling_rate, 279 | n_speakers=args.num_speakers, 280 | ) 281 | 282 | # Count n_chunks 283 | batchsize = args.batchsize * len(device) * \ 284 | args.batchsize_per_gpu 285 | f = open('{}/batchsize.txt'.format(args.model_save_dir), 'w') 286 | f.write("{}\n".format(batchsize)) 287 | f.close() 288 | trainloader = torch.utils.data.DataLoader( 289 | train_set, batch_size=batchsize, 290 | shuffle=False, num_workers=args.num_workers, 291 | collate_fn=partial( 292 | collate_fn_ns, 293 | n_speakers=args.num_speakers, 294 | spkidx_tbl=None) 295 | ) 296 | n_chunks = len(trainloader) 297 | print("n_chunks : {}".format(n_chunks)) 298 | os.makedirs("{}/data/".format(args.model_save_dir), exist_ok=True) 299 | f = open('{}/data/n_chunks.txt'.format(args.model_save_dir), 'w') 300 | f.write("{}\n".format(n_chunks)) 301 | f.close() 302 | 303 | if n_chunks % args.feature_nj == 0: 304 | max_num_per_dir = n_chunks // args.feature_nj 305 | else: 306 | max_num_per_dir = n_chunks // args.feature_nj + 1 307 | print("max_num_per_dir : {}".format(max_num_per_dir)) 308 | 309 | # Save featlab_XXXXXXXX.npy and featlab_chunk_indices.txt 310 | spkidx_tbl = None 311 | if args.initmodel: 312 | # adaptation 313 | npz = np.load(args.spkv_lab) 314 | spkidx_tbl = npz['arr_2'] 315 | 316 | torch.manual_seed(args.seed) 317 | trainloader = torch.utils.data.DataLoader( 318 | train_set, batch_size=batchsize, 319 | shuffle=True, num_workers=args.num_workers, 320 | collate_fn=partial( 321 | collate_fn_ns, 322 | n_speakers=args.num_speakers, 323 | spkidx_tbl=spkidx_tbl) 324 | ) 325 | f = open('{}/data/featlab_chunk_indices.txt'. 326 | format(args.model_save_dir), 'w') 327 | idx = 0 328 | digit_num = len(str(args.feature_nj-1)) 329 | fmt = "{}/data/{:0={}}/featlab_{:0=8}.npy" 330 | for data in trainloader: 331 | dir_num = idx // max_num_per_dir 332 | os.makedirs("{}/data/{:0={}}/". 333 | format(args.model_save_dir, dir_num, digit_num), 334 | exist_ok=True) 335 | output_npy_path = fmt.format(args.model_save_dir, 336 | dir_num, digit_num, idx) 337 | print(output_npy_path) 338 | bs = data[0].shape[0] 339 | cs = data[0].shape[1] 340 | # data0 (feature) 341 | data0 = data[0] 342 | # data1 (reference speech activity) 343 | data1 = data[1] 344 | # data2 (reference speaker ID) 345 | data2 = np.zeros([bs, cs, data[2].shape[1]], dtype=np.float32) 346 | for j in range(bs): 347 | data2[j, :, :] = data[2][j, :] 348 | # data3 (reference number of all speakers) 349 | data3 = np.ones([bs, cs, 1], dtype=np.float32) * len(data[3][0]) 350 | # data4 (real chunk size) 351 | data4 = np.zeros([bs, cs, 1], dtype=np.float32) 352 | for j in range(bs): 353 | data4[j, :, :] = data[4][j] 354 | save_data = np.concatenate((data0, 355 | data1, 356 | data2, 357 | data3, 358 | data4), axis=2) 359 | 360 | np.save(output_npy_path, save_data) 361 | for j in range(save_data.shape[0]): 362 | f.write("{} {}\n".format(output_npy_path, j)) 363 | idx += 1 364 | f.close() 365 | 366 | # Create completion flag 367 | f = open('{}/data/.done'.format(args.model_save_dir), 'w') 368 | f.write("") 369 | f.close() 370 | print('Finished!') 371 | -------------------------------------------------------------------------------- /eend/pytorch_backend/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Nippon Telegraph and Telephone corporation (NTT). 2 | # All rights reserved 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | import torch.nn as nn 8 | from torch.optim.lr_scheduler import _LRScheduler 9 | 10 | 11 | class NoamScheduler(_LRScheduler): 12 | """ learning rate scheduler used in the transformer 13 | See https://arxiv.org/pdf/1706.03762.pdf 14 | lrate = d_model**(-0.5) * \ 15 | min(step_num**(-0.5), step_num*warmup_steps**(-1.5)) 16 | Scaling factor is implemented as in 17 | http://nlp.seas.harvard.edu/2018/04/03/attention.html#optimizer 18 | """ 19 | 20 | def __init__( 21 | self, optimizer, d_model, warmup_steps, tot_step, scale, 22 | last_epoch=-1 23 | ): 24 | self.d_model = d_model 25 | self.warmup_steps = warmup_steps 26 | self.tot_step = tot_step 27 | self.scale = scale 28 | super(NoamScheduler, self).__init__(optimizer, last_epoch) 29 | 30 | def get_lr(self): 31 | self.last_epoch = max(1, self.last_epoch) 32 | step_num = self.last_epoch 33 | val = self.scale * self.d_model ** (-0.5) * \ 34 | min(step_num ** (-0.5), step_num * self.warmup_steps ** (-1.5)) 35 | 36 | return [base_lr / base_lr * val for base_lr in self.base_lrs] 37 | 38 | 39 | class MultiHeadSelfAttention(nn.Module): 40 | """ Multi head "self" attention layer 41 | """ 42 | 43 | def __init__(self, n_units, h=8, dropout_rate=0.1): 44 | super(MultiHeadSelfAttention, self).__init__() 45 | self.linearQ = nn.Linear(n_units, n_units) 46 | self.linearK = nn.Linear(n_units, n_units) 47 | self.linearV = nn.Linear(n_units, n_units) 48 | self.linearO = nn.Linear(n_units, n_units) 49 | self.d_k = n_units // h 50 | self.h = h 51 | self.dropout = nn.Dropout(p=dropout_rate) 52 | # attention for plot 53 | self.att = None 54 | 55 | def forward(self, x, batch_size): 56 | # x: (BT, F) 57 | q = self.linearQ(x).reshape(batch_size, -1, self.h, self.d_k) 58 | k = self.linearK(x).reshape(batch_size, -1, self.h, self.d_k) 59 | v = self.linearV(x).reshape(batch_size, -1, self.h, self.d_k) 60 | 61 | scores = torch.matmul( 62 | q.transpose(1, 2), k.permute(0, 2, 3, 1)) / np.sqrt(self.d_k) 63 | # scores: (B, h, T, T) = (B, h, T, d_k) x (B, h, d_k, T) 64 | self.att = F.softmax(scores, dim=3) 65 | p_att = self.dropout(self.att) 66 | x = torch.matmul(p_att, v.transpose(1, 2)) 67 | x = x.transpose(1, 2).reshape(-1, self.h * self.d_k) 68 | 69 | return self.linearO(x) 70 | 71 | 72 | class PositionwiseFeedForward(nn.Module): 73 | """ Positionwise feed-forward layer 74 | """ 75 | 76 | def __init__(self, n_units, d_units, dropout_rate): 77 | super(PositionwiseFeedForward, self).__init__() 78 | self.linear1 = nn.Linear(n_units, d_units) 79 | self.linear2 = nn.Linear(d_units, n_units) 80 | self.dropout = nn.Dropout(p=dropout_rate) 81 | 82 | def forward(self, x): 83 | return self.linear2(self.dropout(F.relu(self.linear1(x)))) 84 | 85 | 86 | class PositionalEncoding(nn.Module): 87 | """ Positional encoding function 88 | """ 89 | 90 | def __init__(self, n_units, dropout_rate, max_len): 91 | super(PositionalEncoding, self).__init__() 92 | self.dropout = nn.Dropout(p=dropout_rate) 93 | positions = np.arange(0, max_len, dtype='f')[:, None] 94 | dens = np.exp( 95 | np.arange(0, n_units, 2, dtype='f') * -(np.log(10000.) / n_units)) 96 | self.enc = np.zeros((max_len, n_units), dtype='f') 97 | self.enc[:, ::2] = np.sin(positions * dens) 98 | self.enc[:, 1::2] = np.cos(positions * dens) 99 | self.scale = np.sqrt(n_units) 100 | 101 | def forward(self, x): 102 | x = x * self.scale + self.xp.array(self.enc[:, :x.shape[1]]) 103 | return self.dropout(x) 104 | 105 | 106 | class TransformerEncoder(nn.Module): 107 | def __init__(self, idim, n_layers, n_units, 108 | e_units=2048, h=8, dropout_rate=0.1): 109 | super(TransformerEncoder, self).__init__() 110 | self.linear_in = nn.Linear(idim, n_units) 111 | self.lnorm_in = nn.LayerNorm(n_units) 112 | self.pos_enc = PositionalEncoding(n_units, dropout_rate, 5000) 113 | self.n_layers = n_layers 114 | self.dropout = nn.Dropout(p=dropout_rate) 115 | for i in range(n_layers): 116 | setattr(self, '{}{:d}'.format("lnorm1_", i), 117 | nn.LayerNorm(n_units)) 118 | setattr(self, '{}{:d}'.format("self_att_", i), 119 | MultiHeadSelfAttention(n_units, h, dropout_rate)) 120 | setattr(self, '{}{:d}'.format("lnorm2_", i), 121 | nn.LayerNorm(n_units)) 122 | setattr(self, '{}{:d}'.format("ff_", i), 123 | PositionwiseFeedForward(n_units, e_units, dropout_rate)) 124 | self.lnorm_out = nn.LayerNorm(n_units) 125 | 126 | def forward(self, x): 127 | # x: (B, T, F) ... batch, time, (mel)freq 128 | BT_size = x.shape[0] * x.shape[1] 129 | # e: (BT, F) 130 | e = self.linear_in(x.reshape(BT_size, -1)) 131 | # Encoder stack 132 | for i in range(self.n_layers): 133 | # layer normalization 134 | e = getattr(self, '{}{:d}'.format("lnorm1_", i))(e) 135 | # self-attention 136 | s = getattr(self, '{}{:d}'.format("self_att_", i))(e, x.shape[0]) 137 | # residual 138 | e = e + self.dropout(s) 139 | # layer normalization 140 | e = getattr(self, '{}{:d}'.format("lnorm2_", i))(e) 141 | # positionwise feed-forward 142 | s = getattr(self, '{}{:d}'.format("ff_", i))(e) 143 | # residual 144 | e = e + self.dropout(s) 145 | # final layer normalization 146 | # output: (BT, F) 147 | return self.lnorm_out(e) 148 | -------------------------------------------------------------------------------- /egs/callhome/v1/Log.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nttcslab-sp/EEND-vector-clustering/dd1a0cda6b0f61ad5c7eefb47dfc4be67227cb8b/egs/callhome/v1/Log.tar.gz -------------------------------------------------------------------------------- /egs/callhome/v1/RESULT.md: -------------------------------------------------------------------------------- 1 | # Environment 2 | - GPU: GeForce GTX 1070 3 | - CUDA Version: 10.1 4 | 5 | # Main common conditions 6 | - clustering algorithm using for inference: constrained AHC 7 | - chunk size for adaptation and inference: 30 seconds 8 | 9 | # Final DER 10 | 11 | ## Speaker counting: oracle 12 | ``` 13 | exp/diarize/scoring/swb_sre_tr_ns3_beta10_100000.swb_sre_cv_ns3_beta10_500.train/avg91-100.callhome1_spkall.adapt.avg21-25.infer_est_nspk0/callhome2_spk2/result_th0.5_med11_collar0.25: OVERALL SPEAKER DIARIZATION ERROR = 8.08 percent of scored speaker time `(ALL) 14 | exp/diarize/scoring/swb_sre_tr_ns3_beta10_100000.swb_sre_cv_ns3_beta10_500.train/avg91-100.callhome1_spkall.adapt.avg21-25.infer_est_nspk0/callhome2_spk3/result_th0.4_med11_collar0.25: OVERALL SPEAKER DIARIZATION ERROR = 11.27 percent of scored speaker time `(ALL) 15 | exp/diarize/scoring/swb_sre_tr_ns3_beta10_100000.swb_sre_cv_ns3_beta10_500.train/avg91-100.callhome1_spkall.adapt.avg21-25.infer_est_nspk0/callhome2_spk4/result_th0.4_med11_collar0.25: OVERALL SPEAKER DIARIZATION ERROR = 15.01 percent of scored speaker time `(ALL) 16 | exp/diarize/scoring/swb_sre_tr_ns3_beta10_100000.swb_sre_cv_ns3_beta10_500.train/avg91-100.callhome1_spkall.adapt.avg21-25.infer_est_nspk0/callhome2_spk5/result_th0.5_med11_collar0.25: OVERALL SPEAKER DIARIZATION ERROR = 23.14 percent of scored speaker time `(ALL) 17 | exp/diarize/scoring/swb_sre_tr_ns3_beta10_100000.swb_sre_cv_ns3_beta10_500.train/avg91-100.callhome1_spkall.adapt.avg21-25.infer_est_nspk0/callhome2_spk6/result_th0.5_med11_collar0.25: OVERALL SPEAKER DIARIZATION ERROR = 26.56 percent of scored speaker time `(ALL) 18 | exp/diarize/scoring/swb_sre_tr_ns3_beta10_100000.swb_sre_cv_ns3_beta10_500.train/avg91-100.callhome1_spkall.adapt.avg21-25.infer_est_nspk0/callhome2_spkall/result_th0.4_med11_collar0.25: OVERALL SPEAKER DIARIZATION ERROR = 12.22 percent of scored speaker time `(ALL) 19 | ``` 20 | 21 | ## Speaker counting: estimated 22 | ``` 23 | exp/diarize/scoring/swb_sre_tr_ns3_beta10_100000.swb_sre_cv_ns3_beta10_500.train/avg91-100.callhome1_spkall.adapt.avg21-25.infer_est_nspk1/callhome2_spk2/result_th0.5_med11_collar0.25: OVERALL SPEAKER DIARIZATION ERROR = 7.96 percent of scored speaker time `(ALL) 24 | exp/diarize/scoring/swb_sre_tr_ns3_beta10_100000.swb_sre_cv_ns3_beta10_500.train/avg91-100.callhome1_spkall.adapt.avg21-25.infer_est_nspk1/callhome2_spk3/result_th0.4_med11_collar0.25: OVERALL SPEAKER DIARIZATION ERROR = 11.93 percent of scored speaker time `(ALL) 25 | exp/diarize/scoring/swb_sre_tr_ns3_beta10_100000.swb_sre_cv_ns3_beta10_500.train/avg91-100.callhome1_spkall.adapt.avg21-25.infer_est_nspk1/callhome2_spk4/result_th0.4_med11_collar0.25: OVERALL SPEAKER DIARIZATION ERROR = 16.38 percent of scored speaker time `(ALL) 26 | exp/diarize/scoring/swb_sre_tr_ns3_beta10_100000.swb_sre_cv_ns3_beta10_500.train/avg91-100.callhome1_spkall.adapt.avg21-25.infer_est_nspk1/callhome2_spk5/result_th0.5_med11_collar0.25: OVERALL SPEAKER DIARIZATION ERROR = 21.21 percent of scored speaker time `(ALL) 27 | exp/diarize/scoring/swb_sre_tr_ns3_beta10_100000.swb_sre_cv_ns3_beta10_500.train/avg91-100.callhome1_spkall.adapt.avg21-25.infer_est_nspk1/callhome2_spk6/result_th0.5_med11_collar0.25: OVERALL SPEAKER DIARIZATION ERROR = 23.10 percent of scored speaker time `(ALL) 28 | exp/diarize/scoring/swb_sre_tr_ns3_beta10_100000.swb_sre_cv_ns3_beta10_500.train/avg91-100.callhome1_spkall.adapt.avg21-25.infer_est_nspk1/callhome2_spkall/result_th0.4_med11_collar0.25: OVERALL SPEAKER DIARIZATION ERROR = 12.49 percent of scored speaker time `(ALL) 29 | ``` 30 | -------------------------------------------------------------------------------- /egs/callhome/v1/cmd.sh: -------------------------------------------------------------------------------- 1 | # Modify this file according to a job scheduling system in your cluster. 2 | # For more information about cmd.sh see http://kaldi-asr.org/doc/queue.html. 3 | # 4 | # If you use your local machine, use "run.pl". 5 | export train_cmd="run.pl" 6 | export infer_cmd="run.pl" 7 | export save_spkv_lab_cmd="run.pl" 8 | export simu_cmd="run.pl" 9 | export sad_cmd="run.pl" 10 | 11 | # If you use Grid Engine, use "queue.pl" 12 | # export train_cmd="queue.pl --mem 32G -l 'hostname=c*'" 13 | # export infer_cmd="queue.pl --mem 32G -l 'hostname=c*'" 14 | # export save_spkv_lab_cmd="queue.pl" 15 | # export simu_cmd="queue.pl" 16 | # export sad_cmd="queue.pl" 17 | 18 | # If you use SLURM, use "slurm.pl". 19 | # export train_cmd="slurm.pl" 20 | # export infer_cmd="slurm.pl" 21 | # export save_spkv_lab_cmd="slurm.pl" 22 | # export simu_cmd="slurm.pl" 23 | # export sad_cmd="slurm.pl" 24 | -------------------------------------------------------------------------------- /egs/callhome/v1/conf/adapt.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Nippon Telegraph and Telephone corporation (NTT). 2 | # All rights reserved 3 | 4 | # adapt options 5 | spk_loss_ratio: 0.01 6 | spkv_dim: 256 7 | max_epochs: 25 8 | input_transform: logmel23_mn 9 | lr: 1e-5 10 | optimizer: adam 11 | num_speakers: 3 12 | gradclip: 5 13 | chunk_size: 300 14 | batchsize: 1 15 | num_workers: 8 16 | hidden_size: 256 17 | context_size: 7 18 | subsampling: 10 19 | frame_size: 200 20 | frame_shift: 80 21 | sampling_rate: 8000 22 | noam_scale: 1.0 23 | noam_warmup_steps: 25000 24 | transformer_encoder_n_heads: 8 25 | transformer_encoder_n_layers: 6 26 | transformer_encoder_dropout: 0.1 27 | seed: 777 28 | feature_nj: 100 29 | batchsize_per_gpu: 8 30 | test_run: 1 31 | -------------------------------------------------------------------------------- /egs/callhome/v1/conf/infer_est_nspk0.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Nippon Telegraph and Telephone corporation (NTT). 2 | # All rights reserved 3 | 4 | # inference options 5 | est_nspk: 0 6 | num_speakers: 3 7 | spkv_dim: 256 8 | hidden_size: 256 9 | input_transform: logmel23_mn 10 | chunk_size: 300 11 | context_size: 7 12 | subsampling: 10 13 | sampling_rate: 8000 14 | frame_size: 200 15 | frame_shift: 80 16 | transformer_encoder_n_heads: 8 17 | transformer_encoder_n_layers: 6 18 | sil_spk_th: 0.05 19 | ahc_dis_th: 1.0 20 | clink_dis: 1e+4 21 | -------------------------------------------------------------------------------- /egs/callhome/v1/conf/infer_est_nspk1.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Nippon Telegraph and Telephone corporation (NTT). 2 | # All rights reserved 3 | 4 | # inference options 5 | est_nspk: 1 6 | num_speakers: 3 7 | spkv_dim: 256 8 | hidden_size: 256 9 | input_transform: logmel23_mn 10 | chunk_size: 300 11 | context_size: 7 12 | subsampling: 10 13 | sampling_rate: 8000 14 | frame_size: 200 15 | frame_shift: 80 16 | transformer_encoder_n_heads: 8 17 | transformer_encoder_n_layers: 6 18 | sil_spk_th: 0.05 19 | ahc_dis_th: 1.0 20 | clink_dis: 1e+4 21 | -------------------------------------------------------------------------------- /egs/callhome/v1/conf/save_spkv_lab.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Nippon Telegraph and Telephone corporation (NTT). 2 | # All rights reserved 3 | 4 | # options for saving speaker vector with label 5 | num_speakers: 3 6 | spkv_dim: 256 7 | hidden_size: 256 8 | input_transform: logmel23_mn 9 | chunk_size: 300 10 | context_size: 7 11 | subsampling: 10 12 | sampling_rate: 8000 13 | frame_size: 200 14 | frame_shift: 80 15 | transformer_encoder_n_heads: 8 16 | transformer_encoder_n_layers: 6 17 | -------------------------------------------------------------------------------- /egs/callhome/v1/conf/train.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Nippon Telegraph and Telephone corporation (NTT). 2 | # All rights reserved 3 | 4 | # training options 5 | spk_loss_ratio: 0.03 6 | spkv_dim: 256 7 | max_epochs: 100 8 | input_transform: logmel23_mn 9 | lr: 0.001 10 | optimizer: noam 11 | num_speakers: 3 12 | gradclip: 5 13 | chunk_size: 150 14 | batchsize: 64 15 | num_workers: 8 16 | hidden_size: 256 17 | context_size: 7 18 | subsampling: 10 19 | frame_size: 200 20 | frame_shift: 80 21 | sampling_rate: 8000 22 | noam_scale: 1.0 23 | noam_warmup_steps: 25000 24 | transformer_encoder_n_heads: 8 25 | transformer_encoder_n_layers: 6 26 | transformer_encoder_dropout: 0.1 27 | seed: 777 28 | feature_nj: 100 29 | batchsize_per_gpu: 16 30 | test_run: 0 31 | -------------------------------------------------------------------------------- /egs/callhome/v1/local: -------------------------------------------------------------------------------- 1 | ../../../tools/EEND/egs/callhome/v1/local -------------------------------------------------------------------------------- /egs/callhome/v1/path.sh: -------------------------------------------------------------------------------- 1 | ../../../tools/EEND/egs/callhome/v1/path.sh -------------------------------------------------------------------------------- /egs/callhome/v1/rttm_results/README: -------------------------------------------------------------------------------- 1 | This directory contains the RTTM files we obtained with the model we used in our paper: 2 | Keisuke Kinoshita, Marc Delcroix, and Naohiro Tawara, "Advances in integration of end-to-end neural and clustering-based diarization for real conversational speech," Proc. Interspeech, 2021 3 | 4 | You can obtian the rrtm files by doing the following. 5 | $ tar ztf eend-vc_callhome2_hyp_rttm.tar.gz 6 | 7 | It will give you the following files, 8 | a) with oracle number of speakers: 9 | exp/diarize/scoring/swb_sre_tr_ns3_beta10_100000.swb_sre_cv_ns3_beta10_500.train/avg91-100.callhome1_spkall.adapt.avg21-25.infer_est_nspk0/callhome2_spk2/hyp_0.5_11.rttm 10 | exp/diarize/scoring/swb_sre_tr_ns3_beta10_100000.swb_sre_cv_ns3_beta10_500.train/avg91-100.callhome1_spkall.adapt.avg21-25.infer_est_nspk0/callhome2_spk3/hyp_0.4_11.rttm 11 | exp/diarize/scoring/swb_sre_tr_ns3_beta10_100000.swb_sre_cv_ns3_beta10_500.train/avg91-100.callhome1_spkall.adapt.avg21-25.infer_est_nspk0/callhome2_spk4/hyp_0.4_11.rttm 12 | exp/diarize/scoring/swb_sre_tr_ns3_beta10_100000.swb_sre_cv_ns3_beta10_500.train/avg91-100.callhome1_spkall.adapt.avg21-25.infer_est_nspk0/callhome2_spk5/hyp_0.5_11.rttm 13 | exp/diarize/scoring/swb_sre_tr_ns3_beta10_100000.swb_sre_cv_ns3_beta10_500.train/avg91-100.callhome1_spkall.adapt.avg21-25.infer_est_nspk0/callhome2_spk6/hyp_0.5_11.rttm 14 | exp/diarize/scoring/swb_sre_tr_ns3_beta10_100000.swb_sre_cv_ns3_beta10_500.train/avg91-100.callhome1_spkall.adapt.avg21-25.infer_est_nspk0/callhome2_spkall/hyp_0.4_11.rttm 15 | 16 | b) with estimated number of speakers: 17 | exp/diarize/scoring/swb_sre_tr_ns3_beta10_100000.swb_sre_cv_ns3_beta10_500.train/avg91-100.callhome1_spkall.adapt.avg21-25.infer_est_nspk1/callhome2_spk2/hyp_0.5_11.rttm 18 | exp/diarize/scoring/swb_sre_tr_ns3_beta10_100000.swb_sre_cv_ns3_beta10_500.train/avg91-100.callhome1_spkall.adapt.avg21-25.infer_est_nspk1/callhome2_spk3/hyp_0.4_11.rttm 19 | exp/diarize/scoring/swb_sre_tr_ns3_beta10_100000.swb_sre_cv_ns3_beta10_500.train/avg91-100.callhome1_spkall.adapt.avg21-25.infer_est_nspk1/callhome2_spk4/hyp_0.4_11.rttm 20 | exp/diarize/scoring/swb_sre_tr_ns3_beta10_100000.swb_sre_cv_ns3_beta10_500.train/avg91-100.callhome1_spkall.adapt.avg21-25.infer_est_nspk1/callhome2_spk5/hyp_0.5_11.rttm 21 | exp/diarize/scoring/swb_sre_tr_ns3_beta10_100000.swb_sre_cv_ns3_beta10_500.train/avg91-100.callhome1_spkall.adapt.avg21-25.infer_est_nspk1/callhome2_spk6/hyp_0.5_11.rttm 22 | exp/diarize/scoring/swb_sre_tr_ns3_beta10_100000.swb_sre_cv_ns3_beta10_500.train/avg91-100.callhome1_spkall.adapt.avg21-25.infer_est_nspk1/callhome2_spkall/hyp_0.4_11.rttm 23 | 24 | 25 | Note that there are very minor difference in performance with the results reported in the paper because we reran the experiments on a different machine. 26 | -------------------------------------------------------------------------------- /egs/callhome/v1/rttm_results/eend-vc_callhome2_hyp_rttm.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nttcslab-sp/EEND-vector-clustering/dd1a0cda6b0f61ad5c7eefb47dfc4be67227cb8b/egs/callhome/v1/rttm_results/eend-vc_callhome2_hyp_rttm.tar.gz -------------------------------------------------------------------------------- /egs/callhome/v1/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright (c) 2021 Nippon Telegraph and Telephone corporation (NTT). 4 | # All rights reserved 5 | 6 | stage=0 7 | db_path=/db # need to modify 8 | simu_opts_num_speaker=3 9 | simu_opts_sil_scale=10 10 | simu_opts_num_train=100000 11 | simu_opts_rvb_prob="0.1 --noise_snrs 10:15:20" 12 | 13 | # Base config files for {train,save_spkv_lab,infer}.py 14 | train_config=conf/train.yaml 15 | save_spkv_lab_config=conf/save_spkv_lab.yaml 16 | adapt_config=conf/adapt.yaml 17 | infer0_config=conf/infer_est_nspk0.yaml 18 | infer1_config=conf/infer_est_nspk1.yaml 19 | 20 | # Additional arguments passed to {train,save_spkv_lab,infer}.py 21 | # You need not edit the base config files above 22 | train_args= 23 | save_spkv_lab_args= 24 | adapt_args= 25 | infer_args= 26 | 27 | # Model averaging options 28 | average_start=91 29 | average_end=100 30 | 31 | # Adapted model averaging options 32 | adapt_average_start=21 33 | adapt_average_end=25 34 | 35 | . path.sh 36 | . cmd.sh 37 | . parse_options.sh || exit 38 | 39 | train_set=data/simu/data/swb_sre_tr_ns${simu_opts_num_speaker}_beta${simu_opts_sil_scale}_${simu_opts_num_train} 40 | valid_set=data/simu/data/swb_sre_cv_ns${simu_opts_num_speaker}_beta${simu_opts_sil_scale}_500 41 | adapt_set=data/eval/callhome1_spkall 42 | adapt_valid_set=data/eval/callhome2_spk3 43 | test_dsets=(callhome2_spk2 callhome2_spk3 callhome2_spk4 callhome2_spk5 callhome2_spk6 callhome2_spkall) 44 | 45 | if [ $stage -le 0 ]; then 46 | echo -e "==== stage 0: prepare data ====" 47 | [ -L db ] && rm db 48 | ln -s $db_path 49 | [ ! -f musan.tar.gz ] && wget https://www.openslr.org/resources/17/musan.tar.gz 50 | [ ! -d musan ] && tar zvxf musan.tar.gz 51 | callhome_dir=$PWD/db/LDC2001S97 52 | swb2_phase1_train=$PWD/db/LDC98S75 53 | data_root=$PWD/db 54 | musan_root=$PWD/musan 55 | simu_actual_dirs=(\ 56 | $PWD/export/c05/diarization-data 57 | $PWD/export/c08/diarization-data 58 | $PWD/export/c09/diarization-data) 59 | sad_opts="--extra-left-context 79 --extra-right-context 21 --frames-per-chunk 150 --extra-left-context-initial 0 --extra-right-context-final 0 --acwt 0.3 --cmd $sad_cmd" 60 | 61 | # Note: the following options are dummies 62 | # --simu_opts_num_speaker_array[1] "$simu_opts_num_speaker" \ 63 | # --simu_opts_num_speaker_array[2] "$simu_opts_num_speaker" \ 64 | # --simu_opts_num_speaker_array[3] "$simu_opts_num_speaker" \ 65 | # --simu_opts_sil_scale_array[1] "$simu_opts_sil_scale" \ 66 | # --simu_opts_sil_scale_array[2] "$simu_opts_sil_scale" \ 67 | # --simu_opts_sil_scale_array[3] "$simu_opts_sil_scale" \ 68 | 69 | ./run_prepare_shared_eda.sh \ 70 | --callhome_dir "$callhome_dir" \ 71 | --swb2_phase1_train "$swb2_phase1_train" \ 72 | --data_root "$data_root" \ 73 | --musan_root "$musan_root" \ 74 | --simu_actual_dirs[0] "${simu_actual_dirs[0]}" \ 75 | --simu_actual_dirs[1] "${simu_actual_dirs[1]}" \ 76 | --simu_actual_dirs[2] "${simu_actual_dirs[2]}" \ 77 | --sad_opts "$sad_opts" \ 78 | --simu_opts_num_speaker_array[0] "$simu_opts_num_speaker" \ 79 | --simu_opts_num_speaker_array[1] "$simu_opts_num_speaker" \ 80 | --simu_opts_num_speaker_array[2] "$simu_opts_num_speaker" \ 81 | --simu_opts_num_speaker_array[3] "$simu_opts_num_speaker" \ 82 | --simu_opts_sil_scale_array[0] "$simu_opts_sil_scale" \ 83 | --simu_opts_sil_scale_array[1] "$simu_opts_sil_scale" \ 84 | --simu_opts_sil_scale_array[2] "$simu_opts_sil_scale" \ 85 | --simu_opts_sil_scale_array[3] "$simu_opts_sil_scale" \ 86 | --simu_opts_num_train "$simu_opts_num_train" \ 87 | --simu_opts_rvb_prob "$simu_opts_rvb_prob" 88 | 89 | # Remove the following extra data 90 | rm -rf data/simu/data/swb_sre_cv_ns3n3n3n3_beta10n10n10n10_500 91 | rm -rf data/simu/data/swb_sre_tr_ns3n3n3n3_beta10n10n10n10_100000 92 | 93 | # Fix callhome1_spkall and callhome2_spkall 94 | for dset in callhome1_spkall callhome2_spkall; do 95 | # Remove the segment with zero duration from callhome1_spkall 96 | perl -p -i -e 's/iait_A_0035072_0035072\n//' data/eval/$dset/utt2spk 97 | perl -p -i -e 's/iait_A_0035072_0035072\n//' data/eval/$dset/segments 98 | # Modify speaker ID 99 | (cat data/eval/$dset/utt2spk \ 100 | | grep "_0"$ \ 101 | || perl -p -i -e 's/$/_0/' data/eval/$dset/utt2spk) > /dev/null 102 | # Update spk2utt and rttm 103 | LC_ALL=C utils/fix_data_dir.sh data/eval/$dset 104 | steps/segmentation/convert_utt2spk_and_segments_to_rttm.py \ 105 | data/eval/$dset/utt2spk data/eval/$dset/segments \ 106 | data/eval/$dset/rttm 107 | done 108 | 109 | # Separate callhome1_spkall and callhome2_spkall for each number of speakers 110 | for spk_num in spk2 spk3 spk4 spk5 spk6 spk7; do 111 | for dset in callhome1 callhome2; do 112 | [ "$spk_num" == "spk7" ] && [ "$dset" == "callhome2" ] && continue 113 | if ! validate_data_dir.sh --no-text --no-feats data/${dset}_${spk_num}; then 114 | copy_data_dir.sh data/${dset} data/${dset}_${spk_num} 115 | n_spk=$(echo $spk_num | perl -pe 's/spk//') 116 | echo "n_spk : $n_spk" 117 | # Extract ${n_spk}-speaker recordings in wav.scp 118 | utils/filter_scp.pl <(awk -v n_spk=${n_spk} '{if($2==n_spk) print;}' data/${dset}/reco2num_spk) \ 119 | data/${dset}/wav.scp > data/${dset}_${spk_num}/wav.scp 120 | # Regenerate segments file from fullref.rttm 121 | # $2: recid, $4: start_time, $5: duration, $8: speakerid 122 | awk '{printf "%s_%s_%07d_%07d %s %.2f %.2f\n", \ 123 | $2, $8, $4*100, ($4+$5)*100, $2, $4, $4+$5}' \ 124 | data/callhome/fullref.rttm | sort \ 125 | | grep -v "iait_A_0035072_0035072" \ 126 | > data/${dset}_${spk_num}/segments 127 | utils/fix_data_dir.sh data/${dset}_${spk_num} 128 | # Speaker ID is '[recid]_[speakerid]_0 129 | awk '{split($1,A,"_"); printf "%s %s_%s_0\n", $1, A[1], A[2]}' \ 130 | data/${dset}_${spk_num}/segments > data/${dset}_${spk_num}/utt2spk 131 | LC_ALL=C utils/fix_data_dir.sh data/${dset}_${spk_num} 132 | # Generate rttm files for scoring 133 | steps/segmentation/convert_utt2spk_and_segments_to_rttm.py \ 134 | data/${dset}_${spk_num}/utt2spk data/${dset}_${spk_num}/segments \ 135 | data/${dset}_${spk_num}/rttm 136 | utils/data/get_reco2dur.sh data/${dset}_${spk_num} 137 | # Compose data/eval/${dset}_${spk_num} 138 | dst_dset=data/eval/${dset}_${spk_num} 139 | if ! validate_data_dir.sh --no-text --no-feats $dst_dset; then 140 | utils/copy_data_dir.sh data/${dset}_${spk_num} $dst_dset 141 | cp data/${dset}_${spk_num}/rttm $dst_dset/rttm 142 | awk -v dstdir=wav/eval/${dset}_${spk_num} '{print $1, dstdir"/"$1".wav"}' \ 143 | data/${dset}_${spk_num}/wav.scp > $dst_dset/wav.scp 144 | mkdir -p wav/eval/${dset}_${spk_num} 145 | wav-copy scp:data/${dset}_${spk_num}/wav.scp scp:$dst_dset/wav.scp 146 | utils/data/get_reco2dur.sh $dst_dset 147 | LC_ALL=C utils/fix_data_dir.sh $dst_dset 148 | fi 149 | fi 150 | done 151 | done 152 | fi 153 | 154 | set -eu 155 | # Parse the config file to set bash variables like: $infer0_frame_shift, $infer1_subsampling 156 | eval `yaml2bash.py --prefix infer0 $infer0_config` 157 | eval `yaml2bash.py --prefix infer1 $infer1_config` 158 | 159 | # Append gpu reservation flag to the queuing command 160 | train_cmd+=" --gpu 1" 161 | save_spkv_lab_cmd+=" --gpu 1" 162 | infer_cmd+=" --gpu 1" 163 | 164 | # Build directry names for an experiment 165 | # - Training 166 | # exp/diarize/model/${train_id}.${valid_id}.${train_config_id} 167 | # - Adapation from non-adapted averaged model 168 | # exp/diarize/model/${train_id}.${valid_id}.${train_config_id}/${ave_id}.${adapt_id}.${adapt_config_id} 169 | # - Inference 170 | # exp/diarize/infer/${train_id}.${valid_id}.${train_config_id}/${ave_id}.${adapt_id}.${adapt_config_id}.${adapt_ave_id}.${infer0_config_id} 171 | # exp/diarize/infer/${train_id}.${valid_id}.${train_config_id}/${ave_id}.${adapt_id}.${adapt_config_id}.${adapt_ave_id}.${infer1_config_id} 172 | # - Scoring 173 | # exp/diarize/scoring/${train_id}.${valid_id}.${train_config_id}/${ave_id}.${adapt_id}.${adapt_config_id}.${adapt_ave_id}.${infer0_config_id} 174 | # exp/diarize/scoring/${train_id}.${valid_id}.${train_config_id}/${ave_id}.${adapt_id}.${adapt_config_id}.${adapt_ave_id}.${infer1_config_id} 175 | train_id=$(basename $train_set) 176 | valid_id=$(basename $valid_set) 177 | train_config_id=$(echo $train_config | sed -e 's%conf/%%' -e 's%/%_%' -e 's%\.yaml$%%') 178 | save_spkv_lab_config_id=$(echo $save_spkv_lab_config | sed -e 's%conf/%%' -e 's%/%_%' -e 's%\.yaml$%%') 179 | adapt_config_id=$(echo $adapt_config | sed -e 's%conf/%%' -e 's%/%_%' -e 's%\.yaml$%%') 180 | infer0_config_id=$(echo $infer0_config | sed -e 's%conf/%%' -e 's%/%_%' -e 's%\.yaml$%%') 181 | infer1_config_id=$(echo $infer1_config | sed -e 's%conf/%%' -e 's%/%_%' -e 's%\.yaml$%%') 182 | 183 | # Additional arguments are added to config_id 184 | train_config_id+=$(echo $train_args | sed -e 's/\-\-/_/g' -e 's/=//g' -e 's/ \+//g') 185 | save_spkv_lab_config_id+=$(echo $save_spkv_lab_args | sed -e 's/\-\-/_/g' -e 's/=//g' -e 's/ \+//g') 186 | adapt_config_id+=$(echo $adapt_args | sed -e 's/\-\-/_/g' -e 's/=//g' -e 's/ \+//g') 187 | infer0_config_id+=$(echo $infer_args | sed -e 's/\-\-/_/g' -e 's/=//g' -e 's/ \+//g') 188 | infer1_config_id+=$(echo $infer_args | sed -e 's/\-\-/_/g' -e 's/=//g' -e 's/ \+//g') 189 | 190 | model_id=$train_id.$valid_id.$train_config_id 191 | model_dir=exp/diarize/model/$model_id 192 | 193 | if [ $stage -le 1 ]; then 194 | echo -e "\n==== stage 1: training model with simulated mixtures ====" 195 | # To speed up the training process, we first calculate input features 196 | # to NN and save shuffled feature data to the disc. During training, 197 | # we simply read the saved data from the disc. 198 | # Note: shuffled feature data (default total size: 336GB) are saved at the following place 199 | # exp/diarize/model/${train_id}.${valid_id}.${train_config_id}/data/ 200 | echo "training model at $model_dir" 201 | if [ -d $model_dir/checkpoints ]; then 202 | echo "$model_dir/checkpoints already exists. " 203 | echo " if you want to retry, please remove it." 204 | exit 1 205 | fi 206 | work=$model_dir/.work 207 | mkdir -p $work 208 | $train_cmd $work/train.log \ 209 | train.py \ 210 | -c $train_config \ 211 | $train_args \ 212 | $train_set $valid_set $model_dir \ 213 | || exit 1 214 | fi 215 | 216 | ave_id=avg${average_start}-${average_end} 217 | if [ $stage -le 2 ]; then 218 | echo -e "\n==== stage 2: averaging trained models ====" 219 | echo "averaging model parameters into $model_dir/checkpoints/$ave_id.nnet.pth" 220 | if [ -s $model_dir/checkpoints/$ave_id.nnet.pth ]; then 221 | echo "$model_dir/checkpoints/$ave_id.nnet.pth already exists. " 222 | fi 223 | last_epoch=$(ls $model_dir/checkpoints/ckpt_[0-9]*.pth | grep -v "/ckpt_0.pth"$ | wc -l) 224 | echo -e "last epoch of existence : $last_epoch" 225 | if [ $last_epoch -lt $average_end ]; then 226 | echo -e "error : average_end $average_end is too large." 227 | exit 1 228 | fi 229 | models=$(ls $model_dir/checkpoints/ckpt_[0-9]*.pth -tr | head -n $((${average_end}+1)) | tail -n $((${average_end}-${average_start}+1))) 230 | echo -e "take the average with the following models:" 231 | echo -e $models | tr " " "\n" 232 | model_averaging.py $model_dir/checkpoints/$ave_id.nnet.pth $models || exit 1 233 | fi 234 | 235 | adapt_id=$(basename $adapt_set) 236 | adapt_model_dir=exp/diarize/model/$model_id/$ave_id.$adapt_id.$adapt_config_id 237 | save_spkv_lab_dir=$adapt_model_dir/$save_spkv_lab_config_id 238 | if [ $stage -le 3 ]; then 239 | echo -e "\n==== stage 3: adapting model to CALLHOME dataset ====" 240 | 241 | # stage 3-1: saving speaker vector with label and speaker ID conversion table for initializing embedding dictionary 242 | echo "adapt_set: $(basename $adapt_set)" 243 | work=$save_spkv_lab_dir/.work 244 | mkdir -p $work 245 | $save_spkv_lab_cmd $work/save_spkv_lab.log \ 246 | save_spkv_lab.py \ 247 | -c $save_spkv_lab_config \ 248 | $save_spkv_lab_args \ 249 | $adapt_set \ 250 | $model_dir/checkpoints/$ave_id.nnet.pth \ 251 | $save_spkv_lab_dir \ 252 | || exit 1 253 | echo -e "finished saving speaker vector with label" 254 | 255 | # stage 3-2: adapting model to CALLHOME dataset 256 | echo "adapting model at $adapt_model_dir" 257 | if [ -d $adapt_model_dir/checkpoints ]; then 258 | echo "$adapt_model_dir/checkpoints already exists. " 259 | echo " if you want to retry, please remove it." 260 | exit 1 261 | fi 262 | work=$adapt_model_dir/.work 263 | mkdir -p $work 264 | $train_cmd $work/train.log \ 265 | train.py \ 266 | -c $adapt_config \ 267 | $adapt_args \ 268 | --initmodel $model_dir/checkpoints/$ave_id.nnet.pth \ 269 | --spkv-lab $save_spkv_lab_dir/spkvec_lab.npz \ 270 | $adapt_set $adapt_valid_set $adapt_model_dir \ 271 | || exit 1 272 | fi 273 | 274 | adapt_ave_id=avg${adapt_average_start}-${adapt_average_end} 275 | if [ $stage -le 4 ]; then 276 | echo -e "\n==== stage 4: averaging adapted models ====" 277 | echo "averaging models into $adapt_model_dir/checkpoints/$adapt_ave_id.nnet.pth" 278 | if [ -s $adapt_model_dir/checkpoints/$adapt_ave_id.nnet.pth ]; then 279 | echo "$adapt_model_dir/checkpoints/$adapt_ave_id.nnet.pth already exists." 280 | fi 281 | last_epoch=$(ls $adapt_model_dir/checkpoints/ckpt_[0-9]*.pth | grep -v "/ckpt_0.pth"$ | wc -l) 282 | echo -e "last epoch of existence : $last_epoch" 283 | if [ $last_epoch -lt $adapt_average_end ]; then 284 | echo -e "error : adapt_average_end $adapt_average_end is too large." 285 | exit 1 286 | fi 287 | models=$(ls $adapt_model_dir/checkpoints/ckpt_[0-9]*.pth -tr | head -n $((${adapt_average_end}+1)) | tail -n $((${adapt_average_end}-${adapt_average_start}+1))) 288 | echo -e "take the average with the following models:" 289 | echo -e $models | tr " " "\n" 290 | model_averaging.py $adapt_model_dir/checkpoints/$adapt_ave_id.nnet.pth $models || exit 1 291 | fi 292 | 293 | infer_dir=exp/diarize/infer/$model_id/$ave_id.$(basename $adapt_set).$adapt_config_id.$adapt_ave_id.$infer0_config_id 294 | if [ $stage -le 5 ]; then 295 | echo -e "\n==== stage 5: inference for evaluation (speaker counting: oracle) ====" 296 | echo "inference at $infer_dir" 297 | if [ -d $infer_dir ]; then 298 | echo "$infer_dir already exists. " 299 | fi 300 | for dset in ${test_dsets[@]}; do 301 | work=$infer_dir/$dset/.work 302 | mkdir -p $work 303 | $infer_cmd $work/infer.log \ 304 | infer.py -c $infer0_config \ 305 | $infer_args \ 306 | data/eval/${dset} \ 307 | $adapt_model_dir/checkpoints/$adapt_ave_id.nnet.pth \ 308 | $infer_dir/$dset \ 309 | || exit 1 310 | done 311 | fi 312 | 313 | scoring_dir=exp/diarize/scoring/$model_id/$ave_id.$(basename $adapt_set).$adapt_config_id.$adapt_ave_id.$infer0_config_id 314 | if [ $stage -le 6 ]; then 315 | echo -e "\n==== stage 6: scoring for evaluation (speaker counting: oracle) ====" 316 | echo "scoring at $scoring_dir" 317 | if [ -d $scoring_dir ]; then 318 | echo "$scoring_dir already exists. " 319 | fi 320 | for dset in ${test_dsets[@]}; do 321 | work=$scoring_dir/$dset/.work 322 | mkdir -p $work 323 | find $infer_dir/$dset -iname "*.h5" > $work/file_list_$dset 324 | for med in 1 11; do 325 | for th in 0.3 0.4 0.5 0.6 0.7; do 326 | make_rttm.py --median=$med --threshold=$th \ 327 | --frame_shift=$infer0_frame_shift --subsampling=$infer0_subsampling --sampling_rate=$infer0_sampling_rate \ 328 | $work/file_list_$dset $scoring_dir/$dset/hyp_${th}_$med.rttm 329 | md-eval.pl -c 0.25 \ 330 | -r data/eval/$dset/rttm \ 331 | -s $scoring_dir/$dset/hyp_${th}_$med.rttm > $scoring_dir/$dset/result_th${th}_med${med}_collar0.25 2>/dev/null || exit 332 | done 333 | done 334 | best_score.sh $scoring_dir/$dset 335 | done 336 | fi 337 | 338 | infer_dir=exp/diarize/infer/$model_id/$ave_id.$(basename $adapt_set).$adapt_config_id.$adapt_ave_id.$infer1_config_id 339 | if [ $stage -le 7 ]; then 340 | echo -e "\n==== stage 7: inference for evaluation (speaker counting: estimated) ====" 341 | echo "inference at $infer_dir" 342 | if [ -d $infer_dir ]; then 343 | echo "$infer_dir already exists. " 344 | fi 345 | for dset in ${test_dsets[@]}; do 346 | work=$infer_dir/$dset/.work 347 | mkdir -p $work 348 | $infer_cmd $work/infer.log \ 349 | infer.py -c $infer1_config \ 350 | $infer_args \ 351 | data/eval/${dset} \ 352 | $adapt_model_dir/checkpoints/$adapt_ave_id.nnet.pth \ 353 | $infer_dir/$dset \ 354 | || exit 1 355 | done 356 | fi 357 | 358 | scoring_dir=exp/diarize/scoring/$model_id/$ave_id.$(basename $adapt_set).$adapt_config_id.$adapt_ave_id.$infer1_config_id 359 | if [ $stage -le 8 ]; then 360 | echo -e "\n==== stage 8: scoring for evaluation (speaker counting: estimated) ====" 361 | echo "scoring at $scoring_dir" 362 | if [ -d $scoring_dir ]; then 363 | echo "$scoring_dir already exists. " 364 | fi 365 | for dset in ${test_dsets[@]}; do 366 | work=$scoring_dir/$dset/.work 367 | mkdir -p $work 368 | find $infer_dir/$dset -iname "*.h5" > $work/file_list_$dset 369 | for med in 1 11; do 370 | for th in 0.3 0.4 0.5 0.6 0.7; do 371 | make_rttm.py --median=$med --threshold=$th \ 372 | --frame_shift=$infer1_frame_shift --subsampling=$infer1_subsampling --sampling_rate=$infer1_sampling_rate \ 373 | $work/file_list_$dset $scoring_dir/$dset/hyp_${th}_$med.rttm 374 | md-eval.pl -c 0.25 \ 375 | -r data/eval/$dset/rttm \ 376 | -s $scoring_dir/$dset/hyp_${th}_$med.rttm > $scoring_dir/$dset/result_th${th}_med${med}_collar0.25 2>/dev/null || exit 377 | done 378 | done 379 | best_score.sh $scoring_dir/$dset 380 | done 381 | fi 382 | 383 | echo "Finished !" 384 | -------------------------------------------------------------------------------- /egs/callhome/v1/run_prepare_shared_eda.sh: -------------------------------------------------------------------------------- 1 | ../../../tools/EEND/egs/callhome/v1/run_prepare_shared_eda.sh -------------------------------------------------------------------------------- /egs/callhome/v1/steps: -------------------------------------------------------------------------------- 1 | ../../../tools/kaldi/egs/wsj/s5/steps -------------------------------------------------------------------------------- /egs/callhome/v1/utils: -------------------------------------------------------------------------------- 1 | ../../../tools/kaldi/egs/wsj/s5/utils -------------------------------------------------------------------------------- /egs/mini_librispeech/v1/RESULT.md: -------------------------------------------------------------------------------- 1 | # Environment 2 | - GPU: GeForce GTX 1070 3 | - CUDA Version: 10.1 4 | 5 | # Main common conditions 6 | - clustering algorithm using for inference: constrained AHC 7 | - chunk size for inference: 30 seconds 8 | 9 | # Training curve (validation loss) 10 | Note that validation loss := (1 - spk_loss_ratio) * pit_loss 11 | 12 | ``` 13 | grep "Mean loss" exp/diarize/model/train_clean_5_ns3_beta2_500.dev_clean_2_ns3_beta2_500.train/.work/train.log 14 | Finished Validation. Mean loss: 0.6628542523635061 15 | Finished Validation. Mean loss: 0.6607387272935165 16 | Finished Validation. Mean loss: 0.6548653715535214 17 | Finished Validation. Mean loss: 0.6455376411739149 18 | Finished Validation. Mean loss: 0.6333835890418604 19 | Finished Validation. Mean loss: 0.6193064156331514 20 | Finished Validation. Mean loss: 0.6045306845715172 21 | Finished Validation. Mean loss: 0.5902855377448233 22 | Finished Validation. Mean loss: 0.5776547281365646 23 | Finished Validation. Mean loss: 0.5674387065987838 24 | Finished Validation. Mean loss: 0.5597694854987295 25 | ``` 26 | 27 | # Final DER 28 | 29 | ## Speaker counting: oracle 30 | ``` 31 | exp/diarize/scoring/train_clean_5_ns3_beta2_500.dev_clean_2_ns3_beta2_500.train.avg8-10.infer_est_nspk0/dev_clean_2_ns3_beta2_500/result_th0.5_med11_collar0.25: OVERALL SPEAKER DIARIZATION ERROR = 36.70 percent of scored speaker time `(ALL) 32 | ``` 33 | 34 | ## Speaker counting: estimated 35 | ``` 36 | exp/diarize/scoring/train_clean_5_ns3_beta2_500.dev_clean_2_ns3_beta2_500.train.avg8-10.infer_est_nspk1/dev_clean_2_ns3_beta2_500/result_th0.5_med11_collar0.25: OVERALL SPEAKER DIARIZATION ERROR = 41.48 percent of scored speaker time `(ALL) 37 | ``` 38 | -------------------------------------------------------------------------------- /egs/mini_librispeech/v1/cmd.sh: -------------------------------------------------------------------------------- 1 | # Modify this file according to a job scheduling system in your cluster. 2 | # For more information about cmd.sh see http://kaldi-asr.org/doc/queue.html. 3 | # 4 | # If you use your local machine, use "run.pl". 5 | export train_cmd="run.pl" 6 | export infer_cmd="run.pl" 7 | export simu_cmd="run.pl" 8 | 9 | # If you use Grid Engine, use "queue.pl" 10 | # export train_cmd="queue.pl --mem 16G -l 'hostname=c*'" 11 | # export infer_cmd="queue.pl --mem 16G -l 'hostname=c*'" 12 | # export simu_cmd="queue.pl" 13 | 14 | # If you use SLURM, use "slurm.pl". 15 | # export train_cmd="slurm.pl" 16 | # export infer_cmd="slurm.pl" 17 | # export simu_cmd="slurm.pl" 18 | -------------------------------------------------------------------------------- /egs/mini_librispeech/v1/conf/infer_est_nspk0.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Nippon Telegraph and Telephone corporation (NTT). 2 | # All rights reserved 3 | 4 | # inference options 5 | est_nspk: 0 6 | num_speakers: 3 7 | spkv_dim: 256 8 | hidden_size: 256 9 | input_transform: logmel23_mn 10 | chunk_size: 300 11 | context_size: 7 12 | subsampling: 10 13 | sampling_rate: 8000 14 | frame_size: 200 15 | frame_shift: 80 16 | transformer_encoder_n_heads: 8 17 | transformer_encoder_n_layers: 6 18 | sil_spk_th: 0.05 19 | ahc_dis_th: 1.0 20 | clink_dis: 1e+4 21 | -------------------------------------------------------------------------------- /egs/mini_librispeech/v1/conf/infer_est_nspk1.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Nippon Telegraph and Telephone corporation (NTT). 2 | # All rights reserved 3 | 4 | # inference options 5 | est_nspk: 1 6 | num_speakers: 3 7 | spkv_dim: 256 8 | hidden_size: 256 9 | input_transform: logmel23_mn 10 | chunk_size: 300 11 | context_size: 7 12 | subsampling: 10 13 | sampling_rate: 8000 14 | frame_size: 200 15 | frame_shift: 80 16 | transformer_encoder_n_heads: 8 17 | transformer_encoder_n_layers: 6 18 | sil_spk_th: 0.05 19 | ahc_dis_th: 1.0 20 | clink_dis: 1e+4 21 | -------------------------------------------------------------------------------- /egs/mini_librispeech/v1/conf/train.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Nippon Telegraph and Telephone corporation (NTT). 2 | # All rights reserved 3 | 4 | # training options 5 | spk_loss_ratio: 0.03 6 | spkv_dim: 256 7 | max_epochs: 10 8 | input_transform: logmel23_mn 9 | lr: 0.001 10 | optimizer: noam 11 | num_speakers: 3 12 | gradclip: 5 13 | chunk_size: 150 14 | batchsize: 64 15 | num_workers: 8 16 | hidden_size: 256 17 | context_size: 7 18 | subsampling: 10 19 | frame_size: 200 20 | frame_shift: 80 21 | sampling_rate: 8000 22 | noam_scale: 1.0 23 | noam_warmup_steps: 25000 24 | transformer_encoder_n_heads: 8 25 | transformer_encoder_n_layers: 6 26 | transformer_encoder_dropout: 0.1 27 | seed: 777 28 | feature_nj: 100 29 | batchsize_per_gpu: 16 30 | test_run: 0 31 | -------------------------------------------------------------------------------- /egs/mini_librispeech/v1/local: -------------------------------------------------------------------------------- 1 | ../../../tools/EEND/egs/mini_librispeech/v1/local -------------------------------------------------------------------------------- /egs/mini_librispeech/v1/musan_bgnoise.tar.gz: -------------------------------------------------------------------------------- 1 | ../../../tools/EEND/egs/mini_librispeech/v1/musan_bgnoise.tar.gz -------------------------------------------------------------------------------- /egs/mini_librispeech/v1/path.sh: -------------------------------------------------------------------------------- 1 | ../../../tools/EEND/egs/mini_librispeech/v1/path.sh -------------------------------------------------------------------------------- /egs/mini_librispeech/v1/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright (c) 2021 Nippon Telegraph and Telephone corporation (NTT). 4 | # All rights reserved 5 | 6 | stage=0 7 | 8 | # Base config files for {train,infer}.py 9 | train_config=conf/train.yaml 10 | infer0_config=conf/infer_est_nspk0.yaml 11 | infer1_config=conf/infer_est_nspk1.yaml 12 | 13 | # Additional arguments passed to {train,infer}.py 14 | # You need not edit the base config files above 15 | train_args= 16 | infer_args= 17 | 18 | # Model averaging options 19 | average_start=8 20 | average_end=10 21 | 22 | . path.sh 23 | . cmd.sh 24 | . parse_options.sh || exit 25 | 26 | train_set=data/simu/data/train_clean_5_ns3_beta2_500 27 | valid_set=data/simu/data/dev_clean_2_ns3_beta2_500 28 | 29 | if [ $stage -le 0 ]; then 30 | echo -e "==== stage 0: prepare data ====" 31 | ./run_prepare_shared.sh --simu_opts_num_speaker 3 32 | # Note that for simplicity we generate data/simu/data/*/utt2spk by using local/data_prep.sh, 33 | # then speaker ID is regarded as [reader]-[chapter] 34 | fi 35 | 36 | set -eu 37 | # Parse the config file to set bash variables like: $infer0_frame_shift, $infer1_subsampling 38 | eval `yaml2bash.py --prefix infer0 $infer0_config` 39 | eval `yaml2bash.py --prefix infer1 $infer1_config` 40 | 41 | # Append gpu reservation flag to the queuing command 42 | train_cmd+=" --gpu 1" 43 | infer_cmd+=" --gpu 1" 44 | 45 | # Build directry names for an experiment 46 | # - Training 47 | # exp/diarize/model/${train_id}.${valid_id}.${train_config_id} 48 | # - Inference 49 | # exp/diarize/infer/${train_id}.${valid_id}.${train_config_id}.${infer0_config_id} 50 | # exp/diarize/infer/${train_id}.${valid_id}.${train_config_id}.${infer1_config_id} 51 | # - Scoring 52 | # exp/diarize/scoring/${train_id}.${valid_id}.${train_config_id}.${infer0_config_id} 53 | # exp/diarize/scoring/${train_id}.${valid_id}.${train_config_id}.${infer1_config_id} 54 | train_id=$(basename $train_set) 55 | valid_id=$(basename $valid_set) 56 | train_config_id=$(echo $train_config | sed -e 's%conf/%%' -e 's%/%_%' -e 's%\.yaml$%%') 57 | infer0_config_id=$(echo $infer0_config | sed -e 's%conf/%%' -e 's%/%_%' -e 's%\.yaml$%%') 58 | infer1_config_id=$(echo $infer1_config | sed -e 's%conf/%%' -e 's%/%_%' -e 's%\.yaml$%%') 59 | 60 | # Additional arguments are added to config_id 61 | train_config_id+=$(echo $train_args | sed -e 's/\-\-/_/g' -e 's/=//g' -e 's/ \+//g') 62 | infer0_config_id+=$(echo $infer_args | sed -e 's/\-\-/_/g' -e 's/=//g' -e 's/ \+//g') 63 | infer1_config_id+=$(echo $infer_args | sed -e 's/\-\-/_/g' -e 's/=//g' -e 's/ \+//g') 64 | 65 | model_id=$train_id.$valid_id.$train_config_id 66 | model_dir=exp/diarize/model/$model_id 67 | 68 | if [ $stage -le 1 ]; then 69 | echo -e "\n==== stage 1: training model with simulated mixtures ====" 70 | # To speed up the training process, we first calculate input features 71 | # to NN and save shuffled feature data to the disc. During training, 72 | # we simply read the saved data from the disc. 73 | echo "training model at $model_dir" 74 | if [ -d $model_dir/checkpoints ]; then 75 | echo "$model_dir/checkpoints already exists. " 76 | echo " if you want to retry, please remove it." 77 | exit 1 78 | fi 79 | work=$model_dir/.work 80 | mkdir -p $work 81 | $train_cmd $work/train.log \ 82 | train.py \ 83 | -c $train_config \ 84 | $train_args \ 85 | $train_set $valid_set $model_dir \ 86 | || exit 1 87 | fi 88 | 89 | ave_id=avg${average_start}-${average_end} 90 | if [ $stage -le 2 ]; then 91 | echo -e "\n==== stage 2: averaging trained models ====" 92 | echo "averaging model parameters into $model_dir/checkpoints/$ave_id.nnet.pth" 93 | if [ -s $model_dir/checkpoints/$ave_id.nnet.pth ]; then 94 | echo "$model_dir/checkpoints/$ave_id.nnet.pth already exists. " 95 | fi 96 | last_epoch=$(ls $model_dir/checkpoints/ckpt_[0-9]*.pth | grep -v "/ckpt_0.pth"$ | wc -l) 97 | echo -e "last epoch of existence : $last_epoch" 98 | if [ $last_epoch -lt $average_end ]; then 99 | echo -e "error : average_end $average_end is too large." 100 | exit 1 101 | fi 102 | models=$(ls $model_dir/checkpoints/ckpt_[0-9]*.pth -tr | head -n $((${average_end}+1)) | tail -n $((${average_end}-${average_start}+1))) 103 | echo -e "take the average with the following models:" 104 | echo -e $models | tr " " "\n" 105 | model_averaging.py $model_dir/checkpoints/$ave_id.nnet.pth $models || exit 1 106 | fi 107 | 108 | infer_dir=exp/diarize/infer/$model_id.$ave_id.$infer0_config_id 109 | if [ $stage -le 3 ]; then 110 | echo -e "\n==== stage 3: inference for evaluation (speaker counting: oracle) ====" 111 | echo "inference at $infer_dir" 112 | if [ -d $infer_dir ]; then 113 | echo "$infer_dir already exists. " 114 | fi 115 | for dset in dev_clean_2_ns3_beta2_500; do 116 | work=$infer_dir/$dset/.work 117 | mkdir -p $work 118 | $infer_cmd $work/infer.log \ 119 | infer.py -c $infer0_config \ 120 | $infer_args \ 121 | data/simu/data/${dset} \ 122 | $model_dir/checkpoints/$ave_id.nnet.pth \ 123 | $infer_dir/$dset \ 124 | || exit 1 125 | done 126 | fi 127 | 128 | scoring_dir=exp/diarize/scoring/$model_id.$ave_id.$infer0_config_id 129 | if [ $stage -le 4 ]; then 130 | echo -e "\n==== stage 4: scoring for evaluation (speaker counting: oracle) ====" 131 | echo "scoring at $scoring_dir" 132 | if [ -d $scoring_dir ]; then 133 | echo "$scoring_dir already exists. " 134 | fi 135 | for dset in dev_clean_2_ns3_beta2_500; do 136 | work=$scoring_dir/$dset/.work 137 | mkdir -p $work 138 | find $infer_dir/$dset -iname "*.h5" > $work/file_list_$dset 139 | for med in 1 11; do 140 | for th in 0.3 0.4 0.5 0.6 0.7; do 141 | make_rttm.py --median=$med --threshold=$th \ 142 | --frame_shift=$infer0_frame_shift --subsampling=$infer0_subsampling --sampling_rate=$infer0_sampling_rate \ 143 | $work/file_list_$dset $scoring_dir/$dset/hyp_${th}_$med.rttm 144 | md-eval.pl -c 0.25 \ 145 | -r data/simu/data/$dset/rttm \ 146 | -s $scoring_dir/$dset/hyp_${th}_$med.rttm > $scoring_dir/$dset/result_th${th}_med${med}_collar0.25 2>/dev/null || exit 147 | done 148 | done 149 | best_score.sh $scoring_dir/$dset 150 | done 151 | fi 152 | 153 | infer_dir=exp/diarize/infer/$model_id.$ave_id.$infer1_config_id 154 | if [ $stage -le 5 ]; then 155 | echo -e "\n==== stage 5: inference for evaluation (speaker counting: estimated) ====" 156 | echo "inference at $infer_dir" 157 | if [ -d $infer_dir ]; then 158 | echo "$infer_dir already exists. " 159 | fi 160 | for dset in dev_clean_2_ns3_beta2_500; do 161 | work=$infer_dir/$dset/.work 162 | mkdir -p $work 163 | $infer_cmd $work/infer.log \ 164 | infer.py -c $infer1_config \ 165 | $infer_args \ 166 | data/simu/data/${dset} \ 167 | $model_dir/checkpoints/$ave_id.nnet.pth \ 168 | $infer_dir/$dset \ 169 | || exit 1 170 | done 171 | fi 172 | 173 | scoring_dir=exp/diarize/scoring/$model_id.$ave_id.$infer1_config_id 174 | if [ $stage -le 6 ]; then 175 | echo -e "\n==== stage 6: scoring for evaluation (speaker counting: estimated) ====" 176 | echo "scoring at $scoring_dir" 177 | if [ -d $scoring_dir ]; then 178 | echo "$scoring_dir already exists. " 179 | fi 180 | for dset in dev_clean_2_ns3_beta2_500; do 181 | work=$scoring_dir/$dset/.work 182 | mkdir -p $work 183 | find $infer_dir/$dset -iname "*.h5" > $work/file_list_$dset 184 | for med in 1 11; do 185 | for th in 0.3 0.4 0.5 0.6 0.7; do 186 | make_rttm.py --median=$med --threshold=$th \ 187 | --frame_shift=$infer1_frame_shift --subsampling=$infer1_subsampling --sampling_rate=$infer1_sampling_rate \ 188 | $work/file_list_$dset $scoring_dir/$dset/hyp_${th}_$med.rttm 189 | md-eval.pl -c 0.25 \ 190 | -r data/simu/data/$dset/rttm \ 191 | -s $scoring_dir/$dset/hyp_${th}_$med.rttm > $scoring_dir/$dset/result_th${th}_med${med}_collar0.25 2>/dev/null || exit 192 | done 193 | done 194 | best_score.sh $scoring_dir/$dset 195 | done 196 | fi 197 | 198 | echo "Finished !" 199 | -------------------------------------------------------------------------------- /egs/mini_librispeech/v1/run_prepare_shared.sh: -------------------------------------------------------------------------------- 1 | ../../../tools/EEND/egs/mini_librispeech/v1/run_prepare_shared.sh -------------------------------------------------------------------------------- /egs/mini_librispeech/v1/steps: -------------------------------------------------------------------------------- 1 | ../../../tools/kaldi/egs/wsj/s5/steps -------------------------------------------------------------------------------- /egs/mini_librispeech/v1/utils: -------------------------------------------------------------------------------- 1 | ../../../tools/kaldi/egs/wsj/s5/utils -------------------------------------------------------------------------------- /tools/Makefile: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Nippon Telegraph and Telephone corporation (NTT). 2 | # All rights reserved 3 | 4 | # If you want to use prebuild kaldi, make KALDI= 5 | KALDI := 6 | # Specify cuda root path installed in your environment 7 | CUDA_PATH := /usr/local/cuda 8 | CUDA_VERSION := $(shell $(CUDA_PATH)/bin/nvcc --version | tail -n1 | awk '{print substr($$5,0,length($$5)-1)}') 9 | 10 | # Get the pytorch install command 11 | ifeq ($(CUDA_VERSION),10.2) 12 | PYTORCH_INSTALL_CMD := miniconda3/envs/eend/bin/pip install torch==1.5.0 torchvision==0.6.0 13 | else ifeq ($(CUDA_VERSION),10.1) 14 | PYTORCH_INSTALL_CMD := miniconda3/envs/eend/bin/pip install torch==1.5.0+cu101 torchvision==0.6.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html 15 | else ifeq ($(CUDA_VERSION),9.2) 16 | PYTORCH_INSTALL_CMD := miniconda3/envs/eend/bin/pip install torch==1.5.0+cu92 -f https://download.pytorch.org/whl/torch_stable.html 17 | else 18 | $(info CUDA $(CUDA_VERSION) not supported) 19 | all: 20 | endif 21 | 22 | all: kaldi miniconda3/envs/eend/bin EEND env.sh 23 | 24 | ifneq ($(strip $(KALDI)),) 25 | kaldi: 26 | ln -s $(abspath $(KALDI)) kaldi 27 | else 28 | kaldi: 29 | git clone https://github.com/kaldi-asr/kaldi.git 30 | cd kaldi; git checkout 797905b0140f7169caf3d97c75a1a56a92f67d38 31 | cd kaldi/tools; $(MAKE) 32 | cd kaldi/src; ./configure --shared --use-cuda=no; $(MAKE) depend; $(MAKE) all 33 | endif 34 | 35 | miniconda3.sh: 36 | wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O $@ 37 | 38 | miniconda3: miniconda3.sh 39 | # -b: non-interactive install 40 | # -p: installed directory 41 | bash miniconda3.sh -b -p miniconda3 42 | 43 | padertorch: 44 | git clone https://github.com/fgnt/padertorch.git 45 | 46 | # virtual environment of python 47 | miniconda3/envs/eend/bin: miniconda3 padertorch 48 | miniconda3/bin/conda update -y conda 49 | miniconda3/bin/conda env create -f environment.yml 50 | miniconda3/envs/eend/bin/python -m pip install --upgrade pip 51 | $(PYTORCH_INSTALL_CMD) 52 | # padertorch 53 | cd padertorch; git checkout fa6df6a0eae4fd4455c4d82644c2e082ac28ecb2 54 | cd padertorch; ../miniconda3/envs/eend/bin/python setup.py bdist_wheel 55 | miniconda3/envs/eend/bin/pip install padertorch/dist/padertorch-0.0.0-py3-none-any.whl 56 | 57 | EEND: 58 | git clone https://github.com/hitachi-speech/EEND.git 59 | cd EEND; git checkout 9a0f211ce7e377eaea242490c3d7ec0f6adab8af 60 | # Add run_prepare_shared_eda.sh 61 | cd EEND; git show ddf4df4da55a493f550b6c45689b2d60a716c06d:egs/callhome/v1/run_prepare_shared_eda.sh > egs/callhome/v1/run_prepare_shared_eda.sh 62 | chmod u+x EEND/egs/callhome/v1/run_prepare_shared_eda.sh 63 | 64 | update: 65 | miniconda3/bin/conda env update -f environment.yml 66 | 67 | env.sh: EEND 68 | cp EEND/tools/env.sh.in env.sh 69 | echo "export LD_LIBRARY_PATH=$(CUDA_PATH)/lib64:$$LD_LIBRARY_PATH" >> env.sh 70 | -------------------------------------------------------------------------------- /tools/environment.yml: -------------------------------------------------------------------------------- 1 | name: eend 2 | channels: 3 | - defaults 4 | dependencies: 5 | - python=3.7.7 6 | - pip 7 | - pip: 8 | - h5py==2.10.0 9 | - librosa==0.7.2 10 | - numpy==1.18.4 11 | - protobuf==3.7.1 12 | - scipy==1.4.1 13 | - tqdm==4.46.0 14 | - SoundFile==0.10.3.post1 15 | - matplotlib==3.2.1 16 | - kaldiio==2.15.1 17 | - yamlargparse==1.31.1 18 | - cython==0.29.19 19 | - sacred==0.8.1 20 | - numba==0.48.0 21 | - tensorboard==2.2.2 22 | -------------------------------------------------------------------------------- /utils/best_score.sh: -------------------------------------------------------------------------------- 1 | ../tools/EEND/utils/best_score.sh --------------------------------------------------------------------------------