├── an4 ├── test.sh └── local │ ├── test_data.sh │ └── test_data_prep.py ├── tedlium2 ├── local │ ├── join_suffix.py │ └── prepare_test_data.sh └── test.sh ├── tedlium3 ├── local │ ├── join_suffix.py │ └── prepare_test_data.sh └── test.sh ├── timit ├── test.sh └── local │ ├── timit_format_test_data.sh │ ├── test_data.sh │ ├── timit_norm_trans.pl │ └── timit_test_data_prep.sh ├── README.md ├── utils └── score_sclite.sh ├── espnet2 ├── asr │ └── decoder │ │ └── rnn_decoder.py └── bin │ └── asr_test.py ├── espnet ├── bin │ └── asr_test.py ├── utils │ ├── gini_utils.py │ └── gini_guide.py └── nets │ └── test_beam_search.py └── TEMPLATE └── asr1 └── asr_test.sh /an4/test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Set bash to 'debug' mode, it will exit on : 3 | # -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', 4 | set -e 5 | set -u 6 | set -o pipefail 7 | 8 | test_sets="test-room" 9 | orig_flag=false 10 | need_decode=true 11 | selected_num=130 12 | stage=1 13 | stop_stage=4 14 | dataset="an4" 15 | . utils/parse_options.sh 16 | 17 | ./asr_test.sh \ 18 | --use_lm false \ 19 | --test_sets ${test_sets} \ 20 | --local_data_opts ${test_sets}\ 21 | --orig_flag ${orig_flag} \ 22 | --need_decode ${need_decode} \ 23 | --stage ${stage} \ 24 | --stop_stage ${stop_stage} \ 25 | --token_type bpe \ 26 | --asr_exp "exp/asr_train_raw_en_bpe30" \ 27 | --dataset ${dataset} \ 28 | --selected_num ${selected_num} 29 | -------------------------------------------------------------------------------- /tedlium2/local/join_suffix.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # Copyright 2014 Nickolay V. Shmyrev 4 | # 2016 Johns Hopkins University (author: Daniel Povey) 5 | # Apache 2.0 6 | 7 | 8 | import sys 9 | 10 | # This script joins together pairs of split-up words like "you 're" -> "you're". 11 | # The TEDLIUM transcripts are normalized in a way that's not traditional for 12 | # speech recognition. 13 | 14 | prev_line = "" 15 | for line in sys.stdin: 16 | if line == prev_line: 17 | continue 18 | items = line.split() 19 | new_items = [] 20 | i = 0 21 | while i < len(items): 22 | if i < len(items) - 1 and items[i + 1][0] == "'": 23 | new_items.append(items[i] + items[i + 1]) 24 | i = i + 1 25 | else: 26 | new_items.append(items[i]) 27 | i = i + 1 28 | print(" ".join(new_items)) 29 | prev_line = line 30 | -------------------------------------------------------------------------------- /tedlium3/local/join_suffix.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # Copyright 2014 Nickolay V. Shmyrev 4 | # 2016 Johns Hopkins University (author: Daniel Povey) 5 | # Apache 2.0 6 | 7 | 8 | import sys 9 | 10 | # This script joins together pairs of split-up words like "you 're" -> "you're". 11 | # The TEDLIUM transcripts are normalized in a way that's not traditional for 12 | # speech recognition. 13 | 14 | prev_line = "" 15 | for line in sys.stdin: 16 | if line == prev_line: 17 | continue 18 | items = line.split() 19 | new_items = [] 20 | i = 0 21 | while i < len(items): 22 | if i < len(items) - 1 and items[i + 1][0] == "'": 23 | new_items.append(items[i] + items[i + 1]) 24 | i = i + 1 25 | else: 26 | new_items.append(items[i]) 27 | i = i + 1 28 | print(" ".join(new_items)) 29 | prev_line = line 30 | -------------------------------------------------------------------------------- /timit/test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Set bash to 'debug' mode, it will exit on : 3 | # -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', 4 | 5 | set -e 6 | set -u 7 | set -o pipefail 8 | 9 | test_sets="test-feature" 10 | model_name="kamo-naoyuki/timit_asr_train_asr_raw_word_valid.acc.ave" 11 | asr_exp="exp/${model_name}" 12 | selected_num=192 13 | orig_flag=false 14 | need_decode=true 15 | stage=1 16 | stop_stage=4 17 | train_flag=false 18 | dataset="TIMIT" 19 | 20 | # Set this to one of ["phn", "char"] depending on your requirement 21 | trans_type=phn 22 | if [ "${trans_type}" = phn ]; then 23 | # If the transcription is "phn" type, the token splitting should be done in word level 24 | token_type=word 25 | else 26 | token_type="${trans_type}" 27 | fi 28 | 29 | asr_config=conf/train_asr.yaml 30 | lm_config=conf/train_lm_rnn.yaml 31 | inference_config=conf/decode_asr.yaml 32 | 33 | . utils/parse_options.sh 34 | 35 | ./asr_test.sh \ 36 | --token_type "${token_type}" \ 37 | --test_sets "${test_sets}" \ 38 | --use_lm false \ 39 | --asr_config "${asr_config}" \ 40 | --lm_config "${lm_config}" \ 41 | --inference_config "${inference_config}" \ 42 | --local_data_opts "--trans_type ${trans_type} --train_flag ${train_flag} --test_sets ${test_sets}" \ 43 | --selected_num ${selected_num} \ 44 | --orig_flag ${orig_flag} \ 45 | --need_decode ${need_decode} \ 46 | --stage ${stage} \ 47 | --stop_stage ${stop_stage} \ 48 | --asr_exp ${asr_exp} \ 49 | --dataset ${dataset} 50 | -------------------------------------------------------------------------------- /an4/local/test_data.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Set bash to 'debug' mode, it will exit on : 3 | # -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', 4 | set -e 5 | set -u 6 | set -o pipefail 7 | 8 | log() { 9 | local fname=${BASH_SOURCE[1]##*/} 10 | echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" 11 | } 12 | SECONDS=0 13 | 14 | stage=1 15 | stop_stage=100 16 | 17 | datadir=./downloads 18 | an4_root=${datadir}/an4 19 | data_url=http://www.speech.cs.cmu.edu/databases/an4/ 20 | ndev_utt=100 21 | test_set=${1} 22 | log "$0 $*" 23 | echo ${test_set} 24 | . utils/parse_options.sh 25 | 26 | # if [ $# -ne 0 ]; then 27 | # log "Error: No positional arguments are required." 28 | # exit 2 29 | # fi 30 | 31 | . ./path.sh 32 | . ./cmd.sh 33 | 34 | 35 | if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then 36 | log "stage 1: Data Download" 37 | mkdir -p ${datadir} 38 | local/download_and_untar.sh ${datadir} ${data_url} 39 | fi 40 | 41 | 42 | if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then 43 | log "stage 2: Data preparation" 44 | mkdir -p data/${test_set} 45 | 46 | if [[ ${test_set} =~ "130" ]] 47 | then 48 | log "Do nut use preparing data" 49 | else 50 | python3 local/test_data_prep.py ${an4_root} ${test_set} 51 | fi 52 | for x in ${test_set}; do 53 | for f in text wav.scp utt2spk; do 54 | sort data/${x}/${f} -o data/${x}/${f} 55 | done 56 | utils/utt2spk_to_spk2utt.pl data/${x}/utt2spk > "data/${x}/spk2utt" 57 | done 58 | 59 | fi 60 | 61 | log "Successfully finished. [elapsed=${SECONDS}s]" 62 | 63 | -------------------------------------------------------------------------------- /timit/local/timit_format_test_data.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright 2013 (Author: Daniel Povey) 4 | # Apache 2.0 5 | 6 | # This script takes data prepared in a corpus-dependent way 7 | # in data/local/, and converts it into the "canonical" form, 8 | # in various subdirectories of data/, e.g. data/lang, data/train, etc. 9 | 10 | . ./path.sh || exit 1; 11 | 12 | echo "Preparing test data" 13 | srcdir=data/local/data 14 | lmdir=data/local/nist_lm 15 | tmpdir=data/local/lm_tmp 16 | lexicon=data/local/dict/lexicon.txt 17 | mkdir -p $tmpdir 18 | # str1="$1" 19 | # str2="ORGI" 20 | train_flag=$1 21 | test=$(echo $2 | tr '[A-Z]' '[a-z]') 22 | # result=$(echo $str1 | grep "${str2}") 23 | 24 | # if [[ "$result" != "" ]] ; then 25 | # if ! "${train_flag}"; then 26 | # test=test-orgi 27 | # else 28 | # test=train-orgi 29 | # fi 30 | # else 31 | # if ! "${train_flag}"; then 32 | # test=test-new 33 | # else 34 | # test=train-new 35 | # fi 36 | # fi 37 | 38 | for x in $test dev; do 39 | echo $x 40 | mkdir -p data/$x 41 | if [[ $test =~ "192" ]] 42 | then 43 | log "do not use copy" 44 | else 45 | cp $srcdir/${x}_wav.scp data/$x/wav.scp || exit 1; 46 | cp $srcdir/$x.text data/$x/text || exit 1; 47 | cp $srcdir/$x.spk2utt data/$x/spk2utt || exit 1; 48 | cp $srcdir/$x.utt2spk data/$x/utt2spk || exit 1; 49 | utils/filter_scp.pl data/$x/spk2utt $srcdir/$x.spk2gender > data/$x/spk2gender || exit 1; 50 | [ -e $srcdir/${x}.stm ] && cp $srcdir/${x}.stm data/$x/stm 51 | [ -e $srcdir/${x}.glm ] && cp $srcdir/${x}.glm data/$x/glm 52 | fi 53 | utils/fix_data_dir.sh data/$x 54 | utils/validate_data_dir.sh --no-feats data/$x || exit 1 55 | done 56 | -------------------------------------------------------------------------------- /timit/local/test_data.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright IIIT-Bangalore (Shreekantha Nadig) 3 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 4 | set -euo pipefail 5 | SECONDS=0 6 | log() { 7 | local fname=${BASH_SOURCE[1]##*/} 8 | echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" 9 | } 10 | 11 | 12 | stage=0 # start from 0 if you need to start from data preparation 13 | stop_stage=100 14 | train_flag=false 15 | trans_type=phn 16 | test_sets= 17 | log "$0 $*" 18 | . utils/parse_options.sh 19 | 20 | if [ $# -ne 0 ]; then 21 | log "Error: No positional arguments are required." 22 | exit 2 23 | fi 24 | 25 | . ./path.sh 26 | . ./cmd.sh 27 | . ./db.sh 28 | 29 | echo "train flag is ${train_flag}" 30 | 31 | # general configuration 32 | if [ -z "${TIMIT}" ]; then 33 | log "Fill the value of 'TIMIT' of db.sh" 34 | exit 1 35 | fi 36 | 37 | log "data preparation started" 38 | #TIMIT=/root/espnet/egs2/timit/asr1/data/local/data 39 | if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then 40 | ### Task dependent. You have to make data the following preparation part by yourself. 41 | ### But you can utilize Kaldi recipes in most cases 42 | log "stage1: Preparing data for TIMIT for ${trans_type} level transcripts" 43 | echo $TIMIT 44 | if [[ ${test_sets} =~ "192" ]] 45 | then 46 | log "Do nut use preparing data" 47 | else 48 | local/timit_test_data_prep.sh ${TIMIT} ${trans_type} ${train_flag} ${test_sets} 49 | fi 50 | fi 51 | 52 | if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then 53 | log "stage2: Formatting TIMIT directories" 54 | ### Task dependent. You have to make data the following preparation part by yourself. 55 | ### But you can utilize Kaldi recipes in most cases 56 | local/timit_format_test_data.sh ${train_flag} ${test_sets} 57 | fi 58 | 59 | log "Successfully finished. [elapsed=${SECONDS}s]" 60 | -------------------------------------------------------------------------------- /an4/local/test_data_prep.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | import re 5 | import sys 6 | 7 | if len(sys.argv) != 3: 8 | print("Usage: python test_data_prep.py [an4_root] [test_dir]") 9 | sys.exit(1) 10 | an4_root = sys.argv[1] 11 | test_set = sys.argv[2] 12 | wav_dir = {"test": test_set} 13 | for x in ["test"]: 14 | with open( 15 | os.path.join(an4_root, "etc", "an4_" + test_set + ".transcription") 16 | ) as transcript_f, open(os.path.join("data", wav_dir[x], "text"), "w") as text_f, open( 17 | os.path.join("data", wav_dir[x], "wav.scp"), "w" 18 | ) as wav_scp_f, open( 19 | os.path.join("data", wav_dir[x], "utt2spk"), "w" 20 | ) as utt2spk_f: 21 | 22 | text_f.truncate() 23 | wav_scp_f.truncate() 24 | utt2spk_f.truncate() 25 | 26 | lines = sorted(transcript_f.readlines(), key=lambda s: s.split(" ")[0]) 27 | for line in lines: 28 | line = line.strip() 29 | if not line: 30 | continue 31 | words = re.search(r"^(.*) \(", line).group(1) 32 | if words[:4] == " ": 33 | words = words[4:] 34 | if words[-5:] == " ": 35 | words = words[:-5] 36 | source = re.search(r"\((.*)\)", line).group(1) 37 | pre, mid, last = source.split("-") 38 | utt_id = "-".join([mid, pre, last]) 39 | text_f.write(utt_id + " " + words + "\n") 40 | if "test-orgi" in test_set: 41 | wav_scp_f.write( 42 | utt_id 43 | + " " 44 | + "sph2pipe" 45 | + " -f wav -p -c 1 " 46 | + os.path.join(an4_root, "wav", wav_dir[x], mid, source + ".sph") 47 | + " |\n" 48 | ) 49 | elif ("dev-orig" in test_set) | ("train-orig" in test_set): 50 | wav_scp_f.write( 51 | utt_id 52 | + " " 53 | + os.path.join(an4_root, "wav", wav_dir[x], mid, source + ".wav") 54 | + "\n" 55 | ) 56 | else: 57 | wav_scp_f.write( 58 | utt_id 59 | + " " 60 | + os.path.join(an4_root, "wav", wav_dir[x], mid, source) 61 | + "\n" 62 | ) 63 | utt2spk_f.write(utt_id + " " + mid + "\n") 64 | 65 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ASRTest: Automated Testing for Deep-Neural-Network-Driven Speech Recognition Systems 2 | We experiment ASRTest with the ASR toolkit ESPnet: (https://github.com/espnet/espnet). 3 | 4 | ## generated data 5 | - We generated 8 times the size of the data, which can meet all kinds of model testing. 6 | - AN4: https://pan.baidu.com/s/1KCDb1LKdmaExqqZqlrfAyw?pwd=xk95 7 | - TIMIT: https://pan.baidu.com/s/12WJJLdnKejEVTW1Ft4DX2A?pwd=vg5w 8 | - TEDLIUM2: https://pan.baidu.com/s/1lSo1zuf-QPwKz6MOeYN6iQ?pwd=pyqf 9 | - TEDLIUM3: https://pan.baidu.com/s/1-21psFg2oZrt_RSgabdbHg?pwd=p8xu 10 | 11 | ## How to use the generated data for testing 12 | ### Install the ESPnet (https://github.com/espnet/espnet). 13 | - You can install it easily by following the instruction provide by ESPnet. 14 | ### Add scripts to your espnet directory 15 | - For guidance utils: 16 | - ASRTest/utils/* ---> your espnet directory/utils 17 | - ASRTest/espnet/* ---> your espnet directory/espnet 18 | - ASRTest/espnet2/* ---> your espnet directory/espnet2 19 | - For egs2 (Scripts are also available for other models in egs2): 20 | - ASRTest/TEMPLATE/asr1/asr_test.sh ---> your espnet directory/egs2/TEMPLATE/asr1/asr_test.sh 21 | - ASRTest/an4/* ---> your espnet directory/egs2/an4/asr1/ 22 | - ASRTest/timit/* ---> your espnet directory/egs2/timit/asr1/ 23 | - For egs: 24 | - ASRTest/tedlium2/* ---> your espnet directory/egs/tedlium2/asr1/ 25 | - ASRTest/tedlium3/* ---> your espnet directory/egs/tedlium3/asr1/ 26 | ### Decode the generated data 27 | - For egs2: 28 | - cd egs2/xxx/asr1/ 29 | - ln -s ../../TEMPLATE/asr1/asr_test.sh ./asr_test.sh 30 | - decode the origial test set: sh test.sh --test_sets "test-orig" --need_decode true --orig_flag true --dataset "xxx" --stage 1 --stop_stage 4 31 | - decode the all transformed test set: sh test.sh --test_sets "xxxx" --need_decode true --orig_flag false --dataset "xxx" --stage 1 --stop_stage 4 32 | - obtain the result on the test set transformed by ASRTest: sh test.sh --test_sets "xxxx" --need_decode false --orig_flag false --dataset "xxx" --stage 3 --stop_stage 4 33 | - For egs: 34 | - cd egs/xxx/asr1/ 35 | - decode the origial test set: sh test.sh --recog_set "test-orig" --need_decode true --orig_flag true --stage 0 --stop_stage 4 36 | - decode the transformed test set: sh test.sh --recog_set "xxxx" --need_decode true --orig_flag false --stage 0 --stop_stage 4 37 | - obtain the result on the test set transformed by ASRTest: sh test.sh --recog_set "xxxx" --need_decode false --orig_flag false --stage 3 --stop_stage 4 38 | - test_set/recog_set: test-feature, test-noise, test-room, test-orig 39 | - orig_flag: if test_set/recog_set is test-orig, orig_flag is true; otherwise, it is false 40 | -------------------------------------------------------------------------------- /tedlium3/local/prepare_test_data.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # 3 | # Copyright 2014 Nickolay V. Shmyrev 4 | # 2014 Brno University of Technology (Author: Karel Vesely) 5 | # 2016 Johns Hopkins University (Author: Daniel Povey) 6 | # Apache 2.0 7 | 8 | # To be run from one directory above this script. 9 | 10 | . ./path.sh 11 | 12 | export LC_ALL=C 13 | 14 | sph2pipe=sph2pipe 15 | 16 | data_type=$1 17 | recog_set=$2 18 | # Prepare: test, train, 19 | for set in ${recog_set}; do 20 | dir=data/$set.orig 21 | mkdir -p $dir 22 | 23 | # Merge transcripts into a single 'stm' file, do some mappings: 24 | # - -> : map dev stm labels to be coherent with train + test, 25 | # - -> : --||-- 26 | # - (2) -> null : remove pronunciation variants in transcripts, keep in dictionary 27 | # - -> null : remove marked , it is modelled implicitly (in kaldi) 28 | # - (...) -> null : remove utterance names from end-lines of train 29 | # - it 's -> it's : merge words that contain apostrophe (if compound in dictionary, local/join_suffix.py) 30 | { # Add STM header, so sclite can prepare the '.lur' file 31 | echo ';; 32 | ;; LABEL "o" "Overall" "Overall results" 33 | ;; LABEL "f0" "f0" "Wideband channel" 34 | ;; LABEL "f2" "f2" "Telephone channel" 35 | ;; LABEL "male" "Male" "Male Talkers" 36 | ;; LABEL "female" "Female" "Female Talkers" 37 | ;;' 38 | # Process the STMs 39 | cat db/TEDLIUM_release-3/${data_type}/$set/stm/*.stm | sort -k1,1 -k2,2 -k4,4n | \ 40 | sed -e 's:::' \ 41 | -e 's:::' \ 42 | -e 's:([0-9])::g' \ 43 | -e 's:::g' \ 44 | -e 's:([^ ]*)$::' | \ 45 | awk '{ $2 = "A"; print $0; }' 46 | } | local/join_suffix.py > data/$set.orig/stm 47 | 48 | # Prepare 'text' file 49 | # - {NOISE} -> [NOISE] : map the tags to match symbols in dictionary 50 | cat $dir/stm | grep -v -e 'ignore_time_segment_in_scoring' -e ';;' | \ 51 | awk '{ printf ("%s-%07d-%07d", $1, $4*100, $5*100); 52 | for (i=7;i<=NF;i++) { printf(" %s", $i); } 53 | printf("\n"); 54 | }' | tr '{}' '[]' | sort -k1,1 > $dir/text || exit 1 55 | 56 | # Prepare 'segments', 'utt2spk', 'spk2utt' 57 | cat $dir/text | cut -d" " -f 1 | awk -F"-" '{printf("%s %s %07.2f %07.2f\n", $0, $1, $2/100.0, $3/100.0)}' > $dir/segments 58 | cat $dir/segments | awk '{print $1, $2}' > $dir/utt2spk 59 | cat $dir/utt2spk | utils/utt2spk_to_spk2utt.pl > $dir/spk2utt 60 | 61 | # Prepare 'wav.scp', 'reco2file_and_channel' 62 | if [ $set == "test-orig" ]; then 63 | cat $dir/spk2utt | awk -v data_type=$data_type -v set=$set -v pwd=$PWD '{ printf("%s '$sph2pipe' -f wav -p %s/db/TEDLIUM_release-3/%s/%s/sph/%s.sph |\n", $1, pwd, data_type, set, $1); }' > $dir/wav.scp 64 | else 65 | cat $dir/spk2utt | awk -v data_type=$data_type -v set=$set -v pwd=$PWD '{ printf("%s %s/db/TEDLIUM_release-3/%s/%s/wav/%s.wav\n", $1, pwd, data_type, set, $1); }' > $dir/wav.scp 66 | fi 67 | cat $dir/wav.scp | awk '{ print $1, $1, "A"; }' > $dir/reco2file_and_channel 68 | 69 | # Create empty 'glm' file 70 | echo ';; empty.glm 71 | [FAKE] => %HESITATION / [ ] __ [ ] ;; hesitation token 72 | ' > data/$set.orig/glm 73 | 74 | # Check that data dirs are okay! 75 | utils/validate_data_dir.sh --no-feats $dir || exit 1 76 | done 77 | 78 | -------------------------------------------------------------------------------- /tedlium2/local/prepare_test_data.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # 3 | # Copyright 2014 Nickolay V. Shmyrev 4 | # 2014 Brno University of Technology (Author: Karel Vesely) 5 | # 2016 Johns Hopkins University (Author: Daniel Povey) 6 | # Apache 2.0 7 | 8 | # To be run from one directory above this script. 9 | 10 | . ./path.sh 11 | 12 | export LC_ALL=C 13 | 14 | sph2pipe=sph2pipe 15 | recog_set=${1} 16 | 17 | # Prepare: test, train, 18 | for set in ${recog_set}; do 19 | dir=data/$set.orig 20 | mkdir -p $dir 21 | 22 | # Merge transcripts into a single 'stm' file, do some mappings: 23 | # - -> : map dev stm labels to be coherent with train + test, 24 | # - -> : --||-- 25 | # - (2) -> null : remove pronunciation variants in transcripts, keep in dictionary 26 | # - -> null : remove marked , it is modelled implicitly (in kaldi) 27 | # - (...) -> null : remove utterance names from end-lines of train 28 | # - it 's -> it's : merge words that contain apostrophe (if compound in dictionary, local/join_suffix.py) 29 | { # Add STM header, so sclite can prepare the '.lur' file 30 | echo ';; 31 | ;; LABEL "o" "Overall" "Overall results" 32 | ;; LABEL "f0" "f0" "Wideband channel" 33 | ;; LABEL "f2" "f2" "Telephone channel" 34 | ;; LABEL "male" "Male" "Male Talkers" 35 | ;; LABEL "female" "Female" "Female Talkers" 36 | ;;' 37 | # Process the STMs 38 | cat db/TEDLIUM_release2/$set/stm/*.stm | sort -k1,1 -k2,2 -k4,4n | \ 39 | sed -e 's:::' \ 40 | -e 's:::' \ 41 | -e 's:([0-9])::g' \ 42 | -e 's:::g' \ 43 | -e 's:([^ ]*)$::' | \ 44 | awk '{ $2 = "A"; print $0; }' 45 | } | local/join_suffix.py > data/$set.orig/stm 46 | 47 | # Prepare 'text' file 48 | # - {NOISE} -> [NOISE] : map the tags to match symbols in dictionary 49 | cat $dir/stm | grep -v -e 'ignore_time_segment_in_scoring' -e ';;' | \ 50 | awk '{ printf ("%s-%07d-%07d", $1, $4*100, $5*100); 51 | for (i=7;i<=NF;i++) { printf(" %s", $i); } 52 | printf("\n"); 53 | }' | tr '{}' '[]' | sort -k1,1 > $dir/text || exit 1 54 | 55 | # Prepare 'segments', 'utt2spk', 'spk2utt' 56 | cat $dir/text | cut -d" " -f 1 | awk -F"-" '{printf("%s %s %07.2f %07.2f\n", $0, $1, $2/100.0, $3/100.0)}' > $dir/segments 57 | cat $dir/segments | awk '{print $1, $2}' > $dir/utt2spk 58 | cat $dir/utt2spk | utils/utt2spk_to_spk2utt.pl > $dir/spk2utt 59 | 60 | # Prepare 'wav.scp', 'reco2file_and_channel' 61 | if [ $set == "test-orgi" ]; then 62 | cat $dir/spk2utt | awk -v set=$set -v pwd=$PWD '{ printf("%s '$sph2pipe' -f wav -p %s/db/TEDLIUM_release2/%s/sph/%s.sph |\n", $1, pwd, set, $1); }' > $dir/wav.scp 63 | else 64 | cat $dir/spk2utt | awk -v set=$set -v pwd=$PWD '{ printf("%s %s/db/TEDLIUM_release2/%s/wav/%s.wav\n", $1, pwd, set, $1); }' > $dir/wav.scp 65 | fi 66 | cat $dir/wav.scp | awk '{ print $1, $1, "A"; }' > $dir/reco2file_and_channel 67 | 68 | # Create empty 'glm' file 69 | echo ';; empty.glm 70 | [FAKE] => %HESITATION / [ ] __ [ ] ;; hesitation token 71 | ' > data/$set.orig/glm 72 | 73 | # The training set seems to not have enough silence padding in the segmentations, 74 | # especially at the beginning of segments. Extend the times. 75 | if [ $set == "train" ]; then 76 | mv data/$set.orig/segments data/$set.orig/segments.temp 77 | utils/data/extend_segment_times.py --start-padding=0.15 \ 78 | --end-padding=0.1 data/$set.orig/segments || exit 1 79 | rm data/$set.orig/segments.temp 80 | fi 81 | 82 | # Check that data dirs are okay! 83 | utils/validate_data_dir.sh --no-feats $dir || exit 1 84 | done 85 | 86 | 87 | -------------------------------------------------------------------------------- /timit/local/timit_norm_trans.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | use warnings; #sed replacement for -w perl parameter 3 | 4 | # Copyright 2012 Arnab Ghoshal 5 | 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 13 | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED 14 | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, 15 | # MERCHANTABLITY OR NON-INFRINGEMENT. 16 | # See the Apache 2 License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | 20 | # This script normalizes the TIMIT phonetic transcripts that have been 21 | # extracted in a format where each line contains an utterance ID followed by 22 | # the transcript, e.g.: 23 | # fcke0_si1111 h# hh ah dx ux w iy dcl d ix f ay n ih q h# 24 | 25 | my $usage = "Usage: timit_norm_trans.pl -i transcript -m phone_map -from [60|48] -to [48|39] > normalized\n 26 | Normalizes phonetic transcriptions for TIMIT, by mapping the phones to a 27 | smaller set defined by the -m option. This script assumes that the mapping is 28 | done in the \"standard\" fashion, i.e. to 48 or 39 phones. The input is 29 | assumed to have 60 phones (+1 for glottal stop, which is deleted), but that can 30 | be changed using the -from option. The input format is assumed to be utterance 31 | ID followed by transcript on the same line.\n"; 32 | 33 | use strict; 34 | use Getopt::Long; 35 | die "$usage" unless(@ARGV >= 1); 36 | my ($in_trans, $phone_map, $num_phones_out); 37 | my $num_phones_in = 60; 38 | GetOptions ("i=s" => \$in_trans, # Input transcription 39 | "m=s" => \$phone_map, # File containing phone mappings 40 | "from=i" => \$num_phones_in, # Input #phones: must be 60 or 48 41 | "to=i" => \$num_phones_out ); # Output #phones: must be 48 or 39 42 | 43 | die $usage unless(defined($in_trans) && defined($phone_map) && 44 | defined($num_phones_out)); 45 | if ($num_phones_in != 60 && $num_phones_in != 48) { 46 | die "Can only used 60 or 48 for -from (used $num_phones_in)." 47 | } 48 | if ($num_phones_out != 48 && $num_phones_out != 39) { 49 | die "Can only used 48 or 39 for -to (used $num_phones_out)." 50 | } 51 | unless ($num_phones_out < $num_phones_in) { 52 | die "Argument to -from ($num_phones_in) must be greater than that to -to ($num_phones_out)." 53 | } 54 | 55 | 56 | open(M, "<$phone_map") or die "Cannot open mappings file '$phone_map': $!"; 57 | my (%phonemap, %seen_phones); 58 | my $num_seen_phones = 0; 59 | while () { 60 | chomp; 61 | next if ($_ =~ /^q\s*.*$/); # Ignore glottal stops. 62 | m:^(\S+)\s+(\S+)\s+(\S+)$: or die "Bad line: $_"; 63 | my $mapped_from = ($num_phones_in == 60)? $1 : $2; 64 | my $mapped_to = ($num_phones_out == 48)? $2 : $3; 65 | if (!defined($seen_phones{$mapped_to})) { 66 | $seen_phones{$mapped_to} = 1; 67 | $num_seen_phones += 1; 68 | } 69 | $phonemap{$mapped_from} = $mapped_to; 70 | } 71 | if ($num_seen_phones != $num_phones_out) { 72 | die "Trying to map to $num_phones_out phones, but seen only $num_seen_phones"; 73 | } 74 | 75 | open(T, "<$in_trans") or die "Cannot open transcription file '$in_trans': $!"; 76 | while () { 77 | chomp; 78 | $_ =~ m:^(\S+)\s+(.+): or die "Bad line: $_"; 79 | my $utt_id = $1; 80 | my $trans = $2; 81 | 82 | $trans =~ s/q//g; # Remove glottal stops. 83 | $trans =~ s/^\s*//; $trans =~ s/\s*$//; # Normalize spaces 84 | 85 | print $utt_id; 86 | for my $phone (split(/\s+/, $trans)) { 87 | if(exists $phonemap{$phone}) { print " $phonemap{$phone}"; } 88 | if(not exists $phonemap{$phone}) { print " $phone"; } 89 | } 90 | print "\n"; 91 | } 92 | -------------------------------------------------------------------------------- /utils/score_sclite.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright 2017 Johns Hopkins University (Shinji Watanabe) 4 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 5 | 6 | [ -f ./path.sh ] && . ./path.sh 7 | 8 | nlsyms="" 9 | wer=false 10 | bpe="" 11 | bpemodel="" 12 | remove_blank=true 13 | filter="" 14 | num_spkrs=1 15 | need_decode= 16 | guide_type= 17 | help_message="Usage: $0 " 18 | 19 | . utils/parse_options.sh 20 | 21 | if [ $# != 2 ]; then 22 | echo "${help_message}" 23 | exit 1; 24 | fi 25 | 26 | dir=$1 27 | dic=$2 28 | 29 | if "${need_decode}"; then 30 | concatjson.py ${dir}/data.*.json > ${dir}/data.json 31 | else 32 | dir=$1/${guide_type}-1155 33 | fi 34 | 35 | if [ $num_spkrs -eq 1 ]; then 36 | json2trn.py ${dir}/data.json ${dic} --num-spkrs ${num_spkrs} --refs ${dir}/ref.trn --hyps ${dir}/hyp.trn 37 | 38 | if ${remove_blank}; then 39 | sed -i.bak2 -r 's/ //g' ${dir}/hyp.trn 40 | fi 41 | if [ -n "${nlsyms}" ]; then 42 | cp ${dir}/ref.trn ${dir}/ref.trn.org 43 | cp ${dir}/hyp.trn ${dir}/hyp.trn.org 44 | filt.py -v ${nlsyms} ${dir}/ref.trn.org > ${dir}/ref.trn 45 | filt.py -v ${nlsyms} ${dir}/hyp.trn.org > ${dir}/hyp.trn 46 | fi 47 | if [ -n "${filter}" ]; then 48 | sed -i.bak3 -f ${filter} ${dir}/hyp.trn 49 | sed -i.bak3 -f ${filter} ${dir}/ref.trn 50 | fi 51 | 52 | sclite -r ${dir}/ref.trn trn -h ${dir}/hyp.trn trn -i rm -o all stdout > ${dir}/result.txt 53 | 54 | echo "write a CER (or TER) result in ${dir}/result.txt" 55 | grep -e Avg -e SPKR -m 2 ${dir}/result.txt 56 | 57 | if ${wer}; then 58 | if [ -n "$bpe" ]; then 59 | spm_decode --model=${bpemodel} --input_format=piece < ${dir}/ref.trn | sed -e "s/▁/ /g" > ${dir}/ref.wrd.trn 60 | spm_decode --model=${bpemodel} --input_format=piece < ${dir}/hyp.trn | sed -e "s/▁/ /g" > ${dir}/hyp.wrd.trn 61 | else 62 | sed -e "s/ //g" -e "s/(/ (/" -e "s// /g" ${dir}/ref.trn > ${dir}/ref.wrd.trn 63 | sed -e "s/ //g" -e "s/(/ (/" -e "s// /g" ${dir}/hyp.trn > ${dir}/hyp.wrd.trn 64 | fi 65 | sclite -r ${dir}/ref.wrd.trn trn -h ${dir}/hyp.wrd.trn trn -i rm -o all stdout > ${dir}/result.wrd.txt 66 | 67 | echo "write a WER result in ${dir}/result.wrd.txt" 68 | grep -e Avg -e SPKR -m 2 ${dir}/result.wrd.txt 69 | fi 70 | elif [ ${num_spkrs} -lt 4 ]; then 71 | ref_trns="" 72 | hyp_trns="" 73 | for i in $(seq ${num_spkrs}); do 74 | ref_trns=${ref_trns}"${dir}/ref${i}.trn " 75 | hyp_trns=${hyp_trns}"${dir}/hyp${i}.trn " 76 | done 77 | json2trn.py ${dir}/data.json ${dic} --num-spkrs ${num_spkrs} --refs ${ref_trns} --hyps ${hyp_trns} 78 | 79 | for n in $(seq ${num_spkrs}); do 80 | if ${remove_blank}; then 81 | sed -i.bak2 -r 's/ //g' ${dir}/hyp${n}.trn 82 | fi 83 | if [ -n "${nlsyms}" ]; then 84 | cp ${dir}/ref${n}.trn ${dir}/ref${n}.trn.org 85 | cp ${dir}/hyp${n}.trn ${dir}/hyp${n}.trn.org 86 | filt.py -v ${nlsyms} ${dir}/ref${n}.trn.org > ${dir}/ref${n}.trn 87 | filt.py -v ${nlsyms} ${dir}/hyp${n}.trn.org > ${dir}/hyp${n}.trn 88 | fi 89 | if [ -n "${filter}" ]; then 90 | sed -i.bak3 -f ${filter} ${dir}/hyp${n}.trn 91 | sed -i.bak3 -f ${filter} ${dir}/ref${n}.trn 92 | fi 93 | done 94 | 95 | results_str="" 96 | for (( i=0; i<$((num_spkrs * num_spkrs)); i++ )); do 97 | ind_r=$((i / num_spkrs + 1)) 98 | ind_h=$((i % num_spkrs + 1)) 99 | results_str=${results_str}"${dir}/result_r${ind_r}h${ind_h}.txt " 100 | sclite -r ${dir}/ref${ind_r}.trn trn -h ${dir}/hyp${ind_h}.trn trn -i rm -o all stdout > ${dir}/result_r${ind_r}h${ind_h}.txt 101 | done 102 | 103 | echo "write CER (or TER) results in ${dir}/result_r*h*.txt" 104 | eval_perm_free_error.py --num-spkrs ${num_spkrs} \ 105 | ${results_str} > ${dir}/min_perm_result.json 106 | sed -n '2,4p' ${dir}/min_perm_result.json 107 | 108 | if ${wer}; then 109 | for n in $(seq ${num_spkrs}); do 110 | if [ -n "$bpe" ]; then 111 | spm_decode --model=${bpemodel} --input_format=piece < ${dir}/ref${n}.trn | sed -e "s/▁/ /g" > ${dir}/ref${n}.wrd.trn 112 | spm_decode --model=${bpemodel} --input_format=piece < ${dir}/hyp${n}.trn | sed -e "s/▁/ /g" > ${dir}/hyp${n}.wrd.trn 113 | else 114 | sed -e "s/ //g" -e "s/(/ (/" -e "s// /g" ${dir}/ref${n}.trn > ${dir}/ref${n}.wrd.trn 115 | sed -e "s/ //g" -e "s/(/ (/" -e "s// /g" ${dir}/hyp${n}.trn > ${dir}/hyp${n}.wrd.trn 116 | fi 117 | done 118 | results_str="" 119 | for (( i=0; i<$((num_spkrs * num_spkrs)); i++ )); do 120 | ind_r=$((i / num_spkrs + 1)) 121 | ind_h=$((i % num_spkrs + 1)) 122 | results_str=${results_str}"${dir}/result_r${ind_r}h${ind_h}.wrd.txt " 123 | sclite -r ${dir}/ref${ind_r}.wrd.trn trn -h ${dir}/hyp${ind_h}.wrd.trn trn -i rm -o all stdout > ${dir}/result_r${ind_r}h${ind_h}.wrd.txt 124 | done 125 | 126 | echo "write WER results in ${dir}/result_r*h*.wrd.txt" 127 | eval_perm_free_error.py --num-spkrs ${num_spkrs} \ 128 | ${results_str} > ${dir}/min_perm_result.wrd.json 129 | sed -n '2,4p' ${dir}/min_perm_result.wrd.json 130 | fi 131 | fi 132 | 133 | -------------------------------------------------------------------------------- /timit/local/timit_test_data_prep.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | create_glm_stm=false 5 | 6 | if [ $# -le 0 ]; then 7 | echo "Argument should be the Timit directory, see ../run.sh for example." 8 | exit 1; 9 | fi 10 | 11 | dir=`pwd`/data/local/data 12 | lmdir=`pwd`/data/local/nist_lm 13 | mkdir -p $dir $lmdir 14 | local=`pwd`/local 15 | utils=`pwd`/utils 16 | conf=`pwd`/conf 17 | train_flag=$3 18 | if [ $2 ]; then 19 | if [[ $2 = "char" || $2 = "phn" ]]; then 20 | trans_type=$2 21 | else 22 | echo "Transcript type must be one of [phn, char]" 23 | echo $2 24 | fi 25 | else 26 | trans_type=phn 27 | fi 28 | 29 | . ./path.sh # Needed for KALDI_ROOT 30 | export PATH=$PATH:$KALDI_ROOT/tools/irstlm/bin 31 | sph2pipe=sph2pipe 32 | if ! command -v "${sph2pipe}" &> /dev/null; then 33 | echo "Could not find (or execute) the sph2pipe program at $sph2pipe"; 34 | exit 1; 35 | fi 36 | 37 | [ -f $conf/test_spk.list ] || error_exit "$PROG: Eval-set speaker list not found."; 38 | 39 | # First check if the test directories exist (these can either be upper- 40 | # or lower-cased 41 | # if [ ! -d $1/TEST-NEW ] && [ ! -d $1/test-new ] && [ ! -d $1/TEST-ORGI ] && [ ! -d $1/test-orgi ]; then 42 | # # if [ ! -d $1/TEST ] && [ ! -d $1/test ] && [ ! -d $1/TEST ] && [ ! -d $1/test ]; then 43 | # echo "timit_test_data_prep.sh: Spot check of command line argument failed" 44 | # echo "Command line argument must be absolute pathname to TIMIT directory" 45 | # echo "with name like /export/corpora5/LDC/LDC93S1/timit/TIMIT" 46 | # exit 1; 47 | # fi 48 | 49 | # Now check what case the directory structure is 50 | 51 | uppercased=true 52 | test=$4 53 | test_dir=$(echo $test | tr '[a-z]' '[A-Z]') 54 | # if [ -d $1/TEST-NEW ] || [ -d $1/TEST-ORGI ] ; then 55 | # uppercased=true 56 | # fi 57 | 58 | # str1="$1" 59 | # str2="ORGI" 60 | 61 | # result=$(echo $str1 | grep "${str2}") 62 | # if [[ "$result" != "" ]] ; then 63 | # if ! "${train_flag}"; then 64 | # test=test-orgi 65 | # test_dir=TEST-ORGI 66 | # else 67 | # test=train-orgi 68 | # test_dir=TRAIN-ORGI 69 | # fi 70 | # else 71 | # if ! "${train_flag}"; then 72 | # test=test-new 73 | # test_dir=TEST-NEW 74 | # else 75 | # test=train-new 76 | # test_dir=TRAIN-NEW 77 | # fi 78 | # fi 79 | 80 | tmpdir=$(mktemp -d /tmp/kaldi.XXXX); 81 | trap 'rm -rf "$tmpdir"' EXIT 82 | # Get the list of speakers. The list of speakers in the 24-speaker core test 83 | # set and the 50-speaker development set must be supplied to the script. All 84 | # speakers in the 'train' directory are used for training. 85 | if $uppercased; then 86 | tr '[:lower:]' '[:upper:]' < $conf/dev_spk.list > $tmpdir/dev_spk 87 | if ! "${train_flag}"; then 88 | tr '[:lower:]' '[:upper:]' < $conf/test_spk.list > $tmpdir/${test}_spk 89 | else 90 | ls -d "$1"/$test_dir/DR*/* | sed -e "s:^.*/::" > $tmpdir/${test}_spk 91 | fi 92 | else 93 | tr '[:upper:]' '[:lower:]' < $conf/dev_spk.list > $tmpdir/dev_spk 94 | if ! "${train_flag}"; then 95 | tr '[:lower:]' '[:upper:]' < $conf/test_spk.list > $tmpdir/${test}_spk 96 | else 97 | ls -d "$1"/$test_dir/DR*/* | sed -e "s:^.*/::" > $tmpdir/${test}_spk 98 | fi 99 | fi 100 | 101 | cd $dir 102 | for x in $test dev; do 103 | # First, find the list of audio files (use only si & sx utterances). 104 | # Note: train & test sets are under different directories, but doing find on 105 | # both and grepping for the speakers will work correctly. 106 | echo "test_dir is $1/$test_dir" 107 | find $1/$test_dir -not \( -iname 'SA*' \) -iname '*.WAV' \ 108 | | grep -f $tmpdir/${x}_spk > ${x}_sph.flist 109 | 110 | sed -e 's:.*/\(.*\)/\(.*\).WAV$:\1_\2:i' ${x}_sph.flist \ 111 | > $tmpdir/${x}_sph.uttids 112 | paste $tmpdir/${x}_sph.uttids ${x}_sph.flist \ 113 | | sort -k1,1 > ${x}_sph.scp 114 | 115 | cat ${x}_sph.scp | awk '{print $1}' > ${x}.uttids 116 | 117 | # Now, Convert the transcripts into our format (no normalization yet) 118 | # Get the transcripts: each line of the output contains an utterance 119 | # ID followed by the transcript. 120 | 121 | if [ $trans_type = "phn" ] 122 | then 123 | echo "phone transcript!" 124 | find $1/$test_dir -not \( -iname 'SA*' \) -iname '*.PHN' \ 125 | | grep -f $tmpdir/${x}_spk > $tmpdir/${x}_phn.flist 126 | sed -e 's:.*/\(.*\)/\(.*\).PHN$:\1_\2:i' $tmpdir/${x}_phn.flist \ 127 | > $tmpdir/${x}_phn.uttids 128 | while read line; do 129 | [ -f $line ] || error_exit "Cannot find transcription file '$line'"; 130 | cut -f3 -d' ' "$line" | tr '\n' ' ' | perl -ape 's: *$:\n:;' 131 | done < $tmpdir/${x}_phn.flist > $tmpdir/${x}_phn.trans 132 | paste $tmpdir/${x}_phn.uttids $tmpdir/${x}_phn.trans \ 133 | | sort -k1,1 > ${x}.trans 134 | 135 | elif [ $trans_type = "char" ] 136 | then 137 | echo "char transcript!" 138 | find $1/$test_dir -not \( -iname 'SA*' \) -iname '*.WRD' \ 139 | | grep -f $tmpdir/${x}_spk > $tmpdir/${x}_wrd.flist 140 | sed -e 's:.*/\(.*\)/\(.*\).WRD$:\1_\2:i' $tmpdir/${x}_wrd.flist \ 141 | > $tmpdir/${x}_wrd.uttids 142 | while read line; do 143 | [ -f $line ] || error_exit "Cannot find transcription file '$line'"; 144 | cut -f3 -d' ' "$line" | tr '\n' ' ' | perl -ape 's: *$:\n:;' | tr '[:upper:]' '[:lower:]' | sed 's/[^a-z A-Z]//g' 145 | done < $tmpdir/${x}_wrd.flist > $tmpdir/${x}_wrd.trans 146 | paste $tmpdir/${x}_wrd.uttids $tmpdir/${x}_wrd.trans \ 147 | | sort -k1,1 > ${x}.trans 148 | else 149 | echo "WRONG!" 150 | echo $trans_type 151 | exit 0; 152 | fi 153 | 154 | # Do normalization steps. 155 | cat ${x}.trans | $local/timit_norm_trans.pl -i - -m $conf/phones.60-48-39.map -to 39 | sort > $x.text || exit 1; 156 | 157 | # Create wav.scp 158 | awk '{printf("%s '$sph2pipe' -f wav %s |\n", $1, $2);}' < ${x}_sph.scp > ${x}_wav.scp 159 | 160 | # Make the utt2spk and spk2utt files. 161 | cut -f1 -d'_' $x.uttids | paste -d' ' $x.uttids - > $x.utt2spk 162 | cat $x.utt2spk | $utils/utt2spk_to_spk2utt.pl > $x.spk2utt || exit 1; 163 | 164 | # Prepare gender mapping 165 | cat $x.spk2utt | awk '{print $1}' | perl -ane 'chop; m:^.:; $g = lc($&); print "$_ $g\n";' > $x.spk2gender 166 | 167 | 168 | if "${create_glm_stm}"; then 169 | # Prepare STM file for sclite: 170 | wav-to-duration --read-entire-file=true scp:${x}_wav.scp ark,t:${x}_dur.ark || exit 1 171 | awk -v dur=${x}_dur.ark \ 172 | 'BEGIN{ 173 | while(getline < dur) { durH[$1]=$2; } 174 | print ";; LABEL \"O\" \"Overall\" \"Overall\""; 175 | print ";; LABEL \"F\" \"Female\" \"Female speakers\""; 176 | print ";; LABEL \"M\" \"Male\" \"Male speakers\""; 177 | } 178 | { wav=$1; spk=wav; sub(/_.*/,"",spk); $1=""; ref=$0; 179 | gender=(substr(spk,0,1) == "f" ? "F" : "M"); 180 | printf("%s 1 %s 0.0 %f %s\n", wav, spk, durH[wav], gender, ref); 181 | } 182 | ' ${x}.text >${x}.stm || exit 1 183 | 184 | # Create dummy GLM file for sclite: 185 | echo ';; empty.glm 186 | [FAKE] => %HESITATION / [ ] __ [ ] ;; hesitation token 187 | ' > ${x}.glm 188 | fi 189 | done 190 | 191 | echo "Data preparation succeeded" -------------------------------------------------------------------------------- /tedlium3/test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright 2019 Nagoya University (Masao Someki) 4 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 5 | 6 | . ./path.sh || exit 1; 7 | . ./cmd.sh || exit 1; 8 | 9 | # general configuration 10 | backend=pytorch 11 | stage=3 # start from -1 if you need to start from data download 12 | stop_stage=3 13 | ngpu=0 # number of gpus ("0" uses cpu, otherwise use gpu) 14 | debugmode=1 15 | dumpdir=dump # directory to dump full features 16 | N=0 # number of minibatches to be used (mainly for debugging). "0" uses all minibatches. 17 | verbose=1 # verbose option 18 | resume= # Resume the training from snapshot 19 | 20 | # feature configuration 21 | do_delta=false 22 | cmvn= 23 | # rnnlm related 24 | lm_resume= # specify a snapshot file to resume LM training 25 | lmtag= # tag for managing LMs 26 | use_lang_model=false 27 | lang_model= 28 | 29 | # decoding parameter 30 | p=0.005 31 | recog_model= 32 | recog_dir= 33 | decode_config= 34 | decode_dir=decode 35 | api=v2 36 | 37 | # bpemode (unigram or bpe) 38 | nbpe=500 39 | bpemode=unigram 40 | 41 | # exp tag 42 | tag="" # tag for managing experiments. 43 | 44 | train_config= 45 | decode_config= 46 | preprocess_config= 47 | lm_config= 48 | models=tedlium3.conformer 49 | # gini related 50 | orig_flag=false 51 | orig_dir= 52 | need_decode=false 53 | 54 | data_type=legacy 55 | train_set=train_trim_sp 56 | recog_set=dev-new 57 | 58 | . utils/parse_options.sh || exit 1; 59 | 60 | # Set bash to 'debug' mode, it will exit on : 61 | # -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', 62 | set -e 63 | set -u 64 | set -o pipefail 65 | 66 | 67 | # legacy setup 68 | 69 | 70 | #adaptation setup 71 | #data_type=speaker-adaptation 72 | #train_set=train_adapt_trim_sp 73 | #train_dev=dev_adapt_trim 74 | # recog_set="dev_adapt test_adapt" 75 | 76 | download_dir=${decode_dir}/download 77 | 78 | if [ "${api}" = "v2" ] && [ "${backend}" = "chainer" ]; then 79 | echo "chainer backend does not support api v2." >&2 80 | exit 1; 81 | fi 82 | 83 | if [ -z $models ]; then 84 | if [ $use_lang_model = "true" ]; then 85 | if [[ -z $cmvn || -z $lang_model || -z $recog_model || -z $decode_config ]]; then 86 | echo 'Error: models or set of cmvn, lang_model, recog_model and decode_config are required.' >&2 87 | exit 1 88 | fi 89 | else 90 | if [[ -z $cmvn || -z $recog_model || -z $decode_config ]]; then 91 | echo 'Error: models or set of cmvn, recog_model and decode_config are required.' >&2 92 | exit 1 93 | fi 94 | fi 95 | fi 96 | 97 | dir=${download_dir}/${models} 98 | mkdir -p ${dir} 99 | 100 | # Download trained models 101 | if [ -z "${cmvn}" ]; then 102 | #download_models 103 | cmvn=$(find ${download_dir}/${models} -name "cmvn.ark" | head -n 1) 104 | fi 105 | if [ -z "${lang_model}" ] && ${use_lang_model}; then 106 | #download_models 107 | lang_model=$(find ${download_dir}/${models} -name "rnnlm*.best*" | head -n 1) 108 | fi 109 | if [ -z "${recog_model}" ]; then 110 | #download_models 111 | if [ -z "${recog_dir}" ]; then 112 | recog_model=$(find ${download_dir}/${models} -name "model*.best*" | head -n 1) 113 | else 114 | recog_model=$(find "${recog_dir}/results" -name "model.loss.best" | head -n 1) 115 | fi 116 | echo "recog_model is ${recog_model}" 117 | fi 118 | if [ -z "${decode_config}" ]; then 119 | #download_models 120 | decode_config=$(find ${download_dir}/${models} -name "decode*.yaml" | head -n 1) 121 | fi 122 | 123 | # Check file existence 124 | if [ ! -f "${cmvn}" ]; then 125 | echo "No such CMVN file: ${cmvn}" 126 | exit 1 127 | fi 128 | if [ ! -f "${lang_model}" ] && ${use_lang_model}; then 129 | echo "No such language model: ${lang_model}" 130 | exit 1 131 | fi 132 | if [ ! -f "${recog_model}" ]; then 133 | echo "No such E2E model: ${recog_model}" 134 | exit 1 135 | fi 136 | if [ ! -f "${decode_config}" ]; then 137 | echo "No such config file: ${decode_config}" 138 | exit 1 139 | fi 140 | 141 | 142 | if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then 143 | ### Task dependent. You have to make data the following preparation part by yourself. 144 | ### But you can utilize Kaldi recipes in most cases 145 | echo "stage 0: Data preparation" 146 | if [ -z "${recog_dir}" ]; then 147 | local/prepare_test_data.sh ${data_type} ${recog_set} 148 | fi 149 | for dset in ${recog_set}; do 150 | utils/fix_data_dir.sh data/${dset}.orig 151 | utils/data/modify_speaker_info.sh --seconds-per-spk-max 180 data/${dset}.orig data/${dset} 152 | done 153 | fi 154 | 155 | # feat_tr_dir=${dumpdir}/${train_set}/delta${do_delta}; mkdir -p ${feat_tr_dir} 156 | # feat_dt_dir=${dumpdir}/${train_dev}/delta${do_delta}; mkdir -p ${feat_dt_dir} 157 | if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then 158 | ### Task dependent. You have to design training and dev sets by yourself. 159 | ### But you can utilize Kaldi recipes in most cases 160 | echo "stage 1: Feature Generation" 161 | fbankdir=fbank 162 | # Generate the fbank features; by default 80-dimensional fbanks with pitch on each frame 163 | for x in ${recog_set}; do 164 | steps/make_fbank_pitch.sh --cmd "$train_cmd" --nj 32 --write_utt2num_frames true \ 165 | data/${x} exp/make_fbank/${x} ${fbankdir} 166 | utils/fix_data_dir.sh data/${x} 167 | done 168 | 169 | for rtask in ${recog_set}; do 170 | feat_recog_dir=${dumpdir}/${rtask}/delta${do_delta}; mkdir -p ${feat_recog_dir} 171 | dump.sh --cmd "$train_cmd" --nj 32 --do_delta ${do_delta} \ 172 | data/${rtask}/feats.scp ${cmvn} exp/dump_feats/recog/${rtask} \ 173 | ${feat_recog_dir} 174 | done 175 | fi 176 | 177 | dict=data/lang_char/train_trim_sp_${bpemode}${nbpe}_units.txt 178 | bpemodel=data/lang_char/train_trim_sp_${bpemode}${nbpe} 179 | echo "dictionary: ${dict}" 180 | if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then 181 | ### Task dependent. You have to check non-linguistic symbols used in the corpus. 182 | echo "stage 2: Dictionary and Json Data Preparation" 183 | 184 | # make json labels 185 | for rtask in ${recog_set}; do 186 | feat_recog_dir=${dumpdir}/${rtask}/delta${do_delta} 187 | data2json.sh --feat ${feat_recog_dir}/feats.scp --bpecode ${bpemodel}.model\ 188 | data/${rtask} ${dict} > ${feat_recog_dir}/data_${bpemode}${nbpe}.json 189 | done 190 | fi 191 | 192 | # It takes a few days. If you just want to end-to-end ASR without LM, 193 | # you can skip this and remove --rnnlm option in the recognition (stage 5) 194 | 195 | if [ -z "${recog_dir}" ]; then 196 | expname=${models} 197 | expdir=exp/${expname} 198 | mkdir -p ${expdir} 199 | else 200 | expdir=${recog_dir} 201 | fi 202 | 203 | if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then 204 | echo "stage 3: Decoding" 205 | nj=32 206 | if ${use_lang_model}; then 207 | recog_opts="--rnnlm ${lang_model}" 208 | else 209 | recog_opts="" 210 | fi 211 | #feat_recog_dir=${decode_dir}/dump 212 | pids=() # initialize pids 213 | #trap 'rm -rf data/"${recog_set}" data/"${recog_set}.orig"' EXIT 214 | for rtask in ${recog_set}; do 215 | ( 216 | decode_dir=decode_${rtask}_decode_${lmtag} 217 | feat_recog_dir=${dumpdir}/${rtask}/delta${do_delta} 218 | 219 | # # split data 220 | if "${need_decode}"; then 221 | splitjson.py --parts ${nj} ${feat_recog_dir}/data_${bpemode}${nbpe}.json 222 | fi 223 | orig_dir=${expdir}/decode_test-orig_decode_${lmtag} 224 | 225 | #### use CPU for decoding 226 | ngpu=0 227 | if "${need_decode}"; then 228 | ${decode_cmd} JOB=1:${nj} ${expdir}/${decode_dir}/log/decode.JOB.log \ 229 | asr_test.py \ 230 | --config ${decode_config} \ 231 | --ngpu ${ngpu} \ 232 | --backend ${backend} \ 233 | --debugmode ${debugmode} \ 234 | --verbose ${verbose} \ 235 | --recog-json ${feat_recog_dir}/split32utt/data_${bpemode}${nbpe}.JOB.json \ 236 | --result-label ${expdir}/${decode_dir}/data.JOB.json \ 237 | --model ${recog_model} \ 238 | --api ${api} \ 239 | --orig_dir ${orig_dir} \ 240 | --need_decode ${need_decode} \ 241 | --orig_flag ${orig_flag} \ 242 | --recog_set ${recog_set} \ 243 | ${recog_opts} 244 | 245 | else 246 | asr_test.py \ 247 | --config ${decode_config} \ 248 | --ngpu ${ngpu} \ 249 | --backend ${backend} \ 250 | --debugmode ${debugmode} \ 251 | --verbose ${verbose} \ 252 | --recog-json ${feat_recog_dir}/split${nj}utt/data_${bpemode}${nbpe}.JOB.json \ 253 | --result-label ${expdir}/${decode_dir}/data.JOB.json \ 254 | --model ${recog_model} \ 255 | --api ${api} \ 256 | --orig_dir ${orig_dir} \ 257 | --need_decode ${need_decode} \ 258 | --orig_flag ${orig_flag} \ 259 | --recog_set ${recog_set} \ 260 | ${recog_opts} 261 | fi 262 | 263 | ) & 264 | pids+=($!) # store background pids 265 | done 266 | i=0; for pid in "${pids[@]}"; do wait ${pid} || ((++i)); done 267 | [ ${i} -gt 0 ] && echo "$0: ${i} background jobs are failed." && false 268 | 269 | fi 270 | 271 | 272 | if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then 273 | echo "stage 4: Scoring" 274 | decode_dir=decode_${recog_set}_decode_${lmtag} 275 | if "${need_decode}"; then 276 | score_sclite.sh --bpe ${nbpe} --bpemodel ${bpemodel}.model --wer true --need_decode ${need_decode} --guide_type "gini" ${expdir}/${decode_dir} ${dict} 277 | else 278 | score_sclite.sh --bpe ${nbpe} --bpemodel ${bpemodel}.model --wer true --need_decode ${need_decode} --guide_type "gini" ${expdir}/${decode_dir} ${dict} 279 | fi 280 | echo "Finished" 281 | fi 282 | -------------------------------------------------------------------------------- /tedlium2/test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright 2017 Johns Hopkins University (Shinji Watanabe) 4 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 5 | 6 | . ./path.sh || exit 1; 7 | . ./cmd.sh || exit 1; 8 | 9 | # general configuration 10 | backend=pytorch 11 | stage=3 # start from -1 if you need to start from data download 12 | stop_stage=4 13 | ngpu=0 # number of gpus ("0" uses cpu, otherwise use gpu) 14 | debugmode=1 15 | dumpdir=dump # directory to dump full features 16 | N=0 # number of minibatches to be used (mainly for debugging). "0" uses all minibatches. 17 | verbose=1 # verbose option 18 | resume= # Resume the training from snapshot 19 | 20 | # feature configuration 21 | do_delta=false 22 | cmvn= 23 | 24 | preprocess_config=conf/specaug.yaml 25 | train_config= 26 | lm_config= 27 | 28 | # rnnlm related 29 | skip_lm_training=true # for only using end-to-end ASR model without LM 30 | lm_resume= # specify a snapshot file to resume LM training 31 | lmtag= # tag for managing LMs 32 | use_lang_model=false 33 | lang_model= 34 | 35 | # test related 36 | models=tedlium2.transformer.v1 37 | # decoding parameter 38 | p=0.005 39 | recog_model= 40 | recog_dir= 41 | decode_config= 42 | decode_dir=decode 43 | api=v2 44 | 45 | 46 | # model average realted (only for transformer) 47 | # n_average=10 # the number of ASR models to be averaged 48 | # use_valbest_average=true # if true, the validation `n_average`-best ASR models will be averaged. 49 | # # if false, the last `n_average` ASR models will be averaged. 50 | 51 | # bpemode (unigram or bpe) 52 | nbpe=500 53 | bpemode=unigram 54 | 55 | # exp tag 56 | tag="" # tag for managing experiments. 57 | recog_set=test-noise 58 | 59 | # gini related 60 | orig_flag=false 61 | orig_dir= 62 | need_decode=false 63 | 64 | . utils/parse_options.sh || exit 1; 65 | 66 | # Set bash to 'debug' mode, it will exit on : 67 | # -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', 68 | set -e 69 | set -u 70 | set -o pipefail 71 | 72 | # train_set=train_trim_sp 73 | # train_dev=dev_trim 74 | 75 | download_dir=${decode_dir}/download 76 | 77 | if [ "${api}" = "v2" ] && [ "${backend}" = "chainer" ]; then 78 | echo "chainer backend does not support api v2." >&2 79 | exit 1; 80 | fi 81 | 82 | if [ -z $models ]; then 83 | if [ $use_lang_model = "true" ]; then 84 | if [[ -z $cmvn || -z $lang_model || -z $recog_model || -z $decode_config ]]; then 85 | echo 'Error: models or set of cmvn, lang_model, recog_model and decode_config are required.' >&2 86 | exit 1 87 | fi 88 | else 89 | if [[ -z $cmvn || -z $recog_model || -z $decode_config ]]; then 90 | echo 'Error: models or set of cmvn, recog_model and decode_config are required.' >&2 91 | exit 1 92 | fi 93 | fi 94 | fi 95 | 96 | dir=${download_dir}/${models} 97 | mkdir -p ${dir} 98 | 99 | 100 | function download_models () { 101 | if [ -z $models ]; then 102 | return 103 | fi 104 | 105 | file_ext="tar.gz" 106 | case "${models}" in 107 | "tedlium2.rnn.v1") share_url="https://drive.google.com/open?id=1UqIY6WJMZ4sxNxSugUqp3mrGb3j6h7xe"; api=v1 ;; 108 | "tedlium2.rnn.v2") share_url="https://drive.google.com/open?id=1cac5Uc09lJrCYfWkLQsF8eapQcxZnYdf"; api=v1 ;; 109 | "tedlium2.transformer.v1") share_url="https://drive.google.com/open?id=1cVeSOYY1twOfL9Gns7Z3ZDnkrJqNwPow" ;; 110 | "tedlium3.transformer.v1") share_url="https://drive.google.com/open?id=1zcPglHAKILwVgfACoMWWERiyIquzSYuU" ;; 111 | "librispeech.transformer.v1") share_url="https://drive.google.com/open?id=1BtQvAnsFvVi-dp_qsaFP7n4A_5cwnlR6" ;; 112 | "librispeech.transformer.v1.transformerlm.v1") share_url="https://drive.google.com/open?id=17cOOSHHMKI82e1MXj4r2ig8gpGCRmG2p" ;; 113 | "commonvoice.transformer.v1") share_url="https://drive.google.com/open?id=1tWccl6aYU67kbtkm8jv5H6xayqg1rzjh" ;; 114 | "csj.transformer.v1") share_url="https://drive.google.com/open?id=120nUQcSsKeY5dpyMWw_kI33ooMRGT2uF" ;; 115 | *) echo "No such models: ${models}"; exit 1 ;; 116 | esac 117 | 118 | if [ ! -e ${dir}/.complete ]; then 119 | download_from_google_drive.sh ${share_url} ${dir} ${file_ext} 120 | touch ${dir}/.complete 121 | fi 122 | } 123 | 124 | # Download trained models 125 | if [ -z "${cmvn}" ]; then 126 | #download_models 127 | cmvn=$(find ${download_dir}/${models} -name "cmvn.ark" | head -n 1) 128 | fi 129 | if [ -z "${lang_model}" ] && ${use_lang_model}; then 130 | #download_models 131 | lang_model=$(find ${download_dir}/${models} -name "rnnlm*.best*" | head -n 1) 132 | fi 133 | if [ -z "${recog_model}" ]; then 134 | #download_models 135 | if [ -z "${recog_dir}" ]; then 136 | recog_model=$(find ${download_dir}/${models} -name "model*.best*" | head -n 1) 137 | else 138 | recog_model=$(find "${recog_dir}/results" -name "model.acc.best" | head -n 1) 139 | fi 140 | echo "recog_model is ${recog_model}" 141 | fi 142 | if [ -z "${decode_config}" ]; then 143 | #download_models 144 | decode_config=$(find ${download_dir}/${models} -name "decode*.yaml" | head -n 1) 145 | fi 146 | 147 | 148 | # Check file existence 149 | if [ ! -f "${cmvn}" ]; then 150 | echo "No such CMVN file: ${cmvn}" 151 | exit 1 152 | fi 153 | if [ ! -f "${lang_model}" ] && ${use_lang_model}; then 154 | echo "No such language model: ${lang_model}" 155 | exit 1 156 | fi 157 | if [ ! -f "${recog_model}" ]; then 158 | echo "No such E2E model: ${recog_model}" 159 | exit 1 160 | fi 161 | if [ ! -f "${decode_config}" ]; then 162 | echo "No such config file: ${decode_config}" 163 | exit 1 164 | fi 165 | 166 | echo "stage ${stage}" 167 | 168 | if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then 169 | ### Task dependent. You have to make data the following preparation part by yourself. 170 | ### But you can utilize Kaldi recipes in most cases 171 | echo "stage 0: Data preparation" 172 | if [ -z "${recog_dir}" ]; then 173 | local/prepare_test_data.sh ${recog_set} 174 | fi 175 | for dset in ${recog_set}; do 176 | utils/fix_data_dir.sh data/${dset}.orig 177 | utils/data/modify_speaker_info.sh --seconds-per-spk-max 180 data/${dset}.orig data/${dset} 178 | done 179 | fi 180 | 181 | if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then 182 | ### Task dependent. You have to design training and dev sets by yourself. 183 | ### But you can utilize Kaldi recipes in most cases 184 | echo "stage 1: Feature Generation" 185 | fbankdir=fbank 186 | 187 | # Generate the fbank features; by default 80-dimensional fbanks with pitch on each frame 188 | for x in ${recog_set}; do 189 | #utils/fix_data_dir.sh data/${x} 190 | steps/make_fbank_pitch.sh --cmd "$train_cmd" --nj 32 --write_utt2num_frames true \ 191 | data/${x} exp/make_fbank/${x} ${fbankdir} 192 | utils/fix_data_dir.sh data/${x} 193 | done 194 | 195 | for rtask in ${recog_set}; do 196 | feat_recog_dir=${dumpdir}/${rtask}/delta${do_delta}; mkdir -p ${feat_recog_dir} 197 | dump.sh --cmd "$train_cmd" --nj 32 --do_delta ${do_delta} \ 198 | data/${rtask}/feats.scp ${cmvn} exp/dump_feats/recog/${rtask} \ 199 | ${feat_recog_dir} 200 | done 201 | fi 202 | 203 | dict=data/lang_char/train_trim_sp_${bpemode}${nbpe}_units.txt 204 | bpemodel=data/lang_char/train_trim_sp_${bpemode}${nbpe} 205 | if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then 206 | ### Task dependent. You have to check non-linguistic symbols used in the corpus. 207 | echo "stage 2: Dictionary and Json Data Preparation" 208 | 209 | for rtask in ${recog_set}; do 210 | feat_recog_dir=${dumpdir}/${rtask}/delta${do_delta} 211 | data2json.sh --feat ${feat_recog_dir}/feats.scp --bpecode ${bpemodel}.model\ 212 | data/${rtask} ${dict} > ${feat_recog_dir}/data_${bpemode}${nbpe}.json 213 | done 214 | fi 215 | 216 | if [ -z "${recog_dir}" ]; then 217 | expname=${models} 218 | expdir=exp/${expname} 219 | mkdir -p ${expdir} 220 | else 221 | expdir=${recog_dir} 222 | fi 223 | 224 | echo "${expdir}" 225 | if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then 226 | echo "stage 3: Decoding" 227 | nj=32 228 | if ${use_lang_model}; then 229 | recog_opts="--rnnlm ${lang_model}" 230 | else 231 | recog_opts="" 232 | fi 233 | #feat_recog_dir=${decode_dir}/dump 234 | pids=() # initialize pids 235 | #trap 'rm -rf data/"${recog_set}" data/"${recog_set}.orig"' EXIT 236 | for rtask in ${recog_set}; do 237 | ( 238 | decode_dir=decode_${rtask}_decode_${lmtag} 239 | feat_recog_dir=${dumpdir}/${rtask}/delta${do_delta} 240 | 241 | # # split data 242 | if "${need_decode}"; then 243 | splitjson.py --parts ${nj} ${feat_recog_dir}/data_${bpemode}${nbpe}.json 244 | fi 245 | orig_dir=${expdir}/decode_test-orig_decode_${lmtag} 246 | 247 | #### use CPU for decoding 248 | ngpu=0 249 | if "${need_decode}"; then 250 | ${decode_cmd} JOB=1:${nj} ${expdir}/${decode_dir}/log/decode.JOB.log \ 251 | asr_test.py \ 252 | --config ${decode_config} \ 253 | --ngpu ${ngpu} \ 254 | --backend ${backend} \ 255 | --debugmode ${debugmode} \ 256 | --verbose ${verbose} \ 257 | --recog-json ${feat_recog_dir}/split32utt/data_${bpemode}${nbpe}.JOB.json \ 258 | --result-label ${expdir}/${decode_dir}/data.JOB.json \ 259 | --model ${recog_model} \ 260 | --api ${api} \ 261 | --orig_dir ${orig_dir} \ 262 | --need_decode ${need_decode} \ 263 | --orig_flag ${orig_flag} \ 264 | --recog_set ${recog_set} \ 265 | ${recog_opts} 266 | 267 | else 268 | asr_test.py \ 269 | --config ${decode_config} \ 270 | --ngpu ${ngpu} \ 271 | --backend ${backend} \ 272 | --debugmode ${debugmode} \ 273 | --verbose ${verbose} \ 274 | --recog-json ${feat_recog_dir}/split${nj}utt/data_${bpemode}${nbpe}.JOB.json \ 275 | --result-label ${expdir}/${decode_dir}/data.JOB.json \ 276 | --model ${recog_model} \ 277 | --api ${api} \ 278 | --orig_dir ${orig_dir} \ 279 | --need_decode ${need_decode} \ 280 | --orig_flag ${orig_flag} \ 281 | --recog_set ${recog_set} \ 282 | ${recog_opts} 283 | fi 284 | 285 | ) & 286 | pids+=($!) # store background pids 287 | done 288 | i=0; for pid in "${pids[@]}"; do wait ${pid} || ((++i)); done 289 | [ ${i} -gt 0 ] && echo "$0: ${i} background jobs are failed." && false 290 | 291 | fi 292 | 293 | 294 | if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then 295 | echo "stage 4: Scoring" 296 | decode_dir=decode_${recog_set}_decode_${lmtag} 297 | if "${need_decode}"; then 298 | score_sclite.sh --bpe ${nbpe} --bpemodel ${bpemodel}.model --wer true --need_decode ${need_decode} --guide_type "gini" ${expdir}/${decode_dir} ${dict} 299 | else 300 | score_sclite.sh --bpe ${nbpe} --bpemodel ${bpemodel}.model --wer true --need_decode ${need_decode} --guide_type "gini" ${expdir}/${decode_dir} ${dict} 301 | fi 302 | echo "Finished" 303 | fi 304 | 305 | 306 | -------------------------------------------------------------------------------- /espnet2/asr/decoder/rnn_decoder.py: -------------------------------------------------------------------------------- 1 | import random 2 | import logging 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | from typeguard import check_argument_types 7 | 8 | from espnet.nets.pytorch_backend.nets_utils import make_pad_mask 9 | from espnet.nets.pytorch_backend.nets_utils import to_device 10 | from espnet.nets.pytorch_backend.rnn.attentions import initial_att 11 | from espnet2.asr.decoder.abs_decoder import AbsDecoder 12 | from espnet2.utils.get_default_kwargs import get_default_kwargs 13 | 14 | 15 | def build_attention_list( 16 | eprojs: int, 17 | dunits: int, 18 | atype: str = "location", 19 | num_att: int = 1, 20 | num_encs: int = 1, 21 | aheads: int = 4, 22 | adim: int = 320, 23 | awin: int = 5, 24 | aconv_chans: int = 10, 25 | aconv_filts: int = 100, 26 | han_mode: bool = False, 27 | han_type=None, 28 | han_heads: int = 4, 29 | han_dim: int = 320, 30 | han_conv_chans: int = -1, 31 | han_conv_filts: int = 100, 32 | han_win: int = 5, 33 | ): 34 | 35 | att_list = torch.nn.ModuleList() 36 | if num_encs == 1: 37 | for i in range(num_att): 38 | att = initial_att( 39 | atype, 40 | eprojs, 41 | dunits, 42 | aheads, 43 | adim, 44 | awin, 45 | aconv_chans, 46 | aconv_filts, 47 | ) 48 | att_list.append(att) 49 | elif num_encs > 1: # no multi-speaker mode 50 | if han_mode: 51 | att = initial_att( 52 | han_type, 53 | eprojs, 54 | dunits, 55 | han_heads, 56 | han_dim, 57 | han_win, 58 | han_conv_chans, 59 | han_conv_filts, 60 | han_mode=True, 61 | ) 62 | return att 63 | else: 64 | att_list = torch.nn.ModuleList() 65 | for idx in range(num_encs): 66 | att = initial_att( 67 | atype[idx], 68 | eprojs, 69 | dunits, 70 | aheads[idx], 71 | adim[idx], 72 | awin[idx], 73 | aconv_chans[idx], 74 | aconv_filts[idx], 75 | ) 76 | att_list.append(att) 77 | else: 78 | raise ValueError( 79 | "Number of encoders needs to be more than one. {}".format(num_encs) 80 | ) 81 | return att_list 82 | 83 | 84 | 85 | class RNNDecoder(AbsDecoder): 86 | def __init__( 87 | self, 88 | vocab_size: int, 89 | encoder_output_size: int, 90 | rnn_type: str = "lstm", 91 | num_layers: int = 1, 92 | hidden_size: int = 320, 93 | sampling_probability: float = 0.0, 94 | dropout: float = 0.0, 95 | context_residual: bool = False, 96 | replace_sos: bool = False, 97 | num_encs: int = 1, 98 | att_conf: dict = get_default_kwargs(build_attention_list), 99 | ): 100 | # FIXME(kamo): The parts of num_spk should be refactored more more more 101 | assert check_argument_types() 102 | if rnn_type not in {"lstm", "gru"}: 103 | raise ValueError(f"Not supported: rnn_type={rnn_type}") 104 | 105 | super().__init__() 106 | eprojs = encoder_output_size 107 | self.dtype = rnn_type 108 | self.dunits = hidden_size 109 | self.dlayers = num_layers 110 | self.context_residual = context_residual 111 | self.sos = vocab_size - 1 112 | self.eos = vocab_size - 1 113 | self.odim = vocab_size 114 | self.sampling_probability = sampling_probability 115 | self.dropout = dropout 116 | self.num_encs = num_encs 117 | 118 | # for multilingual translation 119 | self.replace_sos = replace_sos 120 | 121 | self.embed = torch.nn.Embedding(vocab_size, hidden_size) 122 | self.dropout_emb = torch.nn.Dropout(p=dropout) 123 | 124 | self.decoder = torch.nn.ModuleList() 125 | self.dropout_dec = torch.nn.ModuleList() 126 | self.decoder += [ 127 | torch.nn.LSTMCell(hidden_size + eprojs, hidden_size) 128 | if self.dtype == "lstm" 129 | else torch.nn.GRUCell(hidden_size + eprojs, hidden_size) 130 | ] 131 | 132 | self.dropout_dec += [torch.nn.Dropout(p=dropout)] 133 | for _ in range(1, self.dlayers): 134 | self.decoder += [ 135 | torch.nn.LSTMCell(hidden_size, hidden_size) 136 | if self.dtype == "lstm" 137 | else torch.nn.GRUCell(hidden_size, hidden_size) 138 | ] 139 | self.dropout_dec += [torch.nn.Dropout(p=dropout)] 140 | # NOTE: dropout is applied only for the vertical connections 141 | # see https://arxiv.org/pdf/1409.2329.pdf 142 | if context_residual: 143 | self.output = torch.nn.Linear(hidden_size + eprojs, vocab_size) 144 | else: 145 | self.output = torch.nn.Linear(hidden_size, vocab_size) 146 | 147 | self.att_list = build_attention_list( 148 | eprojs=eprojs, dunits=hidden_size, **att_conf 149 | ) 150 | 151 | def zero_state(self, hs_pad): 152 | return hs_pad.new_zeros(hs_pad.size(0), self.dunits) 153 | 154 | def caul_gini(self, softmax_value): 155 | gini = 0 156 | class_softmax = softmax_value.cpu().detach().numpy() 157 | for i in range(0, len(class_softmax)): 158 | gini += np.square(class_softmax[i]) 159 | return 1-gini 160 | 161 | def rnn_forward(self, ey, z_list, c_list, z_prev, c_prev): 162 | if self.dtype == "lstm": 163 | z_list[0], c_list[0] = self.decoder[0](ey, (z_prev[0], c_prev[0])) 164 | for i in range(1, self.dlayers): 165 | z_list[i], c_list[i] = self.decoder[i]( 166 | self.dropout_dec[i - 1](z_list[i - 1]), 167 | (z_prev[i], c_prev[i]), 168 | ) 169 | 170 | else: 171 | z_list[0] = self.decoder[0](ey, z_prev[0]) 172 | for i in range(1, self.dlayers): 173 | z_list[i] = self.decoder[i]( 174 | self.dropout_dec[i - 1](z_list[i - 1]), z_prev[i] 175 | ) 176 | return z_list, c_list 177 | 178 | def forward(self, hs_pad, hlens, ys_in_pad, ys_in_lens, strm_idx=0): 179 | # to support mutiple encoder asr mode, in single encoder mode, 180 | # convert torch.Tensor to List of torch.Tensor 181 | if self.num_encs == 1: 182 | hs_pad = [hs_pad] 183 | hlens = [hlens] 184 | 185 | # attention index for the attention module 186 | # in SPA (speaker parallel attention), 187 | # att_idx is used to select attention module. In other cases, it is 0. 188 | att_idx = min(strm_idx, len(self.att_list) - 1) 189 | 190 | # hlens should be list of list of integer 191 | hlens = [list(map(int, hlens[idx])) for idx in range(self.num_encs)] 192 | 193 | # get dim, length info 194 | 195 | olength = ys_in_pad.size(1) 196 | # initialization 197 | c_list = [self.zero_state(hs_pad[0])] 198 | z_list = [self.zero_state(hs_pad[0])] 199 | for _ in range(1, self.dlayers): 200 | c_list.append(self.zero_state(hs_pad[0])) 201 | z_list.append(self.zero_state(hs_pad[0])) 202 | z_all = [] 203 | if self.num_encs == 1: 204 | att_w = None 205 | self.att_list[att_idx].reset() # reset pre-computation of h 206 | else: 207 | att_w_list = [None] * (self.num_encs + 1) # atts + han 208 | att_c_list = [None] * self.num_encs # atts 209 | for idx in range(self.num_encs + 1): 210 | # reset pre-computation of h in atts and han 211 | self.att_list[idx].reset() 212 | 213 | # pre-computation of embedding 214 | eys = self.dropout_emb(self.embed(ys_in_pad)) # utt x olen x zdim 215 | 216 | # loop for an output sequence 217 | # logging.info(f"olength is {olength}") 218 | for i in range(olength): 219 | if self.num_encs == 1: 220 | att_c, att_w = self.att_list[att_idx]( 221 | hs_pad[0], hlens[0], self.dropout_dec[0](z_list[0]), att_w 222 | ) 223 | else: 224 | for idx in range(self.num_encs): 225 | att_c_list[idx], att_w_list[idx] = self.att_list[idx]( 226 | hs_pad[idx], 227 | hlens[idx], 228 | self.dropout_dec[0](z_list[0]), 229 | att_w_list[idx], 230 | ) 231 | hs_pad_han = torch.stack(att_c_list, dim=1) 232 | hlens_han = [self.num_encs] * len(ys_in_pad) 233 | att_c, att_w_list[self.num_encs] = self.att_list[self.num_encs]( 234 | hs_pad_han, 235 | hlens_han, 236 | self.dropout_dec[0](z_list[0]), 237 | att_w_list[self.num_encs], 238 | ) 239 | if i > 0 and random.random() < self.sampling_probability: 240 | z_out = self.output(z_all[-1]) 241 | z_out = np.argmax(z_out.detach().cpu(), axis=1) 242 | z_out = self.dropout_emb(self.embed(to_device(self, z_out))) 243 | ey = torch.cat((z_out, att_c), dim=1) # utt x (zdim + hdim) 244 | else: 245 | # utt x (zdim + hdim) 246 | ey = torch.cat((eys[:, i, :], att_c), dim=1) 247 | z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list, c_list) 248 | if self.context_residual: 249 | z_all.append( 250 | torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1) 251 | ) # utt x (zdim + hdim) 252 | else: 253 | z_all.append(self.dropout_dec[-1](z_list[-1])) # utt x (zdim) 254 | 255 | z_all = torch.stack(z_all, dim=1) 256 | z_all = self.output(z_all) 257 | z_all.masked_fill_( 258 | make_pad_mask(ys_in_lens, z_all, 1), 259 | 0, 260 | ) 261 | return z_all, ys_in_lens 262 | 263 | def init_state(self, x): 264 | # to support mutiple encoder asr mode, in single encoder mode, 265 | # convert torch.Tensor to List of torch.Tensor 266 | if self.num_encs == 1: 267 | x = [x] 268 | 269 | c_list = [self.zero_state(x[0].unsqueeze(0))] 270 | z_list = [self.zero_state(x[0].unsqueeze(0))] 271 | for _ in range(1, self.dlayers): 272 | c_list.append(self.zero_state(x[0].unsqueeze(0))) 273 | z_list.append(self.zero_state(x[0].unsqueeze(0))) 274 | # TODO(karita): support strm_index for `asr_mix` 275 | strm_index = 0 276 | att_idx = min(strm_index, len(self.att_list) - 1) 277 | if self.num_encs == 1: 278 | a = None 279 | self.att_list[att_idx].reset() # reset pre-computation of h 280 | else: 281 | a = [None] * (self.num_encs + 1) # atts + han 282 | for idx in range(self.num_encs + 1): 283 | # reset pre-computation of h in atts and han 284 | self.att_list[idx].reset() 285 | return dict( 286 | c_prev=c_list[:], 287 | z_prev=z_list[:], 288 | a_prev=a, 289 | workspace=(att_idx, z_list, c_list), 290 | ) 291 | 292 | def score(self, yseq, state, x): 293 | # to support mutiple encoder asr mode, in single encoder mode, 294 | # convert torch.Tensor to List of torch.Tensor 295 | #print("call rnn decoder score:", yseq, state, x) 296 | if self.num_encs == 1: 297 | x = [x] 298 | 299 | att_idx, z_list, c_list = state["workspace"] 300 | vy = yseq[-1].unsqueeze(0) 301 | ey = self.dropout_emb(self.embed(vy)) # utt list (1) x zdim 302 | if self.num_encs == 1: 303 | att_c, att_w = self.att_list[att_idx]( 304 | x[0].unsqueeze(0), 305 | [x[0].size(0)], 306 | self.dropout_dec[0](state["z_prev"][0]), 307 | state["a_prev"], 308 | ) 309 | else: 310 | att_w = [None] * (self.num_encs + 1) # atts + han 311 | att_c_list = [None] * self.num_encs # atts 312 | for idx in range(self.num_encs): 313 | att_c_list[idx], att_w[idx] = self.att_list[idx]( 314 | x[idx].unsqueeze(0), 315 | [x[idx].size(0)], 316 | self.dropout_dec[0](state["z_prev"][0]), 317 | state["a_prev"][idx], 318 | ) 319 | h_han = torch.stack(att_c_list, dim=1) 320 | att_c, att_w[self.num_encs] = self.att_list[self.num_encs]( 321 | h_han, 322 | [self.num_encs], 323 | self.dropout_dec[0](state["z_prev"][0]), 324 | state["a_prev"][self.num_encs], 325 | ) 326 | ey = torch.cat((ey, att_c), dim=1) # utt(1) x (zdim + hdim) 327 | z_list, c_list = self.rnn_forward( 328 | ey, z_list, c_list, state["z_prev"], state["c_prev"] 329 | ) 330 | if self.context_residual: 331 | logits = self.output( 332 | torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1) 333 | ) 334 | else: 335 | logits = self.output(self.dropout_dec[-1](z_list[-1])) 336 | 337 | logp = F.log_softmax(logits, dim=1).squeeze(0) 338 | outp = F.softmax(logits, dim=1).squeeze(0) 339 | c=[o for o in outp.numpy()] 340 | sum_output = 0 341 | for o in c: 342 | sum_output += o 343 | # logging.info(f"outp is {outp}, sum is {sum_output}") 344 | # logging.info(f"z_list is {z_list}") 345 | return ( 346 | logp, 347 | dict( 348 | c_prev=c_list[:], 349 | z_prev=z_list[:], 350 | a_prev=att_w, 351 | workspace=(att_idx, z_list, c_list), 352 | ), 353 | outp, 354 | z_list, 355 | ) 356 | -------------------------------------------------------------------------------- /espnet/bin/asr_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # encoding: utf-8 3 | 4 | # Copyright 2017 Johns Hopkins University (Shinji Watanabe) 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | """End-to-end speech recognition model decoding script.""" 8 | 9 | import configargparse 10 | import logging 11 | import os 12 | import random 13 | import sys 14 | 15 | import numpy as np 16 | 17 | from espnet.utils.cli_utils import strtobool 18 | from espnet.utils.gini_guide import * 19 | from espnet.utils.gini_utils import * 20 | from espnet.utils.retrain_utils import * 21 | 22 | # NOTE: you need this func to generate our sphinx doc 23 | 24 | 25 | def get_parser(): 26 | """Get default arguments.""" 27 | parser = configargparse.ArgumentParser( 28 | description="Transcribe text from speech using " 29 | "a speech recognition model on one CPU or GPU", 30 | config_file_parser_class=configargparse.YAMLConfigFileParser, 31 | formatter_class=configargparse.ArgumentDefaultsHelpFormatter, 32 | ) 33 | # general configuration 34 | parser.add("--config", is_config_file=True, help="Config file path") 35 | parser.add( 36 | "--config2", 37 | is_config_file=True, 38 | help="Second config file path that overwrites the settings in `--config`", 39 | ) 40 | parser.add( 41 | "--config3", 42 | is_config_file=True, 43 | help="Third config file path that overwrites the settings " 44 | "in `--config` and `--config2`", 45 | ) 46 | 47 | parser.add_argument("--ngpu", type=int, default=0, help="Number of GPUs") 48 | parser.add_argument( 49 | "--dtype", 50 | choices=("float16", "float32", "float64"), 51 | default="float32", 52 | help="Float precision (only available in --api v2)", 53 | ) 54 | parser.add_argument( 55 | "--backend", 56 | type=str, 57 | default="chainer", 58 | choices=["chainer", "pytorch"], 59 | help="Backend library", 60 | ) 61 | parser.add_argument("--debugmode", type=int, default=1, help="Debugmode") 62 | parser.add_argument("--seed", type=int, default=1, help="Random seed") 63 | parser.add_argument("--verbose", "-V", type=int, default=1, help="Verbose option") 64 | parser.add_argument( 65 | "--batchsize", 66 | type=int, 67 | default=1, 68 | help="Batch size for beam search (0: means no batch processing)", 69 | ) 70 | parser.add_argument( 71 | "--preprocess-conf", 72 | type=str, 73 | default=None, 74 | help="The configuration file for the pre-processing", 75 | ) 76 | parser.add_argument( 77 | "--api", 78 | default="v1", 79 | choices=["v1", "v2"], 80 | help="Beam search APIs " 81 | "v1: Default API. It only supports the ASRInterface.recognize method " 82 | "and DefaultRNNLM. " 83 | "v2: Experimental API. It supports any models that implements ScorerInterface.", 84 | ) 85 | # task related 86 | parser.add_argument( 87 | "--recog-json", type=str, help="Filename of recognition data (json)" 88 | ) 89 | parser.add_argument( 90 | "--result-label", 91 | type=str, 92 | required=True, 93 | help="Filename of result label data (json)", 94 | ) 95 | # model (parameter) related 96 | parser.add_argument( 97 | "--model", type=str, required=True, help="Model file parameters to read" 98 | ) 99 | parser.add_argument( 100 | "--model-conf", type=str, default=None, help="Model config file" 101 | ) 102 | parser.add_argument( 103 | "--num-spkrs", 104 | type=int, 105 | default=1, 106 | choices=[1, 2], 107 | help="Number of speakers in the speech", 108 | ) 109 | parser.add_argument( 110 | "--num-encs", default=1, type=int, help="Number of encoders in the model." 111 | ) 112 | # search related 113 | parser.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses") 114 | parser.add_argument("--beam-size", type=int, default=1, help="Beam size") 115 | parser.add_argument("--penalty", type=float, default=0.0, help="Incertion penalty") 116 | parser.add_argument( 117 | "--maxlenratio", 118 | type=float, 119 | default=0.0, 120 | help="""Input length ratio to obtain max output length. 121 | If maxlenratio=0.0 (default), it uses a end-detect function 122 | to automatically find maximum hypothesis lengths""", 123 | ) 124 | parser.add_argument( 125 | "--minlenratio", 126 | type=float, 127 | default=0.0, 128 | help="Input length ratio to obtain min output length", 129 | ) 130 | parser.add_argument( 131 | "--ctc-weight", type=float, default=0.0, help="CTC weight in joint decoding" 132 | ) 133 | parser.add_argument( 134 | "--weights-ctc-dec", 135 | type=float, 136 | action="append", 137 | help="ctc weight assigned to each encoder during decoding." 138 | "[in multi-encoder mode only]", 139 | ) 140 | parser.add_argument( 141 | "--ctc-window-margin", 142 | type=int, 143 | default=0, 144 | help="""Use CTC window with margin parameter to accelerate 145 | CTC/attention decoding especially on GPU. Smaller magin 146 | makes decoding faster, but may increase search errors. 147 | If margin=0 (default), this function is disabled""", 148 | ) 149 | # transducer related 150 | parser.add_argument( 151 | "--search-type", 152 | type=str, 153 | default="default", 154 | choices=["default", "nsc", "tsd", "alsd"], 155 | help="""Type of beam search implementation to use during inference. 156 | Can be either: default beam search, n-step constrained beam search ("nsc"), 157 | time-synchronous decoding ("tsd") or alignment-length synchronous decoding 158 | ("alsd"). 159 | Additional associated parameters: "nstep" + "prefix-alpha" (for nsc), 160 | "max-sym-exp" (for tsd) and "u-max" (for alsd)""", 161 | ) 162 | parser.add_argument( 163 | "--nstep", 164 | type=int, 165 | default=1, 166 | help="Number of expansion steps allowed in NSC beam search.", 167 | ) 168 | parser.add_argument( 169 | "--prefix-alpha", 170 | type=int, 171 | default=2, 172 | help="Length prefix difference allowed in NSC beam search.", 173 | ) 174 | parser.add_argument( 175 | "--max-sym-exp", 176 | type=int, 177 | default=2, 178 | help="Number of symbol expansions allowed in TSD decoding.", 179 | ) 180 | parser.add_argument( 181 | "--u-max", 182 | type=int, 183 | default=400, 184 | help="Length prefix difference allowed in ALSD beam search.", 185 | ) 186 | parser.add_argument( 187 | "--score-norm", 188 | type=strtobool, 189 | nargs="?", 190 | default=True, 191 | help="Normalize transducer scores by length", 192 | ) 193 | # rnnlm related 194 | parser.add_argument( 195 | "--rnnlm", type=str, default=None, help="RNNLM model file to read" 196 | ) 197 | parser.add_argument( 198 | "--rnnlm-conf", type=str, default=None, help="RNNLM model config file to read" 199 | ) 200 | parser.add_argument( 201 | "--word-rnnlm", type=str, default=None, help="Word RNNLM model file to read" 202 | ) 203 | parser.add_argument( 204 | "--word-rnnlm-conf", 205 | type=str, 206 | default=None, 207 | help="Word RNNLM model config file to read", 208 | ) 209 | parser.add_argument("--word-dict", type=str, default=None, help="Word list to read") 210 | parser.add_argument("--lm-weight", type=float, default=0.1, help="RNNLM weight") 211 | # ngram related 212 | parser.add_argument( 213 | "--ngram-model", type=str, default=None, help="ngram model file to read" 214 | ) 215 | parser.add_argument("--ngram-weight", type=float, default=0.1, help="ngram weight") 216 | parser.add_argument( 217 | "--ngram-scorer", 218 | type=str, 219 | default="part", 220 | choices=("full", "part"), 221 | help="""if the ngram is set as a part scorer, similar with CTC scorer, 222 | ngram scorer only scores topK hypethesis. 223 | if the ngram is set as full scorer, ngram scorer scores all hypthesis 224 | the decoding speed of part scorer is musch faster than full one""", 225 | ) 226 | # streaming related 227 | parser.add_argument( 228 | "--streaming-mode", 229 | type=str, 230 | default=None, 231 | choices=["window", "segment"], 232 | help="""Use streaming recognizer for inference. 233 | `--batchsize` must be set to 0 to enable this mode""", 234 | ) 235 | parser.add_argument("--streaming-window", type=int, default=10, help="Window size") 236 | parser.add_argument( 237 | "--streaming-min-blank-dur", 238 | type=int, 239 | default=10, 240 | help="Minimum blank duration threshold", 241 | ) 242 | parser.add_argument( 243 | "--streaming-onset-margin", type=int, default=1, help="Onset margin" 244 | ) 245 | parser.add_argument( 246 | "--streaming-offset-margin", type=int, default=1, help="Offset margin" 247 | ) 248 | # non-autoregressive related 249 | # Mask CTC related. See https://arxiv.org/abs/2005.08700 for the detail. 250 | parser.add_argument( 251 | "--maskctc-n-iterations", 252 | type=int, 253 | default=10, 254 | help="Number of decoding iterations." 255 | "For Mask CTC, set 0 to predict 1 mask/iter.", 256 | ) 257 | parser.add_argument( 258 | "--maskctc-probability-threshold", 259 | type=float, 260 | default=0.999, 261 | help="Threshold probability for CTC output", 262 | ) 263 | parser.add_argument("--orig_flag", type=str, default="False") 264 | parser.add_argument("--orig_dir", type=str, required=True) 265 | parser.add_argument( 266 | "--recog_set", 267 | type=str, 268 | help="Whether use gini or not", 269 | ) 270 | parser.add_argument( 271 | "--need_decode", 272 | type=str, 273 | default="True", 274 | ) 275 | 276 | return parser 277 | 278 | 279 | def str_2_bool(s): 280 | return True if s.lower() =='true' else False 281 | 282 | def main(args): 283 | """Run the main decoding function.""" 284 | parser = get_parser() 285 | args = parser.parse_args(args) 286 | orig_flag = str_2_bool(args.orig_flag) 287 | need_decode = str_2_bool(args.need_decode) 288 | if args.ngpu == 0 and args.dtype == "float16": 289 | raise ValueError(f"--dtype {args.dtype} does not support the CPU backend.") 290 | 291 | # logging info 292 | if args.verbose == 1: 293 | logging.basicConfig( 294 | level=logging.INFO, 295 | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", 296 | ) 297 | elif args.verbose == 2: 298 | logging.basicConfig( 299 | level=logging.DEBUG, 300 | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", 301 | ) 302 | else: 303 | logging.basicConfig( 304 | level=logging.WARN, 305 | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", 306 | ) 307 | logging.warning("Skip DEBUG/INFO messages") 308 | 309 | # check CUDA_VISIBLE_DEVICES 310 | if args.ngpu > 0: 311 | cvd = os.environ.get("CUDA_VISIBLE_DEVICES") 312 | if cvd is None: 313 | logging.warning("CUDA_VISIBLE_DEVICES is not set.") 314 | elif args.ngpu != len(cvd.split(",")): 315 | logging.error("#gpus is not matched with CUDA_VISIBLE_DEVICES.") 316 | sys.exit(1) 317 | 318 | # TODO(mn5k): support of multiple GPUs 319 | if args.ngpu > 1: 320 | logging.error("The program only supports ngpu=1.") 321 | sys.exit(1) 322 | 323 | # display PYTHONPATH 324 | logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)")) 325 | 326 | # seed setting 327 | random.seed(args.seed) 328 | np.random.seed(args.seed) 329 | logging.info("set random seed = %d" % args.seed) 330 | 331 | # validate rnn options 332 | if args.rnnlm is not None and args.word_rnnlm is not None: 333 | logging.error( 334 | "It seems that both --rnnlm and --word-rnnlm are specified. " 335 | "Please use either option." 336 | ) 337 | sys.exit(1) 338 | 339 | # recog 340 | logging.info("backend = " + args.backend) 341 | if args.num_spkrs == 1: 342 | if args.backend == "chainer": 343 | from espnet.asr.chainer_backend.asr import recog 344 | recog(args) 345 | elif args.backend == "pytorch": 346 | if args.num_encs == 1: 347 | # Experimental API that supports custom LMs 348 | if args.api == "v2": 349 | if need_decode: 350 | from espnet.asr.pytorch_backend.recog_test import recog_v2 351 | recog_v2(args, orig_flag) 352 | else: 353 | new_dir = "/".join(args.result_label.split("/")[0:-1]) 354 | orig_dir = new_dir.replace("new", "orig") 355 | aug_type = args.recog_set.split("-")[-1] 356 | gini_audio = sort_test_by_gini_v1(args.orig_dir, new_dir, aug_type) 357 | test_data_prep("data/test-orig", "data/" + args.recog_set, "ted2", gini_audio, "gini", 1155, args.recog_set) 358 | 359 | else: 360 | from espnet.asr.pytorch_backend.asr import recog 361 | 362 | if args.dtype != "float32": 363 | raise NotImplementedError( 364 | f"`--dtype {args.dtype}` is only available with `--api v2`" 365 | ) 366 | recog(args) 367 | else: 368 | if args.api == "v2": 369 | raise NotImplementedError( 370 | f"--num-encs {args.num_encs} > 1 is not supported in --api v2" 371 | ) 372 | else: 373 | from espnet.asr.pytorch_backend.asr import recog 374 | 375 | recog(args) 376 | else: 377 | raise ValueError("Only chainer and pytorch are supported.") 378 | elif args.num_spkrs == 2: 379 | if args.backend == "pytorch": 380 | from espnet.asr.pytorch_backend.asr_mix import recog 381 | 382 | recog(args) 383 | else: 384 | raise ValueError("Only pytorch is supported.") 385 | 386 | 387 | if __name__ == "__main__": 388 | main(sys.argv[1:]) 389 | 390 | -------------------------------------------------------------------------------- /espnet/utils/gini_utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import os 4 | import json 5 | import numpy as np 6 | import math 7 | 8 | 9 | class SpeechSort(object): 10 | def __init__(self, name, s_count, len_diff, mix_sort): 11 | self.name = name 12 | self.s_count = s_count 13 | self.len_diff = len_diff 14 | self.mix_sort = mix_sort 15 | 16 | 17 | def cau_sum_ginis(token_int, all_ginis): 18 | sum_gini = 0 19 | gini_list = [] 20 | gini_key = "" 21 | if len(all_ginis) == 0: 22 | gini_list.append(0) 23 | else: 24 | for i in range(len(token_int)): 25 | token = token_int[i] 26 | gini_key += str(token) + " " 27 | if gini_key.strip() in all_ginis.keys(): 28 | sum_gini += all_ginis[gini_key.strip()] 29 | gini_list.append(all_ginis[gini_key.strip()]) 30 | else: 31 | key = gini_key.strip().split(" ") 32 | if len(key) > 1: 33 | key = " ".join(key[:-1]) 34 | else: 35 | key = key[-1] 36 | sum_gini += all_ginis[key] 37 | gini_list.append(all_ginis[key]) 38 | return sum_gini, gini_list 39 | 40 | 41 | def compare_sum_gini(source_sum_gini, new_sum_gini, T_s): 42 | if abs(source_sum_gini - new_sum_gini) > T_s: 43 | return True 44 | else: 45 | return False 46 | 47 | 48 | def caul_gini(softmax_value): 49 | gini = 0 50 | class_softmax = softmax_value.cpu().detach().numpy() 51 | #print("class_softmax:",class_softmax) 52 | for i in range(0, len(class_softmax)): 53 | gini += np.square(class_softmax[i]) 54 | #print("gini value: ", 1-gini) 55 | return 1-gini 56 | 57 | 58 | def read_wrd_trn(file_path): 59 | file = open(file_path, 'r') 60 | word_list = [] 61 | line = file.readline() 62 | word_count = 0 63 | while line: 64 | idx = line.find("(") 65 | line = line[:idx] 66 | word_list.append(line.split(" ")) 67 | word_count += len(line.split(" ")) 68 | line = file.readline() 69 | print(word_count) 70 | return word_list, word_count 71 | 72 | 73 | def load_token_file(file_path): 74 | f = open(file_path, "r") 75 | line = f.readline() 76 | token_dict = {} 77 | while line: 78 | line = line[:-1] 79 | split_line = line.split(" ") 80 | name = split_line[0] 81 | if "gini" in file_path: 82 | token_dict[name] = [float(gini) for gini in split_line[1:]] 83 | else: 84 | token_dict[name] = split_line[1:] 85 | line = f.readline() 86 | 87 | return token_dict 88 | 89 | def get_word_gini(token_list, token_gini_list): 90 | all_word_list = {} 91 | all_gini_list = {} 92 | for name in token_list.keys(): 93 | token = token_list[name] 94 | token_gini = token_gini_list[name] 95 | word_list = [] 96 | gini_list = [] 97 | word = "" 98 | word_gini = 0 99 | #c_count = 0 100 | for i in range(len(token)): 101 | if token[i] != token[0]: 102 | word = word + token[i] 103 | if token_gini[i] > word_gini: 104 | word_gini = token_gini[i] 105 | #c_count += 1 106 | else: 107 | if len(word) != 0: 108 | gini_list.append(word_gini) 109 | word_list.append(word) 110 | word = "" 111 | #c_count = 0 112 | word_gini = 0 113 | if len(word) != 0: 114 | gini_list.append(word_gini) 115 | word_list.append(word) 116 | all_word_list[name] = word_list 117 | all_gini_list[name] = gini_list 118 | 119 | return all_word_list, all_gini_list 120 | 121 | def get_token_list(json_data): 122 | all_token = {} 123 | for name in json_data["utts"]: 124 | token = json_data["utts"][name]["output"][0]["rec_token"].split(" ") 125 | all_token[name.lower()] = token[:-1] 126 | return all_token 127 | 128 | 129 | def get_word_gini_v1(json_data, token_gini_list, hyp_dict): 130 | all_word_list = {} 131 | all_gini_list = {} 132 | for name in json_data["utts"]: 133 | token = json_data["utts"][name]["output"][0]["rec_token"].split(" ") 134 | token_gini = token_gini_list[name.lower()] 135 | hyp_word = hyp_dict[name.lower()] 136 | word_list = list(hyp_word) 137 | for i in range(len(hyp_word)): 138 | if "*" in hyp_word[i]: 139 | word_list.remove(hyp_word[i]) 140 | #word_list = [] 141 | gini_list = [] 142 | word = "" 143 | word_gini = 0 144 | w_idx = 0 145 | #c_count = 0 146 | for i in range(len(token)): 147 | if token[i] == "": 148 | break 149 | if (i == 0) & (token[i][0]!= "▁"): 150 | token[i] = "▁" + token[i] 151 | word = word + token[i] 152 | if token_gini[i] > word_gini: 153 | word_gini = token_gini[i] 154 | if len(word) > 1: 155 | if word[0:2] == "▁▁": 156 | word = word[1:] 157 | print(name) 158 | if word[1:] == word_list[w_idx].lower(): 159 | gini_list.append(word_gini) 160 | w_idx += 1 161 | word = "" 162 | word_gini = 0 163 | 164 | if w_idx != len(word_list): 165 | gini_list.append(word_gini) 166 | if len(word_list) != len(gini_list): 167 | print("name:", name) 168 | print("word, w_idx, word_gini", word, w_idx, word_gini) 169 | print(word_list, len(word_list)) 170 | print(gini_list, len(gini_list)) 171 | all_word_list[name.lower()] = word_list 172 | all_gini_list[name.lower()] = gini_list 173 | 174 | return all_word_list, all_gini_list 175 | 176 | 177 | def get_gini_threshold(all_gini_list, eval_dict, orig_word_list): 178 | s_gini = 0 179 | i_gini = 0 180 | c_gini = 0 181 | s_count = 0 182 | i_count = 0 183 | c_count = 0 184 | d_count = 0 185 | for name in eval_dict.keys(): 186 | if name not in all_gini_list.keys(): 187 | gini_name = name.upper() 188 | gini_list = all_gini_list[gini_name] 189 | else: 190 | gini_list = all_gini_list[name] 191 | eval_list = eval_dict[name] 192 | for i in range(len(eval_list)): 193 | if eval_list[i] == "D": 194 | d_count += 1 195 | gini_list.insert(i, -1) 196 | if eval_list[i] == "S": 197 | s_gini += gini_list[i] 198 | s_count += 1 199 | if eval_list[i] == "I": 200 | #print("insert_gini:", name, gini_list[i]) 201 | i_gini += gini_list[i] 202 | i_count += 1 203 | if eval_list[i] == "C": 204 | c_gini += gini_list[i] 205 | c_count += 1 206 | print(s_count, i_count, c_count, d_count) 207 | if s_count > i_count + d_count: 208 | flag = "s_first" 209 | else: 210 | flag = "len_first" 211 | if i_gini != 0: 212 | i_t = i_gini/i_count 213 | else: 214 | i_t = 0 215 | print("flag: ", flag) 216 | return s_gini/s_count, i_t, c_gini/c_count, flag 217 | 218 | 219 | def judge_res(ref_word, hyp_word): 220 | eval_list = [] 221 | for i in range(len(ref_word)): 222 | if ref_word[i] == hyp_word[i]: 223 | eval_list.append("C") 224 | elif "*" in ref_word[i]: 225 | eval_list.append("I") 226 | elif "*" in hyp_word[i]: 227 | eval_list.append("D") 228 | else: 229 | eval_list.append("S") 230 | return eval_list 231 | 232 | 233 | def load_result_txt(file_path): 234 | f = open(file_path, 'r') 235 | ref_dict = {} 236 | hyp_dict = {} 237 | eval_dict = {} 238 | line = f.readline() 239 | while line: 240 | if "id: (" in line: 241 | name = line[:-1].replace("id: (", "").replace(")", "") 242 | name = "-".join(name.split("-")[1:]) 243 | if "REF: " in line: 244 | ref = line[:-1].split()[1:] 245 | ref_dict[name] = ref 246 | if "HYP: " in line: 247 | hyp = line[:-1].split()[1:] 248 | hyp_dict[name] = hyp 249 | if "Eval: " in line: 250 | eval_list = judge_res(ref, hyp) 251 | eval_dict[name] = eval_list 252 | line = f.readline() 253 | return ref_dict, hyp_dict, eval_dict 254 | 255 | 256 | def gini_sort_v1(s_t, all_gini_list, orig_token, new_token, flag): 257 | print("flag is", flag) 258 | sort_list = [] 259 | for key in new_token.keys(): 260 | if "&" not in key: 261 | diff = 0 262 | else: 263 | orig_name = key.split("&")[0] + "-" + "-".join(key.split("&")[1].split("-")[1:]) 264 | orig_len = len(orig_token[orig_name]) 265 | new_len = len(new_token[key]) 266 | diff = abs(new_len - orig_len) 267 | s_count = 0 268 | for gini in all_gini_list[key]: 269 | if gini > s_t: 270 | s_count += 1 271 | if len(all_gini_list[key]) == 0: 272 | sort_list.append(SpeechSort(key, 1.0, 1.0, 1.0)) 273 | else: 274 | sort_list.append(SpeechSort(key, s_count/len(all_gini_list[key]), diff/len(all_gini_list[key]), (s_count + diff)/len(all_gini_list[key]))) 275 | sort_list.sort(key=lambda x:x.mix_sort, reverse=True) 276 | 277 | return sort_list 278 | 279 | 280 | def gini_sort(s_t, all_gini_list, orig_token, new_token, flag): 281 | sort_list = [] 282 | s_count_list = [] 283 | len_diff_list = [] 284 | 285 | for key in new_token.keys(): 286 | if "&" not in key: 287 | diff = 0 288 | else: 289 | orig_name = key.split("&")[0] 290 | orig_len = len(orig_token[orig_name]) 291 | new_len = len(new_token[key]) 292 | #diff = abs(new_len - orig_len)/orig_len 293 | diff = abs(new_len - orig_len) 294 | s_count = 0 295 | for gini in all_gini_list[key]: 296 | if gini > s_t: 297 | s_count += 1 298 | sort_list.append(SpeechSort(key, s_count/len(all_gini_list[key]), diff, (s_count + diff)/len(all_gini_list[key]))) 299 | sort_list.sort(key=lambda x:x.mix_sort, reverse=True) 300 | 301 | return sort_list 302 | 303 | 304 | def parse_hypothesis(hyp, char_list, gini_list): 305 | """Parse hypothesis. 306 | Args: 307 | hyp (list[dict[str, Any]]): Recognition hypothesis. 308 | char_list (list[str]): List of characters. 309 | Returns: 310 | tuple(str, str, str, float) 311 | """ 312 | # remove sos and get results 313 | tokenid_as_list = list(map(int, hyp["yseq"][1:])) 314 | token_as_list = [char_list[idx] for idx in tokenid_as_list] 315 | score = float(hyp["score"]) 316 | 317 | # convert to string 318 | tokenid = " ".join([str(idx) for idx in tokenid_as_list]) 319 | token = " ".join(token_as_list) 320 | text = "".join(token_as_list).replace("", " ") 321 | ginilist = " ".join([str(v) for v in gini_list]) 322 | return text, token, tokenid, score, ginilist 323 | 324 | 325 | def add_results_to_json(js, nbest_hyps, char_list, sum_gini, gini_list): 326 | """Add N-best results to json. 327 | Args: 328 | js (dict[str, Any]): Groundtruth utterance dict. 329 | nbest_hyps_sd (list[dict[str, Any]]): 330 | List of hypothesis for multi_speakers: nutts x nspkrs. 331 | char_list (list[str]): List of characters. 332 | Returns: 333 | dict[str, Any]: N-best results added utterance dict. 334 | """ 335 | # copy old json info 336 | new_js = dict() 337 | new_js["utt2spk"] = js["utt2spk"] 338 | new_js["output"] = [] 339 | 340 | for n, hyp in enumerate(nbest_hyps, 1): 341 | # parse hypothesis 342 | rec_text, rec_token, rec_tokenid, score, gini_list = parse_hypothesis(hyp, char_list, gini_list) 343 | 344 | # copy ground-truth 345 | if len(js["output"]) > 0: 346 | out_dic = dict(js["output"][0].items()) 347 | else: 348 | # for no reference case (e.g., speech translation) 349 | out_dic = {"name": ""} 350 | 351 | # update name 352 | out_dic["name"] += "[%d]" % n 353 | 354 | # add recognition results 355 | out_dic["rec_text"] = rec_text 356 | out_dic["rec_token"] = rec_token 357 | out_dic["rec_tokenid"] = rec_tokenid 358 | out_dic["score"] = score 359 | out_dic["sum_gini"] = sum_gini 360 | out_dic["gini_list"] = gini_list 361 | 362 | # add to list of N-best result dicts 363 | new_js["output"].append(out_dic) 364 | 365 | # show 1-best result 366 | if n == 1: 367 | if "text" in out_dic.keys(): 368 | logging.info("groundtruth: %s" % out_dic["text"]) 369 | logging.info("prediction : %s" % out_dic["rec_text"]) 370 | 371 | return new_js 372 | 373 | 374 | def load_gini(file_path): 375 | f = open(file_path, 'r') 376 | line = f.readline() 377 | orig_gini = {} 378 | while line: 379 | line = line[:-1] 380 | key = line.split(" ")[0] 381 | if "sum_gini" in file_path: 382 | orig_gini[key] = float(line.split(" ")[1]) 383 | else: 384 | if len(line[:-1].split(" ")[1:]) > 0: 385 | gini_list = [float(x) for x in line.split(" ")[1:]] 386 | else: 387 | gini_list = [0.0] 388 | orig_gini[key] = gini_list 389 | line = f.readline() 390 | return orig_gini 391 | 392 | 393 | def load_gini_v1(file_path): 394 | f = open(file_path, 'r') 395 | line = f.readline() 396 | orig_gini = {} 397 | while line: 398 | line = line[:-1] 399 | key = line.split(" ")[0].lower() 400 | if "sum_gini" in file_path: 401 | orig_gini[key] = float(line.split(" ")[1]) 402 | else: 403 | if len(line[:-1].split(" ")[1:]) > 0: 404 | gini_list = [float(x) for x in line.split(" ")[1:]] 405 | else: 406 | gini_list = [0.0] 407 | orig_gini[key] = gini_list 408 | line = f.readline() 409 | return orig_gini 410 | 411 | 412 | def load_text(file_path): 413 | f = open(file_path, 'r') 414 | line = f.readline() 415 | orig_text = {} 416 | while line: 417 | line = line[:-1] 418 | key = line.split(" ")[0] 419 | orig_text[key] = line.split(" ")[1:] 420 | line = f.readline() 421 | return orig_text 422 | 423 | 424 | def load_score(file_path): 425 | f = open(file_path, 'r') 426 | line = f.readline() 427 | orig_score = {} 428 | while line: 429 | line = line[:-1] 430 | key = line.split(" ")[0] 431 | orig_score[key] = line.split(" ")[1] 432 | line = f.readline() 433 | return orig_score 434 | 435 | 436 | def load_token(file_path): 437 | f = open(file_path, 'r') 438 | line = f.readline() 439 | orig_token = {} 440 | while line: 441 | line = line[:-1] 442 | key = line.split(" ")[0] 443 | orig_token[key] = line.split(" ")[1:] 444 | line = f.readline() 445 | return orig_token 446 | 447 | 448 | def mkdir(path): 449 | folder = os.path.exists(path) 450 | if not folder: 451 | os.makedirs(path) 452 | 453 | 454 | def write_new_info(new_dir, file_name, result): 455 | mkdir(new_dir) 456 | f = open(new_dir + "/" + file_name, 'w') 457 | for key in result.keys(): 458 | if isinstance(result[key], list): 459 | line = key + " " + " ".join([str(x) for x in result[key]]) + "\n" 460 | else: 461 | line = key + " " + str(result[key]) + "\n" 462 | f.write(line) 463 | 464 | 465 | def write_gini_select_result(new_dir, file_name, result): 466 | mkdir(new_dir) 467 | f = open(new_dir + "/" + file_name, 'w') 468 | for item in result: 469 | line = item.name + " " + str(item.len_diff) + " " + str(item.s_count)+ "\n" 470 | f.write(line) 471 | 472 | 473 | def update_data_json(json_data, key_name): 474 | new_json = copy.deepcopy(json_data) 475 | new_json["utts"].pop(key_name) 476 | return new_json 477 | 478 | 479 | def load_data_json(path): 480 | with open(path, 'r', encoding='utf8')as fp: 481 | json_data = json.load(fp) 482 | return json_data 483 | 484 | def write_json_data(path, new_json) : 485 | mkdir(path) 486 | with open(path + "/" + "data.json", "w", encoding='utf8')as fp: 487 | json.dump(new_json, fp, ensure_ascii=False) 488 | 489 | 490 | def load_hyp_ref(path): 491 | f = open(path, 'r') 492 | line = f.readline() 493 | orig_dict = {} 494 | while line: 495 | value = str(line) 496 | line = line[:-1] 497 | line = line.replace("\t", " ") 498 | key = line.split(" ")[-1][1:-1] 499 | idx = key.find("-") 500 | key = key[idx + 1:] 501 | orig_dict[key] = value 502 | line = f.readline() 503 | return orig_dict 504 | 505 | 506 | def write_hyp_ref(path, name, result): 507 | mkdir(path) 508 | f = open(path + "/" + name, 'w') 509 | for key in result.keys(): 510 | f.write(result[key]) 511 | f.close() 512 | -------------------------------------------------------------------------------- /espnet/utils/gini_guide.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | from espnet.utils.gini_utils import * 4 | import random 5 | import os 6 | import shutil 7 | 8 | def sort_test_by_gini(old_dir, new_dir, aug_type, selected_num): 9 | new_token_int = load_gini(new_dir + "/token_int") 10 | new_token = load_token(new_dir + "/token") 11 | orig_token = load_token(old_dir + "/token") 12 | new_token_gini = load_gini(new_dir + "/token_gini") 13 | orig_token_gini = load_gini(old_dir + "/token_gini") 14 | orig_word_list, orig_gini_list = get_word_gini(orig_token, orig_token_gini) 15 | new_word_list, new_gini_list = get_word_gini(new_token, new_token_gini) 16 | ref_dict, hyp_dict, orig_eval_dict = load_result_txt(old_dir + "/score_wer/result.txt") 17 | #s_t, i_t, c_t, flag = get_gini_threshold(orig_gini_list, orig_eval_dict, orig_word_list) 18 | s_t, i_t, c_t, flag = get_gini_threshold(orig_token_gini, orig_eval_dict, orig_token) 19 | #flag = "len_first" 20 | sort_list = gini_sort(s_t, new_gini_list, orig_token, new_token, flag) 21 | new_text = load_text(new_dir + "/text") 22 | new_score = load_score(new_dir + "/score") 23 | new_wer_ref = load_hyp_ref(new_dir + "/score_wer/ref.trn") 24 | new_wer_hyp = load_hyp_ref(new_dir + "/score_wer/hyp.trn") 25 | new_cer_ref = load_hyp_ref(new_dir + "/score_cer/ref.trn") 26 | new_cer_hyp = load_hyp_ref(new_dir + "/score_cer/hyp.trn") 27 | folder = os.path.exists(new_dir + "/score_ter/") 28 | if folder: 29 | new_ter_ref = load_hyp_ref(new_dir + "/score_ter/ref.trn") 30 | new_ter_hyp = load_hyp_ref(new_dir + "/score_ter/hyp.trn") 31 | ter_ref_result = dict(new_ter_ref) 32 | ter_hyp_result = dict(new_ter_hyp) 33 | 34 | token_result = dict() 35 | token_int_result = dict() 36 | token_gini_result = dict() 37 | #sum_gini_result = dict() 38 | text_result = dict() 39 | score_result = dict() 40 | wer_ref_result = dict() 41 | wer_hyp_result = dict() 42 | cer_ref_result = dict() 43 | cer_hyp_result = dict() 44 | 45 | for item in sort_list[:selected_num]: 46 | key = item.name 47 | token_result[key] = new_token[key] 48 | token_int_result[key] = new_token_int[key] 49 | token_gini_result[key] = new_token_gini[key] 50 | text_result[key] = new_text[key] 51 | score_result[key] = new_score[key] 52 | wer_hyp_result[key] = new_wer_hyp[key] 53 | wer_ref_result[key] = new_wer_ref[key] 54 | cer_hyp_result[key] = new_cer_hyp[key] 55 | cer_ref_result[key] = new_cer_ref[key] 56 | if folder: 57 | ter_hyp_result[key] = new_ter_hyp[key] 58 | ter_ref_result[key] = new_ter_ref[key] 59 | 60 | result_dir = new_dir + "/gini-" + str(selected_num) 61 | write_new_info(result_dir, "text", text_result) 62 | write_new_info(result_dir, "score", score_result) 63 | write_new_info(result_dir, "token", token_result) 64 | write_new_info(result_dir, "token_gini", token_gini_result) 65 | write_new_info(result_dir, "token_int", token_int_result) 66 | write_hyp_ref(result_dir + "/score_wer/", "hyp.trn", wer_hyp_result) 67 | write_hyp_ref(result_dir + "/score_wer/", "ref.trn", wer_ref_result) 68 | write_hyp_ref(result_dir + "/score_cer/", "hyp.trn", cer_hyp_result) 69 | write_hyp_ref(result_dir + "/score_cer/", "ref.trn", cer_ref_result) 70 | if folder: 71 | write_hyp_ref(result_dir + "/score_ter/", "hyp.trn", ter_hyp_result) 72 | write_hyp_ref(result_dir + "/score_ter/", "ref.trn", ter_ref_result) 73 | return token_result 74 | 75 | def sort_test_by_gini_v1(old_dir, new_dir, aug_type): 76 | orig_token_gini = load_gini_v1(old_dir + "/token_gini") 77 | selected_num = len(orig_token_gini) 78 | new_token_gini = load_gini_v1(new_dir + "/token_gini") 79 | print(new_dir + "/data.json") 80 | new_json_data = load_data_json(new_dir + "/data.json") 81 | orig_json_data = load_data_json(old_dir + "/data.json") 82 | res_json = copy.deepcopy(new_json_data) 83 | token_gini_result = dict() 84 | orig_ref_dict, orig_hyp_dict, orig_eval_dict = load_result_txt(old_dir + "/result.wrd.txt") 85 | new_ref_dict, new_hyp_dict, new_eval_dict = load_result_txt(new_dir + "/result.wrd.txt") 86 | orig_token_list = get_token_list(orig_json_data) 87 | new_token_list = get_token_list(new_json_data) 88 | orig_word_list, orig_gini_list = get_word_gini_v1(orig_json_data, orig_token_gini, orig_hyp_dict) 89 | new_word_list, new_gini_list = get_word_gini_v1(new_json_data, new_token_gini, new_hyp_dict) 90 | s_t, i_t, c_t, flag = get_gini_threshold(orig_gini_list, orig_eval_dict, orig_word_list) 91 | print(s_t, i_t, c_t) 92 | sort_list = gini_sort_v1(s_t, new_gini_list, orig_word_list, new_word_list, flag)[:selected_num] 93 | res_key = [] 94 | for item in sort_list: 95 | key = item.name 96 | token_gini_result[key] = new_token_gini[key] 97 | res_key.append(key) 98 | 99 | for key in new_json_data["utts"]: 100 | if key.lower() not in res_key: 101 | del res_json["utts"][key] 102 | 103 | result_dir = new_dir + "/gini-" + str(selected_num) 104 | write_new_info(result_dir, "token_gini", token_gini_result) 105 | write_json_data(result_dir, res_json) 106 | return token_gini_result 107 | 108 | 109 | def sort_type_by_diff(aug_type, selected_num, diffs): 110 | class_dict = {} 111 | for key in diffs.keys(): 112 | seq = key.split("&")[1].replace(".wav", "") 113 | if ("f" in seq) | ("n" in seq) | ("a" in seq): 114 | seq = key.split("&")[1].replace(".wav", "")[-1] 115 | if aug_type == "feature": 116 | seq = int(seq) 117 | if aug_type == "noise": 118 | seq = int(seq) % 8 119 | if aug_type == "room": 120 | seq = int(int(seq)/ 4) 121 | if seq not in class_dict.keys(): 122 | class_dict[seq] = list() 123 | item = (key, diffs[key]) 124 | class_dict[seq].append(item) 125 | 126 | result = [] 127 | num = int(selected_num/len(class_dict.keys())) 128 | for key in class_dict.keys(): 129 | sort_result = sorted(class_dict[key], reverse=True, key=lambda kv:(kv[1], kv[0])) 130 | sort_result = sort_result[:num] 131 | result = result + sort_result 132 | 133 | while len(result) < selected_num: 134 | key = random.choice(list(diffs.keys())) 135 | item = (key,diffs[key]) 136 | if item not in result: 137 | result.append(item) 138 | total_gini = 0 139 | for item in result: 140 | total_gini += item[1] 141 | print(aug_type, "selects ", total_gini/len(result)) 142 | 143 | return result 144 | 145 | 146 | def sort_type_by_diff_v1(aug_type, selected_num, diffs): 147 | class_dict = {} 148 | for key in diffs.keys(): 149 | seq = key.split("&")[1].replace(".wav", "").split("-")[0] 150 | if ("f" in seq) | ("n" in seq) | ("a" in seq): 151 | seq = key.split("&")[1].replace(".wav", "")[-1] 152 | if aug_type == "feature": 153 | seq = int(seq) 154 | if aug_type == "noise": 155 | seq = int(seq) % 8 156 | if aug_type == "room": 157 | seq = int(int(seq)/ 4) 158 | if seq not in class_dict.keys(): 159 | class_dict[seq] = list() 160 | item = (key, float(diffs[key])) 161 | class_dict[seq].append(item) 162 | 163 | result = [] 164 | num = int(selected_num/len(class_dict.keys())) 165 | for key in class_dict.keys(): 166 | sort_result = sorted(class_dict[key], reverse=True, key=lambda kv:(kv[1], kv[0])) 167 | sort_result = sort_result[:num] 168 | result = result + sort_result 169 | 170 | while len(result) < selected_num: 171 | key = random.choice(list(diffs.keys())) 172 | item = (key,diffs[key]) 173 | if item not in result: 174 | result.append(item) 175 | total_gini = 0 176 | for item in result: 177 | total_gini += float(item[1]) 178 | print(aug_type, "selects ", total_gini/len(result)) 179 | return result 180 | 181 | 182 | def gen_new_prep_file(new_file, new_audio, dataset, flag, selected_num, test_set): 183 | f2 = open(new_file, "r") 184 | new_result = dict() 185 | line = f2.readline() 186 | while line: 187 | line = line[:-1] 188 | key = line.split(" ")[0] 189 | if "ted" in dataset: 190 | key = key.lower() 191 | new_result[key] = line 192 | line = f2.readline() 193 | 194 | retrain_result = dict() 195 | print(new_audio) 196 | for key in new_audio.keys(): 197 | if ("spk2gender" in new_file) & ("TIMIT" in dataset): 198 | key = key.split("_")[0] 199 | retrain_result[key] = new_result[key] 200 | 201 | new_path = new_file.split("/") 202 | new_path[-2] = test_set + "-"+ flag+ "-" + str(selected_num) 203 | if "ted" in dataset: 204 | new_path[-2] = new_path[-2] + ".orig" 205 | mkdir("/".join(new_path[:-1])) 206 | new_path = "/".join(new_path) 207 | w = open(new_path, "w") 208 | 209 | for key in retrain_result.keys(): 210 | w.write(retrain_result[key] + "\n") 211 | w.close() 212 | return retrain_result 213 | 214 | def gen_new_wavscp(new_wavscp, new_audio, dataset, flag, selected_num, test_set): 215 | f2 = open(new_wavscp, "r") 216 | new_scp = dict() 217 | line = f2.readline() 218 | while line: 219 | line = line[:-1] 220 | key = line.split(" ")[0] 221 | if "ted" in dataset: 222 | key = key.lower() 223 | new_scp[key] = line 224 | line = f2.readline() 225 | retrain_scp = set() 226 | for key in new_audio.keys(): 227 | if "ted" in dataset: 228 | key = key.split("-")[0] 229 | retrain_scp.add(new_scp[key]) 230 | new_path = new_wavscp.split("/") 231 | new_path[-2] = test_set + "-"+ flag+ "-" + str(selected_num) 232 | if "ted" in dataset: 233 | new_path[-2] = new_path[-2] + ".orig" 234 | mkdir("/".join(new_path[:-1])) 235 | new_path = "/".join(new_path) 236 | w = open(new_path, "w") 237 | for line in retrain_scp: 238 | w.write(line + "\n") 239 | w.close() 240 | 241 | def gen_new_recog2file(new_recogfile, new_audio, dataset, flag, selected_num, test_set): 242 | f2 = open(new_recogfile, "r") 243 | new_recog = dict() 244 | line = f2.readline() 245 | while line: 246 | line = line[:-1] 247 | key = line.split(" ")[0].lower() 248 | new_recog[key] = line 249 | line = f2.readline() 250 | recog_result = set() 251 | for key in new_audio.keys(): 252 | if "ted" in dataset: 253 | key = key.split("-")[0] 254 | recog_result.add(new_recog[key]) 255 | new_path = new_recogfile.split("/") 256 | new_path[-2] = test_set + "-"+ flag+ "-" + str(selected_num) 257 | if "ted" in dataset: 258 | new_path[-2] = new_path[-2] + ".orig" 259 | mkdir("/".join(new_path[:-1])) 260 | new_path = "/".join(new_path) 261 | w = open(new_path, "w") 262 | for line in recog_result: 263 | w.write(line + "\n") 264 | w.close() 265 | 266 | 267 | def mkdir(path): 268 | folder = os.path.exists(path) 269 | if not folder: 270 | os.makedirs(path) 271 | 272 | def gen_new_spk2utt(orgi_spk2utt, new_audio, dataset, flag, selected_num, test_set): 273 | f = open(orgi_spk2utt, "r") 274 | spk2utt = dict() 275 | line = f.readline() 276 | while line: 277 | line = line[:-1] 278 | key = line.split(" ")[0] 279 | spk2utt[key] = line.split(" ")[1:] 280 | line = f.readline() 281 | for key in new_audio.keys(): 282 | if "an4" in dataset: 283 | spk = key.split("-")[0] 284 | if "TIMIT" in dataset: 285 | spk = key.split("_")[0] 286 | spk2utt[spk].append(key.replace(".wav", "")) 287 | new_path = orgi_spk2utt.split("/") 288 | new_path[-2] = test_set + "-"+ flag+ "-" + str(selected_num) 289 | if "ted" in dataset: 290 | new_path[-2] = new_path[-2] + ".orig" 291 | mkdir("/".join(new_path[:-1])) 292 | new_path = "/".join(new_path) 293 | w = open(new_path, "w") 294 | for key in spk2utt.keys(): 295 | w.write(key.replace(".wav", "") + " " + " ".join(spk2utt[key]) + "\n") 296 | w.close() 297 | 298 | 299 | def gen_new_spk2utt_v1(new_utt2spk, new_spk2utt, dataset, flag, selected_num, test_set): 300 | spk2utt = dict() 301 | for key in new_utt2spk.keys(): 302 | spk = new_utt2spk[key].split(" ")[-1] 303 | if spk in spk2utt.keys(): 304 | spk2utt[spk].append(key) 305 | else: 306 | spk2utt[spk] = [] 307 | spk2utt[spk].append(key) 308 | 309 | new_path = new_spk2utt.split("/") 310 | new_path[-2] = test_set + "-"+ flag+ "-" + str(selected_num) 311 | if "ted" in dataset: 312 | new_path[-2] = new_path[-2] + ".orig" 313 | mkdir("/".join(new_path[:-1])) 314 | new_path = "/".join(new_path) 315 | w = open(new_path, "w") 316 | for key in spk2utt.keys(): 317 | w.write(key + " " + " ".join(spk2utt[key]) + "\n") 318 | w.close() 319 | 320 | 321 | def gen_new_stm(new_stm, new_audio, dataset, flag, selected_num, test_set): 322 | f2 = open(new_stm, "r") 323 | orgi_stm = dict() 324 | line = f2.readline() 325 | start = [] 326 | while ";;" in line: 327 | start.append(line) 328 | line = f2.readline() 329 | while line: 330 | line = line[:-1] 331 | key = line.split(" ")[0] 332 | if "ted" in dataset: 333 | key = key.lower() 334 | if key in orgi_stm.keys(): 335 | orgi_stm[key].append(line) 336 | else: 337 | orgi_stm[key] = [] 338 | orgi_stm[key].append(line) 339 | line = f2.readline() 340 | stm_result = dict() 341 | for key in new_audio.keys(): 342 | key = key.split("-")[0] 343 | if key not in stm_result.keys(): 344 | stm_result[key] = orgi_stm[key] 345 | 346 | new_path = new_stm.split("/") 347 | new_path[-2] = test_set + "-"+ flag+ "-" + str(selected_num) 348 | if "ted" in dataset: 349 | new_path[-2] = new_path[-2] + ".orig" 350 | mkdir("/".join(new_path[:-1])) 351 | new_path = "/".join(new_path) 352 | w = open(new_path, "w") 353 | for line in start: 354 | w.write(line + "\n") 355 | for key in stm_result.keys(): 356 | for line in stm_result[key]: 357 | w.write(line + "\n") 358 | w.close() 359 | 360 | def copy_file(source, test_set, flag, selected_num): 361 | destination = source.split("/") 362 | destination[-2] = test_set + "-"+ flag+ "-" + str(selected_num) 363 | destination[-2] = destination[-2] + ".orig" 364 | mkdir("/".join(destination[:-1])) 365 | destination = "/".join(destination) 366 | shutil.copyfile(source, destination) 367 | if os.path.exists(destination): 368 | logging.info("copy success") 369 | 370 | def copy_dir(source, test_set, flag, selected_num): 371 | destination = source.split("/") 372 | destination[-2] = test_set + "-"+ flag+ "-" + str(selected_num) 373 | mkdir("/".join(destination[:-1])) 374 | destination = "/".join(destination) 375 | shutil.copytree(source, destination) 376 | 377 | 378 | def test_data_prep(orgi_dir, new_dir, dataset, new_audio, flag, selected_num, test_set): 379 | if "an4" in dataset: 380 | gen_new_prep_file(new_dir + "/text", new_audio, dataset, flag, selected_num, test_set) 381 | gen_new_prep_file(new_dir + "/utt2spk", new_audio, dataset, flag, selected_num, test_set) 382 | gen_new_wavscp(new_dir + "/wav.scp", new_audio, dataset, flag, selected_num, test_set) 383 | gen_new_spk2utt(orgi_dir + "/spk2utt", new_audio, dataset, flag, selected_num, test_set) 384 | 385 | if "TIMIT" in dataset: 386 | gen_new_prep_file(new_dir + "/text", new_audio, dataset, flag, selected_num, test_set) 387 | gen_new_prep_file(new_dir + "/utt2spk", new_audio, dataset, flag, selected_num, test_set) 388 | gen_new_prep_file(new_dir + "/spk2gender", new_audio, dataset, flag, selected_num, test_set) 389 | gen_new_wavscp(new_dir + "/wav.scp", new_audio, dataset, flag, selected_num, test_set) 390 | gen_new_spk2utt(orgi_dir + "/spk2utt", new_audio, dataset, flag, selected_num, test_set) 391 | 392 | if "ted" in dataset: 393 | new_utt2spk = gen_new_prep_file(new_dir + "/utt2spk", new_audio, dataset, flag, selected_num, test_set) 394 | gen_new_spk2utt_v1(new_utt2spk, new_dir + "/spk2utt", dataset, flag, selected_num, test_set) 395 | gen_new_wavscp(new_dir + "/wav.scp", new_audio, dataset, flag, selected_num, test_set) 396 | gen_new_prep_file(new_dir + "/text", new_audio, dataset, flag, selected_num, test_set) 397 | gen_new_prep_file(new_dir + "/segments", new_audio, dataset, flag, selected_num, test_set) 398 | #gen_new_prep_file(new_dir + "/utt2dur", new_audio, dataset, flag, selected_num, test_set) 399 | #gen_new_prep_file(new_dir + "/utt2num_frames", new_audio, dataset, flag, selected_num, test_set) 400 | gen_new_stm(new_dir + "/stm", new_audio, dataset, flag, selected_num, test_set) 401 | gen_new_recog2file(new_dir + "/reco2file_and_channel", new_audio, dataset, flag, selected_num, test_set) 402 | #gen_new_prep_file(new_dir + "/feats.scp", new_audio, dataset, flag, selected_num, test_set) 403 | copy_file(new_dir + "/glm", test_set, flag, selected_num) 404 | #copy_file(new_dir + "/frame_shift",test_set, flag, selected_num) 405 | #copy_dir(new_dir + "/conf", test_set, flag, selected_num) 406 | 407 | 408 | 409 | 410 | 411 | 412 | -------------------------------------------------------------------------------- /espnet2/bin/asr_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import logging 4 | from pathlib import Path 5 | import sys 6 | from typing import Optional 7 | from typing import Sequence 8 | from typing import Tuple 9 | from typing import Union 10 | import os 11 | import numpy as np 12 | import torch 13 | from typeguard import check_argument_types 14 | from typeguard import check_return_type 15 | from typing import List 16 | 17 | from espnet.nets.batch_beam_search import BatchBeamSearch 18 | from espnet.nets.batch_beam_search_online_sim import BatchBeamSearchOnlineSim 19 | # from espnet.nets.beam_search import BeamSearch 20 | from espnet.nets.test_beam_search import TestBeamSearch 21 | from espnet.nets.test_beam_search import TestHypothesis 22 | from espnet.nets.pytorch_backend.transformer.subsampling import TooShortUttError 23 | from espnet.nets.scorer_interface import BatchScorerInterface 24 | from espnet.nets.scorers.ctc import CTCPrefixScorer 25 | from espnet.nets.scorers.length_bonus import LengthBonus 26 | from espnet.utils.cli_utils import get_commandline_args 27 | from espnet2.fileio.datadir_writer import DatadirWriter 28 | from espnet2.tasks.asr import ASRTask 29 | from espnet2.tasks.lm import LMTask 30 | from espnet2.text.build_tokenizer import build_tokenizer 31 | from espnet2.text.token_id_converter import TokenIDConverter 32 | from espnet2.torch_utils.device_funcs import to_device 33 | from espnet2.torch_utils.set_all_random_seed import set_all_random_seed 34 | from espnet2.utils import config_argparse 35 | from espnet2.utils.types import str2bool 36 | from espnet2.utils.types import str2triple_str 37 | from espnet2.utils.types import str_or_none 38 | from espnet.utils.gini_utils import * 39 | from espnet.utils.gini_guide import * 40 | from espnet.utils.retrain_utils import * 41 | 42 | class Speech2Text: 43 | """Speech2Text class 44 | 45 | Examples: 46 | >>> import soundfile 47 | >>> speech2text = Speech2Text("asr_config.yml", "asr.pth") 48 | >>> audio, rate = soundfile.read("speech.wav") 49 | >>> speech2text(audio) 50 | [(text, token, token_int, hypothesis object), ...] 51 | 52 | """ 53 | 54 | def __init__( 55 | self, 56 | asr_train_config: Union[Path, str], 57 | asr_model_file: Union[Path, str] = None, 58 | lm_train_config: Union[Path, str] = None, 59 | lm_file: Union[Path, str] = None, 60 | token_type: str = None, 61 | bpemodel: str = None, 62 | device: str = "cpu", 63 | maxlenratio: float = 0.0, 64 | minlenratio: float = 0.0, 65 | batch_size: int = 1, 66 | dtype: str = "float32", 67 | beam_size: int = 20, 68 | ctc_weight: float = 0.5, 69 | lm_weight: float = 1.0, 70 | penalty: float = 0.0, 71 | nbest: int = 1, 72 | streaming: bool = False, 73 | ): 74 | assert check_argument_types() 75 | 76 | # 1. Build ASR model 77 | scorers = {} 78 | asr_model, asr_train_args = ASRTask.build_model_from_file( 79 | asr_train_config, asr_model_file, device 80 | ) 81 | asr_model.to(dtype=getattr(torch, dtype)).eval() 82 | decoder = asr_model.decoder 83 | ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos) 84 | token_list = asr_model.token_list 85 | scorers.update( 86 | decoder=decoder, 87 | ctc=ctc, 88 | length_bonus=LengthBonus(len(token_list)), 89 | ) 90 | logging.info(f"asr_model: {asr_model}") 91 | # 2. Build Language model 92 | if lm_train_config is not None: 93 | lm, lm_train_args = LMTask.build_model_from_file( 94 | lm_train_config, lm_file, device 95 | ) 96 | scorers["lm"] = lm.lm 97 | 98 | # 3. Build BeamSearch object 99 | weights = dict( 100 | decoder=1.0 - ctc_weight, 101 | ctc=ctc_weight, 102 | lm=lm_weight, 103 | length_bonus=penalty, 104 | ) 105 | test_beam_search = TestBeamSearch( 106 | beam_size=beam_size, 107 | weights=weights, 108 | scorers=scorers, 109 | sos=asr_model.sos, 110 | eos=asr_model.eos, 111 | vocab_size=len(token_list), 112 | token_list=token_list, 113 | pre_beam_score_key=None if ctc_weight == 1.0 else "full", 114 | ) 115 | # TODO(karita): make all scorers batchfied 116 | if batch_size == 1: 117 | non_batch = [ 118 | k 119 | for k, v in test_beam_search.full_scorers.items() 120 | if not isinstance(v, BatchScorerInterface) 121 | ] 122 | if len(non_batch) == 0: 123 | if streaming: 124 | test_beam_search.__class__ = BatchBeamSearchOnlineSim 125 | test_beam_search.set_streaming_config(asr_train_config) 126 | logging.info("BatchBeamSearchOnlineSim implementation is selected.") 127 | else: 128 | test_beam_search.__class__ = BatchBeamSearch 129 | logging.info("BatchBeamSearch implementation is selected.") 130 | else: 131 | logging.warning( 132 | f"As non-batch scorers {non_batch} are found, " 133 | f"fall back to non-batch implementation." 134 | ) 135 | test_beam_search.to(device=device, dtype=getattr(torch, dtype)).eval() 136 | for scorer in scorers.values(): 137 | if isinstance(scorer, torch.nn.Module): 138 | scorer.to(device=device, dtype=getattr(torch, dtype)).eval() 139 | logging.info(f"test_beam_search: {test_beam_search}") 140 | logging.info(f"Decoding device={device}, dtype={dtype}") 141 | 142 | # 4. [Optional] Build Text converter: e.g. bpe-sym -> Text 143 | if token_type is None: 144 | token_type = asr_train_args.token_type 145 | if bpemodel is None: 146 | bpemodel = asr_train_args.bpemodel 147 | 148 | if token_type is None: 149 | tokenizer = None 150 | elif token_type == "bpe": 151 | if bpemodel is not None: 152 | tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel) 153 | else: 154 | tokenizer = None 155 | else: 156 | tokenizer = build_tokenizer(token_type=token_type) 157 | converter = TokenIDConverter(token_list=token_list) 158 | logging.info(f"Text tokenizer: {tokenizer}") 159 | 160 | self.asr_model = asr_model 161 | self.asr_train_args = asr_train_args 162 | self.converter = converter 163 | self.tokenizer = tokenizer 164 | self.test_beam_search = test_beam_search 165 | self.maxlenratio = maxlenratio 166 | self.minlenratio = minlenratio 167 | self.device = device 168 | self.dtype = dtype 169 | self.nbest = nbest 170 | 171 | @torch.no_grad() 172 | def __call__( 173 | self, speech: Union[torch.Tensor, np.ndarray] 174 | ) -> List[Tuple[Optional[str], List[str], List[int], TestHypothesis]]: 175 | """Inference 176 | 177 | Args: 178 | data: Input speech data 179 | Returns: 180 | text, token, token_int, hyp 181 | 182 | """ 183 | assert check_argument_types() 184 | 185 | # Input as audio signal 186 | if isinstance(speech, np.ndarray): 187 | speech = torch.tensor(speech) 188 | 189 | # data: (Nsamples,) -> (1, Nsamples) 190 | speech = speech.unsqueeze(0).to(getattr(torch, self.dtype)) 191 | # lenghts: (1,) 192 | lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1)) 193 | batch = {"speech": speech, "speech_lengths": lengths} 194 | 195 | # a. To device 196 | batch = to_device(batch, device=self.device) 197 | 198 | # b. Forward Encoder 199 | enc, _ = self.asr_model.encode(**batch) 200 | assert len(enc) == 1, len(enc) 201 | 202 | # c. Passed the encoder result and the beam search 203 | nbest_hyps, all_ginis = self.test_beam_search( 204 | x=enc[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio 205 | ) 206 | nbest_hyps = nbest_hyps[: self.nbest] 207 | 208 | results = [] 209 | for hyp in nbest_hyps: 210 | assert isinstance(hyp, TestHypothesis), type(hyp) 211 | 212 | # remove sos/eos and get results 213 | token_int = hyp.yseq[1:-1].tolist() 214 | 215 | # remove blank symbol id, which is assumed to be 0 216 | token_int = list(filter(lambda x: x != 0, token_int)) 217 | sum_gini, gini_list = cau_sum_ginis(token_int, all_ginis) 218 | # Change integer-ids to tokens 219 | token = self.converter.ids2tokens(token_int) 220 | if self.tokenizer is not None: 221 | text = self.tokenizer.tokens2text(token) 222 | else: 223 | text = None 224 | results.append((text, token, token_int, hyp)) 225 | 226 | assert check_return_type(results) 227 | return results, sum_gini, gini_list 228 | 229 | def str_2_bool(s): 230 | return True if s.lower() =='true' else False 231 | 232 | def test( 233 | output_dir: str, 234 | maxlenratio: float, 235 | minlenratio: float, 236 | batch_size: int, 237 | dtype: str, 238 | beam_size: int, 239 | ngpu: int, 240 | seed: int, 241 | ctc_weight: float, 242 | lm_weight: float, 243 | penalty: float, 244 | nbest: int, 245 | num_workers: int, 246 | log_level: Union[int, str], 247 | data_path_and_name_and_type: Sequence[Tuple[str, str, str]], 248 | key_file: Optional[str], 249 | asr_train_config: str, 250 | asr_model_file: str, 251 | lm_train_config: Optional[str], 252 | lm_file: Optional[str], 253 | word_lm_train_config: Optional[str], 254 | word_lm_file: Optional[str], 255 | token_type: Optional[str], 256 | bpemodel: Optional[str], 257 | allow_variable_data_keys: bool, 258 | streaming: bool, 259 | orig_flag: str, 260 | orig_dir: str, 261 | need_decode: bool, 262 | ): 263 | assert check_argument_types() 264 | if batch_size > 1: 265 | raise NotImplementedError("batch decoding is not implemented") 266 | if word_lm_train_config is not None: 267 | raise NotImplementedError("Word LM is not implemented") 268 | if ngpu > 1: 269 | raise NotImplementedError("only single GPU decoding is supported") 270 | 271 | logging.basicConfig( 272 | level=log_level, 273 | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", 274 | ) 275 | 276 | if ngpu >= 1: 277 | device = "cuda" 278 | else: 279 | device = "cpu" 280 | 281 | orig_flag = str_2_bool(orig_flag) 282 | # 1. Set random-seed 283 | set_all_random_seed(seed) 284 | 285 | # 2. Build speech2text 286 | speech2text = Speech2Text( 287 | asr_train_config=asr_train_config, 288 | asr_model_file=asr_model_file, 289 | lm_train_config=lm_train_config, 290 | lm_file=lm_file, 291 | token_type=token_type, 292 | bpemodel=bpemodel, 293 | device=device, 294 | maxlenratio=maxlenratio, 295 | minlenratio=minlenratio, 296 | dtype=dtype, 297 | beam_size=beam_size, 298 | ctc_weight=ctc_weight, 299 | lm_weight=lm_weight, 300 | penalty=penalty, 301 | nbest=nbest, 302 | streaming=streaming, 303 | ) 304 | 305 | # 3. Build data-iterator 306 | loader = ASRTask.build_streaming_iterator( 307 | data_path_and_name_and_type, 308 | dtype=dtype, 309 | batch_size=batch_size, 310 | key_file=key_file, 311 | num_workers=num_workers, 312 | preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False), 313 | collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False), 314 | allow_variable_data_keys=allow_variable_data_keys, 315 | inference=True, 316 | ) 317 | 318 | # 7 .Start for-loop 319 | # FIXME(kamo): The output format should be discussed about 320 | if orig_flag: 321 | orig_sum_gini = {} 322 | orig_gini_list = {} 323 | else: 324 | new_sum_gini = {} 325 | new_gini_list = {} 326 | with DatadirWriter(output_dir) as writer: 327 | for keys, batch in loader: 328 | key = keys[0] 329 | assert isinstance(batch, dict), type(batch) 330 | assert all(isinstance(s, str) for s in keys), keys 331 | _bs = len(next(iter(batch.values()))) 332 | assert len(keys) == _bs, f"{len(keys)} != {_bs}" 333 | batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")} 334 | 335 | # N-best list of (text, token, token_int, hyp_object) 336 | try: 337 | results, sum_gini, gini_list = speech2text(**batch) 338 | #logging.info(f"zlist_tensor is {zlist_tensor}") 339 | except TooShortUttError as e: 340 | logging.warning(f"Utterance {keys} {e}") 341 | hyp = TestHypothesis(score=0.0, scores={}, states={}, yseq=[]) 342 | results = [[" ", [""], [2], hyp]] * nbest 343 | 344 | # Only supporting batch_size==1 345 | if orig_flag: 346 | orig_sum_gini[key] = str(sum_gini) 347 | orig_gini_list[key] = gini_list 348 | else: 349 | new_sum_gini[key] = str(sum_gini) 350 | new_gini_list[key] = gini_list 351 | cov = cau_coverage(all_z_list, 0) 352 | #act_cell = get_activate_cell(all_z_list, 0) 353 | if need_decode: 354 | for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results): 355 | # Create a directory: outdir/{n}best_recog 356 | ibest_writer = writer[f"{n}best_recog"] 357 | # Write the result to each file 358 | ibest_writer["token"][key] = " ".join(token) 359 | ibest_writer["token_int"][key] = " ".join(map(str, token_int)) 360 | ibest_writer["score"][key] = str(hyp.score) 361 | ibest_writer["sum_gini"][key] = str(sum_gini) 362 | ibest_writer["token_gini"][key] = " ".join(map(str, gini_list)) 363 | if text is not None: 364 | ibest_writer["text"][key] = text 365 | if orig_flag: 366 | return orig_sum_gini, orig_gini_list 367 | else: 368 | 369 | return new_sum_gini, new_gini_list 370 | 371 | 372 | def merge_result(file_path, gini_dict): 373 | f = open(file_path, 'w') 374 | for key in gini_dict.keys(): 375 | if isinstance(gini_dict[key], str): 376 | line = key + " " + gini_dict[key] + "\n" 377 | else: 378 | line = key + " " + " ".join(map(str, gini_dict[key])) + "\n" 379 | f.write(line) 380 | f.close() 381 | 382 | 383 | 384 | 385 | def get_parser(): 386 | parser = config_argparse.ArgumentParser( 387 | description="ASR Decoding", 388 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 389 | ) 390 | 391 | # Note(kamo): Use '_' instead of '-' as separator. 392 | # '-' is confusing if written in yaml. 393 | parser.add_argument( 394 | "--log_level", 395 | type=lambda x: x.upper(), 396 | default="INFO", 397 | choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), 398 | help="The verbose level of logging", 399 | ) 400 | 401 | parser.add_argument("--output_dir", type=str, required=True) 402 | parser.add_argument( 403 | "--ngpu", 404 | type=int, 405 | default=0, 406 | help="The number of gpus. 0 indicates CPU mode", 407 | ) 408 | parser.add_argument("--seed", type=int, default=0, help="Random seed") 409 | parser.add_argument( 410 | "--dtype", 411 | default="float32", 412 | choices=["float16", "float32", "float64"], 413 | help="Data type", 414 | ) 415 | parser.add_argument( 416 | "--num_workers", 417 | type=int, 418 | default=1, 419 | help="The number of workers used for DataLoader", 420 | ) 421 | 422 | group = parser.add_argument_group("Input data related") 423 | group.add_argument( 424 | "--data_path_and_name_and_type", 425 | type=str2triple_str, 426 | required=True, 427 | action="append", 428 | ) 429 | group.add_argument("--key_file", type=str_or_none) 430 | group.add_argument("--allow_variable_data_keys", type=str2bool, default=False) 431 | 432 | group = parser.add_argument_group("The model configuration related") 433 | group.add_argument("--asr_train_config", type=str, required=True) 434 | group.add_argument("--asr_model_file", type=str, required=True) 435 | group.add_argument("--lm_train_config", type=str) 436 | group.add_argument("--lm_file", type=str) 437 | group.add_argument("--word_lm_train_config", type=str) 438 | group.add_argument("--word_lm_file", type=str) 439 | group.add_argument("--orig_flag", type=str, default="False") 440 | group.add_argument("--orig_dir", type=str, required=True) 441 | group = parser.add_argument_group("Beam-search related") 442 | group.add_argument( 443 | "--batch_size", 444 | type=int, 445 | default=1, 446 | help="The batch size for inference", 447 | ) 448 | group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses") 449 | group.add_argument("--beam_size", type=int, default=20, help="Beam size") 450 | group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty") 451 | group.add_argument( 452 | "--maxlenratio", 453 | type=float, 454 | default=0.0, 455 | help="Input length ratio to obtain max output length. " 456 | "If maxlenratio=0.0 (default), it uses a end-detect " 457 | "function " 458 | "to automatically find maximum hypothesis lengths", 459 | ) 460 | group.add_argument( 461 | "--minlenratio", 462 | type=float, 463 | default=0.0, 464 | help="Input length ratio to obtain min output length", 465 | ) 466 | group.add_argument( 467 | "--ctc_weight", 468 | type=float, 469 | default=0.5, 470 | help="CTC weight in joint decoding", 471 | ) 472 | group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight") 473 | group.add_argument("--streaming", type=str2bool, default=False) 474 | 475 | group = parser.add_argument_group("Text converter related") 476 | group.add_argument( 477 | "--token_type", 478 | type=str_or_none, 479 | default=None, 480 | choices=["char", "bpe", None], 481 | help="The token type for ASR model. " 482 | "If not given, refers from the training args", 483 | ) 484 | group.add_argument( 485 | "--bpemodel", 486 | type=str_or_none, 487 | default=None, 488 | help="The model path of sentencepiece. " 489 | "If not given, refers from the training args", 490 | ) 491 | group.add_argument( 492 | "--need_decode", 493 | type=str, 494 | ) 495 | group.add_argument( 496 | "--selected_num", 497 | type=int, 498 | default=None, 499 | ) 500 | group.add_argument( 501 | "--dataset", 502 | type=str, 503 | default=None, 504 | ) 505 | 506 | 507 | return parser 508 | 509 | 510 | 511 | def main(cmd=None): 512 | print(get_commandline_args(), file=sys.stderr) 513 | parser = get_parser() 514 | args = parser.parse_args(cmd) 515 | args.need_decode = str2bool(args.need_decode) 516 | selected_num = args.selected_num 517 | dataset = args.dataset 518 | kwargs = vars(args) 519 | kwargs.pop("config", None) 520 | kwargs.pop("selected_num") 521 | kwargs.pop("dataset") 522 | test_set=args.data_path_and_name_and_type[0][0].split("/")[-1] 523 | if args.need_decode: 524 | sum_ginis, gini_lists = test(**kwargs) 525 | else: 526 | if ("test" in str(args.data_path_and_name_and_type)) & ("gini" not in test_set) : 527 | aug_type = test_set.split("-")[1] 528 | gini_audio = sort_test_by_gini(args.orig_dir + "/test-orig", args.orig_dir + "/" + test_set, aug_type, selected_num) 529 | test_data_prep("data/test-orig", "data/" + test_set, dataset, gini_audio, "gini", selected_num, test_set) 530 | 531 | 532 | 533 | if __name__ == "__main__": 534 | main() 535 | -------------------------------------------------------------------------------- /espnet/nets/test_beam_search.py: -------------------------------------------------------------------------------- 1 | """Beam search module.""" 2 | 3 | from itertools import chain 4 | import logging 5 | from typing import Any 6 | from typing import Dict 7 | from typing import List 8 | from typing import NamedTuple 9 | from typing import Tuple 10 | from typing import Union 11 | import numpy as np 12 | import torch 13 | 14 | from espnet.nets.e2e_asr_common import end_detect 15 | from espnet.nets.scorer_interface import PartialScorerInterface 16 | from espnet.nets.scorer_interface import ScorerInterface 17 | from espnet.utils.gini_utils import * 18 | 19 | 20 | class TestHypothesis(NamedTuple): 21 | """Test Hypothesis data type.""" 22 | 23 | yseq: torch.Tensor 24 | score: Union[float, torch.Tensor] = 0 25 | scores: Dict[str, Union[float, torch.Tensor]] = dict() 26 | states: Dict[str, Any] = dict() 27 | def asdict(self) -> dict: 28 | """Convert data to JSON-friendly dict.""" 29 | return self._replace( 30 | yseq=self.yseq.tolist(), 31 | score=float(self.score), 32 | scores={k: float(v) for k, v in self.scores.items()}, 33 | )._asdict() 34 | 35 | 36 | class TestBeamSearch(torch.nn.Module): 37 | """Test Beam search implementation.""" 38 | 39 | def __init__( 40 | self, 41 | scorers: Dict[str, ScorerInterface], 42 | weights: Dict[str, float], 43 | beam_size: int, 44 | vocab_size: int, 45 | sos: int, 46 | eos: int, 47 | token_list: List[str] = None, 48 | pre_beam_ratio: float = 1.5, 49 | pre_beam_score_key: str = None, 50 | ): 51 | """Initialize beam search. 52 | 53 | Args: 54 | scorers (dict[str, ScorerInterface]): Dict of decoder modules 55 | e.g., Decoder, CTCPrefixScorer, LM 56 | The scorer will be ignored if it is `None` 57 | weights (dict[str, float]): Dict of weights for each scorers 58 | The scorer will be ignored if its weight is 0 59 | beam_size (int): The number of hypotheses kept during search 60 | vocab_size (int): The number of vocabulary 61 | sos (int): Start of sequence id 62 | eos (int): End of sequence id 63 | token_list (list[str]): List of tokens for debug log 64 | pre_beam_score_key (str): key of scores to perform pre-beam search 65 | pre_beam_ratio (float): beam size in the pre-beam search 66 | will be `int(pre_beam_ratio * beam_size)` 67 | 68 | """ 69 | super().__init__() 70 | # set scorers 71 | self.weights = weights 72 | self.scorers = dict() 73 | self.full_scorers = dict() 74 | self.part_scorers = dict() 75 | # this module dict is required for recursive cast 76 | # `self.to(device, dtype)` in `recog.py` 77 | self.nn_dict = torch.nn.ModuleDict() 78 | for k, v in scorers.items(): 79 | w = weights.get(k, 0) 80 | if w == 0 or v is None: 81 | continue 82 | assert isinstance( 83 | v, ScorerInterface 84 | ), f"{k} ({type(v)}) does not implement ScorerInterface" 85 | self.scorers[k] = v 86 | if isinstance(v, PartialScorerInterface): 87 | self.part_scorers[k] = v 88 | else: 89 | self.full_scorers[k] = v 90 | if isinstance(v, torch.nn.Module): 91 | self.nn_dict[k] = v 92 | 93 | # set configurations 94 | self.sos = sos 95 | self.eos = eos 96 | self.token_list = token_list 97 | self.pre_beam_size = int(pre_beam_ratio * beam_size) 98 | self.beam_size = beam_size 99 | self.n_vocab = vocab_size 100 | if ( 101 | pre_beam_score_key is not None 102 | and pre_beam_score_key != "full" 103 | and pre_beam_score_key not in self.full_scorers 104 | ): 105 | raise KeyError(f"{pre_beam_score_key} is not found in {self.full_scorers}") 106 | self.pre_beam_score_key = pre_beam_score_key 107 | self.do_pre_beam = ( 108 | self.pre_beam_score_key is not None 109 | and self.pre_beam_size < self.n_vocab 110 | and len(self.part_scorers) > 0 111 | ) 112 | 113 | def init_hyp(self, x: torch.Tensor) -> List[TestHypothesis]: 114 | """Get an initial hypothesis data. 115 | 116 | Args: 117 | x (torch.Tensor): The encoder output feature 118 | 119 | Returns: 120 | Hypothesis: The initial hypothesis. 121 | 122 | """ 123 | init_states = dict() 124 | init_scores = dict() 125 | for k, d in self.scorers.items(): 126 | init_states[k] = d.init_state(x) 127 | init_scores[k] = 0.0 128 | return [ 129 | TestHypothesis( 130 | score=0.0, 131 | scores=init_scores, 132 | states=init_states, 133 | yseq=torch.tensor([self.sos], device=x.device), 134 | ) 135 | ] 136 | 137 | @staticmethod 138 | def append_token(xs: torch.Tensor, x: int) -> torch.Tensor: 139 | """Append new token to prefix tokens. 140 | 141 | Args: 142 | xs (torch.Tensor): The prefix token 143 | x (int): The new token to append 144 | 145 | Returns: 146 | torch.Tensor: New tensor contains: xs + [x] with xs.dtype and xs.device 147 | 148 | """ 149 | x = torch.tensor([x], dtype=xs.dtype, device=xs.device) 150 | #logging.info(f"call append_token,get {torch.cat((xs, x))}") 151 | return torch.cat((xs, x)) 152 | 153 | def score_full( 154 | self, hyp: TestHypothesis, x: torch.Tensor 155 | ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: 156 | """Score new hypothesis by `self.full_scorers`. 157 | call rnn decoder 158 | Args: 159 | hyp (Hypothesis): Hypothesis with prefix tokens to score 160 | x (torch.Tensor): Corresponding input feature 161 | 162 | Returns: 163 | Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of 164 | score dict of `hyp` that has string keys of `self.full_scorers` 165 | and tensor score values of shape: `(self.n_vocab,)`, 166 | and state dict that has string keys 167 | and state values of `self.full_scorers` 168 | 169 | """ 170 | #logging.info("call score_full") 171 | scores = dict() 172 | states = dict() 173 | """ 174 | full_scorers:{'decoder': RNNDecoder( 175 | (embed): Embedding(42, 320) 176 | (dropout_emb): Dropout(p=0.0, inplace=False) 177 | (decoder): ModuleList( 178 | (0): LSTMCell(640, 320) 179 | ) 180 | (dropout_dec): ModuleList( 181 | (0): Dropout(p=0.0, inplace=False) 182 | ) 183 | (output): Linear(in_features=320, out_features=42, bias=True) 184 | (att_list): ModuleList( 185 | (0): AttLoc( 186 | (mlp_enc): Linear(in_features=320, out_features=320, bias=True) 187 | (mlp_dec): Linear(in_features=320, out_features=320, bias=False) 188 | (mlp_att): Linear(in_features=10, out_features=320, bias=False) 189 | (loc_conv): Conv2d(1, 10, kernel_size=(1, 201), stride=(1, 1), padding=(0, 100), bias=False) 190 | (gvec): Linear(in_features=320, out_features=1, bias=True) 191 | ) 192 | ) 193 | )} 194 | """ 195 | for k, d in self.full_scorers.items(): 196 | scores[k], states[k], outps = d.score(hyp.yseq, hyp.states[k], x) 197 | #logging.info("score_full get {scores},{outps}") 198 | return scores, states, outps 199 | 200 | def score_partial( 201 | self, hyp: TestHypothesis, ids: torch.Tensor, x: torch.Tensor 202 | ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: 203 | """Score new hypothesis by `self.part_scorers`. 204 | 205 | Args: 206 | hyp (Hypothesis): Hypothesis with prefix tokens to score 207 | ids (torch.Tensor): 1D tensor of new partial tokens to score 208 | x (torch.Tensor): Corresponding input feature 209 | 210 | Returns: 211 | Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of 212 | score dict of `hyp` that has string keys of `self.part_scorers` 213 | and tensor score values of shape: `(len(ids),)`, 214 | and state dict that has string keys 215 | and state values of `self.part_scorers` 216 | 217 | """ 218 | scores = dict() 219 | states = dict() 220 | for k, d in self.part_scorers.items(): 221 | scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], x) 222 | return scores, states 223 | 224 | def beam( 225 | self, weighted_scores: torch.Tensor, ids: torch.Tensor 226 | ) -> Tuple[torch.Tensor, torch.Tensor]: 227 | """Compute topk full token ids and partial token ids. 228 | 229 | Args: 230 | weighted_scores (torch.Tensor): The weighted sum scores for each tokens. 231 | Its shape is `(self.n_vocab,)`. 232 | ids (torch.Tensor): The partial token ids to compute topk 233 | 234 | Returns: 235 | Tuple[torch.Tensor, torch.Tensor]: 236 | The topk full token ids and partial token ids. 237 | Their shapes are `(self.beam_size,)` 238 | 239 | """ 240 | # no pre beam performed 241 | if weighted_scores.size(0) == ids.size(0): 242 | top_ids = weighted_scores.topk(self.beam_size)[1] 243 | return top_ids, top_ids 244 | 245 | # mask pruned in pre-beam not to select in topk 246 | tmp = weighted_scores[ids] 247 | weighted_scores[:] = -float("inf") 248 | weighted_scores[ids] = tmp 249 | top_ids = weighted_scores.topk(self.beam_size)[1] 250 | local_ids = weighted_scores[ids].topk(self.beam_size)[1] 251 | return top_ids, local_ids 252 | 253 | @staticmethod 254 | def merge_scores( 255 | prev_scores: Dict[str, float], 256 | next_full_scores: Dict[str, torch.Tensor], 257 | full_idx: int, 258 | next_part_scores: Dict[str, torch.Tensor], 259 | part_idx: int, 260 | ) -> Dict[str, torch.Tensor]: 261 | """Merge scores for new hypothesis. 262 | 263 | Args: 264 | prev_scores (Dict[str, float]): 265 | The previous hypothesis scores by `self.scorers` 266 | next_full_scores (Dict[str, torch.Tensor]): scores by `self.full_scorers` 267 | full_idx (int): The next token id for `next_full_scores` 268 | next_part_scores (Dict[str, torch.Tensor]): 269 | scores of partial tokens by `self.part_scorers` 270 | part_idx (int): The new token id for `next_part_scores` 271 | 272 | Returns: 273 | Dict[str, torch.Tensor]: The new score dict. 274 | Its keys are names of `self.full_scorers` and `self.part_scorers`. 275 | Its values are scalar tensors by the scorers. 276 | 277 | """ 278 | new_scores = dict() 279 | for k, v in next_full_scores.items(): 280 | new_scores[k] = prev_scores[k] + v[full_idx] 281 | for k, v in next_part_scores.items(): 282 | new_scores[k] = prev_scores[k] + v[part_idx] 283 | return new_scores 284 | 285 | def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any: 286 | """Merge states for new hypothesis. 287 | 288 | Args: 289 | states: states of `self.full_scorers` 290 | part_states: states of `self.part_scorers` 291 | part_idx (int): The new token id for `part_scores` 292 | 293 | Returns: 294 | Dict[str, torch.Tensor]: The new score dict. 295 | Its keys are names of `self.full_scorers` and `self.part_scorers`. 296 | Its values are states of the scorers. 297 | 298 | """ 299 | new_states = dict() 300 | for k, v in states.items(): 301 | new_states[k] = v 302 | for k, d in self.part_scorers.items(): 303 | new_states[k] = d.select_state(part_states[k], part_idx) 304 | return new_states 305 | 306 | def search( 307 | self, running_hyps: List[TestHypothesis], x: torch.Tensor 308 | ) -> List[TestHypothesis]: 309 | """Search new tokens for running hypotheses and encoded speech x. 310 | 311 | Args: 312 | running_hyps (List[Hypothesis]): Running hypotheses on beam 313 | x (torch.Tensor): Encoded speech feature (T, D) 314 | 315 | Returns: 316 | List[Hypotheses]: Best sorted hypotheses 317 | 318 | """ 319 | #logging.info("call search") 320 | best_hyps = [] 321 | new_ginis = {} 322 | part_ids = torch.arange(self.n_vocab, device=x.device) # no pre-beam 323 | for hyp in running_hyps: 324 | # scoring 325 | weighted_scores = torch.zeros(self.n_vocab, dtype=x.dtype, device=x.device) 326 | scores, states, outps = self.score_full(hyp, x) 327 | for k in self.full_scorers: 328 | weighted_scores += self.weights[k] * scores[k] 329 | yseq = hyp.yseq.numpy() 330 | if len(yseq) > 1: 331 | yseq = yseq[1:] 332 | gini = caul_gini(outps) 333 | yseq_key = " ".join('%s' %seq for seq in yseq) 334 | new_ginis[yseq_key] = gini 335 | 336 | # partial scoring 337 | if self.do_pre_beam: 338 | pre_beam_scores = ( 339 | weighted_scores 340 | if self.pre_beam_score_key == "full" 341 | else scores[self.pre_beam_score_key] 342 | ) 343 | part_ids = torch.topk(pre_beam_scores, self.pre_beam_size)[1] 344 | part_scores, part_states = self.score_partial(hyp, part_ids, x) 345 | for k in self.part_scorers: 346 | weighted_scores[part_ids] += self.weights[k] * part_scores[k] 347 | # add previous hyp score 348 | weighted_scores += hyp.score 349 | # update hyps 350 | for j, part_j in zip(*self.beam(weighted_scores, part_ids)): 351 | # will be (2 x beam at most) 352 | #logging.info(f"{j},{part_j},{hyp.yseq}") 353 | best_hyps.append( 354 | TestHypothesis( 355 | score=weighted_scores[j], 356 | yseq=self.append_token(hyp.yseq, j), 357 | scores=self.merge_scores( 358 | hyp.scores, scores, j, part_scores, part_j 359 | ), 360 | states=self.merge_states(states, part_states, part_j), 361 | ) 362 | ) 363 | 364 | # sort and prune 2 x beam -> beam 365 | best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[ 366 | : min(len(best_hyps), self.beam_size) 367 | ] 368 | for best in best_hyps: 369 | last_seq = best.yseq.numpy()[-1] 370 | 371 | ''' 372 | best_hyps:Sort by probability 373 | ''' 374 | return best_hyps,new_ginis 375 | 376 | def forward( 377 | self, x: torch.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0 378 | ) -> List[TestHypothesis]: 379 | #logging.info("call forward") 380 | """Perform beam search. 381 | 382 | Args: 383 | x (torch.Tensor): Encoded speech feature (T, D) 384 | maxlenratio (float): Input length ratio to obtain max output length. 385 | If maxlenratio=0.0 (default), it uses a end-detect function 386 | to automatically find maximum hypothesis lengths 387 | minlenratio (float): Input length ratio to obtain min output length. 388 | 389 | Returns: 390 | list[Hypothesis]: N-best decoding results 391 | 392 | """ 393 | # set length bounds 394 | if maxlenratio == 0: 395 | maxlen = x.shape[0] 396 | else: 397 | maxlen = max(1, int(maxlenratio * x.size(0))) 398 | minlen = int(minlenratio * x.size(0)) 399 | logging.info("decoder input length: " + str(x.shape[0])) 400 | logging.info("max output length: " + str(maxlen)) 401 | logging.info("min output length: " + str(minlen)) 402 | 403 | # main loop of prefix search 404 | running_hyps = self.init_hyp(x) 405 | ended_hyps = [] 406 | all_ginis = {} 407 | for i in range(maxlen): 408 | best, new_ginis = self.search(running_hyps, x) 409 | if len(new_ginis) > 0: 410 | all_ginis.update(new_ginis) 411 | # post process of one iteration 412 | running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps) 413 | # end detection 414 | if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i): 415 | logging.info(f"end detected at {i}") 416 | break 417 | if len(running_hyps) == 0: 418 | logging.info("no hypothesis. Finish decoding.") 419 | break 420 | else: 421 | logging.debug(f"remained hypotheses: {len(running_hyps)}") 422 | nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True) 423 | # check the number of hypotheses reaching to eos 424 | if len(nbest_hyps) == 0: 425 | logging.warning( 426 | "there is no N-best results, perform recognition " 427 | "again with smaller minlenratio." 428 | ) 429 | return ( 430 | [] 431 | if minlenratio < 0.1 432 | else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1)) 433 | ) 434 | 435 | # report the best result 436 | best = nbest_hyps[0] 437 | for k, v in best.scores.items(): 438 | logging.info( 439 | f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}" 440 | ) 441 | logging.info(f"total log probability: {best.score:.2f}") 442 | logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}") 443 | logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}") 444 | logging.info(f"yseq: {best.yseq.shape}") 445 | if self.token_list is not None: 446 | logging.info( 447 | "best hypo: " 448 | + "".join([self.token_list[x] for x in best.yseq[1:-1]]) 449 | + "\n" 450 | ) 451 | #logging.info(f"call forward,get {nbest_hyps[0]}") 452 | logging.info(f"forward return nbest_hyps, the length is {len(nbest_hyps)}") 453 | return nbest_hyps, all_ginis 454 | 455 | 456 | 457 | def post_process( 458 | self, 459 | i: int, 460 | maxlen: int, 461 | maxlenratio: float, 462 | running_hyps: List[TestHypothesis], 463 | ended_hyps: List[TestHypothesis], 464 | ) -> List[TestHypothesis]: 465 | """Perform post-processing of beam search iterations. 466 | Args: 467 | i (int): The length of hypothesis tokens. 468 | maxlen (int): The maximum length of tokens in beam search. 469 | maxlenratio (int): The maximum length ratio in beam search. 470 | running_hyps (List[Hypothesis]): The running hypotheses in beam search. 471 | ended_hyps (List[Hypothesis]): The ended hypotheses in beam search. 472 | 473 | Returns: 474 | List[Hypothesis]: The new running hypotheses. 475 | 476 | """ 477 | #logging.info("call post process") 478 | logging.debug(f"the number of running hypotheses: {len(running_hyps)}") 479 | if self.token_list is not None: 480 | logging.debug( 481 | "best hypo: " 482 | + "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]]) 483 | ) 484 | # add eos in the final loop to avoid that there are no ended hyps 485 | if i == maxlen - 1: 486 | logging.info("adding in the last position in the loop") 487 | running_hyps = [ 488 | h._replace(yseq=self.append_token(h.yseq, self.eos)) 489 | for h in running_hyps 490 | ] 491 | 492 | # add ended hypotheses to a final list, and removed them from current hypotheses 493 | # (this will be a problem, number of hyps < beam) 494 | remained_hyps = [] 495 | for hyp in running_hyps: 496 | if hyp.yseq[-1] == self.eos: 497 | # e.g., Word LM needs to add final score 498 | for k, d in chain(self.full_scorers.items(), self.part_scorers.items()): 499 | s = d.final_score(hyp.states[k]) 500 | hyp.scores[k] += s 501 | hyp = hyp._replace(score=hyp.score + self.weights[k] * s) 502 | ended_hyps.append(hyp) 503 | else: 504 | remained_hyps.append(hyp) 505 | return remained_hyps 506 | 507 | 508 | def beam_search( 509 | x: torch.Tensor, 510 | sos: int, 511 | eos: int, 512 | beam_size: int, 513 | vocab_size: int, 514 | scorers: Dict[str, ScorerInterface], 515 | weights: Dict[str, float], 516 | token_list: List[str] = None, 517 | maxlenratio: float = 0.0, 518 | minlenratio: float = 0.0, 519 | pre_beam_ratio: float = 1.5, 520 | pre_beam_score_key: str = "full", 521 | ) -> list: 522 | """Perform beam search with scorers. 523 | 524 | Args: 525 | x (torch.Tensor): Encoded speech feature (T, D) 526 | sos (int): Start of sequence id 527 | eos (int): End of sequence id 528 | beam_size (int): The number of hypotheses kept during search 529 | vocab_size (int): The number of vocabulary 530 | scorers (dict[str, ScorerInterface]): Dict of decoder modules 531 | e.g., Decoder, CTCPrefixScorer, LM 532 | The scorer will be ignored if it is `None` 533 | weights (dict[str, float]): Dict of weights for each scorers 534 | The scorer will be ignored if its weight is 0 535 | token_list (list[str]): List of tokens for debug log 536 | maxlenratio (float): Input length ratio to obtain max output length. 537 | If maxlenratio=0.0 (default), it uses a end-detect function 538 | to automatically find maximum hypothesis lengths 539 | minlenratio (float): Input length ratio to obtain min output length. 540 | pre_beam_score_key (str): key of scores to perform pre-beam search 541 | pre_beam_ratio (float): beam size in the pre-beam search 542 | will be `int(pre_beam_ratio * beam_size)` 543 | 544 | Returns: 545 | list: N-best decoding results 546 | 547 | """ 548 | ret = BeamSearch( 549 | scorers, 550 | weights, 551 | beam_size=beam_size, 552 | vocab_size=vocab_size, 553 | pre_beam_ratio=pre_beam_ratio, 554 | pre_beam_score_key=pre_beam_score_key, 555 | sos=sos, 556 | eos=eos, 557 | token_list=token_list, 558 | ).forward(x=x, maxlenratio=maxlenratio, minlenratio=minlenratio) 559 | return [h.asdict() for h in ret] 560 | 561 | -------------------------------------------------------------------------------- /TEMPLATE/asr1/asr_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Set bash to 'debug' mode, it will exit on : 4 | # -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', 5 | set -e 6 | set -u 7 | set -o pipefail 8 | 9 | log() { 10 | local fname=${BASH_SOURCE[1]##*/} 11 | echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" 12 | } 13 | min() { 14 | local a b 15 | a=$1 16 | for b in "$@"; do 17 | if [ "${b}" -le "${a}" ]; then 18 | a="${b}" 19 | fi 20 | done 21 | echo "${a}" 22 | } 23 | SECONDS=0 24 | 25 | # General configuration 26 | stage=1 # Processes starts from the specified stage. 27 | stop_stage=10000 # Processes is stopped at the specified stage. 28 | skip_data_prep=false # Skip data preparation stages. 29 | skip_train=false # Skip training stages. 30 | skip_eval=false # Skip decoding and evaluation stages. 31 | skip_upload=true # Skip packing and uploading stages. 32 | ngpu=1 # The number of gpus ("0" uses cpu, otherwise use gpu). 33 | num_nodes=1 # The number of nodes. 34 | nj=32 # The number of parallel jobs. 35 | inference_nj=32 # The number of parallel jobs in decoding. 36 | gpu_inference=false # Whether to perform gpu decoding. 37 | dumpdir=dump # Directory to dump features. 38 | expdir=exp # Directory to save experiments. 39 | python=python3 # Specify python to execute espnet commands. 40 | 41 | # Data preparation related 42 | local_data_opts= # The options given to local/data.sh. 43 | 44 | # Speed perturbation related 45 | speed_perturb_factors= # perturbation factors, e.g. "0.9 1.0 1.1" (separated by space). 46 | 47 | # Feature extraction related 48 | feats_type=raw # Feature type (raw or fbank_pitch). 49 | audio_format=flac # Audio format: wav, flac, wav.ark, flac.ark (only in feats_type=raw). 50 | fs=16k # Sampling rate. 51 | min_wav_duration=0.1 # Minimum duration in second. 52 | max_wav_duration=20 # Maximum duration in second. 53 | 54 | # Tokenization related 55 | token_type=bpe # Tokenization type (char or bpe). 56 | nbpe=30 # The number of BPE vocabulary. 57 | bpemode=unigram # Mode of BPE (unigram or bpe). 58 | oov="" # Out of vocabulary symbol. 59 | blank="" # CTC blank symbol 60 | sos_eos="" # sos and eos symbole 61 | bpe_input_sentence_size=100000000 # Size of input sentence for BPE. 62 | bpe_nlsyms= # non-linguistic symbols list, separated by a comma, for BPE 63 | bpe_char_cover=1.0 # character coverage when modeling BPE 64 | 65 | # Language model related 66 | use_lm=false # Use language model for ASR decoding. 67 | lm_tag= # Suffix to the result dir for language model training. 68 | lm_exp= # Specify the direcotry path for LM experiment. 69 | # If this option is specified, lm_tag is ignored. 70 | lm_stats_dir= # Specify the direcotry path for LM statistics. 71 | lm_config= # Config for language model training. 72 | lm_args= # Arguments for language model training, e.g., "--max_epoch 10". 73 | # Note that it will overwrite args in lm config. 74 | use_word_lm=false # Whether to use word language model. 75 | num_splits_lm=1 # Number of splitting for lm corpus. 76 | # shellcheck disable=SC2034 77 | word_vocab_size=10000 # Size of word vocabulary. 78 | 79 | # ASR model related 80 | asr_tag= # Suffix to the result dir for asr model training. 81 | asr_exp= # Specify the direcotry path for ASR experiment. 82 | # If this option is specified, asr_tag is ignored. 83 | asr_stats_dir= # Specify the direcotry path for ASR statistics. 84 | asr_config= # Config for asr model training. 85 | asr_args= # Arguments for asr model training, e.g., "--max_epoch 10". 86 | # Note that it will overwrite args in asr config. 87 | feats_normalize=global_mvn # Normalizaton layer type. 88 | num_splits_asr=1 # Number of splitting for lm corpus. 89 | 90 | # Decoding related 91 | inference_tag= # Suffix to the result dir for decoding. 92 | inference_config= # Config for decoding. 93 | inference_args= # Arguments for decoding, e.g., "--lm_weight 0.1". 94 | # Note that it will overwrite args in inference config. 95 | inference_lm=valid.loss.ave.pth # Language modle path for decoding. 96 | inference_asr_model=valid.acc.ave.pth # ASR model path for decoding. 97 | # e.g. 98 | # inference_asr_model=train.loss.best.pth 99 | # inference_asr_model=3epoch.pth 100 | # inference_asr_model=valid.acc.best.pth 101 | # inference_asr_model=valid.loss.ave.pth 102 | download_model= # Download a model from Model Zoo and use it for decoding. 103 | 104 | # [Task dependent] Set the datadir name created by local/data.sh 105 | train_set= # Name of training set. 106 | valid_set= # Name of validation set used for monitoring/tuning network training. 107 | test_sets= # Names of test sets. Multiple items (e.g., both dev and eval sets) can be specified. 108 | bpe_train_text= # Text file path of bpe training set. 109 | lm_train_text= # Text file path of language model training set. 110 | lm_dev_text= # Text file path of language model development set. 111 | lm_test_text= # Text file path of language model evaluation set. 112 | nlsyms_txt=none # Non-linguistic symbol list if existing. 113 | cleaner=none # Text cleaner. 114 | g2p=none # g2p method (needed if token_type=phn). 115 | lang=en # The language type of corpus. 116 | score_opts= # The options given to sclite scoring 117 | local_score_opts= # The options given to local/score.sh. 118 | asr_speech_fold_length=800 # fold_length for speech data during ASR training. 119 | asr_text_fold_length=150 # fold_length for text data during ASR training. 120 | lm_fold_length=150 # fold_length for LM training. 121 | 122 | # gini_related 123 | orig_flag= 124 | need_decode= 125 | selected_num= 126 | dataset= 127 | help_message=$(cat << EOF 128 | Usage: $0 --train-set "" --valid-set "" --test_sets "" 129 | 130 | Options: 131 | # General configuration 132 | --stage # Processes starts from the specified stage (default="${stage}"). 133 | --stop_stage # Processes is stopped at the specified stage (default="${stop_stage}"). 134 | --skip_data_prep # Skip data preparation stages (default="${skip_data_prep}"). 135 | --skip_train # Skip training stages (default="${skip_train}"). 136 | --skip_eval # Skip decoding and evaluation stages (default="${skip_eval}"). 137 | --skip_upload # Skip packing and uploading stages (default="${skip_upload}"). 138 | --ngpu # The number of gpus ("0" uses cpu, otherwise use gpu, default="${ngpu}"). 139 | --num_nodes # The number of nodes (default="${num_nodes}"). 140 | --nj # The number of parallel jobs (default="${nj}"). 141 | --inference_nj # The number of parallel jobs in decoding (default="${inference_nj}"). 142 | --gpu_inference # Whether to perform gpu decoding (default="${gpu_inference}"). 143 | --dumpdir # Directory to dump features (default="${dumpdir}"). 144 | --expdir # Directory to save experiments (default="${expdir}"). 145 | --python # Specify python to execute espnet commands (default="${python}"). 146 | 147 | # Data preparation related 148 | --local_data_opts # The options given to local/data.sh (default="${local_data_opts}"). 149 | 150 | # Speed perturbation related 151 | --speed_perturb_factors # speed perturbation factors, e.g. "0.9 1.0 1.1" (separated by space, default="${speed_perturb_factors}"). 152 | 153 | # Feature extraction related 154 | --feats_type # Feature type (raw, fbank_pitch or extracted, default="${feats_type}"). 155 | --audio_format # Audio format: wav, flac, wav.ark, flac.ark (only in feats_type=raw, default="${audio_format}"). 156 | --fs # Sampling rate (default="${fs}"). 157 | --min_wav_duration # Minimum duration in second (default="${min_wav_duration}"). 158 | --max_wav_duration # Maximum duration in second (default="${max_wav_duration}"). 159 | 160 | # Tokenization related 161 | --token_type # Tokenization type (char or bpe, default="${token_type}"). 162 | --nbpe # The number of BPE vocabulary (default="${nbpe}"). 163 | --bpemode # Mode of BPE (unigram or bpe, default="${bpemode}"). 164 | --oov # Out of vocabulary symbol (default="${oov}"). 165 | --blank # CTC blank symbol (default="${blank}"). 166 | --sos_eos # sos and eos symbole (default="${sos_eos}"). 167 | --bpe_input_sentence_size # Size of input sentence for BPE (default="${bpe_input_sentence_size}"). 168 | --bpe_nlsyms # Non-linguistic symbol list for sentencepiece, separated by a comma. (default="${bpe_nlsyms}"). 169 | --bpe_char_cover # Character coverage when modeling BPE (default="${bpe_char_cover}"). 170 | 171 | # Language model related 172 | --lm_tag # Suffix to the result dir for language model training (default="${lm_tag}"). 173 | --lm_exp # Specify the direcotry path for LM experiment. 174 | # If this option is specified, lm_tag is ignored (default="${lm_exp}"). 175 | --lm_stats_dir # Specify the direcotry path for LM statistics (default="${lm_stats_dir}"). 176 | --lm_config # Config for language model training (default="${lm_config}"). 177 | --lm_args # Arguments for language model training (default="${lm_args}"). 178 | # e.g., --lm_args "--max_epoch 10" 179 | # Note that it will overwrite args in lm config. 180 | --use_word_lm # Whether to use word language model (default="${use_word_lm}"). 181 | --word_vocab_size # Size of word vocabulary (default="${word_vocab_size}"). 182 | --num_splits_lm # Number of splitting for lm corpus (default="${num_splits_lm}"). 183 | 184 | # ASR model related 185 | --asr_tag # Suffix to the result dir for asr model training (default="${asr_tag}"). 186 | --asr_exp # Specify the direcotry path for ASR experiment. 187 | # If this option is specified, asr_tag is ignored (default="${asr_exp}"). 188 | --asr_stats_dir # Specify the direcotry path for ASR statistics (default="${asr_stats_dir}"). 189 | --asr_config # Config for asr model training (default="${asr_config}"). 190 | --asr_args # Arguments for asr model training (default="${asr_args}"). 191 | # e.g., --asr_args "--max_epoch 10" 192 | # Note that it will overwrite args in asr config. 193 | --feats_normalize # Normalizaton layer type (default="${feats_normalize}"). 194 | --num_splits_asr # Number of splitting for lm corpus (default="${num_splits_asr}"). 195 | 196 | # Decoding related 197 | --inference_tag # Suffix to the result dir for decoding (default="${inference_tag}"). 198 | --inference_config # Config for decoding (default="${inference_config}"). 199 | --inference_args # Arguments for decoding (default="${inference_args}"). 200 | # e.g., --inference_args "--lm_weight 0.1" 201 | # Note that it will overwrite args in inference config. 202 | --inference_lm # Language modle path for decoding (default="${inference_lm}"). 203 | --inference_asr_model # ASR model path for decoding (default="${inference_asr_model}"). 204 | --download_model # Download a model from Model Zoo and use it for decoding (default="${download_model}"). 205 | 206 | # [Task dependent] Set the datadir name created by local/data.sh 207 | --train_set # Name of training set (required). 208 | --valid_set # Name of validation set used for monitoring/tuning network training (required). 209 | --test_sets # Names of test sets. 210 | # Multiple items (e.g., both dev and eval sets) can be specified (required). 211 | --bpe_train_text # Text file path of bpe training set. 212 | --lm_train_text # Text file path of language model training set. 213 | --lm_dev_text # Text file path of language model development set (default="${lm_dev_text}"). 214 | --lm_test_text # Text file path of language model evaluation set (default="${lm_test_text}"). 215 | --nlsyms_txt # Non-linguistic symbol list if existing (default="${nlsyms_txt}"). 216 | --cleaner # Text cleaner (default="${cleaner}"). 217 | --g2p # g2p method (default="${g2p}"). 218 | --lang # The language type of corpus (default=${lang}). 219 | --score_opts # The options given to sclite scoring (default="{score_opts}"). 220 | --local_score_opts # The options given to local/score.sh (default="{local_score_opts}"). 221 | --asr_speech_fold_length # fold_length for speech data during ASR training (default="${asr_speech_fold_length}"). 222 | --asr_text_fold_length # fold_length for text data during ASR training (default="${asr_text_fold_length}"). 223 | --lm_fold_length # fold_length for LM training (default="${lm_fold_length}"). 224 | EOF 225 | ) 226 | 227 | # log "asr_test $0 $*" 228 | # Save command line args for logging (they will be lost after utils/parse_options.sh) 229 | run_args=$(pyscripts/utils/print_args.py $0 "$@") 230 | . utils/parse_options.sh 231 | 232 | if [ $# -ne 0 ]; then 233 | log "${help_message}" 234 | log "Error: No positional arguments are required." 235 | exit 2 236 | fi 237 | 238 | . ./path.sh 239 | . ./cmd.sh 240 | 241 | 242 | # Check required arguments 243 | [ -z "${test_sets}" ] && { log "${help_message}"; log "Error: --test_sets is required"; exit 2; }; 244 | 245 | # Check feature type 246 | if [ "${feats_type}" = raw ]; then 247 | data_feats=${dumpdir}/raw 248 | elif [ "${feats_type}" = fbank_pitch ]; then 249 | data_feats=${dumpdir}/fbank_pitch 250 | elif [ "${feats_type}" = fbank ]; then 251 | data_feats=${dumpdir}/fbank 252 | elif [ "${feats_type}" == extracted ]; then 253 | data_feats=${dumpdir}/extracted 254 | else 255 | log "${help_message}" 256 | log "Error: not supported: --feats_type ${feats_type}" 257 | exit 2 258 | fi 259 | 260 | # Use the text of the 1st evaldir if lm_test is not specified 261 | [ -z "${lm_test_text}" ] && lm_test_text="${data_feats}/${test_sets%% *}/text" 262 | 263 | # Check tokenization type 264 | if [ "${lang}" != noinfo ]; then 265 | token_listdir=data/${lang}_token_list 266 | else 267 | token_listdir=data/token_list 268 | fi 269 | 270 | bpedir="${token_listdir}/bpe_${bpemode}${nbpe}" 271 | bpeprefix="${bpedir}"/bpe 272 | bpemodel="${bpeprefix}".model 273 | bpetoken_list="${bpedir}"/tokens.txt 274 | chartoken_list="${token_listdir}"/char/tokens.txt 275 | # NOTE: keep for future development. 276 | # shellcheck disable=SC2034 277 | wordtoken_list="${token_listdir}"/word/tokens.txt 278 | 279 | if [ "${token_type}" = bpe ]; then 280 | token_list="${bpetoken_list}" 281 | elif [ "${token_type}" = char ]; then 282 | token_list="${chartoken_list}" 283 | bpemodel=none 284 | elif [ "${token_type}" = word ]; then 285 | token_list="${wordtoken_list}" 286 | bpemodel=none 287 | else 288 | log "Error: not supported --token_type '${token_type}'" 289 | exit 2 290 | fi 291 | if ${use_word_lm}; then 292 | log "Error: Word LM is not supported yet" 293 | exit 2 294 | 295 | lm_token_list="${wordtoken_list}" 296 | lm_token_type=word 297 | else 298 | lm_token_list="${token_list}" 299 | lm_token_type="${token_type}" 300 | fi 301 | 302 | 303 | # Set tag for naming of model directory 304 | if [ -z "${asr_tag}" ]; then 305 | if [ -n "${asr_config}" ]; then 306 | asr_tag="$(basename "${asr_config}" .yaml)_${feats_type}" 307 | else 308 | asr_tag="train_${feats_type}" 309 | fi 310 | if [ "${lang}" != noinfo ]; then 311 | asr_tag+="_${lang}_${token_type}" 312 | else 313 | asr_tag+="_${token_type}" 314 | fi 315 | if [ "${token_type}" = bpe ]; then 316 | asr_tag+="${nbpe}" 317 | fi 318 | # Add overwritten arg's info 319 | if [ -n "${asr_args}" ]; then 320 | asr_tag+="$(echo "${asr_args}" | sed -e "s/--/\_/g" -e "s/[ |=/]//g")" 321 | fi 322 | if [ -n "${speed_perturb_factors}" ]; then 323 | asr_tag+="_sp" 324 | fi 325 | fi 326 | echo "${lang}, ${asr_tag}" 327 | if [ -z "${lm_tag}" ]; then 328 | if [ -n "${lm_config}" ]; then 329 | lm_tag="$(basename "${lm_config}" .yaml)" 330 | else 331 | lm_tag="train" 332 | fi 333 | if [ "${lang}" != noinfo ]; then 334 | lm_tag+="_${lang}_${lm_token_type}" 335 | else 336 | lm_tag+="_${lm_token_type}" 337 | fi 338 | if [ "${lm_token_type}" = bpe ]; then 339 | lm_tag+="${nbpe}" 340 | fi 341 | # Add overwritten arg's info 342 | if [ -n "${lm_args}" ]; then 343 | lm_tag+="$(echo "${lm_args}" | sed -e "s/--/\_/g" -e "s/[ |=/]//g")" 344 | fi 345 | fi 346 | 347 | # The directory used for collect-stats mode 348 | if [ -z "${asr_stats_dir}" ]; then 349 | if [ "${lang}" != noinfo ]; then 350 | asr_stats_dir="${expdir}/asr_stats_${feats_type}_${lang}_${token_type}" 351 | else 352 | asr_stats_dir="${expdir}/asr_stats_${feats_type}_${token_type}" 353 | fi 354 | if [ "${token_type}" = bpe ]; then 355 | asr_stats_dir+="${nbpe}" 356 | fi 357 | if [ -n "${speed_perturb_factors}" ]; then 358 | asr_stats_dir+="_sp" 359 | fi 360 | fi 361 | if [ -z "${lm_stats_dir}" ]; then 362 | if [ "${lang}" != noinfo ]; then 363 | lm_stats_dir="${expdir}/lm_stats_${lang}_${lm_token_type}" 364 | else 365 | lm_stats_dir="${expdir}/lm_stats_${lm_token_type}" 366 | fi 367 | if [ "${lm_token_type}" = bpe ]; then 368 | lm_stats_dir+="${nbpe}" 369 | fi 370 | fi 371 | # The directory used for training commands 372 | if [ -z "${asr_exp}" ]; then 373 | asr_exp="${expdir}/asr_${asr_tag}" 374 | fi 375 | if [ -z "${lm_exp}" ]; then 376 | lm_exp="${expdir}/lm_${lm_tag}" 377 | fi 378 | 379 | 380 | if [ -z "${inference_tag}" ]; then 381 | if [ -n "${inference_config}" ]; then 382 | inference_tag="$(basename "${inference_config}" .yaml)" 383 | else 384 | inference_tag=inference 385 | fi 386 | # Add overwritten arg's info 387 | if [ -n "${inference_args}" ]; then 388 | inference_tag+="$(echo "${inference_args}" | sed -e "s/--/\_/g" -e "s/[ |=]//g")" 389 | fi 390 | if "${use_lm}"; then 391 | inference_tag+="_lm_$(basename "${lm_exp}")_$(echo "${inference_lm}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")" 392 | fi 393 | inference_tag+="_asr_model_$(echo "${inference_asr_model}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")" 394 | fi 395 | 396 | # ========================== Main stages start from here. ========================== 397 | 398 | if ! "${skip_data_prep}"; then 399 | if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then 400 | log "Stage 1: Data preparation for data/${test_sets}, etc." 401 | # [Task dependent] Need to create data.sh for new corpus 402 | local/test_data.sh ${local_data_opts} 403 | fi 404 | 405 | if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then 406 | if [ "${feats_type}" = raw ]; then 407 | log "Stage 2: Format wav.scp: data/ -> ${data_feats}" 408 | 409 | # ====== Recreating "wav.scp" ====== 410 | # Kaldi-wav.scp, which can describe the file path with unix-pipe, like "cat /some/path |", 411 | # shouldn't be used in training process. 412 | # "format_wav_scp.sh" dumps such pipe-style-wav to real audio file 413 | # and it can also change the audio-format and sampling rate. 414 | # If nothing is need, then format_wav_scp.sh does nothing: 415 | # i.e. the input file format and rate is same as the output. 416 | 417 | for dset in "${test_sets}"; do 418 | _suf="" 419 | utils/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}" 420 | rm -f ${data_feats}${_suf}/${dset}/{segments,wav.scp,reco2file_and_channel} 421 | _opts= 422 | if [ -e data/"${dset}"/segments ]; then 423 | # "segments" is used for splitting wav files which are written in "wav".scp 424 | # into utterances. The file format of segments: 425 | # 426 | # "e.g. call-861225-A-0050-0065 call-861225-A 5.0 6.5" 427 | # Where the time is written in seconds. 428 | _opts+="--segments data/${dset}/segments " 429 | fi 430 | # shellcheck disable=SC2086 431 | scripts/audio/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \ 432 | --audio-format "${audio_format}" --fs "${fs}" ${_opts} \ 433 | "data/${dset}/wav.scp" "${data_feats}${_suf}/${dset}" 434 | 435 | echo "${feats_type}" > "${data_feats}${_suf}/${dset}/feats_type" 436 | done 437 | 438 | elif [ "${feats_type}" = fbank_pitch ]; then 439 | log "[Require Kaldi] Stage 3: ${feats_type} extract: data/ -> ${data_feats}" 440 | 441 | for dset in "${test_sets}"; do 442 | _suf="" 443 | # 1. Copy datadir 444 | utils/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}" 445 | 446 | # 2. Feature extract 447 | _nj=$(min "${nj}" "$(<"${data_feats}${_suf}/${dset}/utt2spk" wc -l)") 448 | steps/make_fbank_pitch.sh --nj "${_nj}" --cmd "${train_cmd}" "${data_feats}${_suf}/${dset}" 449 | utils/fix_data_dir.sh "${data_feats}${_suf}/${dset}" 450 | 451 | # 3. Derive the the frame length and feature dimension 452 | scripts/feats/feat_to_shape.sh --nj "${_nj}" --cmd "${train_cmd}" \ 453 | "${data_feats}${_suf}/${dset}/feats.scp" "${data_feats}${_suf}/${dset}/feats_shape" 454 | 455 | # 4. Write feats_dim 456 | head -n 1 "${data_feats}${_suf}/${dset}/feats_shape" | awk '{ print $2 }' \ 457 | | cut -d, -f2 > ${data_feats}${_suf}/${dset}/feats_dim 458 | 459 | # 5. Write feats_type 460 | echo "${feats_type}" > "${data_feats}${_suf}/${dset}/feats_type" 461 | done 462 | 463 | elif [ "${feats_type}" = fbank ]; then 464 | log "Stage 3: ${feats_type} extract: data/ -> ${data_feats}" 465 | log "${feats_type} is not supported yet." 466 | exit 1 467 | 468 | elif [ "${feats_type}" = extracted ]; then 469 | log "Stage 3: ${feats_type} extract: data/ -> ${data_feats}" 470 | # Assumming you don't have wav.scp, but feats.scp is created by local/data.sh instead. 471 | 472 | for dset in "${test_sets}"; do 473 | _suf="" 474 | # Generate dummy wav.scp to avoid error by copy_data_dir.sh 475 | ") }' > data/"${dset}"/wav.scp 476 | utils/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}" 477 | 478 | pyscripts/feats/feat-to-shape.py "scp:head -n 1 ${data_feats}${_suf}/${dset}/feats.scp |" - | \ 479 | awk '{ print $2 }' | cut -d, -f2 > "${data_feats}${_suf}/${dset}/feats_dim" 480 | 481 | echo "${feats_type}" > "${data_feats}${_suf}/${dset}/feats_type" 482 | done 483 | 484 | else 485 | log "Error: not supported: --feats_type ${feats_type}" 486 | exit 2 487 | fi 488 | fi 489 | 490 | else 491 | log "Skip the stages for data preparation" 492 | fi 493 | 494 | # ========================== Data preparation is done here. ========================== 495 | 496 | #download_model="kamo-naoyuki/timit_asr_train_asr_raw_word_valid.acc.ave" 497 | if [ -n "${download_model}" ]; then 498 | log "Use ${download_model} for decoding and evaluation" 499 | asr_exp="${expdir}/${download_model}" 500 | echo ${asr_exp} 501 | mkdir -p "${asr_exp}" 502 | 503 | # If the model already exists, you can skip downloading 504 | espnet_model_zoo_download --unpack true "${download_model}" > "${asr_exp}/config.txt" 505 | 506 | # Get the path of each file 507 | _asr_model_file=$(<"${asr_exp}/config.txt" sed -e "s/.*'asr_model_file': '\([^']*\)'.*$/\1/") 508 | _asr_train_config=$(<"${asr_exp}/config.txt" sed -e "s/.*'asr_train_config': '\([^']*\)'.*$/\1/") 509 | 510 | # Create symbolic links 511 | ln -sf "${_asr_model_file}" "${asr_exp}" 512 | ln -sf "${_asr_train_config}" "${asr_exp}" 513 | inference_asr_model=$(basename "${_asr_model_file}") 514 | 515 | if [ "$(<${asr_exp}/config.txt grep -c lm_file)" -gt 0 ]; then 516 | _lm_file=$(<"${asr_exp}/config.txt" sed -e "s/.*'lm_file': '\([^']*\)'.*$/\1/") 517 | _lm_train_config=$(<"${asr_exp}/config.txt" sed -e "s/.*'lm_train_config': '\([^']*\)'.*$/\1/") 518 | 519 | lm_exp="${expdir}/${download_model}/lm" 520 | mkdir -p "${lm_exp}" 521 | 522 | ln -sf "${_lm_file}" "${lm_exp}" 523 | ln -sf "${_lm_train_config}" "${lm_exp}" 524 | inference_lm=$(basename "${_lm_file}") 525 | fi 526 | 527 | fi 528 | 529 | 530 | if ! "${skip_eval}"; then 531 | if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then 532 | log "Stage 3: Decoding: training_dir=${asr_exp}" 533 | 534 | if ${gpu_inference}; then 535 | _cmd="${cuda_cmd}" 536 | _ngpu=1 537 | else 538 | _cmd="${decode_cmd}" 539 | _ngpu=0 540 | fi 541 | 542 | _opts= 543 | if [ -n "${inference_config}" ]; then 544 | _opts+="--config ${inference_config} " 545 | fi 546 | if "${use_lm}"; then 547 | if "${use_word_lm}"; then 548 | _opts+="--word_lm_train_config ${lm_exp}/config.yaml " 549 | _opts+="--word_lm_file ${lm_exp}/${inference_lm} " 550 | else 551 | _opts+="--lm_train_config ${lm_exp}/config.yaml " 552 | _opts+="--lm_file ${lm_exp}/${inference_lm} " 553 | fi 554 | fi 555 | 556 | # 2. Generate run.sh 557 | log "Generate '${asr_exp}/${inference_tag}/run.sh'. You can resume the process from stage 11 using this script" 558 | mkdir -p "${asr_exp}/${inference_tag}"; echo "${run_args} --stage 11 \"\$@\"; exit \$?" > "${asr_exp}/${inference_tag}/run.sh"; chmod +x "${asr_exp}/${inference_tag}/run.sh" 559 | 560 | for dset in ${test_sets}; do 561 | _data="${data_feats}/${dset}" 562 | _dir="${asr_exp}/${inference_tag}/${dset}" 563 | _logdir="${_dir}/logdir" 564 | mkdir -p "${_logdir}" 565 | 566 | if ! "${need_decode}"; then 567 | log "Do not use decoding " 568 | ${python} -m espnet2.bin.asr_test \ 569 | --ngpu "${_ngpu}" \ 570 | --data_path_and_name_and_type "${_data},speech,sound" \ 571 | --key_file "${_logdir}"/keys.JOB.scp \ 572 | --asr_train_config "${asr_exp}"/config.yaml \ 573 | --asr_model_file "${asr_exp}"/"${inference_asr_model}" \ 574 | --orig_dir "${asr_exp}/${inference_tag}" \ 575 | --orig_flag "${orig_flag}" \ 576 | --need_decode "${need_decode}" \ 577 | --output_dir "${_logdir}"/output.JOB \ 578 | --selected_num ${selected_num} \ 579 | --dataset ${dataset} \ 580 | ${_opts} ${inference_args} 581 | else 582 | _feats_type="$(<${_data}/feats_type)" 583 | if [ "${_feats_type}" = raw ]; then 584 | _scp=wav.scp 585 | if [[ "${audio_format}" == *ark* ]]; then 586 | _type=kaldi_ark 587 | else 588 | _type=sound 589 | fi 590 | else 591 | _scp=feats.scp 592 | _type=kaldi_ark 593 | fi 594 | log "Decoding started... log: '${_logdir}/asr_test.*.log', data path and name: '${_data}/${_scp},speech,${_type}'" 595 | key_file=${_data}/${_scp} 596 | split_scps="" 597 | _nj=$(min "${inference_nj}" "$(<${key_file} wc -l)") 598 | for n in $(seq "${_nj}"); do 599 | split_scps+=" ${_logdir}/keys.${n}.scp" 600 | done 601 | # shellcheck disable=SC2086 602 | utils/split_scp.pl "${key_file}" ${split_scps} 603 | ${_cmd} --gpu "${_ngpu}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \ 604 | ${python} -m espnet2.bin.asr_test \ 605 | --ngpu "${_ngpu}" \ 606 | --data_path_and_name_and_type "${_data}/${_scp},speech,${_type}" \ 607 | --key_file "${_logdir}"/keys.JOB.scp \ 608 | --asr_train_config "${asr_exp}"/config.yaml \ 609 | --asr_model_file "${asr_exp}"/"${inference_asr_model}" \ 610 | --output_dir "${_logdir}"/output.JOB \ 611 | --orig_dir "${asr_exp}/${inference_tag}" \ 612 | --orig_flag "${orig_flag}" \ 613 | --need_decode "${need_decode}" \ 614 | ${_opts} ${inference_args} 615 | # 3. Concatenates the output files from each jobs 616 | for f in token token_int score text sum_gini token_gini cov ; do 617 | for i in $(seq "${_nj}"); do 618 | cat "${_logdir}/output.${i}/1best_recog/${f}" 619 | done | LC_ALL=C sort -k1 >"${_dir}/${f}" 620 | done 621 | fi 622 | done 623 | fi 624 | 625 | 626 | if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then 627 | log "Stage 4: Scoring" 628 | if [ "${token_type}" = pnh ]; then 629 | log "Error: Not implemented for token_type=phn" 630 | exit 1 631 | fi 632 | 633 | for dset in ${test_sets}; do 634 | _data="${data_feats}/${dset}" 635 | if ! "${need_decode}"; then 636 | dirs=("${asr_exp}/${inference_tag}/${dset}/gini-${selected_num}" "${asr_exp}/${inference_tag}/${dset}/random-${selected_num} ${asr_exp}/${inference_tag}/${dset}/cov-${selected_num}") 637 | else 638 | dirs=("${asr_exp}/${inference_tag}/${dset}") 639 | fi 640 | for _dir in ${dirs[*]} ; do 641 | for _type in cer wer ter; do 642 | [ "${_type}" = ter ] && [ ! -f "${bpemodel}" ] && continue 643 | 644 | _scoredir="${_dir}/score_${_type}" 645 | mkdir -p "${_scoredir}" 646 | 647 | if "${need_decode}"; then 648 | if [ "${_type}" = wer ]; then 649 | # Tokenize text to word level 650 | paste \ 651 | <(<"${_data}/text" \ 652 | ${python} -m espnet2.bin.tokenize_text \ 653 | -f 2- --input - --output - \ 654 | --token_type word \ 655 | --non_linguistic_symbols "${nlsyms_txt}" \ 656 | --remove_non_linguistic_symbols true \ 657 | --cleaner "${cleaner}" \ 658 | ) \ 659 | <(<"${_data}/utt2spk" awk '{ print "(" $2 "-" $1 ")" }') \ 660 | >"${_scoredir}/ref.trn" 661 | 662 | # NOTE(kamo): Don't use cleaner for hyp 663 | paste \ 664 | <(<"${_dir}/text" \ 665 | ${python} -m espnet2.bin.tokenize_text \ 666 | -f 2- --input - --output - \ 667 | --token_type word \ 668 | --non_linguistic_symbols "${nlsyms_txt}" \ 669 | --remove_non_linguistic_symbols true \ 670 | ) \ 671 | <(<"${_data}/utt2spk" awk '{ print "(" $2 "-" $1 ")" }') \ 672 | >"${_scoredir}/hyp.trn" 673 | 674 | 675 | elif [ "${_type}" = cer ]; then 676 | # Tokenize text to char level 677 | paste \ 678 | <(<"${_data}/text" \ 679 | ${python} -m espnet2.bin.tokenize_text \ 680 | -f 2- --input - --output - \ 681 | --token_type char \ 682 | --non_linguistic_symbols "${nlsyms_txt}" \ 683 | --remove_non_linguistic_symbols true \ 684 | --cleaner "${cleaner}" \ 685 | ) \ 686 | <(<"${_data}/utt2spk" awk '{ print "(" $2 "-" $1 ")" }') \ 687 | >"${_scoredir}/ref.trn" 688 | 689 | # NOTE(kamo): Don't use cleaner for hyp 690 | paste \ 691 | <(<"${_dir}/text" \ 692 | ${python} -m espnet2.bin.tokenize_text \ 693 | -f 2- --input - --output - \ 694 | --token_type char \ 695 | --non_linguistic_symbols "${nlsyms_txt}" \ 696 | --remove_non_linguistic_symbols true \ 697 | ) \ 698 | <(<"${_data}/utt2spk" awk '{ print "(" $2 "-" $1 ")" }') \ 699 | >"${_scoredir}/hyp.trn" 700 | 701 | elif [ "${_type}" = ter ]; then 702 | # Tokenize text using BPE 703 | paste \ 704 | <(<"${_data}/text" \ 705 | ${python} -m espnet2.bin.tokenize_text \ 706 | -f 2- --input - --output - \ 707 | --token_type bpe \ 708 | --bpemodel "${bpemodel}" \ 709 | --cleaner "${cleaner}" \ 710 | ) \ 711 | <(<"${_data}/utt2spk" awk '{ print "(" $2 "-" $1 ")" }') \ 712 | >"${_scoredir}/ref.trn" 713 | 714 | # NOTE(kamo): Don't use cleaner for hyp 715 | paste \ 716 | <(<"${_dir}/text" \ 717 | ${python} -m espnet2.bin.tokenize_text \ 718 | -f 2- --input - --output - \ 719 | --token_type bpe \ 720 | --bpemodel "${bpemodel}" \ 721 | ) \ 722 | <(<"${_data}/utt2spk" awk '{ print "(" $2 "-" $1 ")" }') \ 723 | >"${_scoredir}/hyp.trn" 724 | 725 | fi 726 | fi 727 | 728 | sclite \ 729 | ${score_opts} \ 730 | -r "${_scoredir}/ref.trn" trn \ 731 | -h "${_scoredir}/hyp.trn" trn \ 732 | -i rm -o all stdout > "${_scoredir}/result.txt" 733 | 734 | log "Write ${_type} result in ${_scoredir}/result.txt" 735 | grep -e Avg -e SPKR -m 2 "${_scoredir}/result.txt" 736 | done 737 | done 738 | done 739 | 740 | [ -f local/score.sh ] && local/score.sh ${local_score_opts} "${asr_exp}" 741 | 742 | # Show results in Markdown syntax 743 | scripts/utils/show_asr_result.sh "${asr_exp}" > "${asr_exp}"/RESULTS.md 744 | cat "${asr_exp}"/RESULTS.md 745 | 746 | fi 747 | else 748 | log "Skip the evaluation stages" 749 | fi 750 | 751 | 752 | log "Successfully finished. [elapsed=${SECONDS}s]" 753 | 754 | --------------------------------------------------------------------------------