├── eval ├── ABX_src │ ├── __init__.py │ ├── test_data │ │ ├── 23.npy │ │ ├── 42.npy │ │ ├── 2107.npy │ │ ├── 407.npy │ │ ├── dummy_item_within.item │ │ └── dummy_item_file.item │ ├── setup.py │ ├── dtw.pyx │ ├── abx_group_computation.py │ └── unit_tests.py ├── PER_src │ ├── __init__.py │ ├── setup.py │ ├── per_operator.pyx │ ├── seq_alignment.py │ └── simplePhonemLearner.py ├── WER_data │ └── letters.lst ├── WER_src │ ├── letter_ctc.py │ ├── wl_decoder.py │ └── simple_dataset.py ├── README.md ├── eval_ABX.py └── CPC_loader.py ├── data_preparation ├── metadata_completion │ ├── __init__.py │ ├── GenreScrapper.py │ ├── genre_folding.py │ ├── ReaderScapper.py │ ├── text_cleaner.py │ └── DuplicateSearch.py ├── rebuild_limited_train │ ├── split.sh │ ├── clean_texts.py │ ├── get_stats.py │ ├── split_1h_in10min.py │ ├── README.md │ ├── select_1h.py │ ├── sample_10h.py │ └── utils.py ├── text_retrieval │ ├── guttenberg.py │ ├── __init__.py │ ├── archive_org.py │ ├── main_lesson.py │ ├── hathitrust.py │ └── bartleby.py ├── unit_tests.py ├── split_librilight │ ├── extract_test_speakers.py │ ├── README.md │ ├── materialize_split.py │ ├── prepare_vads_tests.py │ ├── prepare_vads.py │ └── split.py ├── plot.py ├── build_all_stats.py ├── unzip_and_convert.py ├── cut_by_vad.py ├── make_vad_inputs.py ├── complete_metadata.py ├── calculate_snr.py └── README.md ├── baselines ├── TDS │ ├── data │ │ ├── letters.lst │ │ └── phones.lst │ ├── experiments │ │ ├── arch │ │ │ ├── TDS_20M.arch │ │ │ └── TDS_37M.arch │ │ └── config │ │ │ ├── decoding │ │ │ ├── 1h_phone_20M_TDS.cfg │ │ │ ├── 10h_letter_20M_TDS.cfg │ │ │ ├── 10h_phone_20M_TDS.cfg │ │ │ ├── 1h_letter_20M_TDS.cfg │ │ │ ├── 1h+pseudo-label_phone_37M_TDS.cfg │ │ │ ├── 10h+pseudo-label_phone_37M_TDS.cfg │ │ │ ├── 1h+pseudo-label_letter_37M_TDS.cfg │ │ │ └── 10h+pseudo-label_letter_37M_TDS.cfg │ │ │ └── training │ │ │ ├── 1h_phone_20M_TDS.cfg │ │ │ ├── 10h_phone_20M_TDS.cfg │ │ │ ├── 1h_letter_20M_TDS.cfg │ │ │ ├── 10h_letter_20M_TDS.cfg │ │ │ ├── 1h+pseudo-label_phone_37M_TDS.cfg │ │ │ ├── 10h+pseudo-label_phone_37M_TDS.cfg │ │ │ ├── 1h+pseudo-label_letter_37M_TDS.cfg │ │ │ └── 10h+pseudo-label_letter_37M_TDS.cfg │ └── README.md └── README.md ├── environment.yml ├── CITATION ├── LICENSE ├── CONTRIBUTING.md ├── CODE_OF_CONDUCT.md └── README.md /eval/ABX_src/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | -------------------------------------------------------------------------------- /eval/PER_src/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | -------------------------------------------------------------------------------- /data_preparation/metadata_completion/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | -------------------------------------------------------------------------------- /eval/ABX_src/test_data/23.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/libri-light/HEAD/eval/ABX_src/test_data/23.npy -------------------------------------------------------------------------------- /eval/ABX_src/test_data/42.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/libri-light/HEAD/eval/ABX_src/test_data/42.npy -------------------------------------------------------------------------------- /eval/ABX_src/test_data/2107.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/libri-light/HEAD/eval/ABX_src/test_data/2107.npy -------------------------------------------------------------------------------- /eval/ABX_src/test_data/407.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/libri-light/HEAD/eval/ABX_src/test_data/407.npy -------------------------------------------------------------------------------- /eval/WER_data/letters.lst: -------------------------------------------------------------------------------- 1 | s 2 | h 3 | e 4 | | 5 | k 6 | p 7 | t 8 | i 9 | m 10 | d 11 | o 12 | w 13 | n 14 | b 15 | y 16 | r 17 | f 18 | u 19 | g 20 | q 21 | a 22 | l 23 | j 24 | c 25 | v 26 | z 27 | ' 28 | x 29 | -------------------------------------------------------------------------------- /baselines/TDS/data/letters.lst: -------------------------------------------------------------------------------- 1 | | 2 | ' 3 | a 4 | b 5 | c 6 | d 7 | e 8 | f 9 | g 10 | h 11 | i 12 | j 13 | k 14 | l 15 | m 16 | n 17 | o 18 | p 19 | q 20 | r 21 | s 22 | t 23 | u 24 | v 25 | w 26 | x 27 | y 28 | z 29 | -------------------------------------------------------------------------------- /baselines/README.md: -------------------------------------------------------------------------------- 1 | # Models 2 | 3 | This directory contains pretrained on Libri-Light, in addition to information on reproducing the results. It contains the following models: 4 | - [TDS models trained with wav2letter](./TDS/README.md) 5 | -------------------------------------------------------------------------------- /eval/PER_src/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from distutils.core import setup 3 | from Cython.Build import cythonize 4 | 5 | setup( 6 | ext_modules=cythonize("per_operator.pyx") 7 | ) 8 | -------------------------------------------------------------------------------- /eval/ABX_src/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from distutils.core import setup 3 | from Cython.Build import cythonize 4 | import numpy 5 | 6 | setup( 7 | include_dirs=[numpy.get_include()], 8 | ext_modules=cythonize("dtw.pyx") 9 | ) 10 | -------------------------------------------------------------------------------- /eval/ABX_src/test_data/dummy_item_within.item: -------------------------------------------------------------------------------- 1 | #file onset offset #phone prev-phone next-phone speaker 2 | 2107 0. 0.2 n p d 8193 3 | 2107 0.3225 0.5225 n ae d 8193 4 | 2107 0.6 0.75 n ae d 8193 5 | 2107 0.4225 0.5925 d n l 2222 6 | 42 0.4525 0.6525 d n l 2222 7 | 42 0.1301 0.2501 q n l 2222 8 | 42 0.5225 0.7325 d n l 8193 9 | 42 0.0025 0.3561 d p l 2222 10 | 42 0.5925 0.8725 d p l 8193 11 | -------------------------------------------------------------------------------- /eval/ABX_src/test_data/dummy_item_file.item: -------------------------------------------------------------------------------- 1 | #file onset offset #phone prev-phone next-phone speaker 2 | 2107 0.3225 0.5225 n ae d 8193 3 | 2107 0.4225 0.5925 d n l 2222 4 | 42 0.4525 0.6525 d n l 2222 5 | 42 0.5225 0.7325 ih l n 8193 6 | 42 0.5925 0.8725 n ih s 8193 7 | 23 0.6525 1.1025 s n ax 8193 8 | 23 0.7325 1.1925 s n ax 2222 9 | 407 0.8725 1.2425 s ax dh 2222 10 | 2107 1.1025 1.2925 dh s ax 12 11 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: libri-light 2 | channels: 3 | - pytorch 4 | - anaconda 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - pytorch 9 | - torchvision 10 | - cudatoolkit=9.2 11 | - torchaudio 12 | - numpy 13 | - pysoundfile 14 | - pip 15 | - openblas-devel 16 | - tqdm 17 | - nose 18 | - pip: 19 | - progressbar2 20 | - matplotlib 21 | - torchaudio 22 | - termcolor 23 | -------------------------------------------------------------------------------- /baselines/TDS/experiments/arch/TDS_20M.arch: -------------------------------------------------------------------------------- 1 | V -1 NFEAT 1 0 2 | C2 1 10 21 1 2 1 -1 -1 3 | R 4 | DO 0.4 5 | LN 3 6 | TDS 10 21 80 0.4 800 7 | TDS 10 21 80 0.4 800 8 | C2 10 14 21 1 1 1 -1 -1 9 | R 10 | DO 0.4 11 | LN 3 12 | TDS 14 21 80 0.4 1120 13 | TDS 14 21 80 0.4 1120 14 | C2 14 18 21 1 1 1 -1 -1 15 | R 16 | DO 0.4 17 | LN 3 18 | TDS 18 21 80 0.4 1440 19 | TDS 18 21 80 0.4 1440 20 | TDS 18 21 80 0.4 1440 21 | V 0 1440 1 0 22 | RO 1 0 3 2 23 | L 1440 NLABEL 24 | -------------------------------------------------------------------------------- /baselines/TDS/experiments/config/decoding/1h_phone_20M_TDS.cfg: -------------------------------------------------------------------------------- 1 | --lm=/lm-4g.bin 2 | --am=/1h_phone_20M_TDS.bin 3 | --test=/librivox.lst 4 | --decodertype=wrd 5 | --lmtype=kenlm 6 | --lmweight=2.03 7 | --wordscore=1.61 8 | --beamsize=1000 9 | --beamthreshold=15 10 | --smearing=max 11 | --maxload=-1 12 | --nthread_decoder=5 13 | --minisz=100 14 | --maxisz=36000 15 | --uselexicon=true 16 | --wordseparator=# 17 | --show=true 18 | -------------------------------------------------------------------------------- /baselines/TDS/experiments/config/decoding/10h_letter_20M_TDS.cfg: -------------------------------------------------------------------------------- 1 | --lm=/lm-4g.bin 2 | --am=/10h_letter_20M_TDS.bin 3 | --test=/librivox.lst 4 | --decodertype=wrd 5 | --lmtype=kenlm 6 | --lmweight=2.38 7 | --wordscore=0.0 8 | --beamsize=1000 9 | --beamthreshold=15 10 | --smearing=max 11 | --maxload=-1 12 | --nthread_decoder=5 13 | --minisz=100 14 | --maxisz=36000 15 | --uselexicon=true 16 | --wordseparator=| 17 | --show=true 18 | -------------------------------------------------------------------------------- /baselines/TDS/experiments/config/decoding/10h_phone_20M_TDS.cfg: -------------------------------------------------------------------------------- 1 | --lm=/lm-4g.bin 2 | --am=/10h_phone_20M_TDS.bin 3 | --test=/librivox.lst 4 | --decodertype=wrd 5 | --lmtype=kenlm 6 | --lmweight=4.18 7 | --wordscore=1.58 8 | --beamsize=1000 9 | --beamthreshold=15 10 | --smearing=max 11 | --maxload=-1 12 | --nthread_decoder=5 13 | --minisz=100 14 | --maxisz=36000 15 | --uselexicon=true 16 | --wordseparator=# 17 | --show=true 18 | -------------------------------------------------------------------------------- /baselines/TDS/experiments/config/decoding/1h_letter_20M_TDS.cfg: -------------------------------------------------------------------------------- 1 | --lm=/lm-4g.bin 2 | --am=/1h_letter_20M_TDS.bin 3 | --test=/librivox.lst 4 | --decodertype=wrd 5 | --lmtype=kenlm 6 | --lmweight=1.32 7 | --wordscore=0.0 8 | --beamsize=1000 9 | --beamthreshold=15 10 | --smearing=max 11 | --maxload=-1 12 | --nthread_decoder=5 13 | --minisz=100 14 | --maxisz=36000 15 | --uselexicon=true 16 | --wordseparator=| 17 | --show=true 18 | -------------------------------------------------------------------------------- /baselines/TDS/experiments/config/decoding/1h+pseudo-label_phone_37M_TDS.cfg: -------------------------------------------------------------------------------- 1 | --lm=/lm-4g.bin 2 | --am=/1h+pseudo-label_phone_37M_TDS.bin 3 | --test=/dev-other.lst 4 | --decodertype=wrd 5 | --lmtype=kenlm 6 | --lmweight=1.47 7 | --wordscore=1.47 8 | --beamsize=1000 9 | --beamthreshold=15 10 | --smearing=max 11 | --maxload=-1 12 | --nthread_decoder=5 13 | --minisz=100 14 | --maxisz=36000 15 | --uselexicon=true 16 | --wordseparator=# 17 | --show=true 18 | -------------------------------------------------------------------------------- /data_preparation/rebuild_limited_train/split.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | python sample_10h.py --target_dir=10h_temp 3 | python select_1h.py --root_10h=10h_temp --target_dir=1h 4 | python split_1h_in10min.py --root_1h=1h --target_dir=6x10min 5 | 6 | mkdir librispeech_release 7 | mv 10h_temp ./librispeech_release/9h/ 8 | mv 6x10min ./librispeech_release/1h/ 9 | 10 | python clean_texts.py --root=librispeech_release/ 11 | 12 | rm -rf 1h 13 | -------------------------------------------------------------------------------- /baselines/TDS/experiments/config/decoding/10h+pseudo-label_phone_37M_TDS.cfg: -------------------------------------------------------------------------------- 1 | --lm=/lm-4g.bin 2 | --am=/10h+pseudo-label_phone_37M_TDS.bin 3 | --test=/dev-other.lst 4 | --decodertype=wrd 5 | --lmtype=kenlm 6 | --lmweight=1.57 7 | --wordscore=2.33 8 | --beamsize=1000 9 | --beamthreshold=15 10 | --smearing=max 11 | --maxload=-1 12 | --nthread_decoder=5 13 | --minisz=100 14 | --maxisz=36000 15 | --uselexicon=true 16 | --wordseparator=# 17 | --show=true 18 | -------------------------------------------------------------------------------- /baselines/TDS/experiments/config/decoding/1h+pseudo-label_letter_37M_TDS.cfg: -------------------------------------------------------------------------------- 1 | --lm=/lm-4g.bin 2 | --am=/1h+pseudo-label_letter_37M_TDS.bin 3 | --test=/dev-other.lst 4 | --decodertype=wrd 5 | --lmtype=kenlm 6 | --lmweight=0.70 7 | --wordscore=1.83 8 | --beamsize=1000 9 | --beamthreshold=15 10 | --silweight=-2.89 11 | --smearing=max 12 | --maxload=-1 13 | --nthread_decoder=5 14 | --minisz=100 15 | --maxisz=36000 16 | --uselexicon=true 17 | --wordseparator=| 18 | --show=true 19 | -------------------------------------------------------------------------------- /baselines/TDS/experiments/config/decoding/10h+pseudo-label_letter_37M_TDS.cfg: -------------------------------------------------------------------------------- 1 | --lm=/lm-4g.bin 2 | --am=/10h+pseudo-label_letter_37M_TDS.bin 3 | --test=/dev-other.lst 4 | --decodertype=wrd 5 | --lmtype=kenlm 6 | --lmweight=1.08 7 | --wordscore=2.12 8 | --beamsize=1000 9 | --beamthreshold=15 10 | --silweight=-1.66 11 | --smearing=max 12 | --maxload=-1 13 | --nthread_decoder=5 14 | --minisz=100 15 | --maxisz=36000 16 | --uselexicon=true 17 | --wordseparator=| 18 | --show=true 19 | -------------------------------------------------------------------------------- /CITATION: -------------------------------------------------------------------------------- 1 | @inproceedings{kahn2020libri, 2 | title={Libri-light: A benchmark for asr with limited or no supervision}, 3 | author={Kahn, Jacob and Riviere, Morgane and Zheng, Weiyi and Kharitonov, Evgeny and Xu, Qiantong and Mazar{\'e}, Pierre-Emmanuel and Karadayi, Julien and Liptchinsky, Vitaliy and Collobert, Ronan and Fuegen, Christian and others}, 4 | booktitle={ICASSP 2020-2020 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, 5 | pages={7669--7673}, 6 | year={2020}, 7 | organization={IEEE} 8 | } 9 | -------------------------------------------------------------------------------- /baselines/TDS/experiments/arch/TDS_37M.arch: -------------------------------------------------------------------------------- 1 | V -1 NFEAT 1 0 2 | C2 1 10 21 1 2 1 -1 -1 3 | R 4 | DO 0.2 5 | LN 3 6 | TDS 10 21 80 0.1 800 7 | TDS 10 21 80 0.1 800 8 | C2 10 14 21 1 1 1 -1 -1 9 | R 10 | DO 0.2 11 | LN 3 12 | TDS 14 21 80 0.1 1120 13 | TDS 14 21 80 0.1 1120 14 | TDS 14 21 80 0.1 1120 15 | C2 14 18 21 1 1 1 -1 -1 16 | R 17 | DO 0.2 18 | LN 3 19 | TDS 18 21 80 0.1 1440 20 | TDS 18 21 80 0.1 1440 21 | TDS 18 21 80 0.1 1440 22 | TDS 18 21 80 0.1 1440 23 | TDS 18 21 80 0.1 1440 24 | TDS 18 21 80 0.1 1440 25 | V 0 1440 1 0 26 | RO 1 0 3 2 27 | L 1440 NLABEL 28 | -------------------------------------------------------------------------------- /data_preparation/text_retrieval/guttenberg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import requests 3 | 4 | 5 | def is_guttenberg_url(url): 6 | return url.find('http://www.gutenberg.org') == 0 or \ 7 | url.find('https://www.gutenberg.org') == 0 or \ 8 | url.find("http://gutenberg.org") == 0 9 | 10 | 11 | def get_guttenberg_data(url): 12 | txtID = url.split('/')[-1] 13 | targetURL = f'http://www.gutenberg.org/cache/epub/{txtID}/pg{txtID}.txt' 14 | return requests.get(targetURL)._content.decode("utf-8") 15 | -------------------------------------------------------------------------------- /baselines/TDS/data/phones.lst: -------------------------------------------------------------------------------- 1 | | 2 | EY2 3 | EY1 4 | EY0 5 | UH0 6 | UH1 7 | UH2 8 | P 9 | EH1 10 | EH0 11 | EH2 12 | S 13 | OW0 14 | OW2 15 | OW1 16 | AH2 17 | AH0 18 | AH1 19 | IH1 20 | IH0 21 | IH2 22 | AW2 23 | AW1 24 | AW0 25 | HH 26 | OY2 27 | OY1 28 | OY0 29 | ER1 30 | ER0 31 | ER2 32 | CH 33 | SIL 34 | AE2 35 | AE0 36 | AE1 37 | Y 38 | N 39 | ZH 40 | AA2 41 | AA0 42 | AA1 43 | K 44 | AO2 45 | AO0 46 | AO1 47 | M 48 | DH 49 | UW1 50 | UW0 51 | UW2 52 | AY1 53 | AY0 54 | AY2 55 | V 56 | SH 57 | Z 58 | TH 59 | GARBAGE 60 | B 61 | G 62 | W 63 | IY0 64 | IY2 65 | IY1 66 | R 67 | L 68 | D 69 | JH 70 | NG 71 | T 72 | F 73 | -------------------------------------------------------------------------------- /baselines/TDS/experiments/config/training/1h_phone_20M_TDS.cfg: -------------------------------------------------------------------------------- 1 | --runname=1h_phone_20M_TDS 2 | --rundir= 3 | --arch=TDS_20M.arch 4 | --archdir=/experiments/arch 5 | --batchsize=2 6 | --tokensdir=/data 7 | --tokens=phones.lst 8 | --lr=0.1 9 | --momentum=0.5 10 | --maxgradnorm=1 11 | --onorm=target 12 | --sqnorm=true 13 | --mfsc=true 14 | --nthread=10 15 | --criterion=ctc 16 | --memstepsize=8338608 17 | --listdata=true 18 | --lexicon=/data/phone.lex 19 | --datadir= 20 | --train=/data/1h.lst 21 | --valid=dev-other.lst,dev-clean.lst 22 | --wordseparator= 23 | --target=ltr 24 | --filterbanks=80 25 | --gamma=0.5 26 | --enable_distributed=true 27 | --iter=2000 28 | --framesizems=30 29 | --framestridems=10 30 | --seed=2 31 | --lrcosine=true 32 | -------------------------------------------------------------------------------- /baselines/TDS/experiments/config/training/10h_phone_20M_TDS.cfg: -------------------------------------------------------------------------------- 1 | --runname=10h_phone_20M_TDS 2 | --rundir= 3 | --arch=TDS_20M.arch 4 | --archdir=/experiments/arch 5 | --batchsize=8 6 | --tokensdir=/data 7 | --tokens=phones.lst 8 | --lr=0.1 9 | --momentum=0.5 10 | --maxgradnorm=1 11 | --onorm=target 12 | --sqnorm=true 13 | --mfsc=true 14 | --nthread=10 15 | --criterion=ctc 16 | --memstepsize=8338608 17 | --listdata=true 18 | --lexicon=/data/phone.lex 19 | --datadir= 20 | --train=/data/10h.lst 21 | --valid=dev-other.lst,dev-clean.lst 22 | --wordseparator= 23 | --target=ltr 24 | --filterbanks=80 25 | --gamma=0.5 26 | --enable_distributed=true 27 | --iter=1500 28 | --framesizems=30 29 | --framestridems=10 30 | --seed=2 31 | --lrcosine=true 32 | -------------------------------------------------------------------------------- /baselines/TDS/experiments/config/training/1h_letter_20M_TDS.cfg: -------------------------------------------------------------------------------- 1 | --runname=1h_letter_20M_TDS 2 | --rundir= 3 | --arch=TDS_20M.arch 4 | --archdir=/experiments/arch 5 | --batchsize=2 6 | --tokensdir=/data 7 | --tokens=letters.lst 8 | --lr=0.1 9 | --momentum=0.5 10 | --maxgradnorm=1 11 | --onorm=target 12 | --sqnorm=true 13 | --mfsc=true 14 | --nthread=10 15 | --criterion=ctc 16 | --memstepsize=8338608 17 | --listdata=true 18 | --lexicon=/data/letter.lex 19 | --datadir= 20 | --train=/data/1h.lst 21 | --valid=dev-other.lst,dev-clean.lst 22 | --wordseparator= 23 | --target=ltr 24 | --filterbanks=80 25 | --gamma=0.5 26 | --enable_distributed=true 27 | --iter=2000 28 | --framesizems=30 29 | --framestridems=10 30 | --seed=2 31 | --lrcosine=true 32 | -------------------------------------------------------------------------------- /baselines/TDS/experiments/config/training/10h_letter_20M_TDS.cfg: -------------------------------------------------------------------------------- 1 | --runname=10h_letter_20M_TDS 2 | --rundir= 3 | --arch=TDS_20M.arch 4 | --archdir=/experiments/arch 5 | --batchsize=8 6 | --tokensdir=/data 7 | --tokens=letters.lst 8 | --lr=0.1 9 | --momentum=0.5 10 | --maxgradnorm=1 11 | --onorm=target 12 | --sqnorm=true 13 | --mfsc=true 14 | --nthread=10 15 | --criterion=ctc 16 | --memstepsize=8338608 17 | --listdata=true 18 | --lexicon=/data/letter.lex 19 | --datadir= 20 | --train=/data/10h.lst 21 | --valid=dev-other.lst,dev-clean.lst 22 | --wordseparator= 23 | --target=ltr 24 | --filterbanks=80 25 | --gamma=0.5 26 | --enable_distributed=true 27 | --iter=1500 28 | --framesizems=30 29 | --framestridems=10 30 | --seed=2 31 | --lrcosine=true 32 | -------------------------------------------------------------------------------- /baselines/TDS/experiments/config/training/1h+pseudo-label_phone_37M_TDS.cfg: -------------------------------------------------------------------------------- 1 | --runname=1h+PL_phone_37M_TDS 2 | --rundir= 3 | --arch=TDS_37M.arch 4 | --archdir=/experiments/arch 5 | --batchsize=8 6 | --tokensdir=/data 7 | --tokens=phones.lst 8 | --lr=0.1 9 | --momentum=0.5 10 | --maxgradnorm=1 11 | --onorm=target 12 | --sqnorm=true 13 | --mfsc=true 14 | --nthread=10 15 | --criterion=ctc 16 | --memstepsize=8338608 17 | --listdata=true 18 | --lexicon=/data/phone.lex 19 | --datadir= 20 | --train=/data/1h.lst,/data/librivox_pseudo_label/1h_phn.lst 21 | --valid=dev-other.lst,dev-clean.lst 22 | --wordseparator= 23 | --target=ltr 24 | --filterbanks=80 25 | --gamma=0.5 26 | --enable_distributed=true 27 | --iter=300 28 | --stepsize=30 29 | --framesizems=25 30 | --framestridems=10 31 | --seed=2 32 | --reportiters=2000 33 | --minisz=200 34 | --mintsz=1 35 | -------------------------------------------------------------------------------- /baselines/TDS/experiments/config/training/10h+pseudo-label_phone_37M_TDS.cfg: -------------------------------------------------------------------------------- 1 | --runname=10h+PL_phone_37M_TDS 2 | --rundir= 3 | --arch=TDS_37M.arch 4 | --archdir=/experiments/arch 5 | --batchsize=8 6 | --tokensdir=/data 7 | --tokens=phones.lst 8 | --lr=0.1 9 | --momentum=0.5 10 | --maxgradnorm=1 11 | --onorm=target 12 | --sqnorm=true 13 | --mfsc=true 14 | --nthread=10 15 | --criterion=ctc 16 | --memstepsize=8338608 17 | --listdata=true 18 | --lexicon=/data/phone.lex 19 | --datadir= 20 | --train=/data/10h.lst,/data/librivox_pseudo_label/10h_phn.lst.new 21 | --valid=dev-other.lst,dev-clean.lst 22 | --wordseparator= 23 | --target=ltr 24 | --filterbanks=80 25 | --gamma=0.5 26 | --enable_distributed=true 27 | --iter=300 28 | --stepsize=30 29 | --framesizems=25 30 | --framestridems=10 31 | --seed=2 32 | --reportiters=2000 33 | --minisz=200 34 | --mintsz=1 35 | -------------------------------------------------------------------------------- /baselines/TDS/experiments/config/training/1h+pseudo-label_letter_37M_TDS.cfg: -------------------------------------------------------------------------------- 1 | --runname=1h+PL_letter_37M_TDS 2 | --rundir= 3 | --arch=TDS_37M.arch 4 | --archdir=/experiments/arch 5 | --batchsize=8 6 | --tokensdir=/data 7 | --tokens=letters.lst 8 | --lr=0.1 9 | --momentum=0.5 10 | --maxgradnorm=1 11 | --onorm=target 12 | --sqnorm=true 13 | --mfsc=true 14 | --nthread=10 15 | --criterion=ctc 16 | --memstepsize=8338608 17 | --listdata=true 18 | --lexicon=/data/letter.lex 19 | --datadir= 20 | --train=/data/1h.lst,/data/librivox_pseudo_label/1h_ltr.lst 21 | --valid=dev-other.lst,dev-clean.lst 22 | --wordseparator=| 23 | --target=ltr 24 | --filterbanks=80 25 | --gamma=0.5 26 | --enable_distributed=true 27 | --iter=300 28 | --stepsize=30 29 | --framesizems=25 30 | --framestridems=10 31 | --seed=2 32 | --reportiters=2000 33 | --minisz=200 34 | --mintsz=1 35 | --maxtsz=600 36 | -------------------------------------------------------------------------------- /baselines/TDS/experiments/config/training/10h+pseudo-label_letter_37M_TDS.cfg: -------------------------------------------------------------------------------- 1 | --runname=10h+PL_letter_37M_TDS 2 | --rundir= 3 | --arch=TDS_37M.arch 4 | --archdir=/experiments/arch 5 | --batchsize=8 6 | --tokensdir=/data 7 | --tokens=letters.lst 8 | --lr=0.1 9 | --momentum=0.5 10 | --maxgradnorm=1 11 | --onorm=target 12 | --sqnorm=true 13 | --mfsc=true 14 | --nthread=10 15 | --criterion=ctc 16 | --memstepsize=8338608 17 | --listdata=true 18 | --lexicon=/data/letter.lex 19 | --datadir= 20 | --train=/data/10h.lst,/data/librivox_pseudo_label/10h_ltr.lst 21 | --valid=dev-other.lst,dev-clean.lst 22 | --wordseparator=| 23 | --target=ltr 24 | --filterbanks=80 25 | --gamma=0.5 26 | --enable_distributed=true 27 | --iter=300 28 | --stepsize=30 29 | --framesizems=25 30 | --framestridems=10 31 | --seed=2 32 | --reportiters=2000 33 | --minisz=200 34 | --mintsz=1 35 | --maxtsz=600 36 | -------------------------------------------------------------------------------- /data_preparation/text_retrieval/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from .archive_org import is_archive_org_url, get_archive_org_text_data 3 | from .main_lesson import is_main_lesson_url, get_all_text_from_main_lesson 4 | from .hathitrust import is_hathitrust_url, load_hathitrust_book 5 | from .bartleby import is_bartheleby_url, get_bartheleby_data 6 | from .guttenberg import is_guttenberg_url, get_guttenberg_data 7 | 8 | 9 | def get_text_data(url): 10 | if is_guttenberg_url(url): 11 | return get_guttenberg_data(url) 12 | elif is_archive_org_url(url): 13 | return get_archive_org_text_data(url) 14 | elif is_bartheleby_url(url): 15 | return get_bartheleby_data(url) 16 | elif is_main_lesson_url(url): 17 | return get_all_text_from_main_lesson(url) 18 | elif is_hathitrust_url(url): 19 | return load_hathitrust_book(url) 20 | else: 21 | raise RuntimeError(f'Unknown web API {url}') 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Facebook, Inc. and its affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /data_preparation/rebuild_limited_train/clean_texts.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import argparse 3 | import pathlib 4 | 5 | """ 6 | Cleans the *.txt files, removing the texts that correspond to the flac's not in the directory. 7 | Assumes Librispeech-like outline. 8 | """ 9 | 10 | 11 | def get_args(): 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--dataset', type=str) 14 | args = parser.parse_args() 15 | return args 16 | 17 | 18 | if __name__ == '__main__': 19 | args = get_args() 20 | 21 | txt_names = pathlib.Path(args.dataset).rglob(f"*.txt") 22 | txt_names = sorted(list(txt_names)) 23 | 24 | for txt_name in txt_names: 25 | print(txt_name) 26 | 27 | parent_dir = txt_name.parent 28 | 29 | siblings = list(parent_dir.glob('*.flac')) 30 | assert len(siblings) > 0 31 | 32 | sibling_names = set([s.stem for s in siblings]) 33 | 34 | with open(txt_name, 'r') as f: 35 | txt = f.readlines() 36 | 37 | with open(txt_name, 'w') as f: 38 | for l in txt: 39 | if l.split()[0] in sibling_names: 40 | f.write(l) 41 | -------------------------------------------------------------------------------- /eval/PER_src/per_operator.pyx: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import numpy as np 3 | cimport numpy as np 4 | cimport cython 5 | from cpython cimport bool 6 | ctypedef np.float32_t FLOAT_t # cost type 7 | ctypedef np.intp_t IND_t # array index type 8 | ctypedef np.int32_t INT_t # array index type 9 | CTYPE = np.float32 # cost type 10 | 11 | def needleman_wunsch_align_score(seq1, seq2, d, m, r, normalize = True): 12 | 13 | N1, N2 = seq1.shape[0], seq2.shape[0] 14 | return _needleman_wunsch_align_score(N1, N2, seq1, seq2, d, m, r, normalize) 15 | 16 | cpdef _needleman_wunsch_align_score(IND_t N1, IND_t N2, INT_t[:] seq1, INT_t[:] seq2, 17 | FLOAT_t d, FLOAT_t m, FLOAT_t r, bool normalized): 18 | 19 | # Fill up the errors 20 | cdef IND_t i, j 21 | cdef FLOAT_t match, v1, v2, v3, res 22 | cdef FLOAT_t[:,:] tmpRes_ = np.empty((N1 + 1, N2 + 1), dtype=CTYPE) 23 | 24 | for i in range(0, N1 + 1): 25 | tmpRes_[i][0] = i * d 26 | for j in range(0, N2 + 1): 27 | tmpRes_[0][j] = j * d 28 | 29 | for i in range(0, N1): 30 | for j in range(0, N2): 31 | match = r if seq1[i] == seq2[j] else m 32 | v1 = tmpRes_[i][j] + match 33 | v2 = tmpRes_[i + 1][j] + d 34 | v3 = tmpRes_[i][j + 1] + d 35 | tmpRes_[i + 1][j + 1] = max(v1, max(v2, v3)) 36 | 37 | res = -tmpRes_[N1][N2] 38 | if normalized: 39 | res /= float(N1) 40 | return res 41 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Librilight 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Our Development Process 6 | ... (in particular how this is synced with internal changes to the project) 7 | 8 | ## Pull Requests 9 | We actively welcome your pull requests. 10 | 11 | 1. Fork the repo and create your branch from `master`. 12 | 2. If you've added code that should be tested, add tests. 13 | 3. If you've changed APIs, update the documentation. 14 | 4. Ensure the test suite passes. 15 | 5. Make sure your code lints. 16 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 17 | 18 | ## Contributor License Agreement ("CLA") 19 | In order to accept your pull request, we need you to submit a CLA. You only need 20 | to do this once to work on any of Facebook's open source projects. 21 | 22 | Complete your CLA here: 23 | 24 | ## Issues 25 | We use GitHub issues to track public bugs. Please ensure your description is 26 | clear and has sufficient instructions to be able to reproduce the issue. 27 | 28 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 29 | disclosure of security bugs. In those cases, please go through the process 30 | outlined on that page and do not file a public issue. 31 | 32 | ## Coding Style 33 | * 2 spaces for indentation rather than tabs 34 | * 80 character line length 35 | * ... 36 | 37 | ## License 38 | By contributing to Librilight, you agree that your contributions will be licensed 39 | under the LICENSE file in the root directory of this source tree. 40 | -------------------------------------------------------------------------------- /data_preparation/unit_tests.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import unittest 3 | from nose.tools import eq_, ok_ 4 | import numpy as np 5 | from cutDB import cutWithVAD, greedyMerge 6 | 7 | 8 | class TestCutDB(unittest.TestCase): 9 | 10 | def setUp(self): 11 | 12 | self.vad = np.array([0.8, 0.9, 1., 0.49, 0.4, 0.2, 13 | 0.1, 1., 1., 0.9, 0.0, 0.99]) 14 | self.stepVAD = 3 15 | self.data = np.array(list(range(36))) 16 | 17 | def testCutDB(self): 18 | outCuts = list(cutWithVAD(self.data, self.vad, 0.5, self.stepVAD)) 19 | expectedOutput = [(0, 9), (21, 30), (33, 36)] 20 | 21 | eq_(len(outCuts), len(expectedOutput)) 22 | for index in range(len(outCuts)): 23 | eq_(outCuts[index], expectedOutput[index]) 24 | 25 | def testGreedyMerge(self): 26 | 27 | cutsIndex = [(0, 9), (21, 30), (33, 36), (24, 49), (53, 117), 28 | (201, 222), (230, 240)] 29 | sizeMultiplier = 0.5 30 | targetSize = 20 31 | 32 | mergeIndexes = greedyMerge(cutsIndex, targetSize, sizeMultiplier) 33 | expectedOutput = [(46, [(0, 9), (21, 30), (33, 36), (24, 49)]), 34 | (64, [(53, 117)]), 35 | (31, [(201, 222), (230, 240)])] 36 | 37 | eq_(len(mergeIndexes), len(expectedOutput)) 38 | for index in range(len(mergeIndexes)): 39 | eq_(mergeIndexes[index][0], expectedOutput[index][0]) 40 | eq_(len(mergeIndexes[index][1]), len(expectedOutput[index][1])) 41 | size = len(mergeIndexes[index][1]) 42 | for p in range(size): 43 | eq_(mergeIndexes[index][1][p], expectedOutput[index][1][p]) 44 | -------------------------------------------------------------------------------- /data_preparation/split_librilight/extract_test_speakers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import argparse 7 | import json 8 | import pathlib 9 | 10 | 11 | def get_args(): 12 | parser = argparse.ArgumentParser( 13 | description="Prepare list of speakers that appear in Librispeech test/dev datasets") 14 | parser.add_argument('--librispeech_meta', type=str, required=True) 15 | parser.add_argument('--output', type=str, default='test_speakers.json') 16 | 17 | args = parser.parse_args() 18 | return args 19 | 20 | 21 | def extract_holdout_speakers(librispeech_meta_path): 22 | chapters_file = pathlib.Path(librispeech_meta_path) / 'SPEAKERS.TXT' 23 | speakers_to_exclude = [] 24 | 25 | with open(chapters_file, 'r') as f: 26 | for line in f.readlines(): 27 | if line.startswith(';'): 28 | continue 29 | line = line.split("|") 30 | speaker_id, subset = int(line[0].strip()), line[2].strip() 31 | assert subset in ['train-other-500', 'train-clean-100', 32 | 'dev-other', 'dev-clean', 'test-other', 'test-clean', 'train-clean-360'], subset 33 | 34 | if subset in ['dev-other', 'dev-clean', 'test-other', 'test-clean']: 35 | speakers_to_exclude.append(speaker_id) 36 | 37 | return sorted(speakers_to_exclude) 38 | 39 | 40 | if __name__ == '__main__': 41 | args = get_args() 42 | 43 | speakers_to_exclude = extract_holdout_speakers(args.librispeech_meta) 44 | print('Speakers to exclude: ', len(speakers_to_exclude)) 45 | 46 | with open(args.output, 'w') as f: 47 | f.write(json.dumps(dict( 48 | test_speakers=speakers_to_exclude, 49 | ), indent=1)) 50 | -------------------------------------------------------------------------------- /data_preparation/rebuild_limited_train/get_stats.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from utils import get_histogram, get_speakers, traverse_tree, full_records, print_stats, materialize 3 | import pathlib 4 | import argparse 5 | import random 6 | 7 | 8 | def get_args(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--root', type=str) 11 | parser.add_argument('--meta_path', type=str) 12 | 13 | args = parser.parse_args() 14 | return args 15 | 16 | 17 | if __name__ == '__main__': 18 | args = get_args() 19 | speakers = get_speakers(pathlib.Path(args.meta_path) / 'SPEAKERS.TXT') 20 | 21 | fname2length = traverse_tree(args.root) 22 | records = full_records(speakers, fname2length, subset_name=None) 23 | print(f'Utterances: {len(records)}') 24 | 25 | time_by_gender = get_histogram( 26 | records, lambda_key=lambda r: r.speaker.gender, lambda_value=lambda r: r.length / 16000) 27 | print('Time by gender, seconds', time_by_gender) 28 | 29 | time_by_subset = get_histogram( 30 | records, lambda_key=lambda r: r.speaker.subset, lambda_value=lambda r: r.length / 16000) 31 | print('Time by subset, seconds', time_by_subset) 32 | 33 | speaker_freq = get_histogram( 34 | records, lambda_key=lambda r: r.speaker.id, lambda_value=lambda r: 1) 35 | print('Number of uniq speakers', len(speaker_freq)) 36 | 37 | book_lengths = get_histogram( 38 | records, lambda_key=lambda r: r.book, lambda_value=lambda r: r.length) 39 | 40 | scaler = 1.0 / 16000 41 | max_length = max(book_lengths.values()) * scaler 42 | min_length = min(book_lengths.values()) * scaler 43 | mean_length = sum(book_lengths.values()) / len(book_lengths) * scaler 44 | 45 | print( 46 | f'Book length disrtibution, seconds, min: {min_length}, mean: {mean_length}, max: {max_length}; n_books={len(book_lengths)}') 47 | -------------------------------------------------------------------------------- /data_preparation/rebuild_limited_train/split_1h_in10min.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from utils import get_histogram, get_speakers, traverse_tree, full_records, print_stats, materialize 3 | import pathlib 4 | import argparse 5 | import random 6 | 7 | 8 | def do_split(records): 9 | speakers = set([r.speaker.id for r in records]) 10 | 11 | for speaker in speakers: 12 | speaker_records = [r for r in records if r.speaker.id == speaker] 13 | yield speaker_records 14 | 15 | 16 | def get_args(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--target_dir', type=str) 19 | parser.add_argument('--seed', type=int, default=7) 20 | parser.add_argument('--root_1h', type=str) 21 | parser.add_argument('--meta_path', type=str) 22 | 23 | args = parser.parse_args() 24 | 25 | return args 26 | 27 | 28 | if __name__ == '__main__': 29 | args = get_args() 30 | random.seed(args.seed) 31 | 32 | speakers = get_speakers(pathlib.Path(args.meta_path) / 'SPEAKERS.TXT') 33 | fname2length = traverse_tree(args.root_1h) 34 | all_records = full_records(speakers, fname2length) 35 | print(f'Got {len(all_records)} records') 36 | 37 | for gender in ['f', 'm']: 38 | for t, tag in zip(['train-clean-100', 'train-other-500'], ['clean', 'other']): 39 | records = list(all_records) 40 | print(f'Selecting from {t}, gender {gender}') 41 | 42 | records = filter(lambda x: x.speaker.gender.lower() 43 | == gender and x.speaker.subset == t, records) 44 | records = list(records) 45 | print(f'{len(records)} utterances in the split') 46 | 47 | for i, split in enumerate(do_split(records)): 48 | print_stats(split) 49 | 50 | if args.target_dir: 51 | materialize(split, args.target_dir + f'/{i}/', tag=tag) 52 | -------------------------------------------------------------------------------- /data_preparation/rebuild_limited_train/README.md: -------------------------------------------------------------------------------- 1 | # Getting Librispeech subsample 2 | 3 | This collection of scripts was used to subsample Librispeech data into 10h, 1h, and 10 min chunks. 4 | These chunks are then used for fine-tuning the models trained with Librilight. 5 | 6 | The goal of the splitting process is to ensure that the chunks are approximately balanced w.r.t. 7 | to the gender of the speakers and noise levels. 8 | 9 | Since the samples are nested, the workflow assumes that they are generated in turns, starting from the largest. 10 | 11 | # 10 hours sample 12 | ``` 13 | python sample_10h.py --root_clean= --root_other= --meta_path= 14 | ``` 15 | the script will generate the subsample and output some statistics for it, but it will not be writen to a disc unless 16 | you provide `--target_dir` option. In this case, it would be materialized on disk. 17 | ``` 18 | python sample_10h.py --root_clean= --root_other= --meta_path= --target_dir=10h 19 | ``` 20 | 21 | # 1 hour sample 22 | Next step is selecting 1h sample from the 10h one, obtained above: 23 | ``` 24 | python select_1h.py --meta_path= --root_10h= 25 | ``` 26 | Again, to actually materialize 27 | ``` 28 | python select_1h.py --meta_path= --root_10h= --target_dir=./1h 29 | ``` 30 | As a result, the files would be moved form `root_10h` (making it effectively 9h). 31 | 32 | # Splitting 1 hour in 6 x 10 minutes 33 | Finally, we split 1h sample in 10 samples by 6 minutes: 34 | ``` 35 | python split_1h_in10min.py --root_1h=1h --target_dir=6x10min --meta_path= 36 | ``` 37 | 38 | # Other 39 | `get_stats.py` would output the stats for a particular directory and `clean_texts.py` would prune all texts that correspond to excluded files. 40 | 41 | `build_dataset.sh` is a script for re-generating the Librispeech samples we release. 42 | -------------------------------------------------------------------------------- /data_preparation/text_retrieval/archive_org.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import os 3 | from internetarchive import get_item, download 4 | 5 | 6 | def download_text_data(textID, outDir): 7 | 8 | item = get_item(textID) 9 | namesFile = [] 10 | for data in item.files: 11 | name = data['name'] 12 | if os.path.splitext(name)[1] == ".txt": 13 | namesFile.append(name) 14 | 15 | if len(namesFile) == 0: 16 | return False, [] 17 | 18 | return download(textID, files=namesFile, destdir=outDir), namesFile 19 | 20 | 21 | def get_archive_id(textURL): 22 | 23 | indexStart = textURL.find("archive.org/details/") \ 24 | + len("archive.org/details/") 25 | if indexStart < 0: 26 | return False 27 | 28 | indexEnd = textURL[indexStart:].find("/") 29 | if indexEnd < 0: 30 | return textURL[indexStart:] 31 | return textURL[indexStart:(indexStart + indexEnd)] 32 | 33 | 34 | def get_archive_org_text_data(url): 35 | 36 | ID = get_archive_id(url) 37 | tmpDir = "tmp" 38 | status, fileNames = download_text_data(ID, tmpDir) 39 | 40 | if len(fileNames) == 0: 41 | raise RuntimeError("Invalid URL") 42 | 43 | fullText = "" 44 | for fileName in fileNames: 45 | fullPath = os.path.join(tmpDir, os.path.join(ID, fileName)) 46 | with open(fullPath, 'r', encoding="ISO-8859-1") as file: 47 | data = file.read() 48 | 49 | os.remove(fullPath) 50 | fullText += data.replace('\\n', '\n') + '\n' 51 | 52 | return fullText 53 | 54 | 55 | def is_archive_org_url(url): 56 | if url.find("https://archive.org/stream/") == 0 \ 57 | or url.find("http://archive.org/stream/") == 0: 58 | url = url.replace("archive.org/stream/", "archive.org/details/") 59 | return url.find("archive.org/details/") >= 0 60 | 61 | 62 | if __name__ == "__main__": 63 | 64 | testID = "completepoetical00byro" 65 | outDIR = "tmp" 66 | fullURL = "https://archive.org/details/slaveryourtimes00tolsiala/page/n8" 67 | print(get_archive_org_text_data(fullURL)) 68 | -------------------------------------------------------------------------------- /eval/WER_src/letter_ctc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import torch 3 | 4 | 5 | def cut_data(seq, seq_len): 6 | max_len = seq_len.max() 7 | seq = seq[:, :max_len] 8 | return seq 9 | 10 | 11 | class LetterClassifier(torch.nn.Module): 12 | 13 | def __init__(self, feature_maker, dim_encoder, n_letters, kernel_size=8, p_dropout=0.0): 14 | super().__init__() 15 | self.feature_maker = feature_maker 16 | self.feature_maker.eval() 17 | self.dropout = torch.nn.Dropout2d(p=p_dropout) 18 | self.lstm = torch.nn.LSTM(dim_encoder, dim_encoder // 2, bidirectional=True, 19 | num_layers=1, batch_first=True) 20 | self.classifier = torch.nn.Conv1d( 21 | dim_encoder, n_letters + 1, kernel_size, stride=kernel_size // 2) 22 | 23 | def forward(self, raw): 24 | with torch.no_grad(): 25 | features = self.feature_maker(raw) 26 | 27 | self.lstm.flatten_parameters() 28 | x = self.lstm(features)[0] 29 | x = x.permute(0, 2, 1) 30 | x = self.dropout(x) 31 | return self.classifier(x).permute(0, 2, 1) 32 | 33 | 34 | class CTCLetterCriterion(torch.nn.Module): 35 | 36 | def __init__(self, letter_classifier, n_letters): 37 | super().__init__() 38 | self.letter_classifier = letter_classifier 39 | self.loss = torch.nn.CTCLoss(blank=n_letters, 40 | zero_infinity=True) 41 | 42 | def forward(self, features, feature_size, label, label_size): 43 | predictions = self.letter_classifier(features) 44 | predictions = cut_data(predictions, feature_size) 45 | feature_size = torch.clamp(feature_size, max=predictions.size(1)) 46 | label = cut_data(label, label_size) 47 | assert label_size.min() > 0 48 | predictions = torch.nn.functional.log_softmax(predictions, dim=2) 49 | predictions = predictions.permute(1, 0, 2) 50 | loss = self.loss(predictions, label, feature_size, 51 | label_size).view(1, -1) 52 | 53 | assert not (torch.isinf(loss).any() or torch.isnan(loss).any()) 54 | 55 | return loss 56 | -------------------------------------------------------------------------------- /data_preparation/plot.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | plt.switch_backend('agg') 5 | 6 | 7 | def plot_hist(seq, nBins, pathOut, y_label="", title="", 8 | x_label="", normalized=True, y_scale=None, x_scale=None): 9 | 10 | if isinstance(seq, list): 11 | seq = np.array(seq) 12 | 13 | counts, bins = np.histogram(seq, bins=nBins) 14 | if normalized: 15 | counts = counts / np.sum(counts) 16 | plt.style.use('seaborn') 17 | plt.clf() 18 | plt.hist(bins[:-1], bins, weights=counts) 19 | plt.ylabel(y_label) 20 | plt.xlabel(x_label) 21 | plt.title(title) 22 | if y_scale is not None: 23 | plt.yscale(y_scale) 24 | if x_scale is not None: 25 | plt.yscale(x_scale) 26 | plt.tight_layout() 27 | plt.savefig(pathOut) 28 | 29 | 30 | def plot_scatter(seqs, xLabel, pathOut, x_label="", y_label="", title=""): 31 | plt.clf() 32 | for i in range(seqs.shape[0]): 33 | plt.scatter(xLabel, seqs[i]) 34 | plt.xlabel(x_label) 35 | plt.ylabel(y_label) 36 | plt.title(title) 37 | plt.tight_layout() 38 | plt.savefig(pathOut) 39 | 40 | 41 | def plot_seq(seqs, xLabel, pathOut, x_label="", y_label="", title="", 42 | xscale="linear", yscale="linear", legend=None): 43 | plt.clf() 44 | for i in range(seqs.shape[0]): 45 | plt.plot(xLabel, seqs[i]) 46 | plt.xscale(xscale) 47 | plt.yscale(yscale) 48 | plt.xlabel(x_label) 49 | plt.ylabel(y_label) 50 | plt.title(title) 51 | if legend is not None: 52 | plt.gca().legend(legend) 53 | plt.tight_layout() 54 | plt.savefig(pathOut) 55 | 56 | 57 | def plot_pie(data, pathOut, title=""): 58 | 59 | labels = list(data.keys()) 60 | sizes = [data[x] for x in labels] 61 | 62 | plt.clf() 63 | plt.style.use('classic') 64 | patches, texts, _ = plt.pie(sizes, autopct=lambda p: '{:.0f}'.format(p * sum(sizes) / 100), 65 | shadow=False, startangle=90, pctdistance=1.1) 66 | #plt.axes([0.3, 0.3, .5, .5]) 67 | plt.legend(patches, labels, loc='lower right', 68 | fontsize=8) 69 | plt.tight_layout() 70 | plt.title(title) 71 | plt.savefig(pathOut, bbox_inches='tight') 72 | -------------------------------------------------------------------------------- /data_preparation/split_librilight/README.md: -------------------------------------------------------------------------------- 1 | # Split and filter LibriLight data 2 | 3 | ## Preparation 4 | First step is to find speakers that appear in Librispeech test and dev datasets, 5 | 6 | ```console 7 | python extract_test_speakers.py --librispeech_meta= 8 | ``` 9 | 10 | Next, taking the existing VAD files with per-frame silence probabilities, we find segments 11 | of speech that are not shorter than L ms and separated at least with L ms of silence. We use 12 | L = 6 (frames) * 80 (ms / frame). 13 | 14 | ```console 15 | python prepare_vads.py --vad_root= 16 | ``` 17 | 18 | Further step is to build the audio-file metadata, which would contain both book meta-data and 19 | individual file's SNR/VAD records. To do that, we run 20 | 21 | ```console 22 | python puts_json.py --librivox_dir= --librivox_processed= \ 23 | --vad_preprocessed= --snr_preprocessed= --test_speakers= 24 | ``` 25 | 26 | This command would (a) generate a file `processing_results.json` containing some diagnostic statistics, and 27 | (b) place json files with meta-data alongside the audio files. 28 | 29 | After that, we can decide on the data split. This command with produce three json files, each describing sets of 30 | selected (nested) sets of files, each having 10x less audio-time: 31 | 32 | ```console 33 | python split.py --librivox_processed= --sampling_steps=3 --divisor=10 34 | ``` 35 | The produced files would be named as `split_0.json` (largest), `split_1.json` (second largest), etc. They also 36 | contain some rudimentary statistics of the selected data. 37 | 38 | Finally, you can actually copy the selected files to a specified directory ("materialize") by running 39 | ```console 40 | python materialize_split.py --src_dir --dst_dir= --json=split_2.json 41 | ``` 42 | If you want to exclude other splits (e.g. make the `medium` split directory not contain files from the `small`), you can use `--minus` parameter: 43 | ```console 44 | python materialize_split.py --src_dir --dst_dir= --json=split_1.json --minus=split_2.json 45 | ``` 46 | -------------------------------------------------------------------------------- /eval/ABX_src/dtw.pyx: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import torch 3 | import numpy as np 4 | cimport numpy as np 5 | cimport cython 6 | from cpython cimport bool 7 | ctypedef np.float32_t CTYPE_t # cost type 8 | ctypedef np.intp_t IND_t # array index type 9 | CTYPE = np.float32 # cost type 10 | 11 | 12 | 13 | def dtw_batch(x,y, sx, sy, dist_mat, ignore_diag=False, symetric=False): 14 | 15 | Nx = dist_mat.shape[0] 16 | Ny = dist_mat.shape[1] 17 | 18 | out = torch.zeros((Nx, Ny)) 19 | 20 | for i in range(Nx): 21 | start_index = i if symetric else 0 22 | i_sx = sx[i] 23 | for j in range(start_index, Ny): 24 | 25 | j_sy = sy[j] 26 | if ignore_diag and i == j: 27 | continue 28 | distance = _dtw(i_sx, j_sy, dist_mat[i,j,:i_sx,:j_sy],True) 29 | out[i][j] = distance 30 | if symetric and i != j: 31 | out[j][i] = out[i][j] 32 | 33 | return out 34 | 35 | 36 | 37 | cpdef _dtw(IND_t N, IND_t M, CTYPE_t[:,:] dist_array, bool normalized): 38 | cdef IND_t i, j 39 | cdef CTYPE_t[:,:] cost = np.empty((N, M), dtype=CTYPE) 40 | cdef CTYPE_t final_cost, c_diag, c_left, c_up 41 | # initialization 42 | cost[0,0] = dist_array[0,0] 43 | for i in range(1,N): 44 | cost[i,0] = dist_array[i,0] + cost[i-1,0] 45 | for j in range(1,M): 46 | cost[0,j] = dist_array[0,j] + cost[0,j-1] 47 | # the dynamic programming loop 48 | for i in range(1,N): 49 | for j in range(1,M): 50 | cost[i,j] = dist_array[i,j] + min(cost[i-1,j], cost[i-1,j-1], cost[i,j-1]) 51 | 52 | final_cost = cost[N-1, M-1] 53 | if normalized: 54 | path_len = 1 55 | i = N-1 56 | j = M-1 57 | while i > 0 and j > 0: 58 | c_up = cost[i - 1, j] 59 | c_left = cost[i, j-1] 60 | c_diag = cost[i-1, j-1] 61 | if c_diag <= c_left and c_diag <= c_up: 62 | i -= 1 63 | j -= 1 64 | elif c_left <= c_up: 65 | j -= 1 66 | else: 67 | i -= 1 68 | path_len += 1 69 | if i == 0: 70 | path_len += j 71 | if j == 0: 72 | path_len += i 73 | final_cost /= path_len 74 | return final_cost 75 | -------------------------------------------------------------------------------- /data_preparation/rebuild_limited_train/select_1h.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from utils import get_histogram, get_speakers, traverse_tree, full_records, print_stats, materialize 3 | import argparse 4 | import pathlib 5 | import random 6 | 7 | 8 | def do_split(records, seconds_per_speaker): 9 | speakers = list(set([r.speaker.id for r in records])) 10 | random.shuffle(speakers) 11 | records_filtered = [] 12 | 13 | for speaker in speakers: 14 | time_taken = 0 15 | speaker_records = [r for r in records if r.speaker.id == speaker] 16 | 17 | for r in speaker_records: 18 | if time_taken > seconds_per_speaker * 16000: 19 | break 20 | time_taken += r.length 21 | 22 | records_filtered.append(r) 23 | 24 | return records_filtered 25 | 26 | 27 | def get_args(): 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument('--max_seconds_per_speaker', type=int, default=150) 30 | parser.add_argument('--target_dir', type=str) 31 | parser.add_argument('--seed', type=int, default=7) 32 | parser.add_argument('--root_10h', type=str) 33 | parser.add_argument('--meta_path', type=str) 34 | 35 | args = parser.parse_args() 36 | 37 | return args 38 | 39 | 40 | if __name__ == '__main__': 41 | args = get_args() 42 | random.seed(args.seed) 43 | 44 | speakers = get_speakers(pathlib.Path(args.meta_path) / 'SPEAKERS.TXT') 45 | 46 | print(f'Total {len(speakers)} speakers') 47 | fname2length = traverse_tree(args.root_10h) 48 | all_records = full_records(speakers, fname2length) 49 | print(f'Got {len(all_records)} records') 50 | 51 | for gender in ['f', 'm']: 52 | for t, tag in zip(['train-clean-100', 'train-other-500'], ['clean', 'other']): 53 | records = list(all_records) 54 | print(f'Selecting from {t}, gender {gender}, tag {tag}') 55 | 56 | records = filter(lambda x: x.speaker.gender.lower() 57 | == gender and x.speaker.subset == t, records) 58 | records = list(records) 59 | 60 | records_filtered = do_split(records, args.max_seconds_per_speaker) 61 | print_stats(records_filtered) 62 | 63 | if args.target_dir: 64 | materialize(records_filtered, args.target_dir, 65 | tag=tag, move=True) 66 | -------------------------------------------------------------------------------- /data_preparation/split_librilight/materialize_split.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import argparse 3 | import json 4 | import pathlib 5 | import shutil 6 | import os 7 | import multiprocessing 8 | 9 | 10 | def _apply(task): 11 | file, args, action = task 12 | 13 | src = pathlib.Path(args.src_dir) 14 | dst = pathlib.Path(args.dst_dir) 15 | 16 | file = pathlib.Path(file) 17 | with open(file, 'r') as f: 18 | meta = json.loads(f.read()) 19 | speaker = meta['speaker'] 20 | 21 | dst_dir = dst / speaker / file.parent.name 22 | dst_dir.mkdir(exist_ok=True, parents=True) 23 | 24 | # move/copy json file 25 | src_file = src / file.parent.name / file.name 26 | dst_file = dst_dir / file.name 27 | action(src_file, dst_file) 28 | 29 | # move/copy flac file 30 | src_file = src / file.parent.name / (file.stem + '.flac') 31 | dst_file = dst_dir / (file.stem + '.flac') 32 | action(src_file, dst_file) 33 | 34 | 35 | def get_args(): 36 | parser = argparse.ArgumentParser( 37 | description="A script to copy prepared data splits to releasable folders") 38 | parser.add_argument('--src_dir', type=str, required=True) 39 | parser.add_argument('--dst_dir', type=str, required=True) 40 | parser.add_argument('--json', type=str, required=True) 41 | parser.add_argument('--minus', type=str, action='append', default=[]) 42 | parser.add_argument('--n_workers', type=int, default=16) 43 | parser.add_argument('--mode', type=str, 44 | choices=['copy', 'print'], default='print') 45 | 46 | args = parser.parse_args() 47 | 48 | assert args.json and args.dst_dir 49 | 50 | return args 51 | 52 | 53 | # lambda-functions are un-pickable 54 | def _print(src, dst): 55 | print(src, '->', dst) 56 | 57 | 58 | if __name__ == '__main__': 59 | args = get_args() 60 | 61 | with open(args.json, 'r') as f: 62 | files = json.loads(f.read())['files'] 63 | 64 | files_minus = [] 65 | for fname in args.minus: 66 | with open(fname, 'r') as f: 67 | files_minus.extend(json.loads(f.read())['files']) 68 | 69 | files_minus = set(files_minus) 70 | 71 | if args.mode == 'copy': 72 | action = shutil.copy 73 | elif args.mode == 'print': 74 | action = _print 75 | 76 | tasks = [(file, args, action) for file in files if file not in files_minus] 77 | 78 | with multiprocessing.Pool(processes=args.n_workers) as pool: 79 | pool.map(_apply, tasks) 80 | -------------------------------------------------------------------------------- /data_preparation/metadata_completion/GenreScrapper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import json 3 | import os 4 | from html.parser import HTMLParser 5 | import requests 6 | import progressbar 7 | 8 | 9 | class GenreScapper(HTMLParser): 10 | 11 | def __init__(self): 12 | super(GenreScapper, self).__init__() 13 | self.genre = None 14 | self.inGoodTag = False 15 | self.takeData = False 16 | self.getTitle = False 17 | 18 | def handle_starttag(self, tag, attr): 19 | if tag == "p": 20 | if ('class', 'book-page-genre') in attr: 21 | self.inGoodTag = True 22 | if tag == "span" and self.inGoodTag: 23 | self.getTitle = True 24 | 25 | def handle_endtag(self, tag): 26 | if tag == "p": 27 | self.inGoodTag = False 28 | self.getTitle = False 29 | 30 | def handle_data(self, data): 31 | if self.getTitle: 32 | if data.find("Genre") >= 0: 33 | self.takeData = True 34 | self.getTitle = False 35 | elif self.takeData and self.genre is None: 36 | self.genre = data 37 | 38 | def getGenre(self): 39 | if self.genre is None: 40 | return None 41 | allGenres = self.genre.replace('\\', '').split(',') 42 | output = [] 43 | for item in allGenres: 44 | if len(item) == 0 or item == ' ': 45 | continue 46 | output.append(item.lstrip().rstrip()) 47 | 48 | if len(output) == 0: 49 | return None 50 | return output 51 | 52 | 53 | def getGenreFromMetadata(metadata): 54 | 55 | urlLibriVoxPage = metadata["url_librivox"] 56 | parser = GenreScapper() 57 | req = requests.get(urlLibriVoxPage) 58 | parser.feed(str(req._content)) 59 | return parser.getGenre() 60 | 61 | 62 | def gather_all_genres(pathDIR, metadataList): 63 | out = [] 64 | print("Retrieving all books' genres...") 65 | bar = progressbar.ProgressBar(maxval=len(metadataList)) 66 | bar.start() 67 | for index, fileName in enumerate(metadataList): 68 | 69 | bar.update(index) 70 | pathMetadata = os.path.join(pathDIR, fileName) 71 | with open(pathMetadata, 'rb') as file: 72 | metadata = json.load(file) 73 | 74 | try: 75 | genre = getGenreFromMetadata(metadata) 76 | except KeyboardInterrupt: 77 | break 78 | except: 79 | genre = None 80 | 81 | out.append((fileName, genre)) 82 | 83 | bar.finish() 84 | return out 85 | -------------------------------------------------------------------------------- /data_preparation/build_all_stats.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import argparse 3 | from pathlib import Path 4 | import metadata_completion.utilities as ut 5 | import plot 6 | 7 | if __name__ == "__main__": 8 | 9 | parser = argparse.ArgumentParser(description="Build the statistics on LibriBig") 10 | parser.add_argument('path_data', type=str, 11 | help="Path to the directory containing the data") 12 | parser.add_argument('out_dir', type=str, 13 | help="Path to the output directory") 14 | parser.add_argument('--ignore_cache', action='store_true') 15 | args = parser.parse_args() 16 | 17 | # Build the output directory 18 | args.out_dir = Path(args.out_dir) 19 | Path.mkdir(args.out_dir, exist_ok=True) 20 | 21 | # Build the cache directory 22 | path_cache = args.out_dir / ".cache" 23 | Path.mkdir(path_cache, exist_ok=True) 24 | 25 | # Get the list of all metadata 26 | print("Gathering the list of metadata") 27 | path_cache_metadata = path_cache / "metadata.pkl" 28 | list_metadata = ut.load_cache(path_cache_metadata, 29 | ut.get_all_metadata, 30 | args=(args.path_data, ".json"), 31 | ignore_cache=args.ignore_cache) 32 | print(f"{len(list_metadata)} files found") 33 | 34 | # Get the genre statistics 35 | print("Building the genre statistics") 36 | path_genre_stats = path_cache / "meta_genre_stats.json" 37 | genre_data = ut.load_cache(path_genre_stats, 38 | ut.get_hour_tag_repartition, 39 | args=(list_metadata, 40 | "meta_genre", ".flac"), 41 | ignore_cache=args.ignore_cache) 42 | 43 | path_tags_hist = args.out_dir / "meta_genres.png" 44 | plot.plot_pie(genre_data, str(path_tags_hist), 45 | title="Genre's categories (in hours)") 46 | 47 | print("done.") 48 | 49 | # Get the speaker statistics 50 | print("Building the speaker statistics") 51 | path_speaker_cache = path_cache / "speaker_stats.json" 52 | speaker_data = ut.load_cache(path_speaker_cache, 53 | ut.get_speaker_hours_data, 54 | args=(list_metadata, 55 | ".flac")) 56 | 57 | speaker_hours = [x for _, x in speaker_data.items()] 58 | path_speaker_hist = args.out_dir / "speaker_data.png" 59 | n_bins = 100 60 | plot.plot_hist(speaker_hours, n_bins, str(path_speaker_hist), 61 | title="Time spoken per speaker", 62 | y_label="Number of speakers", normalized=False, 63 | y_scale='log', x_label="Time spoken in hours") 64 | print("done.") 65 | -------------------------------------------------------------------------------- /data_preparation/split_librilight/prepare_vads_tests.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import unittest 7 | from prepare_vads import split_vad 8 | 9 | 10 | class TestSplit(unittest.TestCase): 11 | def test_all_silence(self): 12 | p_silence = [1.0] * 100 13 | segments = split_vad(silence_probs=p_silence, 14 | p_silence_threshold=0.999, len_threshold=6) 15 | 16 | self.assertFalse(segments) 17 | 18 | def test_all_speech(self): 19 | p_silence = [0.0] * 100 20 | segments = split_vad(silence_probs=p_silence, 21 | p_silence_threshold=0.999, len_threshold=6) 22 | 23 | self.assertEqual(len(segments), 1) 24 | self.assertEqual(segments[0], (0, 100)) 25 | 26 | def test_half_speech(self): 27 | p_silence = [1.0] * 50 + [0.0] * 50 28 | segments = split_vad(silence_probs=p_silence, 29 | p_silence_threshold=0.999, len_threshold=6) 30 | 31 | self.assertEqual(len(segments), 1) 32 | self.assertEqual(segments[0], (50, 100)) 33 | 34 | def test_short_speech(self): 35 | """Speech segment shorter than len_threshold""" 36 | p_silence = [1.0] * 50 + [0.0] * 5 + [1.0] * 50 37 | segments = split_vad(silence_probs=p_silence, 38 | p_silence_threshold=0.999, len_threshold=6) 39 | 40 | self.assertFalse(segments) 41 | 42 | def test_short_silence(self): 43 | """Silence segment shorter than len_threshold""" 44 | p_silence = [0.0] * 50 + [1.0] * 5 + [0.0] * 50 45 | segments = split_vad(silence_probs=p_silence, 46 | p_silence_threshold=0.999, len_threshold=6) 47 | 48 | self.assertEqual(len(segments), 1) 49 | self.assertEqual(segments[0], (0, 105)) 50 | 51 | def test_few_segments(self): 52 | # positions 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20..29 30..39 53 | p_silence = [0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 54 | 0, 0, 1, 1, 1, 1, 1, 1] + [0] * 10 + [1] * 10 55 | # 1st: | . short silence | <- taken 56 | # 2nd: ^ those 10, taken 57 | self.assertEqual(len(p_silence), 40) 58 | 59 | segments = split_vad(silence_probs=p_silence, 60 | p_silence_threshold=0.999, len_threshold=6) 61 | 62 | self.assertEqual(len(segments), 2) 63 | self.assertEqual(segments[0], (0, 14)) 64 | self.assertEqual(segments[1], (20, 30)) 65 | 66 | def test_final_silence(self): 67 | p_silence = [1.0] * 50 68 | p_silence[40] = 0 69 | segments = split_vad(silence_probs=p_silence, 70 | p_silence_threshold=0.999, len_threshold=6) 71 | 72 | self.assertFalse(segments) 73 | 74 | 75 | if __name__ == '__main__': 76 | unittest.main() 77 | -------------------------------------------------------------------------------- /data_preparation/rebuild_limited_train/sample_10h.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from utils import get_histogram, get_speakers, traverse_tree, full_records, print_stats, materialize 3 | import argparse 4 | import pathlib 5 | import random 6 | 7 | 8 | def do_split_10h(records, speakers, max_seconds_per_speaker, min_seconds_per_speaker, total_seconds): 9 | """ 10 | Greedily selecting speakers, provided we don't go over budget 11 | """ 12 | scaler = 1.0 / 16000 # sampling rate 13 | speaker2time = get_histogram(records, lambda_key=lambda r: r.speaker.id, 14 | lambda_value=lambda r: r.length * scaler) 15 | 16 | speakers = set([r.speaker.id for r in records]) 17 | speakers = sorted(speakers) 18 | random.shuffle(speakers) 19 | 20 | time_taken = 0.0 21 | speakers_taken = [] 22 | 23 | for speaker in speakers: 24 | current_speaker_time = speaker2time[speaker] 25 | if min_seconds_per_speaker <= current_speaker_time <= max_seconds_per_speaker and current_speaker_time < total_seconds - time_taken: 26 | speakers_taken.append(speaker) 27 | time_taken += current_speaker_time 28 | 29 | speakers_taken = set(speakers_taken) 30 | 31 | records_filtered = [r for r in records if r.speaker.id in speakers_taken] 32 | return records_filtered 33 | 34 | 35 | def get_args(): 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument('--max_minutes_per_speaker', type=int, default=30) 38 | parser.add_argument('--min_minutes_per_speaker', type=int, default=20) 39 | parser.add_argument('--total_minutes', type=int, default=150) 40 | 41 | parser.add_argument('--target_dir', type=str) 42 | 43 | parser.add_argument('--seed', type=int, default=179) 44 | 45 | parser.add_argument('--root_clean', type=str) 46 | parser.add_argument('--root_other', type=str) 47 | parser.add_argument('--meta_path', type=str) 48 | 49 | args = parser.parse_args() 50 | 51 | if args.max_minutes_per_speaker <= 0: 52 | args.max_minutes_per_speaker = float('inf') 53 | return args 54 | 55 | 56 | if __name__ == '__main__': 57 | args = get_args() 58 | random.seed(args.seed) 59 | 60 | speakers = get_speakers(pathlib.Path(args.meta_path) / 'SPEAKERS.TXT') 61 | print('Found speakers', len(speakers)) 62 | 63 | for gender in ['m', 'f']: 64 | for root, tag in zip([args.root_clean, args.root_other], ['clean', 'other']): 65 | print(f'Selecting from {root}, gender {gender}, tag {tag}') 66 | 67 | fname2length = traverse_tree(root) 68 | records = full_records(speakers, fname2length) 69 | 70 | records = filter(lambda x: x.speaker.gender.lower() 71 | == gender, records) 72 | records = list(records) 73 | 74 | records_filtered = do_split_10h( 75 | records, speakers, args.max_minutes_per_speaker * 60, args.min_minutes_per_speaker * 60, args.total_minutes * 60) 76 | print_stats(records_filtered) 77 | 78 | if args.target_dir: 79 | materialize(records_filtered, args.target_dir, tag=tag) 80 | -------------------------------------------------------------------------------- /eval/WER_src/wl_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import math 4 | import os 5 | import struct 6 | import sys 7 | 8 | import numpy as np 9 | from wav2letter.common import Dictionary, create_word_dict, load_words, tkn_to_idx 10 | from wav2letter.decoder import ( 11 | CriterionType, 12 | DecoderOptions, 13 | KenLM, 14 | SmearingMode, 15 | Trie, 16 | WordLMDecoder, 17 | ) 18 | 19 | 20 | class WlDecoder: 21 | """ 22 | Wav2Letter-based decoder. Follows the official examples for the python bindings, 23 | see https://github.com/facebookresearch/wav2letter/blob/master/bindings/python/examples/decoder_example.py 24 | """ 25 | 26 | def __init__(self, 27 | lm_weight=2.0, 28 | lexicon_path="WER_data/lexicon.txt", 29 | token_path="WER_data/letters.lst", 30 | lm_path="WER_data/4-gram.bin"): 31 | lexicon = load_words(lexicon_path) 32 | word_dict = create_word_dict(lexicon) 33 | 34 | self.token_dict = Dictionary(token_path) 35 | self.lm = KenLM(lm_path, word_dict) 36 | 37 | self.sil_idx = self.token_dict.get_index("|") 38 | self.unk_idx = word_dict.get_index("") 39 | self.token_dict.add_entry("#") 40 | self.blank_idx = self.token_dict.get_index('#') 41 | 42 | self.trie = Trie(self.token_dict.index_size(), self.sil_idx) 43 | start_state = self.lm.start(start_with_nothing=False) 44 | 45 | for word, spellings in lexicon.items(): 46 | usr_idx = word_dict.get_index(word) 47 | _, score = self.lm.score(start_state, usr_idx) 48 | for spelling in spellings: 49 | # max_reps should be 1; using 0 here to match DecoderTest bug 50 | spelling_idxs = tkn_to_idx( 51 | spelling, self.token_dict, max_reps=0) 52 | self.trie.insert(spelling_idxs, usr_idx, score) 53 | 54 | self.trie.smear(SmearingMode.MAX) 55 | self.opts = DecoderOptions( 56 | beam_size=2500, beam_threshold=100.0, lm_weight=lm_weight, 57 | word_score=2.0, unk_score=-math.inf, log_add=False, sil_weight=-1, criterion_type=CriterionType.CTC) 58 | 59 | def collapse(self, prediction): 60 | result = [] 61 | 62 | for p in prediction: 63 | if result and p == result[-1]: 64 | continue 65 | result.append(p) 66 | 67 | blank = '#' 68 | space = '|' 69 | 70 | result = [x for x in result if x != blank] 71 | result = [(x if x != space else ' ') for x in result if x != blank] 72 | return result 73 | 74 | def predictions(self, emissions): 75 | t, n = emissions.size() 76 | 77 | emissions = emissions.cpu().numpy() 78 | decoder = WordLMDecoder( 79 | self.opts, self.trie, self.lm, self.sil_idx, self.blank_idx, self.unk_idx, []) 80 | results = decoder.decode(emissions.ctypes.data, t, n) 81 | 82 | prediction = [self.token_dict.get_entry( 83 | x) for x in results[0].tokens if x >= 0] 84 | prediction = self.collapse(prediction) 85 | 86 | return prediction 87 | -------------------------------------------------------------------------------- /baselines/TDS/README.md: -------------------------------------------------------------------------------- 1 | # TDS Baselines 2 | 3 | This directory provides codes to reproduce TDS baselines in the paper. You should use them together with [wav2letter](https://github.com/facebookresearch/wav2letter). 4 | 5 | ## Data 6 | - Two lists of supervised training data with 10 hours and 1 hour. 7 | - Two sets of tokens for phonemes and characters. 8 | - Two lexicons to map words to phonemes and characters: 9 | - [Letter lexicon](https://dl.fbaipublicfiles.com/librilight/data/tds_data/letter.lex) 10 | - [Phone lexicon](https://dl.fbaipublicfiles.com/librilight/data/tds_data/phone.lex) 11 | 12 | ## Experiments 13 | ### Model Architectures 14 | - A TDS model with 20 million parameters is provided for training on the limited supervised data. 15 | - A TDS model with 37 million parameters is provided for training on both supervised data and pseudo labels. 16 | 17 | ### Configurations 18 | #### Acoustic model 19 | Acoustic model training config files for each set-up. Note that the 20-millioin-parameter TDS models are trained on 8 GPUs each, while the 37-millioin-parameter ones are on 64 GPUs. See [wav2letter tutorials](https://github.com/facebookresearch/wav2letter/blob/master/docs/train.md#distributed) about how to run distributed training. 20 | 21 | Sample command: 22 | ```sh 23 | /wav2letter/build/Train \ 24 | --flagsfile=/libri-light/TDS/experiments/config/acoustic_model/10h+pseudo-label_letter_37M_TDS.cfg \ 25 | --enable_distributed=true 26 | ``` 27 | 28 | #### Decoding 29 | Optimal decoding parameters of each model. You can use wav2letter decoder to 30 | - Get optimal WER 31 | - Generate pseudo-labels. 32 | 33 | We use the official Librispeech 4-gram language model for all decoding experiments. The model can be downloaded [here](http://www.openslr.org/11/). 34 | 35 | Sample command: 36 | ```sh 37 | /wav2letter/build/Decode \ 38 | --flagsfile=/libri-light/TDS/experiments/config/decoding/10h+pseudo-label_letter_37M_TDS.cfg \ 39 | --sclite= 40 | ``` 41 | 42 | #### Pretrained Models 43 | | Supervised data | LibriVox | Target unit | Architecture | Model | 44 | | - | - | - | - | - | 45 | | 10 hours | Y | letter | 37M | [10h+pseudo-label_letter_37M_TDS.bin](https://dl.fbaipublicfiles.com/librilight/TDS_pseudo_label_checkpoints/10h%2Bpseudo-label_letter_37M_TDS.bin) | 46 | | 10 hours | Y | phonemes | 37M | [10h+pseudo-label_phone_37M_TDS.bin](https://dl.fbaipublicfiles.com/librilight/TDS_pseudo_label_checkpoints/10h%2Bpseudo-label_phone_37M_TDS.bin) | 47 | | 10 hours | N | letter | 20M | [10h_letter_20M_TDS.bin](https://dl.fbaipublicfiles.com/librilight/TDS_pseudo_label_checkpoints/10h_letter_20M_TDS.bin) | 48 | | 10 hours | N | phonemes | 20M | [10h_phone_20M_TDS.bin](https://dl.fbaipublicfiles.com/librilight/TDS_pseudo_label_checkpoints/10h_phone_20M_TDS.bin) | 49 | | 1 hour | Y | letter | 37M | [1h+pseudo-label_letter_37M_TDS.bin](https://dl.fbaipublicfiles.com/librilight/TDS_pseudo_label_checkpoints/1h%2Bpseudo-label_letter_37M_TDS.bin) | 50 | | 1 hour | Y | phonemes | 37M | [1h+pseudo-label_phone_37M_TDS.bin](https://dl.fbaipublicfiles.com/librilight/TDS_pseudo_label_checkpoints/1h%2Bpseudo-label_phone_37M_TDS.bin) | 51 | | 1 hour | N | letter | 20M | [1h_letter_20M_TDS.bin](https://dl.fbaipublicfiles.com/librilight/TDS_pseudo_label_checkpoints/1h_letter_20M_TDS.bin) | 52 | | 1 hour | N | phonemes | 20M | [1h_phone_20M_TDS.bin](https://dl.fbaipublicfiles.com/librilight/TDS_pseudo_label_checkpoints/1h_phone_20M_TDS.bin) | 53 | -------------------------------------------------------------------------------- /data_preparation/metadata_completion/genre_folding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | UNIQUE_GENRE_FOLDING = {'Religion': ['Religious Fiction', 'Christianity - Biographies', 'Christian Fiction', 'Weymouth New Testament', 'Other religions', 'World English Bible', 'Christianity - Commentary', 'Atheism & Agnosticism', 'Bibles', 'Religion', 'Christianity - Other'], 3 | 'Poetry': ['Poetry', 'Ballads', 'Antiquity', 'Lyric', 'Anthologies', 'Multi-version (Weekly and Fortnightly poetry)', 'Elegies & Odes', 'Free Verse', 'Sonnets'], 4 | 'Theater': ['Tragedy', 'Plays', 'Comedy', 'Performing Arts'], 5 | 'Ancient': ['Classics (Greek & Latin Antiquity)', 'Medieval', 'Ancient'], 6 | 'Fiction': ['General Fiction', 'Gothic Fiction', 'Sagas', 'Fantasy Fiction', 'Legends & Fairy Tales', 'Suspense', 'Espionage', 'Political & Thrillers', 'Myths', 'Fantastic Fiction', 'Fictional Biographies & Memoirs', 'Literary Collections', 'True Crime', 'Short Stories', 'Epics', 'Culture & Heritage Fiction', 'Nature & Animal Fiction', 'War & Military Fiction', 'Sports Fiction', 'Romance', 'Action & Adventure', "Children's Fiction", 'Erotica', 'Drama', 'Historical Fiction', 'Westerns', 'Detective Fiction', 'Literary Fiction', 'Horror & Supernatural Fiction', 'Crime & Mystery Fiction', 'Science Fiction', 'Narratives', 'Travel Fiction', 'Epistolary Fiction', 'Action & Adventure Fiction', 'Nautical & Marine Fiction'], 7 | 'Non fiction': ['Biography & Autobiography', 'Middle Ages/Middle History', 'Letters', 'Historical', 'Exploration', "Children's Non-fiction", 'Short non-fiction', 'History', 'Memoirs', 'Non fiction'], 8 | 'Humor': ['Humor', 'Satire', 'Humorous Fiction'], 9 | 'Essay': ['Travel & Geography', 'Essays & Short Works', 'Law', 'Family Life', 'Psychology', 'School', 'War & Military', 'Transportation', 'Games', 'Art', 'Family & Relationships', 'Essays', 'Business & Economics', 'Education', 'Arts', 'Animals & Nature', 'Philosophy', 'Literary Criticism', 'Writing & Linguistics'], 10 | 'Craft': ['Cooking', 'Crafts & Hobbies', 'Self-Help', 'Music', 'Gardening', 'Sports & Recreation', 'House & Home', 'Health & Fitness', 'Design & Architecture'], 11 | 'Dramatic Readings': ['Dramatic Readings'], 12 | 'Science': ['Mathematics', 'Science', 'Medical', 'Astronomy', 'Earth Sciences', 'Physics & Mechanics', 'Political Science', 'Life Sciences', 'Chemistry', 'Language learning', 'Social Science (Culture & Anthropology)', 'Technology & Engineering', 'Nature', 'Animals'], 13 | 'Undefined': ['Modern (20th C)', 'Reference', 'Douay-Rheims Version', 'Short works', 'Modern', 'Single Author Collections', 'null', 'General', 'Early Modern', 'Modern (19th C)', 'Published 1800 -1900', '*Non-fiction', 'Contemporary', 'Family', 'Single author', 'Published before 1800', 'Published 1900 onward', "Young's Literal Translation", 'King James Version']} 14 | 15 | GENDER_ORDERING = ['Dramatic Readings', 'Poetry', 'Religion', 'Theater', 'Fiction', 16 | 'Science', 'Essay', 'Humor', 'Ancient', 'Non fiction', 'Craft', 'Undefined'] 17 | 18 | SUPER_GENDER_FOLDING = {'Science, Craft & Essay': ['Essay', 'Science', 'Craft'], 19 | 'Literature': ['Fiction', 'Non fiction', 'Humor']} 20 | 21 | SUPER_GENDER_ORDERING = ['Dramatic Readings', 'Poetry', 'Religion', 22 | 'Literature', 'Science, Craft & Essay', 'Theater', 'Ancient', 'Undefined'] 23 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at . All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq 77 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Libri-Light: A Benchmark for ASR with Limited or No Supervision 2 | You can track papers that use Libri-Light and their relative performance on Papers With Code: 3 | [[test-clean]](https://paperswithcode.com/sota/speech-recognition-on-libri-light-test-clean) 4 | [[test-other]](https://paperswithcode.com/sota/speech-recognition-on-libri-light-test-clean) 5 | 6 | ## Description 7 | 8 | This repository contains code and models associated with the Libri-Light dataset, which can be [downloaded and prepared here](./data_preparation/README.md). More information about dataset creation and baselines can be found in this [arXiv Paper](https://arxiv.org/abs/1912.07875). Contained here is code for data preparation, pretrained models, and evaluation resources: 9 | 10 | 11 | data_preparation/ # code to download the data; VAD and SNR code; json generation; stats; audio segmentation 12 | eval/ # ABX, PER, WER (evaluation metrics on LibriSpeech dev-clean, dev-other, test-clean, test-other) 13 | baselines/ # code, pretrained wav2letter models, baselines, and examples 14 | 15 | To get started, first clone the repository: 16 | 17 | git clone https://github.com/facebookresearch/libri-light 18 | 19 | The environment is easiest to set up with Anaconda. Requirements can be installed by running: 20 | 21 | conda env create -f environment.yml && conda activate libri-light 22 | 23 | If you don't have `conda` you can get it [here](https://docs.anaconda.com/anaconda/install/). 24 | 25 | ## Goals and structure 26 | 27 | Libri-Light offers 60+ k hours of unlabelled speech, a small training set for limited supervision (10h, 1h or 10 minutes of labelled speech), and a common set of metrics to evaluated three settings: 28 | 29 | 1. the unsupervised/zero-resource setting. Here, models are trained only on unlabelleds speech and attempt to construct 'good' speech representations. They are evaluated with the ABX metric. 30 | 2. the semi-supervised setting. Here, models are trained with the limited supervision dataset and exploit the unlabelled in various ways (as pretraining, to get pseudo-labels, etc). The models are evaluated using either PER or WER. 31 | 3. the distant supervision setting. Here, models can use additional unaligned text to build a decoder. These models are evaluated using WER. 32 | 33 | 34 | ## Documentation 35 | 36 | Documentation for downloading Libri-Light or preparing the source files from scratch can be found in [`data_preparation`](./data_preparation/README.md). 37 | 38 | The [`eval`](./eval/README.md) directory contains ABX, PER and WER evaluations on pretrained CPC models. 39 | 40 | The [`baselines`](./baselines/README.md) directory contains pretrained [wav2letter](https://github.com/facebookresearch/wav2letter/) baseline models and information about reproduction. 41 | 42 | 43 | ## Citing 44 | ``` 45 | @INPROCEEDINGS{librilight, 46 | author={J. {Kahn} and M. {Rivière} and W. {Zheng} and E. {Kharitonov} and Q. {Xu} and P. E. {Mazaré} and J. {Karadayi} and V. {Liptchinsky} and R. {Collobert} and C. {Fuegen} and T. {Likhomanenko} and G. {Synnaeve} and A. {Joulin} and A. {Mohamed} and E. {Dupoux}}, 47 | booktitle={ICASSP 2020 - 2020 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, 48 | title={Libri-Light: A Benchmark for ASR with Limited or No Supervision}, 49 | year={2020}, 50 | pages={7669-7673}, 51 | note = {\url{https://github.com/facebookresearch/libri-light}}, 52 | } 53 | ``` 54 | 55 | ## License 56 | 57 | The Libri-light code is released under the [MIT license](https://opensource.org/licenses/MIT). See LICENSE for additional details. 58 | -------------------------------------------------------------------------------- /data_preparation/text_retrieval/main_lesson.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from html.parser import HTMLParser 3 | import requests 4 | 5 | 6 | def get_tag_value_in_url(url, tag): 7 | baseUrl = "display.php?" 8 | argBegin = url.find(baseUrl) 9 | if argBegin < 0: 10 | raise RuntimeError("Invalid url") 11 | 12 | argBegin += len(baseUrl) 13 | args = url[argBegin:].split('&') 14 | 15 | for item in args: 16 | if item.find(f'{tag}=') == 0: 17 | return item.split(f'{tag}=')[1] 18 | 19 | raise RuntimeError("{tag} not found") 20 | 21 | 22 | def get_full_url(author, book, chapter): 23 | return f"http://www.gatewaytotheclassics.com/browse/display.php?author={author}&book={book}&story={chapter}" 24 | 25 | 26 | class ToCParser(HTMLParser): 27 | 28 | def __init__(self): 29 | 30 | self.chaptersList = [] 31 | self.inLinkBlock = False 32 | super(ToCParser, self).__init__() 33 | 34 | def handle_starttag(self, tag, attrs): 35 | if tag == "div" and ("class", "lhlink") in attrs: 36 | self.inLinkBlock = True 37 | elif tag == 'a' and self.inLinkBlock: 38 | for name, value in attrs: 39 | if name == "href": 40 | self.chaptersList.append( 41 | get_tag_value_in_url(value, 'story')) 42 | 43 | def handle_endtag(self, tag): 44 | if tag == "div": 45 | self.inLinkBlock = False 46 | 47 | 48 | class ChapterParser(HTMLParser): 49 | 50 | def __init__(self): 51 | 52 | self.text = "" 53 | self.title = None 54 | self.getData = False 55 | self.getTitle = False 56 | super(ChapterParser, self).__init__() 57 | 58 | def handle_starttag(self, tag, attrs): 59 | if tag == "h1" and ("align", "CENTER") in attrs: 60 | self.getTitle = True 61 | 62 | if tag == "table": 63 | self.getData = False 64 | 65 | def handle_endtag(self, tag): 66 | if tag == "h1": 67 | self.getTitle = False 68 | self.getData = True 69 | if tag == "table" and self.title is not None: 70 | self.getData = True 71 | 72 | def handle_data(self, data): 73 | if self.getTitle: 74 | self.title = data.replace("\\", "") 75 | elif self.getData: 76 | self.text += data.replace('\\n', '\n').replace("\\", "") 77 | 78 | def get_full_text(self): 79 | 80 | if self.title is None: 81 | raise RuntimeError("No title found") 82 | 83 | return self.title + '\n' + self.text 84 | 85 | 86 | def get_all_text_from_main_lesson(url): 87 | 88 | book = get_tag_value_in_url(url, 'book') 89 | author = get_tag_value_in_url(url, 'author') 90 | 91 | tocUrl = get_full_url(author, book, '_contents') 92 | 93 | parserToC = ToCParser() 94 | req = requests.get(tocUrl) 95 | parserToC.feed(str(req._content)) 96 | 97 | fullText = "" 98 | 99 | for chapterName in parserToC.chaptersList: 100 | 101 | txtUrl = get_full_url(author, book, chapterName) 102 | parserChapter = ChapterParser() 103 | req = requests.get(txtUrl) 104 | parserChapter.feed(str(req._content)) 105 | fullText += parserChapter.get_full_text() 106 | 107 | return fullText 108 | 109 | 110 | def is_main_lesson_url(url): 111 | return url.find("mainlesson.com") >= 0 112 | 113 | 114 | if __name__ == "__main__": 115 | url = "http://www.gatewaytotheclassics.com/browse/display.php?author=marshall&book=beowulf&story=_contents" 116 | print(get_all_text_from_main_lesson(url)) 117 | -------------------------------------------------------------------------------- /data_preparation/metadata_completion/ReaderScapper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from html.parser import HTMLParser 3 | import requests 4 | import json 5 | import os 6 | from .utilities import get_speaker_data_name 7 | from copy import deepcopy 8 | import progressbar 9 | 10 | 11 | class ReaderScrapper(HTMLParser): 12 | 13 | def __init__(self): 14 | 15 | self.readerName = None 16 | self.inHeader = False 17 | self.getData = False 18 | super(ReaderScrapper, self).__init__() 19 | 20 | def handle_starttag(self, tag, attrs): 21 | if tag == "div": 22 | if ('class', 'page author-page') in attrs: 23 | self.inHeader = True 24 | 25 | if tag == "h1" and self.inHeader: 26 | self.getData = True 27 | 28 | def handle_endtag(self, tag): 29 | if tag == "div": 30 | self.inHeader = False 31 | if tag == "h1": 32 | self.getData = False 33 | 34 | def handle_data(self, data): 35 | if self.getData: 36 | self.readerName = data 37 | 38 | 39 | def get_librivox_reader_from_id(readerID): 40 | url = f"https://librivox.org/reader/{readerID}" 41 | parser = ReaderScrapper() 42 | req = requests.get(url) 43 | parser.feed(str(req._content)) 44 | return parser.readerName 45 | 46 | 47 | def updateDataWithNames(speakerData, idMatch): 48 | 49 | newData = deepcopy(speakerData) 50 | 51 | if speakerData["readers"] is None: 52 | newData["readers_names"] = None 53 | return newData 54 | 55 | newData["readers_names"] = [] 56 | 57 | for item in speakerData["readers"]: 58 | if item is not None: 59 | for ID in item: 60 | if ID not in idMatch: 61 | try: 62 | idMatch[ID] = get_librivox_reader_from_id(ID) 63 | except RuntimeError: 64 | idMatch[ID] = None 65 | 66 | for item in speakerData["readers"]: 67 | if item is None: 68 | newData["readers_names"].append(None) 69 | else: 70 | all_names = [] 71 | for ID in item: 72 | all_names.append(idMatch[ID]) 73 | newData["readers_names"].append(all_names) 74 | 75 | return newData 76 | 77 | 78 | def update_all_speaker_data(listMetadata, pathInDir, pathOutDir): 79 | 80 | print("Updating the speaker data, this is going to be looong....") 81 | pathInDir = os.path.abspath(pathInDir) 82 | pathOutDir = os.path.abspath(pathOutDir) 83 | assert(pathInDir != pathOutDir) 84 | 85 | if not os.path.isdir(pathOutDir): 86 | os.mkdir(pathOutDir) 87 | 88 | bar = progressbar.ProgressBar(maxval=len(listMetadata)) 89 | bar.start() 90 | 91 | idMatch = {None: None} 92 | 93 | for index, pathMetadata in enumerate(listMetadata): 94 | bar.update(index) 95 | 96 | pathSpeakerData = get_speaker_data_name(pathMetadata) 97 | fullPathSpeakerData = os.path.join(pathInDir, pathSpeakerData) 98 | with open(fullPathSpeakerData, 'rb') as file: 99 | speakerData = json.load(file) 100 | 101 | outData = updateDataWithNames(speakerData, idMatch) 102 | pathOutData = os.path.join(pathOutDir, pathSpeakerData) 103 | 104 | assert(fullPathSpeakerData != pathOutData) 105 | 106 | with open(pathOutData, 'w') as file: 107 | json.dump(outData, file, indent=2) 108 | 109 | bar.finish() 110 | 111 | 112 | if __name__ == "__main__": 113 | pathIn = "/checkpoint/mriviere/LibriVox/" 114 | pathOut = "/checkpoint/mriviere/LibriVox_updatedSpeakers/" 115 | update_all_speaker_data(pathIn, pathOut) 116 | -------------------------------------------------------------------------------- /data_preparation/unzip_and_convert.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import os 3 | import subprocess 4 | import argparse 5 | import multiprocessing 6 | 7 | 8 | def unzip(args): 9 | 10 | if args.path_out is None: 11 | args.path_out = args.path_in 12 | 13 | if not os.path.isdir(args.path_out): 14 | os.mkdir(args.path_out) 15 | 16 | files_in = [f for f in os.listdir(args.path_in) 17 | if os.path.splitext(f)[1] == '.zip'] 18 | 19 | print(f"{len(files_in)} files found") 20 | 21 | for file_name in files_in: 22 | full_path_in = os.path.join(args.path_in, file_name) 23 | full_path_out = os.path.join(args.path_out, 24 | os.path.splitext(file_name)[0]) 25 | 26 | if not os.path.isdir(full_path_out): 27 | os.mkdir(full_path_out) 28 | 29 | subprocess.run(["unzip", full_path_in, "-d", full_path_out]) 30 | 31 | 32 | def _convert_dir(task): 33 | dir_name, args = task 34 | valid_formats = ['.mp3', '.ogg', '.flac', '.wav'] 35 | 36 | full_path_in = os.path.join(args.path_in, dir_name) 37 | files_list = [f for f in os.listdir(full_path_in) 38 | if os.path.splitext(f)[1] in valid_formats] 39 | 40 | full_path_out = os.path.join(args.path_out, dir_name) 41 | if not os.path.isdir(full_path_out): 42 | os.mkdir(full_path_out) 43 | 44 | for file_name in files_list: 45 | base_name, format = os.path.splitext(file_name) 46 | path_out_file = os.path.join( 47 | full_path_out, base_name + args.format) 48 | path_in_file = os.path.join(full_path_in, file_name) 49 | 50 | subprocess.run(["ffmpeg", "-i", path_in_file, 51 | "-ac", "1", 52 | "-ar", str(args.sample_rate), path_out_file], 53 | stdout=subprocess.DEVNULL) 54 | 55 | 56 | def convert(args, n_processes=16): 57 | 58 | if args.path_out is None: 59 | args.path_out = args.path_in 60 | 61 | if not os.path.isdir(args.path_out): 62 | os.mkdir(args.path_out) 63 | 64 | dirs_in = [f for f in os.listdir(args.path_in) 65 | if os.path.isdir(os.path.join(args.path_in, f))] 66 | print(f"{len(dirs_in)} books found") 67 | 68 | pool = multiprocessing.Pool(processes=n_processes) 69 | pool.map(_convert_dir, [(dir_name, args) for dir_name in dirs_in]) 70 | 71 | 72 | if __name__ == "__main__": 73 | 74 | parser = argparse.ArgumentParser( 75 | description='Unzip and Convert Libri-Light') 76 | subparsers = parser.add_subparsers(dest='command') 77 | 78 | parser_unzip = subparsers.add_parser('unzip', 79 | help='Unzip the Libri-Light dataset') 80 | parser_unzip.add_argument('path_in', type=str) 81 | parser_unzip.add_argument('-o', '--path_out', type=str, default=None) 82 | 83 | parser_convert = subparsers.add_parser('convert', 84 | help="Convert the " 85 | "Librilight_dataset into the " 86 | "desired format.") 87 | parser_convert.add_argument('path_in', type=str) 88 | parser_convert.add_argument('-o', '--path_out', type=str, default=None) 89 | parser_convert.add_argument('-f', '--format', type=str, default=".flac") 90 | parser_convert.add_argument('-s', '--sample_rate', type=int, default=16000) 91 | parser_convert.add_argument('-j', '--n_processes', type=int, default=16, 92 | help="Number of worker processes") 93 | 94 | args = parser.parse_args() 95 | 96 | if args.command == 'unzip': 97 | unzip(args) 98 | elif args.command == 'convert': 99 | convert(args) 100 | -------------------------------------------------------------------------------- /data_preparation/cut_by_vad.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import pathlib 3 | import soundfile as sf 4 | import numpy as np 5 | import json 6 | import multiprocessing 7 | import argparse 8 | import tqdm 9 | 10 | def save(seq, fname, index, extension): 11 | output = np.hstack(seq) 12 | file_name = fname.parent / (fname.stem + f"_{index:04}{extension}") 13 | fname.parent.mkdir(exist_ok=True, parents=True) 14 | sf.write(file_name, output, samplerate=16000) 15 | 16 | 17 | def cut_sequence(path, vad, path_out, target_len_sec, out_extension): 18 | data, samplerate = sf.read(path) 19 | 20 | assert len(data.shape) == 1 21 | assert samplerate == 16000 22 | 23 | to_stitch = [] 24 | length_accumulated = 0.0 25 | 26 | i = 0 27 | for start, end in vad: 28 | start_index = int(start * samplerate) 29 | end_index = int(end * samplerate) 30 | slice = data[start_index:end_index] 31 | 32 | # if a slice is longer than target_len_sec, we put it entirely in it's own piece 33 | if length_accumulated + (end - start) > target_len_sec and length_accumulated > 0: 34 | save(to_stitch, path_out, i, out_extension) 35 | to_stitch = [] 36 | i += 1 37 | length_accumulated = 0 38 | 39 | to_stitch.append(slice) 40 | length_accumulated += end - start 41 | 42 | if to_stitch: 43 | save(to_stitch, path_out, i, out_extension) 44 | 45 | 46 | def cut_book(task): 47 | path_book, root_out, target_len_sec, extension = task 48 | 49 | speaker = pathlib.Path(path_book.parent.name) 50 | 51 | for i, meta_file_path in enumerate(path_book.glob('*.json')): 52 | with open(meta_file_path, 'r') as f: 53 | meta = json.loads(f.read()) 54 | book_id = meta['book_meta']['id'] 55 | vad = meta['voice_activity'] 56 | 57 | sound_file = meta_file_path.parent / (meta_file_path.stem + '.flac') 58 | 59 | path_out = root_out / speaker / book_id / (meta_file_path.stem) 60 | cut_sequence(sound_file, vad, path_out, target_len_sec, extension) 61 | 62 | 63 | def cut(input_dir, 64 | output_dir, 65 | target_len_sec=30, 66 | n_process=32, 67 | out_extension='.flac'): 68 | 69 | list_dir = pathlib.Path(input_dir).glob('*/*') 70 | list_dir = [x for x in list_dir if x.is_dir()] 71 | 72 | print(f"{len(list_dir)} directories detected") 73 | print(f"Launching {n_process} processes") 74 | 75 | tasks = [(path_book, output_dir, target_len_sec, out_extension) for path_book in list_dir] 76 | 77 | with multiprocessing.Pool(processes=n_process) as pool: 78 | for _ in tqdm.tqdm(pool.imap_unordered(cut_book, tasks), total=len(tasks)): 79 | pass 80 | 81 | 82 | def parse_args(): 83 | 84 | parser = argparse.ArgumentParser(description="Cut a dataset in small " 85 | "sequences using VAD files") 86 | parser.add_argument('--input_dir', type=str, default=None, 87 | help="Path to the input directory", required=True) 88 | parser.add_argument('--output_dir', type=str, default=None, 89 | help="Path to the output directory", required=True) 90 | 91 | parser.add_argument('--target_len_sec', type=int, default=60, 92 | help="Target time, in seconds of each output sequence" 93 | "(default is 60)") 94 | parser.add_argument('--n_workers', type=int, default=32, 95 | help="Number of parallel worker processes") 96 | parser.add_argument('--out_extension', type=str, default=".flac", 97 | choices=[".wav", ".flac", ".mp3"], 98 | help="Output extension") 99 | 100 | 101 | return parser.parse_args() 102 | 103 | 104 | if __name__ == "__main__": 105 | args = parse_args() 106 | pathlib.Path(args.output_dir).mkdir(exist_ok=True, parents=True) 107 | 108 | cut(args.input_dir, args.output_dir, args.target_len_sec, 109 | args.n_workers, args.out_extension) 110 | -------------------------------------------------------------------------------- /data_preparation/text_retrieval/hathitrust.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from html.parser import HTMLParser 3 | import requests 4 | 5 | 6 | class HathitrustParser(HTMLParser): 7 | 8 | def __init__(self): 9 | 10 | self.nextUrl = None 11 | self.tmpUrl = None 12 | self.text = "" 13 | self.getNextP = False 14 | self.getTextData = False 15 | self.emptyPage = False 16 | super(HathitrustParser, self).__init__() 17 | 18 | def handle_starttag(self, tag, attrs): 19 | if tag == "a": 20 | for name, value in attrs: 21 | if name == "href": 22 | self.tmpUrl = value 23 | 24 | if tag == "div" and ("id", "mdpPage") in attrs: 25 | self.getNextP = True 26 | 27 | if tag == "div" and ("id", "mdpTextEmpty") in attrs: 28 | self.emptyPage = True 29 | 30 | if tag == "p" and self.getNextP: 31 | self.getTextData = True 32 | 33 | def handle_data(self, data): 34 | if self.tmpUrl is not None and data.find("Next Page") >= 0: 35 | self.nextUrl = self.tmpUrl 36 | if self.getTextData: 37 | self.text += data 38 | 39 | def handle_endtag(self, tag): 40 | if tag == "a": 41 | self.tmpUrl = None 42 | elif tag == "div": 43 | self.getNextP = False 44 | elif tag == "p": 45 | self.getTextData = False 46 | self.getNextP = False 47 | 48 | 49 | class CatalogParser(HTMLParser): 50 | 51 | def __init__(self): 52 | self.candidatesID = [] 53 | super(CatalogParser, self).__init__() 54 | 55 | def handle_starttag(self, tag, attrs): 56 | 57 | attrs = dict(attrs) 58 | if tag == "a": 59 | if attrs["href"].find("handle.net") >= 0: 60 | self.candidatesID.append(attrs["data-hdl"]) 61 | 62 | 63 | def load_whole_book(bookID): 64 | 65 | baseUrl = "https://babel.hathitrust.org/cgi/ssd?" 66 | nextUrl = f"{baseUrl}id={bookID};page=ssd;view=plaintext;seq=1;num=" 67 | 68 | fullText = "" 69 | 70 | while True: 71 | parserChapter = HathitrustParser() 72 | req = requests.get(nextUrl) 73 | parserChapter.feed(req._content.decode('utf-8')) 74 | if parserChapter.nextUrl is None: 75 | break 76 | nextUrl = f"https://babel.hathitrust.org{parserChapter.nextUrl}" 77 | 78 | if not parserChapter.emptyPage: 79 | fullText += parserChapter.text 80 | 81 | return fullText 82 | 83 | 84 | def is_hathitrust_url(url): 85 | return url.find("hathitrust.org") >= 0 86 | 87 | 88 | def load_hathitrust_book(url): 89 | 90 | candidatesID = None 91 | if url.find("catalog.hathitrust.org") >= 0: 92 | catalogParser = CatalogParser() 93 | req = requests.get(url) 94 | catalogParser.feed(req._content.decode('utf-8')) 95 | 96 | if len(catalogParser.candidatesID) == 0: 97 | raise RuntimeError("Invalid url") 98 | 99 | candidatesID = catalogParser.candidatesID 100 | 101 | else: 102 | key = "cgi/ssd?" 103 | startOffset = url.find(key) 104 | 105 | if startOffset < 0: 106 | raise RuntimeError("Invalid url") 107 | 108 | startOffset += len(key) 109 | markers = url[startOffset:].split(';') 110 | 111 | for data in markers: 112 | name, value = data.split('=') 113 | if name == "id": 114 | candidatesID = [value] 115 | break 116 | 117 | if candidatesID is None: 118 | raise RuntimeError("Invalid url") 119 | 120 | text = None 121 | for id in candidatesID: 122 | 123 | try: 124 | text = load_whole_book(id) 125 | except RuntimeError: 126 | continue 127 | 128 | if text is None: 129 | raise RuntimeError("Couldn't find any transcription") 130 | 131 | return text 132 | 133 | 134 | if __name__ == "__main__": 135 | 136 | url1 = "https://babel.hathitrust.org/cgi/ssd?id=coo.31924074296884;page=ssd;view=plaintext;seq=110;num=104" 137 | url2 = "http://catalog.hathitrust.org/Record/002242980" 138 | print(load_hathitrust_book(url1)) 139 | -------------------------------------------------------------------------------- /data_preparation/split_librilight/prepare_vads.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from typing import List, Tuple 7 | import argparse 8 | import json 9 | import pathlib 10 | import multiprocessing 11 | 12 | 13 | def get_args(): 14 | parser = argparse.ArgumentParser( 15 | description="Transform per-frame VAD files into segments with voice activity") 16 | 17 | parser.add_argument('--vad_root', type=str, required=True) 18 | parser.add_argument('--time_step_ms', type=float, default=80) 19 | parser.add_argument('--p_threshold', type=float, default=0.999) 20 | parser.add_argument('--len_threshold_frames', type=int, 21 | default=6) # 6 frames ~ 0.5s 22 | parser.add_argument('--n_workers', type=int, default=32) 23 | 24 | parser.add_argument('--output', type=str, default='vads.json') 25 | 26 | args = parser.parse_args() 27 | return args 28 | 29 | 30 | def parse_vad(fname): 31 | with open(fname, 'r') as f: 32 | probs = f.read() 33 | probs = [float(x) for x in probs.split()] 34 | return probs 35 | 36 | 37 | def split_vad(silence_probs: List[float], p_silence_threshold: float, len_threshold: int) -> List[Tuple[int, int]]: 38 | """Given a sequence `p_probs` of silence probabilities p, this function 39 | returns intervals of speech activity, such that (a) those intervals are separated by 40 | at least `len_threshold` of silent frames (p > `p_silence_threshold`), 41 | (b) are themselves longer than `len_threshold`. 42 | 43 | Arguments: 44 | silence_probs -- list of silence probabilities 45 | p_silence_threshold -- all frames with silence probability above this thresholds 46 | are considered as silence 47 | len_threshold -- minimal length of silence and non-silence segments 48 | 49 | Returns: list of tuples (start_speech_frame, first_silence_frame_after_start or end_of_sequence) 50 | """ 51 | segments = [] 52 | 53 | start = None 54 | i = 0 55 | n = len(silence_probs) 56 | 57 | while i < len(silence_probs) and silence_probs[i] > p_silence_threshold: 58 | i += 1 59 | # supported invariants: `start` points to the frame where speech starts, i >= start 60 | start = i 61 | 62 | while i < n: 63 | # scroll until first silence frame 64 | if silence_probs[i] < p_silence_threshold: 65 | i += 1 66 | continue 67 | 68 | # now i points to the first silence frame 69 | # look ahead: do we have at least len_threshold silence frames? 70 | all_silence = True 71 | for j in range(i + 1, min(i + len_threshold, n)): 72 | all_silence = all_silence and silence_probs[j] > p_silence_threshold 73 | if not all_silence: 74 | break 75 | 76 | if not all_silence: 77 | # no we don't: disregard the silence, go further 78 | # starting from the first non-silence frame 79 | i = j 80 | else: 81 | # we do have enough silence for a split 82 | if i - start > len_threshold: 83 | segments.append((start, i)) 84 | 85 | while i < n and silence_probs[i] > p_silence_threshold: 86 | i += 1 87 | start = i 88 | i += 1 89 | 90 | if i - start > len_threshold and start < n: 91 | segments.append((start, i)) 92 | 93 | return segments 94 | 95 | 96 | def process(task): 97 | name, args = task 98 | vads = parse_vad(name) 99 | segments = split_vad(vads, args.p_threshold, args.len_threshold_frames) 100 | name = str(name.parent.name) + '/' + str(name.name) 101 | return (name, (segments, len(vads))) 102 | 103 | 104 | if __name__ == '__main__': 105 | fname2segments = {} 106 | 107 | args = get_args() 108 | 109 | tasks = [(x, args) for x in pathlib.Path(args.vad_root).rglob("*.vad")] 110 | print(f'Found {len(tasks)} vad files') 111 | 112 | with multiprocessing.Pool(processes=args.n_workers) as pool: 113 | fname2segments = pool.map(process, tasks) 114 | fname2segments = dict(fname2segments) 115 | 116 | with open(args.output, 'w') as f: 117 | f.write(json.dumps(fname2segments, sort_keys=True)) 118 | -------------------------------------------------------------------------------- /eval/WER_src/simple_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from typing import Dict, List 3 | 4 | import torch 5 | 6 | from torch.utils.data import Dataset 7 | import torchaudio 8 | from copy import deepcopy 9 | import time 10 | from pathlib import Path 11 | 12 | 13 | def parse_ctc_labels_from_root(root, letters_path="./WER_data/letters.lst"): 14 | letter2index = {} 15 | index2letter = {} 16 | 17 | with open(letters_path, 'r') as f: 18 | for i, line in enumerate(f.readlines()): 19 | line = line.strip() 20 | 21 | letter2index[line] = i 22 | index2letter[i] = line 23 | 24 | result = {} 25 | 26 | for file in Path(root).rglob("*.txt"): 27 | with open(file, 'r') as f: 28 | for line in f.readlines(): 29 | line = line.rstrip() 30 | p = line.find(' ') 31 | assert p > 0 32 | fname = line[:p] 33 | 34 | chars = line[p+1:].replace(' ', '|').lower() 35 | decoded = [] 36 | 37 | for c in chars: 38 | decoded.append(letter2index[c]) 39 | result[fname] = decoded 40 | 41 | return result, len(letter2index), (letter2index, index2letter) 42 | 43 | 44 | def find_seqs(dir_name, extension='.flac'): 45 | sequences = [] 46 | for file in Path(dir_name).rglob('*' + extension): 47 | speaker = file.parent.parent.stem 48 | sequences.append((speaker, file)) 49 | 50 | speakers = set(x[0] for x in sequences) 51 | return sequences, speakers 52 | 53 | 54 | class SingleSequenceDataset(Dataset): 55 | 56 | def __init__(self, 57 | root: str, 58 | labels: Dict[str, List[int]]): 59 | """ 60 | root {str} -- Directory that contains the dataset files 61 | labels {Dict[str, List[int]]} -- Dict mapping a filename (without extension) to a list of 62 | integer-encoded labels. 63 | """ 64 | self.seq_names, _ = find_seqs(root) 65 | self.labels_dict = deepcopy(labels) 66 | 67 | self.seq_offsets = [0] 68 | self.labels = [] 69 | self.label_offsets = [0] 70 | self.data = [] 71 | self.max_len_wave = 0 72 | self.max_len_labels = 0 73 | 74 | self.load_seqs() 75 | 76 | def load_seqs(self): 77 | data = [] 78 | 79 | start_time = time.time() 80 | for _, seq in self.seq_names: 81 | name = Path(seq).stem 82 | wave = torchaudio.load(seq)[0].view(-1) 83 | data.append((name, wave)) 84 | 85 | data.sort() 86 | 87 | temp_data = [] 88 | total_size = 0 89 | for name, wave in data: 90 | self.labels.extend(self.labels_dict[name]) 91 | self.label_offsets.append(len(self.labels)) 92 | self.max_len_labels = max( 93 | self.max_len_labels, len(self.labels_dict[name])) 94 | wave_length = wave.size(0) 95 | self.max_len_wave = max(self.max_len_wave, wave_length) 96 | total_size += wave_length 97 | temp_data.append(wave) 98 | self.seq_offsets.append(self.seq_offsets[-1] + wave_length) 99 | 100 | self.data = torch.cat(temp_data, dim=0) 101 | self.labels = torch.tensor(self.labels, dtype=torch.long) 102 | 103 | print(f'Loaded {len(self.label_offsets)} sequences ' 104 | f'in {time.time() - start_time:.2f} seconds') 105 | print(f'max_len_wave: {self.max_len_wave}') 106 | print(f'max_len_labels: {self.max_len_labels}') 107 | print(f'Total size dataset {total_size / (16000 * 3600)} hours') 108 | 109 | def __getitem__(self, idx): 110 | wave_tart = self.seq_offsets[idx] 111 | wave_end = self.seq_offsets[idx + 1] 112 | labels_start = self.label_offsets[idx] 113 | labels_end = self.label_offsets[idx + 1] 114 | 115 | wave_len = wave_end - wave_tart 116 | label_len = labels_end - labels_start 117 | 118 | wave = torch.zeros((1, self.max_len_wave)) 119 | labels = torch.zeros((self.max_len_labels), dtype=torch.long) 120 | 121 | wave[0, :wave_len] = self.data[wave_tart:wave_end] 122 | labels[:label_len] = self.labels[labels_start:labels_end] 123 | 124 | return wave, torch.tensor([wave_len], dtype=torch.long), labels, torch.tensor([label_len], dtype=torch.long) 125 | 126 | def __len__(self): 127 | return len(self.seq_offsets) - 1 128 | -------------------------------------------------------------------------------- /data_preparation/metadata_completion/text_cleaner.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from text_retrieval.guttenberg import is_guttenberg_url 3 | from .utilities import get_txt_name 4 | import os 5 | import json 6 | import progressbar 7 | import sys 8 | sys.path.append('..') 9 | 10 | 11 | def loadData(pathFile): 12 | 13 | with open(pathFile, 'r') as file: 14 | data = file.readlines() 15 | 16 | indexStartProject = -1 17 | indexProducedBy = -1 18 | indexEndProject = -1 19 | for index, line in enumerate(data): 20 | if indexStartProject < 0: 21 | value = line.replace(' ', '').find("***START") 22 | if value >= 0: 23 | indexStartProject = index 24 | elif line.find("CONTENTS") >= 0: 25 | indexStartProject = index 26 | else: 27 | continue 28 | 29 | value = line.replace(' ', '').find("***END") 30 | if value >= 0: 31 | indexEndProject = index 32 | break 33 | 34 | if indexProducedBy < 0: 35 | value = line.find("Produced by") 36 | if value >= 0: 37 | indexProducedBy = index 38 | 39 | if indexStartProject < 0: 40 | return None 41 | 42 | if indexEndProject < 0: 43 | indexEndProject = len(data) 44 | 45 | startIndex = indexProducedBy + 1 if indexProducedBy > 0 \ 46 | else indexStartProject + 1 47 | while startIndex < len(data) and data[startIndex] == '\n': 48 | startIndex += 1 49 | 50 | return ''.join(data[startIndex:indexEndProject]) 51 | 52 | 53 | def find404Error(pathFile): 54 | with open(pathFile, 'r') as file: 55 | data = file.readlines() 56 | 57 | return len(data) == 1 and \ 58 | data[0] == "

404 Not Found

File not found.

" 59 | 60 | 61 | def clean_all_text_data(metadataList, pathInDir, pathOutDir): 62 | 63 | pathInDir = os.path.abspath(pathInDir) 64 | pathOutDir = os.path.abspath(pathOutDir) 65 | 66 | if pathInDir == pathOutDir: 67 | raise ValueError("Can't save the data in the same directory \ 68 | as the originals") 69 | 70 | bar = progressbar.ProgressBar(maxval=len(metadataList)) 71 | bar.start() 72 | nCleaned = 0 73 | nMissing = 0 74 | nNotWorking = 0 75 | emptyTxt = [] 76 | out = [] 77 | 78 | for index, metadataName in enumerate(metadataList): 79 | bar.update(index) 80 | textFileName = get_txt_name(metadataName) 81 | pathInFile = os.path.join(pathInDir, textFileName) 82 | outPathFile = os.path.join(pathOutDir, textFileName) 83 | 84 | if not os.path.isfile(pathInFile): 85 | status = "missing" 86 | nMissing += 1 87 | else: 88 | 89 | assert(pathInFile != outPathFile) 90 | 91 | with open(os.path.join(pathInDir, metadataName), 'rb') as file: 92 | urlSource = json.load(file)["url_text_source"] 93 | 94 | if not is_guttenberg_url(urlSource): 95 | os.popen(f'cp {pathInFile} {outPathFile}') 96 | status = "clear" 97 | else: 98 | outData = loadData(pathInFile) 99 | 100 | if outData is None: 101 | nNotWorking += 1 102 | if find404Error(pathInFile): 103 | emptyTxt.append(pathInFile) 104 | status = "missing" 105 | else: 106 | status = "noisy" 107 | else: 108 | with open(outPathFile, 'w') as file: 109 | file.write(outData) 110 | status = "clear" 111 | out.append((metadataName, status)) 112 | nCleaned += 1 113 | 114 | bar.finish() 115 | print(f"Out of {len(metadataList)} items") 116 | print(f"{nCleaned} files were cleaned and saved to {pathOutDir}") 117 | print(f"{nNotWorking} files didn't match the good format among which {len(emptyTxt)} were empty") 118 | print(f"{nMissing} files were missing") 119 | return out 120 | 121 | 122 | if __name__ == "__main__": 123 | 124 | pathDirData = "/checkpoint/mriviere/LibriVox/" 125 | pathOutData = "/checkpoint/mriviere/LibriVox_cleanTxt/" 126 | 127 | if not os.path.isdir(pathOutData): 128 | os.mkdir(pathOutData) 129 | 130 | clean_all_text_data(pathDirData, pathOutData) 131 | 132 | # pathTestFile = "/checkpoint/mriviere/LibriVox/sadhana_realisation_librivox_64kb_mp3_text.txt" 133 | # print(find404Error(pathTestFile)) 134 | -------------------------------------------------------------------------------- /data_preparation/rebuild_limited_train/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import pathlib 3 | from collections import namedtuple 4 | import torchaudio 5 | import shutil 6 | 7 | 8 | Speaker = namedtuple('Speaker', ['id', 'gender', 'subset']) 9 | FileRecord = namedtuple( 10 | 'FileRecord', ['fname', 'length', 'speaker', 'book', 'text_file']) 11 | 12 | 13 | def get_speakers(speaker_path): 14 | all_speakers = [] 15 | with open(speaker_path) as f: 16 | for line in f: 17 | if line.startswith(';'): 18 | continue 19 | 20 | line = line.split('|') 21 | speaker_id, gender, subset = [x.strip() for x in line[0:3]] 22 | speaker_id = int(speaker_id) 23 | 24 | assert subset in ['test-clean', 'train-clean-360', 'train-clean-100', 25 | 'test-other', 'dev-clean', 'train-other-500', 'dev-other'], subset 26 | 27 | speaker = Speaker(id=speaker_id, gender=gender, subset=subset) 28 | all_speakers.append(speaker) 29 | return all_speakers 30 | 31 | 32 | def get_filelength(fname): 33 | info = torchaudio.info(fname)[0] 34 | return info.length 35 | 36 | 37 | def traverse_tree(root, ext='flac'): 38 | fnames = pathlib.Path(root).rglob(f"*.{ext}") 39 | fnames = sorted(list(fnames)) 40 | 41 | lengths = [] 42 | for file in fnames: 43 | file = str(file.resolve()) 44 | length = get_filelength(file) 45 | lengths.append(length) 46 | 47 | return list(zip(fnames, lengths)) 48 | 49 | 50 | def get_speaker_fname(fname): 51 | stemmed = fname.stem 52 | speaker, book, seq = stemmed.split('-') 53 | return int(speaker), int(book) 54 | 55 | 56 | def full_records(speakers, fname2length, subset_name=None): 57 | all_records = [] 58 | 59 | speakers = dict((speaker.id, speaker) for speaker in speakers) 60 | 61 | for fname, length in fname2length: 62 | speaker, book = get_speaker_fname(fname) 63 | assert speaker in speakers, f'Unknown speaker! {speaker}' 64 | 65 | speaker = speakers[speaker] 66 | 67 | if subset_name is not None: 68 | assert subset_name == speaker.subset 69 | # hacky 70 | text_file = fname.parent / f'{speaker.id}-{book}.trans.txt' 71 | frecord = FileRecord(speaker=speaker, length=length, 72 | fname=fname, book=book, text_file=text_file) 73 | all_records.append(frecord) 74 | 75 | return all_records 76 | 77 | 78 | def get_histogram(records, lambda_key, lambda_value): 79 | from collections import defaultdict 80 | key_value = defaultdict(int) 81 | 82 | for record in records: 83 | key = lambda_key(record) 84 | value = lambda_value(record) 85 | 86 | key_value[key] += value 87 | 88 | return key_value 89 | 90 | 91 | def materialize(records, target_dir, tag=None, move=False): 92 | target_dir = pathlib.Path(target_dir) 93 | 94 | to_copy = set() 95 | to_move = set() 96 | 97 | for record in records: 98 | # outline: 99 | # target_dir / speaker / book / file 100 | if tag is None: 101 | target_book_dir = target_dir / \ 102 | str(record.speaker.id) / str(record.book) 103 | else: 104 | target_book_dir = target_dir / tag / \ 105 | str(record.speaker.id) / str(record.book) 106 | target_book_dir.mkdir(exist_ok=True, parents=True) 107 | 108 | if not move: 109 | to_copy.add((record.fname, target_book_dir / record.fname.name)) 110 | else: 111 | to_move.add((record.fname, target_book_dir / record.fname.name)) 112 | 113 | to_copy.add((record.text_file, target_book_dir / record.text_file.name)) 114 | 115 | to_copy = sorted(list(to_copy)) 116 | for src, dst in to_copy: 117 | shutil.copy(src, dst) 118 | 119 | if len(to_move) > 0: 120 | to_move = sorted(list(to_move)) 121 | for src, dst in to_move: 122 | shutil.move(src, dst) 123 | 124 | 125 | def print_stats(records): 126 | def lambda_speaker(r): return r.speaker.id 127 | def lambda_time(r): return r.length / 16000.0 128 | 129 | speaker_time = get_histogram( 130 | records, lambda_key=lambda_speaker, lambda_value=lambda_time) 131 | print(f'Unique speakers: {len(speaker_time)}') 132 | times = speaker_time.values() 133 | min_time, max_time, mean_time, total_time = min( 134 | times), max(times), sum(times) / len(times), sum(times) 135 | min_time, max_time, mean_time, total_time = map( 136 | int, [min_time, max_time, mean_time, total_time]) 137 | print( 138 | f'Min/Mean/Max/Total, seconds: {min_time}/{mean_time}/{max_time}/{total_time}') 139 | print(f'n_utterances: {len(records)}') 140 | -------------------------------------------------------------------------------- /data_preparation/make_vad_inputs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import os 3 | from pathlib import Path 4 | import torchaudio 5 | import progressbar 6 | import argparse 7 | import torch 8 | import tqdm 9 | 10 | 11 | def findAllSeqs(dirName, 12 | extension='.flac', 13 | loadCache=False): 14 | r""" 15 | Lists all the sequences with the given extension in the dirName directory. 16 | Output: 17 | outSequences, speakers 18 | 19 | outSequence 20 | A list of tuples seq_path, speaker where: 21 | - seq_path is the relative path of each sequence relative to the 22 | parent directory 23 | - speaker is the corresponding speaker index 24 | 25 | outSpeakers 26 | The speaker labels (in order) 27 | 28 | The speaker labels are organized the following way 29 | \dirName 30 | \speaker_label 31 | \.. 32 | ... 33 | seqName.extension 34 | """ 35 | cache_path = os.path.join(dirName, '_seqs_cache.txt') 36 | if loadCache: 37 | try: 38 | outSequences, speakers = torch.load(cache_path) 39 | print(f'Loaded from cache {cache_path} successfully') 40 | return outSequences, speakers 41 | except OSError as err: 42 | print(f'Ran in an error while loading {cache_path}: {err}') 43 | print('Could not load cache, rebuilding') 44 | 45 | if dirName[-1] != os.sep: 46 | dirName += os.sep 47 | prefixSize = len(dirName) 48 | speakersTarget = {} 49 | outSequences = [] 50 | for root, dirs, filenames in tqdm.tqdm(os.walk(dirName)): 51 | filtered_files = [f for f in filenames if f.endswith(extension)] 52 | 53 | if len(filtered_files) > 0: 54 | speakerStr = root[prefixSize:].split(os.sep)[0] 55 | if speakerStr not in speakersTarget: 56 | speakersTarget[speakerStr] = len(speakersTarget) 57 | speaker = speakersTarget[speakerStr] 58 | for filename in filtered_files: 59 | full_path = os.path.join(root[prefixSize:], filename) 60 | outSequences.append((speaker, full_path)) 61 | outSpeakers = [None for x in speakersTarget] 62 | for key, index in speakersTarget.items(): 63 | outSpeakers[index] = key 64 | try: 65 | torch.save((outSequences, outSpeakers), cache_path) 66 | print(f'Saved cache file at {cache_path}') 67 | except OSError as err: 68 | print(f'Ran in an error while saving {cache_path}: {err}') 69 | return outSequences, outSpeakers 70 | 71 | 72 | def get_file_duration_ms(path_file): 73 | info = torchaudio.info(path_file)[0] 74 | return 1000*(info.length // (info.rate)) 75 | 76 | 77 | def get_lst(path_db, file_list): 78 | 79 | bar = progressbar.ProgressBar(maxval=len(file_list)) 80 | bar.start() 81 | 82 | path_db = Path(path_db) 83 | out = [] 84 | 85 | for index, file_name in enumerate(file_list): 86 | 87 | bar.update(index) 88 | full_path = str(path_db / file_name) 89 | duration = get_file_duration_ms(full_path) 90 | out.append((full_path, full_path, int(duration))) 91 | 92 | bar.finish() 93 | return out 94 | 95 | 96 | def save_lst(data, path_out): 97 | 98 | with open(path_out, 'w') as file: 99 | for id, path, val in data: 100 | file.write(' '.join((id, path, str(val))) + '\n') 101 | 102 | 103 | def reorder_vad(path_vad, lst): 104 | 105 | path_vad = Path(path_vad) 106 | 107 | for id, full_path_wav, _ in lst: 108 | 109 | full_path_vad = (path_vad / id).with_suffix('.vad') 110 | full_path_out = Path(full_path_wav).with_suffix('.vad') 111 | full_path_vad.replace(full_path_out) 112 | 113 | full_path_vad.with_suffix('.fwt').unlink(missing_ok=True) 114 | full_path_vad.with_suffix('.tsc').unlink(missing_ok=True) 115 | full_path_vad.with_suffix('.sts').unlink(missing_ok=True) 116 | 117 | 118 | if __name__ == "__main__": 119 | 120 | parser = argparse.ArgumentParser(description="Build the vad inputs") 121 | 122 | parser.add_argument('path_db', type=str, 123 | help="Path to the dataset directory") 124 | parser.add_argument('path_out', type=str) 125 | parser.add_argument('--ignore_cache', action='store_true') 126 | parser.add_argument('--debug', action='store_true') 127 | parser.add_argument('--extension', type=str, default='.wav') 128 | 129 | args = parser.parse_args() 130 | 131 | seqList, _ = findAllSeqs(args.path_db, extension=args.extension, 132 | loadCache=not args.ignore_cache) 133 | if args.debug: 134 | seqList = seqList[:10] 135 | 136 | seqList = [i[1] for i in seqList] 137 | 138 | vad_data = get_lst(args.path_db, seqList) 139 | save_lst(vad_data, args.path_out) 140 | -------------------------------------------------------------------------------- /data_preparation/complete_metadata.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import argparse 3 | from pathlib import Path 4 | import sys 5 | import json 6 | 7 | import metadata_completion.utilities as ut 8 | from metadata_completion.GenreScrapper import gather_all_genres 9 | from metadata_completion.ReaderScapper import update_all_speaker_data 10 | from metadata_completion.genre_folding import UNIQUE_GENRE_FOLDING, \ 11 | SUPER_GENDER_FOLDING, \ 12 | SUPER_GENDER_ORDERING 13 | from metadata_completion.DuplicateSearch import get_books_duplicates 14 | from metadata_completion.text_cleaner import clean_all_text_data 15 | 16 | 17 | def parse_args(): 18 | parser = argparse.ArgumentParser(description="Upgrade LibriBIG's metadata") 19 | parser.add_argument('--path_metadata', type=str, 20 | help="Path to the directory containing the metadata", 21 | default="/checkpoint/mriviere/LibriVox") 22 | parser.add_argument('--ignore_cache', action='store_true') 23 | parser.add_argument('--debug', action='store_true') 24 | parser_out = parser.add_mutually_exclusive_group(required=False) 25 | parser_out.add_argument('--out_dir', type=str, default=None, 26 | help="Path to the output directory") 27 | parser_out.add_argument('-i', '--in_place', action='store_true') 28 | return parser 29 | 30 | 31 | def main(argv): 32 | parser = parse_args() 33 | args = parser.parse_args(argv) 34 | 35 | if args.in_place: 36 | path_out = Path(args.path_metadata) 37 | elif args.out_dir is not None: 38 | path_out = Path(args.out_dir) 39 | Path.mkdir(path_out, exist_ok=True) 40 | else: 41 | print(f"You must input either an output directory or activate the " 42 | "inplace flag") 43 | parser.print_help() 44 | sys.exit() 45 | 46 | path_cache = path_out / ".cache" 47 | Path.mkdir(path_cache, exist_ok=True) 48 | 49 | path_global_data_dir = path_out / "global" 50 | Path.mkdir(path_global_data_dir, exist_ok=True) 51 | 52 | # Get the list of all metadata 53 | print("Gathering the list of metadata") 54 | path_cache_metadata = path_cache / "metadata.pkl" 55 | list_metadata = ut.load_cache(path_cache_metadata, 56 | ut.get_all_metadata, 57 | args=(args.path_metadata,), 58 | ignore_cache=args.ignore_cache) 59 | 60 | if args.debug: 61 | list_metadata = list_metadata[:10] 62 | 63 | # Retrieve the genres 64 | genre_list = gather_all_genres(args.path_metadata, 65 | list_metadata) 66 | 67 | ut.get_updated_metadata(genre_list, args.path_metadata, path_out, "genre") 68 | 69 | # Fold the genres 70 | reverse_folding_unique = ut.build_reverse_folding(UNIQUE_GENRE_FOLDING) 71 | reverse_folding_super = ut.build_reverse_folding(SUPER_GENDER_FOLDING) 72 | final_reverse_folding = ut.combine_reverse_foldings(reverse_folding_super, 73 | reverse_folding_unique) 74 | 75 | # Convert the "dramatic reading" option into a binary tag 76 | has_dramatic_reading = [(name, 'Dramatic Readings' in vals) 77 | for name, vals in genre_list] 78 | ut.get_updated_metadata(has_dramatic_reading, path_out, 79 | path_out, 'Dramatic Readings') 80 | genre_list = [(name, ut.remove_tag(vals, 'Dramatic Readings', 'Undefined')) 81 | for name, vals in genre_list] 82 | 83 | #dramatric_reading = [(name, ut.has_tag(tag_str, tag))] 84 | folded_genres = [(name, ut.remove_multiple_tags(ut.apply_folding('+'.join(vals), 85 | final_reverse_folding), 86 | SUPER_GENDER_ORDERING)) 87 | for name, vals in genre_list] 88 | 89 | ut.get_updated_metadata(folded_genres, path_out, 90 | path_out, "meta_genre") 91 | 92 | # Retrieve the readers names 93 | update_all_speaker_data(list_metadata, args.path_metadata, path_out) 94 | 95 | # Look for duplicates 96 | duplicate_list = get_books_duplicates(args.path_metadata, list_metadata) 97 | path_out_duplicates = path_global_data_dir / "duplicates.json" 98 | print(f"Saving the duplicates index at {path_out_duplicates}") 99 | with open(path_out_duplicates, 'w') as file: 100 | json.dump(duplicate_list, file, indent=2) 101 | 102 | # Clean text data when possible 103 | text_status = clean_all_text_data(list_metadata, args.path_metadata, 104 | str(path_out)) 105 | ut.get_updated_metadata(text_status, path_out, 106 | path_out, "trancription_status") 107 | 108 | 109 | if __name__ == "__main__": 110 | args = sys.argv[1:] 111 | main(args) 112 | -------------------------------------------------------------------------------- /eval/README.md: -------------------------------------------------------------------------------- 1 | # Eval 2 | 3 | You will find here all relevant evaluation launched on the LibriLight-dataset. 4 | 5 | ## ABX 6 | 7 | ABX is an evaluation metric for unsupervised representation learning. It evaluates feature files based on its ability to distinguish sounds like /i/ and /e/ as in "bit" versus "bet". 8 | 9 | ### Setup 10 | 11 | To setup the ABX evaluation script you need to: 12 | 13 | 1. compile the cython code. Just do: 14 | 15 | ```console 16 | cd ABX_src 17 | python setup.py build_ext --inplace 18 | ``` 19 | 20 | 2. Check that everything works properly with: 21 | ```console 22 | cd ABX_src 23 | nosetests -d 24 | ``` 25 | 26 | 3. Download the Librilight `.item` files here: [ABX_data.tgz](https://dl.fbaipublicfiles.com/librilight/data/ABX_data.tgz). 27 | 28 | This archive contains four `.item` files constructed from the Librispeech dev and test set: `dev-clean.item`, `dev-other.item`, `test-clean.item`, and `test-other.item`, which provide the labels for the ABX evaluation. 29 | 30 | ### How to run the ABX evaluation ? 31 | 32 | Dump your features in .pt (torch), .npz or .npy (numpy) format somewhere. Your features dataset should look like this: 33 | 34 | ```console 35 | \data_dir 36 | file_name_0.extension 37 | file_name_1.extension 38 | ... 39 | ``` 40 | 41 | Each file should contain a 2D-vector of shape Sequence_size x Feature_dimension. 42 | 43 | Then run: 44 | ```console 45 | python eval_ABX.py $PATH_FEATURE_DIR $PATH_TO_ABX_ITEMS/$DB_NAME.item --file_extension $EXTENSION --out $OUTPUT_DIR --feature_size $FEATURE_SIZE 46 | ``` 47 | 48 | Where `$DB_NAME` is one of the 4 evaluation datasets (`dev-clean`, `dev-other`, `test-clean`, `test-other`) and `$FEATURE_SIZE` is the duration (in s) of one feature of the model (for a `10ms` frame rate, this would be `0.01`). 49 | 50 | 51 | ## Pre-computed checkpoints 52 | 53 | Some pre-computed model trained with CPC are available for use ! In order to load a model just use CPC_loader.py, for example to retrieve the model trained on the 60k hours dataset: 54 | 55 | ```console 56 | python CPC_loader.py 60k $PATH_OUTPUT_CHECKPOINT 57 | ``` 58 | 59 | You can directly evaluate the ABX score on this checkpoint by running: 60 | ```console 61 | python eval_ABX.py $PATH_AUDIO_DIR ABX_data/$DB_NAME.item --file_extension $EXTENSION --out $OUTPUT_DIR --path_checkpoint $PATH_OUTPUT_CHECKPOINT 62 | ``` 63 | 64 | Where $EXTENSION corresponds to an audio foramt (.wav, .flac ...) 65 | 66 | ## Linear Classification PER 67 | 68 | Representations can also be evaluated by how easy it is to train a linear phoneme classifier. 69 | 70 | ### Setup 71 | 72 | To setup the PER evaluation script you need to compile the cython code it relies on. Just do: 73 | ```console 74 | cd PER_src 75 | python setup.py build_ext --inplace 76 | ``` 77 | 78 | You will also need to download the [10h labelled data](https://dl.fbaipublicfiles.com/librilight/data/librispeech_finetuning.tgz). 79 | 80 | ### How to run the PER evaluation ? 81 | 82 | First you need to train a linear classifier on your features. For example, if you want to evaluate a model fine-tuned on the 10h dataset, just run: 83 | ```console 84 | python eval_PER.py train $PATH_TO_10h_AUDIO_DATA_DIR $PATH_TO_10h_PHONE_DATA $PATH_TO_THE_JSON_PHONE_CONVERTER $PATH_TO_THE_CPC_MODEL -o $PATH_OUT 85 | ``` 86 | 87 | Then you can run the PER computation, for example on librispeech100/test-clean: 88 | ```console 89 | python eval_PER.py per $PATH_OUT/checkpoint.pt $PATH_TO_TEST_CLEAN $PATH_TO_TEST_CLEAN_PHONES --file_extension .flac 90 | ``` 91 | 92 | 93 | ## WER 94 | 95 | We provide here a test of representations based on word error rate. 96 | 97 | ### Setup 98 | * wav2letter python bindings: [(how-to)](https://github.com/facebookresearch/wav2letter/tree/master/bindings/python). 99 | * KenLM-based Librispeech language model, can be found [here](http://www.openslr.org/11/) or downloaded [here](https://dl.fbaipublicfiles.com/librilight/data/4-gram.bin); it should be placed into `WER_data/`. 100 | * lexicon, [download](https://dl.fbaipublicfiles.com/librilight/data/lexicon.txt.gz); it should be placed into `WER_data/`. 101 | * jiwer, installable via `pip install jiwer`. 102 | 103 | ### How to run the WER evaluation? 104 | 105 | Training a letter classifier on top of a pre-trained CPC model: 106 | ```console 107 | python eval_WER.py --path_train=$PATH_FINETUNING --path_val=$PATH_TO_DEV_CLEAN --path_checkpoint=$PATH_OUT/checkpoint.pt --lr=1e-3 --n_epochs=50 --p_dropout=0.1 --output=$OUTPUT_DIR 108 | 109 | ``` 110 | Evaluating it with wav2letter decoder: 111 | ```console 112 | python eval_WER.py --path_checkpoint=$PATH_OUT/checkpoint.pt --lr=1e-3 --n_epochs=50 --p_dropout=0.1 --output=$OUTPUT_DIR --path_wer=$PATH_TO_TEST_CLEAN 113 | ``` 114 | 115 | You can also train and evaluate afterwards, in a single command: 116 | ```console 117 | python eval_WER.py --path_train=$PATH_FINETUNING --path_val=$PATH_TO_DEV_CLEAN --path_checkpoint=$PATH_OUT/checkpoint.pt --lr=1e-3 --n_epochs=50 --p_dropout=0.1 --output=$OUTPUT_DIR --path_wer=$PATH_TO_TEST_CLEAN 118 | ``` 119 | -------------------------------------------------------------------------------- /data_preparation/calculate_snr.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import sys 3 | import time 4 | import os 5 | import numpy as np 6 | from scipy.io import wavfile 7 | import multiprocessing 8 | import argparse 9 | 10 | 11 | _INT16_MAX_VALUE = float(np.abs(np.iinfo(np.int16).min)) 12 | _INT32_MAX_VALUE = float(np.abs(np.iinfo(np.int32).min)) 13 | 14 | 15 | def convert_wav_buf_f32(data): 16 | if data.dtype == np.float32: 17 | pass 18 | elif data.dtype == np.int16: 19 | data = data.astype(np.float32) / _INT16_MAX_VALUE 20 | elif data.dtype == np.int32: 21 | data = data.astype(np.float32) / _INT32_MAX_VALUE 22 | else: 23 | raise ValueError( 24 | "Expecting dtype to be float32/int16/int32" 25 | + "current type is {}".format(str(data.dtype)) 26 | ) 27 | return data.astype(np.float32) 28 | 29 | 30 | def cal_signal_power(start, end, audio, fs): 31 | signal_start = int(round(float(start) * fs)) 32 | signal_length = int(round(float(end) * fs)) 33 | signal_energy = np.sum( 34 | np.power(audio[signal_start: signal_start + signal_length], 2) 35 | ) 36 | return signal_energy, signal_length 37 | 38 | 39 | def calculate_snr(sample, vad, fs=16000, noise_th=0.995, speech_th=0.8, vad_window_ms=80): 40 | sample = convert_wav_buf_f32(sample) 41 | sample_chunk = np.split(sample, range( 42 | 0, len(sample), int(vad_window_ms * fs / 1000))) 43 | speech_chunk = [] 44 | noise_chunk = [] 45 | leftover_chunk = [] 46 | speech_continue_chunk = 2 # heuristic, 240ms 47 | for x, v in zip(sample_chunk, vad): 48 | if v < speech_th or speech_continue_chunk >= 0: 49 | speech_chunk.append(x) 50 | if v < speech_th: 51 | speech_continue_chunk = 2 52 | else: 53 | speech_continue_chunk -= 1 54 | elif v > noise_th: 55 | noise_chunk.append(x) 56 | else: 57 | leftover_chunk.append(x) 58 | speech_chunk = np.concatenate(speech_chunk) 59 | speech_energy = np.sum(np.power(speech_chunk, 2)) 60 | speech_time = len(speech_chunk)/fs 61 | speech_power = speech_energy/speech_time 62 | if len(noise_chunk) == 0: 63 | print("no noise?", file=sys.stderr) 64 | return [float('nan'), speech_power, float('nan')] 65 | noise_chunk = np.concatenate(noise_chunk) 66 | leftover_chunk = np.concatenate(leftover_chunk) 67 | noise_energy = np.sum(np.power(noise_chunk, 2)) 68 | noise_time = len(noise_chunk)/fs 69 | noise_power = noise_energy/noise_time 70 | snr = 10 * np.log10((speech_power)/noise_power) 71 | return [snr, speech_power, noise_power] 72 | 73 | 74 | def calculate_file_snr(file_name, speech_th, noise_th): 75 | vad_file = file_name[:-4] + '.vad' 76 | if not os.path.exists(vad_file): 77 | return file_name, None 78 | try: 79 | fs, signal = wavfile.read(file_name) 80 | except: 81 | print("ignoring {}, wrong format".format(file_name), file=sys.stderr) 82 | return file_name, None 83 | with open(vad_file, 'r') as fh: 84 | vad_string = fh.read() 85 | vad = np.fromstring(vad_string, sep=' ') 86 | return file_name, calculate_snr(signal, vad, speech_th=speech_th, noise_th=noise_th, fs=fs) 87 | 88 | 89 | def cal_snr_librivox(file_name): 90 | return calculate_file_snr(file_name, speech_th=0.8, noise_th=0.995) 91 | 92 | 93 | def mp_file_snr(lst_file, records=None, nproc=60): 94 | with open(lst_file, 'r') as fh: 95 | fnames = [line.split()[0] for line in fh] 96 | if records is not None: 97 | with open(records, 'r') as fh: 98 | existing_fname = [line.split()[0] for line in fh] 99 | fnames = set(fnames) - set(existing_fname) 100 | print("loaded {} file to process".format(len(fnames)), file=sys.stderr) 101 | pool = multiprocessing.Pool(nproc) 102 | print("processing librivox format", file=sys.stderr) 103 | it = pool.imap_unordered(cal_snr_librivox, fnames) 104 | st = time.time() 105 | cnt = 0 106 | for fname, snr_fields in it: 107 | cnt += 1 108 | if snr_fields is not None: 109 | sys.stdout.write('{}\t{}\t{}\t{}\n'.format( 110 | *([fname] + snr_fields))) 111 | if cnt % 1000 == 0 and cnt != 0: 112 | dur = time.time() - st 113 | print("{} file/s".format(cnt / dur), file=sys.stderr) 114 | st = time.time() 115 | cnt = 0 116 | 117 | 118 | if __name__ == "__main__": 119 | usage = """ 120 | example: python calculate_snr.py librivox.lst > snr_output.tsv 121 | """ 122 | parser = argparse.ArgumentParser(description=usage) 123 | parser.add_argument("wav_list", type=str, 124 | help="list path to wavs. oneline per file") 125 | parser.add_argument("--resume_from", type=str, 126 | help="if specified, all entries in the resume-from file will be skipped") 127 | parser.add_argument("--numproc", type=int, default=40, 128 | help="num of processes") 129 | args = parser.parse_args() 130 | mp_file_snr(args.wav_list, records=args.resume_from, nproc=args.numproc) 131 | -------------------------------------------------------------------------------- /eval/PER_src/seq_alignment.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from copy import deepcopy 3 | import torch 4 | import progressbar 5 | import math 6 | import numpy as np 7 | from torch.multiprocessing import Lock, Manager 8 | from .PER_src import per_operator 9 | 10 | 11 | def cut_data(seq, sizeSeq): 12 | maxSeq = sizeSeq.max() 13 | seq = seq[:, :maxSeq] 14 | return seq 15 | 16 | 17 | def beam_search(score_preds, nKeep, blankLabel): 18 | 19 | T, P = score_preds.shape 20 | beams = set(['']) 21 | pb_t_1 = {"": 1} 22 | pnb_t_1 = {"": 0} 23 | 24 | def getLastNumber(b): 25 | return int(b.split(',')[-1]) 26 | 27 | for t in range(T): 28 | 29 | nextBeams = set() 30 | pb_t = {} 31 | pnb_t = {} 32 | for i_beam, b in enumerate(beams): 33 | if b not in pb_t: 34 | pb_t[b] = 0 35 | pnb_t[b] = 0 36 | 37 | if len(b) > 0: 38 | pnb_t[b] += pnb_t_1[b] * score_preds[t, getLastNumber(b)] 39 | pb_t[b] = (pnb_t_1[b] + pb_t_1[b]) * score_preds[t, blankLabel] 40 | nextBeams.add(b) 41 | 42 | for c in range(P): 43 | if c == blankLabel: 44 | continue 45 | 46 | b_ = b + "," + str(c) 47 | if b_ not in pb_t: 48 | pb_t[b_] = 0 49 | pnb_t[b_] = 0 50 | 51 | if b != "" and getLastNumber(b) == c: 52 | pnb_t[b_] += pb_t_1[b] * score_preds[t, c] 53 | else: 54 | pnb_t[b_] += (pb_t_1[b] + pnb_t_1[b]) * score_preds[t, c] 55 | nextBeams.add(b_) 56 | 57 | allPreds = [(pb_t[b] + pnb_t[b], b) for b in nextBeams] 58 | allPreds.sort(reverse=True) 59 | 60 | beams = [x[1] for x in allPreds[:nKeep]] 61 | pb_t_1 = deepcopy(pb_t) 62 | pnb_t_1 = deepcopy(pnb_t) 63 | 64 | output = [] 65 | for score, x in allPreds[:nKeep]: 66 | output.append((score, [int(y) for y in x.split(',') if len(y) > 0])) 67 | return output 68 | 69 | 70 | def get_seq_PER(seqLabels, detectedLabels): 71 | return per_operator.needleman_wunsch_align_score(seqLabels, detectedLabels, 72 | -1, -1, 0, 73 | normalize=True) 74 | 75 | 76 | def prepare_data(data): 77 | seq, sizeSeq, phone, sizePhone = data 78 | seq = seq.cuda(non_blocking=True) 79 | phone = phone.cuda(non_blocking=True) 80 | sizeSeq = sizeSeq.cuda(non_blocking=True).view(-1) 81 | sizePhone = sizePhone.cuda(non_blocking=True).view(-1) 82 | 83 | seq = cut_data(seq, sizeSeq) 84 | 85 | return seq, sizeSeq, phone, sizePhone 86 | 87 | 88 | def get_local_per(pool, mutex, p_, gt_seq, BLANK_LABEL): 89 | predSeq = np.array(beam_search(p_, 10, BLANK_LABEL)[0][1], dtype=np.int32) 90 | per = get_seq_PER(gt_seq, predSeq) 91 | mutex.acquire() 92 | pool.append(per) 93 | mutex.release() 94 | 95 | 96 | def per_step(valLoader, 97 | model, 98 | criterion, 99 | downsamplingFactor): 100 | 101 | model.eval() 102 | criterion.eval() 103 | 104 | avgPER = 0 105 | varPER = 0 106 | nItems = 0 107 | 108 | print("Starting the PER computation through beam search") 109 | bar = progressbar.ProgressBar(maxval=len(valLoader)) 110 | bar.start() 111 | 112 | for index, data in enumerate(valLoader): 113 | 114 | bar.update(index) 115 | 116 | with torch.no_grad(): 117 | seq, sizeSeq, phone, sizePhone = prepare_data(data) 118 | c_feature = model(seq) 119 | sizeSeq = sizeSeq / downsamplingFactor 120 | predictions = torch.nn.functional.softmax(criterion.getPrediction(c_feature), 121 | dim=2).cpu() 122 | c_feature = c_feature 123 | phone = phone.cpu() 124 | sizeSeq = sizeSeq.cpu() 125 | sizePhone = sizePhone.cpu() 126 | 127 | mutex = Lock() 128 | manager = Manager() 129 | poolData = manager.list() 130 | 131 | processes = [] 132 | for b in range(sizeSeq.size(0)): 133 | l_ = min(sizeSeq[b] // 4, predictions.size(1)) 134 | s_ = sizePhone[b] 135 | p = torch.multiprocessing.Process(target=get_local_per, 136 | args=(poolData, mutex, predictions[b, :l_].view(l_, -1).numpy(), 137 | phone[b, :s_].view(-1).numpy().astype(np.int32), criterion.BLANK_LABEL)) 138 | p.start() 139 | processes.append(p) 140 | for p in processes: 141 | p.join() 142 | 143 | avgPER += sum([x for x in poolData]) 144 | varPER += sum([x*x for x in poolData]) 145 | nItems += len(poolData) 146 | 147 | bar.finish() 148 | 149 | avgPER /= nItems 150 | varPER /= nItems 151 | 152 | varPER -= avgPER**2 153 | print(f"Average PER {avgPER}") 154 | print(f"Standard deviation PER {math.sqrt(varPER)}") 155 | -------------------------------------------------------------------------------- /eval/ABX_src/abx_group_computation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import torch 3 | import math 4 | from .ABX_src import dtw 5 | import progressbar 6 | 7 | 8 | def get_distance_function_from_name(name_str): 9 | if name_str == 'euclidian': 10 | return get_euclidian_distance_batch 11 | if name_str == 'cosine': 12 | return get_cosine_distance_batch 13 | if name_str == 'kl': 14 | return get_kl_distance_batch 15 | if name_str == 'kl_symmetric': 16 | return get_kl_distance_symmetric_batch 17 | raise ValueError(f"Invalid distance mode") 18 | 19 | 20 | def check_dtw_group_validity(a, b, x): 21 | assert(len(a.size()) == len(b.size())) 22 | assert(len(a.size()) == len(x.size())) 23 | assert(a.size(2) == x.size(2)) 24 | assert(a.size(2) == b.size(2)) 25 | 26 | def get_kl_distance_batch(a1, a2, epsilon=1e-6): 27 | N1, S1, D = a1.size() # Batch x Seq x Channel 28 | N2, S2, D = a2.size() # Batch x Seq x Channel 29 | 30 | # (P * (P / Q).log()).sum() 31 | div = (a1.view(N1, 1, S1, 1, D) + epsilon) / (a2.view(1, N2, 1, S2, D) + epsilon) 32 | prod = (a1.view(N1, 1, S1, 1, D)) * div.log() 33 | 34 | return prod.sum(dim=4) 35 | 36 | def get_kl_distance_symmetric_batch(a1, a2, epsilon=1e-6): 37 | N1, S1, D = a1.size() 38 | N2, S2, D = a2.size() 39 | 40 | div1 = (a1.view(N1, 1, S1, 1, D) + epsilon) / (a2.view(1, N2, 1, S2, D) + epsilon) 41 | div2 = (a2.view(1, N2, 1, S2, D) + epsilon) / (a1.view(N1, 1, S1, 1, D) + epsilon) 42 | 43 | prod1 = (a1.view(N1, 1, S1, 1, D)) * div1.log() 44 | prod2 = (a2.view(1, N2, 1, S2, D)) * div2.log() 45 | 46 | return (0.5*prod1 + 0.5*prod2).sum(dim=4) 47 | 48 | def get_cosine_distance_batch(a1, a2, epsilon=1e-8): 49 | r""" a1 and a2 must be normalized""" 50 | N1, S1, D = a1.size() # Batch x Seq x Channel 51 | N2, S2, D = a2.size() # Batch x Seq x Channel 52 | 53 | prod = (a1.view(N1, 1, S1, 1, D)) * (a2.view(1, N2, 1, S2, D)) 54 | # Sum accross the channel dimension 55 | prod = torch.clamp(prod.sum(dim=4), -1, 1).acos() / math.pi 56 | 57 | return prod 58 | 59 | 60 | def get_euclidian_distance_batch(a1, a2): 61 | N1, S1, D = a1.size() 62 | N2, S2, D = a2.size() 63 | diff = a1.view(N1, 1, S1, 1, D) - a2.view(1, N2, 1, S2, D) 64 | return torch.sqrt((diff**2).sum(dim=4)) 65 | 66 | 67 | def get_distance_group_dtw(a1, a2, size1, size2, 68 | ignore_diag=False, symmetric=False, 69 | distance_function=get_cosine_distance_batch): 70 | 71 | N1, S1, D = a1.size() 72 | N2, S2, D = a2.size() 73 | if size1.size(0) != N1: 74 | print(a1.size(), size1.size()) 75 | print(a2.size(), size2.size()) 76 | assert(size1.size(0) == N1) 77 | assert(size2.size(0) == N2) 78 | 79 | distance_mat = distance_function(a1, a2).detach().cpu().numpy() 80 | return dtw.dtw_batch(a1, a2, size1, size2, 81 | distance_mat, 82 | ignore_diag, symmetric) 83 | 84 | 85 | def get_theta_group_dtw(a, b, x, sa, sb, sx, distance_function, symmetric): 86 | 87 | check_dtw_group_validity(a, b, x) 88 | 89 | dxb = get_distance_group_dtw( 90 | x, b, sx, sb, distance_function=distance_function) 91 | dxa = get_distance_group_dtw(x, a, sx, sa, ignore_diag=symmetric, 92 | symmetric=symmetric, 93 | distance_function=distance_function) 94 | 95 | Nx, Na = dxa.size() 96 | Nx, Nb = dxb.size() 97 | 98 | if symmetric: 99 | n_pos = Na * (Na - 1) 100 | max_val = dxb.max().item() 101 | for i in range(Na): 102 | dxa[i, i] = max_val + 1 103 | else: 104 | n_pos = Na * Nx 105 | 106 | dxb = dxb.view(Nx, 1, Nb).expand(Nx, Na, Nb) 107 | dxa = dxa.view(Nx, Na, 1).expand(Nx, Na, Nb) 108 | 109 | sc = (dxa < dxb).sum() + 0.5 * (dxa == dxb).sum() 110 | sc /= (n_pos * Nb) 111 | 112 | return sc.item() 113 | 114 | 115 | def loc_dtw(data, distance_function, symmetric): 116 | coords, group_a, group_b, group_x = data 117 | group_a_data, group_a_size = group_a 118 | group_b_data, group_b_size = group_b 119 | group_x_data, group_x_size = group_x 120 | theta = get_theta_group_dtw(group_a_data, 121 | group_b_data, 122 | group_x_data, 123 | group_a_size, 124 | group_b_size, 125 | group_x_size, 126 | distance_function, 127 | symmetric) 128 | 129 | return (coords, 1 - theta) 130 | 131 | 132 | def get_abx_scores_dtw_on_group(group_iterator, 133 | distance_function, 134 | symmetric): 135 | 136 | data_list = [] 137 | coords_list = [] 138 | bar = progressbar.ProgressBar(maxval=len(group_iterator)) 139 | bar.start() 140 | 141 | with torch.no_grad(): 142 | for index, group in enumerate(group_iterator): 143 | bar.update(index) 144 | coords, abx = loc_dtw(group, distance_function, symmetric) 145 | data_list.append(abx) 146 | coords_list.append(coords) 147 | bar.finish() 148 | 149 | return torch.sparse.FloatTensor(torch.LongTensor(coords_list).t(), 150 | torch.FloatTensor(data_list), 151 | group_iterator.get_board_size()) 152 | -------------------------------------------------------------------------------- /data_preparation/text_retrieval/bartleby.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from html.parser import HTMLParser 3 | import requests 4 | import time 5 | 6 | 7 | class BarthelebyParser(HTMLParser): 8 | 9 | from enum import Enum 10 | GLOBAL_STATUS = Enum('STATUS', 'NONE IN_TITLE IN_CHAPTER') 11 | LOCAL_STATUS = Enum('STATUS', 'NONE PARAGRAPH') 12 | 13 | def __init__(self): 14 | super(BarthelebyParser, self).__init__() 15 | self.text = "" 16 | self.global_status = BarthelebyParser.GLOBAL_STATUS.NONE 17 | self.local_status = BarthelebyParser.LOCAL_STATUS.NONE 18 | self.title = "" 19 | self.ignore = False 20 | self.textFound = False 21 | 22 | def handle_comment(self, tag): 23 | if tag.find("BEGIN CHAPTERTITLE") >= 0: 24 | self.global_status = BarthelebyParser.GLOBAL_STATUS.IN_TITLE 25 | elif tag.find("END CHAPTERTITLE") >= 0: 26 | if not self.global_status == BarthelebyParser.GLOBAL_STATUS.IN_TITLE: 27 | raise RuntimeError("Page of invalid format") 28 | self.global_status = BarthelebyParser.GLOBAL_STATUS.NONE 29 | elif tag.find("BEGIN CHAPTER") >= 0 or tag.find("END MAIN HEADER CODE") >= 0: 30 | self.global_status = BarthelebyParser.GLOBAL_STATUS.IN_CHAPTER 31 | self.local_status = BarthelebyParser.LOCAL_STATUS.NONE 32 | self.textFound = True 33 | elif tag.find("END CHAPTER") >= 0 or tag.find("AMAZON") >= 0: 34 | if not self.global_status == BarthelebyParser.GLOBAL_STATUS.IN_CHAPTER: 35 | raise RuntimeError("Page of invalid format") 36 | self.global_status = BarthelebyParser.GLOBAL_STATUS.NONE 37 | 38 | def handle_data(self, data): 39 | if self.global_status == BarthelebyParser.GLOBAL_STATUS.IN_TITLE: 40 | if self.local_status == BarthelebyParser.LOCAL_STATUS.PARAGRAPH: 41 | self.title += data 42 | elif self.global_status == BarthelebyParser.GLOBAL_STATUS.IN_CHAPTER: 43 | if self.local_status == BarthelebyParser.LOCAL_STATUS.PARAGRAPH and not self.ignore: 44 | self.text += data 45 | 46 | def handle_starttag(self, tag, attrs): 47 | if self.global_status == BarthelebyParser.GLOBAL_STATUS.IN_TITLE: 48 | if tag == 'b': 49 | self.local_status = BarthelebyParser.LOCAL_STATUS.PARAGRAPH 50 | elif self.global_status == BarthelebyParser.GLOBAL_STATUS.IN_CHAPTER: 51 | if tag == 'tr': 52 | if self.local_status != BarthelebyParser.LOCAL_STATUS.NONE: 53 | self.text += '\n' 54 | self.local_status = BarthelebyParser.LOCAL_STATUS.PARAGRAPH 55 | if tag == "i": 56 | self.ignore = True 57 | 58 | def handle_endtag(self, tag): 59 | if self.global_status == BarthelebyParser.GLOBAL_STATUS.IN_TITLE and tag == 'b': 60 | self.local_status = BarthelebyParser.LOCAL_STATUS.NONE 61 | if tag == "i": 62 | self.ignore = False 63 | 64 | def getCleanText(self): 65 | return self.text.replace('\\n', '\n').replace("\\'", "'") 66 | 67 | 68 | class BarthelebyTitleParser(HTMLParser): 69 | 70 | def __init__(self): 71 | super(BarthelebyTitleParser, self).__init__() 72 | self.titleFound = False 73 | self.load = False 74 | self.title = "" 75 | 76 | def handle_starttag(self, tag, attr): 77 | if tag == "title": 78 | self.titleFound = True 79 | self.load = True 80 | 81 | def handle_endtag(self, tag): 82 | if tag == "title": 83 | self.load = False 84 | 85 | def handle_data(self, data): 86 | if self.load: 87 | self.title = data 88 | 89 | 90 | def get_bartheleby_data(url): 91 | 92 | extension = url.split('.')[-1] 93 | isUniquePage = extension == 'html' 94 | 95 | def loadText(locUrl): 96 | parser = BarthelebyParser() 97 | req = requests.get(locUrl) 98 | parser.feed(str(req._content)) 99 | time.sleep(1) 100 | if not parser.textFound: 101 | return None 102 | return parser.title + '\n' + '\n' + parser.getCleanText() 103 | 104 | if not isUniquePage: 105 | 106 | # Load title 107 | parser = BarthelebyTitleParser() 108 | req = requests.get(url) 109 | parser.feed(str(req._content)) 110 | 111 | if not parser.titleFound: 112 | raise RuntimeError("No title found") 113 | 114 | fullText = parser.title + '\n' + '\n' 115 | 116 | if url[-1] != '/': 117 | url += '/' 118 | data = url.split('/') 119 | 120 | try: 121 | int(data[-2]) 122 | except ValueError: 123 | raise RuntimeError("Invalid url") 124 | 125 | index = 1 126 | while True: 127 | nextUrl = f"{url}{index}.html" 128 | textData = loadText(nextUrl) 129 | if textData is None: 130 | break 131 | fullText += '\n\n' + textData 132 | index += 1 133 | 134 | return fullText 135 | 136 | text = loadText(url) 137 | if text is None: 138 | raise RuntimeError("Couldn't find the page") 139 | return text 140 | 141 | 142 | def is_bartheleby_url(url): 143 | return url.find("bartleby.com") >= 0 144 | 145 | 146 | if __name__ == "__main__": 147 | 148 | url = "https://www.bartleby.com/95/1.html" 149 | data = get_bartheleby_data(url) 150 | with open('coin.txt', 'w') as file: 151 | file.write(data) 152 | #parser = BarthelebyParser() 153 | #req = requests.get(url) 154 | # parser.feed(str(req._content)) 155 | # print(parser.title) 156 | #print(parser.text.replace('\\n', '\n').replace("\\'", "'")) 157 | -------------------------------------------------------------------------------- /data_preparation/metadata_completion/DuplicateSearch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import os 3 | import json 4 | import progressbar 5 | 6 | 7 | def getSameAuthorGroups(listMetadata, pathDIR): 8 | 9 | output = {} 10 | nEmpty = 0 11 | nSeverals = 0 12 | for metadata_name in listMetadata: 13 | 14 | fullPath = os.path.join(pathDIR, metadata_name) 15 | with open(fullPath, 'rb') as file: 16 | authorsData = json.load(file)['authors'] 17 | 18 | if len(authorsData) == 0: 19 | authorIDs = [-1] 20 | nEmpty += 1 21 | else: 22 | authorIDs = [] 23 | for author in authorsData: 24 | id = author["id"] 25 | if id is None: 26 | id = -1 27 | else: 28 | id = int(id) 29 | authorIDs.append(id) 30 | 31 | if len(authorIDs) > 1: 32 | nSeverals += 1 33 | 34 | for id in authorIDs: 35 | if id not in output: 36 | output[id] = set() 37 | output[id].add(metadata_name) 38 | 39 | print(f"{nEmpty} books without author, {nSeverals} with several authors") 40 | return output 41 | 42 | 43 | def getBaseStringData(in_str): 44 | 45 | in_str = in_str.lower() 46 | tmp = in_str.split() 47 | out = [] 48 | 49 | for word in tmp: 50 | word = ''.join([char for char in word if char.isalnum()]) 51 | if len(word) == 0: 52 | continue 53 | out.append(word) 54 | 55 | return out 56 | 57 | 58 | def getTitleSimilarityScore(title1, title2): 59 | 60 | title1 = set(getBaseStringData(title1)) 61 | title2 = set(getBaseStringData(title2)) 62 | 63 | nCommon = len(title1.intersection(title2)) 64 | nUnion = len(title1.union(title2)) 65 | 66 | if nUnion == 0: 67 | return 0 68 | 69 | return nCommon / nUnion 70 | 71 | 72 | def getBaseTitle(title): 73 | 74 | in_str = title.lower() 75 | 76 | labelWords = ["dramatic reading", "abridged"] 77 | tags = {} 78 | 79 | for label in labelWords: 80 | if in_str.find(label) >= 0: 81 | tags[label] = True 82 | in_str = in_str.replace(label, '') 83 | 84 | tmp = in_str.split() 85 | baseTitle = "" 86 | 87 | index = 0 88 | nItems = len(tmp) 89 | tmp = [''.join([char for char in word if char.isalnum()]) for word in tmp] 90 | 91 | keyWords = ["version", "vol", "chapter", "part", "volume", "book"] 92 | forbiddenWords = ["a", "the", "of", "in"] 93 | 94 | while index < nItems: 95 | word = tmp[index] 96 | if word in keyWords and index < nItems - 1: 97 | if tmp[index+1].isdigit(): 98 | tags[word] = int(tmp[index+1]) 99 | index += 2 100 | continue 101 | elif len(word) > 0 and word not in forbiddenWords: 102 | if len(baseTitle) > 0: 103 | baseTitle += " " 104 | baseTitle += word 105 | index += 1 106 | 107 | return baseTitle, tags 108 | 109 | 110 | def prepareMatches(listMetadata, pathDIR): 111 | 112 | authorGroups = getSameAuthorGroups(listMetadata, pathDIR) 113 | authorGroups = [list(authorGroups[x]) 114 | for x in authorGroups if len(authorGroups[x]) > 1] 115 | print(f"{len(authorGroups)} groups of books with the same author") 116 | print("Preparing the data...") 117 | 118 | output = [] 119 | 120 | bar = progressbar.ProgressBar(len(authorGroups)) 121 | bar.start() 122 | 123 | for index, group in enumerate(authorGroups): 124 | bar.update(index) 125 | nItems = len(group) 126 | match = [] 127 | for i in range(nItems): 128 | pathMetadata = os.path.join(pathDIR, group[i]) 129 | with open(pathMetadata, 'rb') as file: 130 | title_i = json.load(file)["title"] 131 | baseTitle, code = getBaseTitle(title_i) 132 | match.append((baseTitle, code, group[i])) 133 | output.append(match) 134 | bar.finish() 135 | 136 | return output 137 | 138 | 139 | def getPossibleMatches(allGroups): 140 | 141 | output = [] 142 | for group in allGroups: 143 | group.sort(key=lambda x: x[0]) 144 | groupSize = len(group) 145 | indexStart = 0 146 | while indexStart < groupSize - 1: 147 | currMatch = [] 148 | currTitle, currTags, currMetdataName = group[indexStart] 149 | for indexEnd in range(indexStart + 1, groupSize): 150 | nextTitle, nextTags, nextMetadataName = group[indexEnd] 151 | isSame = True 152 | if currTitle == nextTitle: 153 | for tag in currTags: 154 | if tag in ["version", "abridged", "dramatic reading"]: 155 | continue 156 | if nextTags.get(tag, None) != currTags[tag]: 157 | isSame = False 158 | break 159 | if isSame: 160 | currMatch.append(nextMetadataName) 161 | else: 162 | break 163 | indexStart = indexEnd 164 | if len(currMatch) > 0: 165 | currMatch.append(currMetdataName) 166 | output.append(currMatch) 167 | return output 168 | 169 | 170 | def get_books_duplicates(pathDIRMetadata, listMetadata): 171 | 172 | matches = prepareMatches(listMetadata, pathDIRMetadata) 173 | print("Retriveing the possible matches") 174 | matches = getPossibleMatches(matches) 175 | return matches 176 | 177 | 178 | if __name__ == "__main__": 179 | from sumUp import getAllMetadata 180 | 181 | pathDIRMetadata = "/checkpoint/mriviere/LibriVox_full_metadata/" 182 | pathOut = "/checkpoint/mriviere/LibriVox_titleDuplicates.json" 183 | listMetadata = getAllMetadata(pathDIRMetadata) 184 | 185 | matches = get_books_duplicates(pathDIRMetadata, listMetadata) 186 | 187 | with open(pathOut, 'w') as file: 188 | json.dump(matches, file, indent=2) 189 | -------------------------------------------------------------------------------- /data_preparation/split_librilight/split.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from collections import defaultdict 3 | import argparse 4 | import json 5 | import pathlib 6 | import math 7 | 8 | 9 | def get_stats(fnames, fnames2jsons): 10 | 11 | total_seconds = sum(fnames2jsons[fname] 12 | ['file_length_sec'] for fname in fnames) 13 | 14 | files_per_genre = defaultdict(int) 15 | seconds_per_genre = defaultdict(int) 16 | snr_per_genre = defaultdict(float) 17 | 18 | mean_snr = 0.0 19 | 20 | unique_speakers = set() 21 | unique_books = set() 22 | 23 | for fname in fnames: 24 | data = fnames2jsons[fname] 25 | 26 | snr = data['snr'] if not math.isnan(data['snr']) else 0.0 27 | seconds = data['file_length_sec'] 28 | 29 | mean_snr += snr * seconds 30 | 31 | if 'genre' not in data['book_meta'] or data['book_meta']['genre'] is None: 32 | file_genres = [''] 33 | else: 34 | file_genres = data['book_meta']['genre'] 35 | 36 | for genre in file_genres: 37 | files_per_genre[genre] += 1 38 | seconds_per_genre[genre] += seconds 39 | snr_per_genre[genre] += snr * seconds 40 | 41 | unique_speakers.add(data['speaker']) 42 | unique_books.add(data['book_meta']['id']) 43 | 44 | for g in snr_per_genre: 45 | snr_per_genre[g] /= seconds_per_genre[g] 46 | 47 | mean_snr /= total_seconds 48 | return seconds_per_genre, files_per_genre, snr_per_genre, total_seconds, unique_books, unique_speakers, mean_snr 49 | 50 | 51 | def get_genre2time(fnames, fnames2jsons): 52 | seconds_per_genre = defaultdict(int) 53 | 54 | for fname in fnames: 55 | data = fnames2jsons[fname] 56 | if 'genre' not in data['book_meta'] or data['book_meta']['genre'] is None: 57 | file_genres = [''] 58 | else: 59 | file_genres = data['book_meta']['genre'] 60 | 61 | for genre in file_genres: 62 | seconds_per_genre[genre] += data['file_length_sec'] 63 | 64 | return seconds_per_genre 65 | 66 | 67 | def get_genre2files(fnames, fnames2jsons): 68 | genre_files = defaultdict(list) 69 | 70 | for fname in fnames: 71 | data = fnames2jsons[fname] 72 | if 'genre' not in data['book_meta'] or data['book_meta']['genre'] is None: 73 | file_genres = [''] 74 | else: 75 | file_genres = data['book_meta']['genre'] 76 | 77 | for genre in file_genres: 78 | genre_files[genre].append(fname) 79 | 80 | return genre_files 81 | 82 | 83 | def get_fname2json(fnames): 84 | fname2json = {} 85 | for fname in fnames: 86 | with open(fname, 'r') as f: 87 | data = json.load(f) 88 | fname2json[fname] = data 89 | return fname2json 90 | 91 | 92 | def subselect(fnames, files2jsons, divisor=10): 93 | overall_time = sum( 94 | fnames2jsons[fname]['file_length_sec'] for fname in fnames) 95 | print('Selecting from', overall_time / 60 / 60, 'hours') 96 | 97 | genre2time = get_genre2time(fnames, fnames2jsons) 98 | 99 | genre2budget = {} 100 | for genre, time in genre2time.items(): 101 | genre2budget[genre] = time // divisor 102 | 103 | time_selected = 0 104 | selected_files = [] 105 | 106 | for fname in fnames: 107 | if time_selected > overall_time // divisor: 108 | break 109 | 110 | data = fnames2jsons[fname] 111 | if 'genre' not in data['book_meta'] or data['book_meta']['genre'] is None: 112 | file_genres = [''] 113 | else: 114 | file_genres = data['book_meta']['genre'] 115 | length = data['file_length_sec'] 116 | 117 | fits = True 118 | for file_genre in file_genres: 119 | fits = fits and ( 120 | file_genre not in genre2budget or genre2budget[file_genre] > length) 121 | 122 | if fits: 123 | time_selected += length 124 | selected_files.append(fname) 125 | for file_genre in file_genres: 126 | if file_genre in genre2budget: 127 | genre2budget[file_genre] -= length 128 | 129 | overall_time = sum( 130 | fnames2jsons[fname]['file_length_sec'] for fname in selected_files) 131 | print('Selected', overall_time / 60 / 60, 'hours') 132 | 133 | return selected_files 134 | 135 | 136 | def take_n(x, n): 137 | for i, k in enumerate(x): 138 | yield k 139 | 140 | if i == n - 1: 141 | break 142 | 143 | 144 | def get_args(): 145 | parser = argparse.ArgumentParser(description='Reads a direcctory with flac/meta-data files and decides how to split them in ' 146 | 'three nested sets, roughly balancing genres') 147 | parser.add_argument('--librivox_processed', type=str) 148 | parser.add_argument('--sampling_steps', type=int, default=3) 149 | parser.add_argument('--size_divisor', type=int, default=10) 150 | parser.add_argument('--debug', action='store_true') 151 | 152 | args = parser.parse_args() 153 | return args 154 | 155 | 156 | if __name__ == '__main__': 157 | args = get_args() 158 | 159 | fnames = list( 160 | take_n(pathlib.Path(args.librivox_processed).rglob('*.json'), n=1000 if args.debug else -1)) 161 | fnames2jsons = get_fname2json(fnames) 162 | 163 | for sampling_step in range(args.sampling_steps): 164 | seconds_per_genre, files_per_genre, snr_per_genre, total_seconds, unique_books, unique_speakers, mean_snr = get_stats( 165 | fnames, fnames2jsons) 166 | 167 | print('Total seconds', total_seconds, ' = ', 168 | total_seconds / 60 / 60, ' hours') 169 | print('Unique speakers', len(unique_speakers), ' unique books', 170 | len(unique_books), ' files ', len(fnames)) 171 | print('Time-weighted snr', mean_snr) 172 | 173 | with open(f'split_{sampling_step}.json', 'w') as f: 174 | dump = [(genre, {'seconds': seconds, 'hours': seconds / 60 / 60, 175 | 'files': files_per_genre[genre], 176 | 'mean_snr': snr_per_genre[genre]}) for (genre, seconds) in seconds_per_genre.items()] 177 | 178 | fnames_as_str = [str(f) for f in fnames] 179 | f.write(json.dumps({ 180 | 'distribution': dump, 181 | 'files': fnames_as_str, 182 | 'n_speakers': len(unique_speakers), 183 | 'n_books': len(unique_books), 184 | 'n_files': len(fnames), 185 | 'time_weighted_snr': mean_snr}, 186 | indent=1)) 187 | 188 | fnames = subselect( 189 | fnames, fnames2jsons, divisor=args.size_divisor) 190 | -------------------------------------------------------------------------------- /eval/PER_src/simplePhonemLearner.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import torchaudio 3 | from copy import deepcopy 4 | import torch 5 | import time 6 | from pathlib import Path 7 | from torch.utils.data import Dataset 8 | from torch.multiprocessing import Pool 9 | 10 | 11 | def load(path_item): 12 | seq_name = path_item.stem 13 | data = torchaudio.load(str(path_item))[0].view(1, -1) 14 | return seq_name, data 15 | 16 | 17 | class SingleSequenceDataset(Dataset): 18 | 19 | def __init__(self, 20 | pathDB, 21 | seqNames, 22 | phoneLabelsDict, 23 | inDim=1, 24 | transpose=True): 25 | """ 26 | Args: 27 | - path (string): path to the training dataset 28 | - sizeWindow (int): size of the sliding window 29 | - seqNames (list): sequences to load 30 | - phoneLabels (dictionnary): if not None, a dictionnary with the 31 | following entries 32 | 33 | "step": size of a labelled window 34 | "$SEQ_NAME": list of phonem labels for 35 | the sequence $SEQ_NAME 36 | """ 37 | self.seqNames = deepcopy(seqNames) 38 | self.pathDB = pathDB 39 | self.phoneLabelsDict = deepcopy(phoneLabelsDict) 40 | self.inDim = inDim 41 | self.transpose = transpose 42 | self.loadSeqs() 43 | 44 | def loadSeqs(self): 45 | 46 | # Labels 47 | self.seqOffset = [0] 48 | self.phoneLabels = [] 49 | self.phoneOffsets = [0] 50 | self.data = [] 51 | self.maxSize = 0 52 | self.maxSizePhone = 0 53 | 54 | # Data 55 | 56 | nprocess = min(30, len(self.seqNames)) 57 | 58 | start_time = time.time() 59 | to_load = [Path(self.pathDB) / x for _, x in self.seqNames] 60 | 61 | with Pool(nprocess) as p: 62 | poolData = p.map(load, to_load) 63 | 64 | tmpData = [] 65 | poolData.sort() 66 | 67 | totSize = 0 68 | minSizePhone = 1000000 69 | for seqName, seq in poolData: 70 | self.phoneLabels += self.phoneLabelsDict[seqName] 71 | self.phoneOffsets.append(len(self.phoneLabels)) 72 | self.maxSizePhone = max(self.maxSizePhone, 73 | len(self.phoneLabelsDict[seqName])) 74 | minSizePhone = min(minSizePhone, len( 75 | self.phoneLabelsDict[seqName])) 76 | sizeSeq = seq.size(1) 77 | self.maxSize = max(self.maxSize, sizeSeq) 78 | totSize += sizeSeq 79 | tmpData.append(seq) 80 | self.seqOffset.append(self.seqOffset[-1] + sizeSeq) 81 | del seq 82 | self.data = torch.cat(tmpData, dim=1) 83 | self.phoneLabels = torch.tensor(self.phoneLabels, dtype=torch.long) 84 | print(f'Loaded {len(self.phoneOffsets)} sequences ' 85 | f'in {time.time() - start_time:.2f} seconds') 86 | print(f'maxSizeSeq : {self.maxSize}') 87 | print(f'maxSizePhone : {self.maxSizePhone}') 88 | print(f"minSizePhone : {minSizePhone}") 89 | print(f'Total size dataset {totSize / (16000 * 3600)} hours') 90 | 91 | def __getitem__(self, idx): 92 | 93 | offsetStart = self.seqOffset[idx] 94 | offsetEnd = self.seqOffset[idx+1] 95 | offsetPhoneStart = self.phoneOffsets[idx] 96 | offsetPhoneEnd = self.phoneOffsets[idx + 1] 97 | 98 | sizeSeq = int(offsetEnd - offsetStart) 99 | sizePhone = int(offsetPhoneEnd - offsetPhoneStart) 100 | 101 | outSeq = torch.zeros((self.inDim, self.maxSize)) 102 | outPhone = torch.zeros((self.maxSizePhone)) 103 | 104 | outSeq[:, :sizeSeq] = self.data[:, offsetStart:offsetEnd] 105 | outPhone[:sizePhone] = self.phoneLabels[offsetPhoneStart:offsetPhoneEnd] 106 | 107 | return outSeq, torch.tensor([sizeSeq], dtype=torch.long), outPhone.long(), torch.tensor([sizePhone], dtype=torch.long) 108 | 109 | def __len__(self): 110 | return len(self.seqOffset) - 1 111 | 112 | 113 | class CTCPhoneCriterion(torch.nn.Module): 114 | 115 | def __init__(self, dimEncoder, nPhones, LSTM=False, sizeKernel=8, 116 | seqNorm=False, dropout=False, reduction='mean'): 117 | 118 | super(CTCPhoneCriterion, self).__init__() 119 | self.seqNorm = seqNorm 120 | self.epsilon = 1e-8 121 | self.dropout = torch.nn.Dropout2d( 122 | p=0.5, inplace=False) if dropout else None 123 | self.conv1 = torch.nn.LSTM(dimEncoder, dimEncoder, 124 | num_layers=1, batch_first=True) 125 | self.PhoneCriterionClassifier = torch.nn.Conv1d( 126 | dimEncoder, nPhones + 1, sizeKernel, stride=sizeKernel // 2) 127 | self.lossCriterion = torch.nn.CTCLoss(blank=nPhones, 128 | reduction=reduction, 129 | zero_infinity=True) 130 | self.relu = torch.nn.ReLU() 131 | self.BLANK_LABEL = nPhones 132 | self.useLSTM = LSTM 133 | 134 | def getPrediction(self, cFeature): 135 | B, S, H = cFeature.size() 136 | if self.seqNorm: 137 | m = cFeature.mean(dim=1, keepdim=True) 138 | v = cFeature.var(dim=1, keepdim=True) 139 | cFeature = (cFeature - m) / torch.sqrt(v + self.epsilon) 140 | if self.useLSTM: 141 | cFeature = self.conv1(cFeature)[0] 142 | 143 | cFeature = cFeature.permute(0, 2, 1) 144 | 145 | if self.dropout is not None: 146 | cFeature = self.dropout(cFeature) 147 | 148 | return self.PhoneCriterionClassifier(cFeature).permute(0, 2, 1) 149 | 150 | def forward(self, cFeature, featureSize, label, labelSize): 151 | 152 | # cFeature.size() : batchSize x seq Size x hidden size 153 | B, S, H = cFeature.size() 154 | predictions = self.getPrediction(cFeature) 155 | featureSize /= 4 156 | predictions = cutData(predictions, featureSize) 157 | featureSize = torch.clamp(featureSize, max=predictions.size(1)) 158 | label = cutData(label, labelSize) 159 | if labelSize.min() <= 0: 160 | print(label, labelSize) 161 | predictions = torch.nn.functional.log_softmax(predictions, dim=2) 162 | predictions = predictions.permute(1, 0, 2) 163 | loss = self.lossCriterion(predictions, label, 164 | featureSize, labelSize).view(1, -1) 165 | 166 | if torch.isinf(loss).sum() > 0 or torch.isnan(loss).sum() > 0: 167 | loss = 0 168 | 169 | return loss 170 | 171 | 172 | def cutData(seq, sizeSeq): 173 | maxSeq = sizeSeq.max() 174 | seq = seq[:, :maxSeq] 175 | return seq 176 | 177 | 178 | def prepareData(data): 179 | seq, sizeSeq, phone, sizePhone = data 180 | seq = seq.cuda(non_blocking=True) 181 | phone = phone.cuda(non_blocking=True) 182 | sizeSeq = sizeSeq.cuda(non_blocking=True).view(-1) 183 | sizePhone = sizePhone.cuda(non_blocking=True).view(-1) 184 | 185 | seq = cutData(seq, sizeSeq) 186 | 187 | return seq, sizeSeq, phone, sizePhone 188 | 189 | 190 | def trainStep(trainLoader, 191 | model, 192 | criterion, 193 | optimizer, 194 | downsamplingFactor): 195 | 196 | if model.optimize: 197 | model.train() 198 | 199 | criterion.train() 200 | avg_loss = 0 201 | nItems = 0 202 | 203 | for data in trainLoader: 204 | optimizer.zero_grad() 205 | seq, sizeSeq, phone, sizePhone = prepareData(data) 206 | c_feature = model(seq) 207 | sizeSeq = sizeSeq / downsamplingFactor 208 | loss = criterion(c_feature, sizeSeq, phone, sizePhone) 209 | loss.mean().backward() 210 | 211 | avg_loss += loss.mean().item() 212 | nItems += 1 213 | optimizer.step() 214 | 215 | return avg_loss / nItems 216 | 217 | 218 | def valStep(valLoader, 219 | model, 220 | criterion, 221 | downsamplingFactor): 222 | 223 | model.eval() 224 | criterion.eval() 225 | avg_loss = 0 226 | nItems = 0 227 | 228 | for data in valLoader: 229 | with torch.no_grad(): 230 | seq, sizeSeq, phone, sizePhone = prepareData(data) 231 | c_feature = model(seq) 232 | sizeSeq = sizeSeq / downsamplingFactor 233 | loss = criterion(c_feature, sizeSeq, phone, sizePhone) 234 | avg_loss += loss.mean().item() 235 | nItems += 1 236 | 237 | return avg_loss / nItems 238 | -------------------------------------------------------------------------------- /eval/eval_ABX.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import argparse 3 | import sys 4 | import torch 5 | import json 6 | import os 7 | import numpy as np 8 | import ABX_src.abx_group_computation as abx_g 9 | import ABX_src.abx_iterators as abx_it 10 | from CPC_loader import load_cpc_features, build_feature_from_file 11 | from pathlib import Path 12 | 13 | 14 | def find_all_files(path_dir, extension): 15 | out = [] 16 | for root, dirs, filenames in os.walk(path_dir): 17 | for f in filenames: 18 | if f.endswith(extension): 19 | out.append(((str(Path(f).stem)), os.path.join(root, f))) 20 | return out 21 | 22 | 23 | def reduce_sparse_data(quotient, divisor): 24 | return quotient / (1e-08 * (divisor == 0) + divisor) 25 | 26 | 27 | def load_pt(x): 28 | data = torch.load(x, 'cpu') 29 | assert(len(data.size()) == 2) 30 | return data 31 | 32 | 33 | def load_npy(x): 34 | data = torch.tensor(np.load(x)) 35 | assert(len(data.size()) == 2) 36 | return data 37 | 38 | 39 | def ABX(feature_function, 40 | path_item_file, 41 | seq_list, 42 | distance_mode, 43 | step_feature, 44 | modes, 45 | cuda=False, 46 | max_x_across=5, 47 | max_size_group=30): 48 | 49 | # ABX dataset 50 | ABXDataset = abx_it.ABXFeatureLoader(path_item_file, seq_list, 51 | feature_function, step_feature, True) 52 | 53 | if cuda: 54 | ABXDataset.cuda() 55 | 56 | # Distance function 57 | distance_function = abx_g.get_distance_function_from_name(distance_mode) 58 | 59 | # Output 60 | scores = {} 61 | 62 | # ABX within 63 | if 'within' in modes: 64 | print("Computing ABX within speakers...") 65 | ABXIterator = ABXDataset.get_iterator('within', max_size_group) 66 | group_confusion = abx_g.get_abx_scores_dtw_on_group(ABXIterator, 67 | distance_function, 68 | ABXIterator.symmetric) 69 | n_data = group_confusion._values().size(0) 70 | index_ = torch.sparse.LongTensor(group_confusion._indices(), 71 | torch.ones((n_data), 72 | dtype=torch.float), 73 | group_confusion.size()) 74 | divisor_context = torch.sparse.sum(index_, dim=3).to_dense() 75 | group_confusion = torch.sparse.sum(group_confusion, dim=3).to_dense() 76 | group_confusion = reduce_sparse_data(group_confusion, divisor_context) 77 | S, p1, p2 = group_confusion.size() 78 | 79 | index_speaker = divisor_context > 0 80 | divisor_speaker = index_speaker.sum(dim=0) 81 | phone_confusion = reduce_sparse_data(group_confusion.sum(dim=0), 82 | divisor_speaker) 83 | 84 | scores['within'] = (phone_confusion.sum() / 85 | (divisor_speaker > 0).sum()).item() 86 | print(f"...done. ABX within : {scores['within']}") 87 | 88 | # ABX across 89 | if 'across' in modes: 90 | print("Computing ABX across speakers...") 91 | ABXIterator = ABXDataset.get_iterator('across', max_size_group) 92 | ABXIterator.max_x = max_x_across 93 | group_confusion = abx_g.get_abx_scores_dtw_on_group(ABXIterator, 94 | distance_function, 95 | ABXIterator.symmetric) 96 | n_data = group_confusion._values().size(0) 97 | index_ = torch.sparse.LongTensor(group_confusion._indices(), 98 | torch.ones((n_data), 99 | dtype=torch.float), 100 | group_confusion.size()) 101 | divisor_context = torch.sparse.sum(index_, dim=[3, 4]).to_dense() 102 | group_confusion = torch.sparse.sum( 103 | group_confusion, dim=[3, 4]).to_dense() 104 | group_confusion = reduce_sparse_data(group_confusion, divisor_context) 105 | S, p1, p2 = group_confusion.size() 106 | 107 | index_speaker = divisor_context > 0 108 | divisor_speaker = index_speaker.sum(dim=0) 109 | phone_confusion = reduce_sparse_data(group_confusion.sum(dim=0), 110 | divisor_speaker) 111 | scores['across'] = (phone_confusion.sum() / 112 | (divisor_speaker > 0).sum()).item() 113 | print(f"...done. ABX across : {scores['across']}") 114 | 115 | return scores 116 | 117 | 118 | def parse_args(argv): 119 | 120 | parser = argparse.ArgumentParser(description='ABX metric') 121 | 122 | parser.add_argument('path_data', type=str, 123 | help="Path to directory containing the data") 124 | parser.add_argument('path_item_file', type=str, 125 | help="Path to the .item file") 126 | parser.add_argument('--path_checkpoint', type=str, default=None, 127 | help="Path to a CPC checkpoint. If set, the apply the " 128 | "model to the input data to compute the features") 129 | parser.add_argument('--file_extension', type=str, default='.pt', 130 | choices=['.pt', '.npy', '.wav', '.flac', '.mp3']) 131 | parser.add_argument('--feature_size', type=float, default=0.01, 132 | help="Size (in s) of one feature") 133 | parser.add_argument('--cuda', action='store_true', 134 | help="Use the GPU to compute distances") 135 | parser.add_argument('--mode', type=str, default='all', 136 | choices=['all', 'within', 'across'], 137 | help="Choose the mode of the ABX score to compute") 138 | parser.add_argument('--distance_mode', type=str, default='cosine', 139 | choices=['euclidian', 'cosine', 'kl', 'kl_symmetric'], 140 | help="Choose the kind of distance to use to compute " 141 | "the ABX score.") 142 | parser.add_argument("--max_size_group", type=int, default=10, 143 | help="Max size of a group while computing the" 144 | "ABX score. A small value will make the code " 145 | "faster but less precise.") 146 | parser.add_argument("--max_x_across", type=int, default=5, 147 | help="When computing the ABX across score, maximum" 148 | "number of speaker X to sample per couple A,B. " 149 | " A small value will make the code faster but " 150 | "less precise.") 151 | parser.add_argument("--out", type=str, default=None, 152 | help="Path where the results should be saved") 153 | 154 | # multi-gpu / multi-node 155 | return parser.parse_args(argv) 156 | 157 | 158 | def main(argv): 159 | 160 | args = parse_args(argv) 161 | 162 | if args.path_checkpoint is None: 163 | if args.file_extension == '.pt': 164 | feature_function = load_pt 165 | elif args.file_extension == '.npy': 166 | feature_function = load_npy 167 | else: 168 | state_dict = torch.load(args.path_checkpoint) 169 | feature_maker = load_cpc_features(state_dict) 170 | feature_maker.cuda() 171 | def feature_function( 172 | x): return build_feature_from_file(x, feature_maker) 173 | 174 | # Modes 175 | if args.mode == 'all': 176 | modes = ["within", "across"] 177 | else: 178 | modes = [args.mode] 179 | 180 | step_feature = 1 / args.feature_size 181 | 182 | # Get the list of sequences 183 | seq_list = find_all_files(args.path_data, args.file_extension) 184 | 185 | scores = ABX(feature_function, args.path_item_file, 186 | seq_list, args.distance_mode, 187 | step_feature, modes, 188 | cuda=args.cuda, 189 | max_x_across=args.max_x_across, 190 | max_size_group=args.max_size_group) 191 | 192 | out_dir = Path(args.path_checkpoint).parent if args.out is None \ 193 | else Path(args.out) 194 | out_dir.mkdir(exist_ok=True) 195 | 196 | path_score = out_dir / 'ABX_scores.json' 197 | with open(path_score, 'w') as file: 198 | json.dump(scores, file, indent=2) 199 | 200 | path_args = out_dir / 'ABX_args.json' 201 | with open(path_args, 'w') as file: 202 | json.dump(vars(args), file, indent=2) 203 | 204 | 205 | if __name__ == "__main__": 206 | args = sys.argv[1:] 207 | main(args) 208 | -------------------------------------------------------------------------------- /eval/CPC_loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import argparse 3 | import torch 4 | import torchaudio 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | def download_state_dict(model_name): 10 | 11 | base_url = "https://dl.fbaipublicfiles.com/librilight/CPC_checkpoints" 12 | return torch.hub.load_state_dict_from_url(f"{base_url}/{model_name}") 13 | 14 | 15 | def load_cpc_features(state_dict): 16 | 17 | config = state_dict["config"] 18 | weights = state_dict["weights"] 19 | encoder = CPCEncoder(config["hiddenEncoder"]) 20 | ar_net = CPCAR(config["hiddenEncoder"], config["hiddenGar"], False, 21 | config["nLevelsGRU"]) 22 | 23 | model = CPCModel(encoder, ar_net) 24 | model.load_state_dict(weights, strict=False) 25 | output = FeatureModule(model, False) 26 | output.config = config 27 | return output 28 | 29 | 30 | def get_features_state_dict(feature_module): 31 | config = feature_module.config 32 | if config is None: 33 | raise ValueError("The input feature_module should have config defined") 34 | weights = feature_module.model.state_dict() 35 | return {"config": config, "weights": weights} 36 | 37 | 38 | def build_feature_from_file(file_path, feature_maker, max_size_seq=64000): 39 | r""" 40 | Apply the featureMaker to the given file. 41 | Arguments: 42 | - file_path (FeatureModule): model to apply 43 | - file_path (string): path of the sequence to load 44 | - seq_norm (bool): if True, normalize the output along the time 45 | dimension to get chunks of mean zero and var 1 46 | - max_size_seq (int): maximal size of a chunk 47 | Return: 48 | a torch vector of size 1 x Seq_size x Feature_dim 49 | """ 50 | seq = torchaudio.load(file_path)[0] 51 | sizeSeq = seq.size(1) 52 | start = 0 53 | out = [] 54 | while start < sizeSeq: 55 | if start + max_size_seq > sizeSeq: 56 | break 57 | end = min(sizeSeq, start + max_size_seq) 58 | subseq = (seq[:, start:end]).view(1, 1, -1).cuda(device=0) 59 | with torch.no_grad(): 60 | features = feature_maker(subseq) 61 | out.append(features.detach().cpu()) 62 | start += max_size_seq 63 | 64 | if start < sizeSeq: 65 | subseq = (seq[:, -max_size_seq:]).view(1, 1, -1).cuda(device=0) 66 | with torch.no_grad(): 67 | features = feature_maker(subseq) 68 | df = subseq.size(2) // features.size(1) 69 | delta = (sizeSeq - start) // df 70 | out.append(features[:, -delta:].detach().cpu()) 71 | 72 | out = torch.cat(out, dim=1) 73 | return out.view(out.size(1), out.size(2)) 74 | 75 | ############################################################################## 76 | # Minimal code to load a CPC checkpoint 77 | ############################################################################## 78 | 79 | 80 | class ChannelNorm(nn.Module): 81 | 82 | def __init__(self, 83 | numFeatures, 84 | epsilon=1e-05, 85 | affine=True): 86 | 87 | super(ChannelNorm, self).__init__() 88 | if affine: 89 | self.weight = nn.parameter.Parameter( 90 | torch.Tensor(1, numFeatures, 1)) 91 | self.bias = nn.parameter.Parameter(torch.Tensor(1, numFeatures, 1)) 92 | else: 93 | self.weight = None 94 | self.bias = None 95 | self.epsilon = epsilon 96 | self.p = 0 97 | self.affine = affine 98 | self.reset_parameters() 99 | 100 | def reset_parameters(self): 101 | if self.affine: 102 | torch.nn.init.ones_(self.weight) 103 | torch.nn.init.zeros_(self.bias) 104 | 105 | def forward(self, x): 106 | 107 | cumMean = x.mean(dim=1, keepdim=True) 108 | cumVar = x.var(dim=1, keepdim=True) 109 | x = (x - cumMean)*torch.rsqrt(cumVar + self.epsilon) 110 | 111 | if self.weight is not None: 112 | x = x * self.weight + self.bias 113 | return x 114 | 115 | 116 | class CPCEncoder(nn.Module): 117 | 118 | def __init__(self, 119 | sizeHidden=512): 120 | 121 | super(CPCEncoder, self).__init__() 122 | normLayer = ChannelNorm 123 | 124 | self.conv0 = nn.Conv1d(1, sizeHidden, 10, stride=5, padding=3) 125 | self.batchNorm0 = normLayer(sizeHidden) 126 | self.conv1 = nn.Conv1d(sizeHidden, sizeHidden, 8, stride=4, padding=2) 127 | self.batchNorm1 = normLayer(sizeHidden) 128 | self.conv2 = nn.Conv1d(sizeHidden, sizeHidden, 4, 129 | stride=2, padding=1) 130 | self.batchNorm2 = normLayer(sizeHidden) 131 | self.conv3 = nn.Conv1d(sizeHidden, sizeHidden, 4, stride=2, padding=1) 132 | self.batchNorm3 = normLayer(sizeHidden) 133 | self.conv4 = nn.Conv1d(sizeHidden, sizeHidden, 4, stride=2, padding=1) 134 | self.batchNorm4 = normLayer(sizeHidden) 135 | self.DOWNSAMPLING = 160 136 | 137 | def getDimOutput(self): 138 | return self.conv4.out_channels 139 | 140 | def forward(self, x): 141 | x = F.relu(self.batchNorm0(self.conv0(x))) 142 | x = F.relu(self.batchNorm1(self.conv1(x))) 143 | x = F.relu(self.batchNorm2(self.conv2(x))) 144 | x = F.relu(self.batchNorm3(self.conv3(x))) 145 | x = F.relu(self.batchNorm4(self.conv4(x))) 146 | return x 147 | 148 | 149 | class CPCAR(nn.Module): 150 | 151 | def __init__(self, 152 | dimEncoded, 153 | dimOutput, 154 | keepHidden, 155 | nLevelsGRU): 156 | 157 | super(CPCAR, self).__init__() 158 | self.baseNet = nn.LSTM(dimEncoded, dimOutput, 159 | num_layers=nLevelsGRU, batch_first=True) 160 | self.hidden = None 161 | self.keepHidden = keepHidden 162 | 163 | def getDimOutput(self): 164 | return self.baseNet.hidden_size 165 | 166 | def forward(self, x): 167 | 168 | try: 169 | self.baseNet.flatten_parameters() 170 | except RuntimeError: 171 | pass 172 | x, h = self.baseNet(x, self.hidden) 173 | if self.keepHidden: 174 | if isinstance(h, tuple): 175 | self.hidden = tuple(x.detach() for x in h) 176 | else: 177 | self.hidden = h.detach() 178 | return x 179 | 180 | 181 | class CPCModel(nn.Module): 182 | 183 | def __init__(self, 184 | encoder, 185 | AR): 186 | 187 | super(CPCModel, self).__init__() 188 | self.gEncoder = encoder 189 | self.gAR = AR 190 | 191 | def forward(self, batchData, label): 192 | encodedData = self.gEncoder(batchData).permute(0, 2, 1) 193 | cFeature = self.gAR(encodedData) 194 | return cFeature, encodedData, label 195 | 196 | 197 | class FeatureModule(torch.nn.Module): 198 | r""" 199 | A simpler interface to handle CPC models. Useful for a smooth workflow when 200 | working with CPC trained features. 201 | """ 202 | 203 | def __init__(self, featureMaker, get_encoded, 204 | seq_norm=True): 205 | super(FeatureModule, self).__init__() 206 | self.get_encoded = get_encoded 207 | self.model = featureMaker 208 | self.seq_norm = seq_norm 209 | self.config = None 210 | 211 | def forward(self, batch_data): 212 | # Input Size : BatchSize x 1 x SeqSize 213 | # Feature size: BatchSize x SeqSize x ChannelSize 214 | if self.is_cuda: 215 | batch_data = batch_data.cuda() 216 | cFeature, encoded, _ = self.model(batch_data, None) 217 | if self.get_encoded: 218 | cFeature = encoded 219 | if self.seq_norm: 220 | mean = cFeature.mean(dim=1, keepdim=True) 221 | var = cFeature.var(dim=1, keepdim=True) 222 | cFeature = (cFeature - mean) / torch.sqrt(var + 1e-08) 223 | return cFeature 224 | 225 | def cuda(self): 226 | self.is_cuda = True 227 | super(FeatureModule, self).cuda() 228 | 229 | def cpu(self): 230 | self.is_cuda = False 231 | super(FeatureModule, self).cuda() 232 | 233 | def get_output_dim(self): 234 | if self.get_encoded: 235 | return self.config["hiddenEncoder"] 236 | return self.config["hiddenGar"] 237 | 238 | 239 | if __name__ == "__main__": 240 | 241 | parser = argparse.ArgumentParser(description='Download model') 242 | parser.add_argument('model_name', type=str, 243 | choices=["600h", "6kh", "60kh"]) 244 | parser.add_argument('output', type=str) 245 | args = parser.parse_args() 246 | 247 | CPC_MODELS_NAMES = {"60kh": "60k_epoch4-d0f474de.pt", 248 | "600h": "600h-bdd7ced6.pt", 249 | "6kh":"6k_epoch30-9df0493c.pt"} 250 | state_dict = download_state_dict(CPC_MODELS_NAMES[args.model_name]) 251 | torch.save(state_dict, args.output) 252 | -------------------------------------------------------------------------------- /eval/ABX_src/unit_tests.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import unittest 3 | import torch 4 | from nose.tools import eq_, ok_ 5 | from . import abx_group_computation 6 | from . import abx_iterators 7 | import numpy as np 8 | import math 9 | 10 | 11 | class TestDistancesDTW(unittest.TestCase): 12 | 13 | def testDTWFunction(self): 14 | X = torch.tensor([[[0, 1], [0, 0], [1, 1], [42, 42]], 15 | [[0, 2], [0, 1], [1, 1], [-1, 0]], 16 | [[0, 0], [0, 1], [0, 0], [21, 211]]], 17 | dtype=torch.float) 18 | 19 | X_size = torch.tensor([3, 4, 2]) 20 | 21 | Y = torch.tensor([[[0, 1], [1, 2], [0, 0]]], dtype=torch.float) 22 | Y_size = torch.tensor([3]) 23 | 24 | distance_mode = abx_group_computation.get_euclidian_distance_batch 25 | dist = abx_group_computation.get_distance_group_dtw(X, Y, 26 | X_size, Y_size, 27 | distance_function=distance_mode) 28 | eq_(dist.size(), (3, 1)) 29 | expected_dist = [[(math.sqrt(2)) / 2], [3 / 4], 30 | [(2 + math.sqrt(2)) / 3]] 31 | for i in range(3): 32 | ok_(abs(expected_dist[i][0] - dist[i].item()) < 1e-4) 33 | 34 | def testThetaDTWFunctionSymetric(self): 35 | A = torch.tensor([[[0, 1], [0, 0], [1, 1], [42, 42]], 36 | [[0, 2], [0, 1], [1, 1], [-1, 0]], 37 | [[0, 0], [0, 1], [0, 0], [21, 211]]], 38 | dtype=torch.float) 39 | A_size = torch.tensor([3, 4, 2]) 40 | B = torch.tensor([[[0, 1], [1, 2], [0, 0]]], dtype=torch.float) 41 | B_size = torch.tensor([3]) 42 | 43 | distance_mode = abx_group_computation.get_euclidian_distance_batch 44 | symetric = True 45 | theta = abx_group_computation.get_theta_group_dtw(A, B, A, A_size, 46 | B_size, A_size, 47 | distance_mode, 48 | symetric) 49 | eq_(theta, 0.5) 50 | 51 | 52 | class testSingularityNormalization(unittest.TestCase): 53 | 54 | def testCosineNormalized(self): 55 | x = torch.tensor([[[1., 0., 0., 0.], [0., 0., 0., 0.]], 56 | [[0., 0., -1., 0.], [0.5, -0.5, 0.5, -0.5]]]) 57 | y = torch.tensor( 58 | [[-0.5, -0.5, -0.5, 0.5], [0., 0., 0., 0.], [0., 1., 0., 0.]]) 59 | norm_x = [] 60 | for i in range(2): 61 | norm_x.append(abx_iterators.normalize_with_singularity(x[i]).view(1, 2, 5)) 62 | norm_x = torch.cat(norm_x, dim=0) 63 | norm_y = abx_iterators.normalize_with_singularity(y).view(1, 3, 5) 64 | dist = abx_group_computation.get_cosine_distance_batch(norm_x, norm_y) 65 | 66 | eq_(dist.size(), (2, 1, 2, 3)) 67 | ok_(abs(dist[0, 0, 0, 0] - 0.6667) < 1e-4) 68 | ok_(abs(dist[0, 0, 0, 1] - 1.) < 1e-4) 69 | ok_(abs(dist[0, 0, 0, 2] - 0.5) < 1e-4) 70 | 71 | ok_(abs(dist[0, 0, 1, 0] - 1) < 1e-4) 72 | ok_(abs(dist[0, 0, 1, 1]) < 1e-4) 73 | ok_(abs(dist[0, 0, 1, 2] - 1) < 1e-4) 74 | 75 | ok_(abs(dist[1, 0, 0, 0] - 0.3333) < 1e-4) 76 | ok_(abs(dist[1, 0, 0, 1] - 1.) < 1e-4) 77 | ok_(abs(dist[1, 0, 0, 2] - 0.5) < 1e-4) 78 | 79 | ok_(abs(dist[1, 0, 1, 0]-0.6667) < 1e-4) 80 | ok_(abs(dist[1, 0, 1, 1] - 1.) < 1e-4) 81 | ok_(abs(dist[1, 0, 1, 2] - 0.6667) < 1e-4) 82 | 83 | 84 | class testGroupMaker(unittest.TestCase): 85 | 86 | def test1DGroupMaker(self): 87 | 88 | data = [[0], [1], [2], [3], [4], [2], [2], [2]] 89 | order = [0] 90 | out_index, out_data = abx_iterators.get_features_group(data, order) 91 | 92 | expected_index = [0, 1, 2, 5, 6, 7, 3, 4] 93 | eq_(out_index, expected_index) 94 | 95 | expected_output = [(0, 1), (1, 2), (2, 6), (6, 7), (7, 8)] 96 | eq_(out_data, expected_output) 97 | 98 | def test2DGroupMaker(self): 99 | 100 | data = [[0, 1], [1, 2], [2, 3], [3, 3], 101 | [4, 0], [2, 2], [4, 2], [2, 2], [0, 3]] 102 | 103 | order = [1, 0] 104 | out_index, out_data = abx_iterators.get_features_group(data, order) 105 | expected_index = [4, 0, 1, 5, 7, 6, 8, 2, 3] 106 | eq_(out_index, expected_index) 107 | expected_output = [[(0, 1)], 108 | [(1, 2)], 109 | [(2, 3), (3, 5), (5, 6)], 110 | [(6, 7), (7, 8), (8, 9)]] 111 | eq_(out_data, expected_output) 112 | 113 | def test3DGroupMaker(self): 114 | 115 | data = [[0, 0, 0, 1], 116 | [41, 1, 0, 2], 117 | [-23, 0, 3, 1], 118 | [220, 1, -2, 3], 119 | [40, 2, 1, 0], 120 | [200, 0, 0, 1]] 121 | 122 | order = [1, 3, 2] 123 | out_index, out_data = abx_iterators.get_features_group(data, order) 124 | expected_index = [0, 5, 2, 1, 3, 4] 125 | eq_(out_index, expected_index) 126 | 127 | expected_output = [[[(0, 2), (2, 3)]], [ 128 | [(3, 4)], [(4, 5)]], [[(5, 6)]]] 129 | eq_(out_data, expected_output) 130 | 131 | 132 | class testItemLoader(unittest.TestCase): 133 | 134 | def testLoadItemFile(self): 135 | path_item_file = "test_data/dummy_item_file.item" 136 | out, context_match, phone_match, speaker_match = \ 137 | abx_iterators.load_item_file(path_item_file) 138 | 139 | eq_(len(out), 4) 140 | eq_(len(phone_match), 5) 141 | eq_(len(speaker_match), 3) 142 | 143 | expected_phones = {'n': 0, 'd': 1, 'ih': 2, 144 | 's': 3, 'dh': 4} 145 | eq_(phone_match, expected_phones) 146 | 147 | expected_speakers = {'8193': 0, '2222': 1, '12': 2} 148 | eq_(speaker_match, expected_speakers) 149 | 150 | expected_context = {'ae+d': 0, 'n+l': 1, 'l+n': 2, 'ih+s': 3, 151 | 'n+ax': 4, 'ax+dh': 5, 's+ax': 6} 152 | eq_(context_match, expected_context) 153 | 154 | expected_output = {'2107': [[0.3225, 0.5225, 0, 0, 0], 155 | [0.4225, 0.5925, 1, 1, 1], 156 | [1.1025, 1.2925, 6, 4, 2]], 157 | '42': [[0.4525, 0.6525, 1, 1, 1], 158 | [0.5225, 0.7325, 2, 2, 0], 159 | [0.5925, 0.8725, 3, 0, 0]], 160 | '23': [[0.6525, 1.1025, 4, 3, 0], 161 | [0.7325, 1.1925, 4, 3, 1]], 162 | '407': [[0.8725, 1.2425, 5, 3, 1]]} 163 | 164 | eq_(expected_output, out) 165 | 166 | def testLoadWithinItemFile(self): 167 | path_item_file = "test_data/dummy_item_within.item" 168 | out, context_match, phone_match, speaker_match = \ 169 | abx_iterators.load_item_file(path_item_file) 170 | 171 | expected_output = {'2107': [[0., 0.2, 0, 0, 0], 172 | [0.3225, 0.5225, 1, 0, 0], 173 | [0.6, 0.75, 1, 0, 0], 174 | [0.4225, 0.5925, 2, 1, 1]], 175 | '42': [[0.4525, 0.6525, 2, 1, 1], 176 | [0.1301, 0.2501, 2, 2, 1], 177 | [0.5225, 0.7325, 2, 1, 0], 178 | [0.0025, 0.3561, 3, 1, 1], 179 | [0.5925, 0.8725, 3, 1, 0]]} 180 | eq_(expected_output, out) 181 | 182 | 183 | class testABXFeatureLoader(unittest.TestCase): 184 | 185 | def setUp(self): 186 | self.stepFeature = 10 187 | 188 | def dummy_feature_maker(path_file, *args): 189 | data = torch.tensor(np.load(path_file)) 190 | assert(len(data.size()) == 1) 191 | return data.view(-1, 1) 192 | 193 | def testBaseLoader(self): 194 | seqList = [('2107', 'test_data/2107.npy'), 195 | ('42', 'test_data/42.npy'), 196 | ('23', 'test_data/23.npy'), 197 | ('407', 'test_data/407.npy')] 198 | 199 | dataset = abx_iterators.ABXFeatureLoader("test_data/dummy_item_file.item", 200 | seqList, 201 | testABXFeatureLoader.dummy_feature_maker, 202 | self.stepFeature, 203 | False) 204 | print(dataset.features) 205 | eq_(dataset.feature_dim, 1) 206 | eq_(len(dataset), 9) 207 | eq_(len(dataset.data.size()), 2) 208 | eq_(len(dataset.data), 16) 209 | data, size, coords = dataset[0] 210 | eq_(size, 1) 211 | eq_(coords, (0, 0, 0)) 212 | eq_(data.tolist(), [[3]]) 213 | 214 | data, size, coords = dataset[3] 215 | eq_(size, 1) 216 | eq_(coords, (1, 1, 1)) 217 | eq_(data.tolist(), [[5]]) 218 | 219 | def testWithinIterator(self): 220 | seqList = [('2107', 'test_data/2107.npy'), 221 | ('42', 'test_data/42.npy')] 222 | dataset = abx_iterators.ABXFeatureLoader("test_data/dummy_item_within.item", 223 | seqList, 224 | testABXFeatureLoader.dummy_feature_maker, 225 | self.stepFeature, 226 | False) 227 | iterator = dataset.get_iterator('within', 40) 228 | eq_(iterator.index_csp, [0, 1, 2, 6, 3, 4, 5, 8, 7]) 229 | eq_(iterator.groups_csp, [[[(0, 1)]], [[(1, 3)]], [ 230 | [(3, 4)], [(4, 6), (6, 7)]], [[(7, 8)], [(8, 9)]]]) 231 | eq_(len(iterator), 1) 232 | 233 | it = iter(iterator) 234 | c1, a_01, b_01, x_01 = next(it) 235 | eq_(c1, (1, 1, 2, 2)) 236 | a_1, s_a = a_01 237 | eq_(s_a.tolist(), [1, 1]) 238 | eq_(a_1.tolist(), [[[4.]], [[5.]]]) 239 | eq_(x_01[0].tolist(), a_1.tolist()) 240 | eq_(x_01[1].tolist(), s_a.tolist()) 241 | eq_(b_01[0].tolist(), [[[1.]]]) 242 | eq_(b_01[1].item(), 1) 243 | 244 | eq_(next(it, False), False) 245 | eq_(iterator.get_board_size(), (2, 3, 3, 4)) 246 | -------------------------------------------------------------------------------- /data_preparation/README.md: -------------------------------------------------------------------------------- 1 | # Libri-light Data Preparation and Download 2 | 3 | Below, we provide steps scripts to reconstruct the raw dataset from scratch including data download, conversion into flac, 4 | Voice Activity Detection (VAD) and Signal to Noise (SNR) computation, metadata construction, and dataset filtering and splitting. 5 | 6 | All scripts mentioned below can be found in this directory. 7 | 8 | ## Downloading the Data 9 | 10 | ### 1. Getting and Preparing the Unlabeled Audio 11 | 12 | #### 1A. Downloading 13 | 14 | The unlabelled data is spit into 3 subsets of increasing sizes (small, medium, large) in order to facilitate experimentation on smaller amounts of data (further, downloading the large dataset may take several days). 15 | 16 | - [small.tar (577 hours, 35 GB)](https://dl.fbaipublicfiles.com/librilight/data/small.tar) 17 | - md5: `c49207eb86a8e8ac895561c37232041e` 18 | - [medium.tar (5193 hours, 321 GB)](https://dl.fbaipublicfiles.com/librilight/data/medium.tar) 19 | - md5: `c75e7ac62471bfbf2db77528d62a9b74` 20 | - [large.tar (51934 hours, 3.05 TB)](https://dl.fbaipublicfiles.com/librilight/data/large.tar) 21 | - md5: `4dfbac018f50b99797ece101fc9f0c30` 22 | 23 | We additionally provide a 4th subset containing potentially duplicated books. 24 | 25 | - [unlab_duplicate.tar (4500 hours, 274 GB)](https://dl.fbaipublicfiles.com/librilight/data/duplicate.tar) 26 | 27 | The directory structure of the audio archives mirrors that of [LibriSpeech](http://www.openslr.org/12): 28 | 29 | dataset_name/speakerID/book_name/ 30 | 31 | where `dataset_name` is `small`, `medium`, `large`, `duplicate`, `speakerID` is the LibriVox speakerID (a number), and `book_name` the name of the original LibriVox audiobook file. Inside each directory is a set of `.flac` and `.json` files. See below for the format of the `.json` files. 32 | 33 | By combining these subsets, one can construct the 3 splits described in the Libri-Light paper: 34 | 35 | - *unlab-60k* : small + medium + large 36 | - *unlab-6k* : small + medium 37 | - *unlab-600* : small 38 | 39 | 40 | Once the dataset is downloaded, untarred and organized into a directory (`UNLAB_DIR`) you can check its statistics by running: 41 | ```console 42 | python build_all_stats.py UNLAB_DIR OUTPUT_DIR 43 | ``` 44 | This will construct, in `OUTPUT_DIR`, two `.png` files (in addition to `.json` files in a `.cache` directory) 45 | 46 | #### 1B. Segmenting 47 | 48 | Original audio files are long and may not fit into memory. As a final step, we provide a script to segment the files into roughly 60sec sequences obtained by concatenating consecutive chunks containing voice activity: 49 | 50 | ```console 51 | python cut_by_vad.py --input_dir INPUT_DIR --output_dir OUTPUT_DIR 52 | ``` 53 | 54 | `OUTPUT_DIR` will have the same structure as above, but each `file_name` directory will have a list of smaller files (`.flac`). You can modify this step as fits your pipeline and model. 55 | 56 | ### 2. Get the limited-supervision train data 57 | 58 | The limited supervision training sets are built on LibriSpeech. They consist in 10h, 1h, and 10 minute splits with orthographic transciptions and aligned phoneme transcriptions, which can be used to train small models or fine-tune pretrained ones. These can be downloaded here: 59 | 60 | - [librispeech_finetuning.tgz (10 hours, 0.6 GB)](https://dl.fbaipublicfiles.com/librilight/data/librispeech_finetuning.tgz) 61 | 62 | The directory structure is as follows: 63 | 64 | 1h/ # data of the 1h split (made up of 6 folds of 10 min) 65 | 0/ # first 10 min fold 66 | clean/ # 2 speakers, clean 67 | other/ # 2 speakers, other 68 | ... 69 | 5/ # last 10 min fold 70 | clean/ # 2 speakers, clean 71 | other/ # 2 speakers, other 72 | 9h/ # remaining data of the 10h split (10h=1h+9h) 73 | clean / # 12 speakers, clean 74 | other/ # 12 speakers, other 75 | phones/ # phoneme alignment for all of the files 76 | 77 | 78 | The 10h split is created by combining the data from the `9h/` and the `1h` directories. The 1h split is itself made of 6 folds of 10 min splits. The `phone/` directory contains the frame-wise phoneme transcription of the various splits (the IPA phone mappings are in `phone_mapping.json`). There is also a phoneme transcription of the LibriSpeech `dev` and `test` sets. 79 | Alternatively, one can reconstruct the dataset by downloading by hand librispeech and running the scripts in `rebuild_limited_train/`. 80 | 81 | 82 | ### 3. Get the dev and test sets (for evaluation) 83 | 84 | The standard LibriSpeech dev and test sets are used for evaluation, and can be found at: 85 | 86 | wget http://www.openslr.org/resources/12/dev-clean.tar.gz 87 | wget http://www.openslr.org/resources/12/dev-other.tar.gz 88 | wget http://www.openslr.org/resources/12/test-clean.tar.gz 89 | wget http://www.openslr.org/resources/12/test-other.tar.gz 90 | 91 | 92 | 93 | ## Regenerating the Dataset from Scratch (or with books in Another Language!) 94 | 95 | Below, we provide the steps needed to completely reproduce generation of the dataset starting from raw LibriVox audio. 96 | 97 | First, download the audio data from [LibriVox](https://librivox.org/). 98 | ```console 99 | python download_librivox.py $OUTPUT_DOWNLOAD 100 | ``` 101 | 102 | Data can be downloaded in another language. To do so, pass `--language` to above script (for example `--language French`). The amount of available data for each language may differ. 103 | 104 | To unzip the data, run: 105 | ```console 106 | python unzip_and_convert.py unzip $OUTPUT_DOWNLOAD -o $OUTPUT_MP3 107 | ``` 108 | 109 | And convert them from `.mp3` to flac: 110 | ```console 111 | python unzip_and_convert.py convert $OUTPUT_MP3 -o $OUTPUT_FLAC -f .flac 112 | ``` 113 | 114 | ### Running Voice Activity Detection and SNR Computation 115 | 116 | Voice Activity Detection (VAD) is accomplished using [wav2letter](https://github.com/facebookresearch/wav2letter/). Once you've [downloaded and installed wav2letter](https://github.com/facebookresearch/wav2letter/wiki/General-building-instructions) and its [dependencies](https://github.com/facebookresearch/wav2letter/wiki/Dependencies), make sure the [VAD and Audio Analysis suite](https://github.com/facebookresearch/wav2letter/tree/master/tools#voice-activity-detection-and-audio-analysis) is built. 117 | 118 | #### Prepare the Input List file 119 | The wav2letter VAD pipeline expects an [input list file](https://github.com/facebookresearch/wav2letter/wiki/Data-Preparation#audio-and-transcriptions-data) that contains lines with ordered tuples of `[sample ID] [path to sample] [duration]` for each sample. 120 | 121 | Run `make_vad_inputs.py` to prepare the data in a list file format for VAD input: 122 | ```console 123 | python make_vad_inputs.py \ 124 | --path_db [path to dataset] \ 125 | --extension [audio extension] \ 126 | --path_out [path to output file] 127 | ``` 128 | If `--path_out` is not specified, each will be placed in the same directory from which the sample originated. 129 | 130 | #### Running VAD 131 | 132 | To extract the VAD, given some input list file, follow the instructions to [**run the analysis script**](https://github.com/facebookresearch/wav2letter/blob/master/tools/README.md#voice-activity-detection-and-audio-analysis) with the list file generated in the previous step as the file passed to `--test`. 133 | 134 | If `make_vad_inputs.py` is used to generate the input list file, then the [analysis output files](https://github.com/facebookresearch/wav2letter/blob/master/tools/README.md#voice-activity-detection-and-audio-analysis) for each sample will be placed in the same directory as is the sample's original audio. 135 | 136 | ### Running SNR 137 | 138 | To extract the SNR, run: 139 | ```console 140 | python calculate_snr.py librivox.lst > calculated_snr.tsv 141 | ``` 142 | 143 | or if the job failed and need to resume you can run 144 | ```console 145 | python calculate_snr.py librivox.lst calcuated_snr.tsv > calculated_snr_leftover.tsv 146 | ``` 147 | This program looks at the VAD output, classifies speech frames and non speech frames base on a dataset specific threshold, removes unclassified frames, and calculates the SNR base on (speech power / non-speech power) 148 | 149 | Prerequisite: 150 | - `librivox.lst` is a plain list of all wav filepath downloaded from librivox.org, which looks like this: 151 | ```console 152 | /some/path/to/audio1.wav 153 | /some/path/to/audio2.wav 154 | ``` 155 | - you have to finish the VAD step to have generated the corresponding .vad file in the same folder as each .wav files 156 | - If you have retrained the VAD model or running on a different dataset, looking at the histogram over some audio files to decide on the threshold is essential for good performance. 157 | - wav file input is required to be 16kHz. 158 | 159 | ### Preparing Metadata Files 160 | 161 | To create metadata files: 162 | ```console 163 | python complete_metadata.py --path_metadata $OUTPUT_DOWNLOAD --out_dir $OUTPUT_FLAC --ignore_cache 164 | ``` 165 | 166 | This command will also make the list of all duplicate books at save it at $OUTPUT_FLAC/global/duplicates.json. 167 | 168 | ![pipeline](data_preparation_pipeline.svg) 169 | Figure 1. Complete data preparation pipeline. 170 | 171 | ## Metadata JSON File Format 172 | 173 | For each LibriVox audio file, we create one JSON metadata file. This differs from the LibriVox distribution, which contains one JSON per book where a single book may have multiple associated audio files. 174 | 175 | Below is a labeled example of output metadata produced by the pipeline: 176 | 177 | { 178 | "speaker" : "960" # LibriVox speaker ID 179 | "book_meta": {  # a bunch of LibriVox metadata concerning the book relevant to that file 180 | "id": "319" # LibriVox book ID 181 | "title": "History of Holland" # LibriVox book title 182 | ... 183 | "genre": [ # LibriVox genre 184 | "*Non-fiction", 185 | "History" 186 | ], # from this point, this is our own-libri-light metadata: 187 | "Dramatic Readings": false, # boolean for dramatic vs normal reading 188 | "meta_genre" : "Literature" # meta-genre among, 7 possibilities 189 | }, # ["Literature", "Science, Craft & Essay", "Undefined", "Religion", "Poetry", "Theater", "Ancient"] 190 | "snr": 5.391, # Signal to Noise Ratio computed on the basis of Voice Activity Detection 191 | "voice_activity": [ # onsets and offsets (in seconds) of each VAD segments 192 | [ 193 | 0.4, 194 | 12.32 195 | ], 196 | ... 197 | --------------------------------------------------------------------------------