├── requirements.txt ├── path.sh ├── myscripts ├── convert_model.py ├── LibriMix │ ├── create_utt_group.py │ ├── prepare_librimix.sh │ ├── prepare_librimix_full_len_kaldi.py │ ├── prepare_librimix_full_len.sh │ └── prepare_librimix_kaldi.py ├── dump_segments.py └── data_prep_kaldi.py ├── dict.ltr.txt ├── config ├── decode │ └── infer_viterbi.yaml └── base.yaml ├── eval_scripts ├── LS.sh ├── LS_full_len.sh └── LS_full_len_JSM.sh ├── train_scripts ├── LS_wavLM.sh ├── LS_wavLM_spk.sh ├── LS_full_len_wavLM_spk.sh └── LS_full_len_wavLM_spk_JSM.sh └── README.md /requirements.txt: -------------------------------------------------------------------------------- 1 | editdistance 2 | tensorboardX 3 | -------------------------------------------------------------------------------- /path.sh: -------------------------------------------------------------------------------- 1 | export PATH="/export/c05/hzili1/tools/anaconda3/envs/multispk/bin:$PATH" 2 | -------------------------------------------------------------------------------- /myscripts/convert_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | 4 | parser = argparse.ArgumentParser(description='') 5 | parser.add_argument('input_path', type=str, help='Input path') 6 | parser.add_argument('w2v_path', type=str, help='Pretrain model path') 7 | parser.add_argument('output_path', type=str, help='Output path') 8 | args = parser.parse_args() 9 | 10 | def main(): 11 | model = torch.load(args.input_path) 12 | model['cfg']['model']['w2v_path'] = args.w2v_path 13 | torch.save(model, args.output_path) 14 | return 0 15 | 16 | if __name__ == '__main__': 17 | main() 18 | -------------------------------------------------------------------------------- /dict.ltr.txt: -------------------------------------------------------------------------------- 1 | | 18338755 2 | E 8179133 3 | T 6894145 4 | O 5770740 5 | A 5629002 6 | I 5599762 7 | N 4569738 8 | H 4259753 9 | S 4045500 10 | L 3210690 11 | R 3175869 12 | U 2687408 13 | D 2428865 14 | Y 2349836 15 | M 1893751 16 | W 1823450 17 | K 1488066 18 | G 1470103 19 | C 1389838 20 | F 1199731 21 | B 1058633 22 | ' 1033687 23 | P 980371 24 | V 665527 25 | J 229612 26 | X 87238 27 | Z 40639 28 | Q 24300 29 | 2 13482 30 | 3 2114 31 | 0 1918 32 | 1 1400 33 | 4 812 34 | 5 714 35 | 9 588 36 | 6 406 37 | 7 364 38 | ( 336 39 | ) 336 40 | 8 238 41 |  140 42 | - 70 43 | < 28 44 | > 28 45 | : 28 46 | [ 14 47 | ] 14 48 | \ 14 49 | , 14 50 | -------------------------------------------------------------------------------- /config/decode/infer_viterbi.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | defaults: 4 | - model: null 5 | 6 | hydra: 7 | run: 8 | dir: ${common_eval.results_path}/viterbi 9 | sweep: 10 | dir: ${common_eval.results_path} 11 | subdir: viterbi 12 | 13 | task: 14 | _name: hubert_pretraining 15 | single_target: true 16 | fine_tuning: true 17 | data: ??? 18 | normalize: ??? 19 | embed_dir: ??? 20 | JSD: false 21 | spkfield: 1 22 | nspks: 1 23 | 24 | decoding: 25 | type: viterbi 26 | unique_wer_file: true 27 | common_eval: 28 | results_path: ??? 29 | path: ??? 30 | post_process: letter 31 | dataset: 32 | max_tokens: 5600000 33 | gen_subset: ??? 34 | -------------------------------------------------------------------------------- /eval_scripts/LS.sh: -------------------------------------------------------------------------------- 1 | source path.sh 2 | 3 | fairseq_path=/export/c05/hzili1/projects/ICASSP23/fairseq_d01 4 | export PYTHONPATH="${PYTHONPATH}:${fairseq_path}" 5 | 6 | egs_path=$fairseq_path/examples/ 7 | data_dir=`pwd`/datasets/fairseq/LibriMix 8 | exp_dir=`pwd`/experiment/LS/wavLM_spk 9 | embed_dir=/export/c05/hzili1/SSL_multispk/embeddings/LS_enroll_15s/xvec 10 | #embed_dir='None' 11 | ckpt=checkpoint_last 12 | 13 | for split in test; do 14 | results_path=$exp_dir/decode/${ckpt}/LibriMix_${split} 15 | mkdir -p $results_path 16 | 17 | CUDA_VISIBLE_DEVICES=0 python $egs_path/speech_recognition/new/infer.py \ 18 | --config-dir `pwd`/config/decode \ 19 | --config-name infer_viterbi \ 20 | task.data=$data_dir \ 21 | task.normalize=false \ 22 | task.embed_dir=$embed_dir \ 23 | common_eval.results_path=$results_path \ 24 | common_eval.path=$exp_dir/checkpoints/${ckpt}.pt \ 25 | decoding.results_path=$results_path \ 26 | dataset.gen_subset=$split \ 27 | dataset.batch_size=1 28 | done 29 | -------------------------------------------------------------------------------- /eval_scripts/LS_full_len.sh: -------------------------------------------------------------------------------- 1 | source path.sh 2 | 3 | fairseq_path=/export/c05/hzili1/projects/ICASSP23/fairseq_d01 4 | export PYTHONPATH="${PYTHONPATH}:${fairseq_path}" 5 | 6 | egs_path=$fairseq_path/examples/ 7 | data_dir=`pwd`/datasets/fairseq/LibriMix_full_len 8 | exp_dir=`pwd`/experiment/LS_full_len/wavLM_spk 9 | embed_dir=/export/c05/hzili1/SSL_multispk/embeddings/LS_enroll_15s/xvec 10 | JSD=false 11 | ckpt=checkpoint_last 12 | 13 | for split in test; do 14 | results_path=$exp_dir/decode/${ckpt}/LibriMix_${split} 15 | mkdir -p $results_path 16 | 17 | CUDA_VISIBLE_DEVICES=0 python $egs_path/speech_recognition/new/infer.py \ 18 | --config-dir `pwd`/config/decode \ 19 | --config-name infer_viterbi \ 20 | task.data=$data_dir \ 21 | task.normalize=false \ 22 | task.embed_dir=$embed_dir \ 23 | task.JSD=$JSD \ 24 | common_eval.results_path=$results_path \ 25 | common_eval.path=$exp_dir/checkpoints/${ckpt}.pt \ 26 | decoding.results_path=$results_path \ 27 | dataset.gen_subset=$split \ 28 | dataset.batch_size=1 29 | done 30 | -------------------------------------------------------------------------------- /eval_scripts/LS_full_len_JSM.sh: -------------------------------------------------------------------------------- 1 | source path.sh 2 | 3 | fairseq_path=/export/c05/hzili1/projects/ICASSP23/fairseq_d01 4 | export PYTHONPATH="${PYTHONPATH}:${fairseq_path}" 5 | 6 | egs_path=$fairseq_path/examples/ 7 | data_dir=`pwd`/datasets/fairseq/LibriMix_full_len 8 | exp_dir=`pwd`/experiment/LS_full_len/wavLM_spk_JSM 9 | embed_dir=/export/c05/hzili1/SSL_multispk/embeddings/LS_enroll_15s/xvec 10 | JSD=true 11 | nspks=2 12 | ckpt=checkpoint_last 13 | 14 | for split in test_utt_group; do 15 | results_path=$exp_dir/decode/${ckpt}/LibriMix_${split} 16 | mkdir -p $results_path 17 | 18 | CUDA_VISIBLE_DEVICES=0 python $egs_path/speech_recognition/new/infer.py \ 19 | --config-dir `pwd`/config/decode \ 20 | --config-name infer_viterbi \ 21 | task.data=$data_dir \ 22 | task.normalize=false \ 23 | task.embed_dir=$embed_dir \ 24 | task.JSD=$JSD \ 25 | task.nspks=$nspks \ 26 | common_eval.results_path=$results_path \ 27 | common_eval.path=$exp_dir/checkpoints/${ckpt}.pt \ 28 | decoding.results_path=$results_path \ 29 | dataset.gen_subset=$split \ 30 | dataset.batch_size=1 31 | done 32 | -------------------------------------------------------------------------------- /train_scripts/LS_wavLM.sh: -------------------------------------------------------------------------------- 1 | source path.sh 2 | echo `eval hostname` 3 | 4 | train_subset=train-100 5 | valid_subset=dev 6 | embed_type=xvec 7 | embed_dir=/export/c05/hzili1/SSL_multispk/embeddings/LS_enroll_15s/$embed_type 8 | embed_dim=512 9 | spk_method="None" 10 | spk_layers='0' 11 | cat_layers='0' 12 | max_update=50000 13 | lr=3e-5 14 | apply_mask=false 15 | keep_spk_layers=true 16 | ln_after_adapt=false 17 | cln=false 18 | cln_bias=false 19 | port=$((15388 + 6)) 20 | fp16=true 21 | exp_dir=experiment/LS/wavLM 22 | 23 | CUDA_VISIBLE_DEVICES=0,1 fairseq-hydra-train \ 24 | task.data=`pwd`/datasets/fairseq/LibriMix \ 25 | task.label_dir=`pwd`/datasets/fairseq/LibriMix \ 26 | task.embed_dir=$embed_dir \ 27 | dataset.max_tokens=1200000 \ 28 | dataset.train_subset=${train_subset} \ 29 | dataset.valid_subset=${valid_subset} \ 30 | distributed_training.distributed_world_size=2 \ 31 | distributed_training.distributed_init_method='tcp://localhost:'${port} \ 32 | optimization.lr=[$lr] \ 33 | optimization.update_freq=[8] \ 34 | optimization.max_update=${max_update} \ 35 | model.apply_mask=$apply_mask \ 36 | model.spk_aware=false \ 37 | model.spk_embed=${embed_dim} \ 38 | model.spk_method=${spk_method} \ 39 | model.spk_layers=${spk_layers} \ 40 | model.cat_layers=${cat_layers} \ 41 | model.keep_spk_layers=${keep_spk_layers} \ 42 | model.init_linear=true \ 43 | model.ln_after_adapt=${ln_after_adapt} \ 44 | model.cln=${cln} \ 45 | model.cln_bias=${cln_bias} \ 46 | model.w2v_path=`pwd`/downloads/WavLM-Base+.pt \ 47 | hydra.run.dir=${exp_dir} \ 48 | common.fp16=${fp16} \ 49 | --config-dir `pwd`/config \ 50 | --config-name base 51 | -------------------------------------------------------------------------------- /train_scripts/LS_wavLM_spk.sh: -------------------------------------------------------------------------------- 1 | source path.sh 2 | echo `eval hostname` 3 | 4 | train_subset=train-100 5 | valid_subset=dev 6 | embed_type=xvec 7 | embed_dir=/export/c05/hzili1/SSL_multispk/embeddings/LS_enroll_15s/$embed_type 8 | embed_dim=512 9 | spk_method="None" 10 | spk_layers='0' 11 | cat_layers='0' 12 | max_update=50000 13 | lr=3e-5 14 | apply_mask=false 15 | keep_spk_layers=true 16 | ln_after_adapt=false 17 | cln=true 18 | cln_bias=false 19 | port=$((15388 + 12)) 20 | fp16=true 21 | exp_dir=experiment/LS/wavLM_spk 22 | 23 | CUDA_VISIBLE_DEVICES=0,1 fairseq-hydra-train \ 24 | task.data=`pwd`/datasets/fairseq/LibriMix \ 25 | task.label_dir=`pwd`/datasets/fairseq/LibriMix \ 26 | task.embed_dir=$embed_dir \ 27 | dataset.max_tokens=1200000 \ 28 | dataset.train_subset=${train_subset} \ 29 | dataset.valid_subset=${valid_subset} \ 30 | distributed_training.distributed_world_size=2 \ 31 | distributed_training.distributed_init_method='tcp://localhost:'${port} \ 32 | optimization.lr=[$lr] \ 33 | optimization.update_freq=[8] \ 34 | optimization.max_update=${max_update} \ 35 | model.apply_mask=$apply_mask \ 36 | model.spk_aware=true \ 37 | model.spk_embed=${embed_dim} \ 38 | model.spk_method=${spk_method} \ 39 | model.spk_layers=${spk_layers} \ 40 | model.cat_layers=${cat_layers} \ 41 | model.keep_spk_layers=${keep_spk_layers} \ 42 | model.init_linear=true \ 43 | model.ln_after_adapt=${ln_after_adapt} \ 44 | model.cln=${cln} \ 45 | model.cln_bias=${cln_bias} \ 46 | model.w2v_path=`pwd`/downloads/WavLM-Base+.pt \ 47 | hydra.run.dir=${exp_dir} \ 48 | common.fp16=${fp16} \ 49 | --config-dir `pwd`/config \ 50 | --config-name base 51 | -------------------------------------------------------------------------------- /train_scripts/LS_full_len_wavLM_spk.sh: -------------------------------------------------------------------------------- 1 | source path.sh 2 | echo `eval hostname` 3 | 4 | train_subset=train-100 5 | valid_subset=dev 6 | embed_type=xvec 7 | embed_dir=/export/c05/hzili1/SSL_multispk/embeddings/LS_enroll_15s/$embed_type 8 | embed_dim=512 9 | spk_method="None" 10 | spk_layers='0' 11 | cat_layers='0' 12 | max_update=50000 13 | lr=3e-5 14 | apply_mask=false 15 | keep_spk_layers=true 16 | ln_after_adapt=false 17 | cln=true 18 | cln_bias=false 19 | port=$((15388 + 12)) 20 | fp16=true 21 | exp_dir=experiment/LS_full_len/wavLM_spk 22 | 23 | CUDA_VISIBLE_DEVICES=0,1 fairseq-hydra-train \ 24 | task.data=`pwd`/datasets/fairseq/LibriMix_full_len \ 25 | task.label_dir=`pwd`/datasets/fairseq/LibriMix_full_len \ 26 | task.embed_dir=$embed_dir \ 27 | dataset.max_tokens=1200000 \ 28 | dataset.train_subset=${train_subset} \ 29 | dataset.valid_subset=${valid_subset} \ 30 | distributed_training.distributed_world_size=2 \ 31 | distributed_training.distributed_init_method='tcp://localhost:'${port} \ 32 | optimization.lr=[$lr] \ 33 | optimization.update_freq=[8] \ 34 | optimization.max_update=${max_update} \ 35 | model.apply_mask=$apply_mask \ 36 | model.spk_aware=true \ 37 | model.spk_embed=${embed_dim} \ 38 | model.spk_method=${spk_method} \ 39 | model.spk_layers=${spk_layers} \ 40 | model.cat_layers=${cat_layers} \ 41 | model.keep_spk_layers=${keep_spk_layers} \ 42 | model.init_linear=true \ 43 | model.ln_after_adapt=${ln_after_adapt} \ 44 | model.cln=${cln} \ 45 | model.cln_bias=${cln_bias} \ 46 | model.w2v_path=`pwd`/downloads/WavLM-Base+.pt \ 47 | hydra.run.dir=${exp_dir} \ 48 | common.fp16=${fp16} \ 49 | --config-dir `pwd`/config \ 50 | --config-name base 51 | -------------------------------------------------------------------------------- /train_scripts/LS_full_len_wavLM_spk_JSM.sh: -------------------------------------------------------------------------------- 1 | source path.sh 2 | echo `eval hostname` 3 | 4 | train_subset=train-100_utt_group 5 | valid_subset=dev_utt_group 6 | embed_type=xvec 7 | embed_dir=/export/c05/hzili1/SSL_multispk/embeddings/LS_enroll_15s/$embed_type 8 | embed_dim=512 9 | spk_method="None" 10 | spk_layers='0' 11 | cat_layers='0' 12 | max_update=50000 13 | lr=3e-5 14 | apply_mask=false 15 | keep_spk_layers=true 16 | ln_after_adapt=false 17 | cln=true 18 | cln_bias=false 19 | port=$((15388 + 11)) 20 | JSD=true 21 | nspks=2 22 | JSD_layers=1 23 | encoder_layers_finetune=12 24 | fp16=true 25 | exp_dir=experiment/LS_full_len/wavLM_spk_JSM 26 | 27 | CUDA_VISIBLE_DEVICES=0,1 fairseq-hydra-train \ 28 | task.data=`pwd`/datasets/fairseq/LibriMix_full_len \ 29 | task.label_dir=`pwd`/datasets/fairseq/LibriMix_full_len \ 30 | task.embed_dir=$embed_dir \ 31 | task.JSD=${JSD} \ 32 | task.nspks=${nspks} \ 33 | dataset.max_tokens=1200000 \ 34 | dataset.train_subset=${train_subset} \ 35 | dataset.valid_subset=${valid_subset} \ 36 | distributed_training.distributed_world_size=2 \ 37 | distributed_training.distributed_init_method='tcp://localhost:'${port} \ 38 | optimization.lr=[$lr] \ 39 | optimization.update_freq=[8] \ 40 | optimization.max_update=${max_update} \ 41 | model.apply_mask=$apply_mask \ 42 | model.spk_aware=true \ 43 | model.spk_embed=${embed_dim} \ 44 | model.spk_method=${spk_method} \ 45 | model.spk_layers=${spk_layers} \ 46 | model.cat_layers=${cat_layers} \ 47 | model.keep_spk_layers=${keep_spk_layers} \ 48 | model.init_linear=true \ 49 | model.ln_after_adapt=${ln_after_adapt} \ 50 | model.cln=${cln} \ 51 | model.cln_bias=${cln_bias} \ 52 | model.nspks=${nspks} \ 53 | model.JSD_layers=${JSD_layers} \ 54 | model.encoder_layers_finetune=${encoder_layers_finetune} \ 55 | model.w2v_path=`pwd`/downloads/WavLM-Base+.pt \ 56 | hydra.run.dir=${exp_dir} \ 57 | common.fp16=${fp16} \ 58 | --config-dir `pwd`/config \ 59 | --config-name base 60 | -------------------------------------------------------------------------------- /myscripts/LibriMix/create_utt_group.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | parser = argparse.ArgumentParser(description='Create utterance group for JSM') 5 | parser.add_argument('data_dir', type=str, help='Fairseq data directory') 6 | parser.add_argument('split', type=str, help='Data split') 7 | args = parser.parse_args() 8 | 9 | def main(): 10 | with open("{}/{}.tsv".format(args.data_dir, args.split), 'r') as fh: 11 | content_tsv = fh.readlines() 12 | with open("{}/{}.ltr".format(args.data_dir, args.split), 'r') as fh: 13 | content_ltr = fh.readlines() 14 | assert len(content_tsv) == len(content_ltr) + 1 15 | 16 | common_dir = content_tsv[0] 17 | content_tsv = content_tsv[1:] 18 | 19 | utt2seginfo = {} 20 | for i in range(len(content_ltr)): 21 | text = content_ltr[i].strip('\n') 22 | tsv_line = content_tsv[i].strip('\n') 23 | audio_path, nsamples = tsv_line.split()[0], int(tsv_line.split()[1]) 24 | segname = (audio_path.split('/')[-1]).split('.')[0] 25 | spk = segname.split('-')[0] 26 | uttname = '-'.join(segname.split('-')[1:]) 27 | seginfo = [audio_path, nsamples, text, spk] 28 | if uttname not in utt2seginfo: 29 | utt2seginfo[uttname] = [] 30 | utt2seginfo[uttname].append(seginfo) 31 | 32 | tsv_file = open("{}/{}_utt_group.tsv".format(args.data_dir, args.split), 'w') 33 | ltr_file = open("{}/{}_utt_group.ltr".format(args.data_dir, args.split), 'w') 34 | tsv_file.write(common_dir) 35 | uttlist = list(utt2seginfo.keys()) 36 | uttlist.sort() 37 | for utt in uttlist: 38 | seginfo = utt2seginfo[utt] 39 | if len(seginfo) != 2: 40 | continue 41 | tsv_file.write("{}\t{}\n".format(seginfo[0][0], seginfo[0][1])) 42 | ltr = '#'.join(["({}) {}".format(seg[3], seg[2]) for seg in seginfo]) 43 | ltr_file.write(ltr + '\n') 44 | tsv_file.close() 45 | ltr_file.close() 46 | return 0 47 | 48 | if __name__ == '__main__': 49 | main() 50 | -------------------------------------------------------------------------------- /myscripts/LibriMix/prepare_librimix.sh: -------------------------------------------------------------------------------- 1 | source path.sh 2 | 3 | # Force alignment from https://github.com/s3prl/LibriMix/tree/master/metadata/LibriSpeech 4 | rttm_dir=/export/b06/hzili1/datasets/LibriSpeech_rttm/ 5 | 6 | # Follow https://github.com/kaldi-asr/kaldi to prepare text 7 | # transcripts for LibriSpeech 8 | kaldi_LS_dir=/export/c12/hzili1/tools/kaldi/egs/librispeech/s5/data/ 9 | 10 | # Follow https://github.com/JorisCos/LibriMix to prepare Libri2Mix 11 | # Set librimix_dir to Libri2Mix/wav16k/max 12 | librimix_dir=/export/c12/hzili1/dataset/LibriMix/Libri2Mix/wav16k/max/ 13 | 14 | # Output directories 15 | kaldi_dir=datasets/kaldi/LibriMix 16 | dump_dir=datasets/dump/LibriMix 17 | fairseq_dir=datasets/fairseq/LibriMix 18 | 19 | stage=1 20 | 21 | # Prepare KALDI directory 22 | if [ $stage -le 1 ]; then 23 | python myscripts/LibriMix/prepare_librimix_kaldi.py ${rttm_dir}/train_clean_100.rttm ${kaldi_LS_dir}/train_clean_100/text ${librimix_dir}/train-100/mix_clean ${kaldi_dir}/train-100 24 | utils/fix_data_dir.sh ${kaldi_dir}/train-100 25 | 26 | for split in dev test; do 27 | python myscripts/LibriMix/prepare_librimix_kaldi.py ${rttm_dir}/${split}_clean.rttm ${kaldi_LS_dir}/${split}_clean/text ${librimix_dir}/${split}/mix_clean ${kaldi_dir}/${split} 28 | utils/fix_data_dir.sh ${kaldi_dir}/${split} 29 | done 30 | fi 31 | 32 | # Dump segment files 33 | if [ $stage -le 2 ]; then 34 | for split in train-100 dev test; do 35 | python myscripts/dump_segments.py ${kaldi_dir}/${split}/wav.scp $dump_dir/${split} --segments ${kaldi_dir}/${split}/segments 36 | awk -F' ' '{print $1,$1}' $dump_dir/${split}/wav.scp > $dump_dir/${split}/utt2spk 37 | cp ${kaldi_dir}/${split}/text $dump_dir/${split}/. 38 | utils/fix_data_dir.sh $dump_dir/${split} 39 | done 40 | fi 41 | 42 | # Convert KALDI directory to fairseq format 43 | if [ $stage -le 3 ]; then 44 | for split in train-100 dev test; do 45 | python myscripts/data_prep_kaldi.py ${dump_dir}/${split} $fairseq_dir $split $dump_dir/${split}/data 46 | done 47 | cp dict.ltr.txt $fairseq_dir/. 48 | fi 49 | -------------------------------------------------------------------------------- /myscripts/LibriMix/prepare_librimix_full_len_kaldi.py: -------------------------------------------------------------------------------- 1 | # Data preparation for the LibriMix dataset 2 | 3 | import os 4 | import numpy as np 5 | import argparse 6 | 7 | parser = argparse.ArgumentParser(description='Prepare the LibriMix dataset') 8 | parser.add_argument('text_file', type=str, help='Text file from Kaldi directory') 9 | parser.add_argument('data_dir', type=str, help='LibriMix data directory') 10 | parser.add_argument('output_dir', type=str, help='Output directory') 11 | args = parser.parse_args() 12 | 13 | def get_text(text_file): 14 | seg2text = {} 15 | with open(text_file, 'r') as fh: 16 | content = fh.readlines() 17 | for line in content: 18 | line = line.strip('\n') 19 | line_split = line.split(None, 1) 20 | seg2text[line_split[0]] = line_split[1] 21 | return seg2text 22 | 23 | def main(): 24 | seg2text = get_text(args.text_file) 25 | 26 | if not os.path.exists(args.output_dir): 27 | os.makedirs(args.output_dir) 28 | 29 | mix_files = os.listdir(args.data_dir) 30 | 31 | wav_scp_file = open("{}/wav.scp".format(args.output_dir), 'w') 32 | utt2spk_file = open("{}/utt2spk".format(args.output_dir), 'w') 33 | text_file = open("{}/text".format(args.output_dir), 'w') 34 | for mix_f in mix_files: 35 | utt = mix_f.split('.')[0] 36 | seg1, seg2 = utt.split('_')[0], utt.split('_')[1] 37 | spk1, spk2 = seg1.split('-')[0], seg2.split('-')[0] 38 | utt1, utt2 = "{}-{}".format(spk1, utt), "{}-{}".format(spk2, utt) 39 | wav_scp_file.write("{} {}\n".format(utt1, "{}/{}".format(args.data_dir, mix_f))) 40 | wav_scp_file.write("{} {}\n".format(utt2, "{}/{}".format(args.data_dir, mix_f))) 41 | utt2spk_file.write("{} {}\n".format(utt1, spk1)) 42 | utt2spk_file.write("{} {}\n".format(utt2, spk2)) 43 | text_file.write("{} {}\n".format(utt1, seg2text[seg1])) 44 | text_file.write("{} {}\n".format(utt2, seg2text[seg2])) 45 | wav_scp_file.close() 46 | utt2spk_file.close() 47 | text_file.close() 48 | return 0 49 | 50 | if __name__ == '__main__': 51 | main() 52 | -------------------------------------------------------------------------------- /myscripts/LibriMix/prepare_librimix_full_len.sh: -------------------------------------------------------------------------------- 1 | source path.sh 2 | 3 | # Force alignment from https://github.com/s3prl/LibriMix/tree/master/metadata/LibriSpeech 4 | rttm_dir=/export/b06/hzili1/datasets/LibriSpeech_rttm/ 5 | 6 | # Follow https://github.com/kaldi-asr/kaldi to prepare text 7 | # transcripts for LibriSpeech 8 | kaldi_LS_dir=/export/c12/hzili1/tools/kaldi/egs/librispeech/s5/data/ 9 | 10 | # Follow https://github.com/JorisCos/LibriMix to prepare Libri2Mix 11 | # Set librimix_dir to Libri2Mix/wav16k/max 12 | librimix_dir=/export/c12/hzili1/dataset/LibriMix/Libri2Mix/wav16k/max/ 13 | 14 | # Output directories 15 | kaldi_dir=datasets/kaldi/LibriMix_full_len 16 | dump_dir=datasets/dump/LibriMix_full_len 17 | fairseq_dir=datasets/fairseq/LibriMix_full_len 18 | 19 | stage=1 20 | 21 | # Prepare KALDI directory 22 | if [ $stage -le 1 ]; then 23 | python myscripts/LibriMix/prepare_librimix_full_len_kaldi.py ${kaldi_LS_dir}/train_clean_100/text ${librimix_dir}/train-100/mix_clean ${kaldi_dir}/train-100 24 | utils/fix_data_dir.sh ${kaldi_dir}/train-100 25 | 26 | for split in dev test; do 27 | python myscripts/LibriMix/prepare_librimix_full_len_kaldi.py ${kaldi_LS_dir}/${split}_clean/text ${librimix_dir}/${split}/mix_clean ${kaldi_dir}/${split} 28 | utils/fix_data_dir.sh ${kaldi_dir}/${split} 29 | done 30 | fi 31 | 32 | # Dump segment files 33 | if [ $stage -le 2 ]; then 34 | for split in train-100 dev test; do 35 | python myscripts/dump_segments.py ${kaldi_dir}/${split}/wav.scp $dump_dir/${split} 36 | awk -F' ' '{print $1,$1}' $dump_dir/${split}/wav.scp > $dump_dir/${split}/utt2spk 37 | cp ${kaldi_dir}/${split}/text $dump_dir/${split}/. 38 | utils/fix_data_dir.sh $dump_dir/${split} 39 | done 40 | fi 41 | 42 | # Convert KALDI directory to fairseq format 43 | if [ $stage -le 3 ]; then 44 | for split in train-100 dev test; do 45 | python myscripts/data_prep_kaldi.py ${dump_dir}/${split} $fairseq_dir $split $dump_dir/${split}/data 46 | done 47 | cp dict.ltr.txt $fairseq_dir/. 48 | fi 49 | 50 | # Create utterance group data for joint speaker modeling (JSM) 51 | if [ $stage -le 4 ]; then 52 | for split in train-100 dev test; do 53 | python myscripts/LibriMix/create_utt_group.py $fairseq_dir $split 54 | done 55 | fi 56 | -------------------------------------------------------------------------------- /myscripts/dump_segments.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import subprocess 4 | 5 | parser = argparse.ArgumentParser(description='Dump short segments') 6 | parser.add_argument('wav_scp', type=str, help='wav.scp file') 7 | parser.add_argument('output_dir', type=str, help='output directory') 8 | parser.add_argument('--segments', type=str, default=None, help='segments file') 9 | args = parser.parse_args() 10 | 11 | def load_wav_scp(fname): 12 | utt2wav = {} 13 | with open(fname, 'r') as fh: 14 | content = fh.readlines() 15 | for line in content: 16 | line = line.strip('\n') 17 | line_split = line.split() 18 | assert len(line_split) == 2 19 | utt2wav[line_split[0]] = line_split[1] 20 | return utt2wav 21 | 22 | def main(): 23 | data_dir = args.output_dir + '/data' 24 | if not os.path.exists(data_dir): 25 | os.makedirs(data_dir) 26 | 27 | utt2wav = load_wav_scp(args.wav_scp) 28 | 29 | if args.segments is not None: 30 | with open(args.segments, 'r') as fh: 31 | content = fh.readlines() 32 | 33 | wav_scp_file = open("{}/wav.scp".format(args.output_dir), 'w') 34 | for i in range(len(content)): 35 | line = content[i] 36 | line = line.strip('\n') 37 | line_split = line.split() 38 | seg, utt, start_t, end_t = line_split[0], line_split[1], round(float(line_split[2]), 2), round(float(line_split[3]), 2) 39 | output_audio = "{}/{}.wav".format(data_dir, seg) 40 | cmd = "sox {} {} trim {:.2f} {:.2f}".format(utt2wav[utt], output_audio, start_t, end_t-start_t) 41 | status, output = subprocess.getstatusoutput(cmd) 42 | assert status == 0 43 | wav_scp_file.write("{} {}\n".format(seg, output_audio)) 44 | print("Finish {}/{}".format(i + 1, len(content))) 45 | wav_scp_file.close() 46 | else: 47 | uttlist = list(utt2wav.keys()) 48 | uttlist.sort() 49 | 50 | wav_scp_file = open("{}/wav.scp".format(args.output_dir), 'w') 51 | for i in range(len(uttlist)): 52 | utt = uttlist[i] 53 | output_audio = "{}/{}.wav".format(data_dir, utt) 54 | cmd = "ln -s {} {}".format(utt2wav[utt], output_audio) 55 | status, output = subprocess.getstatusoutput(cmd) 56 | assert status == 0 57 | wav_scp_file.write("{} {}\n".format(utt, output_audio)) 58 | print("Finish {}/{}".format(i + 1, len(uttlist))) 59 | wav_scp_file.close() 60 | return 0 61 | 62 | if __name__ == '__main__': 63 | main() 64 | -------------------------------------------------------------------------------- /config/base.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | common: 4 | fp16: false 5 | log_format: json 6 | log_interval: 200 7 | tensorboard_logdir: tblog 8 | seed: 1337 9 | 10 | checkpoint: 11 | no_epoch_checkpoints: true 12 | best_checkpoint_metric: wer 13 | save_interval_updates: 2500 14 | 15 | task: 16 | _name: hubert_pretraining 17 | data: ??? 18 | fine_tuning: true 19 | label_dir: ??? 20 | normalize: false # must be consistent with pre-training 21 | labels: ["ltr"] 22 | single_target: true 23 | embed_dir: None 24 | JSD: false 25 | spkfield: 1 26 | nspks: 1 27 | 28 | dataset: 29 | num_workers: 4 30 | max_tokens: 3200000 31 | skip_invalid_size_inputs_valid_test: true 32 | validate_interval: 100000 33 | validate_interval_updates: 2500 34 | 35 | distributed_training: 36 | ddp_backend: legacy_ddp 37 | distributed_world_size: 2 38 | 39 | criterion: 40 | _name: ctc 41 | zero_infinity: true 42 | 43 | optimization: 44 | max_update: 80000 45 | lr: [0.00003] 46 | sentence_avg: true 47 | update_freq: [4] 48 | 49 | optimizer: 50 | _name: adam 51 | adam_betas: (0.9,0.98) 52 | adam_eps: 1e-08 53 | 54 | lr_scheduler: 55 | _name: tri_stage 56 | phase_ratio: [0.1, 0.4, 0.5] 57 | final_lr_scale: 0.05 58 | 59 | model: 60 | _name: hubert_ctc 61 | w2v_path: ??? 62 | apply_mask: true 63 | mask_selection: static 64 | mask_length: 10 65 | mask_other: 0 66 | mask_prob: 0.65 67 | mask_channel_selection: static 68 | mask_channel_other: 0 69 | mask_channel_prob: 0.5 70 | mask_channel_length: 64 71 | layerdrop: 0.1 72 | dropout: 0.0 73 | activation_dropout: 0.1 74 | attention_dropout: 0.0 75 | feature_grad_mult: 0.0 76 | freeze_finetune_updates: 0 77 | spk_aware: false 78 | spk_embed: 256 79 | spk_method: cat 80 | spk_layers: "0" 81 | cat_layers: "0" 82 | keep_spk_layers: false 83 | init_linear: false 84 | ln_after_adapt: false 85 | cln: false 86 | cln_bias: false 87 | cattn: false 88 | add_adapter: false 89 | adapter_size: 256 90 | adapter_act: swish 91 | adapter_init_range: 1e-3 92 | adapter_method: standard 93 | nspks: 1 94 | JSD_layers: 0 95 | encoder_layers_finetune: 12 96 | 97 | hydra: 98 | job: 99 | config: 100 | override_dirname: 101 | kv_sep: '-' 102 | item_sep: '__' 103 | exclude_keys: 104 | - run 105 | - task.data 106 | - task.label_dir 107 | - model.w2v_path 108 | - dataset.train_subset 109 | - dataset.valid_subset 110 | - criterion.wer_kenlm_model 111 | - criterion.wer_lexicon 112 | run: 113 | dir: ??? 114 | sweep: 115 | dir: ??? 116 | subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} 117 | -------------------------------------------------------------------------------- /myscripts/data_prep_kaldi.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # This script creates *.ltr, *.wrd and *.tsv 4 | # from a kaldi-style directory 5 | 6 | import os 7 | import sys 8 | import argparse 9 | import soundfile as sf 10 | 11 | parser = argparse.ArgumentParser(description='Data preparation from a kaldi-style directory') 12 | parser.add_argument('kaldi_dir', type=str, help='Kaldi style directory') 13 | parser.add_argument('output_dir', type=str, help='Output directory') 14 | parser.add_argument('output_name', type=str, help='Output name') 15 | parser.add_argument('data_dir', type=str, help='Common data directory for audio files') 16 | parser.add_argument('--min_sample', type=int, default=0) 17 | parser.add_argument('--max_sample', type=int, default=10000000000) 18 | args = parser.parse_args() 19 | 20 | def process_file(fname): 21 | utt2info = {} 22 | with open(fname, 'r') as fh: 23 | content = fh.readlines() 24 | for line in content: 25 | line = line.strip('\n') 26 | line_split = line.split(None, 1) 27 | if len(line_split) == 2: 28 | utt2info[line_split[0]] = line_split[1] 29 | else: 30 | utt2info[line_split[0]] = "" 31 | return utt2info 32 | 33 | def main(): 34 | for f in ["wav.scp", "text"]: 35 | assert os.path.exists("{}/{}".format(args.kaldi_dir, f)) 36 | utt2path, utt2text = process_file("{}/wav.scp".format(args.kaldi_dir)), process_file("{}/text".format(args.kaldi_dir)) 37 | assert len(utt2path) == len(utt2text) 38 | 39 | if not os.path.exists(args.output_dir): 40 | os.makedirs(args.output_dir) 41 | 42 | uttlist = list(utt2path.keys()) 43 | uttlist.sort() 44 | 45 | tsv_file = open("{}/{}.tsv".format(args.output_dir, args.output_name), 'w') 46 | ltr_file = open("{}/{}.ltr".format(args.output_dir, args.output_name), 'w') 47 | wrd_file = open("{}/{}.wrd".format(args.output_dir, args.output_name), 'w') 48 | tsv_file.write("{}\n".format(args.data_dir)) 49 | cnt, cnt_success = 0, 0 50 | for utt in uttlist: 51 | cnt += 1 52 | file_path, text = utt2path[utt], utt2text[utt] 53 | num_samples = sf.info(file_path).frames 54 | if int(num_samples) < args.min_sample or int(num_samples) > args.max_sample: 55 | continue 56 | tsv_file.write("{}\t{}\n".format(os.path.relpath(file_path, args.data_dir), num_samples)) 57 | wrd_out = " ".join(text.split()) 58 | wrd_file.write("{}\n".format(wrd_out)) 59 | ltr_out = " ".join(list(wrd_out.replace(" ", "|"))) + " |" 60 | ltr_file.write("{}\n".format(ltr_out)) 61 | cnt_success += 1 62 | tsv_file.close() 63 | ltr_file.close() 64 | wrd_file.close() 65 | print("Keep {} of {} utterances".format(cnt_success, cnt)) 66 | return 0 67 | 68 | if __name__ == '__main__': 69 | main() 70 | -------------------------------------------------------------------------------- /myscripts/LibriMix/prepare_librimix_kaldi.py: -------------------------------------------------------------------------------- 1 | # Data preparation for the LibriMix dataset 2 | 3 | import os 4 | import numpy as np 5 | import argparse 6 | 7 | parser = argparse.ArgumentParser(description='Prepare the LibriMix dataset') 8 | parser.add_argument('rttm_file', type=str, help='RTTM file from force alignment (from s3prl Jiatong)') 9 | parser.add_argument('text_file', type=str, help='Text file from Kaldi directory') 10 | parser.add_argument('data_dir', type=str, help='LibriMix data directory') 11 | parser.add_argument('output_dir', type=str, help='Output directory') 12 | args = parser.parse_args() 13 | 14 | def get_start_end_time(rttm_file): 15 | seg2time = {} 16 | with open(rttm_file, 'r') as fh: 17 | content = fh.readlines() 18 | for line in content: 19 | line = line.strip('\n') 20 | line_split = line.split() 21 | segname, start_t, dur, spk = line_split[1], float(line_split[3]), float(line_split[4]), line_split[7] 22 | end_t = start_t + dur 23 | segname = spk + '-' + segname 24 | if segname not in seg2time: 25 | seg2time[segname] = [] 26 | seg2time[segname].append([start_t, end_t]) 27 | for seg in seg2time.keys(): 28 | align = seg2time[seg] 29 | align = np.array(align) 30 | start_t, end_t = np.min(align), np.max(align) 31 | seg2time[seg] = [start_t, end_t] 32 | return seg2time 33 | 34 | def get_text(text_file): 35 | seg2text = {} 36 | with open(text_file, 'r') as fh: 37 | content = fh.readlines() 38 | for line in content: 39 | line = line.strip('\n') 40 | line_split = line.split(None, 1) 41 | seg2text[line_split[0]] = line_split[1] 42 | return seg2text 43 | 44 | def main(): 45 | seg2time = get_start_end_time(args.rttm_file) 46 | seg2text = get_text(args.text_file) 47 | assert len(seg2time) == len(seg2text) 48 | 49 | if not os.path.exists(args.output_dir): 50 | os.makedirs(args.output_dir) 51 | 52 | mix_files = os.listdir(args.data_dir) 53 | 54 | wav_scp_file = open("{}/wav.scp".format(args.output_dir), 'w') 55 | segments_file = open("{}/segments".format(args.output_dir), 'w') 56 | utt2spk_file = open("{}/utt2spk".format(args.output_dir), 'w') 57 | text_file = open("{}/text".format(args.output_dir), 'w') 58 | for mix_f in mix_files: 59 | utt = mix_f.split('.')[0] 60 | seg1, seg2 = utt.split('_')[0], utt.split('_')[1] 61 | spk1, spk2 = seg1.split('-')[0], seg2.split('-')[0] 62 | utt1, utt2 = "{}-{}".format(spk1, utt), "{}-{}".format(spk2, utt) 63 | wav_scp_file.write("{} {}\n".format(utt, "{}/{}".format(args.data_dir, mix_f))) 64 | utt2spk_file.write("{} {}\n".format(utt1, spk1)) 65 | utt2spk_file.write("{} {}\n".format(utt2, spk2)) 66 | segments_file.write("{} {} {:.2f} {:.2f}\n".format(utt1, utt, seg2time[seg1][0], seg2time[seg1][1])) 67 | segments_file.write("{} {} {:.2f} {:.2f}\n".format(utt2, utt, seg2time[seg2][0], seg2time[seg2][1])) 68 | text_file.write("{} {}\n".format(utt1, seg2text[seg1])) 69 | text_file.write("{} {}\n".format(utt2, seg2text[seg2])) 70 | wav_scp_file.close() 71 | segments_file.close() 72 | utt2spk_file.close() 73 | text_file.close() 74 | return 0 75 | 76 | if __name__ == '__main__': 77 | main() 78 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | We provide the code and models for our ICASSP paper [Adapting self-supervised models to multi-talker speech recognition using speaker embeddings](https://arxiv.org/abs/2211.00482). 2 | 3 | # Requirements and Installation 4 | * Python version == 3.7 5 | * torch==1.10.0, torchaudio==0.10.0 6 | 7 | ``` bash 8 | # Install fairseq 9 | git clone -b multispk --single-branch https://github.com/HuangZiliAndy/fairseq.git 10 | cd fairseq 11 | pip install --editable ./ 12 | 13 | # Install apex 14 | git clone https://github.com/NVIDIA/apex 15 | cd apex 16 | pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" \ 17 | --global-option="--deprecated_fused_adam" --global-option="--xentropy" \ 18 | 19 | pip install -r requirements.txt 20 | ``` 21 | 22 | # Data prepare 23 | 24 | ``` bash 25 | # Prepare LibriMix (https://github.com/JorisCos/LibriMix) 26 | # We only need 16k max condition in our experiment, and train-360 27 | # is not needed. 28 | 29 | # Install Kaldi (https://github.com/kaldi-asr/kaldi) 30 | 31 | # Link utils to current directory 32 | ln -s /egs/wsj/s5/utils . 33 | 34 | # Follow the following two scripts to prepare fairseq style 35 | # training data for LibriMix 36 | 37 | # The difference between the following two scripts is that 38 | # the former makes use of force alignment results to create 39 | # tight boundary (utterance-based evaluation) 40 | ./myscripts/LibriMix/prepare_librimix.sh 41 | ./myscripts/LibriMix/prepare_librimix_full_len.sh 42 | 43 | ``` 44 | 45 | Extract speaker embeddings for enrollment utterances. We use 15s speech from LibriVox (not in LibriSpeech) [LS 15 seconds enrollment](https://drive.google.com/file/d/1AmZQnTUCPW3VHZeYpBzH4fxExi_JBkv3/view?usp=share_link) as enrollment utterances. We also offer extracted [x-vector](https://drive.google.com/file/d/1kKVtXTtjwS0V4ZsYzj1863f9AXgLqMvP/view?usp=share_link) embeddings. 46 | 47 | # Training 48 | 49 | Download [wavLM](https://github.com/microsoft/UniSpeech/tree/main/WavLM) models 50 | and put it under downloads directory 51 | 52 | We offer a few example scripts for training. 53 | 54 | ``` bash 55 | # Utterance-based evaluation (wavLM Base+ without speaker embedding) 56 | ./train_scripts/LS_wavLM.sh 57 | 58 | # Utterance-based evaluation (wavLM Base+ with speaker embedding) 59 | ./train_scripts/LS_wavLM_spk.sh 60 | 61 | # Utterance group-based evaluation (wavLM Base+ with speaker embedding) 62 | ./train_scripts/LS_full_len_wavLM_spk.sh 63 | 64 | # Utterance group-based evaluation (wavLM Base+ with speaker embedding + Joint Speaker Modeling (JSM)) 65 | ./train_scripts/LS_full_len_wavLM_spk_JSM.sh 66 | ``` 67 | 68 | # Evaluation 69 | 70 | ``` bash 71 | # Utterance-based evaluation with and w/o speaker embedding 72 | ./eval_scripts/LS.sh 73 | 74 | # Utterance group-based evaluation (wavLM Base+ with speaker embedding) 75 | ./eval_scripts/LS_full_len.sh 76 | 77 | # Utterance group-based evaluation (wavLM Base+ with speaker embedding + JSM) 78 | ./eval_scripts/LS_full_len_JSM.sh 79 | ``` 80 | 81 | # Pretrained models 82 | 83 | [Utterance-based evaluation (wavLM Base+ without speaker embedding)](https://drive.google.com/file/d/1tMARaaR0YmgcJUEDVrnTFfNE2i68uC4W/view?usp=share_link) 84 | 85 | [Utterance-based evaluation (wavLM Base+ with speaker embedding)](https://drive.google.com/file/d/1XcdxeSbWa6cQAfnUmlEg1YeWdndbDQTA/view?usp=share_link) 86 | 87 | [Utterance group-based evaluation (wavLM Base+ with speaker embedding)](https://drive.google.com/file/d/1A3kXrXlyYDZhZVcHr_4NqjIjR4Kd9sgm/view?usp=share_link) 88 | 89 | [Utterance group-based evaluation (wavLM Base+ with speaker embedding + JSM)](https://drive.google.com/file/d/1gb85DUNRs5Ep6HjLuVHOKWDka5LhK9KZ/view?usp=share_link) 90 | 91 | When you are doing inference using the pretrained model, please first convert the model using 92 | 93 | ```bash 94 | python myscripts/convert_model.py /checkpoint_last.pt downloads/WavLM-Base+.pt /checkpoint_last_tmp.pt 95 | mv /checkpoint_last_tmp.pt /checkpoint_last.pt 96 | ``` 97 | 98 | # Citation 99 | 100 | Please cite as: 101 | 102 | ``` bibtex 103 | @inproceedings{huang2023adapting, 104 | title={Adapting self-supervised models to multi-talker speech recognition using speaker embeddings}, 105 | author={Huang, Zili and Raj, Desh and Garc{\'\i}a, Paola and Khudanpur, Sanjeev}, 106 | booktitle={IEEE ICASSP}, 107 | year={2023}, 108 | } 109 | ``` 110 | --------------------------------------------------------------------------------