├── an4
├── test.sh
└── local
│ ├── test_data.sh
│ └── test_data_prep.py
├── tedlium2
├── local
│ ├── join_suffix.py
│ └── prepare_test_data.sh
└── test.sh
├── tedlium3
├── local
│ ├── join_suffix.py
│ └── prepare_test_data.sh
└── test.sh
├── timit
├── test.sh
└── local
│ ├── timit_format_test_data.sh
│ ├── test_data.sh
│ ├── timit_norm_trans.pl
│ └── timit_test_data_prep.sh
├── README.md
├── utils
└── score_sclite.sh
├── espnet2
├── asr
│ └── decoder
│ │ └── rnn_decoder.py
└── bin
│ └── asr_test.py
├── espnet
├── bin
│ └── asr_test.py
├── utils
│ ├── gini_utils.py
│ └── gini_guide.py
└── nets
│ └── test_beam_search.py
└── TEMPLATE
└── asr1
└── asr_test.sh
/an4/test.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # Set bash to 'debug' mode, it will exit on :
3 | # -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
4 | set -e
5 | set -u
6 | set -o pipefail
7 |
8 | test_sets="test-room"
9 | orig_flag=false
10 | need_decode=true
11 | selected_num=130
12 | stage=1
13 | stop_stage=4
14 | dataset="an4"
15 | . utils/parse_options.sh
16 |
17 | ./asr_test.sh \
18 | --use_lm false \
19 | --test_sets ${test_sets} \
20 | --local_data_opts ${test_sets}\
21 | --orig_flag ${orig_flag} \
22 | --need_decode ${need_decode} \
23 | --stage ${stage} \
24 | --stop_stage ${stop_stage} \
25 | --token_type bpe \
26 | --asr_exp "exp/asr_train_raw_en_bpe30" \
27 | --dataset ${dataset} \
28 | --selected_num ${selected_num}
29 |
--------------------------------------------------------------------------------
/tedlium2/local/join_suffix.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | #
3 | # Copyright 2014 Nickolay V. Shmyrev
4 | # 2016 Johns Hopkins University (author: Daniel Povey)
5 | # Apache 2.0
6 |
7 |
8 | import sys
9 |
10 | # This script joins together pairs of split-up words like "you 're" -> "you're".
11 | # The TEDLIUM transcripts are normalized in a way that's not traditional for
12 | # speech recognition.
13 |
14 | prev_line = ""
15 | for line in sys.stdin:
16 | if line == prev_line:
17 | continue
18 | items = line.split()
19 | new_items = []
20 | i = 0
21 | while i < len(items):
22 | if i < len(items) - 1 and items[i + 1][0] == "'":
23 | new_items.append(items[i] + items[i + 1])
24 | i = i + 1
25 | else:
26 | new_items.append(items[i])
27 | i = i + 1
28 | print(" ".join(new_items))
29 | prev_line = line
30 |
--------------------------------------------------------------------------------
/tedlium3/local/join_suffix.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | #
3 | # Copyright 2014 Nickolay V. Shmyrev
4 | # 2016 Johns Hopkins University (author: Daniel Povey)
5 | # Apache 2.0
6 |
7 |
8 | import sys
9 |
10 | # This script joins together pairs of split-up words like "you 're" -> "you're".
11 | # The TEDLIUM transcripts are normalized in a way that's not traditional for
12 | # speech recognition.
13 |
14 | prev_line = ""
15 | for line in sys.stdin:
16 | if line == prev_line:
17 | continue
18 | items = line.split()
19 | new_items = []
20 | i = 0
21 | while i < len(items):
22 | if i < len(items) - 1 and items[i + 1][0] == "'":
23 | new_items.append(items[i] + items[i + 1])
24 | i = i + 1
25 | else:
26 | new_items.append(items[i])
27 | i = i + 1
28 | print(" ".join(new_items))
29 | prev_line = line
30 |
--------------------------------------------------------------------------------
/timit/test.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # Set bash to 'debug' mode, it will exit on :
3 | # -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
4 |
5 | set -e
6 | set -u
7 | set -o pipefail
8 |
9 | test_sets="test-feature"
10 | model_name="kamo-naoyuki/timit_asr_train_asr_raw_word_valid.acc.ave"
11 | asr_exp="exp/${model_name}"
12 | selected_num=192
13 | orig_flag=false
14 | need_decode=true
15 | stage=1
16 | stop_stage=4
17 | train_flag=false
18 | dataset="TIMIT"
19 |
20 | # Set this to one of ["phn", "char"] depending on your requirement
21 | trans_type=phn
22 | if [ "${trans_type}" = phn ]; then
23 | # If the transcription is "phn" type, the token splitting should be done in word level
24 | token_type=word
25 | else
26 | token_type="${trans_type}"
27 | fi
28 |
29 | asr_config=conf/train_asr.yaml
30 | lm_config=conf/train_lm_rnn.yaml
31 | inference_config=conf/decode_asr.yaml
32 |
33 | . utils/parse_options.sh
34 |
35 | ./asr_test.sh \
36 | --token_type "${token_type}" \
37 | --test_sets "${test_sets}" \
38 | --use_lm false \
39 | --asr_config "${asr_config}" \
40 | --lm_config "${lm_config}" \
41 | --inference_config "${inference_config}" \
42 | --local_data_opts "--trans_type ${trans_type} --train_flag ${train_flag} --test_sets ${test_sets}" \
43 | --selected_num ${selected_num} \
44 | --orig_flag ${orig_flag} \
45 | --need_decode ${need_decode} \
46 | --stage ${stage} \
47 | --stop_stage ${stop_stage} \
48 | --asr_exp ${asr_exp} \
49 | --dataset ${dataset}
50 |
--------------------------------------------------------------------------------
/an4/local/test_data.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # Set bash to 'debug' mode, it will exit on :
3 | # -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
4 | set -e
5 | set -u
6 | set -o pipefail
7 |
8 | log() {
9 | local fname=${BASH_SOURCE[1]##*/}
10 | echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
11 | }
12 | SECONDS=0
13 |
14 | stage=1
15 | stop_stage=100
16 |
17 | datadir=./downloads
18 | an4_root=${datadir}/an4
19 | data_url=http://www.speech.cs.cmu.edu/databases/an4/
20 | ndev_utt=100
21 | test_set=${1}
22 | log "$0 $*"
23 | echo ${test_set}
24 | . utils/parse_options.sh
25 |
26 | # if [ $# -ne 0 ]; then
27 | # log "Error: No positional arguments are required."
28 | # exit 2
29 | # fi
30 |
31 | . ./path.sh
32 | . ./cmd.sh
33 |
34 |
35 | if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
36 | log "stage 1: Data Download"
37 | mkdir -p ${datadir}
38 | local/download_and_untar.sh ${datadir} ${data_url}
39 | fi
40 |
41 |
42 | if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
43 | log "stage 2: Data preparation"
44 | mkdir -p data/${test_set}
45 |
46 | if [[ ${test_set} =~ "130" ]]
47 | then
48 | log "Do nut use preparing data"
49 | else
50 | python3 local/test_data_prep.py ${an4_root} ${test_set}
51 | fi
52 | for x in ${test_set}; do
53 | for f in text wav.scp utt2spk; do
54 | sort data/${x}/${f} -o data/${x}/${f}
55 | done
56 | utils/utt2spk_to_spk2utt.pl data/${x}/utt2spk > "data/${x}/spk2utt"
57 | done
58 |
59 | fi
60 |
61 | log "Successfully finished. [elapsed=${SECONDS}s]"
62 |
63 |
--------------------------------------------------------------------------------
/timit/local/timit_format_test_data.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # Copyright 2013 (Author: Daniel Povey)
4 | # Apache 2.0
5 |
6 | # This script takes data prepared in a corpus-dependent way
7 | # in data/local/, and converts it into the "canonical" form,
8 | # in various subdirectories of data/, e.g. data/lang, data/train, etc.
9 |
10 | . ./path.sh || exit 1;
11 |
12 | echo "Preparing test data"
13 | srcdir=data/local/data
14 | lmdir=data/local/nist_lm
15 | tmpdir=data/local/lm_tmp
16 | lexicon=data/local/dict/lexicon.txt
17 | mkdir -p $tmpdir
18 | # str1="$1"
19 | # str2="ORGI"
20 | train_flag=$1
21 | test=$(echo $2 | tr '[A-Z]' '[a-z]')
22 | # result=$(echo $str1 | grep "${str2}")
23 |
24 | # if [[ "$result" != "" ]] ; then
25 | # if ! "${train_flag}"; then
26 | # test=test-orgi
27 | # else
28 | # test=train-orgi
29 | # fi
30 | # else
31 | # if ! "${train_flag}"; then
32 | # test=test-new
33 | # else
34 | # test=train-new
35 | # fi
36 | # fi
37 |
38 | for x in $test dev; do
39 | echo $x
40 | mkdir -p data/$x
41 | if [[ $test =~ "192" ]]
42 | then
43 | log "do not use copy"
44 | else
45 | cp $srcdir/${x}_wav.scp data/$x/wav.scp || exit 1;
46 | cp $srcdir/$x.text data/$x/text || exit 1;
47 | cp $srcdir/$x.spk2utt data/$x/spk2utt || exit 1;
48 | cp $srcdir/$x.utt2spk data/$x/utt2spk || exit 1;
49 | utils/filter_scp.pl data/$x/spk2utt $srcdir/$x.spk2gender > data/$x/spk2gender || exit 1;
50 | [ -e $srcdir/${x}.stm ] && cp $srcdir/${x}.stm data/$x/stm
51 | [ -e $srcdir/${x}.glm ] && cp $srcdir/${x}.glm data/$x/glm
52 | fi
53 | utils/fix_data_dir.sh data/$x
54 | utils/validate_data_dir.sh --no-feats data/$x || exit 1
55 | done
56 |
--------------------------------------------------------------------------------
/timit/local/test_data.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # Copyright IIIT-Bangalore (Shreekantha Nadig)
3 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
4 | set -euo pipefail
5 | SECONDS=0
6 | log() {
7 | local fname=${BASH_SOURCE[1]##*/}
8 | echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
9 | }
10 |
11 |
12 | stage=0 # start from 0 if you need to start from data preparation
13 | stop_stage=100
14 | train_flag=false
15 | trans_type=phn
16 | test_sets=
17 | log "$0 $*"
18 | . utils/parse_options.sh
19 |
20 | if [ $# -ne 0 ]; then
21 | log "Error: No positional arguments are required."
22 | exit 2
23 | fi
24 |
25 | . ./path.sh
26 | . ./cmd.sh
27 | . ./db.sh
28 |
29 | echo "train flag is ${train_flag}"
30 |
31 | # general configuration
32 | if [ -z "${TIMIT}" ]; then
33 | log "Fill the value of 'TIMIT' of db.sh"
34 | exit 1
35 | fi
36 |
37 | log "data preparation started"
38 | #TIMIT=/root/espnet/egs2/timit/asr1/data/local/data
39 | if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
40 | ### Task dependent. You have to make data the following preparation part by yourself.
41 | ### But you can utilize Kaldi recipes in most cases
42 | log "stage1: Preparing data for TIMIT for ${trans_type} level transcripts"
43 | echo $TIMIT
44 | if [[ ${test_sets} =~ "192" ]]
45 | then
46 | log "Do nut use preparing data"
47 | else
48 | local/timit_test_data_prep.sh ${TIMIT} ${trans_type} ${train_flag} ${test_sets}
49 | fi
50 | fi
51 |
52 | if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
53 | log "stage2: Formatting TIMIT directories"
54 | ### Task dependent. You have to make data the following preparation part by yourself.
55 | ### But you can utilize Kaldi recipes in most cases
56 | local/timit_format_test_data.sh ${train_flag} ${test_sets}
57 | fi
58 |
59 | log "Successfully finished. [elapsed=${SECONDS}s]"
60 |
--------------------------------------------------------------------------------
/an4/local/test_data_prep.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import os
4 | import re
5 | import sys
6 |
7 | if len(sys.argv) != 3:
8 | print("Usage: python test_data_prep.py [an4_root] [test_dir]")
9 | sys.exit(1)
10 | an4_root = sys.argv[1]
11 | test_set = sys.argv[2]
12 | wav_dir = {"test": test_set}
13 | for x in ["test"]:
14 | with open(
15 | os.path.join(an4_root, "etc", "an4_" + test_set + ".transcription")
16 | ) as transcript_f, open(os.path.join("data", wav_dir[x], "text"), "w") as text_f, open(
17 | os.path.join("data", wav_dir[x], "wav.scp"), "w"
18 | ) as wav_scp_f, open(
19 | os.path.join("data", wav_dir[x], "utt2spk"), "w"
20 | ) as utt2spk_f:
21 |
22 | text_f.truncate()
23 | wav_scp_f.truncate()
24 | utt2spk_f.truncate()
25 |
26 | lines = sorted(transcript_f.readlines(), key=lambda s: s.split(" ")[0])
27 | for line in lines:
28 | line = line.strip()
29 | if not line:
30 | continue
31 | words = re.search(r"^(.*) \(", line).group(1)
32 | if words[:4] == " ":
33 | words = words[4:]
34 | if words[-5:] == " ":
35 | words = words[:-5]
36 | source = re.search(r"\((.*)\)", line).group(1)
37 | pre, mid, last = source.split("-")
38 | utt_id = "-".join([mid, pre, last])
39 | text_f.write(utt_id + " " + words + "\n")
40 | if "test-orgi" in test_set:
41 | wav_scp_f.write(
42 | utt_id
43 | + " "
44 | + "sph2pipe"
45 | + " -f wav -p -c 1 "
46 | + os.path.join(an4_root, "wav", wav_dir[x], mid, source + ".sph")
47 | + " |\n"
48 | )
49 | elif ("dev-orig" in test_set) | ("train-orig" in test_set):
50 | wav_scp_f.write(
51 | utt_id
52 | + " "
53 | + os.path.join(an4_root, "wav", wav_dir[x], mid, source + ".wav")
54 | + "\n"
55 | )
56 | else:
57 | wav_scp_f.write(
58 | utt_id
59 | + " "
60 | + os.path.join(an4_root, "wav", wav_dir[x], mid, source)
61 | + "\n"
62 | )
63 | utt2spk_f.write(utt_id + " " + mid + "\n")
64 |
65 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ASRTest: Automated Testing for Deep-Neural-Network-Driven Speech Recognition Systems
2 | We experiment ASRTest with the ASR toolkit ESPnet: (https://github.com/espnet/espnet).
3 |
4 | ## generated data
5 | - We generated 8 times the size of the data, which can meet all kinds of model testing.
6 | - AN4: https://pan.baidu.com/s/1KCDb1LKdmaExqqZqlrfAyw?pwd=xk95
7 | - TIMIT: https://pan.baidu.com/s/12WJJLdnKejEVTW1Ft4DX2A?pwd=vg5w
8 | - TEDLIUM2: https://pan.baidu.com/s/1lSo1zuf-QPwKz6MOeYN6iQ?pwd=pyqf
9 | - TEDLIUM3: https://pan.baidu.com/s/1-21psFg2oZrt_RSgabdbHg?pwd=p8xu
10 |
11 | ## How to use the generated data for testing
12 | ### Install the ESPnet (https://github.com/espnet/espnet).
13 | - You can install it easily by following the instruction provide by ESPnet.
14 | ### Add scripts to your espnet directory
15 | - For guidance utils:
16 | - ASRTest/utils/* ---> your espnet directory/utils
17 | - ASRTest/espnet/* ---> your espnet directory/espnet
18 | - ASRTest/espnet2/* ---> your espnet directory/espnet2
19 | - For egs2 (Scripts are also available for other models in egs2):
20 | - ASRTest/TEMPLATE/asr1/asr_test.sh ---> your espnet directory/egs2/TEMPLATE/asr1/asr_test.sh
21 | - ASRTest/an4/* ---> your espnet directory/egs2/an4/asr1/
22 | - ASRTest/timit/* ---> your espnet directory/egs2/timit/asr1/
23 | - For egs:
24 | - ASRTest/tedlium2/* ---> your espnet directory/egs/tedlium2/asr1/
25 | - ASRTest/tedlium3/* ---> your espnet directory/egs/tedlium3/asr1/
26 | ### Decode the generated data
27 | - For egs2:
28 | - cd egs2/xxx/asr1/
29 | - ln -s ../../TEMPLATE/asr1/asr_test.sh ./asr_test.sh
30 | - decode the origial test set: sh test.sh --test_sets "test-orig" --need_decode true --orig_flag true --dataset "xxx" --stage 1 --stop_stage 4
31 | - decode the all transformed test set: sh test.sh --test_sets "xxxx" --need_decode true --orig_flag false --dataset "xxx" --stage 1 --stop_stage 4
32 | - obtain the result on the test set transformed by ASRTest: sh test.sh --test_sets "xxxx" --need_decode false --orig_flag false --dataset "xxx" --stage 3 --stop_stage 4
33 | - For egs:
34 | - cd egs/xxx/asr1/
35 | - decode the origial test set: sh test.sh --recog_set "test-orig" --need_decode true --orig_flag true --stage 0 --stop_stage 4
36 | - decode the transformed test set: sh test.sh --recog_set "xxxx" --need_decode true --orig_flag false --stage 0 --stop_stage 4
37 | - obtain the result on the test set transformed by ASRTest: sh test.sh --recog_set "xxxx" --need_decode false --orig_flag false --stage 3 --stop_stage 4
38 | - test_set/recog_set: test-feature, test-noise, test-room, test-orig
39 | - orig_flag: if test_set/recog_set is test-orig, orig_flag is true; otherwise, it is false
40 |
--------------------------------------------------------------------------------
/tedlium3/local/prepare_test_data.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | #
3 | # Copyright 2014 Nickolay V. Shmyrev
4 | # 2014 Brno University of Technology (Author: Karel Vesely)
5 | # 2016 Johns Hopkins University (Author: Daniel Povey)
6 | # Apache 2.0
7 |
8 | # To be run from one directory above this script.
9 |
10 | . ./path.sh
11 |
12 | export LC_ALL=C
13 |
14 | sph2pipe=sph2pipe
15 |
16 | data_type=$1
17 | recog_set=$2
18 | # Prepare: test, train,
19 | for set in ${recog_set}; do
20 | dir=data/$set.orig
21 | mkdir -p $dir
22 |
23 | # Merge transcripts into a single 'stm' file, do some mappings:
24 | # - -> : map dev stm labels to be coherent with train + test,
25 | # - -> : --||--
26 | # - (2) -> null : remove pronunciation variants in transcripts, keep in dictionary
27 | # - -> null : remove marked , it is modelled implicitly (in kaldi)
28 | # - (...) -> null : remove utterance names from end-lines of train
29 | # - it 's -> it's : merge words that contain apostrophe (if compound in dictionary, local/join_suffix.py)
30 | { # Add STM header, so sclite can prepare the '.lur' file
31 | echo ';;
32 | ;; LABEL "o" "Overall" "Overall results"
33 | ;; LABEL "f0" "f0" "Wideband channel"
34 | ;; LABEL "f2" "f2" "Telephone channel"
35 | ;; LABEL "male" "Male" "Male Talkers"
36 | ;; LABEL "female" "Female" "Female Talkers"
37 | ;;'
38 | # Process the STMs
39 | cat db/TEDLIUM_release-3/${data_type}/$set/stm/*.stm | sort -k1,1 -k2,2 -k4,4n | \
40 | sed -e 's:::' \
41 | -e 's:::' \
42 | -e 's:([0-9])::g' \
43 | -e 's:::g' \
44 | -e 's:([^ ]*)$::' | \
45 | awk '{ $2 = "A"; print $0; }'
46 | } | local/join_suffix.py > data/$set.orig/stm
47 |
48 | # Prepare 'text' file
49 | # - {NOISE} -> [NOISE] : map the tags to match symbols in dictionary
50 | cat $dir/stm | grep -v -e 'ignore_time_segment_in_scoring' -e ';;' | \
51 | awk '{ printf ("%s-%07d-%07d", $1, $4*100, $5*100);
52 | for (i=7;i<=NF;i++) { printf(" %s", $i); }
53 | printf("\n");
54 | }' | tr '{}' '[]' | sort -k1,1 > $dir/text || exit 1
55 |
56 | # Prepare 'segments', 'utt2spk', 'spk2utt'
57 | cat $dir/text | cut -d" " -f 1 | awk -F"-" '{printf("%s %s %07.2f %07.2f\n", $0, $1, $2/100.0, $3/100.0)}' > $dir/segments
58 | cat $dir/segments | awk '{print $1, $2}' > $dir/utt2spk
59 | cat $dir/utt2spk | utils/utt2spk_to_spk2utt.pl > $dir/spk2utt
60 |
61 | # Prepare 'wav.scp', 'reco2file_and_channel'
62 | if [ $set == "test-orig" ]; then
63 | cat $dir/spk2utt | awk -v data_type=$data_type -v set=$set -v pwd=$PWD '{ printf("%s '$sph2pipe' -f wav -p %s/db/TEDLIUM_release-3/%s/%s/sph/%s.sph |\n", $1, pwd, data_type, set, $1); }' > $dir/wav.scp
64 | else
65 | cat $dir/spk2utt | awk -v data_type=$data_type -v set=$set -v pwd=$PWD '{ printf("%s %s/db/TEDLIUM_release-3/%s/%s/wav/%s.wav\n", $1, pwd, data_type, set, $1); }' > $dir/wav.scp
66 | fi
67 | cat $dir/wav.scp | awk '{ print $1, $1, "A"; }' > $dir/reco2file_and_channel
68 |
69 | # Create empty 'glm' file
70 | echo ';; empty.glm
71 | [FAKE] => %HESITATION / [ ] __ [ ] ;; hesitation token
72 | ' > data/$set.orig/glm
73 |
74 | # Check that data dirs are okay!
75 | utils/validate_data_dir.sh --no-feats $dir || exit 1
76 | done
77 |
78 |
--------------------------------------------------------------------------------
/tedlium2/local/prepare_test_data.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | #
3 | # Copyright 2014 Nickolay V. Shmyrev
4 | # 2014 Brno University of Technology (Author: Karel Vesely)
5 | # 2016 Johns Hopkins University (Author: Daniel Povey)
6 | # Apache 2.0
7 |
8 | # To be run from one directory above this script.
9 |
10 | . ./path.sh
11 |
12 | export LC_ALL=C
13 |
14 | sph2pipe=sph2pipe
15 | recog_set=${1}
16 |
17 | # Prepare: test, train,
18 | for set in ${recog_set}; do
19 | dir=data/$set.orig
20 | mkdir -p $dir
21 |
22 | # Merge transcripts into a single 'stm' file, do some mappings:
23 | # - -> : map dev stm labels to be coherent with train + test,
24 | # - -> : --||--
25 | # - (2) -> null : remove pronunciation variants in transcripts, keep in dictionary
26 | # - -> null : remove marked , it is modelled implicitly (in kaldi)
27 | # - (...) -> null : remove utterance names from end-lines of train
28 | # - it 's -> it's : merge words that contain apostrophe (if compound in dictionary, local/join_suffix.py)
29 | { # Add STM header, so sclite can prepare the '.lur' file
30 | echo ';;
31 | ;; LABEL "o" "Overall" "Overall results"
32 | ;; LABEL "f0" "f0" "Wideband channel"
33 | ;; LABEL "f2" "f2" "Telephone channel"
34 | ;; LABEL "male" "Male" "Male Talkers"
35 | ;; LABEL "female" "Female" "Female Talkers"
36 | ;;'
37 | # Process the STMs
38 | cat db/TEDLIUM_release2/$set/stm/*.stm | sort -k1,1 -k2,2 -k4,4n | \
39 | sed -e 's:::' \
40 | -e 's:::' \
41 | -e 's:([0-9])::g' \
42 | -e 's:::g' \
43 | -e 's:([^ ]*)$::' | \
44 | awk '{ $2 = "A"; print $0; }'
45 | } | local/join_suffix.py > data/$set.orig/stm
46 |
47 | # Prepare 'text' file
48 | # - {NOISE} -> [NOISE] : map the tags to match symbols in dictionary
49 | cat $dir/stm | grep -v -e 'ignore_time_segment_in_scoring' -e ';;' | \
50 | awk '{ printf ("%s-%07d-%07d", $1, $4*100, $5*100);
51 | for (i=7;i<=NF;i++) { printf(" %s", $i); }
52 | printf("\n");
53 | }' | tr '{}' '[]' | sort -k1,1 > $dir/text || exit 1
54 |
55 | # Prepare 'segments', 'utt2spk', 'spk2utt'
56 | cat $dir/text | cut -d" " -f 1 | awk -F"-" '{printf("%s %s %07.2f %07.2f\n", $0, $1, $2/100.0, $3/100.0)}' > $dir/segments
57 | cat $dir/segments | awk '{print $1, $2}' > $dir/utt2spk
58 | cat $dir/utt2spk | utils/utt2spk_to_spk2utt.pl > $dir/spk2utt
59 |
60 | # Prepare 'wav.scp', 'reco2file_and_channel'
61 | if [ $set == "test-orgi" ]; then
62 | cat $dir/spk2utt | awk -v set=$set -v pwd=$PWD '{ printf("%s '$sph2pipe' -f wav -p %s/db/TEDLIUM_release2/%s/sph/%s.sph |\n", $1, pwd, set, $1); }' > $dir/wav.scp
63 | else
64 | cat $dir/spk2utt | awk -v set=$set -v pwd=$PWD '{ printf("%s %s/db/TEDLIUM_release2/%s/wav/%s.wav\n", $1, pwd, set, $1); }' > $dir/wav.scp
65 | fi
66 | cat $dir/wav.scp | awk '{ print $1, $1, "A"; }' > $dir/reco2file_and_channel
67 |
68 | # Create empty 'glm' file
69 | echo ';; empty.glm
70 | [FAKE] => %HESITATION / [ ] __ [ ] ;; hesitation token
71 | ' > data/$set.orig/glm
72 |
73 | # The training set seems to not have enough silence padding in the segmentations,
74 | # especially at the beginning of segments. Extend the times.
75 | if [ $set == "train" ]; then
76 | mv data/$set.orig/segments data/$set.orig/segments.temp
77 | utils/data/extend_segment_times.py --start-padding=0.15 \
78 | --end-padding=0.1 data/$set.orig/segments || exit 1
79 | rm data/$set.orig/segments.temp
80 | fi
81 |
82 | # Check that data dirs are okay!
83 | utils/validate_data_dir.sh --no-feats $dir || exit 1
84 | done
85 |
86 |
87 |
--------------------------------------------------------------------------------
/timit/local/timit_norm_trans.pl:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env perl
2 | use warnings; #sed replacement for -w perl parameter
3 |
4 | # Copyright 2012 Arnab Ghoshal
5 |
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13 | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
14 | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
15 | # MERCHANTABLITY OR NON-INFRINGEMENT.
16 | # See the Apache 2 License for the specific language governing permissions and
17 | # limitations under the License.
18 |
19 |
20 | # This script normalizes the TIMIT phonetic transcripts that have been
21 | # extracted in a format where each line contains an utterance ID followed by
22 | # the transcript, e.g.:
23 | # fcke0_si1111 h# hh ah dx ux w iy dcl d ix f ay n ih q h#
24 |
25 | my $usage = "Usage: timit_norm_trans.pl -i transcript -m phone_map -from [60|48] -to [48|39] > normalized\n
26 | Normalizes phonetic transcriptions for TIMIT, by mapping the phones to a
27 | smaller set defined by the -m option. This script assumes that the mapping is
28 | done in the \"standard\" fashion, i.e. to 48 or 39 phones. The input is
29 | assumed to have 60 phones (+1 for glottal stop, which is deleted), but that can
30 | be changed using the -from option. The input format is assumed to be utterance
31 | ID followed by transcript on the same line.\n";
32 |
33 | use strict;
34 | use Getopt::Long;
35 | die "$usage" unless(@ARGV >= 1);
36 | my ($in_trans, $phone_map, $num_phones_out);
37 | my $num_phones_in = 60;
38 | GetOptions ("i=s" => \$in_trans, # Input transcription
39 | "m=s" => \$phone_map, # File containing phone mappings
40 | "from=i" => \$num_phones_in, # Input #phones: must be 60 or 48
41 | "to=i" => \$num_phones_out ); # Output #phones: must be 48 or 39
42 |
43 | die $usage unless(defined($in_trans) && defined($phone_map) &&
44 | defined($num_phones_out));
45 | if ($num_phones_in != 60 && $num_phones_in != 48) {
46 | die "Can only used 60 or 48 for -from (used $num_phones_in)."
47 | }
48 | if ($num_phones_out != 48 && $num_phones_out != 39) {
49 | die "Can only used 48 or 39 for -to (used $num_phones_out)."
50 | }
51 | unless ($num_phones_out < $num_phones_in) {
52 | die "Argument to -from ($num_phones_in) must be greater than that to -to ($num_phones_out)."
53 | }
54 |
55 |
56 | open(M, "<$phone_map") or die "Cannot open mappings file '$phone_map': $!";
57 | my (%phonemap, %seen_phones);
58 | my $num_seen_phones = 0;
59 | while () {
60 | chomp;
61 | next if ($_ =~ /^q\s*.*$/); # Ignore glottal stops.
62 | m:^(\S+)\s+(\S+)\s+(\S+)$: or die "Bad line: $_";
63 | my $mapped_from = ($num_phones_in == 60)? $1 : $2;
64 | my $mapped_to = ($num_phones_out == 48)? $2 : $3;
65 | if (!defined($seen_phones{$mapped_to})) {
66 | $seen_phones{$mapped_to} = 1;
67 | $num_seen_phones += 1;
68 | }
69 | $phonemap{$mapped_from} = $mapped_to;
70 | }
71 | if ($num_seen_phones != $num_phones_out) {
72 | die "Trying to map to $num_phones_out phones, but seen only $num_seen_phones";
73 | }
74 |
75 | open(T, "<$in_trans") or die "Cannot open transcription file '$in_trans': $!";
76 | while () {
77 | chomp;
78 | $_ =~ m:^(\S+)\s+(.+): or die "Bad line: $_";
79 | my $utt_id = $1;
80 | my $trans = $2;
81 |
82 | $trans =~ s/q//g; # Remove glottal stops.
83 | $trans =~ s/^\s*//; $trans =~ s/\s*$//; # Normalize spaces
84 |
85 | print $utt_id;
86 | for my $phone (split(/\s+/, $trans)) {
87 | if(exists $phonemap{$phone}) { print " $phonemap{$phone}"; }
88 | if(not exists $phonemap{$phone}) { print " $phone"; }
89 | }
90 | print "\n";
91 | }
92 |
--------------------------------------------------------------------------------
/utils/score_sclite.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # Copyright 2017 Johns Hopkins University (Shinji Watanabe)
4 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
5 |
6 | [ -f ./path.sh ] && . ./path.sh
7 |
8 | nlsyms=""
9 | wer=false
10 | bpe=""
11 | bpemodel=""
12 | remove_blank=true
13 | filter=""
14 | num_spkrs=1
15 | need_decode=
16 | guide_type=
17 | help_message="Usage: $0 "
18 |
19 | . utils/parse_options.sh
20 |
21 | if [ $# != 2 ]; then
22 | echo "${help_message}"
23 | exit 1;
24 | fi
25 |
26 | dir=$1
27 | dic=$2
28 |
29 | if "${need_decode}"; then
30 | concatjson.py ${dir}/data.*.json > ${dir}/data.json
31 | else
32 | dir=$1/${guide_type}-1155
33 | fi
34 |
35 | if [ $num_spkrs -eq 1 ]; then
36 | json2trn.py ${dir}/data.json ${dic} --num-spkrs ${num_spkrs} --refs ${dir}/ref.trn --hyps ${dir}/hyp.trn
37 |
38 | if ${remove_blank}; then
39 | sed -i.bak2 -r 's/ //g' ${dir}/hyp.trn
40 | fi
41 | if [ -n "${nlsyms}" ]; then
42 | cp ${dir}/ref.trn ${dir}/ref.trn.org
43 | cp ${dir}/hyp.trn ${dir}/hyp.trn.org
44 | filt.py -v ${nlsyms} ${dir}/ref.trn.org > ${dir}/ref.trn
45 | filt.py -v ${nlsyms} ${dir}/hyp.trn.org > ${dir}/hyp.trn
46 | fi
47 | if [ -n "${filter}" ]; then
48 | sed -i.bak3 -f ${filter} ${dir}/hyp.trn
49 | sed -i.bak3 -f ${filter} ${dir}/ref.trn
50 | fi
51 |
52 | sclite -r ${dir}/ref.trn trn -h ${dir}/hyp.trn trn -i rm -o all stdout > ${dir}/result.txt
53 |
54 | echo "write a CER (or TER) result in ${dir}/result.txt"
55 | grep -e Avg -e SPKR -m 2 ${dir}/result.txt
56 |
57 | if ${wer}; then
58 | if [ -n "$bpe" ]; then
59 | spm_decode --model=${bpemodel} --input_format=piece < ${dir}/ref.trn | sed -e "s/▁/ /g" > ${dir}/ref.wrd.trn
60 | spm_decode --model=${bpemodel} --input_format=piece < ${dir}/hyp.trn | sed -e "s/▁/ /g" > ${dir}/hyp.wrd.trn
61 | else
62 | sed -e "s/ //g" -e "s/(/ (/" -e "s// /g" ${dir}/ref.trn > ${dir}/ref.wrd.trn
63 | sed -e "s/ //g" -e "s/(/ (/" -e "s// /g" ${dir}/hyp.trn > ${dir}/hyp.wrd.trn
64 | fi
65 | sclite -r ${dir}/ref.wrd.trn trn -h ${dir}/hyp.wrd.trn trn -i rm -o all stdout > ${dir}/result.wrd.txt
66 |
67 | echo "write a WER result in ${dir}/result.wrd.txt"
68 | grep -e Avg -e SPKR -m 2 ${dir}/result.wrd.txt
69 | fi
70 | elif [ ${num_spkrs} -lt 4 ]; then
71 | ref_trns=""
72 | hyp_trns=""
73 | for i in $(seq ${num_spkrs}); do
74 | ref_trns=${ref_trns}"${dir}/ref${i}.trn "
75 | hyp_trns=${hyp_trns}"${dir}/hyp${i}.trn "
76 | done
77 | json2trn.py ${dir}/data.json ${dic} --num-spkrs ${num_spkrs} --refs ${ref_trns} --hyps ${hyp_trns}
78 |
79 | for n in $(seq ${num_spkrs}); do
80 | if ${remove_blank}; then
81 | sed -i.bak2 -r 's/ //g' ${dir}/hyp${n}.trn
82 | fi
83 | if [ -n "${nlsyms}" ]; then
84 | cp ${dir}/ref${n}.trn ${dir}/ref${n}.trn.org
85 | cp ${dir}/hyp${n}.trn ${dir}/hyp${n}.trn.org
86 | filt.py -v ${nlsyms} ${dir}/ref${n}.trn.org > ${dir}/ref${n}.trn
87 | filt.py -v ${nlsyms} ${dir}/hyp${n}.trn.org > ${dir}/hyp${n}.trn
88 | fi
89 | if [ -n "${filter}" ]; then
90 | sed -i.bak3 -f ${filter} ${dir}/hyp${n}.trn
91 | sed -i.bak3 -f ${filter} ${dir}/ref${n}.trn
92 | fi
93 | done
94 |
95 | results_str=""
96 | for (( i=0; i<$((num_spkrs * num_spkrs)); i++ )); do
97 | ind_r=$((i / num_spkrs + 1))
98 | ind_h=$((i % num_spkrs + 1))
99 | results_str=${results_str}"${dir}/result_r${ind_r}h${ind_h}.txt "
100 | sclite -r ${dir}/ref${ind_r}.trn trn -h ${dir}/hyp${ind_h}.trn trn -i rm -o all stdout > ${dir}/result_r${ind_r}h${ind_h}.txt
101 | done
102 |
103 | echo "write CER (or TER) results in ${dir}/result_r*h*.txt"
104 | eval_perm_free_error.py --num-spkrs ${num_spkrs} \
105 | ${results_str} > ${dir}/min_perm_result.json
106 | sed -n '2,4p' ${dir}/min_perm_result.json
107 |
108 | if ${wer}; then
109 | for n in $(seq ${num_spkrs}); do
110 | if [ -n "$bpe" ]; then
111 | spm_decode --model=${bpemodel} --input_format=piece < ${dir}/ref${n}.trn | sed -e "s/▁/ /g" > ${dir}/ref${n}.wrd.trn
112 | spm_decode --model=${bpemodel} --input_format=piece < ${dir}/hyp${n}.trn | sed -e "s/▁/ /g" > ${dir}/hyp${n}.wrd.trn
113 | else
114 | sed -e "s/ //g" -e "s/(/ (/" -e "s// /g" ${dir}/ref${n}.trn > ${dir}/ref${n}.wrd.trn
115 | sed -e "s/ //g" -e "s/(/ (/" -e "s// /g" ${dir}/hyp${n}.trn > ${dir}/hyp${n}.wrd.trn
116 | fi
117 | done
118 | results_str=""
119 | for (( i=0; i<$((num_spkrs * num_spkrs)); i++ )); do
120 | ind_r=$((i / num_spkrs + 1))
121 | ind_h=$((i % num_spkrs + 1))
122 | results_str=${results_str}"${dir}/result_r${ind_r}h${ind_h}.wrd.txt "
123 | sclite -r ${dir}/ref${ind_r}.wrd.trn trn -h ${dir}/hyp${ind_h}.wrd.trn trn -i rm -o all stdout > ${dir}/result_r${ind_r}h${ind_h}.wrd.txt
124 | done
125 |
126 | echo "write WER results in ${dir}/result_r*h*.wrd.txt"
127 | eval_perm_free_error.py --num-spkrs ${num_spkrs} \
128 | ${results_str} > ${dir}/min_perm_result.wrd.json
129 | sed -n '2,4p' ${dir}/min_perm_result.wrd.json
130 | fi
131 | fi
132 |
133 |
--------------------------------------------------------------------------------
/timit/local/timit_test_data_prep.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 |
4 | create_glm_stm=false
5 |
6 | if [ $# -le 0 ]; then
7 | echo "Argument should be the Timit directory, see ../run.sh for example."
8 | exit 1;
9 | fi
10 |
11 | dir=`pwd`/data/local/data
12 | lmdir=`pwd`/data/local/nist_lm
13 | mkdir -p $dir $lmdir
14 | local=`pwd`/local
15 | utils=`pwd`/utils
16 | conf=`pwd`/conf
17 | train_flag=$3
18 | if [ $2 ]; then
19 | if [[ $2 = "char" || $2 = "phn" ]]; then
20 | trans_type=$2
21 | else
22 | echo "Transcript type must be one of [phn, char]"
23 | echo $2
24 | fi
25 | else
26 | trans_type=phn
27 | fi
28 |
29 | . ./path.sh # Needed for KALDI_ROOT
30 | export PATH=$PATH:$KALDI_ROOT/tools/irstlm/bin
31 | sph2pipe=sph2pipe
32 | if ! command -v "${sph2pipe}" &> /dev/null; then
33 | echo "Could not find (or execute) the sph2pipe program at $sph2pipe";
34 | exit 1;
35 | fi
36 |
37 | [ -f $conf/test_spk.list ] || error_exit "$PROG: Eval-set speaker list not found.";
38 |
39 | # First check if the test directories exist (these can either be upper-
40 | # or lower-cased
41 | # if [ ! -d $1/TEST-NEW ] && [ ! -d $1/test-new ] && [ ! -d $1/TEST-ORGI ] && [ ! -d $1/test-orgi ]; then
42 | # # if [ ! -d $1/TEST ] && [ ! -d $1/test ] && [ ! -d $1/TEST ] && [ ! -d $1/test ]; then
43 | # echo "timit_test_data_prep.sh: Spot check of command line argument failed"
44 | # echo "Command line argument must be absolute pathname to TIMIT directory"
45 | # echo "with name like /export/corpora5/LDC/LDC93S1/timit/TIMIT"
46 | # exit 1;
47 | # fi
48 |
49 | # Now check what case the directory structure is
50 |
51 | uppercased=true
52 | test=$4
53 | test_dir=$(echo $test | tr '[a-z]' '[A-Z]')
54 | # if [ -d $1/TEST-NEW ] || [ -d $1/TEST-ORGI ] ; then
55 | # uppercased=true
56 | # fi
57 |
58 | # str1="$1"
59 | # str2="ORGI"
60 |
61 | # result=$(echo $str1 | grep "${str2}")
62 | # if [[ "$result" != "" ]] ; then
63 | # if ! "${train_flag}"; then
64 | # test=test-orgi
65 | # test_dir=TEST-ORGI
66 | # else
67 | # test=train-orgi
68 | # test_dir=TRAIN-ORGI
69 | # fi
70 | # else
71 | # if ! "${train_flag}"; then
72 | # test=test-new
73 | # test_dir=TEST-NEW
74 | # else
75 | # test=train-new
76 | # test_dir=TRAIN-NEW
77 | # fi
78 | # fi
79 |
80 | tmpdir=$(mktemp -d /tmp/kaldi.XXXX);
81 | trap 'rm -rf "$tmpdir"' EXIT
82 | # Get the list of speakers. The list of speakers in the 24-speaker core test
83 | # set and the 50-speaker development set must be supplied to the script. All
84 | # speakers in the 'train' directory are used for training.
85 | if $uppercased; then
86 | tr '[:lower:]' '[:upper:]' < $conf/dev_spk.list > $tmpdir/dev_spk
87 | if ! "${train_flag}"; then
88 | tr '[:lower:]' '[:upper:]' < $conf/test_spk.list > $tmpdir/${test}_spk
89 | else
90 | ls -d "$1"/$test_dir/DR*/* | sed -e "s:^.*/::" > $tmpdir/${test}_spk
91 | fi
92 | else
93 | tr '[:upper:]' '[:lower:]' < $conf/dev_spk.list > $tmpdir/dev_spk
94 | if ! "${train_flag}"; then
95 | tr '[:lower:]' '[:upper:]' < $conf/test_spk.list > $tmpdir/${test}_spk
96 | else
97 | ls -d "$1"/$test_dir/DR*/* | sed -e "s:^.*/::" > $tmpdir/${test}_spk
98 | fi
99 | fi
100 |
101 | cd $dir
102 | for x in $test dev; do
103 | # First, find the list of audio files (use only si & sx utterances).
104 | # Note: train & test sets are under different directories, but doing find on
105 | # both and grepping for the speakers will work correctly.
106 | echo "test_dir is $1/$test_dir"
107 | find $1/$test_dir -not \( -iname 'SA*' \) -iname '*.WAV' \
108 | | grep -f $tmpdir/${x}_spk > ${x}_sph.flist
109 |
110 | sed -e 's:.*/\(.*\)/\(.*\).WAV$:\1_\2:i' ${x}_sph.flist \
111 | > $tmpdir/${x}_sph.uttids
112 | paste $tmpdir/${x}_sph.uttids ${x}_sph.flist \
113 | | sort -k1,1 > ${x}_sph.scp
114 |
115 | cat ${x}_sph.scp | awk '{print $1}' > ${x}.uttids
116 |
117 | # Now, Convert the transcripts into our format (no normalization yet)
118 | # Get the transcripts: each line of the output contains an utterance
119 | # ID followed by the transcript.
120 |
121 | if [ $trans_type = "phn" ]
122 | then
123 | echo "phone transcript!"
124 | find $1/$test_dir -not \( -iname 'SA*' \) -iname '*.PHN' \
125 | | grep -f $tmpdir/${x}_spk > $tmpdir/${x}_phn.flist
126 | sed -e 's:.*/\(.*\)/\(.*\).PHN$:\1_\2:i' $tmpdir/${x}_phn.flist \
127 | > $tmpdir/${x}_phn.uttids
128 | while read line; do
129 | [ -f $line ] || error_exit "Cannot find transcription file '$line'";
130 | cut -f3 -d' ' "$line" | tr '\n' ' ' | perl -ape 's: *$:\n:;'
131 | done < $tmpdir/${x}_phn.flist > $tmpdir/${x}_phn.trans
132 | paste $tmpdir/${x}_phn.uttids $tmpdir/${x}_phn.trans \
133 | | sort -k1,1 > ${x}.trans
134 |
135 | elif [ $trans_type = "char" ]
136 | then
137 | echo "char transcript!"
138 | find $1/$test_dir -not \( -iname 'SA*' \) -iname '*.WRD' \
139 | | grep -f $tmpdir/${x}_spk > $tmpdir/${x}_wrd.flist
140 | sed -e 's:.*/\(.*\)/\(.*\).WRD$:\1_\2:i' $tmpdir/${x}_wrd.flist \
141 | > $tmpdir/${x}_wrd.uttids
142 | while read line; do
143 | [ -f $line ] || error_exit "Cannot find transcription file '$line'";
144 | cut -f3 -d' ' "$line" | tr '\n' ' ' | perl -ape 's: *$:\n:;' | tr '[:upper:]' '[:lower:]' | sed 's/[^a-z A-Z]//g'
145 | done < $tmpdir/${x}_wrd.flist > $tmpdir/${x}_wrd.trans
146 | paste $tmpdir/${x}_wrd.uttids $tmpdir/${x}_wrd.trans \
147 | | sort -k1,1 > ${x}.trans
148 | else
149 | echo "WRONG!"
150 | echo $trans_type
151 | exit 0;
152 | fi
153 |
154 | # Do normalization steps.
155 | cat ${x}.trans | $local/timit_norm_trans.pl -i - -m $conf/phones.60-48-39.map -to 39 | sort > $x.text || exit 1;
156 |
157 | # Create wav.scp
158 | awk '{printf("%s '$sph2pipe' -f wav %s |\n", $1, $2);}' < ${x}_sph.scp > ${x}_wav.scp
159 |
160 | # Make the utt2spk and spk2utt files.
161 | cut -f1 -d'_' $x.uttids | paste -d' ' $x.uttids - > $x.utt2spk
162 | cat $x.utt2spk | $utils/utt2spk_to_spk2utt.pl > $x.spk2utt || exit 1;
163 |
164 | # Prepare gender mapping
165 | cat $x.spk2utt | awk '{print $1}' | perl -ane 'chop; m:^.:; $g = lc($&); print "$_ $g\n";' > $x.spk2gender
166 |
167 |
168 | if "${create_glm_stm}"; then
169 | # Prepare STM file for sclite:
170 | wav-to-duration --read-entire-file=true scp:${x}_wav.scp ark,t:${x}_dur.ark || exit 1
171 | awk -v dur=${x}_dur.ark \
172 | 'BEGIN{
173 | while(getline < dur) { durH[$1]=$2; }
174 | print ";; LABEL \"O\" \"Overall\" \"Overall\"";
175 | print ";; LABEL \"F\" \"Female\" \"Female speakers\"";
176 | print ";; LABEL \"M\" \"Male\" \"Male speakers\"";
177 | }
178 | { wav=$1; spk=wav; sub(/_.*/,"",spk); $1=""; ref=$0;
179 | gender=(substr(spk,0,1) == "f" ? "F" : "M");
180 | printf("%s 1 %s 0.0 %f %s\n", wav, spk, durH[wav], gender, ref);
181 | }
182 | ' ${x}.text >${x}.stm || exit 1
183 |
184 | # Create dummy GLM file for sclite:
185 | echo ';; empty.glm
186 | [FAKE] => %HESITATION / [ ] __ [ ] ;; hesitation token
187 | ' > ${x}.glm
188 | fi
189 | done
190 |
191 | echo "Data preparation succeeded"
--------------------------------------------------------------------------------
/tedlium3/test.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # Copyright 2019 Nagoya University (Masao Someki)
4 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
5 |
6 | . ./path.sh || exit 1;
7 | . ./cmd.sh || exit 1;
8 |
9 | # general configuration
10 | backend=pytorch
11 | stage=3 # start from -1 if you need to start from data download
12 | stop_stage=3
13 | ngpu=0 # number of gpus ("0" uses cpu, otherwise use gpu)
14 | debugmode=1
15 | dumpdir=dump # directory to dump full features
16 | N=0 # number of minibatches to be used (mainly for debugging). "0" uses all minibatches.
17 | verbose=1 # verbose option
18 | resume= # Resume the training from snapshot
19 |
20 | # feature configuration
21 | do_delta=false
22 | cmvn=
23 | # rnnlm related
24 | lm_resume= # specify a snapshot file to resume LM training
25 | lmtag= # tag for managing LMs
26 | use_lang_model=false
27 | lang_model=
28 |
29 | # decoding parameter
30 | p=0.005
31 | recog_model=
32 | recog_dir=
33 | decode_config=
34 | decode_dir=decode
35 | api=v2
36 |
37 | # bpemode (unigram or bpe)
38 | nbpe=500
39 | bpemode=unigram
40 |
41 | # exp tag
42 | tag="" # tag for managing experiments.
43 |
44 | train_config=
45 | decode_config=
46 | preprocess_config=
47 | lm_config=
48 | models=tedlium3.conformer
49 | # gini related
50 | orig_flag=false
51 | orig_dir=
52 | need_decode=false
53 |
54 | data_type=legacy
55 | train_set=train_trim_sp
56 | recog_set=dev-new
57 |
58 | . utils/parse_options.sh || exit 1;
59 |
60 | # Set bash to 'debug' mode, it will exit on :
61 | # -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
62 | set -e
63 | set -u
64 | set -o pipefail
65 |
66 |
67 | # legacy setup
68 |
69 |
70 | #adaptation setup
71 | #data_type=speaker-adaptation
72 | #train_set=train_adapt_trim_sp
73 | #train_dev=dev_adapt_trim
74 | # recog_set="dev_adapt test_adapt"
75 |
76 | download_dir=${decode_dir}/download
77 |
78 | if [ "${api}" = "v2" ] && [ "${backend}" = "chainer" ]; then
79 | echo "chainer backend does not support api v2." >&2
80 | exit 1;
81 | fi
82 |
83 | if [ -z $models ]; then
84 | if [ $use_lang_model = "true" ]; then
85 | if [[ -z $cmvn || -z $lang_model || -z $recog_model || -z $decode_config ]]; then
86 | echo 'Error: models or set of cmvn, lang_model, recog_model and decode_config are required.' >&2
87 | exit 1
88 | fi
89 | else
90 | if [[ -z $cmvn || -z $recog_model || -z $decode_config ]]; then
91 | echo 'Error: models or set of cmvn, recog_model and decode_config are required.' >&2
92 | exit 1
93 | fi
94 | fi
95 | fi
96 |
97 | dir=${download_dir}/${models}
98 | mkdir -p ${dir}
99 |
100 | # Download trained models
101 | if [ -z "${cmvn}" ]; then
102 | #download_models
103 | cmvn=$(find ${download_dir}/${models} -name "cmvn.ark" | head -n 1)
104 | fi
105 | if [ -z "${lang_model}" ] && ${use_lang_model}; then
106 | #download_models
107 | lang_model=$(find ${download_dir}/${models} -name "rnnlm*.best*" | head -n 1)
108 | fi
109 | if [ -z "${recog_model}" ]; then
110 | #download_models
111 | if [ -z "${recog_dir}" ]; then
112 | recog_model=$(find ${download_dir}/${models} -name "model*.best*" | head -n 1)
113 | else
114 | recog_model=$(find "${recog_dir}/results" -name "model.loss.best" | head -n 1)
115 | fi
116 | echo "recog_model is ${recog_model}"
117 | fi
118 | if [ -z "${decode_config}" ]; then
119 | #download_models
120 | decode_config=$(find ${download_dir}/${models} -name "decode*.yaml" | head -n 1)
121 | fi
122 |
123 | # Check file existence
124 | if [ ! -f "${cmvn}" ]; then
125 | echo "No such CMVN file: ${cmvn}"
126 | exit 1
127 | fi
128 | if [ ! -f "${lang_model}" ] && ${use_lang_model}; then
129 | echo "No such language model: ${lang_model}"
130 | exit 1
131 | fi
132 | if [ ! -f "${recog_model}" ]; then
133 | echo "No such E2E model: ${recog_model}"
134 | exit 1
135 | fi
136 | if [ ! -f "${decode_config}" ]; then
137 | echo "No such config file: ${decode_config}"
138 | exit 1
139 | fi
140 |
141 |
142 | if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
143 | ### Task dependent. You have to make data the following preparation part by yourself.
144 | ### But you can utilize Kaldi recipes in most cases
145 | echo "stage 0: Data preparation"
146 | if [ -z "${recog_dir}" ]; then
147 | local/prepare_test_data.sh ${data_type} ${recog_set}
148 | fi
149 | for dset in ${recog_set}; do
150 | utils/fix_data_dir.sh data/${dset}.orig
151 | utils/data/modify_speaker_info.sh --seconds-per-spk-max 180 data/${dset}.orig data/${dset}
152 | done
153 | fi
154 |
155 | # feat_tr_dir=${dumpdir}/${train_set}/delta${do_delta}; mkdir -p ${feat_tr_dir}
156 | # feat_dt_dir=${dumpdir}/${train_dev}/delta${do_delta}; mkdir -p ${feat_dt_dir}
157 | if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
158 | ### Task dependent. You have to design training and dev sets by yourself.
159 | ### But you can utilize Kaldi recipes in most cases
160 | echo "stage 1: Feature Generation"
161 | fbankdir=fbank
162 | # Generate the fbank features; by default 80-dimensional fbanks with pitch on each frame
163 | for x in ${recog_set}; do
164 | steps/make_fbank_pitch.sh --cmd "$train_cmd" --nj 32 --write_utt2num_frames true \
165 | data/${x} exp/make_fbank/${x} ${fbankdir}
166 | utils/fix_data_dir.sh data/${x}
167 | done
168 |
169 | for rtask in ${recog_set}; do
170 | feat_recog_dir=${dumpdir}/${rtask}/delta${do_delta}; mkdir -p ${feat_recog_dir}
171 | dump.sh --cmd "$train_cmd" --nj 32 --do_delta ${do_delta} \
172 | data/${rtask}/feats.scp ${cmvn} exp/dump_feats/recog/${rtask} \
173 | ${feat_recog_dir}
174 | done
175 | fi
176 |
177 | dict=data/lang_char/train_trim_sp_${bpemode}${nbpe}_units.txt
178 | bpemodel=data/lang_char/train_trim_sp_${bpemode}${nbpe}
179 | echo "dictionary: ${dict}"
180 | if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
181 | ### Task dependent. You have to check non-linguistic symbols used in the corpus.
182 | echo "stage 2: Dictionary and Json Data Preparation"
183 |
184 | # make json labels
185 | for rtask in ${recog_set}; do
186 | feat_recog_dir=${dumpdir}/${rtask}/delta${do_delta}
187 | data2json.sh --feat ${feat_recog_dir}/feats.scp --bpecode ${bpemodel}.model\
188 | data/${rtask} ${dict} > ${feat_recog_dir}/data_${bpemode}${nbpe}.json
189 | done
190 | fi
191 |
192 | # It takes a few days. If you just want to end-to-end ASR without LM,
193 | # you can skip this and remove --rnnlm option in the recognition (stage 5)
194 |
195 | if [ -z "${recog_dir}" ]; then
196 | expname=${models}
197 | expdir=exp/${expname}
198 | mkdir -p ${expdir}
199 | else
200 | expdir=${recog_dir}
201 | fi
202 |
203 | if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
204 | echo "stage 3: Decoding"
205 | nj=32
206 | if ${use_lang_model}; then
207 | recog_opts="--rnnlm ${lang_model}"
208 | else
209 | recog_opts=""
210 | fi
211 | #feat_recog_dir=${decode_dir}/dump
212 | pids=() # initialize pids
213 | #trap 'rm -rf data/"${recog_set}" data/"${recog_set}.orig"' EXIT
214 | for rtask in ${recog_set}; do
215 | (
216 | decode_dir=decode_${rtask}_decode_${lmtag}
217 | feat_recog_dir=${dumpdir}/${rtask}/delta${do_delta}
218 |
219 | # # split data
220 | if "${need_decode}"; then
221 | splitjson.py --parts ${nj} ${feat_recog_dir}/data_${bpemode}${nbpe}.json
222 | fi
223 | orig_dir=${expdir}/decode_test-orig_decode_${lmtag}
224 |
225 | #### use CPU for decoding
226 | ngpu=0
227 | if "${need_decode}"; then
228 | ${decode_cmd} JOB=1:${nj} ${expdir}/${decode_dir}/log/decode.JOB.log \
229 | asr_test.py \
230 | --config ${decode_config} \
231 | --ngpu ${ngpu} \
232 | --backend ${backend} \
233 | --debugmode ${debugmode} \
234 | --verbose ${verbose} \
235 | --recog-json ${feat_recog_dir}/split32utt/data_${bpemode}${nbpe}.JOB.json \
236 | --result-label ${expdir}/${decode_dir}/data.JOB.json \
237 | --model ${recog_model} \
238 | --api ${api} \
239 | --orig_dir ${orig_dir} \
240 | --need_decode ${need_decode} \
241 | --orig_flag ${orig_flag} \
242 | --recog_set ${recog_set} \
243 | ${recog_opts}
244 |
245 | else
246 | asr_test.py \
247 | --config ${decode_config} \
248 | --ngpu ${ngpu} \
249 | --backend ${backend} \
250 | --debugmode ${debugmode} \
251 | --verbose ${verbose} \
252 | --recog-json ${feat_recog_dir}/split${nj}utt/data_${bpemode}${nbpe}.JOB.json \
253 | --result-label ${expdir}/${decode_dir}/data.JOB.json \
254 | --model ${recog_model} \
255 | --api ${api} \
256 | --orig_dir ${orig_dir} \
257 | --need_decode ${need_decode} \
258 | --orig_flag ${orig_flag} \
259 | --recog_set ${recog_set} \
260 | ${recog_opts}
261 | fi
262 |
263 | ) &
264 | pids+=($!) # store background pids
265 | done
266 | i=0; for pid in "${pids[@]}"; do wait ${pid} || ((++i)); done
267 | [ ${i} -gt 0 ] && echo "$0: ${i} background jobs are failed." && false
268 |
269 | fi
270 |
271 |
272 | if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
273 | echo "stage 4: Scoring"
274 | decode_dir=decode_${recog_set}_decode_${lmtag}
275 | if "${need_decode}"; then
276 | score_sclite.sh --bpe ${nbpe} --bpemodel ${bpemodel}.model --wer true --need_decode ${need_decode} --guide_type "gini" ${expdir}/${decode_dir} ${dict}
277 | else
278 | score_sclite.sh --bpe ${nbpe} --bpemodel ${bpemodel}.model --wer true --need_decode ${need_decode} --guide_type "gini" ${expdir}/${decode_dir} ${dict}
279 | fi
280 | echo "Finished"
281 | fi
282 |
--------------------------------------------------------------------------------
/tedlium2/test.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # Copyright 2017 Johns Hopkins University (Shinji Watanabe)
4 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
5 |
6 | . ./path.sh || exit 1;
7 | . ./cmd.sh || exit 1;
8 |
9 | # general configuration
10 | backend=pytorch
11 | stage=3 # start from -1 if you need to start from data download
12 | stop_stage=4
13 | ngpu=0 # number of gpus ("0" uses cpu, otherwise use gpu)
14 | debugmode=1
15 | dumpdir=dump # directory to dump full features
16 | N=0 # number of minibatches to be used (mainly for debugging). "0" uses all minibatches.
17 | verbose=1 # verbose option
18 | resume= # Resume the training from snapshot
19 |
20 | # feature configuration
21 | do_delta=false
22 | cmvn=
23 |
24 | preprocess_config=conf/specaug.yaml
25 | train_config=
26 | lm_config=
27 |
28 | # rnnlm related
29 | skip_lm_training=true # for only using end-to-end ASR model without LM
30 | lm_resume= # specify a snapshot file to resume LM training
31 | lmtag= # tag for managing LMs
32 | use_lang_model=false
33 | lang_model=
34 |
35 | # test related
36 | models=tedlium2.transformer.v1
37 | # decoding parameter
38 | p=0.005
39 | recog_model=
40 | recog_dir=
41 | decode_config=
42 | decode_dir=decode
43 | api=v2
44 |
45 |
46 | # model average realted (only for transformer)
47 | # n_average=10 # the number of ASR models to be averaged
48 | # use_valbest_average=true # if true, the validation `n_average`-best ASR models will be averaged.
49 | # # if false, the last `n_average` ASR models will be averaged.
50 |
51 | # bpemode (unigram or bpe)
52 | nbpe=500
53 | bpemode=unigram
54 |
55 | # exp tag
56 | tag="" # tag for managing experiments.
57 | recog_set=test-noise
58 |
59 | # gini related
60 | orig_flag=false
61 | orig_dir=
62 | need_decode=false
63 |
64 | . utils/parse_options.sh || exit 1;
65 |
66 | # Set bash to 'debug' mode, it will exit on :
67 | # -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
68 | set -e
69 | set -u
70 | set -o pipefail
71 |
72 | # train_set=train_trim_sp
73 | # train_dev=dev_trim
74 |
75 | download_dir=${decode_dir}/download
76 |
77 | if [ "${api}" = "v2" ] && [ "${backend}" = "chainer" ]; then
78 | echo "chainer backend does not support api v2." >&2
79 | exit 1;
80 | fi
81 |
82 | if [ -z $models ]; then
83 | if [ $use_lang_model = "true" ]; then
84 | if [[ -z $cmvn || -z $lang_model || -z $recog_model || -z $decode_config ]]; then
85 | echo 'Error: models or set of cmvn, lang_model, recog_model and decode_config are required.' >&2
86 | exit 1
87 | fi
88 | else
89 | if [[ -z $cmvn || -z $recog_model || -z $decode_config ]]; then
90 | echo 'Error: models or set of cmvn, recog_model and decode_config are required.' >&2
91 | exit 1
92 | fi
93 | fi
94 | fi
95 |
96 | dir=${download_dir}/${models}
97 | mkdir -p ${dir}
98 |
99 |
100 | function download_models () {
101 | if [ -z $models ]; then
102 | return
103 | fi
104 |
105 | file_ext="tar.gz"
106 | case "${models}" in
107 | "tedlium2.rnn.v1") share_url="https://drive.google.com/open?id=1UqIY6WJMZ4sxNxSugUqp3mrGb3j6h7xe"; api=v1 ;;
108 | "tedlium2.rnn.v2") share_url="https://drive.google.com/open?id=1cac5Uc09lJrCYfWkLQsF8eapQcxZnYdf"; api=v1 ;;
109 | "tedlium2.transformer.v1") share_url="https://drive.google.com/open?id=1cVeSOYY1twOfL9Gns7Z3ZDnkrJqNwPow" ;;
110 | "tedlium3.transformer.v1") share_url="https://drive.google.com/open?id=1zcPglHAKILwVgfACoMWWERiyIquzSYuU" ;;
111 | "librispeech.transformer.v1") share_url="https://drive.google.com/open?id=1BtQvAnsFvVi-dp_qsaFP7n4A_5cwnlR6" ;;
112 | "librispeech.transformer.v1.transformerlm.v1") share_url="https://drive.google.com/open?id=17cOOSHHMKI82e1MXj4r2ig8gpGCRmG2p" ;;
113 | "commonvoice.transformer.v1") share_url="https://drive.google.com/open?id=1tWccl6aYU67kbtkm8jv5H6xayqg1rzjh" ;;
114 | "csj.transformer.v1") share_url="https://drive.google.com/open?id=120nUQcSsKeY5dpyMWw_kI33ooMRGT2uF" ;;
115 | *) echo "No such models: ${models}"; exit 1 ;;
116 | esac
117 |
118 | if [ ! -e ${dir}/.complete ]; then
119 | download_from_google_drive.sh ${share_url} ${dir} ${file_ext}
120 | touch ${dir}/.complete
121 | fi
122 | }
123 |
124 | # Download trained models
125 | if [ -z "${cmvn}" ]; then
126 | #download_models
127 | cmvn=$(find ${download_dir}/${models} -name "cmvn.ark" | head -n 1)
128 | fi
129 | if [ -z "${lang_model}" ] && ${use_lang_model}; then
130 | #download_models
131 | lang_model=$(find ${download_dir}/${models} -name "rnnlm*.best*" | head -n 1)
132 | fi
133 | if [ -z "${recog_model}" ]; then
134 | #download_models
135 | if [ -z "${recog_dir}" ]; then
136 | recog_model=$(find ${download_dir}/${models} -name "model*.best*" | head -n 1)
137 | else
138 | recog_model=$(find "${recog_dir}/results" -name "model.acc.best" | head -n 1)
139 | fi
140 | echo "recog_model is ${recog_model}"
141 | fi
142 | if [ -z "${decode_config}" ]; then
143 | #download_models
144 | decode_config=$(find ${download_dir}/${models} -name "decode*.yaml" | head -n 1)
145 | fi
146 |
147 |
148 | # Check file existence
149 | if [ ! -f "${cmvn}" ]; then
150 | echo "No such CMVN file: ${cmvn}"
151 | exit 1
152 | fi
153 | if [ ! -f "${lang_model}" ] && ${use_lang_model}; then
154 | echo "No such language model: ${lang_model}"
155 | exit 1
156 | fi
157 | if [ ! -f "${recog_model}" ]; then
158 | echo "No such E2E model: ${recog_model}"
159 | exit 1
160 | fi
161 | if [ ! -f "${decode_config}" ]; then
162 | echo "No such config file: ${decode_config}"
163 | exit 1
164 | fi
165 |
166 | echo "stage ${stage}"
167 |
168 | if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
169 | ### Task dependent. You have to make data the following preparation part by yourself.
170 | ### But you can utilize Kaldi recipes in most cases
171 | echo "stage 0: Data preparation"
172 | if [ -z "${recog_dir}" ]; then
173 | local/prepare_test_data.sh ${recog_set}
174 | fi
175 | for dset in ${recog_set}; do
176 | utils/fix_data_dir.sh data/${dset}.orig
177 | utils/data/modify_speaker_info.sh --seconds-per-spk-max 180 data/${dset}.orig data/${dset}
178 | done
179 | fi
180 |
181 | if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
182 | ### Task dependent. You have to design training and dev sets by yourself.
183 | ### But you can utilize Kaldi recipes in most cases
184 | echo "stage 1: Feature Generation"
185 | fbankdir=fbank
186 |
187 | # Generate the fbank features; by default 80-dimensional fbanks with pitch on each frame
188 | for x in ${recog_set}; do
189 | #utils/fix_data_dir.sh data/${x}
190 | steps/make_fbank_pitch.sh --cmd "$train_cmd" --nj 32 --write_utt2num_frames true \
191 | data/${x} exp/make_fbank/${x} ${fbankdir}
192 | utils/fix_data_dir.sh data/${x}
193 | done
194 |
195 | for rtask in ${recog_set}; do
196 | feat_recog_dir=${dumpdir}/${rtask}/delta${do_delta}; mkdir -p ${feat_recog_dir}
197 | dump.sh --cmd "$train_cmd" --nj 32 --do_delta ${do_delta} \
198 | data/${rtask}/feats.scp ${cmvn} exp/dump_feats/recog/${rtask} \
199 | ${feat_recog_dir}
200 | done
201 | fi
202 |
203 | dict=data/lang_char/train_trim_sp_${bpemode}${nbpe}_units.txt
204 | bpemodel=data/lang_char/train_trim_sp_${bpemode}${nbpe}
205 | if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
206 | ### Task dependent. You have to check non-linguistic symbols used in the corpus.
207 | echo "stage 2: Dictionary and Json Data Preparation"
208 |
209 | for rtask in ${recog_set}; do
210 | feat_recog_dir=${dumpdir}/${rtask}/delta${do_delta}
211 | data2json.sh --feat ${feat_recog_dir}/feats.scp --bpecode ${bpemodel}.model\
212 | data/${rtask} ${dict} > ${feat_recog_dir}/data_${bpemode}${nbpe}.json
213 | done
214 | fi
215 |
216 | if [ -z "${recog_dir}" ]; then
217 | expname=${models}
218 | expdir=exp/${expname}
219 | mkdir -p ${expdir}
220 | else
221 | expdir=${recog_dir}
222 | fi
223 |
224 | echo "${expdir}"
225 | if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
226 | echo "stage 3: Decoding"
227 | nj=32
228 | if ${use_lang_model}; then
229 | recog_opts="--rnnlm ${lang_model}"
230 | else
231 | recog_opts=""
232 | fi
233 | #feat_recog_dir=${decode_dir}/dump
234 | pids=() # initialize pids
235 | #trap 'rm -rf data/"${recog_set}" data/"${recog_set}.orig"' EXIT
236 | for rtask in ${recog_set}; do
237 | (
238 | decode_dir=decode_${rtask}_decode_${lmtag}
239 | feat_recog_dir=${dumpdir}/${rtask}/delta${do_delta}
240 |
241 | # # split data
242 | if "${need_decode}"; then
243 | splitjson.py --parts ${nj} ${feat_recog_dir}/data_${bpemode}${nbpe}.json
244 | fi
245 | orig_dir=${expdir}/decode_test-orig_decode_${lmtag}
246 |
247 | #### use CPU for decoding
248 | ngpu=0
249 | if "${need_decode}"; then
250 | ${decode_cmd} JOB=1:${nj} ${expdir}/${decode_dir}/log/decode.JOB.log \
251 | asr_test.py \
252 | --config ${decode_config} \
253 | --ngpu ${ngpu} \
254 | --backend ${backend} \
255 | --debugmode ${debugmode} \
256 | --verbose ${verbose} \
257 | --recog-json ${feat_recog_dir}/split32utt/data_${bpemode}${nbpe}.JOB.json \
258 | --result-label ${expdir}/${decode_dir}/data.JOB.json \
259 | --model ${recog_model} \
260 | --api ${api} \
261 | --orig_dir ${orig_dir} \
262 | --need_decode ${need_decode} \
263 | --orig_flag ${orig_flag} \
264 | --recog_set ${recog_set} \
265 | ${recog_opts}
266 |
267 | else
268 | asr_test.py \
269 | --config ${decode_config} \
270 | --ngpu ${ngpu} \
271 | --backend ${backend} \
272 | --debugmode ${debugmode} \
273 | --verbose ${verbose} \
274 | --recog-json ${feat_recog_dir}/split${nj}utt/data_${bpemode}${nbpe}.JOB.json \
275 | --result-label ${expdir}/${decode_dir}/data.JOB.json \
276 | --model ${recog_model} \
277 | --api ${api} \
278 | --orig_dir ${orig_dir} \
279 | --need_decode ${need_decode} \
280 | --orig_flag ${orig_flag} \
281 | --recog_set ${recog_set} \
282 | ${recog_opts}
283 | fi
284 |
285 | ) &
286 | pids+=($!) # store background pids
287 | done
288 | i=0; for pid in "${pids[@]}"; do wait ${pid} || ((++i)); done
289 | [ ${i} -gt 0 ] && echo "$0: ${i} background jobs are failed." && false
290 |
291 | fi
292 |
293 |
294 | if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
295 | echo "stage 4: Scoring"
296 | decode_dir=decode_${recog_set}_decode_${lmtag}
297 | if "${need_decode}"; then
298 | score_sclite.sh --bpe ${nbpe} --bpemodel ${bpemodel}.model --wer true --need_decode ${need_decode} --guide_type "gini" ${expdir}/${decode_dir} ${dict}
299 | else
300 | score_sclite.sh --bpe ${nbpe} --bpemodel ${bpemodel}.model --wer true --need_decode ${need_decode} --guide_type "gini" ${expdir}/${decode_dir} ${dict}
301 | fi
302 | echo "Finished"
303 | fi
304 |
305 |
306 |
--------------------------------------------------------------------------------
/espnet2/asr/decoder/rnn_decoder.py:
--------------------------------------------------------------------------------
1 | import random
2 | import logging
3 | import numpy as np
4 | import torch
5 | import torch.nn.functional as F
6 | from typeguard import check_argument_types
7 |
8 | from espnet.nets.pytorch_backend.nets_utils import make_pad_mask
9 | from espnet.nets.pytorch_backend.nets_utils import to_device
10 | from espnet.nets.pytorch_backend.rnn.attentions import initial_att
11 | from espnet2.asr.decoder.abs_decoder import AbsDecoder
12 | from espnet2.utils.get_default_kwargs import get_default_kwargs
13 |
14 |
15 | def build_attention_list(
16 | eprojs: int,
17 | dunits: int,
18 | atype: str = "location",
19 | num_att: int = 1,
20 | num_encs: int = 1,
21 | aheads: int = 4,
22 | adim: int = 320,
23 | awin: int = 5,
24 | aconv_chans: int = 10,
25 | aconv_filts: int = 100,
26 | han_mode: bool = False,
27 | han_type=None,
28 | han_heads: int = 4,
29 | han_dim: int = 320,
30 | han_conv_chans: int = -1,
31 | han_conv_filts: int = 100,
32 | han_win: int = 5,
33 | ):
34 |
35 | att_list = torch.nn.ModuleList()
36 | if num_encs == 1:
37 | for i in range(num_att):
38 | att = initial_att(
39 | atype,
40 | eprojs,
41 | dunits,
42 | aheads,
43 | adim,
44 | awin,
45 | aconv_chans,
46 | aconv_filts,
47 | )
48 | att_list.append(att)
49 | elif num_encs > 1: # no multi-speaker mode
50 | if han_mode:
51 | att = initial_att(
52 | han_type,
53 | eprojs,
54 | dunits,
55 | han_heads,
56 | han_dim,
57 | han_win,
58 | han_conv_chans,
59 | han_conv_filts,
60 | han_mode=True,
61 | )
62 | return att
63 | else:
64 | att_list = torch.nn.ModuleList()
65 | for idx in range(num_encs):
66 | att = initial_att(
67 | atype[idx],
68 | eprojs,
69 | dunits,
70 | aheads[idx],
71 | adim[idx],
72 | awin[idx],
73 | aconv_chans[idx],
74 | aconv_filts[idx],
75 | )
76 | att_list.append(att)
77 | else:
78 | raise ValueError(
79 | "Number of encoders needs to be more than one. {}".format(num_encs)
80 | )
81 | return att_list
82 |
83 |
84 |
85 | class RNNDecoder(AbsDecoder):
86 | def __init__(
87 | self,
88 | vocab_size: int,
89 | encoder_output_size: int,
90 | rnn_type: str = "lstm",
91 | num_layers: int = 1,
92 | hidden_size: int = 320,
93 | sampling_probability: float = 0.0,
94 | dropout: float = 0.0,
95 | context_residual: bool = False,
96 | replace_sos: bool = False,
97 | num_encs: int = 1,
98 | att_conf: dict = get_default_kwargs(build_attention_list),
99 | ):
100 | # FIXME(kamo): The parts of num_spk should be refactored more more more
101 | assert check_argument_types()
102 | if rnn_type not in {"lstm", "gru"}:
103 | raise ValueError(f"Not supported: rnn_type={rnn_type}")
104 |
105 | super().__init__()
106 | eprojs = encoder_output_size
107 | self.dtype = rnn_type
108 | self.dunits = hidden_size
109 | self.dlayers = num_layers
110 | self.context_residual = context_residual
111 | self.sos = vocab_size - 1
112 | self.eos = vocab_size - 1
113 | self.odim = vocab_size
114 | self.sampling_probability = sampling_probability
115 | self.dropout = dropout
116 | self.num_encs = num_encs
117 |
118 | # for multilingual translation
119 | self.replace_sos = replace_sos
120 |
121 | self.embed = torch.nn.Embedding(vocab_size, hidden_size)
122 | self.dropout_emb = torch.nn.Dropout(p=dropout)
123 |
124 | self.decoder = torch.nn.ModuleList()
125 | self.dropout_dec = torch.nn.ModuleList()
126 | self.decoder += [
127 | torch.nn.LSTMCell(hidden_size + eprojs, hidden_size)
128 | if self.dtype == "lstm"
129 | else torch.nn.GRUCell(hidden_size + eprojs, hidden_size)
130 | ]
131 |
132 | self.dropout_dec += [torch.nn.Dropout(p=dropout)]
133 | for _ in range(1, self.dlayers):
134 | self.decoder += [
135 | torch.nn.LSTMCell(hidden_size, hidden_size)
136 | if self.dtype == "lstm"
137 | else torch.nn.GRUCell(hidden_size, hidden_size)
138 | ]
139 | self.dropout_dec += [torch.nn.Dropout(p=dropout)]
140 | # NOTE: dropout is applied only for the vertical connections
141 | # see https://arxiv.org/pdf/1409.2329.pdf
142 | if context_residual:
143 | self.output = torch.nn.Linear(hidden_size + eprojs, vocab_size)
144 | else:
145 | self.output = torch.nn.Linear(hidden_size, vocab_size)
146 |
147 | self.att_list = build_attention_list(
148 | eprojs=eprojs, dunits=hidden_size, **att_conf
149 | )
150 |
151 | def zero_state(self, hs_pad):
152 | return hs_pad.new_zeros(hs_pad.size(0), self.dunits)
153 |
154 | def caul_gini(self, softmax_value):
155 | gini = 0
156 | class_softmax = softmax_value.cpu().detach().numpy()
157 | for i in range(0, len(class_softmax)):
158 | gini += np.square(class_softmax[i])
159 | return 1-gini
160 |
161 | def rnn_forward(self, ey, z_list, c_list, z_prev, c_prev):
162 | if self.dtype == "lstm":
163 | z_list[0], c_list[0] = self.decoder[0](ey, (z_prev[0], c_prev[0]))
164 | for i in range(1, self.dlayers):
165 | z_list[i], c_list[i] = self.decoder[i](
166 | self.dropout_dec[i - 1](z_list[i - 1]),
167 | (z_prev[i], c_prev[i]),
168 | )
169 |
170 | else:
171 | z_list[0] = self.decoder[0](ey, z_prev[0])
172 | for i in range(1, self.dlayers):
173 | z_list[i] = self.decoder[i](
174 | self.dropout_dec[i - 1](z_list[i - 1]), z_prev[i]
175 | )
176 | return z_list, c_list
177 |
178 | def forward(self, hs_pad, hlens, ys_in_pad, ys_in_lens, strm_idx=0):
179 | # to support mutiple encoder asr mode, in single encoder mode,
180 | # convert torch.Tensor to List of torch.Tensor
181 | if self.num_encs == 1:
182 | hs_pad = [hs_pad]
183 | hlens = [hlens]
184 |
185 | # attention index for the attention module
186 | # in SPA (speaker parallel attention),
187 | # att_idx is used to select attention module. In other cases, it is 0.
188 | att_idx = min(strm_idx, len(self.att_list) - 1)
189 |
190 | # hlens should be list of list of integer
191 | hlens = [list(map(int, hlens[idx])) for idx in range(self.num_encs)]
192 |
193 | # get dim, length info
194 |
195 | olength = ys_in_pad.size(1)
196 | # initialization
197 | c_list = [self.zero_state(hs_pad[0])]
198 | z_list = [self.zero_state(hs_pad[0])]
199 | for _ in range(1, self.dlayers):
200 | c_list.append(self.zero_state(hs_pad[0]))
201 | z_list.append(self.zero_state(hs_pad[0]))
202 | z_all = []
203 | if self.num_encs == 1:
204 | att_w = None
205 | self.att_list[att_idx].reset() # reset pre-computation of h
206 | else:
207 | att_w_list = [None] * (self.num_encs + 1) # atts + han
208 | att_c_list = [None] * self.num_encs # atts
209 | for idx in range(self.num_encs + 1):
210 | # reset pre-computation of h in atts and han
211 | self.att_list[idx].reset()
212 |
213 | # pre-computation of embedding
214 | eys = self.dropout_emb(self.embed(ys_in_pad)) # utt x olen x zdim
215 |
216 | # loop for an output sequence
217 | # logging.info(f"olength is {olength}")
218 | for i in range(olength):
219 | if self.num_encs == 1:
220 | att_c, att_w = self.att_list[att_idx](
221 | hs_pad[0], hlens[0], self.dropout_dec[0](z_list[0]), att_w
222 | )
223 | else:
224 | for idx in range(self.num_encs):
225 | att_c_list[idx], att_w_list[idx] = self.att_list[idx](
226 | hs_pad[idx],
227 | hlens[idx],
228 | self.dropout_dec[0](z_list[0]),
229 | att_w_list[idx],
230 | )
231 | hs_pad_han = torch.stack(att_c_list, dim=1)
232 | hlens_han = [self.num_encs] * len(ys_in_pad)
233 | att_c, att_w_list[self.num_encs] = self.att_list[self.num_encs](
234 | hs_pad_han,
235 | hlens_han,
236 | self.dropout_dec[0](z_list[0]),
237 | att_w_list[self.num_encs],
238 | )
239 | if i > 0 and random.random() < self.sampling_probability:
240 | z_out = self.output(z_all[-1])
241 | z_out = np.argmax(z_out.detach().cpu(), axis=1)
242 | z_out = self.dropout_emb(self.embed(to_device(self, z_out)))
243 | ey = torch.cat((z_out, att_c), dim=1) # utt x (zdim + hdim)
244 | else:
245 | # utt x (zdim + hdim)
246 | ey = torch.cat((eys[:, i, :], att_c), dim=1)
247 | z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list, c_list)
248 | if self.context_residual:
249 | z_all.append(
250 | torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)
251 | ) # utt x (zdim + hdim)
252 | else:
253 | z_all.append(self.dropout_dec[-1](z_list[-1])) # utt x (zdim)
254 |
255 | z_all = torch.stack(z_all, dim=1)
256 | z_all = self.output(z_all)
257 | z_all.masked_fill_(
258 | make_pad_mask(ys_in_lens, z_all, 1),
259 | 0,
260 | )
261 | return z_all, ys_in_lens
262 |
263 | def init_state(self, x):
264 | # to support mutiple encoder asr mode, in single encoder mode,
265 | # convert torch.Tensor to List of torch.Tensor
266 | if self.num_encs == 1:
267 | x = [x]
268 |
269 | c_list = [self.zero_state(x[0].unsqueeze(0))]
270 | z_list = [self.zero_state(x[0].unsqueeze(0))]
271 | for _ in range(1, self.dlayers):
272 | c_list.append(self.zero_state(x[0].unsqueeze(0)))
273 | z_list.append(self.zero_state(x[0].unsqueeze(0)))
274 | # TODO(karita): support strm_index for `asr_mix`
275 | strm_index = 0
276 | att_idx = min(strm_index, len(self.att_list) - 1)
277 | if self.num_encs == 1:
278 | a = None
279 | self.att_list[att_idx].reset() # reset pre-computation of h
280 | else:
281 | a = [None] * (self.num_encs + 1) # atts + han
282 | for idx in range(self.num_encs + 1):
283 | # reset pre-computation of h in atts and han
284 | self.att_list[idx].reset()
285 | return dict(
286 | c_prev=c_list[:],
287 | z_prev=z_list[:],
288 | a_prev=a,
289 | workspace=(att_idx, z_list, c_list),
290 | )
291 |
292 | def score(self, yseq, state, x):
293 | # to support mutiple encoder asr mode, in single encoder mode,
294 | # convert torch.Tensor to List of torch.Tensor
295 | #print("call rnn decoder score:", yseq, state, x)
296 | if self.num_encs == 1:
297 | x = [x]
298 |
299 | att_idx, z_list, c_list = state["workspace"]
300 | vy = yseq[-1].unsqueeze(0)
301 | ey = self.dropout_emb(self.embed(vy)) # utt list (1) x zdim
302 | if self.num_encs == 1:
303 | att_c, att_w = self.att_list[att_idx](
304 | x[0].unsqueeze(0),
305 | [x[0].size(0)],
306 | self.dropout_dec[0](state["z_prev"][0]),
307 | state["a_prev"],
308 | )
309 | else:
310 | att_w = [None] * (self.num_encs + 1) # atts + han
311 | att_c_list = [None] * self.num_encs # atts
312 | for idx in range(self.num_encs):
313 | att_c_list[idx], att_w[idx] = self.att_list[idx](
314 | x[idx].unsqueeze(0),
315 | [x[idx].size(0)],
316 | self.dropout_dec[0](state["z_prev"][0]),
317 | state["a_prev"][idx],
318 | )
319 | h_han = torch.stack(att_c_list, dim=1)
320 | att_c, att_w[self.num_encs] = self.att_list[self.num_encs](
321 | h_han,
322 | [self.num_encs],
323 | self.dropout_dec[0](state["z_prev"][0]),
324 | state["a_prev"][self.num_encs],
325 | )
326 | ey = torch.cat((ey, att_c), dim=1) # utt(1) x (zdim + hdim)
327 | z_list, c_list = self.rnn_forward(
328 | ey, z_list, c_list, state["z_prev"], state["c_prev"]
329 | )
330 | if self.context_residual:
331 | logits = self.output(
332 | torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)
333 | )
334 | else:
335 | logits = self.output(self.dropout_dec[-1](z_list[-1]))
336 |
337 | logp = F.log_softmax(logits, dim=1).squeeze(0)
338 | outp = F.softmax(logits, dim=1).squeeze(0)
339 | c=[o for o in outp.numpy()]
340 | sum_output = 0
341 | for o in c:
342 | sum_output += o
343 | # logging.info(f"outp is {outp}, sum is {sum_output}")
344 | # logging.info(f"z_list is {z_list}")
345 | return (
346 | logp,
347 | dict(
348 | c_prev=c_list[:],
349 | z_prev=z_list[:],
350 | a_prev=att_w,
351 | workspace=(att_idx, z_list, c_list),
352 | ),
353 | outp,
354 | z_list,
355 | )
356 |
--------------------------------------------------------------------------------
/espnet/bin/asr_test.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # encoding: utf-8
3 |
4 | # Copyright 2017 Johns Hopkins University (Shinji Watanabe)
5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6 |
7 | """End-to-end speech recognition model decoding script."""
8 |
9 | import configargparse
10 | import logging
11 | import os
12 | import random
13 | import sys
14 |
15 | import numpy as np
16 |
17 | from espnet.utils.cli_utils import strtobool
18 | from espnet.utils.gini_guide import *
19 | from espnet.utils.gini_utils import *
20 | from espnet.utils.retrain_utils import *
21 |
22 | # NOTE: you need this func to generate our sphinx doc
23 |
24 |
25 | def get_parser():
26 | """Get default arguments."""
27 | parser = configargparse.ArgumentParser(
28 | description="Transcribe text from speech using "
29 | "a speech recognition model on one CPU or GPU",
30 | config_file_parser_class=configargparse.YAMLConfigFileParser,
31 | formatter_class=configargparse.ArgumentDefaultsHelpFormatter,
32 | )
33 | # general configuration
34 | parser.add("--config", is_config_file=True, help="Config file path")
35 | parser.add(
36 | "--config2",
37 | is_config_file=True,
38 | help="Second config file path that overwrites the settings in `--config`",
39 | )
40 | parser.add(
41 | "--config3",
42 | is_config_file=True,
43 | help="Third config file path that overwrites the settings "
44 | "in `--config` and `--config2`",
45 | )
46 |
47 | parser.add_argument("--ngpu", type=int, default=0, help="Number of GPUs")
48 | parser.add_argument(
49 | "--dtype",
50 | choices=("float16", "float32", "float64"),
51 | default="float32",
52 | help="Float precision (only available in --api v2)",
53 | )
54 | parser.add_argument(
55 | "--backend",
56 | type=str,
57 | default="chainer",
58 | choices=["chainer", "pytorch"],
59 | help="Backend library",
60 | )
61 | parser.add_argument("--debugmode", type=int, default=1, help="Debugmode")
62 | parser.add_argument("--seed", type=int, default=1, help="Random seed")
63 | parser.add_argument("--verbose", "-V", type=int, default=1, help="Verbose option")
64 | parser.add_argument(
65 | "--batchsize",
66 | type=int,
67 | default=1,
68 | help="Batch size for beam search (0: means no batch processing)",
69 | )
70 | parser.add_argument(
71 | "--preprocess-conf",
72 | type=str,
73 | default=None,
74 | help="The configuration file for the pre-processing",
75 | )
76 | parser.add_argument(
77 | "--api",
78 | default="v1",
79 | choices=["v1", "v2"],
80 | help="Beam search APIs "
81 | "v1: Default API. It only supports the ASRInterface.recognize method "
82 | "and DefaultRNNLM. "
83 | "v2: Experimental API. It supports any models that implements ScorerInterface.",
84 | )
85 | # task related
86 | parser.add_argument(
87 | "--recog-json", type=str, help="Filename of recognition data (json)"
88 | )
89 | parser.add_argument(
90 | "--result-label",
91 | type=str,
92 | required=True,
93 | help="Filename of result label data (json)",
94 | )
95 | # model (parameter) related
96 | parser.add_argument(
97 | "--model", type=str, required=True, help="Model file parameters to read"
98 | )
99 | parser.add_argument(
100 | "--model-conf", type=str, default=None, help="Model config file"
101 | )
102 | parser.add_argument(
103 | "--num-spkrs",
104 | type=int,
105 | default=1,
106 | choices=[1, 2],
107 | help="Number of speakers in the speech",
108 | )
109 | parser.add_argument(
110 | "--num-encs", default=1, type=int, help="Number of encoders in the model."
111 | )
112 | # search related
113 | parser.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
114 | parser.add_argument("--beam-size", type=int, default=1, help="Beam size")
115 | parser.add_argument("--penalty", type=float, default=0.0, help="Incertion penalty")
116 | parser.add_argument(
117 | "--maxlenratio",
118 | type=float,
119 | default=0.0,
120 | help="""Input length ratio to obtain max output length.
121 | If maxlenratio=0.0 (default), it uses a end-detect function
122 | to automatically find maximum hypothesis lengths""",
123 | )
124 | parser.add_argument(
125 | "--minlenratio",
126 | type=float,
127 | default=0.0,
128 | help="Input length ratio to obtain min output length",
129 | )
130 | parser.add_argument(
131 | "--ctc-weight", type=float, default=0.0, help="CTC weight in joint decoding"
132 | )
133 | parser.add_argument(
134 | "--weights-ctc-dec",
135 | type=float,
136 | action="append",
137 | help="ctc weight assigned to each encoder during decoding."
138 | "[in multi-encoder mode only]",
139 | )
140 | parser.add_argument(
141 | "--ctc-window-margin",
142 | type=int,
143 | default=0,
144 | help="""Use CTC window with margin parameter to accelerate
145 | CTC/attention decoding especially on GPU. Smaller magin
146 | makes decoding faster, but may increase search errors.
147 | If margin=0 (default), this function is disabled""",
148 | )
149 | # transducer related
150 | parser.add_argument(
151 | "--search-type",
152 | type=str,
153 | default="default",
154 | choices=["default", "nsc", "tsd", "alsd"],
155 | help="""Type of beam search implementation to use during inference.
156 | Can be either: default beam search, n-step constrained beam search ("nsc"),
157 | time-synchronous decoding ("tsd") or alignment-length synchronous decoding
158 | ("alsd").
159 | Additional associated parameters: "nstep" + "prefix-alpha" (for nsc),
160 | "max-sym-exp" (for tsd) and "u-max" (for alsd)""",
161 | )
162 | parser.add_argument(
163 | "--nstep",
164 | type=int,
165 | default=1,
166 | help="Number of expansion steps allowed in NSC beam search.",
167 | )
168 | parser.add_argument(
169 | "--prefix-alpha",
170 | type=int,
171 | default=2,
172 | help="Length prefix difference allowed in NSC beam search.",
173 | )
174 | parser.add_argument(
175 | "--max-sym-exp",
176 | type=int,
177 | default=2,
178 | help="Number of symbol expansions allowed in TSD decoding.",
179 | )
180 | parser.add_argument(
181 | "--u-max",
182 | type=int,
183 | default=400,
184 | help="Length prefix difference allowed in ALSD beam search.",
185 | )
186 | parser.add_argument(
187 | "--score-norm",
188 | type=strtobool,
189 | nargs="?",
190 | default=True,
191 | help="Normalize transducer scores by length",
192 | )
193 | # rnnlm related
194 | parser.add_argument(
195 | "--rnnlm", type=str, default=None, help="RNNLM model file to read"
196 | )
197 | parser.add_argument(
198 | "--rnnlm-conf", type=str, default=None, help="RNNLM model config file to read"
199 | )
200 | parser.add_argument(
201 | "--word-rnnlm", type=str, default=None, help="Word RNNLM model file to read"
202 | )
203 | parser.add_argument(
204 | "--word-rnnlm-conf",
205 | type=str,
206 | default=None,
207 | help="Word RNNLM model config file to read",
208 | )
209 | parser.add_argument("--word-dict", type=str, default=None, help="Word list to read")
210 | parser.add_argument("--lm-weight", type=float, default=0.1, help="RNNLM weight")
211 | # ngram related
212 | parser.add_argument(
213 | "--ngram-model", type=str, default=None, help="ngram model file to read"
214 | )
215 | parser.add_argument("--ngram-weight", type=float, default=0.1, help="ngram weight")
216 | parser.add_argument(
217 | "--ngram-scorer",
218 | type=str,
219 | default="part",
220 | choices=("full", "part"),
221 | help="""if the ngram is set as a part scorer, similar with CTC scorer,
222 | ngram scorer only scores topK hypethesis.
223 | if the ngram is set as full scorer, ngram scorer scores all hypthesis
224 | the decoding speed of part scorer is musch faster than full one""",
225 | )
226 | # streaming related
227 | parser.add_argument(
228 | "--streaming-mode",
229 | type=str,
230 | default=None,
231 | choices=["window", "segment"],
232 | help="""Use streaming recognizer for inference.
233 | `--batchsize` must be set to 0 to enable this mode""",
234 | )
235 | parser.add_argument("--streaming-window", type=int, default=10, help="Window size")
236 | parser.add_argument(
237 | "--streaming-min-blank-dur",
238 | type=int,
239 | default=10,
240 | help="Minimum blank duration threshold",
241 | )
242 | parser.add_argument(
243 | "--streaming-onset-margin", type=int, default=1, help="Onset margin"
244 | )
245 | parser.add_argument(
246 | "--streaming-offset-margin", type=int, default=1, help="Offset margin"
247 | )
248 | # non-autoregressive related
249 | # Mask CTC related. See https://arxiv.org/abs/2005.08700 for the detail.
250 | parser.add_argument(
251 | "--maskctc-n-iterations",
252 | type=int,
253 | default=10,
254 | help="Number of decoding iterations."
255 | "For Mask CTC, set 0 to predict 1 mask/iter.",
256 | )
257 | parser.add_argument(
258 | "--maskctc-probability-threshold",
259 | type=float,
260 | default=0.999,
261 | help="Threshold probability for CTC output",
262 | )
263 | parser.add_argument("--orig_flag", type=str, default="False")
264 | parser.add_argument("--orig_dir", type=str, required=True)
265 | parser.add_argument(
266 | "--recog_set",
267 | type=str,
268 | help="Whether use gini or not",
269 | )
270 | parser.add_argument(
271 | "--need_decode",
272 | type=str,
273 | default="True",
274 | )
275 |
276 | return parser
277 |
278 |
279 | def str_2_bool(s):
280 | return True if s.lower() =='true' else False
281 |
282 | def main(args):
283 | """Run the main decoding function."""
284 | parser = get_parser()
285 | args = parser.parse_args(args)
286 | orig_flag = str_2_bool(args.orig_flag)
287 | need_decode = str_2_bool(args.need_decode)
288 | if args.ngpu == 0 and args.dtype == "float16":
289 | raise ValueError(f"--dtype {args.dtype} does not support the CPU backend.")
290 |
291 | # logging info
292 | if args.verbose == 1:
293 | logging.basicConfig(
294 | level=logging.INFO,
295 | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
296 | )
297 | elif args.verbose == 2:
298 | logging.basicConfig(
299 | level=logging.DEBUG,
300 | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
301 | )
302 | else:
303 | logging.basicConfig(
304 | level=logging.WARN,
305 | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
306 | )
307 | logging.warning("Skip DEBUG/INFO messages")
308 |
309 | # check CUDA_VISIBLE_DEVICES
310 | if args.ngpu > 0:
311 | cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
312 | if cvd is None:
313 | logging.warning("CUDA_VISIBLE_DEVICES is not set.")
314 | elif args.ngpu != len(cvd.split(",")):
315 | logging.error("#gpus is not matched with CUDA_VISIBLE_DEVICES.")
316 | sys.exit(1)
317 |
318 | # TODO(mn5k): support of multiple GPUs
319 | if args.ngpu > 1:
320 | logging.error("The program only supports ngpu=1.")
321 | sys.exit(1)
322 |
323 | # display PYTHONPATH
324 | logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)"))
325 |
326 | # seed setting
327 | random.seed(args.seed)
328 | np.random.seed(args.seed)
329 | logging.info("set random seed = %d" % args.seed)
330 |
331 | # validate rnn options
332 | if args.rnnlm is not None and args.word_rnnlm is not None:
333 | logging.error(
334 | "It seems that both --rnnlm and --word-rnnlm are specified. "
335 | "Please use either option."
336 | )
337 | sys.exit(1)
338 |
339 | # recog
340 | logging.info("backend = " + args.backend)
341 | if args.num_spkrs == 1:
342 | if args.backend == "chainer":
343 | from espnet.asr.chainer_backend.asr import recog
344 | recog(args)
345 | elif args.backend == "pytorch":
346 | if args.num_encs == 1:
347 | # Experimental API that supports custom LMs
348 | if args.api == "v2":
349 | if need_decode:
350 | from espnet.asr.pytorch_backend.recog_test import recog_v2
351 | recog_v2(args, orig_flag)
352 | else:
353 | new_dir = "/".join(args.result_label.split("/")[0:-1])
354 | orig_dir = new_dir.replace("new", "orig")
355 | aug_type = args.recog_set.split("-")[-1]
356 | gini_audio = sort_test_by_gini_v1(args.orig_dir, new_dir, aug_type)
357 | test_data_prep("data/test-orig", "data/" + args.recog_set, "ted2", gini_audio, "gini", 1155, args.recog_set)
358 |
359 | else:
360 | from espnet.asr.pytorch_backend.asr import recog
361 |
362 | if args.dtype != "float32":
363 | raise NotImplementedError(
364 | f"`--dtype {args.dtype}` is only available with `--api v2`"
365 | )
366 | recog(args)
367 | else:
368 | if args.api == "v2":
369 | raise NotImplementedError(
370 | f"--num-encs {args.num_encs} > 1 is not supported in --api v2"
371 | )
372 | else:
373 | from espnet.asr.pytorch_backend.asr import recog
374 |
375 | recog(args)
376 | else:
377 | raise ValueError("Only chainer and pytorch are supported.")
378 | elif args.num_spkrs == 2:
379 | if args.backend == "pytorch":
380 | from espnet.asr.pytorch_backend.asr_mix import recog
381 |
382 | recog(args)
383 | else:
384 | raise ValueError("Only pytorch is supported.")
385 |
386 |
387 | if __name__ == "__main__":
388 | main(sys.argv[1:])
389 |
390 |
--------------------------------------------------------------------------------
/espnet/utils/gini_utils.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import logging
3 | import os
4 | import json
5 | import numpy as np
6 | import math
7 |
8 |
9 | class SpeechSort(object):
10 | def __init__(self, name, s_count, len_diff, mix_sort):
11 | self.name = name
12 | self.s_count = s_count
13 | self.len_diff = len_diff
14 | self.mix_sort = mix_sort
15 |
16 |
17 | def cau_sum_ginis(token_int, all_ginis):
18 | sum_gini = 0
19 | gini_list = []
20 | gini_key = ""
21 | if len(all_ginis) == 0:
22 | gini_list.append(0)
23 | else:
24 | for i in range(len(token_int)):
25 | token = token_int[i]
26 | gini_key += str(token) + " "
27 | if gini_key.strip() in all_ginis.keys():
28 | sum_gini += all_ginis[gini_key.strip()]
29 | gini_list.append(all_ginis[gini_key.strip()])
30 | else:
31 | key = gini_key.strip().split(" ")
32 | if len(key) > 1:
33 | key = " ".join(key[:-1])
34 | else:
35 | key = key[-1]
36 | sum_gini += all_ginis[key]
37 | gini_list.append(all_ginis[key])
38 | return sum_gini, gini_list
39 |
40 |
41 | def compare_sum_gini(source_sum_gini, new_sum_gini, T_s):
42 | if abs(source_sum_gini - new_sum_gini) > T_s:
43 | return True
44 | else:
45 | return False
46 |
47 |
48 | def caul_gini(softmax_value):
49 | gini = 0
50 | class_softmax = softmax_value.cpu().detach().numpy()
51 | #print("class_softmax:",class_softmax)
52 | for i in range(0, len(class_softmax)):
53 | gini += np.square(class_softmax[i])
54 | #print("gini value: ", 1-gini)
55 | return 1-gini
56 |
57 |
58 | def read_wrd_trn(file_path):
59 | file = open(file_path, 'r')
60 | word_list = []
61 | line = file.readline()
62 | word_count = 0
63 | while line:
64 | idx = line.find("(")
65 | line = line[:idx]
66 | word_list.append(line.split(" "))
67 | word_count += len(line.split(" "))
68 | line = file.readline()
69 | print(word_count)
70 | return word_list, word_count
71 |
72 |
73 | def load_token_file(file_path):
74 | f = open(file_path, "r")
75 | line = f.readline()
76 | token_dict = {}
77 | while line:
78 | line = line[:-1]
79 | split_line = line.split(" ")
80 | name = split_line[0]
81 | if "gini" in file_path:
82 | token_dict[name] = [float(gini) for gini in split_line[1:]]
83 | else:
84 | token_dict[name] = split_line[1:]
85 | line = f.readline()
86 |
87 | return token_dict
88 |
89 | def get_word_gini(token_list, token_gini_list):
90 | all_word_list = {}
91 | all_gini_list = {}
92 | for name in token_list.keys():
93 | token = token_list[name]
94 | token_gini = token_gini_list[name]
95 | word_list = []
96 | gini_list = []
97 | word = ""
98 | word_gini = 0
99 | #c_count = 0
100 | for i in range(len(token)):
101 | if token[i] != token[0]:
102 | word = word + token[i]
103 | if token_gini[i] > word_gini:
104 | word_gini = token_gini[i]
105 | #c_count += 1
106 | else:
107 | if len(word) != 0:
108 | gini_list.append(word_gini)
109 | word_list.append(word)
110 | word = ""
111 | #c_count = 0
112 | word_gini = 0
113 | if len(word) != 0:
114 | gini_list.append(word_gini)
115 | word_list.append(word)
116 | all_word_list[name] = word_list
117 | all_gini_list[name] = gini_list
118 |
119 | return all_word_list, all_gini_list
120 |
121 | def get_token_list(json_data):
122 | all_token = {}
123 | for name in json_data["utts"]:
124 | token = json_data["utts"][name]["output"][0]["rec_token"].split(" ")
125 | all_token[name.lower()] = token[:-1]
126 | return all_token
127 |
128 |
129 | def get_word_gini_v1(json_data, token_gini_list, hyp_dict):
130 | all_word_list = {}
131 | all_gini_list = {}
132 | for name in json_data["utts"]:
133 | token = json_data["utts"][name]["output"][0]["rec_token"].split(" ")
134 | token_gini = token_gini_list[name.lower()]
135 | hyp_word = hyp_dict[name.lower()]
136 | word_list = list(hyp_word)
137 | for i in range(len(hyp_word)):
138 | if "*" in hyp_word[i]:
139 | word_list.remove(hyp_word[i])
140 | #word_list = []
141 | gini_list = []
142 | word = ""
143 | word_gini = 0
144 | w_idx = 0
145 | #c_count = 0
146 | for i in range(len(token)):
147 | if token[i] == "":
148 | break
149 | if (i == 0) & (token[i][0]!= "▁"):
150 | token[i] = "▁" + token[i]
151 | word = word + token[i]
152 | if token_gini[i] > word_gini:
153 | word_gini = token_gini[i]
154 | if len(word) > 1:
155 | if word[0:2] == "▁▁":
156 | word = word[1:]
157 | print(name)
158 | if word[1:] == word_list[w_idx].lower():
159 | gini_list.append(word_gini)
160 | w_idx += 1
161 | word = ""
162 | word_gini = 0
163 |
164 | if w_idx != len(word_list):
165 | gini_list.append(word_gini)
166 | if len(word_list) != len(gini_list):
167 | print("name:", name)
168 | print("word, w_idx, word_gini", word, w_idx, word_gini)
169 | print(word_list, len(word_list))
170 | print(gini_list, len(gini_list))
171 | all_word_list[name.lower()] = word_list
172 | all_gini_list[name.lower()] = gini_list
173 |
174 | return all_word_list, all_gini_list
175 |
176 |
177 | def get_gini_threshold(all_gini_list, eval_dict, orig_word_list):
178 | s_gini = 0
179 | i_gini = 0
180 | c_gini = 0
181 | s_count = 0
182 | i_count = 0
183 | c_count = 0
184 | d_count = 0
185 | for name in eval_dict.keys():
186 | if name not in all_gini_list.keys():
187 | gini_name = name.upper()
188 | gini_list = all_gini_list[gini_name]
189 | else:
190 | gini_list = all_gini_list[name]
191 | eval_list = eval_dict[name]
192 | for i in range(len(eval_list)):
193 | if eval_list[i] == "D":
194 | d_count += 1
195 | gini_list.insert(i, -1)
196 | if eval_list[i] == "S":
197 | s_gini += gini_list[i]
198 | s_count += 1
199 | if eval_list[i] == "I":
200 | #print("insert_gini:", name, gini_list[i])
201 | i_gini += gini_list[i]
202 | i_count += 1
203 | if eval_list[i] == "C":
204 | c_gini += gini_list[i]
205 | c_count += 1
206 | print(s_count, i_count, c_count, d_count)
207 | if s_count > i_count + d_count:
208 | flag = "s_first"
209 | else:
210 | flag = "len_first"
211 | if i_gini != 0:
212 | i_t = i_gini/i_count
213 | else:
214 | i_t = 0
215 | print("flag: ", flag)
216 | return s_gini/s_count, i_t, c_gini/c_count, flag
217 |
218 |
219 | def judge_res(ref_word, hyp_word):
220 | eval_list = []
221 | for i in range(len(ref_word)):
222 | if ref_word[i] == hyp_word[i]:
223 | eval_list.append("C")
224 | elif "*" in ref_word[i]:
225 | eval_list.append("I")
226 | elif "*" in hyp_word[i]:
227 | eval_list.append("D")
228 | else:
229 | eval_list.append("S")
230 | return eval_list
231 |
232 |
233 | def load_result_txt(file_path):
234 | f = open(file_path, 'r')
235 | ref_dict = {}
236 | hyp_dict = {}
237 | eval_dict = {}
238 | line = f.readline()
239 | while line:
240 | if "id: (" in line:
241 | name = line[:-1].replace("id: (", "").replace(")", "")
242 | name = "-".join(name.split("-")[1:])
243 | if "REF: " in line:
244 | ref = line[:-1].split()[1:]
245 | ref_dict[name] = ref
246 | if "HYP: " in line:
247 | hyp = line[:-1].split()[1:]
248 | hyp_dict[name] = hyp
249 | if "Eval: " in line:
250 | eval_list = judge_res(ref, hyp)
251 | eval_dict[name] = eval_list
252 | line = f.readline()
253 | return ref_dict, hyp_dict, eval_dict
254 |
255 |
256 | def gini_sort_v1(s_t, all_gini_list, orig_token, new_token, flag):
257 | print("flag is", flag)
258 | sort_list = []
259 | for key in new_token.keys():
260 | if "&" not in key:
261 | diff = 0
262 | else:
263 | orig_name = key.split("&")[0] + "-" + "-".join(key.split("&")[1].split("-")[1:])
264 | orig_len = len(orig_token[orig_name])
265 | new_len = len(new_token[key])
266 | diff = abs(new_len - orig_len)
267 | s_count = 0
268 | for gini in all_gini_list[key]:
269 | if gini > s_t:
270 | s_count += 1
271 | if len(all_gini_list[key]) == 0:
272 | sort_list.append(SpeechSort(key, 1.0, 1.0, 1.0))
273 | else:
274 | sort_list.append(SpeechSort(key, s_count/len(all_gini_list[key]), diff/len(all_gini_list[key]), (s_count + diff)/len(all_gini_list[key])))
275 | sort_list.sort(key=lambda x:x.mix_sort, reverse=True)
276 |
277 | return sort_list
278 |
279 |
280 | def gini_sort(s_t, all_gini_list, orig_token, new_token, flag):
281 | sort_list = []
282 | s_count_list = []
283 | len_diff_list = []
284 |
285 | for key in new_token.keys():
286 | if "&" not in key:
287 | diff = 0
288 | else:
289 | orig_name = key.split("&")[0]
290 | orig_len = len(orig_token[orig_name])
291 | new_len = len(new_token[key])
292 | #diff = abs(new_len - orig_len)/orig_len
293 | diff = abs(new_len - orig_len)
294 | s_count = 0
295 | for gini in all_gini_list[key]:
296 | if gini > s_t:
297 | s_count += 1
298 | sort_list.append(SpeechSort(key, s_count/len(all_gini_list[key]), diff, (s_count + diff)/len(all_gini_list[key])))
299 | sort_list.sort(key=lambda x:x.mix_sort, reverse=True)
300 |
301 | return sort_list
302 |
303 |
304 | def parse_hypothesis(hyp, char_list, gini_list):
305 | """Parse hypothesis.
306 | Args:
307 | hyp (list[dict[str, Any]]): Recognition hypothesis.
308 | char_list (list[str]): List of characters.
309 | Returns:
310 | tuple(str, str, str, float)
311 | """
312 | # remove sos and get results
313 | tokenid_as_list = list(map(int, hyp["yseq"][1:]))
314 | token_as_list = [char_list[idx] for idx in tokenid_as_list]
315 | score = float(hyp["score"])
316 |
317 | # convert to string
318 | tokenid = " ".join([str(idx) for idx in tokenid_as_list])
319 | token = " ".join(token_as_list)
320 | text = "".join(token_as_list).replace("", " ")
321 | ginilist = " ".join([str(v) for v in gini_list])
322 | return text, token, tokenid, score, ginilist
323 |
324 |
325 | def add_results_to_json(js, nbest_hyps, char_list, sum_gini, gini_list):
326 | """Add N-best results to json.
327 | Args:
328 | js (dict[str, Any]): Groundtruth utterance dict.
329 | nbest_hyps_sd (list[dict[str, Any]]):
330 | List of hypothesis for multi_speakers: nutts x nspkrs.
331 | char_list (list[str]): List of characters.
332 | Returns:
333 | dict[str, Any]: N-best results added utterance dict.
334 | """
335 | # copy old json info
336 | new_js = dict()
337 | new_js["utt2spk"] = js["utt2spk"]
338 | new_js["output"] = []
339 |
340 | for n, hyp in enumerate(nbest_hyps, 1):
341 | # parse hypothesis
342 | rec_text, rec_token, rec_tokenid, score, gini_list = parse_hypothesis(hyp, char_list, gini_list)
343 |
344 | # copy ground-truth
345 | if len(js["output"]) > 0:
346 | out_dic = dict(js["output"][0].items())
347 | else:
348 | # for no reference case (e.g., speech translation)
349 | out_dic = {"name": ""}
350 |
351 | # update name
352 | out_dic["name"] += "[%d]" % n
353 |
354 | # add recognition results
355 | out_dic["rec_text"] = rec_text
356 | out_dic["rec_token"] = rec_token
357 | out_dic["rec_tokenid"] = rec_tokenid
358 | out_dic["score"] = score
359 | out_dic["sum_gini"] = sum_gini
360 | out_dic["gini_list"] = gini_list
361 |
362 | # add to list of N-best result dicts
363 | new_js["output"].append(out_dic)
364 |
365 | # show 1-best result
366 | if n == 1:
367 | if "text" in out_dic.keys():
368 | logging.info("groundtruth: %s" % out_dic["text"])
369 | logging.info("prediction : %s" % out_dic["rec_text"])
370 |
371 | return new_js
372 |
373 |
374 | def load_gini(file_path):
375 | f = open(file_path, 'r')
376 | line = f.readline()
377 | orig_gini = {}
378 | while line:
379 | line = line[:-1]
380 | key = line.split(" ")[0]
381 | if "sum_gini" in file_path:
382 | orig_gini[key] = float(line.split(" ")[1])
383 | else:
384 | if len(line[:-1].split(" ")[1:]) > 0:
385 | gini_list = [float(x) for x in line.split(" ")[1:]]
386 | else:
387 | gini_list = [0.0]
388 | orig_gini[key] = gini_list
389 | line = f.readline()
390 | return orig_gini
391 |
392 |
393 | def load_gini_v1(file_path):
394 | f = open(file_path, 'r')
395 | line = f.readline()
396 | orig_gini = {}
397 | while line:
398 | line = line[:-1]
399 | key = line.split(" ")[0].lower()
400 | if "sum_gini" in file_path:
401 | orig_gini[key] = float(line.split(" ")[1])
402 | else:
403 | if len(line[:-1].split(" ")[1:]) > 0:
404 | gini_list = [float(x) for x in line.split(" ")[1:]]
405 | else:
406 | gini_list = [0.0]
407 | orig_gini[key] = gini_list
408 | line = f.readline()
409 | return orig_gini
410 |
411 |
412 | def load_text(file_path):
413 | f = open(file_path, 'r')
414 | line = f.readline()
415 | orig_text = {}
416 | while line:
417 | line = line[:-1]
418 | key = line.split(" ")[0]
419 | orig_text[key] = line.split(" ")[1:]
420 | line = f.readline()
421 | return orig_text
422 |
423 |
424 | def load_score(file_path):
425 | f = open(file_path, 'r')
426 | line = f.readline()
427 | orig_score = {}
428 | while line:
429 | line = line[:-1]
430 | key = line.split(" ")[0]
431 | orig_score[key] = line.split(" ")[1]
432 | line = f.readline()
433 | return orig_score
434 |
435 |
436 | def load_token(file_path):
437 | f = open(file_path, 'r')
438 | line = f.readline()
439 | orig_token = {}
440 | while line:
441 | line = line[:-1]
442 | key = line.split(" ")[0]
443 | orig_token[key] = line.split(" ")[1:]
444 | line = f.readline()
445 | return orig_token
446 |
447 |
448 | def mkdir(path):
449 | folder = os.path.exists(path)
450 | if not folder:
451 | os.makedirs(path)
452 |
453 |
454 | def write_new_info(new_dir, file_name, result):
455 | mkdir(new_dir)
456 | f = open(new_dir + "/" + file_name, 'w')
457 | for key in result.keys():
458 | if isinstance(result[key], list):
459 | line = key + " " + " ".join([str(x) for x in result[key]]) + "\n"
460 | else:
461 | line = key + " " + str(result[key]) + "\n"
462 | f.write(line)
463 |
464 |
465 | def write_gini_select_result(new_dir, file_name, result):
466 | mkdir(new_dir)
467 | f = open(new_dir + "/" + file_name, 'w')
468 | for item in result:
469 | line = item.name + " " + str(item.len_diff) + " " + str(item.s_count)+ "\n"
470 | f.write(line)
471 |
472 |
473 | def update_data_json(json_data, key_name):
474 | new_json = copy.deepcopy(json_data)
475 | new_json["utts"].pop(key_name)
476 | return new_json
477 |
478 |
479 | def load_data_json(path):
480 | with open(path, 'r', encoding='utf8')as fp:
481 | json_data = json.load(fp)
482 | return json_data
483 |
484 | def write_json_data(path, new_json) :
485 | mkdir(path)
486 | with open(path + "/" + "data.json", "w", encoding='utf8')as fp:
487 | json.dump(new_json, fp, ensure_ascii=False)
488 |
489 |
490 | def load_hyp_ref(path):
491 | f = open(path, 'r')
492 | line = f.readline()
493 | orig_dict = {}
494 | while line:
495 | value = str(line)
496 | line = line[:-1]
497 | line = line.replace("\t", " ")
498 | key = line.split(" ")[-1][1:-1]
499 | idx = key.find("-")
500 | key = key[idx + 1:]
501 | orig_dict[key] = value
502 | line = f.readline()
503 | return orig_dict
504 |
505 |
506 | def write_hyp_ref(path, name, result):
507 | mkdir(path)
508 | f = open(path + "/" + name, 'w')
509 | for key in result.keys():
510 | f.write(result[key])
511 | f.close()
512 |
--------------------------------------------------------------------------------
/espnet/utils/gini_guide.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | from espnet.utils.gini_utils import *
4 | import random
5 | import os
6 | import shutil
7 |
8 | def sort_test_by_gini(old_dir, new_dir, aug_type, selected_num):
9 | new_token_int = load_gini(new_dir + "/token_int")
10 | new_token = load_token(new_dir + "/token")
11 | orig_token = load_token(old_dir + "/token")
12 | new_token_gini = load_gini(new_dir + "/token_gini")
13 | orig_token_gini = load_gini(old_dir + "/token_gini")
14 | orig_word_list, orig_gini_list = get_word_gini(orig_token, orig_token_gini)
15 | new_word_list, new_gini_list = get_word_gini(new_token, new_token_gini)
16 | ref_dict, hyp_dict, orig_eval_dict = load_result_txt(old_dir + "/score_wer/result.txt")
17 | #s_t, i_t, c_t, flag = get_gini_threshold(orig_gini_list, orig_eval_dict, orig_word_list)
18 | s_t, i_t, c_t, flag = get_gini_threshold(orig_token_gini, orig_eval_dict, orig_token)
19 | #flag = "len_first"
20 | sort_list = gini_sort(s_t, new_gini_list, orig_token, new_token, flag)
21 | new_text = load_text(new_dir + "/text")
22 | new_score = load_score(new_dir + "/score")
23 | new_wer_ref = load_hyp_ref(new_dir + "/score_wer/ref.trn")
24 | new_wer_hyp = load_hyp_ref(new_dir + "/score_wer/hyp.trn")
25 | new_cer_ref = load_hyp_ref(new_dir + "/score_cer/ref.trn")
26 | new_cer_hyp = load_hyp_ref(new_dir + "/score_cer/hyp.trn")
27 | folder = os.path.exists(new_dir + "/score_ter/")
28 | if folder:
29 | new_ter_ref = load_hyp_ref(new_dir + "/score_ter/ref.trn")
30 | new_ter_hyp = load_hyp_ref(new_dir + "/score_ter/hyp.trn")
31 | ter_ref_result = dict(new_ter_ref)
32 | ter_hyp_result = dict(new_ter_hyp)
33 |
34 | token_result = dict()
35 | token_int_result = dict()
36 | token_gini_result = dict()
37 | #sum_gini_result = dict()
38 | text_result = dict()
39 | score_result = dict()
40 | wer_ref_result = dict()
41 | wer_hyp_result = dict()
42 | cer_ref_result = dict()
43 | cer_hyp_result = dict()
44 |
45 | for item in sort_list[:selected_num]:
46 | key = item.name
47 | token_result[key] = new_token[key]
48 | token_int_result[key] = new_token_int[key]
49 | token_gini_result[key] = new_token_gini[key]
50 | text_result[key] = new_text[key]
51 | score_result[key] = new_score[key]
52 | wer_hyp_result[key] = new_wer_hyp[key]
53 | wer_ref_result[key] = new_wer_ref[key]
54 | cer_hyp_result[key] = new_cer_hyp[key]
55 | cer_ref_result[key] = new_cer_ref[key]
56 | if folder:
57 | ter_hyp_result[key] = new_ter_hyp[key]
58 | ter_ref_result[key] = new_ter_ref[key]
59 |
60 | result_dir = new_dir + "/gini-" + str(selected_num)
61 | write_new_info(result_dir, "text", text_result)
62 | write_new_info(result_dir, "score", score_result)
63 | write_new_info(result_dir, "token", token_result)
64 | write_new_info(result_dir, "token_gini", token_gini_result)
65 | write_new_info(result_dir, "token_int", token_int_result)
66 | write_hyp_ref(result_dir + "/score_wer/", "hyp.trn", wer_hyp_result)
67 | write_hyp_ref(result_dir + "/score_wer/", "ref.trn", wer_ref_result)
68 | write_hyp_ref(result_dir + "/score_cer/", "hyp.trn", cer_hyp_result)
69 | write_hyp_ref(result_dir + "/score_cer/", "ref.trn", cer_ref_result)
70 | if folder:
71 | write_hyp_ref(result_dir + "/score_ter/", "hyp.trn", ter_hyp_result)
72 | write_hyp_ref(result_dir + "/score_ter/", "ref.trn", ter_ref_result)
73 | return token_result
74 |
75 | def sort_test_by_gini_v1(old_dir, new_dir, aug_type):
76 | orig_token_gini = load_gini_v1(old_dir + "/token_gini")
77 | selected_num = len(orig_token_gini)
78 | new_token_gini = load_gini_v1(new_dir + "/token_gini")
79 | print(new_dir + "/data.json")
80 | new_json_data = load_data_json(new_dir + "/data.json")
81 | orig_json_data = load_data_json(old_dir + "/data.json")
82 | res_json = copy.deepcopy(new_json_data)
83 | token_gini_result = dict()
84 | orig_ref_dict, orig_hyp_dict, orig_eval_dict = load_result_txt(old_dir + "/result.wrd.txt")
85 | new_ref_dict, new_hyp_dict, new_eval_dict = load_result_txt(new_dir + "/result.wrd.txt")
86 | orig_token_list = get_token_list(orig_json_data)
87 | new_token_list = get_token_list(new_json_data)
88 | orig_word_list, orig_gini_list = get_word_gini_v1(orig_json_data, orig_token_gini, orig_hyp_dict)
89 | new_word_list, new_gini_list = get_word_gini_v1(new_json_data, new_token_gini, new_hyp_dict)
90 | s_t, i_t, c_t, flag = get_gini_threshold(orig_gini_list, orig_eval_dict, orig_word_list)
91 | print(s_t, i_t, c_t)
92 | sort_list = gini_sort_v1(s_t, new_gini_list, orig_word_list, new_word_list, flag)[:selected_num]
93 | res_key = []
94 | for item in sort_list:
95 | key = item.name
96 | token_gini_result[key] = new_token_gini[key]
97 | res_key.append(key)
98 |
99 | for key in new_json_data["utts"]:
100 | if key.lower() not in res_key:
101 | del res_json["utts"][key]
102 |
103 | result_dir = new_dir + "/gini-" + str(selected_num)
104 | write_new_info(result_dir, "token_gini", token_gini_result)
105 | write_json_data(result_dir, res_json)
106 | return token_gini_result
107 |
108 |
109 | def sort_type_by_diff(aug_type, selected_num, diffs):
110 | class_dict = {}
111 | for key in diffs.keys():
112 | seq = key.split("&")[1].replace(".wav", "")
113 | if ("f" in seq) | ("n" in seq) | ("a" in seq):
114 | seq = key.split("&")[1].replace(".wav", "")[-1]
115 | if aug_type == "feature":
116 | seq = int(seq)
117 | if aug_type == "noise":
118 | seq = int(seq) % 8
119 | if aug_type == "room":
120 | seq = int(int(seq)/ 4)
121 | if seq not in class_dict.keys():
122 | class_dict[seq] = list()
123 | item = (key, diffs[key])
124 | class_dict[seq].append(item)
125 |
126 | result = []
127 | num = int(selected_num/len(class_dict.keys()))
128 | for key in class_dict.keys():
129 | sort_result = sorted(class_dict[key], reverse=True, key=lambda kv:(kv[1], kv[0]))
130 | sort_result = sort_result[:num]
131 | result = result + sort_result
132 |
133 | while len(result) < selected_num:
134 | key = random.choice(list(diffs.keys()))
135 | item = (key,diffs[key])
136 | if item not in result:
137 | result.append(item)
138 | total_gini = 0
139 | for item in result:
140 | total_gini += item[1]
141 | print(aug_type, "selects ", total_gini/len(result))
142 |
143 | return result
144 |
145 |
146 | def sort_type_by_diff_v1(aug_type, selected_num, diffs):
147 | class_dict = {}
148 | for key in diffs.keys():
149 | seq = key.split("&")[1].replace(".wav", "").split("-")[0]
150 | if ("f" in seq) | ("n" in seq) | ("a" in seq):
151 | seq = key.split("&")[1].replace(".wav", "")[-1]
152 | if aug_type == "feature":
153 | seq = int(seq)
154 | if aug_type == "noise":
155 | seq = int(seq) % 8
156 | if aug_type == "room":
157 | seq = int(int(seq)/ 4)
158 | if seq not in class_dict.keys():
159 | class_dict[seq] = list()
160 | item = (key, float(diffs[key]))
161 | class_dict[seq].append(item)
162 |
163 | result = []
164 | num = int(selected_num/len(class_dict.keys()))
165 | for key in class_dict.keys():
166 | sort_result = sorted(class_dict[key], reverse=True, key=lambda kv:(kv[1], kv[0]))
167 | sort_result = sort_result[:num]
168 | result = result + sort_result
169 |
170 | while len(result) < selected_num:
171 | key = random.choice(list(diffs.keys()))
172 | item = (key,diffs[key])
173 | if item not in result:
174 | result.append(item)
175 | total_gini = 0
176 | for item in result:
177 | total_gini += float(item[1])
178 | print(aug_type, "selects ", total_gini/len(result))
179 | return result
180 |
181 |
182 | def gen_new_prep_file(new_file, new_audio, dataset, flag, selected_num, test_set):
183 | f2 = open(new_file, "r")
184 | new_result = dict()
185 | line = f2.readline()
186 | while line:
187 | line = line[:-1]
188 | key = line.split(" ")[0]
189 | if "ted" in dataset:
190 | key = key.lower()
191 | new_result[key] = line
192 | line = f2.readline()
193 |
194 | retrain_result = dict()
195 | print(new_audio)
196 | for key in new_audio.keys():
197 | if ("spk2gender" in new_file) & ("TIMIT" in dataset):
198 | key = key.split("_")[0]
199 | retrain_result[key] = new_result[key]
200 |
201 | new_path = new_file.split("/")
202 | new_path[-2] = test_set + "-"+ flag+ "-" + str(selected_num)
203 | if "ted" in dataset:
204 | new_path[-2] = new_path[-2] + ".orig"
205 | mkdir("/".join(new_path[:-1]))
206 | new_path = "/".join(new_path)
207 | w = open(new_path, "w")
208 |
209 | for key in retrain_result.keys():
210 | w.write(retrain_result[key] + "\n")
211 | w.close()
212 | return retrain_result
213 |
214 | def gen_new_wavscp(new_wavscp, new_audio, dataset, flag, selected_num, test_set):
215 | f2 = open(new_wavscp, "r")
216 | new_scp = dict()
217 | line = f2.readline()
218 | while line:
219 | line = line[:-1]
220 | key = line.split(" ")[0]
221 | if "ted" in dataset:
222 | key = key.lower()
223 | new_scp[key] = line
224 | line = f2.readline()
225 | retrain_scp = set()
226 | for key in new_audio.keys():
227 | if "ted" in dataset:
228 | key = key.split("-")[0]
229 | retrain_scp.add(new_scp[key])
230 | new_path = new_wavscp.split("/")
231 | new_path[-2] = test_set + "-"+ flag+ "-" + str(selected_num)
232 | if "ted" in dataset:
233 | new_path[-2] = new_path[-2] + ".orig"
234 | mkdir("/".join(new_path[:-1]))
235 | new_path = "/".join(new_path)
236 | w = open(new_path, "w")
237 | for line in retrain_scp:
238 | w.write(line + "\n")
239 | w.close()
240 |
241 | def gen_new_recog2file(new_recogfile, new_audio, dataset, flag, selected_num, test_set):
242 | f2 = open(new_recogfile, "r")
243 | new_recog = dict()
244 | line = f2.readline()
245 | while line:
246 | line = line[:-1]
247 | key = line.split(" ")[0].lower()
248 | new_recog[key] = line
249 | line = f2.readline()
250 | recog_result = set()
251 | for key in new_audio.keys():
252 | if "ted" in dataset:
253 | key = key.split("-")[0]
254 | recog_result.add(new_recog[key])
255 | new_path = new_recogfile.split("/")
256 | new_path[-2] = test_set + "-"+ flag+ "-" + str(selected_num)
257 | if "ted" in dataset:
258 | new_path[-2] = new_path[-2] + ".orig"
259 | mkdir("/".join(new_path[:-1]))
260 | new_path = "/".join(new_path)
261 | w = open(new_path, "w")
262 | for line in recog_result:
263 | w.write(line + "\n")
264 | w.close()
265 |
266 |
267 | def mkdir(path):
268 | folder = os.path.exists(path)
269 | if not folder:
270 | os.makedirs(path)
271 |
272 | def gen_new_spk2utt(orgi_spk2utt, new_audio, dataset, flag, selected_num, test_set):
273 | f = open(orgi_spk2utt, "r")
274 | spk2utt = dict()
275 | line = f.readline()
276 | while line:
277 | line = line[:-1]
278 | key = line.split(" ")[0]
279 | spk2utt[key] = line.split(" ")[1:]
280 | line = f.readline()
281 | for key in new_audio.keys():
282 | if "an4" in dataset:
283 | spk = key.split("-")[0]
284 | if "TIMIT" in dataset:
285 | spk = key.split("_")[0]
286 | spk2utt[spk].append(key.replace(".wav", ""))
287 | new_path = orgi_spk2utt.split("/")
288 | new_path[-2] = test_set + "-"+ flag+ "-" + str(selected_num)
289 | if "ted" in dataset:
290 | new_path[-2] = new_path[-2] + ".orig"
291 | mkdir("/".join(new_path[:-1]))
292 | new_path = "/".join(new_path)
293 | w = open(new_path, "w")
294 | for key in spk2utt.keys():
295 | w.write(key.replace(".wav", "") + " " + " ".join(spk2utt[key]) + "\n")
296 | w.close()
297 |
298 |
299 | def gen_new_spk2utt_v1(new_utt2spk, new_spk2utt, dataset, flag, selected_num, test_set):
300 | spk2utt = dict()
301 | for key in new_utt2spk.keys():
302 | spk = new_utt2spk[key].split(" ")[-1]
303 | if spk in spk2utt.keys():
304 | spk2utt[spk].append(key)
305 | else:
306 | spk2utt[spk] = []
307 | spk2utt[spk].append(key)
308 |
309 | new_path = new_spk2utt.split("/")
310 | new_path[-2] = test_set + "-"+ flag+ "-" + str(selected_num)
311 | if "ted" in dataset:
312 | new_path[-2] = new_path[-2] + ".orig"
313 | mkdir("/".join(new_path[:-1]))
314 | new_path = "/".join(new_path)
315 | w = open(new_path, "w")
316 | for key in spk2utt.keys():
317 | w.write(key + " " + " ".join(spk2utt[key]) + "\n")
318 | w.close()
319 |
320 |
321 | def gen_new_stm(new_stm, new_audio, dataset, flag, selected_num, test_set):
322 | f2 = open(new_stm, "r")
323 | orgi_stm = dict()
324 | line = f2.readline()
325 | start = []
326 | while ";;" in line:
327 | start.append(line)
328 | line = f2.readline()
329 | while line:
330 | line = line[:-1]
331 | key = line.split(" ")[0]
332 | if "ted" in dataset:
333 | key = key.lower()
334 | if key in orgi_stm.keys():
335 | orgi_stm[key].append(line)
336 | else:
337 | orgi_stm[key] = []
338 | orgi_stm[key].append(line)
339 | line = f2.readline()
340 | stm_result = dict()
341 | for key in new_audio.keys():
342 | key = key.split("-")[0]
343 | if key not in stm_result.keys():
344 | stm_result[key] = orgi_stm[key]
345 |
346 | new_path = new_stm.split("/")
347 | new_path[-2] = test_set + "-"+ flag+ "-" + str(selected_num)
348 | if "ted" in dataset:
349 | new_path[-2] = new_path[-2] + ".orig"
350 | mkdir("/".join(new_path[:-1]))
351 | new_path = "/".join(new_path)
352 | w = open(new_path, "w")
353 | for line in start:
354 | w.write(line + "\n")
355 | for key in stm_result.keys():
356 | for line in stm_result[key]:
357 | w.write(line + "\n")
358 | w.close()
359 |
360 | def copy_file(source, test_set, flag, selected_num):
361 | destination = source.split("/")
362 | destination[-2] = test_set + "-"+ flag+ "-" + str(selected_num)
363 | destination[-2] = destination[-2] + ".orig"
364 | mkdir("/".join(destination[:-1]))
365 | destination = "/".join(destination)
366 | shutil.copyfile(source, destination)
367 | if os.path.exists(destination):
368 | logging.info("copy success")
369 |
370 | def copy_dir(source, test_set, flag, selected_num):
371 | destination = source.split("/")
372 | destination[-2] = test_set + "-"+ flag+ "-" + str(selected_num)
373 | mkdir("/".join(destination[:-1]))
374 | destination = "/".join(destination)
375 | shutil.copytree(source, destination)
376 |
377 |
378 | def test_data_prep(orgi_dir, new_dir, dataset, new_audio, flag, selected_num, test_set):
379 | if "an4" in dataset:
380 | gen_new_prep_file(new_dir + "/text", new_audio, dataset, flag, selected_num, test_set)
381 | gen_new_prep_file(new_dir + "/utt2spk", new_audio, dataset, flag, selected_num, test_set)
382 | gen_new_wavscp(new_dir + "/wav.scp", new_audio, dataset, flag, selected_num, test_set)
383 | gen_new_spk2utt(orgi_dir + "/spk2utt", new_audio, dataset, flag, selected_num, test_set)
384 |
385 | if "TIMIT" in dataset:
386 | gen_new_prep_file(new_dir + "/text", new_audio, dataset, flag, selected_num, test_set)
387 | gen_new_prep_file(new_dir + "/utt2spk", new_audio, dataset, flag, selected_num, test_set)
388 | gen_new_prep_file(new_dir + "/spk2gender", new_audio, dataset, flag, selected_num, test_set)
389 | gen_new_wavscp(new_dir + "/wav.scp", new_audio, dataset, flag, selected_num, test_set)
390 | gen_new_spk2utt(orgi_dir + "/spk2utt", new_audio, dataset, flag, selected_num, test_set)
391 |
392 | if "ted" in dataset:
393 | new_utt2spk = gen_new_prep_file(new_dir + "/utt2spk", new_audio, dataset, flag, selected_num, test_set)
394 | gen_new_spk2utt_v1(new_utt2spk, new_dir + "/spk2utt", dataset, flag, selected_num, test_set)
395 | gen_new_wavscp(new_dir + "/wav.scp", new_audio, dataset, flag, selected_num, test_set)
396 | gen_new_prep_file(new_dir + "/text", new_audio, dataset, flag, selected_num, test_set)
397 | gen_new_prep_file(new_dir + "/segments", new_audio, dataset, flag, selected_num, test_set)
398 | #gen_new_prep_file(new_dir + "/utt2dur", new_audio, dataset, flag, selected_num, test_set)
399 | #gen_new_prep_file(new_dir + "/utt2num_frames", new_audio, dataset, flag, selected_num, test_set)
400 | gen_new_stm(new_dir + "/stm", new_audio, dataset, flag, selected_num, test_set)
401 | gen_new_recog2file(new_dir + "/reco2file_and_channel", new_audio, dataset, flag, selected_num, test_set)
402 | #gen_new_prep_file(new_dir + "/feats.scp", new_audio, dataset, flag, selected_num, test_set)
403 | copy_file(new_dir + "/glm", test_set, flag, selected_num)
404 | #copy_file(new_dir + "/frame_shift",test_set, flag, selected_num)
405 | #copy_dir(new_dir + "/conf", test_set, flag, selected_num)
406 |
407 |
408 |
409 |
410 |
411 |
412 |
--------------------------------------------------------------------------------
/espnet2/bin/asr_test.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import argparse
3 | import logging
4 | from pathlib import Path
5 | import sys
6 | from typing import Optional
7 | from typing import Sequence
8 | from typing import Tuple
9 | from typing import Union
10 | import os
11 | import numpy as np
12 | import torch
13 | from typeguard import check_argument_types
14 | from typeguard import check_return_type
15 | from typing import List
16 |
17 | from espnet.nets.batch_beam_search import BatchBeamSearch
18 | from espnet.nets.batch_beam_search_online_sim import BatchBeamSearchOnlineSim
19 | # from espnet.nets.beam_search import BeamSearch
20 | from espnet.nets.test_beam_search import TestBeamSearch
21 | from espnet.nets.test_beam_search import TestHypothesis
22 | from espnet.nets.pytorch_backend.transformer.subsampling import TooShortUttError
23 | from espnet.nets.scorer_interface import BatchScorerInterface
24 | from espnet.nets.scorers.ctc import CTCPrefixScorer
25 | from espnet.nets.scorers.length_bonus import LengthBonus
26 | from espnet.utils.cli_utils import get_commandline_args
27 | from espnet2.fileio.datadir_writer import DatadirWriter
28 | from espnet2.tasks.asr import ASRTask
29 | from espnet2.tasks.lm import LMTask
30 | from espnet2.text.build_tokenizer import build_tokenizer
31 | from espnet2.text.token_id_converter import TokenIDConverter
32 | from espnet2.torch_utils.device_funcs import to_device
33 | from espnet2.torch_utils.set_all_random_seed import set_all_random_seed
34 | from espnet2.utils import config_argparse
35 | from espnet2.utils.types import str2bool
36 | from espnet2.utils.types import str2triple_str
37 | from espnet2.utils.types import str_or_none
38 | from espnet.utils.gini_utils import *
39 | from espnet.utils.gini_guide import *
40 | from espnet.utils.retrain_utils import *
41 |
42 | class Speech2Text:
43 | """Speech2Text class
44 |
45 | Examples:
46 | >>> import soundfile
47 | >>> speech2text = Speech2Text("asr_config.yml", "asr.pth")
48 | >>> audio, rate = soundfile.read("speech.wav")
49 | >>> speech2text(audio)
50 | [(text, token, token_int, hypothesis object), ...]
51 |
52 | """
53 |
54 | def __init__(
55 | self,
56 | asr_train_config: Union[Path, str],
57 | asr_model_file: Union[Path, str] = None,
58 | lm_train_config: Union[Path, str] = None,
59 | lm_file: Union[Path, str] = None,
60 | token_type: str = None,
61 | bpemodel: str = None,
62 | device: str = "cpu",
63 | maxlenratio: float = 0.0,
64 | minlenratio: float = 0.0,
65 | batch_size: int = 1,
66 | dtype: str = "float32",
67 | beam_size: int = 20,
68 | ctc_weight: float = 0.5,
69 | lm_weight: float = 1.0,
70 | penalty: float = 0.0,
71 | nbest: int = 1,
72 | streaming: bool = False,
73 | ):
74 | assert check_argument_types()
75 |
76 | # 1. Build ASR model
77 | scorers = {}
78 | asr_model, asr_train_args = ASRTask.build_model_from_file(
79 | asr_train_config, asr_model_file, device
80 | )
81 | asr_model.to(dtype=getattr(torch, dtype)).eval()
82 | decoder = asr_model.decoder
83 | ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
84 | token_list = asr_model.token_list
85 | scorers.update(
86 | decoder=decoder,
87 | ctc=ctc,
88 | length_bonus=LengthBonus(len(token_list)),
89 | )
90 | logging.info(f"asr_model: {asr_model}")
91 | # 2. Build Language model
92 | if lm_train_config is not None:
93 | lm, lm_train_args = LMTask.build_model_from_file(
94 | lm_train_config, lm_file, device
95 | )
96 | scorers["lm"] = lm.lm
97 |
98 | # 3. Build BeamSearch object
99 | weights = dict(
100 | decoder=1.0 - ctc_weight,
101 | ctc=ctc_weight,
102 | lm=lm_weight,
103 | length_bonus=penalty,
104 | )
105 | test_beam_search = TestBeamSearch(
106 | beam_size=beam_size,
107 | weights=weights,
108 | scorers=scorers,
109 | sos=asr_model.sos,
110 | eos=asr_model.eos,
111 | vocab_size=len(token_list),
112 | token_list=token_list,
113 | pre_beam_score_key=None if ctc_weight == 1.0 else "full",
114 | )
115 | # TODO(karita): make all scorers batchfied
116 | if batch_size == 1:
117 | non_batch = [
118 | k
119 | for k, v in test_beam_search.full_scorers.items()
120 | if not isinstance(v, BatchScorerInterface)
121 | ]
122 | if len(non_batch) == 0:
123 | if streaming:
124 | test_beam_search.__class__ = BatchBeamSearchOnlineSim
125 | test_beam_search.set_streaming_config(asr_train_config)
126 | logging.info("BatchBeamSearchOnlineSim implementation is selected.")
127 | else:
128 | test_beam_search.__class__ = BatchBeamSearch
129 | logging.info("BatchBeamSearch implementation is selected.")
130 | else:
131 | logging.warning(
132 | f"As non-batch scorers {non_batch} are found, "
133 | f"fall back to non-batch implementation."
134 | )
135 | test_beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
136 | for scorer in scorers.values():
137 | if isinstance(scorer, torch.nn.Module):
138 | scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
139 | logging.info(f"test_beam_search: {test_beam_search}")
140 | logging.info(f"Decoding device={device}, dtype={dtype}")
141 |
142 | # 4. [Optional] Build Text converter: e.g. bpe-sym -> Text
143 | if token_type is None:
144 | token_type = asr_train_args.token_type
145 | if bpemodel is None:
146 | bpemodel = asr_train_args.bpemodel
147 |
148 | if token_type is None:
149 | tokenizer = None
150 | elif token_type == "bpe":
151 | if bpemodel is not None:
152 | tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
153 | else:
154 | tokenizer = None
155 | else:
156 | tokenizer = build_tokenizer(token_type=token_type)
157 | converter = TokenIDConverter(token_list=token_list)
158 | logging.info(f"Text tokenizer: {tokenizer}")
159 |
160 | self.asr_model = asr_model
161 | self.asr_train_args = asr_train_args
162 | self.converter = converter
163 | self.tokenizer = tokenizer
164 | self.test_beam_search = test_beam_search
165 | self.maxlenratio = maxlenratio
166 | self.minlenratio = minlenratio
167 | self.device = device
168 | self.dtype = dtype
169 | self.nbest = nbest
170 |
171 | @torch.no_grad()
172 | def __call__(
173 | self, speech: Union[torch.Tensor, np.ndarray]
174 | ) -> List[Tuple[Optional[str], List[str], List[int], TestHypothesis]]:
175 | """Inference
176 |
177 | Args:
178 | data: Input speech data
179 | Returns:
180 | text, token, token_int, hyp
181 |
182 | """
183 | assert check_argument_types()
184 |
185 | # Input as audio signal
186 | if isinstance(speech, np.ndarray):
187 | speech = torch.tensor(speech)
188 |
189 | # data: (Nsamples,) -> (1, Nsamples)
190 | speech = speech.unsqueeze(0).to(getattr(torch, self.dtype))
191 | # lenghts: (1,)
192 | lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
193 | batch = {"speech": speech, "speech_lengths": lengths}
194 |
195 | # a. To device
196 | batch = to_device(batch, device=self.device)
197 |
198 | # b. Forward Encoder
199 | enc, _ = self.asr_model.encode(**batch)
200 | assert len(enc) == 1, len(enc)
201 |
202 | # c. Passed the encoder result and the beam search
203 | nbest_hyps, all_ginis = self.test_beam_search(
204 | x=enc[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
205 | )
206 | nbest_hyps = nbest_hyps[: self.nbest]
207 |
208 | results = []
209 | for hyp in nbest_hyps:
210 | assert isinstance(hyp, TestHypothesis), type(hyp)
211 |
212 | # remove sos/eos and get results
213 | token_int = hyp.yseq[1:-1].tolist()
214 |
215 | # remove blank symbol id, which is assumed to be 0
216 | token_int = list(filter(lambda x: x != 0, token_int))
217 | sum_gini, gini_list = cau_sum_ginis(token_int, all_ginis)
218 | # Change integer-ids to tokens
219 | token = self.converter.ids2tokens(token_int)
220 | if self.tokenizer is not None:
221 | text = self.tokenizer.tokens2text(token)
222 | else:
223 | text = None
224 | results.append((text, token, token_int, hyp))
225 |
226 | assert check_return_type(results)
227 | return results, sum_gini, gini_list
228 |
229 | def str_2_bool(s):
230 | return True if s.lower() =='true' else False
231 |
232 | def test(
233 | output_dir: str,
234 | maxlenratio: float,
235 | minlenratio: float,
236 | batch_size: int,
237 | dtype: str,
238 | beam_size: int,
239 | ngpu: int,
240 | seed: int,
241 | ctc_weight: float,
242 | lm_weight: float,
243 | penalty: float,
244 | nbest: int,
245 | num_workers: int,
246 | log_level: Union[int, str],
247 | data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
248 | key_file: Optional[str],
249 | asr_train_config: str,
250 | asr_model_file: str,
251 | lm_train_config: Optional[str],
252 | lm_file: Optional[str],
253 | word_lm_train_config: Optional[str],
254 | word_lm_file: Optional[str],
255 | token_type: Optional[str],
256 | bpemodel: Optional[str],
257 | allow_variable_data_keys: bool,
258 | streaming: bool,
259 | orig_flag: str,
260 | orig_dir: str,
261 | need_decode: bool,
262 | ):
263 | assert check_argument_types()
264 | if batch_size > 1:
265 | raise NotImplementedError("batch decoding is not implemented")
266 | if word_lm_train_config is not None:
267 | raise NotImplementedError("Word LM is not implemented")
268 | if ngpu > 1:
269 | raise NotImplementedError("only single GPU decoding is supported")
270 |
271 | logging.basicConfig(
272 | level=log_level,
273 | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
274 | )
275 |
276 | if ngpu >= 1:
277 | device = "cuda"
278 | else:
279 | device = "cpu"
280 |
281 | orig_flag = str_2_bool(orig_flag)
282 | # 1. Set random-seed
283 | set_all_random_seed(seed)
284 |
285 | # 2. Build speech2text
286 | speech2text = Speech2Text(
287 | asr_train_config=asr_train_config,
288 | asr_model_file=asr_model_file,
289 | lm_train_config=lm_train_config,
290 | lm_file=lm_file,
291 | token_type=token_type,
292 | bpemodel=bpemodel,
293 | device=device,
294 | maxlenratio=maxlenratio,
295 | minlenratio=minlenratio,
296 | dtype=dtype,
297 | beam_size=beam_size,
298 | ctc_weight=ctc_weight,
299 | lm_weight=lm_weight,
300 | penalty=penalty,
301 | nbest=nbest,
302 | streaming=streaming,
303 | )
304 |
305 | # 3. Build data-iterator
306 | loader = ASRTask.build_streaming_iterator(
307 | data_path_and_name_and_type,
308 | dtype=dtype,
309 | batch_size=batch_size,
310 | key_file=key_file,
311 | num_workers=num_workers,
312 | preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
313 | collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
314 | allow_variable_data_keys=allow_variable_data_keys,
315 | inference=True,
316 | )
317 |
318 | # 7 .Start for-loop
319 | # FIXME(kamo): The output format should be discussed about
320 | if orig_flag:
321 | orig_sum_gini = {}
322 | orig_gini_list = {}
323 | else:
324 | new_sum_gini = {}
325 | new_gini_list = {}
326 | with DatadirWriter(output_dir) as writer:
327 | for keys, batch in loader:
328 | key = keys[0]
329 | assert isinstance(batch, dict), type(batch)
330 | assert all(isinstance(s, str) for s in keys), keys
331 | _bs = len(next(iter(batch.values())))
332 | assert len(keys) == _bs, f"{len(keys)} != {_bs}"
333 | batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
334 |
335 | # N-best list of (text, token, token_int, hyp_object)
336 | try:
337 | results, sum_gini, gini_list = speech2text(**batch)
338 | #logging.info(f"zlist_tensor is {zlist_tensor}")
339 | except TooShortUttError as e:
340 | logging.warning(f"Utterance {keys} {e}")
341 | hyp = TestHypothesis(score=0.0, scores={}, states={}, yseq=[])
342 | results = [[" ", [""], [2], hyp]] * nbest
343 |
344 | # Only supporting batch_size==1
345 | if orig_flag:
346 | orig_sum_gini[key] = str(sum_gini)
347 | orig_gini_list[key] = gini_list
348 | else:
349 | new_sum_gini[key] = str(sum_gini)
350 | new_gini_list[key] = gini_list
351 | cov = cau_coverage(all_z_list, 0)
352 | #act_cell = get_activate_cell(all_z_list, 0)
353 | if need_decode:
354 | for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
355 | # Create a directory: outdir/{n}best_recog
356 | ibest_writer = writer[f"{n}best_recog"]
357 | # Write the result to each file
358 | ibest_writer["token"][key] = " ".join(token)
359 | ibest_writer["token_int"][key] = " ".join(map(str, token_int))
360 | ibest_writer["score"][key] = str(hyp.score)
361 | ibest_writer["sum_gini"][key] = str(sum_gini)
362 | ibest_writer["token_gini"][key] = " ".join(map(str, gini_list))
363 | if text is not None:
364 | ibest_writer["text"][key] = text
365 | if orig_flag:
366 | return orig_sum_gini, orig_gini_list
367 | else:
368 |
369 | return new_sum_gini, new_gini_list
370 |
371 |
372 | def merge_result(file_path, gini_dict):
373 | f = open(file_path, 'w')
374 | for key in gini_dict.keys():
375 | if isinstance(gini_dict[key], str):
376 | line = key + " " + gini_dict[key] + "\n"
377 | else:
378 | line = key + " " + " ".join(map(str, gini_dict[key])) + "\n"
379 | f.write(line)
380 | f.close()
381 |
382 |
383 |
384 |
385 | def get_parser():
386 | parser = config_argparse.ArgumentParser(
387 | description="ASR Decoding",
388 | formatter_class=argparse.ArgumentDefaultsHelpFormatter,
389 | )
390 |
391 | # Note(kamo): Use '_' instead of '-' as separator.
392 | # '-' is confusing if written in yaml.
393 | parser.add_argument(
394 | "--log_level",
395 | type=lambda x: x.upper(),
396 | default="INFO",
397 | choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
398 | help="The verbose level of logging",
399 | )
400 |
401 | parser.add_argument("--output_dir", type=str, required=True)
402 | parser.add_argument(
403 | "--ngpu",
404 | type=int,
405 | default=0,
406 | help="The number of gpus. 0 indicates CPU mode",
407 | )
408 | parser.add_argument("--seed", type=int, default=0, help="Random seed")
409 | parser.add_argument(
410 | "--dtype",
411 | default="float32",
412 | choices=["float16", "float32", "float64"],
413 | help="Data type",
414 | )
415 | parser.add_argument(
416 | "--num_workers",
417 | type=int,
418 | default=1,
419 | help="The number of workers used for DataLoader",
420 | )
421 |
422 | group = parser.add_argument_group("Input data related")
423 | group.add_argument(
424 | "--data_path_and_name_and_type",
425 | type=str2triple_str,
426 | required=True,
427 | action="append",
428 | )
429 | group.add_argument("--key_file", type=str_or_none)
430 | group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
431 |
432 | group = parser.add_argument_group("The model configuration related")
433 | group.add_argument("--asr_train_config", type=str, required=True)
434 | group.add_argument("--asr_model_file", type=str, required=True)
435 | group.add_argument("--lm_train_config", type=str)
436 | group.add_argument("--lm_file", type=str)
437 | group.add_argument("--word_lm_train_config", type=str)
438 | group.add_argument("--word_lm_file", type=str)
439 | group.add_argument("--orig_flag", type=str, default="False")
440 | group.add_argument("--orig_dir", type=str, required=True)
441 | group = parser.add_argument_group("Beam-search related")
442 | group.add_argument(
443 | "--batch_size",
444 | type=int,
445 | default=1,
446 | help="The batch size for inference",
447 | )
448 | group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
449 | group.add_argument("--beam_size", type=int, default=20, help="Beam size")
450 | group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
451 | group.add_argument(
452 | "--maxlenratio",
453 | type=float,
454 | default=0.0,
455 | help="Input length ratio to obtain max output length. "
456 | "If maxlenratio=0.0 (default), it uses a end-detect "
457 | "function "
458 | "to automatically find maximum hypothesis lengths",
459 | )
460 | group.add_argument(
461 | "--minlenratio",
462 | type=float,
463 | default=0.0,
464 | help="Input length ratio to obtain min output length",
465 | )
466 | group.add_argument(
467 | "--ctc_weight",
468 | type=float,
469 | default=0.5,
470 | help="CTC weight in joint decoding",
471 | )
472 | group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
473 | group.add_argument("--streaming", type=str2bool, default=False)
474 |
475 | group = parser.add_argument_group("Text converter related")
476 | group.add_argument(
477 | "--token_type",
478 | type=str_or_none,
479 | default=None,
480 | choices=["char", "bpe", None],
481 | help="The token type for ASR model. "
482 | "If not given, refers from the training args",
483 | )
484 | group.add_argument(
485 | "--bpemodel",
486 | type=str_or_none,
487 | default=None,
488 | help="The model path of sentencepiece. "
489 | "If not given, refers from the training args",
490 | )
491 | group.add_argument(
492 | "--need_decode",
493 | type=str,
494 | )
495 | group.add_argument(
496 | "--selected_num",
497 | type=int,
498 | default=None,
499 | )
500 | group.add_argument(
501 | "--dataset",
502 | type=str,
503 | default=None,
504 | )
505 |
506 |
507 | return parser
508 |
509 |
510 |
511 | def main(cmd=None):
512 | print(get_commandline_args(), file=sys.stderr)
513 | parser = get_parser()
514 | args = parser.parse_args(cmd)
515 | args.need_decode = str2bool(args.need_decode)
516 | selected_num = args.selected_num
517 | dataset = args.dataset
518 | kwargs = vars(args)
519 | kwargs.pop("config", None)
520 | kwargs.pop("selected_num")
521 | kwargs.pop("dataset")
522 | test_set=args.data_path_and_name_and_type[0][0].split("/")[-1]
523 | if args.need_decode:
524 | sum_ginis, gini_lists = test(**kwargs)
525 | else:
526 | if ("test" in str(args.data_path_and_name_and_type)) & ("gini" not in test_set) :
527 | aug_type = test_set.split("-")[1]
528 | gini_audio = sort_test_by_gini(args.orig_dir + "/test-orig", args.orig_dir + "/" + test_set, aug_type, selected_num)
529 | test_data_prep("data/test-orig", "data/" + test_set, dataset, gini_audio, "gini", selected_num, test_set)
530 |
531 |
532 |
533 | if __name__ == "__main__":
534 | main()
535 |
--------------------------------------------------------------------------------
/espnet/nets/test_beam_search.py:
--------------------------------------------------------------------------------
1 | """Beam search module."""
2 |
3 | from itertools import chain
4 | import logging
5 | from typing import Any
6 | from typing import Dict
7 | from typing import List
8 | from typing import NamedTuple
9 | from typing import Tuple
10 | from typing import Union
11 | import numpy as np
12 | import torch
13 |
14 | from espnet.nets.e2e_asr_common import end_detect
15 | from espnet.nets.scorer_interface import PartialScorerInterface
16 | from espnet.nets.scorer_interface import ScorerInterface
17 | from espnet.utils.gini_utils import *
18 |
19 |
20 | class TestHypothesis(NamedTuple):
21 | """Test Hypothesis data type."""
22 |
23 | yseq: torch.Tensor
24 | score: Union[float, torch.Tensor] = 0
25 | scores: Dict[str, Union[float, torch.Tensor]] = dict()
26 | states: Dict[str, Any] = dict()
27 | def asdict(self) -> dict:
28 | """Convert data to JSON-friendly dict."""
29 | return self._replace(
30 | yseq=self.yseq.tolist(),
31 | score=float(self.score),
32 | scores={k: float(v) for k, v in self.scores.items()},
33 | )._asdict()
34 |
35 |
36 | class TestBeamSearch(torch.nn.Module):
37 | """Test Beam search implementation."""
38 |
39 | def __init__(
40 | self,
41 | scorers: Dict[str, ScorerInterface],
42 | weights: Dict[str, float],
43 | beam_size: int,
44 | vocab_size: int,
45 | sos: int,
46 | eos: int,
47 | token_list: List[str] = None,
48 | pre_beam_ratio: float = 1.5,
49 | pre_beam_score_key: str = None,
50 | ):
51 | """Initialize beam search.
52 |
53 | Args:
54 | scorers (dict[str, ScorerInterface]): Dict of decoder modules
55 | e.g., Decoder, CTCPrefixScorer, LM
56 | The scorer will be ignored if it is `None`
57 | weights (dict[str, float]): Dict of weights for each scorers
58 | The scorer will be ignored if its weight is 0
59 | beam_size (int): The number of hypotheses kept during search
60 | vocab_size (int): The number of vocabulary
61 | sos (int): Start of sequence id
62 | eos (int): End of sequence id
63 | token_list (list[str]): List of tokens for debug log
64 | pre_beam_score_key (str): key of scores to perform pre-beam search
65 | pre_beam_ratio (float): beam size in the pre-beam search
66 | will be `int(pre_beam_ratio * beam_size)`
67 |
68 | """
69 | super().__init__()
70 | # set scorers
71 | self.weights = weights
72 | self.scorers = dict()
73 | self.full_scorers = dict()
74 | self.part_scorers = dict()
75 | # this module dict is required for recursive cast
76 | # `self.to(device, dtype)` in `recog.py`
77 | self.nn_dict = torch.nn.ModuleDict()
78 | for k, v in scorers.items():
79 | w = weights.get(k, 0)
80 | if w == 0 or v is None:
81 | continue
82 | assert isinstance(
83 | v, ScorerInterface
84 | ), f"{k} ({type(v)}) does not implement ScorerInterface"
85 | self.scorers[k] = v
86 | if isinstance(v, PartialScorerInterface):
87 | self.part_scorers[k] = v
88 | else:
89 | self.full_scorers[k] = v
90 | if isinstance(v, torch.nn.Module):
91 | self.nn_dict[k] = v
92 |
93 | # set configurations
94 | self.sos = sos
95 | self.eos = eos
96 | self.token_list = token_list
97 | self.pre_beam_size = int(pre_beam_ratio * beam_size)
98 | self.beam_size = beam_size
99 | self.n_vocab = vocab_size
100 | if (
101 | pre_beam_score_key is not None
102 | and pre_beam_score_key != "full"
103 | and pre_beam_score_key not in self.full_scorers
104 | ):
105 | raise KeyError(f"{pre_beam_score_key} is not found in {self.full_scorers}")
106 | self.pre_beam_score_key = pre_beam_score_key
107 | self.do_pre_beam = (
108 | self.pre_beam_score_key is not None
109 | and self.pre_beam_size < self.n_vocab
110 | and len(self.part_scorers) > 0
111 | )
112 |
113 | def init_hyp(self, x: torch.Tensor) -> List[TestHypothesis]:
114 | """Get an initial hypothesis data.
115 |
116 | Args:
117 | x (torch.Tensor): The encoder output feature
118 |
119 | Returns:
120 | Hypothesis: The initial hypothesis.
121 |
122 | """
123 | init_states = dict()
124 | init_scores = dict()
125 | for k, d in self.scorers.items():
126 | init_states[k] = d.init_state(x)
127 | init_scores[k] = 0.0
128 | return [
129 | TestHypothesis(
130 | score=0.0,
131 | scores=init_scores,
132 | states=init_states,
133 | yseq=torch.tensor([self.sos], device=x.device),
134 | )
135 | ]
136 |
137 | @staticmethod
138 | def append_token(xs: torch.Tensor, x: int) -> torch.Tensor:
139 | """Append new token to prefix tokens.
140 |
141 | Args:
142 | xs (torch.Tensor): The prefix token
143 | x (int): The new token to append
144 |
145 | Returns:
146 | torch.Tensor: New tensor contains: xs + [x] with xs.dtype and xs.device
147 |
148 | """
149 | x = torch.tensor([x], dtype=xs.dtype, device=xs.device)
150 | #logging.info(f"call append_token,get {torch.cat((xs, x))}")
151 | return torch.cat((xs, x))
152 |
153 | def score_full(
154 | self, hyp: TestHypothesis, x: torch.Tensor
155 | ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
156 | """Score new hypothesis by `self.full_scorers`.
157 | call rnn decoder
158 | Args:
159 | hyp (Hypothesis): Hypothesis with prefix tokens to score
160 | x (torch.Tensor): Corresponding input feature
161 |
162 | Returns:
163 | Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
164 | score dict of `hyp` that has string keys of `self.full_scorers`
165 | and tensor score values of shape: `(self.n_vocab,)`,
166 | and state dict that has string keys
167 | and state values of `self.full_scorers`
168 |
169 | """
170 | #logging.info("call score_full")
171 | scores = dict()
172 | states = dict()
173 | """
174 | full_scorers:{'decoder': RNNDecoder(
175 | (embed): Embedding(42, 320)
176 | (dropout_emb): Dropout(p=0.0, inplace=False)
177 | (decoder): ModuleList(
178 | (0): LSTMCell(640, 320)
179 | )
180 | (dropout_dec): ModuleList(
181 | (0): Dropout(p=0.0, inplace=False)
182 | )
183 | (output): Linear(in_features=320, out_features=42, bias=True)
184 | (att_list): ModuleList(
185 | (0): AttLoc(
186 | (mlp_enc): Linear(in_features=320, out_features=320, bias=True)
187 | (mlp_dec): Linear(in_features=320, out_features=320, bias=False)
188 | (mlp_att): Linear(in_features=10, out_features=320, bias=False)
189 | (loc_conv): Conv2d(1, 10, kernel_size=(1, 201), stride=(1, 1), padding=(0, 100), bias=False)
190 | (gvec): Linear(in_features=320, out_features=1, bias=True)
191 | )
192 | )
193 | )}
194 | """
195 | for k, d in self.full_scorers.items():
196 | scores[k], states[k], outps = d.score(hyp.yseq, hyp.states[k], x)
197 | #logging.info("score_full get {scores},{outps}")
198 | return scores, states, outps
199 |
200 | def score_partial(
201 | self, hyp: TestHypothesis, ids: torch.Tensor, x: torch.Tensor
202 | ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
203 | """Score new hypothesis by `self.part_scorers`.
204 |
205 | Args:
206 | hyp (Hypothesis): Hypothesis with prefix tokens to score
207 | ids (torch.Tensor): 1D tensor of new partial tokens to score
208 | x (torch.Tensor): Corresponding input feature
209 |
210 | Returns:
211 | Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
212 | score dict of `hyp` that has string keys of `self.part_scorers`
213 | and tensor score values of shape: `(len(ids),)`,
214 | and state dict that has string keys
215 | and state values of `self.part_scorers`
216 |
217 | """
218 | scores = dict()
219 | states = dict()
220 | for k, d in self.part_scorers.items():
221 | scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], x)
222 | return scores, states
223 |
224 | def beam(
225 | self, weighted_scores: torch.Tensor, ids: torch.Tensor
226 | ) -> Tuple[torch.Tensor, torch.Tensor]:
227 | """Compute topk full token ids and partial token ids.
228 |
229 | Args:
230 | weighted_scores (torch.Tensor): The weighted sum scores for each tokens.
231 | Its shape is `(self.n_vocab,)`.
232 | ids (torch.Tensor): The partial token ids to compute topk
233 |
234 | Returns:
235 | Tuple[torch.Tensor, torch.Tensor]:
236 | The topk full token ids and partial token ids.
237 | Their shapes are `(self.beam_size,)`
238 |
239 | """
240 | # no pre beam performed
241 | if weighted_scores.size(0) == ids.size(0):
242 | top_ids = weighted_scores.topk(self.beam_size)[1]
243 | return top_ids, top_ids
244 |
245 | # mask pruned in pre-beam not to select in topk
246 | tmp = weighted_scores[ids]
247 | weighted_scores[:] = -float("inf")
248 | weighted_scores[ids] = tmp
249 | top_ids = weighted_scores.topk(self.beam_size)[1]
250 | local_ids = weighted_scores[ids].topk(self.beam_size)[1]
251 | return top_ids, local_ids
252 |
253 | @staticmethod
254 | def merge_scores(
255 | prev_scores: Dict[str, float],
256 | next_full_scores: Dict[str, torch.Tensor],
257 | full_idx: int,
258 | next_part_scores: Dict[str, torch.Tensor],
259 | part_idx: int,
260 | ) -> Dict[str, torch.Tensor]:
261 | """Merge scores for new hypothesis.
262 |
263 | Args:
264 | prev_scores (Dict[str, float]):
265 | The previous hypothesis scores by `self.scorers`
266 | next_full_scores (Dict[str, torch.Tensor]): scores by `self.full_scorers`
267 | full_idx (int): The next token id for `next_full_scores`
268 | next_part_scores (Dict[str, torch.Tensor]):
269 | scores of partial tokens by `self.part_scorers`
270 | part_idx (int): The new token id for `next_part_scores`
271 |
272 | Returns:
273 | Dict[str, torch.Tensor]: The new score dict.
274 | Its keys are names of `self.full_scorers` and `self.part_scorers`.
275 | Its values are scalar tensors by the scorers.
276 |
277 | """
278 | new_scores = dict()
279 | for k, v in next_full_scores.items():
280 | new_scores[k] = prev_scores[k] + v[full_idx]
281 | for k, v in next_part_scores.items():
282 | new_scores[k] = prev_scores[k] + v[part_idx]
283 | return new_scores
284 |
285 | def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any:
286 | """Merge states for new hypothesis.
287 |
288 | Args:
289 | states: states of `self.full_scorers`
290 | part_states: states of `self.part_scorers`
291 | part_idx (int): The new token id for `part_scores`
292 |
293 | Returns:
294 | Dict[str, torch.Tensor]: The new score dict.
295 | Its keys are names of `self.full_scorers` and `self.part_scorers`.
296 | Its values are states of the scorers.
297 |
298 | """
299 | new_states = dict()
300 | for k, v in states.items():
301 | new_states[k] = v
302 | for k, d in self.part_scorers.items():
303 | new_states[k] = d.select_state(part_states[k], part_idx)
304 | return new_states
305 |
306 | def search(
307 | self, running_hyps: List[TestHypothesis], x: torch.Tensor
308 | ) -> List[TestHypothesis]:
309 | """Search new tokens for running hypotheses and encoded speech x.
310 |
311 | Args:
312 | running_hyps (List[Hypothesis]): Running hypotheses on beam
313 | x (torch.Tensor): Encoded speech feature (T, D)
314 |
315 | Returns:
316 | List[Hypotheses]: Best sorted hypotheses
317 |
318 | """
319 | #logging.info("call search")
320 | best_hyps = []
321 | new_ginis = {}
322 | part_ids = torch.arange(self.n_vocab, device=x.device) # no pre-beam
323 | for hyp in running_hyps:
324 | # scoring
325 | weighted_scores = torch.zeros(self.n_vocab, dtype=x.dtype, device=x.device)
326 | scores, states, outps = self.score_full(hyp, x)
327 | for k in self.full_scorers:
328 | weighted_scores += self.weights[k] * scores[k]
329 | yseq = hyp.yseq.numpy()
330 | if len(yseq) > 1:
331 | yseq = yseq[1:]
332 | gini = caul_gini(outps)
333 | yseq_key = " ".join('%s' %seq for seq in yseq)
334 | new_ginis[yseq_key] = gini
335 |
336 | # partial scoring
337 | if self.do_pre_beam:
338 | pre_beam_scores = (
339 | weighted_scores
340 | if self.pre_beam_score_key == "full"
341 | else scores[self.pre_beam_score_key]
342 | )
343 | part_ids = torch.topk(pre_beam_scores, self.pre_beam_size)[1]
344 | part_scores, part_states = self.score_partial(hyp, part_ids, x)
345 | for k in self.part_scorers:
346 | weighted_scores[part_ids] += self.weights[k] * part_scores[k]
347 | # add previous hyp score
348 | weighted_scores += hyp.score
349 | # update hyps
350 | for j, part_j in zip(*self.beam(weighted_scores, part_ids)):
351 | # will be (2 x beam at most)
352 | #logging.info(f"{j},{part_j},{hyp.yseq}")
353 | best_hyps.append(
354 | TestHypothesis(
355 | score=weighted_scores[j],
356 | yseq=self.append_token(hyp.yseq, j),
357 | scores=self.merge_scores(
358 | hyp.scores, scores, j, part_scores, part_j
359 | ),
360 | states=self.merge_states(states, part_states, part_j),
361 | )
362 | )
363 |
364 | # sort and prune 2 x beam -> beam
365 | best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[
366 | : min(len(best_hyps), self.beam_size)
367 | ]
368 | for best in best_hyps:
369 | last_seq = best.yseq.numpy()[-1]
370 |
371 | '''
372 | best_hyps:Sort by probability
373 | '''
374 | return best_hyps,new_ginis
375 |
376 | def forward(
377 | self, x: torch.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0
378 | ) -> List[TestHypothesis]:
379 | #logging.info("call forward")
380 | """Perform beam search.
381 |
382 | Args:
383 | x (torch.Tensor): Encoded speech feature (T, D)
384 | maxlenratio (float): Input length ratio to obtain max output length.
385 | If maxlenratio=0.0 (default), it uses a end-detect function
386 | to automatically find maximum hypothesis lengths
387 | minlenratio (float): Input length ratio to obtain min output length.
388 |
389 | Returns:
390 | list[Hypothesis]: N-best decoding results
391 |
392 | """
393 | # set length bounds
394 | if maxlenratio == 0:
395 | maxlen = x.shape[0]
396 | else:
397 | maxlen = max(1, int(maxlenratio * x.size(0)))
398 | minlen = int(minlenratio * x.size(0))
399 | logging.info("decoder input length: " + str(x.shape[0]))
400 | logging.info("max output length: " + str(maxlen))
401 | logging.info("min output length: " + str(minlen))
402 |
403 | # main loop of prefix search
404 | running_hyps = self.init_hyp(x)
405 | ended_hyps = []
406 | all_ginis = {}
407 | for i in range(maxlen):
408 | best, new_ginis = self.search(running_hyps, x)
409 | if len(new_ginis) > 0:
410 | all_ginis.update(new_ginis)
411 | # post process of one iteration
412 | running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps)
413 | # end detection
414 | if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i):
415 | logging.info(f"end detected at {i}")
416 | break
417 | if len(running_hyps) == 0:
418 | logging.info("no hypothesis. Finish decoding.")
419 | break
420 | else:
421 | logging.debug(f"remained hypotheses: {len(running_hyps)}")
422 | nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True)
423 | # check the number of hypotheses reaching to eos
424 | if len(nbest_hyps) == 0:
425 | logging.warning(
426 | "there is no N-best results, perform recognition "
427 | "again with smaller minlenratio."
428 | )
429 | return (
430 | []
431 | if minlenratio < 0.1
432 | else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1))
433 | )
434 |
435 | # report the best result
436 | best = nbest_hyps[0]
437 | for k, v in best.scores.items():
438 | logging.info(
439 | f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}"
440 | )
441 | logging.info(f"total log probability: {best.score:.2f}")
442 | logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}")
443 | logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}")
444 | logging.info(f"yseq: {best.yseq.shape}")
445 | if self.token_list is not None:
446 | logging.info(
447 | "best hypo: "
448 | + "".join([self.token_list[x] for x in best.yseq[1:-1]])
449 | + "\n"
450 | )
451 | #logging.info(f"call forward,get {nbest_hyps[0]}")
452 | logging.info(f"forward return nbest_hyps, the length is {len(nbest_hyps)}")
453 | return nbest_hyps, all_ginis
454 |
455 |
456 |
457 | def post_process(
458 | self,
459 | i: int,
460 | maxlen: int,
461 | maxlenratio: float,
462 | running_hyps: List[TestHypothesis],
463 | ended_hyps: List[TestHypothesis],
464 | ) -> List[TestHypothesis]:
465 | """Perform post-processing of beam search iterations.
466 | Args:
467 | i (int): The length of hypothesis tokens.
468 | maxlen (int): The maximum length of tokens in beam search.
469 | maxlenratio (int): The maximum length ratio in beam search.
470 | running_hyps (List[Hypothesis]): The running hypotheses in beam search.
471 | ended_hyps (List[Hypothesis]): The ended hypotheses in beam search.
472 |
473 | Returns:
474 | List[Hypothesis]: The new running hypotheses.
475 |
476 | """
477 | #logging.info("call post process")
478 | logging.debug(f"the number of running hypotheses: {len(running_hyps)}")
479 | if self.token_list is not None:
480 | logging.debug(
481 | "best hypo: "
482 | + "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]])
483 | )
484 | # add eos in the final loop to avoid that there are no ended hyps
485 | if i == maxlen - 1:
486 | logging.info("adding in the last position in the loop")
487 | running_hyps = [
488 | h._replace(yseq=self.append_token(h.yseq, self.eos))
489 | for h in running_hyps
490 | ]
491 |
492 | # add ended hypotheses to a final list, and removed them from current hypotheses
493 | # (this will be a problem, number of hyps < beam)
494 | remained_hyps = []
495 | for hyp in running_hyps:
496 | if hyp.yseq[-1] == self.eos:
497 | # e.g., Word LM needs to add final score
498 | for k, d in chain(self.full_scorers.items(), self.part_scorers.items()):
499 | s = d.final_score(hyp.states[k])
500 | hyp.scores[k] += s
501 | hyp = hyp._replace(score=hyp.score + self.weights[k] * s)
502 | ended_hyps.append(hyp)
503 | else:
504 | remained_hyps.append(hyp)
505 | return remained_hyps
506 |
507 |
508 | def beam_search(
509 | x: torch.Tensor,
510 | sos: int,
511 | eos: int,
512 | beam_size: int,
513 | vocab_size: int,
514 | scorers: Dict[str, ScorerInterface],
515 | weights: Dict[str, float],
516 | token_list: List[str] = None,
517 | maxlenratio: float = 0.0,
518 | minlenratio: float = 0.0,
519 | pre_beam_ratio: float = 1.5,
520 | pre_beam_score_key: str = "full",
521 | ) -> list:
522 | """Perform beam search with scorers.
523 |
524 | Args:
525 | x (torch.Tensor): Encoded speech feature (T, D)
526 | sos (int): Start of sequence id
527 | eos (int): End of sequence id
528 | beam_size (int): The number of hypotheses kept during search
529 | vocab_size (int): The number of vocabulary
530 | scorers (dict[str, ScorerInterface]): Dict of decoder modules
531 | e.g., Decoder, CTCPrefixScorer, LM
532 | The scorer will be ignored if it is `None`
533 | weights (dict[str, float]): Dict of weights for each scorers
534 | The scorer will be ignored if its weight is 0
535 | token_list (list[str]): List of tokens for debug log
536 | maxlenratio (float): Input length ratio to obtain max output length.
537 | If maxlenratio=0.0 (default), it uses a end-detect function
538 | to automatically find maximum hypothesis lengths
539 | minlenratio (float): Input length ratio to obtain min output length.
540 | pre_beam_score_key (str): key of scores to perform pre-beam search
541 | pre_beam_ratio (float): beam size in the pre-beam search
542 | will be `int(pre_beam_ratio * beam_size)`
543 |
544 | Returns:
545 | list: N-best decoding results
546 |
547 | """
548 | ret = BeamSearch(
549 | scorers,
550 | weights,
551 | beam_size=beam_size,
552 | vocab_size=vocab_size,
553 | pre_beam_ratio=pre_beam_ratio,
554 | pre_beam_score_key=pre_beam_score_key,
555 | sos=sos,
556 | eos=eos,
557 | token_list=token_list,
558 | ).forward(x=x, maxlenratio=maxlenratio, minlenratio=minlenratio)
559 | return [h.asdict() for h in ret]
560 |
561 |
--------------------------------------------------------------------------------
/TEMPLATE/asr1/asr_test.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # Set bash to 'debug' mode, it will exit on :
4 | # -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
5 | set -e
6 | set -u
7 | set -o pipefail
8 |
9 | log() {
10 | local fname=${BASH_SOURCE[1]##*/}
11 | echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
12 | }
13 | min() {
14 | local a b
15 | a=$1
16 | for b in "$@"; do
17 | if [ "${b}" -le "${a}" ]; then
18 | a="${b}"
19 | fi
20 | done
21 | echo "${a}"
22 | }
23 | SECONDS=0
24 |
25 | # General configuration
26 | stage=1 # Processes starts from the specified stage.
27 | stop_stage=10000 # Processes is stopped at the specified stage.
28 | skip_data_prep=false # Skip data preparation stages.
29 | skip_train=false # Skip training stages.
30 | skip_eval=false # Skip decoding and evaluation stages.
31 | skip_upload=true # Skip packing and uploading stages.
32 | ngpu=1 # The number of gpus ("0" uses cpu, otherwise use gpu).
33 | num_nodes=1 # The number of nodes.
34 | nj=32 # The number of parallel jobs.
35 | inference_nj=32 # The number of parallel jobs in decoding.
36 | gpu_inference=false # Whether to perform gpu decoding.
37 | dumpdir=dump # Directory to dump features.
38 | expdir=exp # Directory to save experiments.
39 | python=python3 # Specify python to execute espnet commands.
40 |
41 | # Data preparation related
42 | local_data_opts= # The options given to local/data.sh.
43 |
44 | # Speed perturbation related
45 | speed_perturb_factors= # perturbation factors, e.g. "0.9 1.0 1.1" (separated by space).
46 |
47 | # Feature extraction related
48 | feats_type=raw # Feature type (raw or fbank_pitch).
49 | audio_format=flac # Audio format: wav, flac, wav.ark, flac.ark (only in feats_type=raw).
50 | fs=16k # Sampling rate.
51 | min_wav_duration=0.1 # Minimum duration in second.
52 | max_wav_duration=20 # Maximum duration in second.
53 |
54 | # Tokenization related
55 | token_type=bpe # Tokenization type (char or bpe).
56 | nbpe=30 # The number of BPE vocabulary.
57 | bpemode=unigram # Mode of BPE (unigram or bpe).
58 | oov="" # Out of vocabulary symbol.
59 | blank="" # CTC blank symbol
60 | sos_eos="" # sos and eos symbole
61 | bpe_input_sentence_size=100000000 # Size of input sentence for BPE.
62 | bpe_nlsyms= # non-linguistic symbols list, separated by a comma, for BPE
63 | bpe_char_cover=1.0 # character coverage when modeling BPE
64 |
65 | # Language model related
66 | use_lm=false # Use language model for ASR decoding.
67 | lm_tag= # Suffix to the result dir for language model training.
68 | lm_exp= # Specify the direcotry path for LM experiment.
69 | # If this option is specified, lm_tag is ignored.
70 | lm_stats_dir= # Specify the direcotry path for LM statistics.
71 | lm_config= # Config for language model training.
72 | lm_args= # Arguments for language model training, e.g., "--max_epoch 10".
73 | # Note that it will overwrite args in lm config.
74 | use_word_lm=false # Whether to use word language model.
75 | num_splits_lm=1 # Number of splitting for lm corpus.
76 | # shellcheck disable=SC2034
77 | word_vocab_size=10000 # Size of word vocabulary.
78 |
79 | # ASR model related
80 | asr_tag= # Suffix to the result dir for asr model training.
81 | asr_exp= # Specify the direcotry path for ASR experiment.
82 | # If this option is specified, asr_tag is ignored.
83 | asr_stats_dir= # Specify the direcotry path for ASR statistics.
84 | asr_config= # Config for asr model training.
85 | asr_args= # Arguments for asr model training, e.g., "--max_epoch 10".
86 | # Note that it will overwrite args in asr config.
87 | feats_normalize=global_mvn # Normalizaton layer type.
88 | num_splits_asr=1 # Number of splitting for lm corpus.
89 |
90 | # Decoding related
91 | inference_tag= # Suffix to the result dir for decoding.
92 | inference_config= # Config for decoding.
93 | inference_args= # Arguments for decoding, e.g., "--lm_weight 0.1".
94 | # Note that it will overwrite args in inference config.
95 | inference_lm=valid.loss.ave.pth # Language modle path for decoding.
96 | inference_asr_model=valid.acc.ave.pth # ASR model path for decoding.
97 | # e.g.
98 | # inference_asr_model=train.loss.best.pth
99 | # inference_asr_model=3epoch.pth
100 | # inference_asr_model=valid.acc.best.pth
101 | # inference_asr_model=valid.loss.ave.pth
102 | download_model= # Download a model from Model Zoo and use it for decoding.
103 |
104 | # [Task dependent] Set the datadir name created by local/data.sh
105 | train_set= # Name of training set.
106 | valid_set= # Name of validation set used for monitoring/tuning network training.
107 | test_sets= # Names of test sets. Multiple items (e.g., both dev and eval sets) can be specified.
108 | bpe_train_text= # Text file path of bpe training set.
109 | lm_train_text= # Text file path of language model training set.
110 | lm_dev_text= # Text file path of language model development set.
111 | lm_test_text= # Text file path of language model evaluation set.
112 | nlsyms_txt=none # Non-linguistic symbol list if existing.
113 | cleaner=none # Text cleaner.
114 | g2p=none # g2p method (needed if token_type=phn).
115 | lang=en # The language type of corpus.
116 | score_opts= # The options given to sclite scoring
117 | local_score_opts= # The options given to local/score.sh.
118 | asr_speech_fold_length=800 # fold_length for speech data during ASR training.
119 | asr_text_fold_length=150 # fold_length for text data during ASR training.
120 | lm_fold_length=150 # fold_length for LM training.
121 |
122 | # gini_related
123 | orig_flag=
124 | need_decode=
125 | selected_num=
126 | dataset=
127 | help_message=$(cat << EOF
128 | Usage: $0 --train-set "" --valid-set "" --test_sets ""
129 |
130 | Options:
131 | # General configuration
132 | --stage # Processes starts from the specified stage (default="${stage}").
133 | --stop_stage # Processes is stopped at the specified stage (default="${stop_stage}").
134 | --skip_data_prep # Skip data preparation stages (default="${skip_data_prep}").
135 | --skip_train # Skip training stages (default="${skip_train}").
136 | --skip_eval # Skip decoding and evaluation stages (default="${skip_eval}").
137 | --skip_upload # Skip packing and uploading stages (default="${skip_upload}").
138 | --ngpu # The number of gpus ("0" uses cpu, otherwise use gpu, default="${ngpu}").
139 | --num_nodes # The number of nodes (default="${num_nodes}").
140 | --nj # The number of parallel jobs (default="${nj}").
141 | --inference_nj # The number of parallel jobs in decoding (default="${inference_nj}").
142 | --gpu_inference # Whether to perform gpu decoding (default="${gpu_inference}").
143 | --dumpdir # Directory to dump features (default="${dumpdir}").
144 | --expdir # Directory to save experiments (default="${expdir}").
145 | --python # Specify python to execute espnet commands (default="${python}").
146 |
147 | # Data preparation related
148 | --local_data_opts # The options given to local/data.sh (default="${local_data_opts}").
149 |
150 | # Speed perturbation related
151 | --speed_perturb_factors # speed perturbation factors, e.g. "0.9 1.0 1.1" (separated by space, default="${speed_perturb_factors}").
152 |
153 | # Feature extraction related
154 | --feats_type # Feature type (raw, fbank_pitch or extracted, default="${feats_type}").
155 | --audio_format # Audio format: wav, flac, wav.ark, flac.ark (only in feats_type=raw, default="${audio_format}").
156 | --fs # Sampling rate (default="${fs}").
157 | --min_wav_duration # Minimum duration in second (default="${min_wav_duration}").
158 | --max_wav_duration # Maximum duration in second (default="${max_wav_duration}").
159 |
160 | # Tokenization related
161 | --token_type # Tokenization type (char or bpe, default="${token_type}").
162 | --nbpe # The number of BPE vocabulary (default="${nbpe}").
163 | --bpemode # Mode of BPE (unigram or bpe, default="${bpemode}").
164 | --oov # Out of vocabulary symbol (default="${oov}").
165 | --blank # CTC blank symbol (default="${blank}").
166 | --sos_eos # sos and eos symbole (default="${sos_eos}").
167 | --bpe_input_sentence_size # Size of input sentence for BPE (default="${bpe_input_sentence_size}").
168 | --bpe_nlsyms # Non-linguistic symbol list for sentencepiece, separated by a comma. (default="${bpe_nlsyms}").
169 | --bpe_char_cover # Character coverage when modeling BPE (default="${bpe_char_cover}").
170 |
171 | # Language model related
172 | --lm_tag # Suffix to the result dir for language model training (default="${lm_tag}").
173 | --lm_exp # Specify the direcotry path for LM experiment.
174 | # If this option is specified, lm_tag is ignored (default="${lm_exp}").
175 | --lm_stats_dir # Specify the direcotry path for LM statistics (default="${lm_stats_dir}").
176 | --lm_config # Config for language model training (default="${lm_config}").
177 | --lm_args # Arguments for language model training (default="${lm_args}").
178 | # e.g., --lm_args "--max_epoch 10"
179 | # Note that it will overwrite args in lm config.
180 | --use_word_lm # Whether to use word language model (default="${use_word_lm}").
181 | --word_vocab_size # Size of word vocabulary (default="${word_vocab_size}").
182 | --num_splits_lm # Number of splitting for lm corpus (default="${num_splits_lm}").
183 |
184 | # ASR model related
185 | --asr_tag # Suffix to the result dir for asr model training (default="${asr_tag}").
186 | --asr_exp # Specify the direcotry path for ASR experiment.
187 | # If this option is specified, asr_tag is ignored (default="${asr_exp}").
188 | --asr_stats_dir # Specify the direcotry path for ASR statistics (default="${asr_stats_dir}").
189 | --asr_config # Config for asr model training (default="${asr_config}").
190 | --asr_args # Arguments for asr model training (default="${asr_args}").
191 | # e.g., --asr_args "--max_epoch 10"
192 | # Note that it will overwrite args in asr config.
193 | --feats_normalize # Normalizaton layer type (default="${feats_normalize}").
194 | --num_splits_asr # Number of splitting for lm corpus (default="${num_splits_asr}").
195 |
196 | # Decoding related
197 | --inference_tag # Suffix to the result dir for decoding (default="${inference_tag}").
198 | --inference_config # Config for decoding (default="${inference_config}").
199 | --inference_args # Arguments for decoding (default="${inference_args}").
200 | # e.g., --inference_args "--lm_weight 0.1"
201 | # Note that it will overwrite args in inference config.
202 | --inference_lm # Language modle path for decoding (default="${inference_lm}").
203 | --inference_asr_model # ASR model path for decoding (default="${inference_asr_model}").
204 | --download_model # Download a model from Model Zoo and use it for decoding (default="${download_model}").
205 |
206 | # [Task dependent] Set the datadir name created by local/data.sh
207 | --train_set # Name of training set (required).
208 | --valid_set # Name of validation set used for monitoring/tuning network training (required).
209 | --test_sets # Names of test sets.
210 | # Multiple items (e.g., both dev and eval sets) can be specified (required).
211 | --bpe_train_text # Text file path of bpe training set.
212 | --lm_train_text # Text file path of language model training set.
213 | --lm_dev_text # Text file path of language model development set (default="${lm_dev_text}").
214 | --lm_test_text # Text file path of language model evaluation set (default="${lm_test_text}").
215 | --nlsyms_txt # Non-linguistic symbol list if existing (default="${nlsyms_txt}").
216 | --cleaner # Text cleaner (default="${cleaner}").
217 | --g2p # g2p method (default="${g2p}").
218 | --lang # The language type of corpus (default=${lang}).
219 | --score_opts # The options given to sclite scoring (default="{score_opts}").
220 | --local_score_opts # The options given to local/score.sh (default="{local_score_opts}").
221 | --asr_speech_fold_length # fold_length for speech data during ASR training (default="${asr_speech_fold_length}").
222 | --asr_text_fold_length # fold_length for text data during ASR training (default="${asr_text_fold_length}").
223 | --lm_fold_length # fold_length for LM training (default="${lm_fold_length}").
224 | EOF
225 | )
226 |
227 | # log "asr_test $0 $*"
228 | # Save command line args for logging (they will be lost after utils/parse_options.sh)
229 | run_args=$(pyscripts/utils/print_args.py $0 "$@")
230 | . utils/parse_options.sh
231 |
232 | if [ $# -ne 0 ]; then
233 | log "${help_message}"
234 | log "Error: No positional arguments are required."
235 | exit 2
236 | fi
237 |
238 | . ./path.sh
239 | . ./cmd.sh
240 |
241 |
242 | # Check required arguments
243 | [ -z "${test_sets}" ] && { log "${help_message}"; log "Error: --test_sets is required"; exit 2; };
244 |
245 | # Check feature type
246 | if [ "${feats_type}" = raw ]; then
247 | data_feats=${dumpdir}/raw
248 | elif [ "${feats_type}" = fbank_pitch ]; then
249 | data_feats=${dumpdir}/fbank_pitch
250 | elif [ "${feats_type}" = fbank ]; then
251 | data_feats=${dumpdir}/fbank
252 | elif [ "${feats_type}" == extracted ]; then
253 | data_feats=${dumpdir}/extracted
254 | else
255 | log "${help_message}"
256 | log "Error: not supported: --feats_type ${feats_type}"
257 | exit 2
258 | fi
259 |
260 | # Use the text of the 1st evaldir if lm_test is not specified
261 | [ -z "${lm_test_text}" ] && lm_test_text="${data_feats}/${test_sets%% *}/text"
262 |
263 | # Check tokenization type
264 | if [ "${lang}" != noinfo ]; then
265 | token_listdir=data/${lang}_token_list
266 | else
267 | token_listdir=data/token_list
268 | fi
269 |
270 | bpedir="${token_listdir}/bpe_${bpemode}${nbpe}"
271 | bpeprefix="${bpedir}"/bpe
272 | bpemodel="${bpeprefix}".model
273 | bpetoken_list="${bpedir}"/tokens.txt
274 | chartoken_list="${token_listdir}"/char/tokens.txt
275 | # NOTE: keep for future development.
276 | # shellcheck disable=SC2034
277 | wordtoken_list="${token_listdir}"/word/tokens.txt
278 |
279 | if [ "${token_type}" = bpe ]; then
280 | token_list="${bpetoken_list}"
281 | elif [ "${token_type}" = char ]; then
282 | token_list="${chartoken_list}"
283 | bpemodel=none
284 | elif [ "${token_type}" = word ]; then
285 | token_list="${wordtoken_list}"
286 | bpemodel=none
287 | else
288 | log "Error: not supported --token_type '${token_type}'"
289 | exit 2
290 | fi
291 | if ${use_word_lm}; then
292 | log "Error: Word LM is not supported yet"
293 | exit 2
294 |
295 | lm_token_list="${wordtoken_list}"
296 | lm_token_type=word
297 | else
298 | lm_token_list="${token_list}"
299 | lm_token_type="${token_type}"
300 | fi
301 |
302 |
303 | # Set tag for naming of model directory
304 | if [ -z "${asr_tag}" ]; then
305 | if [ -n "${asr_config}" ]; then
306 | asr_tag="$(basename "${asr_config}" .yaml)_${feats_type}"
307 | else
308 | asr_tag="train_${feats_type}"
309 | fi
310 | if [ "${lang}" != noinfo ]; then
311 | asr_tag+="_${lang}_${token_type}"
312 | else
313 | asr_tag+="_${token_type}"
314 | fi
315 | if [ "${token_type}" = bpe ]; then
316 | asr_tag+="${nbpe}"
317 | fi
318 | # Add overwritten arg's info
319 | if [ -n "${asr_args}" ]; then
320 | asr_tag+="$(echo "${asr_args}" | sed -e "s/--/\_/g" -e "s/[ |=/]//g")"
321 | fi
322 | if [ -n "${speed_perturb_factors}" ]; then
323 | asr_tag+="_sp"
324 | fi
325 | fi
326 | echo "${lang}, ${asr_tag}"
327 | if [ -z "${lm_tag}" ]; then
328 | if [ -n "${lm_config}" ]; then
329 | lm_tag="$(basename "${lm_config}" .yaml)"
330 | else
331 | lm_tag="train"
332 | fi
333 | if [ "${lang}" != noinfo ]; then
334 | lm_tag+="_${lang}_${lm_token_type}"
335 | else
336 | lm_tag+="_${lm_token_type}"
337 | fi
338 | if [ "${lm_token_type}" = bpe ]; then
339 | lm_tag+="${nbpe}"
340 | fi
341 | # Add overwritten arg's info
342 | if [ -n "${lm_args}" ]; then
343 | lm_tag+="$(echo "${lm_args}" | sed -e "s/--/\_/g" -e "s/[ |=/]//g")"
344 | fi
345 | fi
346 |
347 | # The directory used for collect-stats mode
348 | if [ -z "${asr_stats_dir}" ]; then
349 | if [ "${lang}" != noinfo ]; then
350 | asr_stats_dir="${expdir}/asr_stats_${feats_type}_${lang}_${token_type}"
351 | else
352 | asr_stats_dir="${expdir}/asr_stats_${feats_type}_${token_type}"
353 | fi
354 | if [ "${token_type}" = bpe ]; then
355 | asr_stats_dir+="${nbpe}"
356 | fi
357 | if [ -n "${speed_perturb_factors}" ]; then
358 | asr_stats_dir+="_sp"
359 | fi
360 | fi
361 | if [ -z "${lm_stats_dir}" ]; then
362 | if [ "${lang}" != noinfo ]; then
363 | lm_stats_dir="${expdir}/lm_stats_${lang}_${lm_token_type}"
364 | else
365 | lm_stats_dir="${expdir}/lm_stats_${lm_token_type}"
366 | fi
367 | if [ "${lm_token_type}" = bpe ]; then
368 | lm_stats_dir+="${nbpe}"
369 | fi
370 | fi
371 | # The directory used for training commands
372 | if [ -z "${asr_exp}" ]; then
373 | asr_exp="${expdir}/asr_${asr_tag}"
374 | fi
375 | if [ -z "${lm_exp}" ]; then
376 | lm_exp="${expdir}/lm_${lm_tag}"
377 | fi
378 |
379 |
380 | if [ -z "${inference_tag}" ]; then
381 | if [ -n "${inference_config}" ]; then
382 | inference_tag="$(basename "${inference_config}" .yaml)"
383 | else
384 | inference_tag=inference
385 | fi
386 | # Add overwritten arg's info
387 | if [ -n "${inference_args}" ]; then
388 | inference_tag+="$(echo "${inference_args}" | sed -e "s/--/\_/g" -e "s/[ |=]//g")"
389 | fi
390 | if "${use_lm}"; then
391 | inference_tag+="_lm_$(basename "${lm_exp}")_$(echo "${inference_lm}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")"
392 | fi
393 | inference_tag+="_asr_model_$(echo "${inference_asr_model}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")"
394 | fi
395 |
396 | # ========================== Main stages start from here. ==========================
397 |
398 | if ! "${skip_data_prep}"; then
399 | if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
400 | log "Stage 1: Data preparation for data/${test_sets}, etc."
401 | # [Task dependent] Need to create data.sh for new corpus
402 | local/test_data.sh ${local_data_opts}
403 | fi
404 |
405 | if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
406 | if [ "${feats_type}" = raw ]; then
407 | log "Stage 2: Format wav.scp: data/ -> ${data_feats}"
408 |
409 | # ====== Recreating "wav.scp" ======
410 | # Kaldi-wav.scp, which can describe the file path with unix-pipe, like "cat /some/path |",
411 | # shouldn't be used in training process.
412 | # "format_wav_scp.sh" dumps such pipe-style-wav to real audio file
413 | # and it can also change the audio-format and sampling rate.
414 | # If nothing is need, then format_wav_scp.sh does nothing:
415 | # i.e. the input file format and rate is same as the output.
416 |
417 | for dset in "${test_sets}"; do
418 | _suf=""
419 | utils/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}"
420 | rm -f ${data_feats}${_suf}/${dset}/{segments,wav.scp,reco2file_and_channel}
421 | _opts=
422 | if [ -e data/"${dset}"/segments ]; then
423 | # "segments" is used for splitting wav files which are written in "wav".scp
424 | # into utterances. The file format of segments:
425 | #
426 | # "e.g. call-861225-A-0050-0065 call-861225-A 5.0 6.5"
427 | # Where the time is written in seconds.
428 | _opts+="--segments data/${dset}/segments "
429 | fi
430 | # shellcheck disable=SC2086
431 | scripts/audio/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \
432 | --audio-format "${audio_format}" --fs "${fs}" ${_opts} \
433 | "data/${dset}/wav.scp" "${data_feats}${_suf}/${dset}"
434 |
435 | echo "${feats_type}" > "${data_feats}${_suf}/${dset}/feats_type"
436 | done
437 |
438 | elif [ "${feats_type}" = fbank_pitch ]; then
439 | log "[Require Kaldi] Stage 3: ${feats_type} extract: data/ -> ${data_feats}"
440 |
441 | for dset in "${test_sets}"; do
442 | _suf=""
443 | # 1. Copy datadir
444 | utils/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}"
445 |
446 | # 2. Feature extract
447 | _nj=$(min "${nj}" "$(<"${data_feats}${_suf}/${dset}/utt2spk" wc -l)")
448 | steps/make_fbank_pitch.sh --nj "${_nj}" --cmd "${train_cmd}" "${data_feats}${_suf}/${dset}"
449 | utils/fix_data_dir.sh "${data_feats}${_suf}/${dset}"
450 |
451 | # 3. Derive the the frame length and feature dimension
452 | scripts/feats/feat_to_shape.sh --nj "${_nj}" --cmd "${train_cmd}" \
453 | "${data_feats}${_suf}/${dset}/feats.scp" "${data_feats}${_suf}/${dset}/feats_shape"
454 |
455 | # 4. Write feats_dim
456 | head -n 1 "${data_feats}${_suf}/${dset}/feats_shape" | awk '{ print $2 }' \
457 | | cut -d, -f2 > ${data_feats}${_suf}/${dset}/feats_dim
458 |
459 | # 5. Write feats_type
460 | echo "${feats_type}" > "${data_feats}${_suf}/${dset}/feats_type"
461 | done
462 |
463 | elif [ "${feats_type}" = fbank ]; then
464 | log "Stage 3: ${feats_type} extract: data/ -> ${data_feats}"
465 | log "${feats_type} is not supported yet."
466 | exit 1
467 |
468 | elif [ "${feats_type}" = extracted ]; then
469 | log "Stage 3: ${feats_type} extract: data/ -> ${data_feats}"
470 | # Assumming you don't have wav.scp, but feats.scp is created by local/data.sh instead.
471 |
472 | for dset in "${test_sets}"; do
473 | _suf=""
474 | # Generate dummy wav.scp to avoid error by copy_data_dir.sh
475 | ") }' > data/"${dset}"/wav.scp
476 | utils/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}"
477 |
478 | pyscripts/feats/feat-to-shape.py "scp:head -n 1 ${data_feats}${_suf}/${dset}/feats.scp |" - | \
479 | awk '{ print $2 }' | cut -d, -f2 > "${data_feats}${_suf}/${dset}/feats_dim"
480 |
481 | echo "${feats_type}" > "${data_feats}${_suf}/${dset}/feats_type"
482 | done
483 |
484 | else
485 | log "Error: not supported: --feats_type ${feats_type}"
486 | exit 2
487 | fi
488 | fi
489 |
490 | else
491 | log "Skip the stages for data preparation"
492 | fi
493 |
494 | # ========================== Data preparation is done here. ==========================
495 |
496 | #download_model="kamo-naoyuki/timit_asr_train_asr_raw_word_valid.acc.ave"
497 | if [ -n "${download_model}" ]; then
498 | log "Use ${download_model} for decoding and evaluation"
499 | asr_exp="${expdir}/${download_model}"
500 | echo ${asr_exp}
501 | mkdir -p "${asr_exp}"
502 |
503 | # If the model already exists, you can skip downloading
504 | espnet_model_zoo_download --unpack true "${download_model}" > "${asr_exp}/config.txt"
505 |
506 | # Get the path of each file
507 | _asr_model_file=$(<"${asr_exp}/config.txt" sed -e "s/.*'asr_model_file': '\([^']*\)'.*$/\1/")
508 | _asr_train_config=$(<"${asr_exp}/config.txt" sed -e "s/.*'asr_train_config': '\([^']*\)'.*$/\1/")
509 |
510 | # Create symbolic links
511 | ln -sf "${_asr_model_file}" "${asr_exp}"
512 | ln -sf "${_asr_train_config}" "${asr_exp}"
513 | inference_asr_model=$(basename "${_asr_model_file}")
514 |
515 | if [ "$(<${asr_exp}/config.txt grep -c lm_file)" -gt 0 ]; then
516 | _lm_file=$(<"${asr_exp}/config.txt" sed -e "s/.*'lm_file': '\([^']*\)'.*$/\1/")
517 | _lm_train_config=$(<"${asr_exp}/config.txt" sed -e "s/.*'lm_train_config': '\([^']*\)'.*$/\1/")
518 |
519 | lm_exp="${expdir}/${download_model}/lm"
520 | mkdir -p "${lm_exp}"
521 |
522 | ln -sf "${_lm_file}" "${lm_exp}"
523 | ln -sf "${_lm_train_config}" "${lm_exp}"
524 | inference_lm=$(basename "${_lm_file}")
525 | fi
526 |
527 | fi
528 |
529 |
530 | if ! "${skip_eval}"; then
531 | if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
532 | log "Stage 3: Decoding: training_dir=${asr_exp}"
533 |
534 | if ${gpu_inference}; then
535 | _cmd="${cuda_cmd}"
536 | _ngpu=1
537 | else
538 | _cmd="${decode_cmd}"
539 | _ngpu=0
540 | fi
541 |
542 | _opts=
543 | if [ -n "${inference_config}" ]; then
544 | _opts+="--config ${inference_config} "
545 | fi
546 | if "${use_lm}"; then
547 | if "${use_word_lm}"; then
548 | _opts+="--word_lm_train_config ${lm_exp}/config.yaml "
549 | _opts+="--word_lm_file ${lm_exp}/${inference_lm} "
550 | else
551 | _opts+="--lm_train_config ${lm_exp}/config.yaml "
552 | _opts+="--lm_file ${lm_exp}/${inference_lm} "
553 | fi
554 | fi
555 |
556 | # 2. Generate run.sh
557 | log "Generate '${asr_exp}/${inference_tag}/run.sh'. You can resume the process from stage 11 using this script"
558 | mkdir -p "${asr_exp}/${inference_tag}"; echo "${run_args} --stage 11 \"\$@\"; exit \$?" > "${asr_exp}/${inference_tag}/run.sh"; chmod +x "${asr_exp}/${inference_tag}/run.sh"
559 |
560 | for dset in ${test_sets}; do
561 | _data="${data_feats}/${dset}"
562 | _dir="${asr_exp}/${inference_tag}/${dset}"
563 | _logdir="${_dir}/logdir"
564 | mkdir -p "${_logdir}"
565 |
566 | if ! "${need_decode}"; then
567 | log "Do not use decoding "
568 | ${python} -m espnet2.bin.asr_test \
569 | --ngpu "${_ngpu}" \
570 | --data_path_and_name_and_type "${_data},speech,sound" \
571 | --key_file "${_logdir}"/keys.JOB.scp \
572 | --asr_train_config "${asr_exp}"/config.yaml \
573 | --asr_model_file "${asr_exp}"/"${inference_asr_model}" \
574 | --orig_dir "${asr_exp}/${inference_tag}" \
575 | --orig_flag "${orig_flag}" \
576 | --need_decode "${need_decode}" \
577 | --output_dir "${_logdir}"/output.JOB \
578 | --selected_num ${selected_num} \
579 | --dataset ${dataset} \
580 | ${_opts} ${inference_args}
581 | else
582 | _feats_type="$(<${_data}/feats_type)"
583 | if [ "${_feats_type}" = raw ]; then
584 | _scp=wav.scp
585 | if [[ "${audio_format}" == *ark* ]]; then
586 | _type=kaldi_ark
587 | else
588 | _type=sound
589 | fi
590 | else
591 | _scp=feats.scp
592 | _type=kaldi_ark
593 | fi
594 | log "Decoding started... log: '${_logdir}/asr_test.*.log', data path and name: '${_data}/${_scp},speech,${_type}'"
595 | key_file=${_data}/${_scp}
596 | split_scps=""
597 | _nj=$(min "${inference_nj}" "$(<${key_file} wc -l)")
598 | for n in $(seq "${_nj}"); do
599 | split_scps+=" ${_logdir}/keys.${n}.scp"
600 | done
601 | # shellcheck disable=SC2086
602 | utils/split_scp.pl "${key_file}" ${split_scps}
603 | ${_cmd} --gpu "${_ngpu}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
604 | ${python} -m espnet2.bin.asr_test \
605 | --ngpu "${_ngpu}" \
606 | --data_path_and_name_and_type "${_data}/${_scp},speech,${_type}" \
607 | --key_file "${_logdir}"/keys.JOB.scp \
608 | --asr_train_config "${asr_exp}"/config.yaml \
609 | --asr_model_file "${asr_exp}"/"${inference_asr_model}" \
610 | --output_dir "${_logdir}"/output.JOB \
611 | --orig_dir "${asr_exp}/${inference_tag}" \
612 | --orig_flag "${orig_flag}" \
613 | --need_decode "${need_decode}" \
614 | ${_opts} ${inference_args}
615 | # 3. Concatenates the output files from each jobs
616 | for f in token token_int score text sum_gini token_gini cov ; do
617 | for i in $(seq "${_nj}"); do
618 | cat "${_logdir}/output.${i}/1best_recog/${f}"
619 | done | LC_ALL=C sort -k1 >"${_dir}/${f}"
620 | done
621 | fi
622 | done
623 | fi
624 |
625 |
626 | if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
627 | log "Stage 4: Scoring"
628 | if [ "${token_type}" = pnh ]; then
629 | log "Error: Not implemented for token_type=phn"
630 | exit 1
631 | fi
632 |
633 | for dset in ${test_sets}; do
634 | _data="${data_feats}/${dset}"
635 | if ! "${need_decode}"; then
636 | dirs=("${asr_exp}/${inference_tag}/${dset}/gini-${selected_num}" "${asr_exp}/${inference_tag}/${dset}/random-${selected_num} ${asr_exp}/${inference_tag}/${dset}/cov-${selected_num}")
637 | else
638 | dirs=("${asr_exp}/${inference_tag}/${dset}")
639 | fi
640 | for _dir in ${dirs[*]} ; do
641 | for _type in cer wer ter; do
642 | [ "${_type}" = ter ] && [ ! -f "${bpemodel}" ] && continue
643 |
644 | _scoredir="${_dir}/score_${_type}"
645 | mkdir -p "${_scoredir}"
646 |
647 | if "${need_decode}"; then
648 | if [ "${_type}" = wer ]; then
649 | # Tokenize text to word level
650 | paste \
651 | <(<"${_data}/text" \
652 | ${python} -m espnet2.bin.tokenize_text \
653 | -f 2- --input - --output - \
654 | --token_type word \
655 | --non_linguistic_symbols "${nlsyms_txt}" \
656 | --remove_non_linguistic_symbols true \
657 | --cleaner "${cleaner}" \
658 | ) \
659 | <(<"${_data}/utt2spk" awk '{ print "(" $2 "-" $1 ")" }') \
660 | >"${_scoredir}/ref.trn"
661 |
662 | # NOTE(kamo): Don't use cleaner for hyp
663 | paste \
664 | <(<"${_dir}/text" \
665 | ${python} -m espnet2.bin.tokenize_text \
666 | -f 2- --input - --output - \
667 | --token_type word \
668 | --non_linguistic_symbols "${nlsyms_txt}" \
669 | --remove_non_linguistic_symbols true \
670 | ) \
671 | <(<"${_data}/utt2spk" awk '{ print "(" $2 "-" $1 ")" }') \
672 | >"${_scoredir}/hyp.trn"
673 |
674 |
675 | elif [ "${_type}" = cer ]; then
676 | # Tokenize text to char level
677 | paste \
678 | <(<"${_data}/text" \
679 | ${python} -m espnet2.bin.tokenize_text \
680 | -f 2- --input - --output - \
681 | --token_type char \
682 | --non_linguistic_symbols "${nlsyms_txt}" \
683 | --remove_non_linguistic_symbols true \
684 | --cleaner "${cleaner}" \
685 | ) \
686 | <(<"${_data}/utt2spk" awk '{ print "(" $2 "-" $1 ")" }') \
687 | >"${_scoredir}/ref.trn"
688 |
689 | # NOTE(kamo): Don't use cleaner for hyp
690 | paste \
691 | <(<"${_dir}/text" \
692 | ${python} -m espnet2.bin.tokenize_text \
693 | -f 2- --input - --output - \
694 | --token_type char \
695 | --non_linguistic_symbols "${nlsyms_txt}" \
696 | --remove_non_linguistic_symbols true \
697 | ) \
698 | <(<"${_data}/utt2spk" awk '{ print "(" $2 "-" $1 ")" }') \
699 | >"${_scoredir}/hyp.trn"
700 |
701 | elif [ "${_type}" = ter ]; then
702 | # Tokenize text using BPE
703 | paste \
704 | <(<"${_data}/text" \
705 | ${python} -m espnet2.bin.tokenize_text \
706 | -f 2- --input - --output - \
707 | --token_type bpe \
708 | --bpemodel "${bpemodel}" \
709 | --cleaner "${cleaner}" \
710 | ) \
711 | <(<"${_data}/utt2spk" awk '{ print "(" $2 "-" $1 ")" }') \
712 | >"${_scoredir}/ref.trn"
713 |
714 | # NOTE(kamo): Don't use cleaner for hyp
715 | paste \
716 | <(<"${_dir}/text" \
717 | ${python} -m espnet2.bin.tokenize_text \
718 | -f 2- --input - --output - \
719 | --token_type bpe \
720 | --bpemodel "${bpemodel}" \
721 | ) \
722 | <(<"${_data}/utt2spk" awk '{ print "(" $2 "-" $1 ")" }') \
723 | >"${_scoredir}/hyp.trn"
724 |
725 | fi
726 | fi
727 |
728 | sclite \
729 | ${score_opts} \
730 | -r "${_scoredir}/ref.trn" trn \
731 | -h "${_scoredir}/hyp.trn" trn \
732 | -i rm -o all stdout > "${_scoredir}/result.txt"
733 |
734 | log "Write ${_type} result in ${_scoredir}/result.txt"
735 | grep -e Avg -e SPKR -m 2 "${_scoredir}/result.txt"
736 | done
737 | done
738 | done
739 |
740 | [ -f local/score.sh ] && local/score.sh ${local_score_opts} "${asr_exp}"
741 |
742 | # Show results in Markdown syntax
743 | scripts/utils/show_asr_result.sh "${asr_exp}" > "${asr_exp}"/RESULTS.md
744 | cat "${asr_exp}"/RESULTS.md
745 |
746 | fi
747 | else
748 | log "Skip the evaluation stages"
749 | fi
750 |
751 |
752 | log "Successfully finished. [elapsed=${SECONDS}s]"
753 |
754 |
--------------------------------------------------------------------------------