├── .gitignore ├── LICENSE ├── README.md ├── requirements.txt ├── track1_asr ├── conf │ └── train_ebranchformer.yaml ├── local │ ├── data_prep.py │ ├── enhance.py │ ├── enhancement.sh │ ├── enhancement │ │ ├── iva.py │ │ └── pfdkf.py │ ├── generate_submission_file.py │ ├── icmcasr_data_prep.sh │ ├── normalize │ │ └── cn_tn.py │ └── segment_wavs.py ├── path.sh ├── run.sh ├── tools │ ├── alignment.sh │ ├── analyze_dataset.py │ ├── cmvn_kaldi2json.py │ ├── combine_data.sh │ ├── compute-cer.py │ ├── compute-wer.py │ ├── compute_cmvn_stats.py │ ├── compute_fbank_feats.py │ ├── copy_data_dir.sh │ ├── data │ │ ├── remove_dup_utts.sh │ │ └── split_scp.pl │ ├── decode.sh │ ├── feat_to_shape.sh │ ├── filter_scp.pl │ ├── filter_uneven_data.py │ ├── fix_data_dir.sh │ ├── flake8_hook.py │ ├── format_data.sh │ ├── fst │ │ ├── add_lex_disambig.pl │ │ ├── compile_lexicon_token_fst.sh │ │ ├── ctc_token_fst.py │ │ ├── ctc_token_fst_compact.py │ │ ├── ctc_token_fst_corrected.py │ │ ├── eps2disambig.pl │ │ ├── make_lexicon_fst.pl │ │ ├── make_tlg.sh │ │ ├── prepare_dict.py │ │ ├── remove_oovs.pl │ │ ├── rnnt_token_fst.py │ │ └── s2eps.pl │ ├── git-pre-commit │ ├── install_srilm.sh │ ├── k2 │ │ ├── make_hlg.sh │ │ ├── prepare_char.py │ │ └── prepare_mmi.sh │ ├── latency_metrics.py │ ├── make_raw_list.py │ ├── make_shard_list.py │ ├── merge_scp2txt.py │ ├── onnx2horizonbin.py │ ├── parse_options.sh │ ├── perturb_data_dir_speed.sh │ ├── reduce_data_dir.sh │ ├── remove_longshortdata.py │ ├── segment.py │ ├── setup_anaconda.sh │ ├── sph2wav.sh │ ├── spk2utt_to_utt2spk.pl │ ├── spm_decode │ ├── spm_encode │ ├── spm_train │ ├── subset_data_dir.sh │ ├── subset_scp.pl │ ├── sym2int.pl │ ├── text2token.py │ ├── utt2spk_to_spk2utt.pl │ ├── validate_data_dir.sh │ ├── validate_dict_dir.pl │ ├── validate_text.pl │ ├── wav2dur.py │ ├── wav_to_duration.sh │ └── websocket │ │ └── performance-ws.py └── wenet │ ├── README.md │ ├── __init__.py │ ├── bin │ ├── alignment.py │ ├── average_model.py │ ├── export_ipex.py │ ├── export_jit.py │ ├── export_onnx_bpu.py │ ├── export_onnx_cpu.py │ ├── export_onnx_gpu.py │ ├── recognize.py │ ├── recognize_onnx_gpu.py │ └── train.py │ ├── branchformer │ ├── __init__.py │ ├── cgmlp.py │ ├── encoder.py │ └── encoder_layer.py │ ├── cif │ └── predictor.py │ ├── dataset │ ├── __init__.py │ ├── dataset.py │ ├── kaldi_io.py │ ├── processor.py │ └── wav_distortion.py │ ├── e_branchformer │ ├── encoder.py │ └── encoder_layer.py │ ├── efficient_conformer │ ├── __init__.py │ ├── attention.py │ ├── convolution.py │ ├── encoder.py │ ├── encoder_layer.py │ └── subsampling.py │ ├── paraformer │ ├── paraformer.py │ ├── search │ │ ├── beam_search.py │ │ ├── ctc.py │ │ ├── ctc_prefix_score.py │ │ └── scorer_interface.py │ └── utils.py │ ├── squeezeformer │ ├── __init__.py │ ├── attention.py │ ├── conv2d.py │ ├── convolution.py │ ├── encoder.py │ ├── encoder_layer.py │ ├── positionwise_feed_forward.py │ └── subsampling.py │ ├── ssl │ └── bestrq │ │ ├── bestqr_model.py │ │ └── mask.py │ ├── transducer │ ├── __init__.py │ ├── joint.py │ ├── predictor.py │ ├── search │ │ ├── greedy_search.py │ │ └── prefix_beam_search.py │ └── transducer.py │ ├── transformer │ ├── __init__.py │ ├── asr_model.py │ ├── attention.py │ ├── cmvn.py │ ├── convolution.py │ ├── ctc.py │ ├── decoder.py │ ├── decoder_layer.py │ ├── embedding.py │ ├── encoder.py │ ├── encoder_layer.py │ ├── label_smoothing_loss.py │ ├── positionwise_feed_forward.py │ ├── subsampling.py │ └── swish.py │ └── utils │ ├── __init__.py │ ├── checkpoint.py │ ├── cmvn.py │ ├── common.py │ ├── config.py │ ├── context_graph.py │ ├── ctc_util.py │ ├── executor.py │ ├── file_utils.py │ ├── init_model.py │ ├── mask.py │ └── scheduler.py └── track2_asdr ├── data ├── dev_ref.txt ├── dict │ └── lang_char.txt └── train_aec_iva_near │ └── global_cmvn ├── exp └── baseline_ebranchformer │ ├── avg_10.pt │ └── train.yaml ├── local ├── compute_cpcer.py ├── generate_submission_file.py ├── merge_session_rttms.py ├── pyannote_vad.py ├── run_pyannote_vad.sh └── segment_wavs_by_rttm.py ├── path.sh ├── run.sh ├── tools └── wenet /.gitignore: -------------------------------------------------------------------------------- 1 | ## ignore the following files ## 2 | /track1_asr/data/ 3 | /track1_asr/exp/ 4 | /track1_asr/tensorboard/ 5 | **/__pycache__/ 6 | **/*.pyc 7 | /track2_asdr/exp/pyannote_vad/ 8 | /track2_asdr/data/*_aec_iva_0.95/ 9 | /track2_asdr/data/vad/ 10 | /track2_asdr/data/eval_ref.txt 11 | /track2_asdr/exp/baseline_ebranchformer/dev_aec_iva_*/ 12 | /track2_asdr/exp/baseline_ebranchformer/eval_track2_*/ 13 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Pillow 2 | pyyaml>=5.1 3 | sentencepiece 4 | tensorboard 5 | tensorboardX 6 | textgrid 7 | pytest 8 | flake8==3.8.2 9 | flake8-bugbear 10 | flake8-comprehensions 11 | flake8-executable 12 | flake8-pyi==20.5.0 13 | mccabe 14 | pycodestyle==2.6.0 15 | pyflakes==2.2.0 16 | editdistance==0.6.2 17 | pydub==0.25.1 18 | zhon==2.0.2 19 | -------------------------------------------------------------------------------- /track1_asr/conf/train_ebranchformer.yaml: -------------------------------------------------------------------------------- 1 | # network architecture 2 | # encoder related 3 | encoder: e_branchformer 4 | encoder_conf: 5 | output_size: 256 # dimension of attention 6 | attention_heads: 4 7 | linear_units: 1024 # the number of units of position-wise feed forward 8 | num_blocks: 16 # the number of encoder blocks 9 | cgmlp_linear_units: 1024 10 | cgmlp_conv_kernel: 31 11 | use_linear_after_conv: false 12 | gate_activation: identity 13 | merge_conv_kernel: 31 14 | dropout_rate: 0.1 15 | positional_dropout_rate: 0.1 16 | attention_dropout_rate: 0.1 17 | input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 18 | activation_type: 'swish' 19 | causal: false 20 | pos_enc_layer_type: 'rel_pos' 21 | attention_layer_type: 'rel_selfattn' 22 | 23 | # decoder related 24 | decoder: transformer 25 | decoder_conf: 26 | attention_heads: 4 27 | linear_units: 2048 28 | num_blocks: 6 29 | dropout_rate: 0.1 30 | positional_dropout_rate: 0.1 31 | self_attention_dropout_rate: 0.1 32 | src_attention_dropout_rate: 0.1 33 | 34 | # hybrid CTC/attention 35 | model_conf: 36 | ctc_weight: 0.3 37 | lsm_weight: 0.1 # label smoothing option 38 | length_normalized_loss: false 39 | 40 | dataset_conf: 41 | filter_conf: 42 | max_length: 40960 43 | min_length: 10 # 0.1 s 44 | token_max_length: 200 45 | token_min_length: 1 46 | resample_conf: 47 | resample_rate: 16000 48 | speed_perturb: false 49 | fbank_conf: 50 | num_mel_bins: 80 51 | frame_shift: 10 52 | frame_length: 25 53 | dither: 1.0 54 | spec_aug: true 55 | spec_aug_conf: 56 | num_t_mask: 2 57 | num_f_mask: 2 58 | max_t: 50 59 | max_f: 10 60 | shuffle: true 61 | shuffle_conf: 62 | shuffle_size: 1500 63 | sort: true 64 | sort_conf: 65 | sort_size: 500 # sort_size should be less than shuffle_size 66 | batch_conf: 67 | batch_type: 'dynamic' # static or dynamic 68 | max_frames_in_batch: 11000 # 2080TI 11G VRAM 69 | 70 | grad_clip: 5 71 | accum_grad: 1 72 | max_epoch: 120 73 | log_interval: 100 74 | 75 | optim: adam 76 | optim_conf: 77 | lr: 0.001 78 | weight_decay: 0.000001 79 | scheduler: warmuplr # pytorch v1.1.0+ required 80 | scheduler_conf: 81 | warmup_steps: 45000 82 | -------------------------------------------------------------------------------- /track1_asr/local/data_prep.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import sys 4 | 5 | sys.path.append("./normalize") 6 | 7 | from tqdm import tqdm 8 | from normalize.cn_tn import TextNorm 9 | 10 | 11 | def prepare_data(dataset_pf): 12 | # train dev eval_track1 13 | audio_dirs = glob.glob(f"{data_root}/{dataset_pf}/*") 14 | wav_scp = [] 15 | text = [] 16 | for audio_dir in tqdm(audio_dirs): 17 | # near field audio 18 | near_textgrids = glob.glob(f"{audio_dir}/DA0*.TextGrid") 19 | # 1 2 3 4 20 | near_seat_ids = [tg.split('/')[-1][3] for tg in near_textgrids] 21 | near_wav_dirs = [tg.replace('.TextGrid', '') for tg in near_textgrids] 22 | # search far-field audio by seat id 23 | far_wav_dirs = [f"{audio_dir.replace(data_root, enhanced_data_root)}/DX0{seat_id}C01" 24 | for seat_id in near_seat_ids] 25 | txt_files = glob.glob(f"{audio_dir}/*.txt") 26 | if dataset_pf == "train": 27 | for dir_ in near_wav_dirs + far_wav_dirs: 28 | wav_scp.extend([(f"{file.split('/')[-1].replace('.wav', '')}", file) 29 | for file in glob.glob(f"{dir_}/*.wav")]) 30 | else: 31 | for dir_ in far_wav_dirs: 32 | wav_scp.extend([(f"{file.split('/')[-1].replace('.wav', '')}", file) 33 | for file in glob.glob(f"{dir_}/*.wav")]) 34 | 35 | for txt_file in txt_files: 36 | lines = open(txt_file).readlines() 37 | text.extend([(line.split()[0], 38 | text_normalizer(line.split()[1].strip().replace('', '')).replace('2', '二')) 39 | for line in lines]) 40 | 41 | wav_scp.sort(key=lambda x: x[0]) 42 | text.sort(key=lambda x: x[0]) 43 | assert len(wav_scp) == len(text), f"wav_scp: {len(wav_scp)}, text: {len(text)}" 44 | for i in range(len(wav_scp)): 45 | assert wav_scp[i][0] == text[i][0] 46 | 47 | os.system(f"mkdir -p ./data/{dataset}") 48 | with open(f"./data/{dataset}/wav.scp", "w") as f: 49 | for line in wav_scp: 50 | f.write(f"{line[0]} {line[1]}\n") 51 | with open(f"./data/{dataset}/text", "w") as f: 52 | for line in text: 53 | f.write(f"{line[0]} {line[1]}\n") 54 | with open(f"./data/{dataset}/utt2spk", "w") as f: 55 | for line in wav_scp: 56 | f.write(f"{line[0]} {line[0].split('_')[0]}\n") 57 | 58 | os.system(f"./tools/utt2spk_to_spk2utt.pl ./data/{dataset}/utt2spk > ./data/{dataset}/spk2utt") 59 | os.system(f"./tools/fix_data_dir.sh ./data/{dataset}") 60 | os.system(f"./tools/validate_data_dir.sh --no-feats ./data/{dataset}") 61 | 62 | 63 | if __name__ == '__main__': 64 | data_root, enhanced_data_root, dataset = sys.argv[1], sys.argv[2], sys.argv[3] 65 | dataset_prefix = dataset.split('_aec_iva')[0] 66 | text_normalizer = TextNorm(to_banjiao=True, to_upper=True) 67 | prepare_data(dataset_prefix) 68 | -------------------------------------------------------------------------------- /track1_asr/local/enhance.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import glob 4 | import argparse 5 | import numpy as np 6 | import soundfile as sf 7 | 8 | sys.path.append("./enhancement") 9 | from tqdm import tqdm 10 | from enhancement.iva import iva 11 | from enhancement.pfdkf import pfdkf 12 | 13 | 14 | def main(args): 15 | session2utt = {} 16 | 17 | with open(args.wav_scp, "r") as f: 18 | for line in f.readlines(): 19 | utt, path = line.strip().split() 20 | session2utt.setdefault(utt.split("_")[0], []).append(path) 21 | 22 | sr = 16000 23 | for session, utts in session2utt.items(): 24 | utts.sort() 25 | os.makedirs(os.path.join(args.save_path, session), exist_ok=True) 26 | wavs = [sf.read(utt)[0] for utt in utts] 27 | # Do AEC First 28 | session_dir = os.path.dirname(utts[0]) 29 | ref_utts = glob.glob(os.path.join(session_dir, "DX0[5-6]C01.wav")) 30 | if len(ref_utts) != 0: 31 | # If there is reference, do AEC 32 | print(f"Doing AEC for session {session}") 33 | ref_utts.sort() 34 | ref1, _ = sf.read(ref_utts[0]) 35 | ref2, _ = sf.read(ref_utts[1]) 36 | for i in range(len(wavs)): 37 | mic = wavs[i] 38 | error1, echo1 = pfdkf(ref1, mic, A=0.999, keep_m_gate=0.5) 39 | error2, echo2 = pfdkf(ref2, mic, A=0.999, keep_m_gate=0.5) 40 | echo = (echo1 + echo2) / 2.0 41 | min_len = min(len(echo), len(mic)) 42 | mic = mic[:min_len] 43 | echo = echo[:min_len] 44 | wavs[i] = mic - echo 45 | else: 46 | print(f"No reference for session {session}, skip AEC") 47 | 48 | print(f"Doing IVA for session {session}") 49 | min_length = min([wav.shape[0] for wav in wavs]) 50 | wavs = [wav[:min_length] for wav in wavs] 51 | x = np.stack(wavs, axis=1) 52 | y = iva(x) 53 | 54 | for i in range(y.shape[0]): 55 | sf.write(os.path.join(args.save_path, session, f"DX0{i + 1}C01.wav"), y[i], sr) 56 | 57 | 58 | if __name__ == "__main__": 59 | parser = argparse.ArgumentParser() 60 | parser.add_argument("--wav_scp", type=str, required=True) 61 | parser.add_argument("--save_path", type=str, required=True) 62 | main(parser.parse_args()) 63 | -------------------------------------------------------------------------------- /track1_asr/local/enhancement.sh: -------------------------------------------------------------------------------- 1 | . ./path.sh || exit 1 2 | 3 | stage=0 4 | stop_stage=0 5 | nj=48 6 | 7 | . tools/parse_options.sh 8 | 9 | data_root=$1 10 | enhanced_data_root=$2 11 | dataset=$3 12 | 13 | if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then 14 | dataset_pf=$(echo "$dataset" | cut -d '_' -f 1) 15 | if [ ${dataset_pf} == 'eval' ]; then 16 | dataset_pf=$(echo "$dataset" | cut -d _ -f 1-2) 17 | fi 18 | 19 | echo "[local/enhancement.sh] generate enhanced audio for ${dataset_pf}" 20 | 21 | mkdir -p exp/enhance/${dataset_pf} 22 | 23 | ls ${data_root}/${dataset_pf}/*/DX0[1-4]C01.wav | awk -F/ '{print $(NF-1)"_"substr($NF, 1, length($NF)-4), $0}' \ 24 | > exp/enhance/${dataset_pf}/wav.scp 25 | 26 | mkdir -p exp/enhance/${dataset_pf}/split_scp 27 | mkdir -p exp/enhance/${dataset_pf}/log 28 | 29 | file_len=`wc -l exp/enhance/${dataset_pf}/wav.scp | awk '{print $1}'` 30 | subfile_len=$[${file_len} / ${nj} + 1] 31 | 32 | if [ $[${file_len} / 4] -le ${nj} ]; then 33 | nj=$[${file_len} / 4] 34 | subfile_len=4 35 | fi 36 | 37 | prefix='split' 38 | split -l $subfile_len -d -a 3 exp/enhance/${dataset_pf}/wav.scp exp/enhance/${dataset_pf}/split_scp/${prefix}_scp_ 39 | echo "you can check the log files in exp/enhance/${dataset_pf}/log for possible errors and progress" 40 | for suffix in `seq 0 $[${nj}-1]`;do 41 | suffix=`printf '%03d' $suffix` 42 | scp_subfile=exp/enhance/${dataset_pf}/split_scp/${prefix}_scp_${suffix} 43 | python3 -u local/enhance.py \ 44 | --wav_scp ${scp_subfile} \ 45 | --save_path ${enhanced_data_root}/${dataset_pf} \ 46 | > exp/enhance/${dataset_pf}/log/${prefix}.${suffix}.log 2>&1 & 47 | done 48 | wait 49 | fi 50 | -------------------------------------------------------------------------------- /track1_asr/local/enhancement/iva.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | import soundfile as sf 5 | 6 | from tqdm import tqdm 7 | 8 | 9 | def rstft(x, nfft, shift, win): 10 | nframe = int((len(x) - nfft) / shift) + 1 11 | X = np.zeros((nfft, nframe), dtype=np.float64) 12 | begin = 0 13 | 14 | for i in range(nframe): 15 | X[:, i] = x[begin:begin + nfft] 16 | begin += shift 17 | 18 | Y = np.fft.fft(X * np.tile(np.expand_dims(win, 1), (1, nframe)), axis=0) 19 | Y = Y[:nfft // 2 + 1, :] 20 | 21 | return Y 22 | 23 | 24 | def irstft(Y, shift, win=None): 25 | Y = np.vstack([Y, np.conj(np.flipud(Y[1:-1, :]))]) 26 | nfft, nframe = Y.shape 27 | 28 | syn_win = np.zeros(nfft) 29 | for i in range(1, nfft // shift + 1): 30 | syn_win += np.roll(win, i * shift) 31 | 32 | syn_win = 1.0 / syn_win 33 | syn_win = syn_win[:shift] 34 | 35 | N = (nframe - 1) * shift + nfft 36 | x = np.zeros(N) 37 | begin = 0 38 | 39 | for i in range(nframe): 40 | x[begin:begin + nfft] += np.real(np.fft.ifft(Y[:, i])) 41 | begin += shift 42 | 43 | x = x * np.tile(syn_win, N // shift) 44 | 45 | return x 46 | 47 | 48 | def iva(sp_in, fs=16000, eps=1e-6, epoch=30, nfft=512, nshift=256): 49 | nTime, M = sp_in.shape 50 | nf = nfft // 2 + 1 51 | 52 | ana_win = 0.54 - 0.46 * np.cos(2 * np.pi * np.arange(nfft) / (nfft - 1)) 53 | 54 | X = [] 55 | for k in range(M): 56 | X.append(rstft(sp_in[:, k], nfft, nshift, ana_win)) 57 | X = np.stack(X, axis=2) 58 | X = X.astype(np.complex64) 59 | 60 | W = np.tile(np.eye(M, dtype=np.complex64).reshape(M, M, 1), (1, 1, nf)) 61 | Winv = np.tile(np.eye(M, dtype=np.complex64).reshape(M, M, 1), (1, 1, nf)) 62 | Wt = np.tile(np.eye(M, dtype=np.complex64).reshape(M, M, 1), (1, 1, nf)) 63 | Vk = np.tile(np.eye(M, dtype=np.complex64).reshape(M, M, 1, 1), (1, 1, nf, M)) 64 | Rx = np.multiply(np.transpose(np.expand_dims(X, 3), (2, 3, 0, 1)), 65 | np.transpose(np.conj(np.expand_dims(X, 3)), (3, 2, 0, 1))) 66 | 67 | for iter in tqdm(range(epoch), desc="IVA Iteration"): 68 | Yp = np.sum(np.multiply(np.transpose(np.expand_dims(W, 3), (0, 1, 3, 2)), 69 | np.transpose(np.expand_dims(X, 3), (3, 2, 1, 0))), axis=1) 70 | R = np.sum(np.real(Yp * np.conj(Yp)), axis=2) 71 | Gr = 1 / (np.sqrt(R) + eps) 72 | 73 | for k in range(M): 74 | Vk[:, :, :, k] = np.mean(np.multiply(Rx, Gr[k, :].reshape(1, 1, 1, Gr.shape[1])), axis=3) 75 | 76 | for i in range(nf): 77 | wk = np.linalg.solve(Vk[:, :, i, k] + 0 * np.eye(M, dtype=np.complex64), Winv[:, k, i]) 78 | wk = wk / (np.sqrt(np.real(np.dot(wk.T.conj(), Winv[:, k, i])))) 79 | W[k, :, i] = wk.T.conj() 80 | 81 | for i in range(nf): 82 | Winv[:, :, i] = np.linalg.pinv(W[:, :, i]) 83 | 84 | # Normalize W 85 | for i in range(nf): 86 | W[:, :, i] = np.dot(np.diag(np.diag(Winv[:, :, i])), W[:, :, i]) 87 | 88 | # Get output 89 | Xp = np.transpose(X, (2, 1, 0)) 90 | Y = Xp 91 | for i in range(nf): 92 | Y[:, :, i] = np.dot(W[:, :, i], Xp[:, :, i]) 93 | 94 | # iSTFT 95 | sp_out = [] 96 | for k in range(M): 97 | sp_out.append(irstft(Y[k, :, :].T, nshift, ana_win)) 98 | sp_out = np.stack(sp_out, axis=0) 99 | 100 | return sp_out 101 | 102 | 103 | def main(args): 104 | session2utt = {} 105 | 106 | with open(args.wav_scp, "r") as f: 107 | for line in f.readlines(): 108 | utt, path = line.strip().split() 109 | session2utt.setdefault(utt.split("_")[0], []).append(path) 110 | 111 | sr = 16000 112 | for session, utts in session2utt.items(): 113 | utts.sort() 114 | 115 | wav1, _ = sf.read(utts[0]) 116 | wav2, _ = sf.read(utts[1]) 117 | wav3, _ = sf.read(utts[2]) 118 | wav4, _ = sf.read(utts[3]) 119 | 120 | min_length = min(wav1.shape[0], wav2.shape[0], wav3.shape[0], wav4.shape[0]) 121 | x = np.stack([wav1[:min_length], wav2[:min_length], wav3[:min_length], wav4[:min_length]], axis=1) 122 | 123 | y = iva(x) 124 | 125 | os.makedirs(os.path.join(args.save_path, session), exist_ok=True) 126 | for i in range(4): 127 | sf.write(os.path.join(args.save_path, session, f"DX0{i + 1}C01.wav"), y[i], sr) 128 | 129 | 130 | if __name__ == '__main__': 131 | parser = argparse.ArgumentParser() 132 | parser.add_argument("--wav_scp", type=str, required=True) 133 | parser.add_argument("--save_path", type=str, required=True) 134 | main(parser.parse_args()) 135 | -------------------------------------------------------------------------------- /track1_asr/local/enhancement/pfdkf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import argparse 4 | import numpy as np 5 | import soundfile as sf 6 | 7 | from numpy.fft import rfft as fft 8 | from numpy.fft import irfft as ifft 9 | 10 | 11 | class PFDKF: 12 | def __init__(self, N, M, A=0.999, P_initial=1, keep_m_gate=0.5, res=False): 13 | """Initial state of partitioned block based frequency domain kalman filter 14 | 15 | Args: 16 | N (int): Num of blocks. 17 | M (int): Filter length in one block. 18 | A (float, optional): Diag coeff, if more nonlinear components, 19 | can be set to 0.99. Defaults to 0.999. 20 | P_initial (int, optional): About the begining covergence. 21 | Defaults to 10. 22 | keep_m_gate (float, optional): When more linear, 23 | can be set to 0.2 or less. Defaults to 0.5. 24 | """ 25 | # M = 2*V 26 | self.N = N 27 | self.M = M 28 | self.A = A 29 | self.A2 = A ** 2 30 | self.m_smooth_factor = keep_m_gate 31 | self.res = res 32 | 33 | self.x = np.zeros(shape=(2 * self.M), dtype=np.float32) 34 | self.m = np.zeros(shape=(self.M + 1), dtype=np.float32) 35 | self.P = np.full((self.N, self.M + 1), P_initial) 36 | self.X = np.zeros((self.N, self.M + 1), dtype=complex) 37 | self.H = np.zeros((self.N, self.M + 1), dtype=complex) 38 | self.mu = np.zeros((self.N, self.M + 1), dtype=complex) 39 | self.half_window = np.concatenate(([1] * self.M, [0] * self.M)) 40 | 41 | def filt(self, x, d): 42 | assert (len(x) == self.M) 43 | self.x = np.concatenate([self.x[self.M:], x]) 44 | X = fft(self.x) 45 | self.X[1:] = self.X[:-1] 46 | self.X[0] = X 47 | Y = np.sum(self.H * self.X, axis=0) 48 | y = ifft(Y).real[self.M:] 49 | e = d - y 50 | 51 | e_fft = np.concatenate( 52 | (np.zeros(shape=(self.M,), dtype=np.float32), e)) 53 | self.E = fft(e_fft) 54 | self.m = self.m_smooth_factor * self.m + \ 55 | (1 - self.m_smooth_factor) * np.abs(self.E) ** 2 56 | R = np.sum(self.X * self.P * self.X.conj(), 0) + 2 * self.m / self.N 57 | self.mu = self.P / (R + 1e-10) 58 | if self.res: 59 | W = 1 - np.sum(self.mu * np.abs(self.X) ** 2, 0) 60 | E_res = W * self.E 61 | e = ifft(E_res).real[self.M:].real 62 | y = d - e 63 | return e, y 64 | 65 | def update(self): 66 | G = self.mu * self.X.conj() 67 | self.P = self.A2 * (1 - 0.5 * G * self.X) * self.P + \ 68 | (1 - self.A2) * np.abs(self.H) ** 2 69 | self.H = self.A * (self.H + fft(self.half_window * (ifft(self.E * G).real))) 70 | 71 | 72 | def pfdkf(x, d, N=10, M=256, A=0.999, P_initial=1, keep_m_gate=0.1): 73 | ft = PFDKF(N, M, A, P_initial, keep_m_gate) 74 | num_block = min(len(x), len(d)) // M 75 | 76 | e = np.zeros(num_block * M) 77 | y = np.zeros(num_block * M) 78 | for n in range(num_block): 79 | x_n = x[n * M:(n + 1) * M] 80 | d_n = d[n * M:(n + 1) * M] 81 | e_n, y_n = ft.filt(x_n, d_n) 82 | ft.update() 83 | e[n * M:(n + 1) * M] = e_n 84 | y[n * M:(n + 1) * M] = y_n 85 | return e, y 86 | 87 | 88 | def main(args): 89 | session2utt = {} 90 | 91 | with open(args.wav_scp, "r") as f: 92 | for line in f.readlines(): 93 | utt, path = line.strip().split() 94 | session2utt.setdefault(utt.split("_")[0], []).append(path) 95 | 96 | sr = 16000 97 | for session, utts in session2utt.items(): 98 | utts.sort() 99 | os.makedirs(os.path.join(args.save_path, session), exist_ok=True) 100 | 101 | wav1, _ = sf.read(utts[0]) 102 | wav2, _ = sf.read(utts[1]) 103 | wav3, _ = sf.read(utts[2]) 104 | wav4, _ = sf.read(utts[3]) 105 | 106 | ref_utts = glob.glob(os.path.join(args.ref_path, session, "DX0[5-6]C01.wav")) 107 | if len(ref_utts) == 0: 108 | sf.write(os.path.join(args.save_path, session, "DX01C01.wav"), wav1, sr) 109 | sf.write(os.path.join(args.save_path, session, "DX02C01.wav"), wav2, sr) 110 | sf.write(os.path.join(args.save_path, session, "DX03C01.wav"), wav3, sr) 111 | sf.write(os.path.join(args.save_path, session, "DX04C01.wav"), wav4, sr) 112 | continue 113 | 114 | ref_utts.sort() 115 | 116 | mic_list = np.stack([wav1, wav2, wav3, wav4]) 117 | ref1, _ = sf.read(ref_utts[0]) 118 | ref2, _ = sf.read(ref_utts[1]) 119 | errors = [] 120 | for i in range(len(mic_list)): 121 | mic = mic_list[i] 122 | error1, echo1 = pfdkf(ref1, mic, A=0.999, keep_m_gate=0.5) 123 | error2, echo2 = pfdkf(ref2, mic, A=0.999, keep_m_gate=0.5) 124 | echo = (echo1 + echo2) / 2.0 125 | min_len = min(len(echo), len(mic)) 126 | mic = mic[:min_len] 127 | echo = echo[:min_len] 128 | errors.append(mic - echo) 129 | 130 | out = np.stack(errors, axis=0) 131 | 132 | for i in range(4): 133 | sf.write(os.path.join(args.save_path, session, f"DX0{i + 1}C01.wav"), out[i], sr) 134 | 135 | 136 | if __name__ == "__main__": 137 | parser = argparse.ArgumentParser() 138 | parser.add_argument("--wav_scp", type=str, required=True) 139 | parser.add_argument("--ref_path", type=str, required=True) 140 | parser.add_argument("--save_path", type=str, required=True) 141 | main(parser.parse_args()) 142 | -------------------------------------------------------------------------------- /track1_asr/local/generate_submission_file.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | 5 | if __name__ == '__main__': 6 | assert len(sys.argv) == 2, "Usage: python3 generate_submission_file.py " 7 | test_dir = sys.argv[1] 8 | text_file = f"{test_dir}/text" 9 | if not os.path.exists(text_file): 10 | print(f"Error: {text_file} not found. Please check the path.") 11 | exit(1) 12 | print(f"Generating submission file for {test_dir} ...") 13 | lines = [] 14 | for line in open(text_file, 'r').readlines(): 15 | line_splits = line.split() 16 | if len(line_splits) != 2: 17 | # blank decoding results 18 | lines.append([line_splits[0], '']) 19 | else: 20 | lines.append([line_splits[0], line_splits[1].strip()]) 21 | 22 | # Hide the channel information in utt id 23 | for i, line in enumerate(lines): 24 | utt_id_splits = line[0].split('_') 25 | utt_id = '_'.join(utt_id_splits[:2] + utt_id_splits[3:]) 26 | lines[i][0] = utt_id 27 | lines = sorted(lines, key=lambda x: x[0]) 28 | with open(f"{test_dir}/submission.txt", 'w') as f: 29 | for line in lines: 30 | f.write(f"{line[0]} {line[1]}\n") 31 | -------------------------------------------------------------------------------- /track1_asr/local/icmcasr_data_prep.sh: -------------------------------------------------------------------------------- 1 | . ./path.sh || exit 1 2 | 3 | stage=0 4 | stop_stage=2 5 | nj=48 6 | 7 | . tools/parse_options.sh 8 | 9 | data_root=$1 10 | enhanced_data_root=$2 11 | dataset=$3 12 | 13 | 14 | if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then 15 | echo "[local/icmcasr_data_prep.sh] stage 0: AEC + IVA Enhancement" 16 | local/enhancement.sh --nj ${nj} ${data_root} ${enhanced_data_root} ${dataset} 17 | fi 18 | 19 | if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then 20 | echo "[local/icmcasr_data_prep.sh] stage 1: Segment ${dataset} wavs" 21 | python3 local/segment_wavs.py ${data_root} ${enhanced_data_root} ${dataset} ${nj} 22 | fi 23 | 24 | if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then 25 | echo "[local/icmcasr_data_prep.sh] stage 2: Prepare data files" 26 | python3 local/data_prep.py ${data_root} ${enhanced_data_root} ${dataset} 27 | fi 28 | -------------------------------------------------------------------------------- /track1_asr/local/segment_wavs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import sys 4 | import glob 5 | import math 6 | import textgrid 7 | import traceback 8 | import multiprocessing 9 | 10 | from tqdm import tqdm 11 | from functools import partial 12 | from pydub import AudioSegment 13 | from zhon.hanzi import punctuation 14 | from concurrent.futures import ProcessPoolExecutor 15 | 16 | 17 | def segment_wav(wav_path): 18 | dur = round(float(os.popen(f"soxi -D {wav_path}").read().strip()), 3) 19 | # 001_DA02 20 | wav_id = '_'.join(wav_path.split('/')[-2:]).replace('.wav', '') 21 | # only segment near-field wavs of training set 22 | if wav_id.split('_')[-1].startswith("DA0") and dataset != "train": 23 | return 24 | utt2text = [] 25 | if wav_id.split('_')[-1].startswith("DA0"): 26 | wav = AudioSegment.from_wav(wav_path) 27 | os.system(f"mkdir -p {wav_path.replace('.wav', '')}") 28 | tg = textgrid.TextGrid() 29 | try: 30 | tg.read(wav_path.replace('.wav', '.TextGrid')) 31 | except: 32 | print(f"Error With: {wav_path.replace('.wav', '.TextGrid')}") 33 | return 34 | for tier in tg.tiers: 35 | speaker = tier.name 36 | for interval in tier: 37 | if not interval.mark or interval.mark == "": 38 | continue 39 | text = re.sub(f"[{punctuation} *]", "", interval.mark) 40 | if text == "": 41 | continue 42 | start, end = float(interval.minTime), float(interval.maxTime) 43 | if end > dur or end - start < 0.03: 44 | continue 45 | # P0117_001_DA02_005156-006106.wav 46 | utterance_id = f"{'_'.join([speaker] + wav_id.split('_'))}_{int(start * 1000):0>6}-{int(end * 1000):0>6}" 47 | utt2text.append((utterance_id, text)) 48 | export_wav_path = f"{wav_path.replace('.wav', '')}/{utterance_id}.wav" 49 | wav[start * 1000:end * 1000].export(export_wav_path, format="wav") 50 | with open(f"{'/'.join(wav_path.split('/')[:-1])}/{wav_path.split('/')[-1].replace('.wav', '.txt')}", 'w') as f: 51 | for item in utt2text: 52 | f.write(f"{item[0]} {item[1]}\n") 53 | else: 54 | # for dev and eval_track1 sets 55 | # segment the enhanced far-field wavs 56 | enhanced_wav_path = wav_path.replace(data_root, enhanced_data_root) 57 | wav = AudioSegment.from_wav(enhanced_wav_path) 58 | tg_files = glob.glob(f"{os.path.dirname(wav_path)}/DA0*.TextGrid") 59 | for tg_file in tg_files: 60 | # only use the far-field wav closest to the speaker for training 61 | if wav_id.split('_')[1][3] != tg_file.split('/')[-1][3]: 62 | continue 63 | os.system(f"mkdir -p {enhanced_wav_path.replace('.wav', '')}") 64 | tg = textgrid.TextGrid() 65 | try: 66 | tg.read(tg_file) 67 | except: 68 | print(f"Error With: {tg_file}") 69 | continue 70 | for tier in tg.tiers: 71 | speaker = tier.name 72 | for interval in tier: 73 | if not interval.mark or interval.mark == "": 74 | continue 75 | text = re.sub(f"[{punctuation} *]", "", interval.mark) 76 | if text == "": 77 | continue 78 | start, end = float(interval.minTime), float(interval.maxTime) 79 | if end > dur or end - start < 0.03: 80 | continue 81 | # P0117_001_DA02_005156-006106.wav 82 | utterance_id = f"{'_'.join([speaker] + wav_id.split('_'))}_{int(start * 1000):0>6}-{int(end * 1000):0>6}" 83 | utt2text.append((utterance_id, text)) 84 | export_wav_path = f"{enhanced_wav_path.replace('.wav', '')}/{utterance_id}.wav" 85 | wav[start * 1000:end * 1000].export(export_wav_path, format="wav") 86 | with open(f"{'/'.join(wav_path.split('/')[:-1])}/{wav_path.split('/')[-1].replace('.wav', '.txt')}", 'w') as f: 87 | for item in utt2text: 88 | f.write(f"{item[0]} {item[1]}\n") 89 | 90 | 91 | def multiThread_use_ProcessPoolExecutor_dicarg(scp, numthread, func, args): 92 | executor = ProcessPoolExecutor(max_workers=numthread) 93 | results = [] 94 | for item in scp: 95 | results.append(executor.submit(partial(func, item, **args))) 96 | return [result.result() for result in tqdm(results)] 97 | 98 | 99 | if __name__ == '__main__': 100 | data_root, enhanced_data_root, dataset, nj = sys.argv[1], sys.argv[2], sys.argv[3], int(sys.argv[4]) 101 | dataset = dataset.split('_aec_iva')[0] 102 | audio_root = f"{data_root}/{dataset}" 103 | wav_scp = [wav for wav in glob.glob(f'{audio_root}/*/*.wav') if not wav.split('/')[-1].startswith(("DX05", "DX06"))] 104 | wav_scp.sort() 105 | multiThread_use_ProcessPoolExecutor_dicarg(wav_scp, nj, segment_wav, {}) 106 | -------------------------------------------------------------------------------- /track1_asr/path.sh: -------------------------------------------------------------------------------- 1 | # NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C 2 | export PYTHONIOENCODING=UTF-8 3 | #export PYTHONPATH=../../../:$PYTHONPATH 4 | export PYTHONPATH=./wenet:$PYTHONPATH 5 | -------------------------------------------------------------------------------- /track1_asr/tools/alignment.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright 2019 Mobvoi Inc. All Rights Reserved. 4 | . ./path.sh || exit 1; 5 | 6 | stage=0 # start from 0 if you need to start from data preparation 7 | stop_stage=0 8 | 9 | nj=16 10 | dict=data/dict/lang_char.txt 11 | 12 | dir=exp/ 13 | config=$dir/train.yaml 14 | # model trained with trim tail will get a better alignment result 15 | # (Todo) cif/attention/rnnt alignment 16 | checkpoint=$dir/final.pt 17 | 18 | set=test 19 | ali_format=ali_format.data 20 | ali_result=ali.res 21 | blank_thres=0.9999 22 | thres=0.00001 23 | . tools/parse_options.sh || exit 1; 24 | 25 | if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then 26 | # Prepare required data for ctc alignment 27 | echo "Prepare data, prepare required format" 28 | for x in $set; do 29 | tools/make_raw_list.py data/$x/wav.scp data/$x/text \ 30 | ali_format 31 | done 32 | fi 33 | 34 | 35 | if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then 36 | # Test model, please specify the model you want to use by --checkpoint 37 | mkdir -p exp_${thres} 38 | python wenet/bin/alignment.py --gpu -1 \ 39 | --config $config \ 40 | --input_file $ali_format \ 41 | --checkpoint $checkpoint \ 42 | --batch_size 1 \ 43 | --dict $dict \ 44 | --result_file $ali_result \ 45 | --thres $thres \ 46 | --blank_thres $blank_thres \ 47 | --gen_praat 48 | 49 | fi 50 | 51 | 52 | -------------------------------------------------------------------------------- /track1_asr/tools/cmvn_kaldi2json.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | import logging 4 | import sys 5 | import json 6 | 7 | def kaldi2json(kaldi_cmvn_file): 8 | means = [] 9 | variance = [] 10 | with open(kaldi_cmvn_file, 'r') as fid: 11 | # kaldi binary file start with '\0B' 12 | if fid.read(2) == '\0B': 13 | logging.error('kaldi cmvn binary file is not supported, please ' 14 | 'recompute it by: compute-cmvn-stats --binary=false ' 15 | ' scp:feats.scp global_cmvn') 16 | sys.exit(1) 17 | fid.seek(0) 18 | arr = fid.read().split() 19 | assert (arr[0] == '[') 20 | assert (arr[-2] == '0') 21 | assert (arr[-1] == ']') 22 | feat_dim = int((len(arr) - 2 - 2) / 2) 23 | for i in range(1, feat_dim + 1): 24 | means.append(float(arr[i])) 25 | count = float(arr[feat_dim + 1]) 26 | for i in range(feat_dim + 2, 2 * feat_dim + 2): 27 | variance.append(float(arr[i])) 28 | 29 | cmvn_info = {'mean_stat:' : means, 30 | 'var_stat' : variance, 31 | 'frame_num' : count} 32 | return cmvn_info 33 | 34 | if __name__ == '__main__': 35 | with open(sys.argv[2], 'w') as fout: 36 | cmvn = kaldi2json(sys.argv[1]) 37 | fout.write(json.dumps(cmvn)) 38 | -------------------------------------------------------------------------------- /track1_asr/tools/combine_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2012 Johns Hopkins University (Author: Daniel Povey). Apache 2.0. 3 | # 2014 David Snyder 4 | 5 | # This script combines the data from multiple source directories into 6 | # a single destination directory. 7 | 8 | # See http://kaldi-asr.org/doc/data_prep.html#data_prep_data for information 9 | # about what these directories contain. 10 | 11 | # Begin configuration section. 12 | extra_files= # specify additional files in 'src-data-dir' to merge, ex. "file1 file2 ..." 13 | skip_fix=false # skip the fix_data_dir.sh in the end 14 | # End configuration section. 15 | 16 | echo "$0 $@" # Print the command line for logging 17 | 18 | if [ -f path.sh ]; then . ./path.sh; fi 19 | if [ -f parse_options.sh ]; then . parse_options.sh || exit 1; fi 20 | 21 | if [ $# -lt 2 ]; then 22 | echo "Usage: combine_data.sh [--extra-files 'file1 file2'] ..." 23 | echo "Note, files that don't appear in all source dirs will not be combined," 24 | echo "with the exception of utt2uniq and segments, which are created where necessary." 25 | exit 1 26 | fi 27 | 28 | dest=$1; 29 | shift; 30 | 31 | first_src=$1; 32 | 33 | rm -r $dest 2>/dev/null 34 | mkdir -p $dest; 35 | 36 | export LC_ALL=C 37 | 38 | for dir in $*; do 39 | if [ ! -f $dir/utt2spk ]; then 40 | echo "$0: no such file $dir/utt2spk" 41 | exit 1; 42 | fi 43 | done 44 | 45 | # Check that frame_shift are compatible, where present together with features. 46 | dir_with_frame_shift= 47 | for dir in $*; do 48 | if [[ -f $dir/feats.scp && -f $dir/frame_shift ]]; then 49 | if [[ $dir_with_frame_shift ]] && 50 | ! cmp -s $dir_with_frame_shift/frame_shift $dir/frame_shift; then 51 | echo "$0:error: different frame_shift in directories $dir and " \ 52 | "$dir_with_frame_shift. Cannot combine features." 53 | exit 1; 54 | fi 55 | dir_with_frame_shift=$dir 56 | fi 57 | done 58 | 59 | # W.r.t. utt2uniq file the script has different behavior compared to other files 60 | # it is not compulsary for it to exist in src directories, but if it exists in 61 | # even one it should exist in all. We will create the files where necessary 62 | has_utt2uniq=false 63 | for in_dir in $*; do 64 | if [ -f $in_dir/utt2uniq ]; then 65 | has_utt2uniq=true 66 | break 67 | fi 68 | done 69 | 70 | if $has_utt2uniq; then 71 | # we are going to create an utt2uniq file in the destdir 72 | for in_dir in $*; do 73 | if [ ! -f $in_dir/utt2uniq ]; then 74 | # we assume that utt2uniq is a one to one mapping 75 | cat $in_dir/utt2spk | awk '{printf("%s %s\n", $1, $1);}' 76 | else 77 | cat $in_dir/utt2uniq 78 | fi 79 | done | sort -k1 > $dest/utt2uniq 80 | echo "$0: combined utt2uniq" 81 | else 82 | echo "$0 [info]: not combining utt2uniq as it does not exist" 83 | fi 84 | # some of the old scripts might provide utt2uniq as an extrafile, so just remove it 85 | extra_files=$(echo "$extra_files"|sed -e "s/utt2uniq//g") 86 | 87 | # segments are treated similarly to utt2uniq. If it exists in some, but not all 88 | # src directories, then we generate segments where necessary. 89 | has_segments=false 90 | for in_dir in $*; do 91 | if [ -f $in_dir/segments ]; then 92 | has_segments=true 93 | break 94 | fi 95 | done 96 | 97 | if $has_segments; then 98 | for in_dir in $*; do 99 | if [ ! -f $in_dir/segments ]; then 100 | echo "$0 [info]: will generate missing segments for $in_dir" 1>&2 101 | utils/data/get_segments_for_data.sh $in_dir 102 | else 103 | cat $in_dir/segments 104 | fi 105 | done | sort -k1 > $dest/segments 106 | echo "$0: combined segments" 107 | else 108 | echo "$0 [info]: not combining segments as it does not exist" 109 | fi 110 | 111 | for file in utt2spk utt2lang utt2dur utt2num_frames reco2dur feats.scp text cmvn.scp vad.scp reco2file_and_channel wav.scp spk2gender $extra_files; do 112 | exists_somewhere=false 113 | absent_somewhere=false 114 | for d in $*; do 115 | if [ -f $d/$file ]; then 116 | exists_somewhere=true 117 | else 118 | absent_somewhere=true 119 | fi 120 | done 121 | 122 | if ! $absent_somewhere; then 123 | set -o pipefail 124 | ( for f in $*; do cat $f/$file; done ) | sort -k1 > $dest/$file || exit 1; 125 | set +o pipefail 126 | echo "$0: combined $file" 127 | else 128 | if ! $exists_somewhere; then 129 | echo "$0 [info]: not combining $file as it does not exist" 130 | else 131 | echo "$0 [info]: **not combining $file as it does not exist everywhere**" 132 | fi 133 | fi 134 | done 135 | 136 | tools/utt2spk_to_spk2utt.pl <$dest/utt2spk >$dest/spk2utt 137 | 138 | if [[ $dir_with_frame_shift ]]; then 139 | cp $dir_with_frame_shift/frame_shift $dest 140 | fi 141 | 142 | if ! $skip_fix ; then 143 | tools/fix_data_dir.sh $dest || exit 1; 144 | fi 145 | 146 | exit 0 147 | -------------------------------------------------------------------------------- /track1_asr/tools/compute_fbank_feats.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Chao Yang) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import logging 17 | 18 | import torchaudio 19 | import torchaudio.compliance.kaldi as kaldi 20 | 21 | import wenet.dataset.kaldi_io as kaldi_io 22 | 23 | # The "sox" backends are deprecated and will be removed in 0.9.0 release. 24 | # So here we use sox_io backend 25 | torchaudio.set_audio_backend("sox_io") 26 | 27 | 28 | def parse_opts(): 29 | parser = argparse.ArgumentParser(description='training your network') 30 | parser.add_argument('--num_mel_bins', 31 | default=80, 32 | type=int, 33 | help='Number of triangular mel-frequency bins') 34 | parser.add_argument('--frame_length', 35 | type=int, 36 | default=25, 37 | help='Frame length in milliseconds') 38 | parser.add_argument('--frame_shift', 39 | type=int, 40 | default=10, 41 | help='Frame shift in milliseconds') 42 | parser.add_argument('--dither', 43 | type=int, 44 | default=0.0, 45 | help='Dithering constant (0.0 means no dither)') 46 | parser.add_argument('--segments', default=None, help='segments file') 47 | parser.add_argument('wav_scp', help='wav scp file') 48 | parser.add_argument('out_ark', help='output ark file') 49 | parser.add_argument('out_scp', help='output scp file') 50 | args = parser.parse_args() 51 | return args 52 | 53 | 54 | # wav format: 55 | def load_wav_scp(wav_scp_file): 56 | wav_list = [] 57 | with open(wav_scp_file, 'r', encoding='utf8') as fin: 58 | for line in fin: 59 | arr = line.strip().split() 60 | assert len(arr) == 2 61 | wav_list.append((arr[0], arr[1])) 62 | return wav_list 63 | 64 | 65 | # wav format: 66 | def load_wav_scp_dict(wav_scp_file): 67 | wav_dict = {} 68 | with open(wav_scp_file, 'r', encoding='utf8') as fin: 69 | for line in fin: 70 | arr = line.strip().split() 71 | assert len(arr) == 2 72 | wav_dict[arr[0]] = arr[1] 73 | return wav_dict 74 | 75 | 76 | # Segments format: 77 | def load_wav_segments(wav_scp_file, segments_file): 78 | wav_dict = load_wav_scp_dict(wav_scp_file) 79 | audio_list = [] 80 | with open(segments_file, 'r', encoding='utf8') as fin: 81 | for line in fin: 82 | arr = line.strip().split() 83 | assert len(arr) == 4 84 | key = arr[0] 85 | wav_file = wav_dict[arr[1]] 86 | start = float(arr[2]) 87 | end = float(arr[3]) 88 | audio_list.append((key, wav_file, start, end)) 89 | return audio_list 90 | 91 | 92 | if __name__ == '__main__': 93 | args = parse_opts() 94 | logging.basicConfig(level=logging.DEBUG, 95 | format='%(asctime)s %(levelname)s %(message)s') 96 | if args.segments is None: 97 | audio_list = load_wav_scp(args.wav_scp) 98 | else: 99 | audio_list = load_wav_segments(args.wav_scp, args.segments) 100 | 101 | count = 0 102 | with open(args.out_ark, 'wb') as ark_fout, \ 103 | open(args.out_scp, 'w', encoding='utf8') as scp_fout: 104 | for item in audio_list: 105 | if len(item) == 2: 106 | key, wav_path = item 107 | waveform, sample_rate = torchaudio.load_wav(wav_path) 108 | else: 109 | assert len(item) == 4 110 | key, wav_path, start, end = item 111 | sample_rate = torchaudio.info(wav_path).sample_rate 112 | frame_offset = int(start * sample_rate) 113 | num_frames = int((end - start) * sample_rate) 114 | waveform, sample_rate = torchaudio.load_wav( 115 | wav_path, frame_offset, num_frames) 116 | 117 | mat = kaldi.fbank(waveform, 118 | num_mel_bins=args.num_mel_bins, 119 | frame_length=args.frame_length, 120 | frame_shift=args.frame_shift, 121 | dither=args.dither, 122 | energy_floor=0.0, 123 | sample_frequency=sample_rate) 124 | mat = mat.detach().numpy() 125 | kaldi_io.write_ark_scp(key, mat, ark_fout, scp_fout) 126 | count += 1 127 | if count % 10000 == 0: 128 | logging.info('Progress {}/{}'.format(count, len(audio_list))) 129 | -------------------------------------------------------------------------------- /track1_asr/tools/copy_data_dir.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright 2013 Johns Hopkins University (author: Daniel Povey) 4 | # Apache 2.0 5 | 6 | # This script operates on a directory, such as in data/train/, 7 | # that contains some subset of the following files: 8 | # feats.scp 9 | # wav.scp 10 | # vad.scp 11 | # spk2utt 12 | # utt2spk 13 | # text 14 | # 15 | # It copies to another directory, possibly adding a specified prefix or a suffix 16 | # to the utterance and/or speaker names. Note, the recording-ids stay the same. 17 | # 18 | 19 | 20 | # begin configuration section 21 | spk_prefix= 22 | utt_prefix= 23 | spk_suffix= 24 | utt_suffix= 25 | validate_opts= # should rarely be needed. 26 | # end configuration section 27 | 28 | . utils/parse_options.sh 29 | 30 | if [ $# != 2 ]; then 31 | echo "Usage: " 32 | echo " $0 [options] " 33 | echo "e.g.:" 34 | echo " $0 --spk-prefix=1- --utt-prefix=1- data/train data/train_1" 35 | echo "Options" 36 | echo " --spk-prefix= # Prefix for speaker ids, default empty" 37 | echo " --utt-prefix= # Prefix for utterance ids, default empty" 38 | echo " --spk-suffix= # Suffix for speaker ids, default empty" 39 | echo " --utt-suffix= # Suffix for utterance ids, default empty" 40 | exit 1; 41 | fi 42 | 43 | 44 | export LC_ALL=C 45 | 46 | srcdir=$1 47 | destdir=$2 48 | 49 | if [ ! -f $srcdir/utt2spk ]; then 50 | echo "copy_data_dir.sh: no such file $srcdir/utt2spk" 51 | exit 1; 52 | fi 53 | 54 | if [ "$destdir" == "$srcdir" ]; then 55 | echo "$0: this script requires and to be different." 56 | exit 1 57 | fi 58 | 59 | set -e; 60 | 61 | mkdir -p $destdir 62 | 63 | cat $srcdir/utt2spk | awk -v p=$utt_prefix -v s=$utt_suffix '{printf("%s %s%s%s\n", $1, p, $1, s);}' > $destdir/utt_map 64 | cat $srcdir/spk2utt | awk -v p=$spk_prefix -v s=$spk_suffix '{printf("%s %s%s%s\n", $1, p, $1, s);}' > $destdir/spk_map 65 | 66 | if [ ! -f $srcdir/utt2uniq ]; then 67 | if [[ ! -z $utt_prefix || ! -z $utt_suffix ]]; then 68 | cat $srcdir/utt2spk | awk -v p=$utt_prefix -v s=$utt_suffix '{printf("%s%s%s %s\n", p, $1, s, $1);}' > $destdir/utt2uniq 69 | fi 70 | else 71 | cat $srcdir/utt2uniq | awk -v p=$utt_prefix -v s=$utt_suffix '{printf("%s%s%s %s\n", p, $1, s, $2);}' > $destdir/utt2uniq 72 | fi 73 | 74 | cat $srcdir/utt2spk | utils/apply_map.pl -f 1 $destdir/utt_map | \ 75 | utils/apply_map.pl -f 2 $destdir/spk_map >$destdir/utt2spk 76 | 77 | utils/utt2spk_to_spk2utt.pl <$destdir/utt2spk >$destdir/spk2utt 78 | 79 | if [ -f $srcdir/feats.scp ]; then 80 | utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/feats.scp >$destdir/feats.scp 81 | fi 82 | 83 | if [ -f $srcdir/vad.scp ]; then 84 | utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/vad.scp >$destdir/vad.scp 85 | fi 86 | 87 | if [ -f $srcdir/segments ]; then 88 | utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/segments >$destdir/segments 89 | cp $srcdir/wav.scp $destdir 90 | else # no segments->wav indexed by utt. 91 | if [ -f $srcdir/wav.scp ]; then 92 | utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/wav.scp >$destdir/wav.scp 93 | fi 94 | fi 95 | 96 | if [ -f $srcdir/reco2file_and_channel ]; then 97 | cp $srcdir/reco2file_and_channel $destdir/ 98 | fi 99 | 100 | if [ -f $srcdir/text ]; then 101 | utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/text >$destdir/text 102 | fi 103 | if [ -f $srcdir/utt2dur ]; then 104 | utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/utt2dur >$destdir/utt2dur 105 | fi 106 | if [ -f $srcdir/utt2num_frames ]; then 107 | utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/utt2num_frames >$destdir/utt2num_frames 108 | fi 109 | if [ -f $srcdir/reco2dur ]; then 110 | if [ -f $srcdir/segments ]; then 111 | cp $srcdir/reco2dur $destdir/reco2dur 112 | else 113 | utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/reco2dur >$destdir/reco2dur 114 | fi 115 | fi 116 | if [ -f $srcdir/spk2gender ]; then 117 | utils/apply_map.pl -f 1 $destdir/spk_map <$srcdir/spk2gender >$destdir/spk2gender 118 | fi 119 | if [ -f $srcdir/cmvn.scp ]; then 120 | utils/apply_map.pl -f 1 $destdir/spk_map <$srcdir/cmvn.scp >$destdir/cmvn.scp 121 | fi 122 | for f in frame_shift stm glm ctm; do 123 | if [ -f $srcdir/$f ]; then 124 | cp $srcdir/$f $destdir 125 | fi 126 | done 127 | 128 | rm $destdir/spk_map $destdir/utt_map 129 | 130 | echo "$0: copied data from $srcdir to $destdir" 131 | 132 | for f in feats.scp cmvn.scp vad.scp utt2lang utt2uniq utt2dur utt2num_frames text wav.scp reco2file_and_channel frame_shift stm glm ctm; do 133 | if [ -f $destdir/$f ] && [ ! -f $srcdir/$f ]; then 134 | echo "$0: file $f exists in dest $destdir but not in src $srcdir. Moving it to" 135 | echo " ... $destdir/.backup/$f" 136 | mkdir -p $destdir/.backup 137 | mv $destdir/$f $destdir/.backup/ 138 | fi 139 | done 140 | 141 | 142 | [ ! -f $srcdir/feats.scp ] && validate_opts="$validate_opts --no-feats" 143 | [ ! -f $srcdir/text ] && validate_opts="$validate_opts --no-text" 144 | 145 | echo $validate_opts 146 | echo $destdir 147 | utils/validate_data_dir.sh $validate_opts $destdir 148 | -------------------------------------------------------------------------------- /track1_asr/tools/data/remove_dup_utts.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Script taken from kaldi repo: 4 | # https://github.com/kaldi-asr/kaldi/blob/master/egs/wsj/s5/utils/data/remove_dup_utts.sh 5 | 6 | # Remove excess utterances once they appear more than a specified 7 | # number of times with the same transcription, in a data set. 8 | # E.g. useful for removing excess "uh-huh" from training. 9 | 10 | if [ $# != 3 ]; then 11 | echo "Usage: remove_dup_utts.sh max-count " 12 | echo "e.g.: remove_dup_utts.sh 10 data/train data/train_nodup" 13 | echo "This script is used to filter out utterances that have from over-represented" 14 | echo "transcriptions (such as 'uh-huh'), by limiting the number of repetitions of" 15 | echo "any given word-sequence to a specified value. It's often used to get" 16 | echo "subsets for early stages of training." 17 | exit 1; 18 | fi 19 | 20 | maxcount=$1 21 | srcdir=$2 22 | destdir=$3 23 | mkdir -p $destdir 24 | 25 | [ ! -f $srcdir/text ] && echo "$0: Invalid input directory $srcdir" && exit 1; 26 | 27 | ! mkdir -p $destdir && echo "$0: could not create directory $destdir" && exit 1; 28 | 29 | ! [ "$maxcount" -gt 1 ] && echo "$0: invalid max-count '$maxcount'" && exit 1; 30 | 31 | cp $srcdir/* $destdir 32 | cat $srcdir/text | \ 33 | perl -e ' 34 | $maxcount = shift @ARGV; 35 | @all = (); 36 | $p1 = 103349; $p2 = 71147; $k = 0; 37 | sub random { # our own random number generator: predictable. 38 | $k = ($k + $p1) % $p2; 39 | return ($k / $p2); 40 | } 41 | while(<>) { 42 | push @all, $_; 43 | @A = split(" ", $_); 44 | shift @A; 45 | $text = join(" ", @A); 46 | $count{$text} ++; 47 | } 48 | foreach $line (@all) { 49 | @A = split(" ", $line); 50 | shift @A; 51 | $text = join(" ", @A); 52 | $n = $count{$text}; 53 | if ($n < $maxcount || random() < ($maxcount / $n)) { 54 | print $line; 55 | } 56 | }' $maxcount >$destdir/text 57 | 58 | echo "Reduced number of utterances from `cat $srcdir/text | wc -l` to `cat $destdir/text | wc -l`" 59 | 60 | # Not doing these steps as this script doesn't exist 61 | # + the calling script already validates data 62 | #echo "Using fix_data_dir.sh to reconcile the other files." 63 | #utils/fix_data_dir.sh $destdir 64 | #rm -r $destdir/.backup 65 | 66 | exit 0 67 | -------------------------------------------------------------------------------- /track1_asr/tools/decode.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright 2021 Mobvoi Inc. All Rights Reserved. 3 | # Author: binbinzhang@mobvoi.com (Binbin Zhang) 4 | export GLOG_logtostderr=1 5 | export GLOG_v=2 6 | 7 | set -e 8 | 9 | nj=1 10 | chunk_size=-1 11 | ctc_weight=0.0 12 | reverse_weight=0.0 13 | rescoring_weight=1.0 14 | # For CTC WFST based decoding 15 | fst_path= 16 | dict_path= 17 | acoustic_scale=1.0 18 | beam=15.0 19 | lattice_beam=12.0 20 | min_active=200 21 | max_active=7000 22 | blank_skip_thresh=1.0 23 | length_penalty=0.0 24 | 25 | . tools/parse_options.sh || exit 1; 26 | if [ $# != 5 ]; then 27 | echo "Usage: $0 [options] " 28 | exit 1; 29 | fi 30 | 31 | if ! which decoder_main > /dev/null; then 32 | echo "decoder_main is not built, please go to runtime/libtorch to build it." 33 | exit 1; 34 | fi 35 | 36 | scp=$1 37 | label_file=$2 38 | model_file=$3 39 | unit_file=$4 40 | dir=$5 41 | 42 | mkdir -p $dir/split${nj} 43 | 44 | # Step 1. Split wav.scp 45 | split_scps="" 46 | for n in $(seq ${nj}); do 47 | split_scps="${split_scps} ${dir}/split${nj}/wav.${n}.scp" 48 | done 49 | tools/data/split_scp.pl ${scp} ${split_scps} 50 | 51 | # Step 2. Parallel decoding 52 | wfst_decode_opts= 53 | if [ ! -z $fst_path ]; then 54 | wfst_decode_opts="--fst_path $fst_path" 55 | wfst_decode_opts="$wfst_decode_opts --beam $beam" 56 | wfst_decode_opts="$wfst_decode_opts --dict_path $dict_path" 57 | wfst_decode_opts="$wfst_decode_opts --lattice_beam $lattice_beam" 58 | wfst_decode_opts="$wfst_decode_opts --max_active $max_active" 59 | wfst_decode_opts="$wfst_decode_opts --min_active $min_active" 60 | wfst_decode_opts="$wfst_decode_opts --acoustic_scale $acoustic_scale" 61 | wfst_decode_opts="$wfst_decode_opts --blank_skip_thresh $blank_skip_thresh" 62 | wfst_decode_opts="$wfst_decode_opts --length_penalty $length_penalty" 63 | echo $wfst_decode_opts > $dir/config 64 | fi 65 | for n in $(seq ${nj}); do 66 | { 67 | decoder_main \ 68 | --rescoring_weight $rescoring_weight \ 69 | --ctc_weight $ctc_weight \ 70 | --reverse_weight $reverse_weight \ 71 | --chunk_size $chunk_size \ 72 | --wav_scp ${dir}/split${nj}/wav.${n}.scp \ 73 | --model_path $model_file \ 74 | --unit_path $unit_file \ 75 | $wfst_decode_opts \ 76 | --result ${dir}/split${nj}/${n}.text &> ${dir}/split${nj}/${n}.log 77 | } & 78 | done 79 | wait 80 | 81 | # Step 3. Merge files 82 | for n in $(seq ${nj}); do 83 | cat ${dir}/split${nj}/${n}.text 84 | done > ${dir}/text 85 | tail $dir/split${nj}/*.log | grep RTF | awk '{sum+=$NF}END{print sum/NR}' > $dir/rtf 86 | 87 | # Step 4. Compute WER 88 | python3 tools/compute-wer.py --char=1 --v=1 \ 89 | $label_file $dir/text > $dir/wer 90 | -------------------------------------------------------------------------------- /track1_asr/tools/feat_to_shape.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Begin configuration section. 4 | nj=4 5 | cmd=run.pl 6 | verbose=0 7 | filetype="" 8 | preprocess_conf="" 9 | # End configuration section. 10 | 11 | help_message=$(cat << EOF 12 | Usage: $0 [options] [] 13 | e.g.: $0 data/train/feats.scp data/train/shape.scp data/train/log 14 | Options: 15 | --nj # number of parallel jobs 16 | --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs. 17 | --filetype # Specify the format of feats file 18 | --preprocess-conf # Apply preprocess to feats when creating shape.scp 19 | --verbose # Default: 0 20 | EOF 21 | ) 22 | 23 | echo "$0 $*" 1>&2 # Print the command line for logging 24 | 25 | . parse_options.sh || exit 1; 26 | 27 | if [ $# -lt 2 ] || [ $# -gt 3 ]; then 28 | echo "${help_message}" 1>&2 29 | exit 1; 30 | fi 31 | 32 | set -euo pipefail 33 | 34 | scp=$1 35 | outscp=$2 36 | data=$(dirname ${scp}) 37 | if [ $# -eq 3 ]; then 38 | logdir=$3 39 | else 40 | logdir=${data}/log 41 | fi 42 | mkdir -p ${logdir} 43 | 44 | split_scps="" 45 | for n in $(seq ${nj}); do 46 | split_scps="${split_scps} ${logdir}/feats.${n}.scp" 47 | done 48 | 49 | utils/split_scp.pl ${scp} ${split_scps} 50 | 51 | if [ -n "${preprocess_conf}" ]; then 52 | preprocess_opt="--preprocess-conf ${preprocess_conf}" 53 | else 54 | preprocess_opt="" 55 | fi 56 | if [ -n "${filetype}" ]; then 57 | filetype_opt="--filetype ${filetype}" 58 | else 59 | filetype_opt="" 60 | fi 61 | 62 | ${cmd} JOB=1:${nj} ${logdir}/feat_to_shape.JOB.log \ 63 | feat-to-len --verbose=${verbose} \ 64 | scp:${logdir}/feats.JOB.scp ark,t:${logdir}/shape.JOB.scp 65 | 66 | feat_dim=$(feat-to-dim scp:$logdir/feats.1.scp -) 67 | 68 | # concatenate the .scp files together. 69 | for n in $(seq ${nj}); do 70 | sed "s:\ *$:,$feat_dim:g" ${logdir}/shape.${n}.scp 71 | done > ${outscp} 72 | 73 | rm -f ${logdir}/feats.*.scp 2>/dev/null 74 | -------------------------------------------------------------------------------- /track1_asr/tools/filter_scp.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # Copyright 2010-2012 Microsoft Corporation 3 | # Johns Hopkins University (author: Daniel Povey) 4 | 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 12 | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED 13 | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, 14 | # MERCHANTABLITY OR NON-INFRINGEMENT. 15 | # See the Apache 2 License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | 19 | # This script takes a list of utterance-ids or any file whose first field 20 | # of each line is an utterance-id, and filters an scp 21 | # file (or any file whose "n-th" field is an utterance id), printing 22 | # out only those lines whose "n-th" field is in id_list. The index of 23 | # the "n-th" field is 1, by default, but can be changed by using 24 | # the -f switch 25 | 26 | $exclude = 0; 27 | $field = 1; 28 | $shifted = 0; 29 | 30 | do { 31 | $shifted=0; 32 | if ($ARGV[0] eq "--exclude") { 33 | $exclude = 1; 34 | shift @ARGV; 35 | $shifted=1; 36 | } 37 | if ($ARGV[0] eq "-f") { 38 | $field = $ARGV[1]; 39 | shift @ARGV; shift @ARGV; 40 | $shifted=1 41 | } 42 | } while ($shifted); 43 | 44 | if(@ARGV < 1 || @ARGV > 2) { 45 | die "Usage: filter_scp.pl [--exclude] [-f ] id_list [in.scp] > out.scp \n" . 46 | "Prints only the input lines whose f'th field (default: first) is in 'id_list'.\n" . 47 | "Note: only the first field of each line in id_list matters. With --exclude, prints\n" . 48 | "only the lines that were *not* in id_list.\n" . 49 | "Caution: previously, the -f option was interpreted as a zero-based field index.\n" . 50 | "If your older scripts (written before Oct 2014) stopped working and you used the\n" . 51 | "-f option, add 1 to the argument.\n" . 52 | "See also: utils/filter_scp.pl .\n"; 53 | } 54 | 55 | 56 | $idlist = shift @ARGV; 57 | open(F, "<$idlist") || die "Could not open id-list file $idlist"; 58 | while() { 59 | @A = split; 60 | @A>=1 || die "Invalid id-list file line $_"; 61 | $seen{$A[0]} = 1; 62 | } 63 | 64 | if ($field == 1) { # Treat this as special case, since it is common. 65 | while(<>) { 66 | $_ =~ m/\s*(\S+)\s*/ || die "Bad line $_, could not get first field."; 67 | # $1 is what we filter on. 68 | if ((!$exclude && $seen{$1}) || ($exclude && !defined $seen{$1})) { 69 | print $_; 70 | } 71 | } 72 | } else { 73 | while(<>) { 74 | @A = split; 75 | @A > 0 || die "Invalid scp file line $_"; 76 | @A >= $field || die "Invalid scp file line $_"; 77 | if ((!$exclude && $seen{$A[$field-1]}) || ($exclude && !defined $seen{$A[$field-1]})) { 78 | print $_; 79 | } 80 | } 81 | } 82 | 83 | # tests: 84 | # the following should print "foo 1" 85 | # ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl <(echo foo) 86 | # the following should print "bar 2". 87 | # ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl -f 2 <(echo 2) 88 | -------------------------------------------------------------------------------- /track1_asr/tools/filter_uneven_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # Copyright [2023-04-27] 4 | 5 | import os 6 | import random 7 | import tarfile 8 | 9 | random.seed(1024) 10 | 11 | # parse arg from command line 12 | datalist = os.sys.argv[1] 13 | datatype = os.sys.argv[2] 14 | num_gpus = int(os.sys.argv[3]) 15 | num_samples_per_tar = int(os.sys.argv[4]) # only used in shard mode 16 | new_datalist = os.sys.argv[5] 17 | 18 | assert datatype in ["shard", "raw"] 19 | 20 | 21 | filtered_list = [] 22 | with open(datalist, "r") as f: 23 | lines = f.readlines() 24 | lines = [l.strip() for l in lines] 25 | if datatype == "raw": 26 | valid_num = len(lines) // num_gpus * num_gpus 27 | random.shuffle(lines) 28 | filtered_list = lines[:valid_num] 29 | else: 30 | for line in lines: 31 | cnt = 0 32 | with open(line, "rb") as tar: 33 | stream = tarfile.open(fileobj=tar, mode="r|*") 34 | for tarinfo in stream: 35 | name = tarinfo.name 36 | pos = name.rfind('.') 37 | assert pos > 0 38 | prefix, postfix = name[:pos], name[pos + 1:] 39 | if postfix == 'txt': 40 | cnt += 1 41 | if cnt == num_samples_per_tar: 42 | filtered_list.append(line) 43 | valid_num = len(filtered_list) // num_gpus * num_gpus 44 | random.shuffle(filtered_list) 45 | filtered_list = filtered_list[:valid_num] 46 | filtered_list.sort() 47 | print("before filter: {} after filter: {}".format(len(lines), len(filtered_list))) 48 | 49 | with open(new_datalist, "w") as f: 50 | for line in filtered_list: 51 | f.writelines("{}\n".format(line)) 52 | -------------------------------------------------------------------------------- /track1_asr/tools/flake8_hook.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # encoding: utf-8 3 | import sys 4 | 5 | from flake8.main import git 6 | 7 | if __name__ == '__main__': 8 | sys.exit( 9 | git.hook( 10 | strict=True, 11 | lazy=git.config_for('lazy'), 12 | ) 13 | ) 14 | -------------------------------------------------------------------------------- /track1_asr/tools/fst/compile_lexicon_token_fst.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2015 Yajie Miao (Carnegie Mellon University) 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 11 | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED 12 | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, 13 | # MERCHANTABLITY OR NON-INFRINGEMENT. 14 | # See the Apache 2 License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # This script compiles the lexicon and CTC tokens into FSTs. FST compiling slightly differs between the 18 | # phoneme and character-based lexicons. 19 | set -eo pipefail 20 | . tools/parse_options.sh 21 | 22 | if [ $# -ne 3 ]; then 23 | echo "usage: tools/fst/compile_lexicon_token_fst.sh " 24 | echo "e.g.: tools/fst/compile_lexicon_token_fst.sh data/local/dict data/local/lang_tmp data/lang" 25 | echo " should contain the following files:" 26 | echo "lexicon.txt units.txt" 27 | echo "options: " 28 | exit 1; 29 | fi 30 | 31 | srcdir=$1 32 | tmpdir=$2 33 | dir=$3 34 | mkdir -p $dir $tmpdir 35 | 36 | [ -f path.sh ] && . ./path.sh 37 | 38 | export LC_ALL=C 39 | 40 | cp $srcdir/units.txt $dir 41 | 42 | # Add probabilities to lexicon entries. There is in fact no point of doing this here since all the entries have 1.0. 43 | # But utils/make_lexicon_fst.pl requires a probabilistic version, so we just leave it as it is. 44 | perl -ape 's/(\S+\s+)(.+)/${1}1.0\t$2/;' < $srcdir/lexicon.txt > $tmpdir/lexiconp.txt || exit 1; 45 | 46 | # Add disambiguation symbols to the lexicon. This is necessary for determinizing the composition of L.fst and G.fst. 47 | # Without these symbols, determinization will fail. 48 | ndisambig=`tools/fst/add_lex_disambig.pl $tmpdir/lexiconp.txt $tmpdir/lexiconp_disambig.txt` 49 | ndisambig=$[$ndisambig+1]; 50 | 51 | ( for n in `seq 0 $ndisambig`; do echo '#'$n; done ) > $tmpdir/disambig.list 52 | 53 | # Get the full list of CTC tokens used in FST. These tokens include , the blank , 54 | # the actual model unit, and the disambiguation symbols. 55 | cat $srcdir/units.txt | awk '{print $1}' > $tmpdir/units.list 56 | (echo '';) | cat - $tmpdir/units.list $tmpdir/disambig.list | awk '{print $1 " " (NR-1)}' > $dir/tokens.txt 57 | 58 | # ctc_token_fst_corrected is too big and too slow for character based chinese modeling, 59 | # so here use ctc_token_fst_compact 60 | tools/fst/ctc_token_fst_compact.py $dir/tokens.txt | \ 61 | fstcompile --isymbols=$dir/tokens.txt --osymbols=$dir/tokens.txt --keep_isymbols=false --keep_osymbols=false | \ 62 | fstarcsort --sort_type=olabel > $dir/T.fst || exit 1; 63 | 64 | # Encode the words with indices. Will be used in lexicon and language model FST compiling. 65 | cat $tmpdir/lexiconp.txt | awk '{print $1}' | sort | uniq | awk ' 66 | BEGIN { 67 | print " 0"; 68 | } 69 | { 70 | printf("%s %d\n", $1, NR); 71 | } 72 | END { 73 | printf("#0 %d\n", NR+1); 74 | printf(" %d\n", NR+2); 75 | printf(" %d\n", NR+3); 76 | }' > $dir/words.txt || exit 1; 77 | 78 | # Now compile the lexicon FST. Depending on the size of your lexicon, it may take some time. 79 | token_disambig_symbol=`grep \#0 $dir/tokens.txt | awk '{print $2}'` 80 | word_disambig_symbol=`grep \#0 $dir/words.txt | awk '{print $2}'` 81 | 82 | tools/fst/make_lexicon_fst.pl --pron-probs $tmpdir/lexiconp_disambig.txt 0 "sil" '#'$ndisambig | \ 83 | fstcompile --isymbols=$dir/tokens.txt --osymbols=$dir/words.txt \ 84 | --keep_isymbols=false --keep_osymbols=false | \ 85 | fstaddselfloops "echo $token_disambig_symbol |" "echo $word_disambig_symbol |" | \ 86 | fstarcsort --sort_type=olabel > $dir/L.fst || exit 1; 87 | 88 | echo "Lexicon and token FSTs compiling succeeded" 89 | -------------------------------------------------------------------------------- /track1_asr/tools/fst/ctc_token_fst.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import sys 4 | 5 | print('0 1 ') 6 | print('1 1 ') 7 | print('2 2 ') 8 | print('2 0 ') 9 | 10 | with open(sys.argv[1], 'r', encoding='utf8') as fin: 11 | node = 3 12 | for entry in fin: 13 | fields = entry.strip().split(' ') 14 | phone = fields[0] 15 | if phone == '' or phone == '': 16 | continue 17 | elif '#' in phone: # disambiguous phone 18 | print('{} {} {} {}'.format(0, 0, '', phone)) 19 | else: 20 | print('{} {} {} {}'.format(1, node, phone, phone)) 21 | print('{} {} {} {}'.format(node, node, phone, '')) 22 | print('{} {} {} {}'.format(node, 2, '', '')) 23 | node += 1 24 | print('0') 25 | -------------------------------------------------------------------------------- /track1_asr/tools/fst/ctc_token_fst_compact.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import sys 4 | 5 | print('0 0 ') 6 | 7 | with open(sys.argv[1], 'r', encoding='utf8') as fin: 8 | node = 1 9 | for entry in fin: 10 | fields = entry.strip().split(' ') 11 | phone = fields[0] 12 | if phone == '' or phone == '': 13 | continue 14 | elif '#' in phone: # disambiguous phone 15 | print('{} {} {} {}'.format(0, 0, '', phone)) 16 | else: 17 | print('{} {} {} {}'.format(0, node, phone, phone)) 18 | print('{} {} {} {}'.format(node, node, phone, '')) 19 | print('{} {} {} {}'.format(node, 0, '', '')) 20 | node += 1 21 | print('0') 22 | -------------------------------------------------------------------------------- /track1_asr/tools/fst/ctc_token_fst_corrected.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import sys 4 | 5 | 6 | def il(n): 7 | return n + 1 8 | 9 | 10 | def ol(n): 11 | return n + 1 12 | 13 | 14 | def s(n): 15 | return n 16 | 17 | 18 | if __name__ == "__main__": 19 | with open(sys.argv[1]) as f: 20 | lines = f.readlines() 21 | phone_count = 0 22 | disambig_count = 0 23 | for line in lines: 24 | sp = line.split() 25 | phone = sp[0] 26 | if phone == '' or phone == '': 27 | continue 28 | if phone.startswith('#'): 29 | disambig_count += 1 30 | else: 31 | phone_count += 1 32 | 33 | # 1. add start state 34 | print('0 0 {} 0'.format(il(0))) 35 | 36 | # 2. 0 -> i, i -> i, i -> 0 37 | for i in range(1, phone_count + 1): 38 | print('0 {} {} {}'.format(s(i), il(i), ol(i))) 39 | print('{} {} {} 0'.format(s(i), s(i), il(i))) 40 | print('{} 0 {} 0'.format(s(i), il(0))) 41 | 42 | # 3. i -> other phone 43 | for i in range(1, phone_count + 1): 44 | for j in range(1, phone_count + 1): 45 | if i != j: 46 | print('{} {} {} {}'.format(s(i), s(j), il(j), ol(j))) 47 | 48 | # 4. add disambiguous arcs on every final state 49 | for i in range(0, phone_count + 1): 50 | for j in range(phone_count + 2, phone_count + disambig_count + 2): 51 | print('{} {} {} {}'.format(s(i), s(i), 0, j)) 52 | 53 | # 5. every i is final state 54 | for i in range(0, phone_count + 1): 55 | print(s(i)) 56 | -------------------------------------------------------------------------------- /track1_asr/tools/fst/eps2disambig.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # Copyright 2010-2011 Microsoft Corporation 3 | # 2015 Guoguo Chen 4 | 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 12 | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED 13 | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, 14 | # MERCHANTABLITY OR NON-INFRINGEMENT. 15 | # See the Apache 2 License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | # This script replaces epsilon with #0 on the input side only, of the G.fst 19 | # acceptor. 20 | 21 | while(<>){ 22 | if (/\s+#0\s+/) { 23 | print STDERR "$0: ERROR: LM has word #0, " . 24 | "which is reserved as disambiguation symbol\n"; 25 | exit 1; 26 | } 27 | s:^(\d+\s+\d+\s+)\(\s+):$1#0$2:; 28 | print; 29 | } 30 | -------------------------------------------------------------------------------- /track1_asr/tools/fst/make_tlg.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | 4 | if [ -f path.sh ]; then . path.sh; fi 5 | 6 | lm_dir=$1 7 | src_lang=$2 8 | tgt_lang=$3 9 | 10 | arpa_lm=${lm_dir}/lm.arpa 11 | [ ! -f $arpa_lm ] && echo No such file $arpa_lm && exit 1; 12 | 13 | rm -rf $tgt_lang 14 | cp -r $src_lang $tgt_lang 15 | 16 | # Compose the language model to FST 17 | cat $arpa_lm | \ 18 | grep -v ' ' | \ 19 | grep -v ' ' | \ 20 | grep -v ' ' | \ 21 | grep -v -i '' | \ 22 | grep -v -i '' | \ 23 | arpa2fst --read-symbol-table=$tgt_lang/words.txt --keep-symbols=true - | fstprint | \ 24 | tools/fst/eps2disambig.pl | tools/fst/s2eps.pl | fstcompile --isymbols=$tgt_lang/words.txt \ 25 | --osymbols=$tgt_lang/words.txt --keep_isymbols=false --keep_osymbols=false | \ 26 | fstrmepsilon | fstarcsort --sort_type=ilabel > $tgt_lang/G.fst 27 | 28 | 29 | echo "Checking how stochastic G is (the first of these numbers should be small):" 30 | fstisstochastic $tgt_lang/G.fst 31 | 32 | # Compose the token, lexicon and language-model FST into the final decoding graph 33 | fsttablecompose $tgt_lang/L.fst $tgt_lang/G.fst | fstdeterminizestar --use-log=true | \ 34 | fstminimizeencoded | fstarcsort --sort_type=ilabel > $tgt_lang/LG.fst || exit 1; 35 | fsttablecompose $tgt_lang/T.fst $tgt_lang/LG.fst > $tgt_lang/TLG.fst || exit 1; 36 | 37 | echo "Composing decoding graph TLG.fst succeeded" 38 | #rm -r $tgt_lang/LG.fst # We don't need to keep this intermediate FST 39 | -------------------------------------------------------------------------------- /track1_asr/tools/fst/prepare_dict.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # encoding: utf-8 3 | 4 | import sys 5 | 6 | # sys.argv[1]: e2e model unit file(lang_char.txt) 7 | # sys.argv[2]: raw lexicon file 8 | # sys.argv[3]: output lexicon file 9 | # sys.argv[4]: bpemodel 10 | 11 | unit_table = set() 12 | with open(sys.argv[1], 'r', encoding='utf8') as fin: 13 | for line in fin: 14 | unit = line.split()[0] 15 | unit_table.add(unit) 16 | 17 | 18 | def contain_oov(units): 19 | for unit in units: 20 | if unit not in unit_table: 21 | return True 22 | return False 23 | 24 | 25 | bpemode = len(sys.argv) > 4 26 | if bpemode: 27 | import sentencepiece as spm 28 | sp = spm.SentencePieceProcessor() 29 | sp.Load(sys.argv[4]) 30 | lexicon_table = set() 31 | with open(sys.argv[2], 'r', encoding='utf8') as fin, \ 32 | open(sys.argv[3], 'w', encoding='utf8') as fout: 33 | for line in fin: 34 | word = line.split()[0] 35 | if word == 'SIL' and not bpemode: # `sil` might be a valid piece in bpemodel 36 | continue 37 | elif word == '': 38 | continue 39 | else: 40 | # each word only has one pronunciation for e2e system 41 | if word in lexicon_table: 42 | continue 43 | if bpemode: 44 | # We assume that the lexicon does not contain code-switch, 45 | # i.e. the word contains both English and Chinese. 46 | # see PR https://github.com/wenet-e2e/wenet/pull/1693 47 | # and Issue https://github.com/wenet-e2e/wenet/issues/1653 48 | if word.replace('\'', '').encode("utf-8").isalpha(): 49 | pieces = sp.EncodeAsPieces(word) 50 | else: 51 | pieces = word 52 | if contain_oov(pieces): 53 | print( 54 | 'Ignoring words {}, which contains oov unit'.format( 55 | ''.join(word).strip('▁')) 56 | ) 57 | continue 58 | chars = ' '.join( 59 | [p if p in unit_table else '' for p in pieces]) 60 | else: 61 | # ignore words with OOV 62 | if contain_oov(word): 63 | print('Ignoring words {}, which contains oov unit'.format(word)) 64 | continue 65 | # Optional, append ▁ in front of english word 66 | # we assume the model unit of our e2e system is char now. 67 | if word.replace('\'', '').encode("utf-8").isalpha() and \ 68 | '▁' in unit_table: 69 | word = '▁' + word 70 | chars = ' '.join(word) # word is a char list 71 | fout.write('{} {}\n'.format(word, chars)) 72 | lexicon_table.add(word) 73 | -------------------------------------------------------------------------------- /track1_asr/tools/fst/remove_oovs.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # Copyright 2010-2011 Microsoft Corporation 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 11 | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED 12 | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, 13 | # MERCHANTABLITY OR NON-INFRINGEMENT. 14 | # See the Apache 2 License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # This script removes lines that contain these OOVs on either the 18 | # third or fourth fields of the line. It is intended to remove arcs 19 | # with OOVs on, from FSTs (probably compiled from ARPAs with OOVs in). 20 | 21 | if ( @ARGV < 1 && @ARGV > 2) { 22 | die "Usage: remove_oovs.pl unk_list.txt [ printed-fst ]\n"; 23 | } 24 | 25 | $unklist = shift @ARGV; 26 | open(S, "<$unklist") || die "Failed opening unknown-symbol list $unklist\n"; 27 | while(){ 28 | @A = split(" ", $_); 29 | @A == 1 || die "Bad line in unknown-symbol list: $_"; 30 | $unk{$A[0]} = 1; 31 | } 32 | 33 | $num_removed = 0; 34 | while(<>){ 35 | @A = split(" ", $_); 36 | if(defined $unk{$A[2]} || defined $unk{$A[3]}) { 37 | $num_removed++; 38 | } else { 39 | print; 40 | } 41 | } 42 | print STDERR "remove_oovs.pl: removed $num_removed lines.\n"; 43 | 44 | -------------------------------------------------------------------------------- /track1_asr/tools/fst/rnnt_token_fst.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import sys 4 | 5 | print('0 0 ') 6 | 7 | with open(sys.argv[1], 'r', encoding='utf8') as fin: 8 | for entry in fin: 9 | fields = entry.strip().split(' ') 10 | phone = fields[0] 11 | if phone == '' or phone == '': 12 | continue 13 | elif '#' in phone: # disambiguous phone 14 | print('{} {} {} {}'.format(0, 0, '', phone)) 15 | else: 16 | print('{} {} {} {}'.format(0, 0, phone, phone)) 17 | print('0') 18 | -------------------------------------------------------------------------------- /track1_asr/tools/fst/s2eps.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # Copyright 2010-2011 Microsoft Corporation 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 11 | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED 12 | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, 13 | # MERCHANTABLITY OR NON-INFRINGEMENT. 14 | # See the Apache 2 License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # This script replaces and with (on both input and output sides), 18 | # for the G.fst acceptor. 19 | 20 | while(<>){ 21 | @A = split(" ", $_); 22 | if ( @A >= 4 ) { 23 | if ($A[2] eq "" || $A[2] eq "") { $A[2] = ""; } 24 | if ($A[3] eq "" || $A[3] eq "") { $A[3] = ""; } 25 | } 26 | print join("\t", @A) . "\n"; 27 | } 28 | -------------------------------------------------------------------------------- /track1_asr/tools/git-pre-commit: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | echo "Running pre-commit flake8" 5 | python tools/flake8_hook.py 6 | -------------------------------------------------------------------------------- /track1_asr/tools/install_srilm.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright 2012 Johns Hopkins University (Author: Daniel Povey). Apache 2.0. 3 | # 2022 Binbin Zhang(binbzha@qq.com) 4 | 5 | current_path=`pwd` 6 | current_dir=`basename "$current_path"` 7 | 8 | if [ "tools" != "$current_dir" ]; then 9 | echo "You should run this script in tools/ directory!!" 10 | exit 1 11 | fi 12 | 13 | ! command -v gawk > /dev/null && \ 14 | echo "GNU awk is not installed so SRILM will probably not work correctly: refusing to install" && exit 1; 15 | 16 | srilm_url="https://github.com/BitSpeech/SRILM/archive/refs/tags/1.7.3.tar.gz" 17 | 18 | if [ ! -f ./srilm.tar.gz ]; then 19 | if ! wget -O ./srilm.tar.gz "$srilm_url"; then 20 | echo 'There was a problem downloading the file.' 21 | echo 'Check you internet connection and try again.' 22 | exit 1 23 | fi 24 | fi 25 | 26 | tar -zxvf srilm.tar.gz 27 | mv SRILM-1.7.3 srilm 28 | 29 | # set the SRILM variable in the top-level Makefile to this directory. 30 | cd srilm 31 | cp Makefile tmpf 32 | 33 | cat tmpf | gawk -v pwd=`pwd` '/SRILM =/{printf("SRILM = %s\n", pwd); next;} {print;}' \ 34 | > Makefile || exit 1 35 | rm tmpf 36 | 37 | make || exit 38 | cd .. 39 | 40 | ( 41 | [ ! -z "${SRILM}" ] && \ 42 | echo >&2 "SRILM variable is aleady defined. Undefining..." && \ 43 | unset SRILM 44 | 45 | [ -f ./env.sh ] && . ./env.sh 46 | 47 | [ ! -z "${SRILM}" ] && \ 48 | echo >&2 "SRILM config is already in env.sh" && exit 49 | 50 | wd=`pwd` 51 | wd=`readlink -f $wd || pwd` 52 | 53 | echo "export SRILM=$wd/srilm" 54 | dirs="\${PATH}" 55 | for directory in $(cd srilm && find bin -type d ) ; do 56 | dirs="$dirs:\${SRILM}/$directory" 57 | done 58 | echo "export PATH=$dirs" 59 | ) >> env.sh 60 | 61 | echo >&2 "Installation of SRILM finished successfully" 62 | echo >&2 "Please source the tools/env.sh in your path.sh to enable it" 63 | -------------------------------------------------------------------------------- /track1_asr/tools/k2/make_hlg.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, 3 | # Wei Kang) 4 | # Copyright 2022 Ximalaya Speech Team (author: Xiang Lyu) 5 | 6 | lexion_dir=$1 7 | lm_dir=$2 8 | tgt_dir=$3 9 | 10 | # k2 and icefall updates very fast. Below commits are veryfied in this script. 11 | # k2 3dc222f981b9fdbc8061b3782c3b385514a2d444, icefall 499ac24ecba64f687ff244c7d66baa5c222ecf0f 12 | 13 | # For k2 installation, please refer to https://github.com/k2-fsa/k2/ 14 | python -c "import k2; print(k2.__file__)" 15 | python -c "import torch; import _k2; print(_k2.__file__)" 16 | 17 | # Prepare necessary icefall scripts 18 | if [ ! -d tools/k2/icefall ]; then 19 | git clone --depth 1 https://github.com/k2-fsa/icefall.git tools/k2/icefall 20 | fi 21 | pip install -r tools/k2/icefall/requirements.txt 22 | export PYTHONPATH=`pwd`/tools/k2/icefall:`pwd`/tools/k2/icefall/egs/aishell/ASR/local:$PYTHONPATH 23 | 24 | # 8.1 Prepare char based lang 25 | mkdir -p $tgt_dir 26 | python tools/k2/prepare_char.py $lexion_dir/units.txt $lm_dir/wordlist $tgt_dir 27 | echo "Compile lexicon L.pt L_disambig.pt succeeded" 28 | 29 | # 8.2 Prepare G 30 | mkdir -p data/lm 31 | python -m kaldilm \ 32 | --read-symbol-table="$tgt_dir/words.txt" \ 33 | --disambig-symbol='#0' \ 34 | --max-order=3 \ 35 | $lm_dir/lm.arpa > data/lm/G_3_gram.fst.txt 36 | 37 | # 8.3 Compile HLG 38 | python tools/k2/icefall/egs/aishell/ASR/local/compile_hlg.py --lang-dir $tgt_dir 39 | echo "Compile decoding graph HLG.pt succeeded" -------------------------------------------------------------------------------- /track1_asr/tools/k2/prepare_mmi.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, 3 | # Wei Kang) 4 | # Copyright 2023 Ximalaya Speech Team (author: Xiang Lyu) 5 | 6 | train_dir=$1 7 | dev_dir=$2 8 | tgt_dir=$3 9 | 10 | # k2 and icefall updates very fast. Below commits are veryfied in this script. 11 | # k2 3dc222f981b9fdbc8061b3782c3b385514a2d444, icefall 499ac24ecba64f687ff244c7d66baa5c222ecf0f 12 | 13 | # For k2 installation, please refer to https://github.com/k2-fsa/k2/ 14 | python -c "import k2; print(k2.__file__)" 15 | python -c "import torch; import _k2; print(_k2.__file__)" 16 | 17 | # Prepare necessary icefall scripts 18 | if [ ! -d tools/k2/icefall ]; then 19 | git clone --depth 1 https://github.com/k2-fsa/icefall.git tools/k2/icefall 20 | fi 21 | pip install -r tools/k2/icefall/requirements.txt 22 | export PYTHONPATH=`pwd`/tools/k2/icefall:`pwd`/tools/k2/icefall/egs/aishell/ASR/local:$PYTHONPATH 23 | 24 | # 1. prepare wordlist 25 | mkdir -p $tgt_dir 26 | awk 'FNR>2&&FNR<=4232{print $1}END{printf("")}' $train_dir/units.txt > $tgt_dir/wordlist 27 | 28 | # 2. prepare L.pt tokens.txt words.txt lexicon.txt uniq_lexicon.txt 29 | python tools/k2/prepare_char.py $train_dir/units.txt $tgt_dir/wordlist $tgt_dir 30 | ln -s lexicon.txt $tgt_dir/uniq_lexicon.txt 31 | 32 | # 3. prepare token level bigram 33 | cat $train_dir/text | awk '{print $2}'| sed -r 's/(.)/ \1/g' > $tgt_dir/transcript_chars.txt 34 | cat $dev_dir/text | awk '{print $2}'| sed -r 's/(.)/ \1/g' >> $tgt_dir/transcript_chars.txt 35 | 36 | ./shared/make_kn_lm.py \ 37 | -ngram-order 2 \ 38 | -text $tgt_dir/transcript_chars.txt \ 39 | -lm $tgt_dir/P.arpa 40 | python -m kaldilm \ 41 | --read-symbol-table="$tgt_dir/words.txt" \ 42 | --disambig-symbol='#0' \ 43 | --max-order=2 \ 44 | $tgt_dir/P.arpa > $tgt_dir/P.fst.txt -------------------------------------------------------------------------------- /track1_asr/tools/make_raw_list.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import argparse 18 | import json 19 | 20 | if __name__ == '__main__': 21 | parser = argparse.ArgumentParser(description='') 22 | parser.add_argument('--segments', default=None, help='segments file') 23 | parser.add_argument('wav_file', help='wav file') 24 | parser.add_argument('text_file', help='text file') 25 | parser.add_argument('output_file', help='output list file') 26 | args = parser.parse_args() 27 | 28 | wav_table = {} 29 | with open(args.wav_file, 'r', encoding='utf8') as fin: 30 | for line in fin: 31 | arr = line.strip().split() 32 | assert len(arr) == 2 33 | wav_table[arr[0]] = arr[1] 34 | 35 | if args.segments is not None: 36 | segments_table = {} 37 | with open(args.segments, 'r', encoding='utf8') as fin: 38 | for line in fin: 39 | arr = line.strip().split() 40 | assert len(arr) == 4 41 | segments_table[arr[0]] = (arr[1], float(arr[2]), float(arr[3])) 42 | 43 | with open(args.text_file, 'r', encoding='utf8') as fin, \ 44 | open(args.output_file, 'w', encoding='utf8') as fout: 45 | for line in fin: 46 | arr = line.strip().split(maxsplit=1) 47 | key = arr[0] 48 | txt = arr[1] if len(arr) > 1 else '' 49 | if args.segments is None: 50 | assert key in wav_table 51 | wav = wav_table[key] 52 | line = dict(key=key, wav=wav, txt=txt) 53 | else: 54 | assert key in segments_table 55 | wav_key, start, end = segments_table[key] 56 | wav = wav_table[wav_key] 57 | line = dict(key=key, wav=wav, txt=txt, start=start, end=end) 58 | json_line = json.dumps(line, ensure_ascii=False) 59 | fout.write(json_line + '\n') 60 | -------------------------------------------------------------------------------- /track1_asr/tools/merge_scp2txt.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # encoding: utf-8 3 | 4 | from __future__ import print_function 5 | from __future__ import unicode_literals 6 | 7 | import argparse 8 | import codecs 9 | from distutils.util import strtobool 10 | from io import open 11 | import logging 12 | import sys 13 | 14 | PY2 = sys.version_info[0] == 2 15 | sys.stdin = codecs.getreader('utf-8')(sys.stdin if PY2 else sys.stdin.buffer) 16 | sys.stdout = codecs.getwriter('utf-8')( 17 | sys.stdout if PY2 else sys.stdout.buffer) 18 | 19 | 20 | # Special types: 21 | def shape(x): 22 | """Change str to List[int] 23 | 24 | >>> shape('3,5') 25 | [3, 5] 26 | >>> shape(' [3, 5] ') 27 | [3, 5] 28 | 29 | """ 30 | 31 | # x: ' [3, 5] ' -> '3, 5' 32 | x = x.strip() 33 | if x[0] == '[': 34 | x = x[1:] 35 | if x[-1] == ']': 36 | x = x[:-1] 37 | 38 | return list(map(int, x.split(','))) 39 | 40 | 41 | def get_parser(): 42 | parser = argparse.ArgumentParser( 43 | description='Given each file paths with such format as ' 44 | '::. type> can be omitted and the default ' 45 | 'is "str". e.g. {} ' 46 | '--input-scps feat:data/feats.scp shape:data/utt2feat_shape:shape ' 47 | '--input-scps feat:data/feats2.scp shape:data/utt2feat2_shape:shape ' 48 | '--output-scps text:data/text shape:data/utt2text_shape:shape ' 49 | '--scps utt2spk:data/utt2spk'.format(sys.argv[0]), 50 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 51 | parser.add_argument('--input-scps', 52 | type=str, 53 | nargs='*', 54 | action='append', 55 | default=[], 56 | help='files for the inputs') 57 | parser.add_argument('--output-scps', 58 | type=str, 59 | nargs='*', 60 | action='append', 61 | default=[], 62 | help='files for the outputs') 63 | parser.add_argument('--scps', 64 | type=str, 65 | nargs='+', 66 | default=[], 67 | help='The files except for the input and outputs') 68 | parser.add_argument('--verbose', 69 | '-V', 70 | default=1, 71 | type=int, 72 | help='Verbose option') 73 | parser.add_argument('--allow-one-column', 74 | type=strtobool, 75 | default=False, 76 | help='Allow one column in input scp files. ' 77 | 'In this case, the value will be empty string.') 78 | parser.add_argument('--out', 79 | '-O', 80 | type=str, 81 | help='The output filename. ' 82 | 'If omitted, then output to sys.stdout') 83 | return parser 84 | 85 | 86 | if __name__ == '__main__': 87 | parser = get_parser() 88 | args = parser.parse_args() 89 | args.scps = [args.scps] 90 | 91 | # logging info 92 | logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" 93 | if args.verbose > 0: 94 | logging.basicConfig(level=logging.INFO, format=logfmt) 95 | else: 96 | logging.basicConfig(level=logging.WARN, format=logfmt) 97 | 98 | inputs = {} 99 | assert (len(args.input_scps) == 1) 100 | for f in args.input_scps[0]: 101 | arr = f.strip().split(':') 102 | inputs[arr[0]] = arr[1] 103 | assert ('feat' in inputs) 104 | assert ('shape' in inputs) 105 | 106 | outputs = {} 107 | assert (len(args.output_scps) == 1) 108 | for f in args.output_scps[0]: 109 | arr = f.strip().split(':') 110 | outputs[arr[0]] = arr[1] 111 | assert ('shape' in outputs) 112 | assert ('text' in outputs) 113 | assert ('token' in outputs) 114 | assert ('tokenid' in outputs) 115 | 116 | files = [ 117 | inputs['feat'], inputs['shape'], outputs['text'], outputs['token'], 118 | outputs['tokenid'], outputs['shape'] 119 | ] 120 | fields = ['feat', 'feat_shape', 'text', 'token', 'tokenid', 'token_shape'] 121 | fids = [open(f, 'r', encoding='utf-8') for f in files] 122 | 123 | if args.out is None: 124 | out = sys.stdout 125 | else: 126 | out = open(args.out, 'w', encoding='utf-8') 127 | done = False 128 | while not done: 129 | for i, fid in enumerate(fids): 130 | line = fid.readline() 131 | if line == '': 132 | done = True 133 | break 134 | arr = line.strip().split() 135 | content = ' '.join(arr[1:]) 136 | if i == 0: 137 | out.write('utt:{}'.format(arr[0])) 138 | out.write('\t') 139 | out.write('{}:{}'.format(fields[i], content)) 140 | out.write('\n') 141 | 142 | for f in fids: 143 | f.close() 144 | if args.out is not None: 145 | out.close() 146 | -------------------------------------------------------------------------------- /track1_asr/tools/parse_options.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright 2012 Johns Hopkins University (Author: Daniel Povey); 4 | # Arnab Ghoshal, Karel Vesely 5 | 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 13 | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED 14 | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, 15 | # MERCHANTABLITY OR NON-INFRINGEMENT. 16 | # See the Apache 2 License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | 20 | # Parse command-line options. 21 | # To be sourced by another script (as in ". parse_options.sh"). 22 | # Option format is: --option-name arg 23 | # and shell variable "option_name" gets set to value "arg." 24 | # The exception is --help, which takes no arguments, but prints the 25 | # $help_message variable (if defined). 26 | 27 | 28 | ### 29 | ### The --config file options have lower priority to command line 30 | ### options, so we need to import them first... 31 | ### 32 | 33 | # Now import all the configs specified by command-line, in left-to-right order 34 | for ((argpos=1; argpos<$#; argpos++)); do 35 | if [ "${!argpos}" == "--config" ]; then 36 | argpos_plus1=$((argpos+1)) 37 | config=${!argpos_plus1} 38 | [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1 39 | . $config # source the config file. 40 | fi 41 | done 42 | 43 | 44 | ### 45 | ### No we process the command line options 46 | ### 47 | while true; do 48 | [ -z "${1:-}" ] && break; # break if there are no arguments 49 | case "$1" in 50 | # If the enclosing script is called with --help option, print the help 51 | # message and exit. Scripts should put help messages in $help_message 52 | --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2; 53 | else printf "$help_message\n" 1>&2 ; fi; 54 | exit 0 ;; 55 | --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'" 56 | exit 1 ;; 57 | # If the first command-line argument begins with "--" (e.g. --foo-bar), 58 | # then work out the variable name as $name, which will equal "foo_bar". 59 | --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`; 60 | # Next we test whether the variable in question is undefned-- if so it's 61 | # an invalid option and we die. Note: $0 evaluates to the name of the 62 | # enclosing script. 63 | # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar 64 | # is undefined. We then have to wrap this test inside "eval" because 65 | # foo_bar is itself inside a variable ($name). 66 | eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; 67 | 68 | oldval="`eval echo \\$$name`"; 69 | # Work out whether we seem to be expecting a Boolean argument. 70 | if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then 71 | was_bool=true; 72 | else 73 | was_bool=false; 74 | fi 75 | 76 | # Set the variable to the right value-- the escaped quotes make it work if 77 | # the option had spaces, like --cmd "queue.pl -sync y" 78 | eval $name=\"$2\"; 79 | 80 | # Check that Boolean-valued arguments are really Boolean. 81 | if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then 82 | echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 83 | exit 1; 84 | fi 85 | shift 2; 86 | ;; 87 | *) break; 88 | esac 89 | done 90 | 91 | 92 | # Check for an empty argument to the --cmd option, which can easily occur as a 93 | # result of scripting errors. 94 | [ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1; 95 | 96 | 97 | true; # so this script returns exit code 0. 98 | -------------------------------------------------------------------------------- /track1_asr/tools/perturb_data_dir_speed.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # 2020 @kamo-naoyuki 4 | # This file was copied from Kaldi and 5 | # I deleted parts related to wav duration 6 | # because we shouldn't use kaldi's command here 7 | # and we don't need the files actually. 8 | 9 | # Copyright 2013 Johns Hopkins University (author: Daniel Povey) 10 | # 2014 Tom Ko 11 | # 2018 Emotech LTD (author: Pawel Swietojanski) 12 | # Apache 2.0 13 | 14 | # This script operates on a directory, such as in data/train/, 15 | # that contains some subset of the following files: 16 | # wav.scp 17 | # spk2utt 18 | # utt2spk 19 | # text 20 | # 21 | # It generates the files which are used for perturbing the speed of the original data. 22 | 23 | export LC_ALL=C 24 | set -euo pipefail 25 | 26 | if [[ $# != 3 ]]; then 27 | echo "Usage: perturb_data_dir_speed.sh " 28 | echo "e.g.:" 29 | echo " $0 0.9 data/train_si284 data/train_si284p" 30 | exit 1 31 | fi 32 | 33 | factor=$1 34 | srcdir=$2 35 | destdir=$3 36 | label="sp" 37 | spk_prefix="${label}${factor}-" 38 | utt_prefix="${label}${factor}-" 39 | 40 | #check is sox on the path 41 | 42 | ! command -v sox &>/dev/null && echo "sox: command not found" && exit 1; 43 | 44 | if [[ ! -f ${srcdir}/utt2spk ]]; then 45 | echo "$0: no such file ${srcdir}/utt2spk" 46 | exit 1; 47 | fi 48 | 49 | if [[ ${destdir} == "${srcdir}" ]]; then 50 | echo "$0: this script requires and to be different." 51 | exit 1 52 | fi 53 | 54 | mkdir -p "${destdir}" 55 | 56 | <"${srcdir}"/utt2spk awk -v p="${utt_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/utt_map" 57 | <"${srcdir}"/spk2utt awk -v p="${spk_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/spk_map" 58 | <"${srcdir}"/wav.scp awk -v p="${spk_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/reco_map" 59 | if [[ ! -f ${srcdir}/utt2uniq ]]; then 60 | <"${srcdir}/utt2spk" awk -v p="${utt_prefix}" '{printf("%s%s %s\n", p, $1, $1);}' > "${destdir}/utt2uniq" 61 | else 62 | <"${srcdir}/utt2uniq" awk -v p="${utt_prefix}" '{printf("%s%s %s\n", p, $1, $2);}' > "${destdir}/utt2uniq" 63 | fi 64 | 65 | 66 | <"${srcdir}"/utt2spk utils/apply_map.pl -f 1 "${destdir}"/utt_map | \ 67 | utils/apply_map.pl -f 2 "${destdir}"/spk_map >"${destdir}"/utt2spk 68 | 69 | utils/utt2spk_to_spk2utt.pl <"${destdir}"/utt2spk >"${destdir}"/spk2utt 70 | 71 | if [[ -f ${srcdir}/segments ]]; then 72 | 73 | utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/segments | \ 74 | utils/apply_map.pl -f 2 "${destdir}"/reco_map | \ 75 | awk -v factor="${factor}" \ 76 | '{s=$3/factor; e=$4/factor; if (e > s + 0.01) { printf("%s %s %.2f %.2f\n", $1, $2, $3/factor, $4/factor);} }' \ 77 | >"${destdir}"/segments 78 | 79 | utils/apply_map.pl -f 1 "${destdir}"/reco_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \ 80 | # Handle three cases of rxfilenames appropriately; "input piped command", "file offset" and "filename" 81 | awk -v factor="${factor}" \ 82 | '{wid=$1; $1=""; if ($NF=="|") {print wid $_ " sox -t wav - -t wav - speed " factor " |"} 83 | else if (match($0, /:[0-9]+$/)) {print wid " wav-copy" $_ " - | sox -t wav - -t wav - speed " factor " |" } 84 | else {print wid " sox" $_ " -t wav - speed " factor " |"}}' \ 85 | > "${destdir}"/wav.scp 86 | if [[ -f ${srcdir}/reco2file_and_channel ]]; then 87 | utils/apply_map.pl -f 1 "${destdir}"/reco_map \ 88 | <"${srcdir}"/reco2file_and_channel >"${destdir}"/reco2file_and_channel 89 | fi 90 | 91 | else # no segments->wav indexed by utterance. 92 | if [[ -f ${srcdir}/wav.scp ]]; then 93 | utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \ 94 | # Handle three cases of rxfilenames appropriately; "input piped command", "file offset" and "filename" 95 | awk -v factor="${factor}" \ 96 | '{wid=$1; $1=""; if ($NF=="|") {print wid $_ " sox -t wav - -t wav - speed " factor " |"} 97 | else if (match($0, /:[0-9]+$/)) {print wid " wav-copy" $_ " - | sox -t wav - -t wav - speed " factor " |" } 98 | else {print wid " sox" $_ " -t wav - speed " factor " |"}}' \ 99 | > "${destdir}"/wav.scp 100 | fi 101 | fi 102 | 103 | if [[ -f ${srcdir}/text ]]; then 104 | utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/text >"${destdir}"/text 105 | fi 106 | if [[ -f ${srcdir}/spk2gender ]]; then 107 | utils/apply_map.pl -f 1 "${destdir}"/spk_map <"${srcdir}"/spk2gender >"${destdir}"/spk2gender 108 | fi 109 | if [[ -f ${srcdir}/utt2lang ]]; then 110 | utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/utt2lang >"${destdir}"/utt2lang 111 | fi 112 | 113 | rm "${destdir}"/spk_map "${destdir}"/utt_map "${destdir}"/reco_map 2>/dev/null 114 | echo "$0: generated speed-perturbed version of data in ${srcdir}, in ${destdir}" 115 | 116 | utils/validate_data_dir.sh --no-feats --no-text "${destdir}" 117 | -------------------------------------------------------------------------------- /track1_asr/tools/reduce_data_dir.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # koried, 10/29/2012 4 | 5 | # Reduce a data set based on a list of turn-ids 6 | 7 | help_message="usage: $0 srcdir turnlist destdir" 8 | 9 | if [ $1 == "--help" ]; then 10 | echo "${help_message}" 11 | exit 0; 12 | fi 13 | 14 | if [ $# != 3 ]; then 15 | echo "${help_message}" 16 | exit 1; 17 | fi 18 | 19 | srcdir=$1 20 | reclist=$2 21 | destdir=$3 22 | 23 | if [ ! -f ${srcdir}/utt2spk ]; then 24 | echo "$0: no such file $srcdir/utt2spk" 25 | exit 1; 26 | fi 27 | 28 | function do_filtering { 29 | # assumes the utt2spk and spk2utt files already exist. 30 | [ -f ${srcdir}/feats.scp ] && utils/filter_scp.pl ${destdir}/utt2spk <${srcdir}/feats.scp >${destdir}/feats.scp 31 | [ -f ${srcdir}/wav.scp ] && utils/filter_scp.pl ${destdir}/utt2spk <${srcdir}/wav.scp >${destdir}/wav.scp 32 | [ -f ${srcdir}/text ] && utils/filter_scp.pl ${destdir}/utt2spk <${srcdir}/text >${destdir}/text 33 | [ -f ${srcdir}/utt2num_frames ] && utils/filter_scp.pl ${destdir}/utt2spk <${srcdir}/utt2num_frames >${destdir}/utt2num_frames 34 | [ -f ${srcdir}/spk2gender ] && utils/filter_scp.pl ${destdir}/spk2utt <${srcdir}/spk2gender >${destdir}/spk2gender 35 | [ -f ${srcdir}/cmvn.scp ] && utils/filter_scp.pl ${destdir}/spk2utt <${srcdir}/cmvn.scp >${destdir}/cmvn.scp 36 | if [ -f ${srcdir}/segments ]; then 37 | utils/filter_scp.pl ${destdir}/utt2spk <${srcdir}/segments >${destdir}/segments 38 | awk '{print $2;}' ${destdir}/segments | sort | uniq > ${destdir}/reco # recordings. 39 | # The next line would override the command above for wav.scp, which would be incorrect. 40 | [ -f ${srcdir}/wav.scp ] && utils/filter_scp.pl ${destdir}/reco <${srcdir}/wav.scp >${destdir}/wav.scp 41 | [ -f ${srcdir}/reco2file_and_channel ] && \ 42 | utils/filter_scp.pl ${destdir}/reco <${srcdir}/reco2file_and_channel >${destdir}/reco2file_and_channel 43 | 44 | # Filter the STM file for proper sclite scoring (this will also remove the comments lines) 45 | [ -f ${srcdir}/stm ] && utils/filter_scp.pl ${destdir}/reco < ${srcdir}/stm > ${destdir}/stm 46 | rm ${destdir}/reco 47 | fi 48 | srcutts=$(wc -l < ${srcdir}/utt2spk) 49 | destutts=$(wc -l < ${destdir}/utt2spk) 50 | echo "Reduced #utt from $srcutts to $destutts" 51 | } 52 | 53 | mkdir -p ${destdir} 54 | 55 | # filter the utt2spk based on the set of recordings 56 | utils/filter_scp.pl ${reclist} < ${srcdir}/utt2spk > ${destdir}/utt2spk 57 | 58 | utils/utt2spk_to_spk2utt.pl < ${destdir}/utt2spk > ${destdir}/spk2utt 59 | do_filtering; 60 | -------------------------------------------------------------------------------- /track1_asr/tools/remove_longshortdata.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # encoding: utf-8 3 | 4 | import argparse 5 | 6 | if __name__ == '__main__': 7 | parser = argparse.ArgumentParser( 8 | description='remove too long or too short data in format.data') 9 | parser.add_argument('--data_file', 10 | type=str, 11 | help='input format data') 12 | parser.add_argument('--output_data_file', 13 | type=str, 14 | help='output format data') 15 | parser.add_argument('--min_input_len', type=float, 16 | default=0, 17 | help='minimum input seq length, in seconds for raw wav, \ 18 | in frame numbers for feature data') 19 | parser.add_argument('--max_input_len', type=float, 20 | default=20, 21 | help='maximum output seq length, in seconds for raw wav, \ 22 | in frame numbers for feature data') 23 | parser.add_argument('--min_output_len', type=float, 24 | default=0, help='minimum input seq length, in modeling units') 25 | parser.add_argument('--max_output_len', type=float, 26 | default=500, 27 | help='maximum output seq length, in modeling units') 28 | parser.add_argument('--min_output_input_ratio', type=float, default=0.05, 29 | help='minimum output seq length/output seq length ratio') 30 | parser.add_argument('--max_output_input_ratio', type=float, default=10, 31 | help='maximum output seq length/output seq length ratio') 32 | args = parser.parse_args() 33 | 34 | data_file = args.data_file 35 | output_data_file = args.output_data_file 36 | min_input_len = args.min_input_len 37 | max_input_len = args.max_input_len 38 | min_output_len = args.min_output_len 39 | max_output_len = args.max_output_len 40 | min_output_input_ratio = args.min_output_input_ratio 41 | max_output_input_ratio = args.max_output_input_ratio 42 | 43 | with open(data_file, 'r') as f, open(output_data_file, 'w') as fout: 44 | for l in f: 45 | l = l.strip() 46 | if l: 47 | items = l.strip().split('\t') 48 | token_shape = items[6] 49 | feature_shape = items[2] 50 | feat_len = float(feature_shape.split(':')[1].split(',')[0]) 51 | token_len = float(token_shape.split(':')[1].split(',')[0]) 52 | condition = [feat_len > min_input_len, 53 | feat_len < max_input_len, 54 | token_len > min_output_len, 55 | token_len < max_output_len, 56 | token_len / feat_len > min_output_input_ratio, 57 | token_len / feat_len < max_output_input_ratio, 58 | ] 59 | if all(condition): 60 | fout.write('{}\n'.format(l)) 61 | continue 62 | -------------------------------------------------------------------------------- /track1_asr/tools/segment.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2021 Mobvoi Inc. (Di Wu) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | import argparse 11 | 12 | if __name__ == '__main__': 13 | parser = argparse.ArgumentParser(description='generate segmented wav.scp') 14 | parser.add_argument('--segments', required=True, help='segments file') 15 | parser.add_argument('--input', 16 | required=True, 17 | help='origin wav.scp that not segmented') 18 | parser.add_argument('--output', 19 | required=True, 20 | help='output segmented wav.scp') 21 | wav_dic = {} 22 | args = parser.parse_args() 23 | ori_wav = args.input 24 | segment_file = args.segments 25 | wav_scp = args.output 26 | with open(ori_wav, 'r') as ori: 27 | for l in ori: 28 | item = l.strip().split() 29 | wav_dic[item[0]] = item[1] 30 | with open(wav_scp, 'w') as f, open(segment_file, 'r') as sgement: 31 | for l in sgement: 32 | item = l.strip().split() 33 | if item[1] in wav_dic: 34 | item[1] = wav_dic[item[1]] 35 | f.write("{} {},{},{}\n".format(item[0], item[1], item[2], item[3])) 36 | -------------------------------------------------------------------------------- /track1_asr/tools/setup_anaconda.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # NOTE(hslee): this code is borrowed from ESPnet (https://github.com/espnet/espnet) 3 | set -euo pipefail 4 | 5 | if [ -z "${PS1:-}" ]; then 6 | PS1=__dummy__ 7 | fi 8 | CONDA_URL=https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh 9 | 10 | if [ $# -gt 4 ]; then 11 | echo "Usage: $0 [output] [conda-env-name] [python-version>]" 12 | exit 1; 13 | elif [ $# -eq 3 ]; then 14 | output_dir="$1" 15 | name="$2" 16 | PYTHON_VERSION="$3" 17 | elif [ $# -eq 2 ]; then 18 | output_dir="$1" 19 | name="$2" 20 | PYTHON_VERSION="" 21 | elif [ $# -eq 1 ]; then 22 | output_dir="$1" 23 | name="" 24 | PYTHON_VERSION="" 25 | elif [ $# -eq 0 ]; then 26 | output_dir=venv 27 | name="" 28 | PYTHON_VERSION="" 29 | fi 30 | 31 | if [ -e activate_python.sh ]; then 32 | echo "Warning: activate_python.sh already exists. It will be overwritten" 33 | fi 34 | 35 | if [ ! -e "${output_dir}/etc/profile.d/conda.sh" ]; then 36 | if [ ! -e miniconda.sh ]; then 37 | wget --tries=3 "${CONDA_URL}" -O miniconda.sh 38 | fi 39 | 40 | bash miniconda.sh -b -p "${output_dir}" 41 | fi 42 | 43 | # shellcheck disable=SC1090 44 | source "${output_dir}/etc/profile.d/conda.sh" 45 | conda deactivate 46 | 47 | # If the env already exists, skip recreation 48 | if [ -n "${name}" ] && ! conda activate ${name}; then 49 | conda create -yn "${name}" 50 | fi 51 | conda activate ${name} 52 | 53 | if [ -n "${PYTHON_VERSION}" ]; then 54 | conda install -y conda "python=${PYTHON_VERSION}" 55 | else 56 | conda install -y conda 57 | fi 58 | 59 | conda install -y pip setuptools 60 | 61 | cat << EOF > activate_python.sh 62 | #!/usr/bin/env bash 63 | # THIS FILE IS GENERATED BY tools/setup_anaconda.sh 64 | if [ -z "\${PS1:-}" ]; then 65 | PS1=__dummy__ 66 | fi 67 | . $(cd ${output_dir}; pwd)/etc/profile.d/conda.sh && conda deactivate && conda activate ${name} 68 | EOF 69 | -------------------------------------------------------------------------------- /track1_asr/tools/sph2wav.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # convert sph scp to segmented wav scp 3 | nj=1 4 | . tools/parse_options.sh || exit 1; 5 | 6 | inscp=$1 7 | segments=$2 8 | outscp=$3 9 | data=$(dirname ${inscp}) 10 | if [ $# -eq 4 ]; then 11 | logdir=$4 12 | else 13 | logdir=${data}/log 14 | fi 15 | mkdir -p ${logdir} 16 | 17 | sph2pipe_version="v2.5" 18 | if [ ! -d tools/sph2pipe_${sph2pipe_version} ]; then 19 | echo "Download sph2pipe_${sph2pipe_version} ......" 20 | wget -T 10 -t 3 -P tools https://www.openslr.org/resources/3/sph2pipe_${sph2pipe_version}.tar.gz || \ 21 | wget -T 10 -c -P tools https://sourceforge.net/projects/kaldi/files/sph2pipe_${sph2pipe_version}.tar.gz; \ 22 | tar --no-same-owner -xzf tools/sph2pipe_${sph2pipe_version}.tar.gz -C tools 23 | cd tools/sph2pipe_${sph2pipe_version}/ && \ 24 | gcc -o sph2pipe *.c -lm 25 | cd - 26 | fi 27 | sph2pipe=`which sph2pipe` || sph2pipe=`pwd`/tools/sph2pipe_${sph2pipe_version}/sph2pipe 28 | [ ! -x $sph2pipe ] && echo "Could not find the sph2pipe program at $sph2pipe" && exit 1; 29 | sox=`which sox` 30 | [ ! -x $sox ] && echo "Could not find the sox program at $sph2pipe" && exit 1; 31 | 32 | cat $inscp | awk -v sph2pipe=$sph2pipe '{printf("%s-A %s#-f#wav#-p#-c#1#%s#|\n", $1, sph2pipe, $2); 33 | printf("%s-B %s#-f#wav#-p#-c#2#%s#|\n", $1, sph2pipe, $2);}' | \ 34 | sort > $data/wav_ori.scp || exit 1; 35 | 36 | tools/segment.py --segments $segments --input $data/wav_ori.scp --output $data/wav_segments.scp 37 | sed -i 's/ /,/g' $data/wav_segments.scp 38 | sed -i 's/#/ /g' $data/wav_segments.scp 39 | 40 | rm -f $logdir/wav_*.slice 41 | rm -f $logdir/*.log 42 | split --additional-suffix .slice -d -n l/$nj $data/wav_segments.scp $logdir/wav_ 43 | 44 | for slice in `ls $logdir/wav_*.slice`; do 45 | { 46 | name=`basename -s .slice $slice` 47 | mkdir -p ${data}/wavs/${name} 48 | cat ${slice} | awk -F ',' -v sox=$sox -v data=`pwd`/$data/wavs/$name \ 49 | -v logdir=$logdir -v name=$name '{ 50 | during=$4-$3 51 | cmd=$2 sox " - " data "/" $1 ".wav" " trim " $3 " " during; 52 | system(cmd) 53 | printf("%s %s/%s.wav\n", $1, data, $1); 54 | }' | \ 55 | sort > ${data}/wavs_${name}.scp || exit 1; 56 | } & 57 | done 58 | wait 59 | cat ${data}/wavs_*.scp > $outscp 60 | rm ${data}/wavs_*.scp 61 | -------------------------------------------------------------------------------- /track1_asr/tools/spk2utt_to_utt2spk.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # Copyright 2010-2011 Microsoft Corporation 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 11 | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED 12 | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, 13 | # MERCHANTABLITY OR NON-INFRINGEMENT. 14 | # See the Apache 2 License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | 18 | while(<>){ 19 | @A = split(" ", $_); 20 | @A > 1 || die "Invalid line in spk2utt file: $_"; 21 | $s = shift @A; 22 | foreach $u ( @A ) { 23 | print "$u $s\n"; 24 | } 25 | } 26 | 27 | 28 | -------------------------------------------------------------------------------- /track1_asr/tools/spm_decode: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # https://github.com/pytorch/fairseq/blob/master/LICENSE 7 | 8 | from __future__ import absolute_import, division, print_function, unicode_literals 9 | 10 | import argparse 11 | import sys 12 | 13 | import sentencepiece as spm 14 | 15 | 16 | def main(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--model", required=True, 19 | help="sentencepiece model to use for decoding") 20 | parser.add_argument("--input", default=None, help="input file to decode") 21 | parser.add_argument("--input_format", choices=["piece", "id"], default="piece") 22 | args = parser.parse_args() 23 | 24 | sp = spm.SentencePieceProcessor() 25 | sp.Load(args.model) 26 | 27 | if args.input_format == "piece": 28 | def decode(l): 29 | return "".join(sp.DecodePieces(l)) 30 | elif args.input_format == "id": 31 | def decode(l): 32 | return "".join(sp.DecodeIds(l)) 33 | else: 34 | raise NotImplementedError 35 | 36 | def tok2int(tok): 37 | # remap reference-side (represented as <>) to 0 38 | return int(tok) if tok != "<>" else 0 39 | 40 | if args.input is None: 41 | h = sys.stdin 42 | else: 43 | h = open(args.input, "r", encoding="utf-8") 44 | for line in h: 45 | print(decode(line.split())) 46 | 47 | 48 | if __name__ == "__main__": 49 | main() 50 | -------------------------------------------------------------------------------- /track1_asr/tools/spm_encode: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in 6 | # https://github.com/pytorch/fairseq/blob/master/LICENSE 7 | 8 | from __future__ import absolute_import, division, print_function, unicode_literals 9 | 10 | import argparse 11 | import contextlib 12 | import sys 13 | 14 | import sentencepiece as spm 15 | 16 | 17 | def main(): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--model", required=True, 20 | help="sentencepiece model to use for encoding") 21 | parser.add_argument("--inputs", nargs="+", default=['-'], 22 | help="input files to filter/encode") 23 | parser.add_argument("--outputs", nargs="+", default=['-'], 24 | help="path to save encoded outputs") 25 | parser.add_argument("--output_format", choices=["piece", "id"], default="piece") 26 | parser.add_argument("--min-len", type=int, metavar="N", 27 | help="filter sentence pairs with fewer than N tokens") 28 | parser.add_argument("--max-len", type=int, metavar="N", 29 | help="filter sentence pairs with more than N tokens") 30 | args = parser.parse_args() 31 | 32 | assert len(args.inputs) == len(args.outputs), \ 33 | "number of input and output paths should match" 34 | 35 | sp = spm.SentencePieceProcessor() 36 | sp.Load(args.model) 37 | 38 | if args.output_format == "piece": 39 | def encode(l): 40 | return sp.EncodeAsPieces(l) 41 | elif args.output_format == "id": 42 | def encode(l): 43 | return list(map(str, sp.EncodeAsIds(l))) 44 | else: 45 | raise NotImplementedError 46 | 47 | if args.min_len is not None or args.max_len is not None: 48 | def valid(line): 49 | return ( 50 | (args.min_len is None or len(line) >= args.min_len) and 51 | (args.max_len is None or len(line) <= args.max_len) 52 | ) 53 | else: 54 | def valid(lines): 55 | return True 56 | 57 | with contextlib.ExitStack() as stack: 58 | inputs = [ 59 | stack.enter_context(open(input, "r", encoding="utf-8")) 60 | if input != "-" else sys.stdin 61 | for input in args.inputs 62 | ] 63 | outputs = [ 64 | stack.enter_context(open(output, "w", encoding="utf-8")) 65 | if output != "-" else sys.stdout 66 | for output in args.outputs 67 | ] 68 | 69 | stats = { 70 | "num_empty": 0, 71 | "num_filtered": 0, 72 | } 73 | 74 | def encode_line(line): 75 | line = line.strip() 76 | if len(line) > 0: 77 | line = encode(line) 78 | if valid(line): 79 | return line 80 | else: 81 | stats["num_filtered"] += 1 82 | else: 83 | stats["num_empty"] += 1 84 | return None 85 | 86 | for i, lines in enumerate(zip(*inputs), start=1): 87 | enc_lines = list(map(encode_line, lines)) 88 | if not any(enc_line is None for enc_line in enc_lines): 89 | for enc_line, output_h in zip(enc_lines, outputs): 90 | print(" ".join(enc_line), file=output_h) 91 | if i % 10000 == 0: 92 | print("processed {} lines".format(i), file=sys.stderr) 93 | 94 | print("skipped {} empty lines".format(stats["num_empty"]), file=sys.stderr) 95 | print("filtered {} lines".format(stats["num_filtered"]), file=sys.stderr) 96 | 97 | 98 | if __name__ == "__main__": 99 | main() 100 | -------------------------------------------------------------------------------- /track1_asr/tools/spm_train: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # https://github.com/pytorch/fairseq/blob/master/LICENSE 7 | import sys 8 | 9 | import sentencepiece as spm 10 | 11 | 12 | if __name__ == "__main__": 13 | spm.SentencePieceTrainer.Train(" ".join(sys.argv[1:])) 14 | -------------------------------------------------------------------------------- /track1_asr/tools/subset_scp.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | use warnings; #sed replacement for -w perl parameter 3 | # Copyright 2010-2011 Microsoft Corporation 4 | 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 12 | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED 13 | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, 14 | # MERCHANTABLITY OR NON-INFRINGEMENT. 15 | # See the Apache 2 License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | # This program selects a subset of N elements in the scp. 19 | 20 | # By default, it selects them evenly from throughout the scp, in order to avoid 21 | # selecting too many from the same speaker. It prints them on the standard 22 | # output. 23 | # With the option --first, it just selects the N first utterances. 24 | # With the option --last, it just selects the N last utterances. 25 | 26 | # Last modified by JHU & HKUST @2013 27 | 28 | 29 | $quiet = 0; 30 | $first = 0; 31 | $last = 0; 32 | 33 | if (@ARGV > 0 && $ARGV[0] eq "--quiet") { 34 | shift; 35 | $quiet = 1; 36 | } 37 | if (@ARGV > 0 && $ARGV[0] eq "--first") { 38 | shift; 39 | $first = 1; 40 | } 41 | if (@ARGV > 0 && $ARGV[0] eq "--last") { 42 | shift; 43 | $last = 1; 44 | } 45 | 46 | if(@ARGV < 2 ) { 47 | die "Usage: subset_scp.pl [--quiet][--first|--last] N in.scp\n" . 48 | " --quiet causes it to not die if N < num lines in scp.\n" . 49 | " --first and --last make it equivalent to head or tail.\n" . 50 | "See also: filter_scp.pl\n"; 51 | } 52 | 53 | $N = shift @ARGV; 54 | if($N == 0) { 55 | die "First command-line parameter to subset_scp.pl must be an integer, got \"$N\""; 56 | } 57 | $inscp = shift @ARGV; 58 | open(I, "<$inscp") || die "Opening input scp file $inscp"; 59 | 60 | @F = (); 61 | while() { 62 | push @F, $_; 63 | } 64 | $numlines = @F; 65 | if($N > $numlines) { 66 | if ($quiet) { 67 | $N = $numlines; 68 | } else { 69 | die "You requested from subset_scp.pl more elements than available: $N > $numlines"; 70 | } 71 | } 72 | 73 | sub select_n { 74 | my ($start,$end,$num_needed) = @_; 75 | my $diff = $end - $start; 76 | if ($num_needed > $diff) { 77 | die "select_n: code error"; 78 | } 79 | if ($diff == 1 ) { 80 | if ($num_needed > 0) { 81 | print $F[$start]; 82 | } 83 | } else { 84 | my $halfdiff = int($diff/2); 85 | my $halfneeded = int($num_needed/2); 86 | select_n($start, $start+$halfdiff, $halfneeded); 87 | select_n($start+$halfdiff, $end, $num_needed - $halfneeded); 88 | } 89 | } 90 | 91 | if ( ! $first && ! $last) { 92 | if ($N > 0) { 93 | select_n(0, $numlines, $N); 94 | } 95 | } else { 96 | if ($first) { # --first option: same as head. 97 | for ($n = 0; $n < $N; $n++) { 98 | print $F[$n]; 99 | } 100 | } else { # --last option: same as tail. 101 | for ($n = @F - $N; $n < @F; $n++) { 102 | print $F[$n]; 103 | } 104 | } 105 | } 106 | -------------------------------------------------------------------------------- /track1_asr/tools/sym2int.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # Copyright 2010-2012 Microsoft Corporation Johns Hopkins University (Author: Daniel Povey) 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 11 | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED 12 | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, 13 | # MERCHANTABLITY OR NON-INFRINGEMENT. 14 | # See the Apache 2 License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | 18 | $ignore_oov = 0; 19 | 20 | for($x = 0; $x < 2; $x++) { 21 | if ($ARGV[0] eq "--map-oov") { 22 | shift @ARGV; 23 | $map_oov = shift @ARGV; 24 | if ($map_oov eq "-f" || $map_oov =~ m/words\.txt$/ || $map_oov eq "") { 25 | # disallow '-f', the empty string and anything ending in words.txt as the 26 | # OOV symbol because these are likely command-line errors. 27 | die "the --map-oov option requires an argument"; 28 | } 29 | } 30 | if ($ARGV[0] eq "-f") { 31 | shift @ARGV; 32 | $field_spec = shift @ARGV; 33 | if ($field_spec =~ m/^\d+$/) { 34 | $field_begin = $field_spec - 1; $field_end = $field_spec - 1; 35 | } 36 | if ($field_spec =~ m/^(\d*)[-:](\d*)/) { # accept e.g. 1:10 as a courtesty (properly, 1-10) 37 | if ($1 ne "") { 38 | $field_begin = $1 - 1; # Change to zero-based indexing. 39 | } 40 | if ($2 ne "") { 41 | $field_end = $2 - 1; # Change to zero-based indexing. 42 | } 43 | } 44 | if (!defined $field_begin && !defined $field_end) { 45 | die "Bad argument to -f option: $field_spec"; 46 | } 47 | } 48 | } 49 | 50 | $symtab = shift @ARGV; 51 | if (!defined $symtab) { 52 | print STDERR "Usage: sym2int.pl [options] symtab [input transcriptions] > output transcriptions\n" . 53 | "options: [--map-oov ] [-f ]\n" . 54 | "note: can look like 4-5, or 4-, or 5-, or 1.\n"; 55 | } 56 | open(F, "<$symtab") || die "Error opening symbol table file $symtab"; 57 | while() { 58 | @A = split(" ", $_); 59 | @A == 2 || die "bad line in symbol table file: $_"; 60 | $sym2int{$A[0]} = $A[1] + 0; 61 | } 62 | 63 | if (defined $map_oov && $map_oov !~ m/^\d+$/) { # not numeric-> look it up 64 | if (!defined $sym2int{$map_oov}) { die "OOV symbol $map_oov not defined."; } 65 | $map_oov = $sym2int{$map_oov}; 66 | } 67 | 68 | $num_warning = 0; 69 | $max_warning = 20; 70 | 71 | while (<>) { 72 | @A = split(" ", $_); 73 | @B = (); 74 | for ($n = 0; $n < @A; $n++) { 75 | $a = $A[$n]; 76 | if ( (!defined $field_begin || $n >= $field_begin) 77 | && (!defined $field_end || $n <= $field_end)) { 78 | $i = $sym2int{$a}; 79 | if (!defined ($i)) { 80 | if (defined $map_oov) { 81 | if ($num_warning++ < $max_warning) { 82 | print STDERR "sym2int.pl: replacing $a with $map_oov\n"; 83 | if ($num_warning == $max_warning) { 84 | print STDERR "sym2int.pl: not warning for OOVs any more times\n"; 85 | } 86 | } 87 | $i = $map_oov; 88 | } else { 89 | $pos = $n+1; 90 | die "sym2int.pl: undefined symbol $a (in position $pos)\n"; 91 | } 92 | } 93 | $a = $i; 94 | } 95 | push @B, $a; 96 | } 97 | print join(" ", @B); 98 | print "\n"; 99 | } 100 | if ($num_warning > 0) { 101 | print STDERR "** Replaced $num_warning instances of OOVs with $map_oov\n"; 102 | } 103 | 104 | exit(0); 105 | -------------------------------------------------------------------------------- /track1_asr/tools/utt2spk_to_spk2utt.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # Copyright 2010-2011 Microsoft Corporation 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 11 | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED 12 | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, 13 | # MERCHANTABLITY OR NON-INFRINGEMENT. 14 | # See the Apache 2 License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # converts an utt2spk file to a spk2utt file. 18 | # Takes input from the stdin or from a file argument; 19 | # output goes to the standard out. 20 | 21 | if ( @ARGV > 1 ) { 22 | die "Usage: utt2spk_to_spk2utt.pl [ utt2spk ] > spk2utt"; 23 | } 24 | 25 | while(<>){ 26 | @A = split(" ", $_); 27 | @A == 2 || die "Invalid line in utt2spk file: $_"; 28 | ($u,$s) = @A; 29 | if(!$seen_spk{$s}) { 30 | $seen_spk{$s} = 1; 31 | push @spklist, $s; 32 | } 33 | push (@{$spk_hash{$s}}, "$u"); 34 | } 35 | foreach $s (@spklist) { 36 | $l = join(' ',@{$spk_hash{$s}}); 37 | print "$s $l\n"; 38 | } 39 | -------------------------------------------------------------------------------- /track1_asr/tools/validate_text.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # 3 | #=============================================================================== 4 | # Copyright 2017 Johns Hopkins University (author: Yenda Trmal ) 5 | # Johns Hopkins University (author: Daniel Povey) 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED 15 | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, 16 | # MERCHANTABLITY OR NON-INFRINGEMENT. 17 | # See the Apache 2 License for the specific language governing permissions and 18 | # limitations under the License. 19 | #=============================================================================== 20 | 21 | # validation script for data//text 22 | # to be called (preferably) from utils/validate_data_dir.sh 23 | use strict; 24 | use warnings; 25 | use utf8; 26 | use Fcntl qw< SEEK_SET >; 27 | 28 | # this function reads the opened file (supplied as a first 29 | # parameter) into an array of lines. For each 30 | # line, it tests whether it's a valid utf-8 compatible 31 | # line. If all lines are valid utf-8, it returns the lines 32 | # decoded as utf-8, otherwise it assumes the file's encoding 33 | # is one of those 1-byte encodings, such as ISO-8859-x 34 | # or Windows CP-X. 35 | # Please recall we do not really care about 36 | # the actually encoding, we just need to 37 | # make sure the length of the (decoded) string 38 | # is correct (to make the output formatting looking right). 39 | sub get_utf8_or_bytestream { 40 | use Encode qw(decode encode); 41 | my $is_utf_compatible = 1; 42 | my @unicode_lines; 43 | my @raw_lines; 44 | my $raw_text; 45 | my $lineno = 0; 46 | my $file = shift; 47 | 48 | while (<$file>) { 49 | $raw_text = $_; 50 | last unless $raw_text; 51 | if ($is_utf_compatible) { 52 | my $decoded_text = eval { decode("UTF-8", $raw_text, Encode::FB_CROAK) } ; 53 | $is_utf_compatible = $is_utf_compatible && defined($decoded_text); 54 | push @unicode_lines, $decoded_text; 55 | } else { 56 | #print STDERR "WARNING: the line $raw_text cannot be interpreted as UTF-8: $decoded_text\n"; 57 | ; 58 | } 59 | push @raw_lines, $raw_text; 60 | $lineno += 1; 61 | } 62 | 63 | if (!$is_utf_compatible) { 64 | return (0, @raw_lines); 65 | } else { 66 | return (1, @unicode_lines); 67 | } 68 | } 69 | 70 | # check if the given unicode string contain unicode whitespaces 71 | # other than the usual four: TAB, LF, CR and SPACE 72 | sub validate_utf8_whitespaces { 73 | my $unicode_lines = shift; 74 | use feature 'unicode_strings'; 75 | for (my $i = 0; $i < scalar @{$unicode_lines}; $i++) { 76 | my $current_line = $unicode_lines->[$i]; 77 | if ((substr $current_line, -1) ne "\n"){ 78 | print STDERR "$0: The current line (nr. $i) has invalid newline\n"; 79 | return 1; 80 | } 81 | my @A = split(" ", $current_line); 82 | my $utt_id = $A[0]; 83 | # we replace TAB, LF, CR, and SPACE 84 | # this is to simplify the test 85 | if ($current_line =~ /\x{000d}/) { 86 | print STDERR "$0: The line for utterance $utt_id contains CR (0x0D) character\n"; 87 | return 1; 88 | } 89 | $current_line =~ s/[\x{0009}\x{000a}\x{0020}]/./g; 90 | if ($current_line =~/\s/) { 91 | print STDERR "$0: The line for utterance $utt_id contains disallowed Unicode whitespaces\n"; 92 | return 1; 93 | } 94 | } 95 | return 0; 96 | } 97 | 98 | # checks if the text in the file (supplied as the argument) is utf-8 compatible 99 | # if yes, checks if it contains only allowed whitespaces. If no, then does not 100 | # do anything. The function seeks to the original position in the file after 101 | # reading the text. 102 | sub check_allowed_whitespace { 103 | my $file = shift; 104 | my $filename = shift; 105 | my $pos = tell($file); 106 | (my $is_utf, my @lines) = get_utf8_or_bytestream($file); 107 | seek($file, $pos, SEEK_SET); 108 | if ($is_utf) { 109 | my $has_invalid_whitespaces = validate_utf8_whitespaces(\@lines); 110 | if ($has_invalid_whitespaces) { 111 | print STDERR "$0: ERROR: text file '$filename' contains disallowed UTF-8 whitespace character(s)\n"; 112 | return 0; 113 | } 114 | } 115 | return 1; 116 | } 117 | 118 | if(@ARGV != 1) { 119 | die "Usage: validate_text.pl \n" . 120 | "e.g.: validate_text.pl data/train/text\n"; 121 | } 122 | 123 | my $text = shift @ARGV; 124 | 125 | if (-z "$text") { 126 | print STDERR "$0: ERROR: file '$text' is empty or does not exist\n"; 127 | exit 1; 128 | } 129 | 130 | if(!open(FILE, "<$text")) { 131 | print STDERR "$0: ERROR: failed to open $text\n"; 132 | exit 1; 133 | } 134 | 135 | check_allowed_whitespace(\*FILE, $text) or exit 1; 136 | close(FILE); 137 | -------------------------------------------------------------------------------- /track1_asr/tools/wav2dur.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # encoding: utf-8 3 | 4 | import sys 5 | 6 | import torchaudio 7 | torchaudio.set_audio_backend("sox_io") 8 | 9 | scp = sys.argv[1] 10 | dur_scp = sys.argv[2] 11 | 12 | with open(scp, 'r') as f, open(dur_scp, 'w') as fout: 13 | cnt = 0 14 | total_duration = 0 15 | for l in f: 16 | items = l.strip().split() 17 | wav_id = items[0] 18 | fname = items[1] 19 | cnt += 1 20 | waveform, rate = torchaudio.load(fname) 21 | frames = len(waveform[0]) 22 | duration = frames / float(rate) 23 | total_duration += duration 24 | fout.write('{} {}\n'.format(wav_id, duration)) 25 | print('process {} utts'.format(cnt)) 26 | print('total {} s'.format(total_duration)) 27 | -------------------------------------------------------------------------------- /track1_asr/tools/wav_to_duration.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # split the wav scp, calculate duration and merge 3 | nj=4 4 | . tools/parse_options.sh || exit 1; 5 | 6 | inscp=$1 7 | outscp=$2 8 | data=$(dirname ${inscp}) 9 | if [ $# -eq 3 ]; then 10 | logdir=$3 11 | else 12 | logdir=${data}/log 13 | fi 14 | mkdir -p ${logdir} 15 | 16 | rm -f $logdir/wav_*.slice 17 | rm -f $logdir/wav_*.shape 18 | split --additional-suffix .slice -d -n l/$nj $inscp $logdir/wav_ 19 | 20 | for slice in `ls $logdir/wav_*.slice`; do 21 | { 22 | name=`basename -s .slice $slice` 23 | tools/wav2dur.py $slice $logdir/$name.shape 1>$logdir/$name.log 24 | } & 25 | done 26 | wait 27 | cat $logdir/wav_*.shape > $outscp 28 | -------------------------------------------------------------------------------- /track1_asr/wenet/README.md: -------------------------------------------------------------------------------- 1 | # Module Introduction 2 | 3 | Here is a brief introduction of each module(directory). 4 | 5 | * `bin`: training and recognition binaries 6 | * `dataset`: IO design 7 | * `utils`: common utils 8 | * `transformer`: the core of `WeNet`, in which the standard transformer/conformer is implemented. It contains the common blocks(backbone) of speech transformers. 9 | * transformer/attention.py: Standard multi head attention 10 | * transformer/embedding.py: Standard position encoding 11 | * transformer/positionwise_feed_forward.py: Standard feed forward in transformer 12 | * transformer/convolution.py: ConvolutionModule in Conformer model 13 | * transformer/subsampling.py: Subsampling implementation for speech task 14 | * `transducer`: transducer implementation 15 | * `squeezeformer`: squeezeformer implementation, please refer [paper](https://arxiv.org/pdf/2206.00888.pdf) 16 | * `efficient_conformer`: efficient conformer implementation, please refer [paper](https://arxiv.org/pdf/2109.01163.pdf) 17 | * `cif`: Continuous Integrate-and-Fire implemented, please refer [paper](https://arxiv.org/pdf/1905.11235.pdf) 18 | * `branchformer`: branchformer implementation, please refer [paper](https://arxiv.org/abs/2207.02971) 19 | 20 | 21 | `transducer`, `squeezeformer`, `efficient_conformer`, `branchformer` and `cif` are all based on `transformer`, 22 | they resue a lot of the common blocks of `tranformer`. 23 | 24 | **If you want to contribute your own x-former, please reuse the current code as much as possible**. 25 | 26 | 27 | -------------------------------------------------------------------------------- /track1_asr/wenet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrSupW/ICMC-ASR_Baseline/355626cef459e58a0fe7cc62af1326d67a43a0d7/track1_asr/wenet/__init__.py -------------------------------------------------------------------------------- /track1_asr/wenet/bin/average_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Mobvoi Inc (Di Wu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import os 17 | import argparse 18 | import glob 19 | 20 | import yaml 21 | import numpy as np 22 | import torch 23 | 24 | 25 | def get_args(): 26 | parser = argparse.ArgumentParser(description='average model') 27 | parser.add_argument('--dst_model', required=True, help='averaged model') 28 | parser.add_argument('--src_path', 29 | required=True, 30 | help='src model path for average') 31 | parser.add_argument('--val_best', 32 | action="store_true", 33 | help='averaged model') 34 | parser.add_argument('--num', 35 | default=5, 36 | type=int, 37 | help='nums for averaged model') 38 | parser.add_argument('--min_epoch', 39 | default=0, 40 | type=int, 41 | help='min epoch used for averaging model') 42 | parser.add_argument('--max_epoch', 43 | default=65536, 44 | type=int, 45 | help='max epoch used for averaging model') 46 | 47 | args = parser.parse_args() 48 | print(args) 49 | return args 50 | 51 | 52 | def main(): 53 | args = get_args() 54 | checkpoints = [] 55 | val_scores = [] 56 | if args.val_best: 57 | yamls = glob.glob('{}/[!train]*.yaml'.format(args.src_path)) 58 | for y in yamls: 59 | with open(y, 'r') as f: 60 | dic_yaml = yaml.load(f, Loader=yaml.FullLoader) 61 | loss = dic_yaml['cv_loss'] 62 | epoch = dic_yaml['epoch'] 63 | if epoch >= args.min_epoch and epoch <= args.max_epoch: 64 | val_scores += [[epoch, loss]] 65 | val_scores = np.array(val_scores) 66 | sort_idx = np.argsort(val_scores[:, -1]) 67 | sorted_val_scores = val_scores[sort_idx][::1] 68 | print("best val scores = " + str(sorted_val_scores[:args.num, 1])) 69 | print("selected epochs = " + 70 | str(sorted_val_scores[:args.num, 0].astype(np.int64))) 71 | path_list = [ 72 | args.src_path + '/{}.pt'.format(int(epoch)) 73 | for epoch in sorted_val_scores[:args.num, 0] 74 | ] 75 | else: 76 | path_list = glob.glob('{}/[0-9]*.pt'.format(args.src_path)) 77 | path_list = sorted(path_list, key=os.path.getmtime) 78 | path_list = path_list[-args.num:] 79 | print(path_list) 80 | avg = None 81 | num = args.num 82 | assert num == len(path_list) 83 | for path in path_list: 84 | print('Processing {}'.format(path)) 85 | states = torch.load(path, map_location=torch.device('cpu')) 86 | if avg is None: 87 | avg = states 88 | else: 89 | for k in avg.keys(): 90 | avg[k] += states[k] 91 | # average 92 | for k in avg.keys(): 93 | if avg[k] is not None: 94 | # pytorch 1.6 use true_divide instead of /= 95 | avg[k] = torch.true_divide(avg[k], num) 96 | print('Saving to {}'.format(args.dst_model)) 97 | torch.save(avg, args.dst_model) 98 | 99 | 100 | if __name__ == '__main__': 101 | main() 102 | -------------------------------------------------------------------------------- /track1_asr/wenet/bin/export_ipex.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2023 Intel Corporation 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | from __future__ import print_function 5 | 6 | import argparse 7 | import os 8 | 9 | import torch 10 | import yaml 11 | 12 | from wenet.utils.checkpoint import load_checkpoint 13 | from wenet.utils.init_model import init_model 14 | import intel_extension_for_pytorch as ipex 15 | from intel_extension_for_pytorch.quantization import prepare, convert 16 | 17 | def get_args(): 18 | parser = argparse.ArgumentParser(description='export your script model') 19 | parser.add_argument('--config', required=True, help='config file') 20 | parser.add_argument('--checkpoint', required=True, help='checkpoint model') 21 | parser.add_argument('--output_file', default=None, help='output file') 22 | parser.add_argument('--dtype', 23 | default="fp32", 24 | help='choose the dtype to run:[fp32,bf16]') 25 | parser.add_argument('--output_quant_file', 26 | default=None, 27 | help='output quantized model file') 28 | args = parser.parse_args() 29 | return args 30 | 31 | def scripting(model): 32 | with torch.inference_mode(): 33 | script_model = torch.jit.script(model) 34 | script_model = torch.jit.freeze( 35 | script_model, 36 | preserved_attrs=["forward_encoder_chunk", 37 | "ctc_activation", 38 | "forward_attention_decoder", 39 | "subsampling_rate", 40 | "right_context", 41 | "sos_symbol", 42 | "eos_symbol", 43 | "is_bidirectional_decoder"] 44 | ) 45 | return script_model 46 | 47 | def main(): 48 | args = get_args() 49 | # No need gpu for model export 50 | os.environ['CUDA_VISIBLE_DEVICES'] = '-1' 51 | 52 | with open(args.config, 'r') as fin: 53 | configs = yaml.load(fin, Loader=yaml.FullLoader) 54 | model = init_model(configs) 55 | print(model) 56 | 57 | load_checkpoint(model, args.checkpoint) 58 | 59 | # Apply IPEX optimization 60 | model.eval() 61 | torch._C._jit_set_texpr_fuser_enabled(False) 62 | model.to(memory_format=torch.channels_last) 63 | if args.dtype == "fp32": 64 | ipex_model = ipex.optimize(model) 65 | elif args.dtype == "bf16": # For Intel 4th generation Xeon (SPR) 66 | ipex_model = ipex.optimize(model, dtype=torch.bfloat16, weights_prepack=False) 67 | 68 | # Export jit torch script model 69 | if args.output_file: 70 | if args.dtype == "fp32": 71 | script_model = scripting(ipex_model) 72 | elif args.dtype == "bf16": 73 | torch._C._jit_set_autocast_mode(True) 74 | with torch.cpu.amp.autocast(): 75 | script_model = scripting(ipex_model) 76 | script_model.save(args.output_file) 77 | print('Export model successfully, see {}'.format(args.output_file)) 78 | 79 | # Export quantized jit torch script model 80 | if args.output_quant_file: 81 | dynamic_qconfig = ipex.quantization.default_dynamic_qconfig 82 | dummy_data = (torch.zeros(1, 67, 80), 83 | 16, 84 | -16, 85 | torch.zeros(12, 4, 32, 128), 86 | torch.zeros(12, 1, 256, 7)) 87 | model = prepare(model, dynamic_qconfig, dummy_data) 88 | model = convert(model) 89 | script_quant_model = scripting(model) 90 | script_quant_model.save(args.output_quant_file) 91 | print('Export quantized model successfully, ' 92 | 'see {}'.format(args.output_quant_file)) 93 | 94 | 95 | if __name__ == '__main__': 96 | main() 97 | -------------------------------------------------------------------------------- /track1_asr/wenet/bin/export_jit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import print_function 16 | 17 | import argparse 18 | import os 19 | 20 | import torch 21 | import yaml 22 | 23 | from wenet.utils.checkpoint import load_checkpoint 24 | from wenet.utils.init_model import init_model 25 | 26 | 27 | def get_args(): 28 | parser = argparse.ArgumentParser(description='export your script model') 29 | parser.add_argument('--config', required=True, help='config file') 30 | parser.add_argument('--checkpoint', required=True, help='checkpoint model') 31 | parser.add_argument('--output_file', default=None, help='output file') 32 | parser.add_argument('--output_quant_file', 33 | default=None, 34 | help='output quantized model file') 35 | args = parser.parse_args() 36 | return args 37 | 38 | 39 | def main(): 40 | args = get_args() 41 | # No need gpu for model export 42 | os.environ['CUDA_VISIBLE_DEVICES'] = '-1' 43 | 44 | with open(args.config, 'r') as fin: 45 | configs = yaml.load(fin, Loader=yaml.FullLoader) 46 | model = init_model(configs) 47 | print(model) 48 | 49 | load_checkpoint(model, args.checkpoint) 50 | # Export jit torch script model 51 | 52 | if args.output_file: 53 | script_model = torch.jit.script(model) 54 | script_model.save(args.output_file) 55 | print('Export model successfully, see {}'.format(args.output_file)) 56 | 57 | # Export quantized jit torch script model 58 | if args.output_quant_file: 59 | quantized_model = torch.quantization.quantize_dynamic( 60 | model, {torch.nn.Linear}, dtype=torch.qint8 61 | ) 62 | print(quantized_model) 63 | script_quant_model = torch.jit.script(quantized_model) 64 | script_quant_model.save(args.output_quant_file) 65 | print('Export quantized model successfully, ' 66 | 'see {}'.format(args.output_quant_file)) 67 | 68 | 69 | if __name__ == '__main__': 70 | main() 71 | -------------------------------------------------------------------------------- /track1_asr/wenet/branchformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrSupW/ICMC-ASR_Baseline/355626cef459e58a0fe7cc62af1326d67a43a0d7/track1_asr/wenet/branchformer/__init__.py -------------------------------------------------------------------------------- /track1_asr/wenet/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrSupW/ICMC-ASR_Baseline/355626cef459e58a0fe7cc62af1326d67a43a0d7/track1_asr/wenet/dataset/__init__.py -------------------------------------------------------------------------------- /track1_asr/wenet/efficient_conformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrSupW/ICMC-ASR_Baseline/355626cef459e58a0fe7cc62af1326d67a43a0d7/track1_asr/wenet/efficient_conformer/__init__.py -------------------------------------------------------------------------------- /track1_asr/wenet/efficient_conformer/subsampling.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu) 2 | # 2022 58.com(Wuba) Inc AI Lab. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # Modified from ESPnet(https://github.com/espnet/espnet) 16 | 17 | 18 | """Subsampling layer definition.""" 19 | 20 | from typing import Tuple, Union 21 | 22 | import torch 23 | from wenet.transformer.subsampling import BaseSubsampling 24 | 25 | 26 | class Conv2dSubsampling2(BaseSubsampling): 27 | """Convolutional 2D subsampling (to 1/4 length). 28 | 29 | Args: 30 | idim (int): Input dimension. 31 | odim (int): Output dimension. 32 | dropout_rate (float): Dropout rate. 33 | 34 | """ 35 | def __init__(self, idim: int, odim: int, dropout_rate: float, 36 | pos_enc_class: torch.nn.Module): 37 | """Construct an Conv2dSubsampling4 object.""" 38 | super().__init__() 39 | self.conv = torch.nn.Sequential( 40 | torch.nn.Conv2d(1, odim, 3, 2), 41 | torch.nn.ReLU() 42 | ) 43 | self.out = torch.nn.Sequential( 44 | torch.nn.Linear(odim * ((idim - 1) // 2), odim)) 45 | self.pos_enc = pos_enc_class 46 | # The right context for every conv layer is computed by: 47 | # (kernel_size - 1) * frame_rate_of_this_layer 48 | self.subsampling_rate = 2 49 | # 2 = (3 - 1) * 1 50 | self.right_context = 2 51 | 52 | def forward( 53 | self, 54 | x: torch.Tensor, 55 | x_mask: torch.Tensor, 56 | offset: Union[int, torch.Tensor] = 0 57 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 58 | """Subsample x. 59 | 60 | Args: 61 | x (torch.Tensor): Input tensor (#batch, time, idim). 62 | x_mask (torch.Tensor): Input mask (#batch, 1, time). 63 | 64 | Returns: 65 | torch.Tensor: Subsampled tensor (#batch, time', odim), 66 | where time' = time // 2. 67 | torch.Tensor: Subsampled mask (#batch, 1, time'), 68 | where time' = time // 2. 69 | torch.Tensor: positional encoding 70 | 71 | """ 72 | x = x.unsqueeze(1) # (b, c=1, t, f) 73 | x = self.conv(x) 74 | b, c, t, f = x.size() 75 | x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) 76 | x, pos_emb = self.pos_enc(x, offset) 77 | return x, pos_emb, x_mask[:, :, :-2:2] 78 | -------------------------------------------------------------------------------- /track1_asr/wenet/paraformer/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 ASLP@NWPU (authors: He Wang, Fan Yu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | # License for the specific language governing permissions and limitations 13 | # under the License. Modified from 14 | # FunASR(https://github.com/alibaba-damo-academy/FunASR) 15 | 16 | from typing import Optional 17 | 18 | import six 19 | import torch 20 | import numpy as np 21 | 22 | 23 | def sequence_mask(lengths, maxlen: Optional[int] = None, 24 | dtype: torch.dtype = torch.float32, 25 | device: Optional[torch.device] = None) -> torch.Tensor: 26 | if maxlen is None: 27 | maxlen = lengths.max() 28 | row_vector = torch.arange(0, maxlen, 1).to(lengths.device) 29 | matrix = torch.unsqueeze(lengths, dim=-1) 30 | mask = row_vector < matrix 31 | mask = mask.detach() 32 | 33 | return mask.type(dtype).to(device) if device is not None else \ 34 | mask.type(dtype) 35 | 36 | 37 | def end_detect(ended_hyps, i, M=3, d_end=np.log(1 * np.exp(-10))): 38 | """End detection. 39 | 40 | described in Eq. (50) of S. Watanabe et al 41 | "Hybrid CTC/Attention Architecture for End-to-End Speech Recognition" 42 | 43 | :param ended_hyps: 44 | :param i: 45 | :param M: 46 | :param d_end: 47 | :return: 48 | """ 49 | if len(ended_hyps) == 0: 50 | return False 51 | count = 0 52 | best_hyp = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[0] 53 | for m in six.moves.range(M): 54 | # get ended_hyps with their length is i - m 55 | hyp_length = i - m 56 | hyps_same_length = [x for x in ended_hyps if 57 | len(x["yseq"]) == hyp_length] 58 | if len(hyps_same_length) > 0: 59 | best_hyp_same_length = sorted( 60 | hyps_same_length, key=lambda x: x["score"], reverse=True)[0] 61 | if best_hyp_same_length["score"] - best_hyp["score"] < d_end: 62 | count += 1 63 | 64 | if count == M: 65 | return True 66 | else: 67 | return False 68 | -------------------------------------------------------------------------------- /track1_asr/wenet/squeezeformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrSupW/ICMC-ASR_Baseline/355626cef459e58a0fe7cc62af1326d67a43a0d7/track1_asr/wenet/squeezeformer/__init__.py -------------------------------------------------------------------------------- /track1_asr/wenet/squeezeformer/conv2d.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Ximalaya Inc. (authors: Yuguang Yang) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Conv2d Module with Valid Padding""" 16 | 17 | import torch.nn.functional as F 18 | from torch.nn.modules.conv import _ConvNd, _size_2_t, Union, _pair, Tensor, Optional 19 | 20 | 21 | class Conv2dValid(_ConvNd): 22 | """ 23 | Conv2d operator for VALID mode padding. 24 | """ 25 | def __init__( 26 | self, 27 | in_channels: int, 28 | out_channels: int, 29 | kernel_size: _size_2_t, 30 | stride: _size_2_t = 1, 31 | padding: Union[str, _size_2_t] = 0, 32 | dilation: _size_2_t = 1, 33 | groups: int = 1, 34 | bias: bool = True, 35 | padding_mode: str = 'zeros', # TODO: refine this type 36 | device=None, 37 | dtype=None, 38 | valid_trigx: bool = False, 39 | valid_trigy: bool = False 40 | ) -> None: 41 | factory_kwargs = {'device': device, 'dtype': dtype} 42 | kernel_size_ = _pair(kernel_size) 43 | stride_ = _pair(stride) 44 | padding_ = padding if isinstance(padding, str) else _pair(padding) 45 | dilation_ = _pair(dilation) 46 | super(Conv2dValid, self).__init__( 47 | in_channels, out_channels, kernel_size_, 48 | stride_, padding_, dilation_, False, _pair(0), 49 | groups, bias, padding_mode, **factory_kwargs) 50 | self.valid_trigx = valid_trigx 51 | self.valid_trigy = valid_trigy 52 | 53 | def _conv_forward( 54 | self, input: Tensor, weight: Tensor, bias: Optional[Tensor]): 55 | validx, validy = 0, 0 56 | if self.valid_trigx: 57 | validx = (input.size(-2) * (self.stride[-2] - 1) - 1 58 | + self.kernel_size[-2]) // 2 59 | if self.valid_trigy: 60 | validy = (input.size(-1) * (self.stride[-1] - 1) - 1 61 | + self.kernel_size[-1]) // 2 62 | return F.conv2d(input, weight, bias, self.stride, 63 | (validx, validy), self.dilation, self.groups) 64 | 65 | def forward(self, input: Tensor) -> Tensor: 66 | return self._conv_forward(input, self.weight, self.bias) 67 | -------------------------------------------------------------------------------- /track1_asr/wenet/squeezeformer/encoder_layer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Ximalaya Inc. (authors: Yuguang Yang) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """SqueezeformerEncoderLayer definition.""" 16 | 17 | import torch 18 | import torch.nn as nn 19 | from typing import Optional, Tuple 20 | 21 | 22 | class SqueezeformerEncoderLayer(nn.Module): 23 | """Encoder layer module. 24 | Args: 25 | size (int): Input dimension. 26 | self_attn (torch.nn.Module): Self-attention module instance. 27 | `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` 28 | instance can be used as the argument. 29 | feed_forward1 (torch.nn.Module): Feed-forward module instance. 30 | `PositionwiseFeedForward` instance can be used as the argument. 31 | conv_module (torch.nn.Module): Convolution module instance. 32 | `ConvlutionModule` instance can be used as the argument. 33 | feed_forward2 (torch.nn.Module): Feed-forward module instance. 34 | `PositionwiseFeedForward` instance can be used as the argument. 35 | dropout_rate (float): Dropout rate. 36 | normalize_before (bool): 37 | True: use layer_norm before each sub-block. 38 | False: use layer_norm after each sub-block. 39 | """ 40 | 41 | def __init__( 42 | self, 43 | size: int, 44 | self_attn: torch.nn.Module, 45 | feed_forward1: Optional[nn.Module] = None, 46 | conv_module: Optional[nn.Module] = None, 47 | feed_forward2: Optional[nn.Module] = None, 48 | normalize_before: bool = False, 49 | dropout_rate: float = 0.1, 50 | concat_after: bool = False, 51 | ): 52 | super(SqueezeformerEncoderLayer, self).__init__() 53 | self.size = size 54 | self.self_attn = self_attn 55 | self.layer_norm1 = nn.LayerNorm(size) 56 | self.ffn1 = feed_forward1 57 | self.layer_norm2 = nn.LayerNorm(size) 58 | self.conv_module = conv_module 59 | self.layer_norm3 = nn.LayerNorm(size) 60 | self.ffn2 = feed_forward2 61 | self.layer_norm4 = nn.LayerNorm(size) 62 | self.normalize_before = normalize_before 63 | self.dropout = nn.Dropout(dropout_rate) 64 | self.concat_after = concat_after 65 | if concat_after: 66 | self.concat_linear = nn.Linear(size + size, size) 67 | else: 68 | self.concat_linear = nn.Identity() 69 | 70 | def forward( 71 | self, 72 | x: torch.Tensor, 73 | mask: torch.Tensor, 74 | pos_emb: torch.Tensor, 75 | mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), 76 | att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), 77 | cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), 78 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 79 | # self attention module 80 | residual = x 81 | if self.normalize_before: 82 | x = self.layer_norm1(x) 83 | x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb, att_cache) 84 | if self.concat_after: 85 | x_concat = torch.cat((x, x_att), dim=-1) 86 | x = residual + self.concat_linear(x_concat) 87 | else: 88 | x = residual + self.dropout(x_att) 89 | if not self.normalize_before: 90 | x = self.layer_norm1(x) 91 | 92 | # ffn module 93 | residual = x 94 | if self.normalize_before: 95 | x = self.layer_norm2(x) 96 | x = self.ffn1(x) 97 | x = residual + self.dropout(x) 98 | if not self.normalize_before: 99 | x = self.layer_norm2(x) 100 | 101 | # conv module 102 | new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) 103 | residual = x 104 | if self.normalize_before: 105 | x = self.layer_norm3(x) 106 | x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache) 107 | x = residual + self.dropout(x) 108 | if not self.normalize_before: 109 | x = self.layer_norm3(x) 110 | 111 | # ffn module 112 | residual = x 113 | if self.normalize_before: 114 | x = self.layer_norm4(x) 115 | x = self.ffn2(x) 116 | # we do not use dropout here since it is inside feed forward function 117 | x = residual + self.dropout(x) 118 | if not self.normalize_before: 119 | x = self.layer_norm4(x) 120 | 121 | return x, mask, new_att_cache, new_cnn_cache 122 | -------------------------------------------------------------------------------- /track1_asr/wenet/squeezeformer/positionwise_feed_forward.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Shigeki Karita 2 | # 2020 Mobvoi Inc (Binbin Zhang) 3 | # 2022 Ximalaya Inc (Yuguang Yang) 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Positionwise feed forward layer definition.""" 18 | 19 | import torch 20 | 21 | 22 | class PositionwiseFeedForward(torch.nn.Module): 23 | """Positionwise feed forward layer. 24 | 25 | FeedForward are appied on each position of the sequence. 26 | The output dim is same with the input dim. 27 | 28 | Args: 29 | idim (int): Input dimenstion. 30 | hidden_units (int): The number of hidden units. 31 | dropout_rate (float): Dropout rate. 32 | activation (torch.nn.Module): Activation function 33 | """ 34 | 35 | def __init__(self, 36 | idim: int, 37 | hidden_units: int, 38 | dropout_rate: float, 39 | activation: torch.nn.Module = torch.nn.ReLU(), 40 | adaptive_scale: bool = False, 41 | init_weights: bool = False 42 | ): 43 | """Construct a PositionwiseFeedForward object.""" 44 | super(PositionwiseFeedForward, self).__init__() 45 | self.idim = idim 46 | self.hidden_units = hidden_units 47 | self.w_1 = torch.nn.Linear(idim, hidden_units) 48 | self.activation = activation 49 | self.dropout = torch.nn.Dropout(dropout_rate) 50 | self.w_2 = torch.nn.Linear(hidden_units, idim) 51 | self.ada_scale = None 52 | self.ada_bias = None 53 | self.adaptive_scale = adaptive_scale 54 | self.ada_scale = torch.nn.Parameter( 55 | torch.ones([1, 1, idim]), requires_grad=adaptive_scale) 56 | self.ada_bias = torch.nn.Parameter( 57 | torch.zeros([1, 1, idim]), requires_grad=adaptive_scale) 58 | if init_weights: 59 | self.init_weights() 60 | 61 | def init_weights(self): 62 | ffn1_max = self.idim ** -0.5 63 | ffn2_max = self.hidden_units ** -0.5 64 | torch.nn.init.uniform_(self.w_1.weight.data, -ffn1_max, ffn1_max) 65 | torch.nn.init.uniform_(self.w_1.bias.data, -ffn1_max, ffn1_max) 66 | torch.nn.init.uniform_(self.w_2.weight.data, -ffn2_max, ffn2_max) 67 | torch.nn.init.uniform_(self.w_2.bias.data, -ffn2_max, ffn2_max) 68 | 69 | def forward(self, xs: torch.Tensor) -> torch.Tensor: 70 | """Forward function. 71 | 72 | Args: 73 | xs: input tensor (B, L, D) 74 | Returns: 75 | output tensor, (B, L, D) 76 | """ 77 | if self.adaptive_scale: 78 | xs = self.ada_scale * xs + self.ada_bias 79 | return self.w_2(self.dropout(self.activation(self.w_1(xs)))) 80 | -------------------------------------------------------------------------------- /track1_asr/wenet/transducer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrSupW/ICMC-ASR_Baseline/355626cef459e58a0fe7cc62af1326d67a43a0d7/track1_asr/wenet/transducer/__init__.py -------------------------------------------------------------------------------- /track1_asr/wenet/transducer/joint.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torch import nn 5 | from wenet.utils.common import get_activation 6 | 7 | 8 | class TransducerJoint(torch.nn.Module): 9 | 10 | def __init__(self, 11 | voca_size: int, 12 | enc_output_size: int, 13 | pred_output_size: int, 14 | join_dim: int, 15 | prejoin_linear: bool = True, 16 | postjoin_linear: bool = False, 17 | joint_mode: str = 'add', 18 | activation: str = "tanh", 19 | hat_joint: bool = False, 20 | dropout_rate: float = 0.1, 21 | hat_activation: str = 'tanh'): 22 | # TODO(Mddct): concat in future 23 | assert joint_mode in ['add'] 24 | super().__init__() 25 | 26 | self.activatoin = get_activation(activation) 27 | self.prejoin_linear = prejoin_linear 28 | self.postjoin_linear = postjoin_linear 29 | self.joint_mode = joint_mode 30 | 31 | if not self.prejoin_linear and not self.postjoin_linear: 32 | assert enc_output_size == pred_output_size == join_dim 33 | # torchscript compatibility 34 | self.enc_ffn: Optional[nn.Linear] = None 35 | self.pred_ffn: Optional[nn.Linear] = None 36 | if self.prejoin_linear: 37 | self.enc_ffn = nn.Linear(enc_output_size, join_dim) 38 | self.pred_ffn = nn.Linear(pred_output_size, join_dim) 39 | # torchscript compatibility 40 | self.post_ffn: Optional[nn.Linear] = None 41 | if self.postjoin_linear: 42 | self.post_ffn = nn.Linear(join_dim, join_dim) 43 | 44 | # NOTE: in vocab_size 45 | self.hat_joint = hat_joint 46 | self.vocab_size = voca_size 47 | self.ffn_out: Optional[torch.nn.Linear] = None 48 | if not self.hat_joint: 49 | self.ffn_out = nn.Linear(join_dim, voca_size) 50 | 51 | self.blank_pred: Optional[torch.nn.Module] = None 52 | self.token_pred: Optional[torch.nn.Module] = None 53 | if self.hat_joint: 54 | self.blank_pred = torch.nn.Sequential( 55 | torch.nn.Tanh(), torch.nn.Dropout(dropout_rate), 56 | torch.nn.Linear(join_dim, 1), torch.nn.LogSigmoid()) 57 | self.token_pred = torch.nn.Sequential( 58 | get_activation(hat_activation), torch.nn.Dropout(dropout_rate), 59 | torch.nn.Linear(join_dim, self.vocab_size - 1)) 60 | 61 | def forward(self, 62 | enc_out: torch.Tensor, 63 | pred_out: torch.Tensor, 64 | pre_project: bool = True) -> torch.Tensor: 65 | """ 66 | Args: 67 | enc_out (torch.Tensor): [B, T, E] 68 | pred_out (torch.Tensor): [B, T, P] 69 | Return: 70 | [B,T,U,V] 71 | """ 72 | if (pre_project and self.prejoin_linear and self.enc_ffn is not None 73 | and self.pred_ffn is not None): 74 | enc_out = self.enc_ffn(enc_out) # [B,T,E] -> [B,T,D] 75 | pred_out = self.pred_ffn(pred_out) 76 | if enc_out.ndim != 4: 77 | enc_out = enc_out.unsqueeze(2) # [B,T,D] -> [B,T,1,D] 78 | if pred_out.ndim != 4: 79 | pred_out = pred_out.unsqueeze(1) # [B,U,D] -> [B,1,U,D] 80 | 81 | # TODO(Mddct): concat joint 82 | _ = self.joint_mode 83 | out = enc_out + pred_out # [B,T,U,V] 84 | 85 | if self.postjoin_linear and self.post_ffn is not None: 86 | out = self.post_ffn(out) 87 | 88 | if not self.hat_joint and self.ffn_out is not None: 89 | out = self.activatoin(out) 90 | out = self.ffn_out(out) 91 | return out 92 | else: 93 | assert self.blank_pred is not None 94 | assert self.token_pred is not None 95 | blank_logp = self.blank_pred(out) # [B,T,U,1] 96 | 97 | # scale blank logp 98 | scale_logp = torch.clamp(1 - torch.exp(blank_logp), min=1e-6) 99 | label_logp = self.token_pred(out).log_softmax( 100 | dim=-1) # [B,T,U,vocab-1] 101 | # scale token logp 102 | label_logp = torch.log(scale_logp) + label_logp 103 | 104 | out = torch.cat((blank_logp, label_logp), dim=-1) # [B,T,U,vocab] 105 | return out 106 | -------------------------------------------------------------------------------- /track1_asr/wenet/transducer/search/greedy_search.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | 5 | 6 | def basic_greedy_search( 7 | model: torch.nn.Module, 8 | encoder_out: torch.Tensor, 9 | encoder_out_lens: torch.Tensor, 10 | n_steps: int = 64, 11 | ) -> List[List[int]]: 12 | # fake padding 13 | padding = torch.zeros(1, 1).to(encoder_out.device) 14 | # sos 15 | pred_input_step = torch.tensor([model.blank]).reshape(1, 1).to(encoder_out.device) 16 | cache = model.predictor.init_state(1, 17 | method="zero", 18 | device=encoder_out.device) 19 | new_cache: List[torch.Tensor] = [] 20 | t = 0 21 | hyps = [] 22 | prev_out_nblk = True 23 | pred_out_step = None 24 | per_frame_max_noblk = n_steps 25 | per_frame_noblk = 0 26 | while t < encoder_out_lens: 27 | encoder_out_step = encoder_out[:, t:t + 1, :] # [1, 1, E] 28 | if prev_out_nblk: 29 | step_outs = model.predictor.forward_step(pred_input_step, padding, 30 | cache) # [1, 1, P] 31 | pred_out_step, new_cache = step_outs[0], step_outs[1] 32 | 33 | joint_out_step = model.joint(encoder_out_step, 34 | pred_out_step) # [1,1,v] 35 | joint_out_probs = joint_out_step.log_softmax(dim=-1) 36 | 37 | joint_out_max = joint_out_probs.argmax(dim=-1).squeeze() # [] 38 | if joint_out_max != model.blank: 39 | hyps.append(joint_out_max.item()) 40 | prev_out_nblk = True 41 | per_frame_noblk = per_frame_noblk + 1 42 | pred_input_step = joint_out_max.reshape(1, 1) 43 | # state_m, state_c = clstate_out_m, state_out_c 44 | cache = new_cache 45 | 46 | if joint_out_max == model.blank or per_frame_noblk >= per_frame_max_noblk: 47 | if joint_out_max == model.blank: 48 | prev_out_nblk = False 49 | # TODO(Mddct): make t in chunk for streamming 50 | # or t should't be too lang to predict none blank 51 | t = t + 1 52 | per_frame_noblk = 0 53 | 54 | return [hyps] 55 | -------------------------------------------------------------------------------- /track1_asr/wenet/transformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrSupW/ICMC-ASR_Baseline/355626cef459e58a0fe7cc62af1326d67a43a0d7/track1_asr/wenet/transformer/__init__.py -------------------------------------------------------------------------------- /track1_asr/wenet/transformer/cmvn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | 18 | class GlobalCMVN(torch.nn.Module): 19 | def __init__(self, 20 | mean: torch.Tensor, 21 | istd: torch.Tensor, 22 | norm_var: bool = True): 23 | """ 24 | Args: 25 | mean (torch.Tensor): mean stats 26 | istd (torch.Tensor): inverse std, std which is 1.0 / std 27 | """ 28 | super().__init__() 29 | assert mean.shape == istd.shape 30 | self.norm_var = norm_var 31 | # The buffer can be accessed from this module using self.mean 32 | self.register_buffer("mean", mean) 33 | self.register_buffer("istd", istd) 34 | 35 | def forward(self, x: torch.Tensor): 36 | """ 37 | Args: 38 | x (torch.Tensor): (batch, max_len, feat_dim) 39 | 40 | Returns: 41 | (torch.Tensor): normalized feature 42 | """ 43 | x = x - self.mean 44 | if self.norm_var: 45 | x = x * self.istd 46 | return x 47 | -------------------------------------------------------------------------------- /track1_asr/wenet/transformer/ctc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Modified from ESPnet(https://github.com/espnet/espnet) 15 | 16 | import torch 17 | import torch.nn.functional as F 18 | 19 | 20 | class CTC(torch.nn.Module): 21 | """CTC module""" 22 | def __init__( 23 | self, 24 | odim: int, 25 | encoder_output_size: int, 26 | dropout_rate: float = 0.0, 27 | reduce: bool = True, 28 | ): 29 | """ Construct CTC module 30 | Args: 31 | odim: dimension of outputs 32 | encoder_output_size: number of encoder projection units 33 | dropout_rate: dropout rate (0.0 ~ 1.0) 34 | reduce: reduce the CTC loss into a scalar 35 | """ 36 | super().__init__() 37 | eprojs = encoder_output_size 38 | self.dropout_rate = dropout_rate 39 | self.ctc_lo = torch.nn.Linear(eprojs, odim) 40 | 41 | reduction_type = "sum" if reduce else "none" 42 | self.ctc_loss = torch.nn.CTCLoss(reduction=reduction_type) 43 | 44 | def forward(self, hs_pad: torch.Tensor, hlens: torch.Tensor, 45 | ys_pad: torch.Tensor, ys_lens: torch.Tensor) -> torch.Tensor: 46 | """Calculate CTC loss. 47 | 48 | Args: 49 | hs_pad: batch of padded hidden state sequences (B, Tmax, D) 50 | hlens: batch of lengths of hidden state sequences (B) 51 | ys_pad: batch of padded character id sequence tensor (B, Lmax) 52 | ys_lens: batch of lengths of character sequence (B) 53 | """ 54 | # hs_pad: (B, L, NProj) -> ys_hat: (B, L, Nvocab) 55 | ys_hat = self.ctc_lo(F.dropout(hs_pad, p=self.dropout_rate)) 56 | # ys_hat: (B, L, D) -> (L, B, D) 57 | ys_hat = ys_hat.transpose(0, 1) 58 | ys_hat = ys_hat.log_softmax(2) 59 | loss = self.ctc_loss(ys_hat, ys_pad, hlens, ys_lens) 60 | # Batch-size average 61 | loss = loss / ys_hat.size(1) 62 | return loss 63 | 64 | def log_softmax(self, hs_pad: torch.Tensor) -> torch.Tensor: 65 | """log_softmax of frame activations 66 | 67 | Args: 68 | Tensor hs_pad: 3d tensor (B, Tmax, eprojs) 69 | Returns: 70 | torch.Tensor: log softmax applied 3d tensor (B, Tmax, odim) 71 | """ 72 | return F.log_softmax(self.ctc_lo(hs_pad), dim=2) 73 | 74 | def argmax(self, hs_pad: torch.Tensor) -> torch.Tensor: 75 | """argmax of frame activations 76 | 77 | Args: 78 | torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs) 79 | Returns: 80 | torch.Tensor: argmax applied 2d tensor (B, Tmax) 81 | """ 82 | return torch.argmax(self.ctc_lo(hs_pad), dim=2) 83 | -------------------------------------------------------------------------------- /track1_asr/wenet/transformer/decoder_layer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Shigeki Karita 2 | # 2020 Mobvoi Inc (Binbin Zhang) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Decoder self-attention layer definition.""" 17 | from typing import Optional, Tuple 18 | 19 | import torch 20 | from torch import nn 21 | 22 | 23 | class DecoderLayer(nn.Module): 24 | """Single decoder layer module. 25 | 26 | Args: 27 | size (int): Input dimension. 28 | self_attn (torch.nn.Module): Self-attention module instance. 29 | `MultiHeadedAttention` instance can be used as the argument. 30 | src_attn (torch.nn.Module): Inter-attention module instance. 31 | `MultiHeadedAttention` instance can be used as the argument. 32 | If `None` is passed, Inter-attention is not used, such as 33 | CIF, GPT, and other decoder only model. 34 | feed_forward (torch.nn.Module): Feed-forward module instance. 35 | `PositionwiseFeedForward` instance can be used as the argument. 36 | dropout_rate (float): Dropout rate. 37 | normalize_before (bool): 38 | True: use layer_norm before each sub-block. 39 | False: to use layer_norm after each sub-block. 40 | """ 41 | def __init__( 42 | self, 43 | size: int, 44 | self_attn: nn.Module, 45 | src_attn: Optional[nn.Module], 46 | feed_forward: nn.Module, 47 | dropout_rate: float, 48 | normalize_before: bool = True, 49 | ): 50 | """Construct an DecoderLayer object.""" 51 | super().__init__() 52 | self.size = size 53 | self.self_attn = self_attn 54 | self.src_attn = src_attn 55 | self.feed_forward = feed_forward 56 | self.norm1 = nn.LayerNorm(size, eps=1e-5) 57 | self.norm2 = nn.LayerNorm(size, eps=1e-5) 58 | self.norm3 = nn.LayerNorm(size, eps=1e-5) 59 | self.dropout = nn.Dropout(dropout_rate) 60 | self.normalize_before = normalize_before 61 | 62 | def forward( 63 | self, 64 | tgt: torch.Tensor, 65 | tgt_mask: torch.Tensor, 66 | memory: torch.Tensor, 67 | memory_mask: torch.Tensor, 68 | cache: Optional[torch.Tensor] = None 69 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 70 | """Compute decoded features. 71 | 72 | Args: 73 | tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size). 74 | tgt_mask (torch.Tensor): Mask for input tensor 75 | (#batch, maxlen_out). 76 | memory (torch.Tensor): Encoded memory 77 | (#batch, maxlen_in, size). 78 | memory_mask (torch.Tensor): Encoded memory mask 79 | (#batch, maxlen_in). 80 | cache (torch.Tensor): cached tensors. 81 | (#batch, maxlen_out - 1, size). 82 | 83 | Returns: 84 | torch.Tensor: Output tensor (#batch, maxlen_out, size). 85 | torch.Tensor: Mask for output tensor (#batch, maxlen_out). 86 | torch.Tensor: Encoded memory (#batch, maxlen_in, size). 87 | torch.Tensor: Encoded memory mask (#batch, maxlen_in). 88 | 89 | """ 90 | residual = tgt 91 | if self.normalize_before: 92 | tgt = self.norm1(tgt) 93 | 94 | if cache is None: 95 | tgt_q = tgt 96 | tgt_q_mask = tgt_mask 97 | else: 98 | # compute only the last frame query keeping dim: max_time_out -> 1 99 | assert cache.shape == ( 100 | tgt.shape[0], 101 | tgt.shape[1] - 1, 102 | self.size, 103 | ), "{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}" 104 | tgt_q = tgt[:, -1:, :] 105 | residual = residual[:, -1:, :] 106 | tgt_q_mask = tgt_mask[:, -1:, :] 107 | 108 | x = residual + self.dropout( 109 | self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0]) 110 | if not self.normalize_before: 111 | x = self.norm1(x) 112 | 113 | if self.src_attn is not None: 114 | residual = x 115 | if self.normalize_before: 116 | x = self.norm2(x) 117 | x = residual + self.dropout( 118 | self.src_attn(x, memory, memory, memory_mask)[0]) 119 | if not self.normalize_before: 120 | x = self.norm2(x) 121 | 122 | residual = x 123 | if self.normalize_before: 124 | x = self.norm3(x) 125 | x = residual + self.dropout(self.feed_forward(x)) 126 | if not self.normalize_before: 127 | x = self.norm3(x) 128 | 129 | if cache is not None: 130 | x = torch.cat([cache, x], dim=1) 131 | 132 | return x, tgt_mask, memory, memory_mask 133 | -------------------------------------------------------------------------------- /track1_asr/wenet/transformer/label_smoothing_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Shigeki Karita 2 | # 2020 Mobvoi Inc (Binbin Zhang) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Label smoothing module.""" 17 | 18 | import torch 19 | from torch import nn 20 | 21 | 22 | class LabelSmoothingLoss(nn.Module): 23 | """Label-smoothing loss. 24 | 25 | In a standard CE loss, the label's data distribution is: 26 | [0,1,2] -> 27 | [ 28 | [1.0, 0.0, 0.0], 29 | [0.0, 1.0, 0.0], 30 | [0.0, 0.0, 1.0], 31 | ] 32 | 33 | In the smoothing version CE Loss,some probabilities 34 | are taken from the true label prob (1.0) and are divided 35 | among other labels. 36 | 37 | e.g. 38 | smoothing=0.1 39 | [0,1,2] -> 40 | [ 41 | [0.9, 0.05, 0.05], 42 | [0.05, 0.9, 0.05], 43 | [0.05, 0.05, 0.9], 44 | ] 45 | 46 | Args: 47 | size (int): the number of class 48 | padding_idx (int): padding class id which will be ignored for loss 49 | smoothing (float): smoothing rate (0.0 means the conventional CE) 50 | normalize_length (bool): 51 | normalize loss by sequence length if True 52 | normalize loss by batch size if False 53 | """ 54 | def __init__(self, 55 | size: int, 56 | padding_idx: int, 57 | smoothing: float, 58 | normalize_length: bool = False): 59 | """Construct an LabelSmoothingLoss object.""" 60 | super(LabelSmoothingLoss, self).__init__() 61 | self.criterion = nn.KLDivLoss(reduction="none") 62 | self.padding_idx = padding_idx 63 | self.confidence = 1.0 - smoothing 64 | self.smoothing = smoothing 65 | self.size = size 66 | self.normalize_length = normalize_length 67 | 68 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 69 | """Compute loss between x and target. 70 | 71 | The model outputs and data labels tensors are flatten to 72 | (batch*seqlen, class) shape and a mask is applied to the 73 | padding part which should not be calculated for loss. 74 | 75 | Args: 76 | x (torch.Tensor): prediction (batch, seqlen, class) 77 | target (torch.Tensor): 78 | target signal masked with self.padding_id (batch, seqlen) 79 | Returns: 80 | loss (torch.Tensor) : The KL loss, scalar float value 81 | """ 82 | assert x.size(2) == self.size 83 | batch_size = x.size(0) 84 | x = x.view(-1, self.size) 85 | target = target.view(-1) 86 | # use zeros_like instead of torch.no_grad() for true_dist, 87 | # since no_grad() can not be exported by JIT 88 | true_dist = torch.zeros_like(x) 89 | true_dist.fill_(self.smoothing / (self.size - 1)) 90 | ignore = target == self.padding_idx # (B,) 91 | total = len(target) - ignore.sum().item() 92 | target = target.masked_fill(ignore, 0) # avoid -1 index 93 | true_dist.scatter_(1, target.unsqueeze(1), self.confidence) 94 | kl = self.criterion(torch.log_softmax(x, dim=1), true_dist) 95 | denom = total if self.normalize_length else batch_size 96 | return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom 97 | -------------------------------------------------------------------------------- /track1_asr/wenet/transformer/positionwise_feed_forward.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Shigeki Karita 2 | # 2020 Mobvoi Inc (Binbin Zhang) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Positionwise feed forward layer definition.""" 17 | 18 | import torch 19 | 20 | 21 | class PositionwiseFeedForward(torch.nn.Module): 22 | """Positionwise feed forward layer. 23 | 24 | FeedForward are appied on each position of the sequence. 25 | The output dim is same with the input dim. 26 | 27 | Args: 28 | idim (int): Input dimenstion. 29 | hidden_units (int): The number of hidden units. 30 | dropout_rate (float): Dropout rate. 31 | activation (torch.nn.Module): Activation function 32 | """ 33 | def __init__(self, 34 | idim: int, 35 | hidden_units: int, 36 | dropout_rate: float, 37 | activation: torch.nn.Module = torch.nn.ReLU()): 38 | """Construct a PositionwiseFeedForward object.""" 39 | super(PositionwiseFeedForward, self).__init__() 40 | self.w_1 = torch.nn.Linear(idim, hidden_units) 41 | self.activation = activation 42 | self.dropout = torch.nn.Dropout(dropout_rate) 43 | self.w_2 = torch.nn.Linear(hidden_units, idim) 44 | 45 | def forward(self, xs: torch.Tensor) -> torch.Tensor: 46 | """Forward function. 47 | 48 | Args: 49 | xs: input tensor (B, L, D) 50 | Returns: 51 | output tensor, (B, L, D) 52 | """ 53 | return self.w_2(self.dropout(self.activation(self.w_1(xs)))) 54 | -------------------------------------------------------------------------------- /track1_asr/wenet/transformer/swish.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe) 2 | # 2020 Northwestern Polytechnical University (Pengcheng Guo) 3 | # 2020 Mobvoi Inc (Binbin Zhang) 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Swish() activation function for Conformer.""" 18 | 19 | import torch 20 | 21 | 22 | class Swish(torch.nn.Module): 23 | """Construct an Swish object.""" 24 | def forward(self, x: torch.Tensor) -> torch.Tensor: 25 | """Return Swish activation function.""" 26 | return x * torch.sigmoid(x) 27 | -------------------------------------------------------------------------------- /track1_asr/wenet/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrSupW/ICMC-ASR_Baseline/355626cef459e58a0fe7cc62af1326d67a43a0d7/track1_asr/wenet/utils/__init__.py -------------------------------------------------------------------------------- /track1_asr/wenet/utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | import os 17 | import re 18 | 19 | import yaml 20 | import torch 21 | from collections import OrderedDict 22 | 23 | import datetime 24 | 25 | 26 | def load_checkpoint(model: torch.nn.Module, path: str, local_rank: int = 0) -> dict: 27 | if torch.cuda.is_available(): 28 | logging.info('Checkpoint: loading from checkpoint %s for GPU' % path) 29 | checkpoint = torch.load(path, map_location=lambda storage, loc: storage.cuda(local_rank)) 30 | else: 31 | logging.info('Checkpoint: loading from checkpoint %s for CPU' % path) 32 | checkpoint = torch.load(path, map_location='cpu') 33 | model.load_state_dict(checkpoint, strict=False) 34 | info_path = re.sub('.pt$', '.yaml', path) 35 | configs = {} 36 | if os.path.exists(info_path): 37 | with open(info_path, 'r') as fin: 38 | configs = yaml.load(fin, Loader=yaml.FullLoader) 39 | return configs 40 | 41 | 42 | def save_checkpoint(model: torch.nn.Module, path: str, infos=None): 43 | ''' 44 | Args: 45 | infos (dict or None): any info you want to save. 46 | ''' 47 | logging.info('Checkpoint: save to checkpoint %s' % path) 48 | if isinstance(model, torch.nn.DataParallel): 49 | state_dict = model.module.state_dict() 50 | elif isinstance(model, torch.nn.parallel.DistributedDataParallel): 51 | state_dict = model.module.state_dict() 52 | else: 53 | state_dict = model.state_dict() 54 | torch.save(state_dict, path) 55 | info_path = re.sub('.pt$', '.yaml', path) 56 | if infos is None: 57 | infos = {} 58 | infos['save_time'] = datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S') 59 | with open(info_path, 'w') as fout: 60 | data = yaml.dump(infos) 61 | fout.write(data) 62 | 63 | 64 | def filter_modules(model_state_dict, modules): 65 | new_mods = [] 66 | incorrect_mods = [] 67 | mods_model = model_state_dict.keys() 68 | for mod in modules: 69 | if any(key.startswith(mod) for key in mods_model): 70 | new_mods += [mod] 71 | else: 72 | incorrect_mods += [mod] 73 | if incorrect_mods: 74 | logging.warning( 75 | "module(s) %s don't match or (partially match) " 76 | "available modules in model.", 77 | incorrect_mods, 78 | ) 79 | logging.warning("for information, the existing modules in model are:") 80 | logging.warning("%s", mods_model) 81 | 82 | return new_mods 83 | 84 | 85 | def load_trained_modules(model: torch.nn.Module, args: None): 86 | # Load encoder modules with pre-trained model(s). 87 | enc_model_path = args.enc_init 88 | enc_modules = args.enc_init_mods 89 | main_state_dict = model.state_dict() 90 | logging.warning("model(s) found for pre-initialization") 91 | if os.path.isfile(enc_model_path): 92 | logging.info('Checkpoint: loading from checkpoint %s for CPU' % 93 | enc_model_path) 94 | model_state_dict = torch.load(enc_model_path, map_location='cpu') 95 | modules = filter_modules(model_state_dict, enc_modules) 96 | partial_state_dict = OrderedDict() 97 | for key, value in model_state_dict.items(): 98 | if any(key.startswith(m) for m in modules): 99 | partial_state_dict[key] = value 100 | main_state_dict.update(partial_state_dict) 101 | else: 102 | logging.warning("model was not found : %s", enc_model_path) 103 | 104 | model.load_state_dict(main_state_dict) 105 | configs = {} 106 | return configs 107 | -------------------------------------------------------------------------------- /track1_asr/wenet/utils/cmvn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import json 16 | import math 17 | 18 | import numpy as np 19 | 20 | 21 | def _load_json_cmvn(json_cmvn_file): 22 | """ Load the json format cmvn stats file and calculate cmvn 23 | 24 | Args: 25 | json_cmvn_file: cmvn stats file in json format 26 | 27 | Returns: 28 | a numpy array of [means, vars] 29 | """ 30 | with open(json_cmvn_file) as f: 31 | cmvn_stats = json.load(f) 32 | 33 | means = cmvn_stats['mean_stat'] 34 | variance = cmvn_stats['var_stat'] 35 | count = cmvn_stats['frame_num'] 36 | for i in range(len(means)): 37 | means[i] /= count 38 | variance[i] = variance[i] / count - means[i] * means[i] 39 | if variance[i] < 1.0e-20: 40 | variance[i] = 1.0e-20 41 | variance[i] = 1.0 / math.sqrt(variance[i]) 42 | cmvn = np.array([means, variance]) 43 | return cmvn 44 | 45 | 46 | def _load_kaldi_cmvn(kaldi_cmvn_file): 47 | """ Load the kaldi format cmvn stats file and calculate cmvn 48 | 49 | Args: 50 | kaldi_cmvn_file: kaldi text style global cmvn file, which 51 | is generated by: 52 | compute-cmvn-stats --binary=false scp:feats.scp global_cmvn 53 | 54 | Returns: 55 | a numpy array of [means, vars] 56 | """ 57 | means = [] 58 | variance = [] 59 | with open(kaldi_cmvn_file, 'r') as fid: 60 | # kaldi binary file start with '\0B' 61 | if fid.read(2) == '\0B': 62 | logging.error('kaldi cmvn binary file is not supported, please ' 63 | 'recompute it by: compute-cmvn-stats --binary=false ' 64 | ' scp:feats.scp global_cmvn') 65 | sys.exit(1) 66 | fid.seek(0) 67 | arr = fid.read().split() 68 | assert (arr[0] == '[') 69 | assert (arr[-2] == '0') 70 | assert (arr[-1] == ']') 71 | feat_dim = int((len(arr) - 2 - 2) / 2) 72 | for i in range(1, feat_dim + 1): 73 | means.append(float(arr[i])) 74 | count = float(arr[feat_dim + 1]) 75 | for i in range(feat_dim + 2, 2 * feat_dim + 2): 76 | variance.append(float(arr[i])) 77 | 78 | for i in range(len(means)): 79 | means[i] /= count 80 | variance[i] = variance[i] / count - means[i] * means[i] 81 | if variance[i] < 1.0e-20: 82 | variance[i] = 1.0e-20 83 | variance[i] = 1.0 / math.sqrt(variance[i]) 84 | cmvn = np.array([means, variance]) 85 | return cmvn 86 | 87 | 88 | def load_cmvn(cmvn_file, is_json): 89 | if is_json: 90 | cmvn = _load_json_cmvn(cmvn_file) 91 | else: 92 | cmvn = _load_kaldi_cmvn(cmvn_file) 93 | return cmvn[0], cmvn[1] 94 | -------------------------------------------------------------------------------- /track1_asr/wenet/utils/config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Shaoshang Qi 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import copy 17 | 18 | def override_config(configs, override_list): 19 | new_configs = copy.deepcopy(configs) 20 | for item in override_list: 21 | arr = item.split() 22 | if len(arr) != 2: 23 | print(f"the overrive {item} format not correct, skip it") 24 | continue 25 | keys = arr[0].split('.') 26 | s_configs = new_configs 27 | for i, key in enumerate(keys): 28 | if key not in s_configs: 29 | print(f"the overrive {item} format not correct, skip it") 30 | if i == len(keys) - 1: 31 | param_type = type(s_configs[key]) 32 | if param_type != bool: 33 | s_configs[key] = param_type(arr[1]) 34 | else: 35 | s_configs[key] = arr[1] in ['true', 'True'] 36 | print(f"override {arr[0]} with {arr[1]}") 37 | else: 38 | s_configs = s_configs[key] 39 | return new_configs 40 | -------------------------------------------------------------------------------- /track1_asr/wenet/utils/context_graph.py: -------------------------------------------------------------------------------- 1 | from wenet.dataset.processor import __tokenize_by_bpe_model 2 | from typing import Dict, List 3 | 4 | 5 | def tokenize(context_list_path, symbol_table, bpe_model=None): 6 | """ Read biasing list from the biasing list address, tokenize and convert it 7 | into token id 8 | """ 9 | if bpe_model is not None: 10 | import sentencepiece as spm 11 | sp = spm.SentencePieceProcessor() 12 | sp.load(bpe_model) 13 | else: 14 | sp = None 15 | 16 | with open(context_list_path, "r") as fin: 17 | context_txts = fin.readlines() 18 | 19 | context_list = [] 20 | for context_txt in context_txts: 21 | context_txt = context_txt.strip() 22 | 23 | labels = [] 24 | tokens = [] 25 | if bpe_model is not None: 26 | tokens = __tokenize_by_bpe_model(sp, context_txt) 27 | else: 28 | for ch in context_txt: 29 | if ch == ' ': 30 | ch = "▁" 31 | tokens.append(ch) 32 | for ch in tokens: 33 | if ch in symbol_table: 34 | labels.append(symbol_table[ch]) 35 | elif '' in symbol_table: 36 | labels.append(symbol_table['']) 37 | context_list.append(labels) 38 | return context_list 39 | 40 | 41 | class ContextGraph: 42 | """ Context decoding graph, constructing graph using dict instead of WFST 43 | Args: 44 | context_list_path(str): context list path 45 | bpe_model(str): model for english bpe part 46 | context_score(float): context score for each token 47 | """ 48 | def __init__(self, 49 | context_list_path: str, 50 | symbol_table: Dict[str, int], 51 | bpe_model: str = None, 52 | context_score: float = 6): 53 | self.context_score = context_score 54 | self.context_list = tokenize(context_list_path, symbol_table, 55 | bpe_model) 56 | self.graph = {0: {}} 57 | self.graph_size = 0 58 | self.state2token = {} 59 | self.back_score = {0: 0.0} 60 | self.build_graph(self.context_list) 61 | 62 | def build_graph(self, context_list: List[List[int]]): 63 | """ Constructing the context decoding graph, add arcs with negative 64 | scores returning to the starting state for each non-terminal tokens 65 | of hotwords, and add arcs with scores of 0 returning to the starting 66 | state for terminal tokens. 67 | """ 68 | self.graph = {0: {}} 69 | self.graph_size = 0 70 | self.state2token = {} 71 | self.back_score = {0: 0.0} 72 | for context_token in context_list: 73 | now_state = 0 74 | for i in range(len(context_token)): 75 | if context_token[i] in self.graph[now_state]: 76 | now_state = self.graph[now_state][context_token[i]] 77 | if i == len(context_token) - 1: 78 | self.back_score[now_state] = 0 79 | else: 80 | self.graph_size += 1 81 | self.graph[self.graph_size] = {} 82 | self.graph[now_state][context_token[i]] = self.graph_size 83 | now_state = self.graph_size 84 | if i != len(context_token) - 1: 85 | self.back_score[now_state] = -(i + 86 | 1) * self.context_score 87 | else: 88 | self.back_score[now_state] = 0 89 | self.state2token[now_state] = context_token[i] 90 | 91 | def find_next_state(self, now_state: int, token: int): 92 | """ Search for an arc with the input being a token from the current state, 93 | returning the score on the arc and the state it points to. If there is 94 | no match, return to the starting state and perform an additional search 95 | from the starting state to avoid token consumption due to mismatches. 96 | """ 97 | if token in self.graph[now_state]: 98 | return self.graph[now_state][token], self.context_score 99 | back_score = self.back_score[now_state] 100 | now_state = 0 101 | if token in self.graph[now_state]: 102 | return self.graph[now_state][ 103 | token], back_score + self.context_score 104 | return 0, back_score 105 | -------------------------------------------------------------------------------- /track1_asr/wenet/utils/ctc_util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import torch 17 | 18 | def insert_blank(label, blank_id=0): 19 | """Insert blank token between every two label token.""" 20 | label = np.expand_dims(label, 1) 21 | blanks = np.zeros((label.shape[0], 1), dtype=np.int64) + blank_id 22 | label = np.concatenate([blanks, label], axis=1) 23 | label = label.reshape(-1) 24 | label = np.append(label, label[0]) 25 | return label 26 | 27 | def forced_align(ctc_probs: torch.Tensor, 28 | y: torch.Tensor, 29 | blank_id=0) -> list: 30 | """ctc forced alignment. 31 | 32 | Args: 33 | torch.Tensor ctc_probs: hidden state sequence, 2d tensor (T, D) 34 | torch.Tensor y: id sequence tensor 1d tensor (L) 35 | int blank_id: blank symbol index 36 | Returns: 37 | torch.Tensor: alignment result 38 | """ 39 | ctc_probs = ctc_probs.cpu() 40 | y = y.cpu() 41 | y_insert_blank = insert_blank(y, blank_id) 42 | 43 | log_alpha = torch.zeros((ctc_probs.size(0), len(y_insert_blank))) 44 | log_alpha = log_alpha - float('inf') # log of zero 45 | state_path = (torch.zeros( 46 | (ctc_probs.size(0), len(y_insert_blank)), dtype=torch.int16) - 1 47 | ) # state path 48 | 49 | # init start state 50 | log_alpha[0, 0] = ctc_probs[0][y_insert_blank[0]] 51 | log_alpha[0, 1] = ctc_probs[0][y_insert_blank[1]] 52 | 53 | for t in range(1, ctc_probs.size(0)): 54 | for s in range(len(y_insert_blank)): 55 | if y_insert_blank[s] == blank_id or s < 2 or y_insert_blank[ 56 | s] == y_insert_blank[s - 2]: 57 | candidates = torch.tensor( 58 | [log_alpha[t - 1, s], log_alpha[t - 1, s - 1]]) 59 | prev_state = [s, s - 1] 60 | else: 61 | candidates = torch.tensor([ 62 | log_alpha[t - 1, s], 63 | log_alpha[t - 1, s - 1], 64 | log_alpha[t - 1, s - 2], 65 | ]) 66 | prev_state = [s, s - 1, s - 2] 67 | log_alpha[t, s] = torch.max(candidates) + ctc_probs[t][y_insert_blank[s]] 68 | state_path[t, s] = prev_state[torch.argmax(candidates)] 69 | 70 | state_seq = -1 * torch.ones((ctc_probs.size(0), 1), dtype=torch.int16) 71 | 72 | candidates = torch.tensor([ 73 | log_alpha[-1, len(y_insert_blank) - 1], 74 | log_alpha[-1, len(y_insert_blank) - 2] 75 | ]) 76 | final_state = [len(y_insert_blank) - 1, len(y_insert_blank) - 2] 77 | state_seq[-1] = final_state[torch.argmax(candidates)] 78 | for t in range(ctc_probs.size(0) - 2, -1, -1): 79 | state_seq[t] = state_path[t + 1, state_seq[t + 1, 0]] 80 | 81 | output_alignment = [] 82 | for t in range(0, ctc_probs.size(0)): 83 | output_alignment.append(y_insert_blank[state_seq[t, 0]]) 84 | 85 | return output_alignment 86 | -------------------------------------------------------------------------------- /track1_asr/wenet/utils/file_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import re 16 | 17 | 18 | def read_lists(list_file): 19 | lists = [] 20 | with open(list_file, 'r', encoding='utf8') as fin: 21 | for line in fin: 22 | lists.append(line.strip()) 23 | return lists 24 | 25 | 26 | def read_non_lang_symbols(non_lang_sym_path): 27 | """read non-linguistic symbol from file. 28 | 29 | The file format is like below: 30 | 31 | {NOISE}\n 32 | {BRK}\n 33 | ... 34 | 35 | 36 | Args: 37 | non_lang_sym_path: non-linguistic symbol file path, None means no any 38 | syms. 39 | 40 | """ 41 | if non_lang_sym_path is None: 42 | return None 43 | else: 44 | syms = read_lists(non_lang_sym_path) 45 | non_lang_syms_pattern = re.compile(r"(\[[^\[\]]+\]|<[^<>]+>|{[^{}]+})") 46 | for sym in syms: 47 | if non_lang_syms_pattern.fullmatch(sym) is None: 48 | class BadSymbolFormat(Exception): 49 | pass 50 | raise BadSymbolFormat( 51 | "Non-linguistic symbols should be " 52 | "formatted in {xxx}//[xxx], consider" 53 | " modify '%s' to meet the requirment. " 54 | "More details can be found in discussions here : " 55 | "https://github.com/wenet-e2e/wenet/pull/819" % (sym)) 56 | return syms 57 | 58 | 59 | def read_symbol_table(symbol_table_file): 60 | symbol_table = {} 61 | with open(symbol_table_file, 'r', encoding='utf8') as fin: 62 | for line in fin: 63 | arr = line.strip().split() 64 | assert len(arr) == 2 65 | symbol_table[arr[0]] = int(arr[1]) 66 | return symbol_table 67 | -------------------------------------------------------------------------------- /track2_asdr/data/dict/lang_char.txt: -------------------------------------------------------------------------------- 1 | ../../../track1_asr/data/dict/lang_char.txt -------------------------------------------------------------------------------- /track2_asdr/data/train_aec_iva_near/global_cmvn: -------------------------------------------------------------------------------- 1 | ../../../track1_asr/data/train_aec_iva_near/global_cmvn -------------------------------------------------------------------------------- /track2_asdr/exp/baseline_ebranchformer/avg_10.pt: -------------------------------------------------------------------------------- 1 | ../../../track1_asr/exp/baseline_ebranchformer/avg_10.pt -------------------------------------------------------------------------------- /track2_asdr/exp/baseline_ebranchformer/train.yaml: -------------------------------------------------------------------------------- 1 | ../../../track1_asr/exp/baseline_ebranchformer/train.yaml -------------------------------------------------------------------------------- /track2_asdr/local/generate_submission_file.py: -------------------------------------------------------------------------------- 1 | ../../track1_asr/local/generate_submission_file.py -------------------------------------------------------------------------------- /track2_asdr/local/merge_session_rttms.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | import numpy as np 5 | 6 | 7 | def main(args): 8 | rttm = {} 9 | rttm2dur = {} 10 | 11 | if os.path.exists(args.save_path): 12 | os.system(f"rm {args.save_path}") 13 | 14 | for rttm_file in glob.glob(os.path.join(args.segments_path, "*.rttm"), recursive=False): 15 | session = rttm_file.split("/")[-1].split(".")[0] 16 | with open(rttm_file, "r") as f: 17 | for line in f.readlines(): 18 | start = line.strip().split()[3] 19 | dur = line.strip().split()[4] 20 | start, dur = float(start), float(dur) 21 | rttm.setdefault(session, []).append((round(start, 2), round(start + dur, 2))) 22 | if session not in rttm2dur: 23 | rttm2dur[session] = 0. 24 | else: 25 | rttm2dur[session] += dur 26 | 27 | for utt in rttm.keys(): 28 | rttm[utt].sort() 29 | 30 | result = {} 31 | result_save_index = {} 32 | 33 | for session in sorted(rttm.keys()): 34 | segs = rttm[session] 35 | real_session, channel = session.split("_") 36 | if real_session not in result: 37 | result[real_session] = [] 38 | durs = [] 39 | channel = channel[:3] + "1" + channel[4:] 40 | result[real_session].append(rttm[real_session + "_" + channel]) 41 | durs.append(rttm2dur[real_session + "_" + channel]) 42 | 43 | channel = channel[:3] + "2" + channel[4:] 44 | result[real_session].append(rttm[real_session + "_" + channel]) 45 | durs.append(rttm2dur[real_session + "_" + channel]) 46 | 47 | channel = channel[:3] + "3" + channel[4:] 48 | result[real_session].append(rttm[real_session + "_" + channel]) 49 | durs.append(rttm2dur[real_session + "_" + channel]) 50 | 51 | channel = channel[:3] + "4" + channel[4:] 52 | result[real_session].append(rttm[real_session + "_" + channel]) 53 | durs.append(rttm2dur[real_session + "_" + channel]) 54 | 55 | # only two speakers in each session. so we only need to save the two longest duration speakers 56 | result_save_index[real_session] = np.argsort(durs)[-2:] 57 | 58 | with open(args.save_path, "w") as wf: 59 | for session, segs in result.items(): 60 | for i in range(len(segs)): 61 | if i in result_save_index[session]: 62 | for seg in segs[i]: 63 | print(f"SPEAKER {session} 1 {seg[0]} {round(seg[1] - seg[0], 2)} {i + 1} ", 64 | file=wf) 65 | 66 | 67 | if __name__ == "__main__": 68 | parser = argparse.ArgumentParser() 69 | parser.add_argument("--segments_path", type=str, required=True) 70 | parser.add_argument("--save_path", type=str, required=True) 71 | main(parser.parse_args()) 72 | -------------------------------------------------------------------------------- /track2_asdr/local/pyannote_vad.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | from tqdm import tqdm 5 | 6 | from pyannote.audio import Model 7 | from pyannote.audio.pipelines import VoiceActivityDetection 8 | 9 | 10 | def main(args): 11 | if os.path.exists(args.save_path): 12 | os.system(f"rm -rf {args.save_path}") 13 | os.makedirs(args.save_path) 14 | 15 | wav_scp = {} 16 | with open(args.wav_scp, "r") as f: 17 | for line in f.readlines(): 18 | utt, path = line.strip().split() 19 | wav_scp[utt] = path 20 | 21 | model = Model.from_pretrained("pyannote/segmentation", use_auth_token=args.token).cuda() 22 | pipeline = VoiceActivityDetection(segmentation=model) 23 | 24 | HYPER_PARAMETERS = { 25 | # onset/offset activation thresholds 26 | "onset": args.threshold, "offset": 0.5, 27 | # remove speech regions shorter than that many seconds. 28 | "min_duration_on": 0.1, 29 | # fill non-speech regions shorter than that many seconds. 30 | "min_duration_off": 0.3 31 | } 32 | 33 | pipeline.instantiate(HYPER_PARAMETERS) 34 | for utt, path in tqdm(wav_scp.items()): 35 | vad = pipeline(path) 36 | with open(os.path.join(args.save_path, f"{utt}.rttm"), "w") as rttm: 37 | vad.write_rttm(rttm) 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument('--wav_scp', type=str, required=True) 43 | parser.add_argument('--save_path', type=str, required=True) 44 | parser.add_argument('--threshold', type=float, required=True) 45 | parser.add_argument('--token', type=str, required=True) 46 | 47 | main(parser.parse_args()) 48 | -------------------------------------------------------------------------------- /track2_asdr/local/run_pyannote_vad.sh: -------------------------------------------------------------------------------- 1 | . ./path.sh || exit 1 2 | 3 | stage=0 4 | stop_stage=2 5 | 6 | . tools/parse_options.sh 7 | 8 | enhanced_data_root=$1 9 | dataset=$2 10 | threshold=$3 11 | HUGGINGFACE_ACCESS_TOKEN=$4 12 | dataset_prefix=$(echo "$dataset" | cut -d '_' -f 1) 13 | if [ "${dataset_prefix}" == "eval" ]; then 14 | dataset_prefix=$(echo "$dataset" | cut -d _ -f 1-2) 15 | fi 16 | 17 | if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then 18 | echo "[local/run_vad.sh] stage 0 generate wav.scp file for ${dataset}_${threshold} set" 19 | mkdir -p data/vad/${dataset}_${threshold} 20 | ls ${enhanced_data_root}/${dataset_prefix}/*/DX0[1-4]C01.wav | awk -F/ '{print $(NF-1)"_"substr($NF, 1, length($NF)-4), $0}' \ 21 | > data/vad/${dataset}_${threshold}/wav.scp 22 | fi 23 | 24 | 25 | if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then 26 | echo "[local/run_vad.sh] stage 1 run vad for ${dataset}_${threshold} set" 27 | python3 local/pyannote_vad.py \ 28 | --wav_scp data/vad/${dataset}_${threshold}/wav.scp \ 29 | --save_path exp/pyannote_vad/${dataset}_${threshold} \ 30 | --threshold $threshold \ 31 | --token "$HUGGINGFACE_ACCESS_TOKEN" 32 | fi 33 | 34 | 35 | if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then 36 | echo "[local/run_vad.sh] stage 2 merge session rttms for ${dataset}_${threshold} set" 37 | python3 local/merge_session_rttms.py \ 38 | --segments_path exp/pyannote_vad/${dataset}_${threshold} \ 39 | --save_path exp/pyannote_vad/${dataset}_${threshold}/pyannote_vad.rttm 40 | fi 41 | -------------------------------------------------------------------------------- /track2_asdr/local/segment_wavs_by_rttm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import sys 4 | import glob 5 | import math 6 | import argparse 7 | import textgrid 8 | import traceback 9 | import multiprocessing 10 | 11 | from tqdm import tqdm 12 | from functools import partial 13 | from pydub import AudioSegment 14 | from concurrent.futures import ProcessPoolExecutor 15 | 16 | 17 | def rttm2wav(line): 18 | line = line.strip() 19 | _, session, _, start, duration, _, _, seat, _, _ = line.split() 20 | wav = AudioSegment.from_wav(f"{enhanced_data_root}/{dataset.split('_aec_iva')[0]}/{session}/DX0{seat}C01.wav") 21 | start = int(float(start) * 1000) 22 | end = start + int(float(duration) * 1000) 23 | export_wav_path = f"{output_dir}/P000{seat}_{session}_DX0{seat}C01_{start:0>6}-{end:0>6}.wav" 24 | wav[start:end].export(export_wav_path, format="wav") 25 | 26 | 27 | def multiThread_use_ProcessPoolExecutor_dicarg(scp, numthread, func, args): 28 | executor = ProcessPoolExecutor(max_workers=numthread) 29 | results = [] 30 | for item in scp: 31 | results.append(executor.submit(partial(func, item, **args))) 32 | return [result.result() for result in tqdm(results)] 33 | 34 | 35 | if __name__ == '__main__': 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument('--enhanced_data_root', type=str, help='enhanced data root') 38 | parser.add_argument('--dataset', type=str, help='dataset') 39 | parser.add_argument('--nj', type=int, default=32, help='number of jobs') 40 | parser.add_argument('rttm_file', help='rttm file path') 41 | parser.add_argument('output_dir', help='output dir of segment wavs') 42 | args = parser.parse_args() 43 | enhanced_data_root, dataset, nj, rttm_file, output_dir = \ 44 | args.enhanced_data_root, args.dataset, args.nj, args.rttm_file, args.output_dir 45 | os.makedirs(output_dir, exist_ok=True) 46 | lines = open(rttm_file, 'r').readlines() 47 | # segment_wavs_by_rttm 48 | multiThread_use_ProcessPoolExecutor_dicarg(lines, nj, rttm2wav, {}) 49 | # write scp and blank text 50 | output_dir_dir = '/'.join(output_dir.strip('/').split('/')[:-1]) 51 | with open(f"{output_dir_dir}/wav.scp", 'w') as ws: 52 | with open(f"{output_dir_dir}/text", 'w') as tt: 53 | for wav in glob.glob(f"{output_dir}/*.wav"): 54 | ws.write(f"{os.path.basename(wav).split('.')[0]} {wav}\n") 55 | tt.write(f"{os.path.basename(wav).split('.')[0]} 空\n") 56 | -------------------------------------------------------------------------------- /track2_asdr/path.sh: -------------------------------------------------------------------------------- 1 | ../track1_asr/path.sh -------------------------------------------------------------------------------- /track2_asdr/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright 2019 Mobvoi Inc. All Rights Reserved. 4 | . ./path.sh || exit 1; 5 | 6 | # Use this to control how many gpu you use, It's 1-gpu training if you specify 7 | # just 1gpu, otherwise it's is multiple gpu training based on DDP in pytorch 8 | export CUDA_VISIBLE_DEVICES="0,1,2,3" 9 | # for debug purpose, please set it to 1 otherwise, set it to 0 10 | export CUDA_LAUNCH_BLOCKING=0 11 | 12 | stage=0 # start from 0 if you need to start from data preparation 13 | stop_stage=5 14 | 15 | # Create your access token at https://huggingface.co/settings/tokens 16 | HUGGINGFACE_ACCESS_TOKEN="YOUR_HUGGINGFACE_ACCESS_TOKEN" 17 | ################################################ 18 | # The icmc-asr dataset location, please change this to your own path!!! 19 | # Make sure of using absolute path. DO-NOT-USE relatvie path! 20 | # data dir for IVA + AEC enhanced audio 21 | data_enhanced=/home/work_nfs4_ssd/hwang/data/ICMC-ASR_ENHANCED 22 | ################################################ 23 | 24 | nj=64 25 | dict=data/dict/lang_char.txt 26 | 27 | # Pyannote VAD activation threshold 28 | threshold=0.95 29 | # data_type can be `raw` or `shard`. Typically, raw is used for small dataset, 30 | # `shard` is used for large dataset which is over 1k hours, and `shard` is 31 | # faster on reading data and training. 32 | data_type=raw 33 | num_utts_per_shard=1000 34 | 35 | test_set="eval_track2_aec_iva" 36 | dir=exp/baseline_ebranchformer 37 | 38 | # use average_checkpoint will get better result 39 | decode_checkpoint=$dir/avg_10.pt 40 | decode_modes="ctc_greedy_search ctc_prefix_beam_search attention attention_rescoring" 41 | 42 | . tools/parse_options.sh || exit 1; 43 | 44 | 45 | if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then 46 | echo "stage 0: Do VAD for enhanced audio data" 47 | for x in ${test_set} ; do 48 | # enhanced_data_root dataset threshold HUGGINGFACE_ACCESS_TOKEN 49 | local/run_pyannote_vad.sh $data_enhanced $x $threshold $HUGGINGFACE_ACCESS_TOKEN 50 | done 51 | fi 52 | 53 | if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then 54 | echo "stage 1: Segment audio data based on VAD results" 55 | for x in ${test_set} ; do 56 | python3 local/segment_wavs_by_rttm.py --enhanced_data_root $data_enhanced --dataset $x --nj $nj \ 57 | exp/pyannote_vad/${x}_${threshold}/pyannote_vad.rttm data/${x}_${threshold}/pyannote_vad_wavs/ 58 | done 59 | fi 60 | 61 | 62 | if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then 63 | echo "stage 2: Prepare data in WeNet required format" 64 | for x in ${test_set}; do 65 | if [ $data_type == "shard" ]; then 66 | tools/make_shard_list.py --num_utts_per_shard $num_utts_per_shard \ 67 | --num_threads ${nj} data/${x}_${threshold}/wav.scp data/$x/text \ 68 | $(realpath data/${x}_${threshold}/shards) data/${x}_${threshold}/data.list 69 | else 70 | tools/make_raw_list.py data/${x}_${threshold}/wav.scp data/${x}_${threshold}/text \ 71 | data/${x}_${threshold}/data.list 72 | fi 73 | done 74 | fi 75 | 76 | if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then 77 | # Test model, please specify the model you want to test by --checkpoint 78 | echo "stage 3: Test model testset ${test_set}_${threshold}" 79 | mkdir -p $dir 80 | 81 | if [ ! -f $decode_checkpoint ]; then 82 | echo "error: $decode_checkpoint does not exist." 83 | echo "please copy the trained model from track1 to $decode_checkpoint" 84 | exit 1 85 | fi 86 | 87 | num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') 88 | # Please specify decoding_chunk_size for unified streaming and 89 | # non-streaming model. The default value is -1, which is full chunk 90 | # for non-streaming inference. 91 | decoding_chunk_size= 92 | ctc_weight=0.3 93 | idx=0 94 | for mode in ${decode_modes}; do 95 | { 96 | test_dir=$dir/${test_set}_${mode}_${threshold} 97 | mkdir -p $test_dir 98 | gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$idx+1]) 99 | python3 wenet/bin/recognize.py --gpu $gpu_id \ 100 | --mode $mode \ 101 | --config $dir/train.yaml \ 102 | --data_type $data_type \ 103 | --test_data data/${test_set}_${threshold}/data.list \ 104 | --checkpoint $decode_checkpoint \ 105 | --beam_size 10 \ 106 | --batch_size 1 \ 107 | --penalty 0.0 \ 108 | --dict $dict \ 109 | --ctc_weight $ctc_weight \ 110 | --result_file $test_dir/text \ 111 | ${decoding_chunk_size:+--decoding_chunk_size $decoding_chunk_size} 112 | } & 113 | ((idx+=1)) 114 | if [ $idx -eq $num_gpus ]; then 115 | idx=0 116 | fi 117 | done 118 | wait 119 | fi 120 | 121 | if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then 122 | echo "stage 4: Generate submission file for track2 leaderboard" 123 | for mode in ${decode_modes}; do 124 | { 125 | test_dir=$dir/${test_set}_${mode}_${threshold} 126 | python3 local/generate_submission_file.py "$test_dir" 127 | } 128 | done 129 | fi 130 | 131 | # only for dev_aec_iva set 132 | if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then 133 | echo "stage 5: Compute cpCER of dev_aec_iva set" 134 | for mode in ${decode_modes}; do 135 | { 136 | echo "compute cpCER for ${test_set}_${mode}_${threshold}" 137 | test_dir=$dir/${test_set}_${mode}_${threshold} 138 | python3 local/compute_cpcer.py --hyp-path $test_dir/submission.txt --ref-path data/dev_ref.txt 139 | echo "" 140 | } 141 | done 142 | fi 143 | -------------------------------------------------------------------------------- /track2_asdr/tools: -------------------------------------------------------------------------------- 1 | ../track1_asr/tools -------------------------------------------------------------------------------- /track2_asdr/wenet: -------------------------------------------------------------------------------- 1 | ../track1_asr/wenet --------------------------------------------------------------------------------