├── DeepSpeech
└── README.md
├── README.md
├── logo.jpeg
├── speechbrain
├── ASR
│ └── CTC
│ │ ├── README.md
│ │ ├── dvoice_prepare.py
│ │ ├── extra_requirements.txt
│ │ ├── hparams
│ │ ├── train_dar_with_wav2vec.yaml
│ │ └── train_sw_with_wav2vec.yaml
│ │ └── train_with_wav2vec.py
└── dvoice_prepare.py
└── wav2vec 2.0
├── README.md
├── __init__.py
├── config
├── finetuning
│ ├── base_100h.yaml
│ ├── base_10h.yaml
│ ├── base_10m.yaml
│ ├── base_1h.yaml
│ ├── base_960h.yaml
│ ├── vox_100h.yaml
│ ├── vox_10h.yaml
│ ├── vox_10m.yaml
│ ├── vox_1h.yaml
│ └── vox_960h.yaml
└── pretraining
│ ├── wav2vec2_base_librispeech.yaml
│ ├── wav2vec2_large_librivox.yaml
│ ├── wav2vec2_large_librivox_tpu-pod.yaml
│ └── wav2vec2_large_librivox_tpu.yaml
├── libri_labels.py
├── scripts
└── binarize_manifest.sh
├── unsupervised
├── README.md
├── __init__.py
├── config
│ ├── finetuning
│ │ └── w2v_finetune.yaml
│ ├── gan
│ │ └── w2vu.yaml
│ ├── generate
│ │ └── viterbi.yaml
│ ├── timit_matched
│ │ ├── test.uid
│ │ ├── train.uid
│ │ ├── train_text.uid
│ │ └── valid.uid
│ └── timit_unmatched
│ │ ├── test.uid
│ │ ├── train.uid
│ │ ├── train_text.uid
│ │ └── valid.uid
├── data
│ ├── __init__.py
│ ├── extracted_features_dataset.py
│ └── random_input_dataset.py
├── kaldi_self_train
│ ├── README.md
│ └── st
│ │ ├── cmd.sh
│ │ ├── decode_phone.sh
│ │ ├── decode_word_step1.sh
│ │ ├── decode_word_step2.sh
│ │ ├── local
│ │ ├── copy_aligned_text.py
│ │ ├── decode.sh
│ │ ├── prepare_data_from_w2v.py
│ │ ├── prepare_lang.sh
│ │ ├── prepare_lang_word.sh
│ │ ├── prepare_lm.sh
│ │ ├── score.sh
│ │ ├── show_wer.sh
│ │ ├── train_subset_lgbeam.sh
│ │ ├── unsup_select.py
│ │ ├── unsup_select_decode.sh
│ │ └── unsup_select_decode_word.sh
│ │ ├── path.sh
│ │ ├── steps
│ │ ├── steps_gan
│ │ ├── train_deltas.sh
│ │ ├── train_lda_mllt.sh
│ │ └── train_sat.sh
│ │ ├── train.sh
│ │ └── utils
├── models
│ ├── __init__.py
│ └── wav2vec_u.py
├── scripts
│ ├── apply_pca.py
│ ├── copy_labels.py
│ ├── filter_lexicon.py
│ ├── filter_tsv.py
│ ├── g2p_wrd_to_phn.py
│ ├── ltr_to_wrd.py
│ ├── mean_pool.py
│ ├── merge_clusters.py
│ ├── normalize_and_filter_text.py
│ ├── normalize_text.py
│ ├── pca.py
│ ├── phonemize_with_sil.py
│ ├── prepare_audio.sh
│ ├── prepare_text.sh
│ ├── prepare_timit.sh
│ ├── remove_silence.py
│ ├── vads.py
│ ├── wav2vec_apply_cluster_faiss.py
│ ├── wav2vec_cluster_faiss.py
│ ├── wav2vec_extract_features.py
│ ├── wer.py
│ └── wrd_to_ltr.py
├── tasks
│ ├── __init__.py
│ └── unpaired_audio_text.py
└── w2vu_generate.py
├── vq-wav2vec_featurize.py
├── wav2vec_featurize.py
├── wav2vec_manifest.py
└── xlsr_wav2vec2_darija_finetuning.ipynb
/DeepSpeech/README.md:
--------------------------------------------------------------------------------
1 | # 1. Description du modèle
2 | [DeepSpeech](https://arxiv.org/abs/1412.5567) est un modèle de reconnaissance vocale open source développé par Mozilla. DeepSpeech est très flexible, s'adapte bien sur les données à l'entraînement grâce à sa capacité à identifier les bruits de fond et au modèle de langage inclut au moment de l'entraînement.
3 |
4 | # 2. Entraîner son propre modèle
5 | Pour entraîner le modèle, il faut avoir les pré-requis suivants:
6 | - Python 3.6.
7 | - Mac or Linux environment.
8 | - CUDA 10.0 / CuDNN v7.6 par Dockerfile.
9 |
10 | Il faut ensuite suivre les étapes suivantes:
11 |
12 | ## 2.1. Préparation de l'environnement
13 | ```shell script
14 | # Importer le code du repository
15 | git clone --branch v0.9.3 https://github.com/mozilla/DeepSpeech
16 |
17 | # Créer un environnement virtuel
18 | $ python3 -m venv $HOME/tmp/deepspeech-train-venv/
19 | $ source $HOME/tmp/deepspeech-train-venv/bin/activate
20 |
21 | # Installer le code pour l'entraînement et toutes ses dépendances
22 | cd DeepSpeech
23 | pip3 install --upgrade pip==20.2.2 wheel==0.34.2 setuptools==49.6.0
24 | pip3 install --upgrade -e .
25 | sudo apt-get install python3-dev
26 |
27 | # Il est recommandé d'installer Tensorflow sur GPU, lorsqu'on dispose d'une machine NVIDIA avec au moins 8G de RAM
28 | pip3 uninstall tensorflow
29 | pip3 install 'tensorflow-gpu==1.15.4'
30 |
31 | # Pour ceux qui veulent entraîner l'environnement sous Docker
32 | make Dockerfile.train
33 | make Dockerfile.train DEEPSPEECH_REPO=git://your/fork DEEPSPEECH_SHA=origin/your-branch
34 |
35 | ```
36 |
37 | ## 2.2. Charger la base de données
38 | On suppose ici qu'on utilise la base de données de Mozilla Common Voice.
39 |
40 | ```shell script
41 | bin/import_cv2.py --filter_alphabet path/to/some/alphabet.txt /path/to/extracted/language/archive
42 | python3 DeepSpeech.py --train_files ../data/CV/en/clips/train.csv --dev_files ../data/CV/ar/clips/dev.csv --test_files ../data/CV/ar/clips/test.csv
43 | ```
44 |
45 | ## 2.3. Entraîner le modèle
46 | L'entraînement du modèle est tout simple:
47 | ```shell script
48 | python3 DeepSpeech.py --helpfull
49 | ./bin/run-ldc93s1.sh
50 | ```
51 | Pour plus de personnalisation à l'entraînement du modèle, on vous recommande de suivre les étapes décrites [ici](https://deepspeech.readthedocs.io/en/r0.9/TRAINING.html).
52 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | # 1. A propos
6 | Basée à Rabat, Londres et Paris, AIOX-Labs mobilise les technologies d’intelligence artificielle pour répondre aux besoins métiers et projets data des entreprises.
7 | - Au service de la croissance des groupes, de l’optimisation des processus ou de l’amélioration de l'expérience client.
8 | - AIOX-Labs est multisecteur, de la fintech à l’industrie en passant par le retail et les biens de consommation.
9 | - Des data products business ready avec un socle algorithmique solide et une adaptabilité pour les besoins spécifiques de chaque client.
10 | - Une équipe complémentaire composée de docteurs en IA et d’experts métiers avec une assise scientifique solide et des publications internationales.
11 |
12 | Site web : https://www.aiox-labs.com/
13 |
14 | Dialectal Voice est un projet communautaire initié par AIOX-Labs pour faciliter la reconnaissance de la voix par les Systèmes Intelligents. Aujourd'hui, le besoin en Systèmes IA capable de reconnaître la voix humaine s'exprime de plus en plus au sein des communautés. On remarque cependant que pour certaines langues comme le Darija, il n'existe pas assez de solutions de technologie vocale. Pour répondre à ce besoin, on s'est proposé alors d'établir ce programme de construction itérative et intéractive d'une base de données dialectale et ouverte à tous afin d'aider à améliorer les modèles de reconnaissance et de génération de la voix.
15 |
16 | Sites web : https://dvoice.ma/, https://dvoice.sn/
17 |
18 | # 2. Contribuer au projet DVoice
19 | Il existe deux manières de contribuer au projet:
20 | ## 2.1. Soumettre des enregistrements "هضر"
21 | Pour aider à améliorer les technologies vocales sur l’arabe dialectal marocain, il est nécessaire d’avoir une quantité importante de données pour l’entraînement des modèles. Les données se présentent sous forme de « voix + transcription textuelle ». La contribution à l’enrichissement de la dataset se fait par la lecture d’un texte Darija affiché à l’écran suivie de la soumission de l’enregistrement.
22 |
23 | N.B : Au moment de la lecture des textes, veuillez vous assurer que votre micro fonctionne parfaitement et qu’il n’y a pas beaucoup de bruits de fond.
24 |
25 | ## 2.2. Évaluer des enregistrements "سمع"
26 | Une autre manière de contribuer à cette initiative, c’est d’évaluer des enregistrements. Il s’agit ici d’écouter un échantillon « voix + transcription textuelle » et de l’évaluer en fonction de la proximité entre la voix enregistrée et le texte correspondant. Lorsqu’on juge les deux assez proches, on clique sur « Oui » sinon ça sera « Non ».
27 |
28 | # 3. Obtenir la dataset
29 | ## 3.1. DVoice-v1.0
30 | La base de données construite est ouverte au publique. Son acquisition se fait sous deux différentes manières:
31 | - Après souscription sur ce site. On vous l'enverra par suite par mail dans un bref délai.
32 | - En allant sur le répertoire Zenodo depuis ce lien : [Télécharger ici](https://zenodo.org/record/5482551).
33 |
34 | ## 3.2. DVoice-v2.0
35 | Cette deuxième version de la dataset, disponible [ici](https://zenodo.org/record/6342622), intègre les données augmentées, facilement repérables et une dataset swahilie obtenue par transfer learning depuis la dataset multilingue [VoxLingua107](http://bark.phon.ioc.ee/voxlingua107/).
36 |
37 | N.B : Pour être informé lorsqu'une nouvelle mise à jour de la base de données sera disponible, vous pouvez vous souscrire depuis la page Souscription.
38 |
39 | ## 3.2. Dataset alternative
40 | L'entraînement d'un modèle de reconnaissance vocale nécessite souvent d'avoir d'une part un ensemble d'enregistrements vocales et d'une autre part leurs transcriptions en texte. L'initiative DVoice vise justement à pourvoir à la communauté ces prérequis afin de mieux faciliter l'adoption des technologies vocales et la recherche autour d'elles.
41 |
42 | Une méthode alternative consiste à faire du transfert learning sur des enregistrements en transcrivant ces dernières en texte. On procède comme suit:
43 | - Récupérer du contenu vidéo public sur Youtube, Facebook,...
44 | - Convertir les vidéos en audios.
45 | - Segmenter les audios obtenus (faire une segmentation par silence).
46 | - Effectuer un pre-processing sur ces données (réduire les bruits, ajouter du silence si nécessaire,...).
47 | - Labelliser les audios avec la bibliothèque [SpeechRecognition](https://pypi.org/project/SpeechRecognition/)
48 |
49 | # 4. Notre vision pour DVoice Africa
50 | La plus grande nouveauté de cette première mise à jour du projet DVoice est certainement l'inclusion d'autres langues africaines, un pas important vers l'objectif ultime de ce projet. Aujourd'hui, le projet comprend, au niveau de la collection de données, le Darija puis, via DVoice Senegal, supporte six langues et dialectes (Wolof, Serere, Diola, Mandingue, Pular, Soninke) parlées au Sénégal et dans plusieurs pays d'Afrique de l'Est. Au niveau modèle, il supporte le Darija et le Swahili.
51 |
52 | Les premières versions des modèles testés ont de très bons scores. Nous encourageons vivement tout le monde à participer à ce programme communautaire en visitant notre site [dvoice.ma](https://dvoice.ma/) mais aussi le tout nouveau [dvoice.sn](https://dvoice.sn/) pour proposer ou valider quelques enregistrements. Celà aidera considérablement à la construction d'une grande base de données vocale africaine.
53 |
54 | De notre part on essaie encore d'aller plus loin en accentuant sur des potentiels paramètres/approches à embarquer et sur la possibilité d'intégrer un modèle de langage qui permettra de mieux affiner les transcriptions. L'idée ensuite est de mettre au point et de facilitier les technologies vocales pour les langues africaines.
55 |
--------------------------------------------------------------------------------
/logo.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIOXLABS/DVoice/19b760ca93ed9016d406bd9baf077e49bab384ec/logo.jpeg
--------------------------------------------------------------------------------
/speechbrain/ASR/CTC/README.md:
--------------------------------------------------------------------------------
1 | # DVoice ASR with CTC based Seq2Seq models.
2 | This folder contains scripts necessary to run an ASR experiment with the DVoice datasets : [Link](https://zenodo.org/record/6342622)
3 |
4 | # Data preparation
5 | [DVoice](https://dvoice.ma) attempts to provide automatic voice processing solutions for African languages and dialects. We use preprocessing techniques including voice augmentation to fill the data gap for each language.
6 |
7 | # How to run
8 | - First, get Speechbrain
9 |
10 | ``` bash
11 | git clone https://github.com/speechbrain/speechbrain
12 | ```
13 | - Place the Dvoice/speechbrain folder inside speechbrain/recipes folder
14 |
15 | - Go to speechbrain/recipes/DVoice/ASR/CTC the run:
16 | ``` bash
17 | python train.py hparams/{hparam_file}.py
18 | ```
19 |
20 |
21 | # Languages
22 | Here is a list of the different languages and dialects that we tested within the DVoice dataset and CTC:
23 | - Darija
24 | - Swahili (upcoming soon)
25 |
26 | # Results
27 |
28 | | Language | DVoice Release | hyperparams file | LM | Val. CER | Val. WER | Test CER | Test WER | HuggingFace link |
29 | | ------------- |:-------------:|:---------------------------:| -----:| -----:| -----:| -----:| -----:| :-----------:|
30 | | Darija (Moroccan Arabic) | v2.0 | train_dar_with_wav2vec.yaml | No | 5.51 | 18.46 | 5.85 | 18.28 | [Link](https://huggingface.co/nairaxo/dvoice-darija) |
31 | | Swahili | v2.0 | train_sw_with_wav2vec.yaml | No | 8.83 | 22.78 | 9.46 | 23.16 | [Link](https://huggingface.co/nairaxo/dvoice-swahili) |
32 |
33 |
34 |
35 |
36 | ## How to simply use pretrained models to transcribe my audio file?
37 |
38 | SpeechBrain provides a simple interface to transcribe audio files with pretrained models. All the necessary information can be found on the different HuggingFace repositories (see the results table above) corresponding to our different models for DVoice.
39 |
40 | # **About SpeechBrain**
41 | - Website: https://speechbrain.github.io/
42 | - Code: https://github.com/speechbrain/speechbrain/
43 | - HuggingFace: https://huggingface.co/speechbrain/
44 |
45 |
46 | # **Citing SpeechBrain**
47 | Please, cite SpeechBrain if you use it for your research or business.
48 |
49 | ```bibtex
50 | @misc{speechbrain,
51 | title={{SpeechBrain}: A General-Purpose Speech Toolkit},
52 | author={Mirco Ravanelli and Titouan Parcollet and Peter Plantinga and Aku Rouhe and Samuele Cornell and Loren Lugosch and Cem Subakan and Nauman Dawalatabad and Abdelwahab Heba and Jianyuan Zhong and Ju-Chieh Chou and Sung-Lin Yeh and Szu-Wei Fu and Chien-Feng Liao and Elena Rastorgueva and François Grondin and William Aris and Hwidong Na and Yan Gao and Renato De Mori and Yoshua Bengio},
53 | year={2021},
54 | eprint={2106.04624},
55 | archivePrefix={arXiv},
56 | primaryClass={eess.AS},
57 | note={arXiv:2106.04624}
58 | }
59 | ```
60 |
--------------------------------------------------------------------------------
/speechbrain/ASR/CTC/dvoice_prepare.py:
--------------------------------------------------------------------------------
1 | ../../dvoice_prepare.py
2 |
--------------------------------------------------------------------------------
/speechbrain/ASR/CTC/extra_requirements.txt:
--------------------------------------------------------------------------------
1 | transformers==4.13
2 |
--------------------------------------------------------------------------------
/speechbrain/ASR/CTC/hparams/train_dar_with_wav2vec.yaml:
--------------------------------------------------------------------------------
1 | # ################################
2 | # Model: wav2vec2 + DNN + CTC
3 | # Augmentation: SpecAugment
4 | # Authors: Titouan Parcollet 2021
5 | # ################################
6 |
7 | # Seed needs to be set at top of yaml, before objects with parameters are made
8 | seed: 1234
9 | __set_seed: !!python/object/apply:torch.manual_seed [!ref ]
10 | output_folder: !ref results/wav2vec2_ctc_DAR/
11 | wer_file: !ref /wer.txt
12 | save_folder: !ref /save
13 | train_log: !ref /train_log.txt
14 |
15 | # URL for the biggest LeBenchmark wav2vec french.
16 | wav2vec2_hub: facebook/wav2vec2-large-xlsr-53
17 |
18 | # Data files
19 | data_folder: #!PLACEHOLDER # e.g, /dataset/
20 | train_csv_file: !ref /texts/train2.csv # Standard CommonVoice .tsv files
21 | dev_csv_file: !ref /texts/dev2.csv # Standard CommonVoice .tsv files
22 | test_csv_file: !ref /texts/test2.csv # Standard CommonVoice .tsv files
23 | accented_letters: True
24 | language: dar # use 'it' for Italian, 'rw' for Kinyarwanda, 'en' for english
25 | train_csv: !ref /train.csv
26 | valid_csv: !ref /dev.csv
27 | test_csv: !ref /test.csv
28 | skip_prep: False # Skip data preparation
29 |
30 | # We remove utterance slonger than 10s in the train/dev/test sets as
31 | # longer sentences certainly correspond to "open microphones".
32 | avoid_if_longer_than: 15.0
33 |
34 | # Training parameters
35 | number_of_epochs: 30
36 | number_of_ctc_epochs: 15
37 | lr: 1.0
38 | lr_wav2vec: 0.0001
39 | ctc_weight: 0.3
40 | sorting: ascending
41 | auto_mix_prec: False
42 | sample_rate: 16000
43 | ckpt_interval_minutes: 30 # save checkpoint every N min
44 |
45 | # With data_parallel batch_size is split into N jobs
46 | # With DDP batch_size is multiplied by N jobs
47 | # Must be 6 per GPU to fit 16GB of VRAM
48 | batch_size: 4
49 | test_batch_size: 4
50 |
51 | dataloader_options:
52 | batch_size: !ref
53 | num_workers: 2
54 | test_dataloader_options:
55 | batch_size: !ref
56 | num_workers: 2
57 |
58 | # BPE parameters
59 | token_type: char # ["unigram", "bpe", "char"]
60 | character_coverage: 1.0
61 |
62 | # Model parameters
63 | activation: !name:torch.nn.LeakyReLU
64 | wav2vec_output_dim: 1024
65 | dnn_neurons: 1024
66 | freeze_wav2vec: False
67 |
68 | # Outputs
69 | output_neurons: 36 # BPE size, index(blank/eos/bos) = 0
70 |
71 | # Decoding parameters
72 | # Be sure that the bos and eos index match with the BPEs ones
73 | blank_index: 0
74 | bos_index: 1
75 | eos_index: 2
76 | min_decode_ratio: 0.0
77 | max_decode_ratio: 1.0
78 | beam_size: 80
79 | eos_threshold: 1.5
80 | using_max_attn_shift: True
81 | max_attn_shift: 140
82 | ctc_weight_decode: 0.0
83 | temperature: 1.50
84 |
85 | #
86 | # Functions and classes
87 | #
88 | epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
89 | limit: !ref
90 |
91 | augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
92 | sample_rate: !ref
93 | speeds: [95, 100, 105]
94 |
95 | enc: !new:speechbrain.nnet.containers.Sequential
96 | input_shape: [null, null, !ref ]
97 | linear1: !name:speechbrain.nnet.linear.Linear
98 | n_neurons: 1024
99 | bias: True
100 | bn1: !name:speechbrain.nnet.normalization.BatchNorm1d
101 | activation: !new:torch.nn.LeakyReLU
102 | drop: !new:torch.nn.Dropout
103 | p: 0.15
104 | linear2: !name:speechbrain.nnet.linear.Linear
105 | n_neurons: 1024
106 | bias: True
107 | bn2: !name:speechbrain.nnet.normalization.BatchNorm1d
108 | activation2: !new:torch.nn.LeakyReLU
109 | drop2: !new:torch.nn.Dropout
110 | p: 0.15
111 | linear3: !name:speechbrain.nnet.linear.Linear
112 | n_neurons: 1024
113 | bias: True
114 | bn3: !name:speechbrain.nnet.normalization.BatchNorm1d
115 | activation3: !new:torch.nn.LeakyReLU
116 |
117 | wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
118 | source: !ref
119 | output_norm: True
120 | freeze: !ref
121 | save_path: !ref /wav2vec2_checkpoint
122 |
123 | #####
124 | # Uncomment this block if you prefer to use a Fairseq pretrained model instead
125 | # of a HuggingFace one. Here, we provide an URL that is obtained from the
126 | # Fairseq github for the multilingual XLSR.
127 | #
128 | #wav2vec2_url: https://dl.fbaipublicfiles.com/fairseq/wav2vec/xlsr_53_56k.pt
129 | #wav2vec2: !new:speechbrain.lobes.models.fairseq_wav2vec.FairseqWav2Vec2
130 | # pretrained_path: !ref
131 | # output_norm: True
132 | # freeze: False
133 | # save_path: !ref /wav2vec2_checkpoint/model.pt
134 | #####
135 |
136 |
137 | ctc_lin: !new:speechbrain.nnet.linear.Linear
138 | input_size: !ref
139 | n_neurons: !ref
140 |
141 | log_softmax: !new:speechbrain.nnet.activations.Softmax
142 | apply_log: True
143 |
144 | ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
145 | blank_index: !ref
146 |
147 | modules:
148 | wav2vec2: !ref
149 | enc: !ref
150 | ctc_lin: !ref
151 |
152 | model: !new:torch.nn.ModuleList
153 | - [!ref , !ref ]
154 |
155 | model_opt_class: !name:torch.optim.Adadelta
156 | lr: !ref
157 | rho: 0.95
158 | eps: 1.e-8
159 |
160 | wav2vec_opt_class: !name:torch.optim.Adam
161 | lr: !ref
162 |
163 | lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler
164 | initial_value: !ref
165 | improvement_threshold: 0.0025
166 | annealing_factor: 0.8
167 | patient: 0
168 |
169 | lr_annealing_wav2vec: !new:speechbrain.nnet.schedulers.NewBobScheduler
170 | initial_value: !ref
171 | improvement_threshold: 0.0025
172 | annealing_factor: 0.9
173 | patient: 0
174 |
175 | checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
176 | checkpoints_dir: !ref
177 | recoverables:
178 | wav2vec2: !ref
179 | model: !ref
180 | scheduler_model: !ref
181 | scheduler_wav2vec: !ref
182 | counter: !ref
183 |
184 | train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
185 | save_file: !ref
186 |
187 | error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
188 |
189 | cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
190 | split_tokens: True
191 |
--------------------------------------------------------------------------------
/speechbrain/ASR/CTC/hparams/train_sw_with_wav2vec.yaml:
--------------------------------------------------------------------------------
1 | # ################################
2 | # Model: wav2vec2 + DNN + CTC
3 | # Augmentation: SpecAugment
4 | # Authors: Titouan Parcollet 2021
5 | # ################################
6 |
7 | # Seed needs to be set at top of yaml, before objects with parameters are made
8 | seed: 1234
9 | __set_seed: !!python/object/apply:torch.manual_seed [!ref ]
10 | output_folder: !ref results/wav2vec2_ctc_SW/
11 | wer_file: !ref /wer.txt
12 | save_folder: !ref /save
13 | train_log: !ref /train_log.txt
14 |
15 | # URL for the biggest LeBenchmark wav2vec french.
16 | wav2vec2_hub: facebook/wav2vec2-large-xlsr-53
17 |
18 | # Data files
19 | data_folder: #!PLACEHOLDER # e.g, /dataset/
20 | train_csv_file: !ref /texts/train2.csv # Standard CommonVoice .tsv files
21 | dev_csv_file: !ref /texts/dev2.csv # Standard CommonVoice .tsv files
22 | test_csv_file: !ref /texts/test2.csv # Standard CommonVoice .tsv files
23 | accented_letters: True
24 | language: sw # use 'it' for Italian, 'rw' for Kinyarwanda, 'en' for english
25 | train_csv: !ref /train.csv
26 | valid_csv: !ref /dev.csv
27 | test_csv: !ref /test.csv
28 | skip_prep: False # Skip data preparation
29 |
30 | # We remove utterance slonger than 10s in the train/dev/test sets as
31 | # longer sentences certainly correspond to "open microphones".
32 | avoid_if_longer_than: 15.0
33 |
34 | # Training parameters
35 | number_of_epochs: 30
36 | number_of_ctc_epochs: 15
37 | lr: 1.0
38 | lr_wav2vec: 0.0001
39 | ctc_weight: 0.3
40 | sorting: ascending
41 | auto_mix_prec: False
42 | sample_rate: 16000
43 | ckpt_interval_minutes: 30 # save checkpoint every N min
44 |
45 | # With data_parallel batch_size is split into N jobs
46 | # With DDP batch_size is multiplied by N jobs
47 | # Must be 6 per GPU to fit 16GB of VRAM
48 | batch_size: 4
49 | test_batch_size: 4
50 |
51 | dataloader_options:
52 | batch_size: !ref
53 | num_workers: 2
54 | test_dataloader_options:
55 | batch_size: !ref
56 | num_workers: 2
57 |
58 | # BPE parameters
59 | token_type: char # ["unigram", "bpe", "char"]
60 | character_coverage: 1.0
61 |
62 | # Model parameters
63 | activation: !name:torch.nn.LeakyReLU
64 | wav2vec_output_dim: 1024
65 | dnn_neurons: 1024
66 | freeze_wav2vec: False
67 |
68 | # Outputs
69 | output_neurons: 38 # BPE size, index(blank/eos/bos) = 0
70 |
71 | # Decoding parameters
72 | # Be sure that the bos and eos index match with the BPEs ones
73 | blank_index: 0
74 | bos_index: 1
75 | eos_index: 2
76 | min_decode_ratio: 0.0
77 | max_decode_ratio: 1.0
78 | beam_size: 80
79 | eos_threshold: 1.5
80 | using_max_attn_shift: True
81 | max_attn_shift: 140
82 | ctc_weight_decode: 0.0
83 | temperature: 1.50
84 |
85 | #
86 | # Functions and classes
87 | #
88 | epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
89 | limit: !ref
90 |
91 | augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
92 | sample_rate: !ref
93 | speeds: [95, 100, 105]
94 |
95 | enc: !new:speechbrain.nnet.containers.Sequential
96 | input_shape: [null, null, !ref ]
97 | linear1: !name:speechbrain.nnet.linear.Linear
98 | n_neurons: 1024
99 | bias: True
100 | bn1: !name:speechbrain.nnet.normalization.BatchNorm1d
101 | activation: !new:torch.nn.LeakyReLU
102 | drop: !new:torch.nn.Dropout
103 | p: 0.15
104 | linear2: !name:speechbrain.nnet.linear.Linear
105 | n_neurons: 1024
106 | bias: True
107 | bn2: !name:speechbrain.nnet.normalization.BatchNorm1d
108 | activation2: !new:torch.nn.LeakyReLU
109 | drop2: !new:torch.nn.Dropout
110 | p: 0.15
111 | linear3: !name:speechbrain.nnet.linear.Linear
112 | n_neurons: 1024
113 | bias: True
114 | bn3: !name:speechbrain.nnet.normalization.BatchNorm1d
115 | activation3: !new:torch.nn.LeakyReLU
116 |
117 | wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
118 | source: !ref
119 | output_norm: True
120 | freeze: !ref
121 | save_path: !ref /wav2vec2_checkpoint
122 |
123 | #####
124 | # Uncomment this block if you prefer to use a Fairseq pretrained model instead
125 | # of a HuggingFace one. Here, we provide an URL that is obtained from the
126 | # Fairseq github for the multilingual XLSR.
127 | #
128 | #wav2vec2_url: https://dl.fbaipublicfiles.com/fairseq/wav2vec/xlsr_53_56k.pt
129 | #wav2vec2: !new:speechbrain.lobes.models.fairseq_wav2vec.FairseqWav2Vec2
130 | # pretrained_path: !ref
131 | # output_norm: True
132 | # freeze: False
133 | # save_path: !ref /wav2vec2_checkpoint/model.pt
134 | #####
135 |
136 |
137 | ctc_lin: !new:speechbrain.nnet.linear.Linear
138 | input_size: !ref
139 | n_neurons: !ref
140 |
141 | log_softmax: !new:speechbrain.nnet.activations.Softmax
142 | apply_log: True
143 |
144 | ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
145 | blank_index: !ref
146 |
147 | modules:
148 | wav2vec2: !ref
149 | enc: !ref
150 | ctc_lin: !ref
151 |
152 | model: !new:torch.nn.ModuleList
153 | - [!ref , !ref ]
154 |
155 | model_opt_class: !name:torch.optim.Adadelta
156 | lr: !ref
157 | rho: 0.95
158 | eps: 1.e-8
159 |
160 | wav2vec_opt_class: !name:torch.optim.Adam
161 | lr: !ref
162 |
163 | lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler
164 | initial_value: !ref
165 | improvement_threshold: 0.0025
166 | annealing_factor: 0.8
167 | patient: 0
168 |
169 | lr_annealing_wav2vec: !new:speechbrain.nnet.schedulers.NewBobScheduler
170 | initial_value: !ref
171 | improvement_threshold: 0.0025
172 | annealing_factor: 0.9
173 | patient: 0
174 |
175 | checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
176 | checkpoints_dir: !ref
177 | recoverables:
178 | wav2vec2: !ref
179 | model: !ref
180 | scheduler_model: !ref
181 | scheduler_wav2vec: !ref
182 | counter: !ref
183 |
184 | train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
185 | save_file: !ref
186 |
187 | error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
188 |
189 | cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
190 | split_tokens: True
191 |
--------------------------------------------------------------------------------
/speechbrain/dvoice_prepare.py:
--------------------------------------------------------------------------------
1 | """
2 | Data preparation.
3 | Download: https://dvoice.ma/
4 | Author
5 | ------
6 | Abdou Mohamed Naira
7 | """
8 |
9 | import os
10 | import csv
11 | import re
12 | import logging
13 | import torchaudio
14 | import unicodedata
15 | from tqdm.contrib import tzip
16 |
17 | import torch
18 | import random
19 | import pandas as pd
20 | from tqdm import tqdm
21 | import numpy as np
22 | from speechbrain.dataio.dataio import read_audio
23 | from speechbrain.processing.speech_augmentation import SpeedPerturb
24 | from speechbrain.processing.speech_augmentation import DropChunk
25 | from speechbrain.processing.speech_augmentation import DropFreq
26 | from speechbrain.processing.speech_augmentation import DoClip
27 | from speechbrain.lobes.augment import TimeDomainSpecAugment
28 |
29 |
30 |
31 | logger = logging.getLogger(__name__)
32 |
33 | def prepare_dvoice(
34 | data_folder,
35 | save_folder,
36 | train_csv_file=None,
37 | dev_csv_file=None,
38 | test_csv_file=None,
39 | accented_letters=False,
40 | language="dar",
41 | skip_prep=False,
42 |
43 | ):
44 |
45 | if skip_prep:
46 | return
47 |
48 | # If not specified point toward standard location w.r.t CommonVoice tree
49 | if train_csv_file is None:
50 | train_csv_file = data_folder + "texts/train.csv"
51 | else:
52 | train_csv_file = train_csv_file
53 |
54 | if dev_csv_file is None:
55 | dev_csv_file = data_folder + "texts/dev.csv"
56 | else:
57 | dev_csv_file = dev_csv_file
58 |
59 | if test_csv_file is None:
60 | test_csv_file = data_folder + "texts/test.csv"
61 | else:
62 | test_csv_file = test_csv_file
63 |
64 | # Setting the save folder
65 | if not os.path.exists(save_folder):
66 | os.makedirs(save_folder)
67 |
68 | # Setting ouput files
69 | save_csv_train = save_folder + "/train.csv"
70 | save_csv_dev = save_folder + "/dev.csv"
71 | save_csv_test = save_folder + "/test.csv"
72 |
73 | # If csv already exists, we skip the data preparation
74 | if skip(save_csv_train, save_csv_dev, save_csv_test):
75 |
76 | msg = "%s already exists, skipping data preparation!" % (save_csv_train)
77 | logger.info(msg)
78 |
79 | msg = "%s already exists, skipping data preparation!" % (save_csv_dev)
80 | logger.info(msg)
81 |
82 | msg = "%s already exists, skipping data preparation!" % (save_csv_test)
83 | logger.info(msg)
84 |
85 | return
86 |
87 | # Additional checks to make sure the data folder contains Common Voice
88 | check_commonvoice_folders(data_folder)
89 |
90 | # Creating csv file for training data
91 | if train_csv_file is not None:
92 |
93 | create_csv(
94 | train_csv_file,
95 | save_csv_train,
96 | data_folder,
97 | accented_letters,
98 | language,
99 | )
100 |
101 | # Creating csv file for dev data
102 | if dev_csv_file is not None:
103 |
104 | create_csv(
105 | dev_csv_file,
106 | save_csv_dev,
107 | data_folder,
108 | accented_letters,
109 | language,
110 | )
111 |
112 | # Creating csv file for test data
113 | if test_csv_file is not None:
114 |
115 | create_csv(
116 | test_csv_file,
117 | save_csv_test,
118 | data_folder,
119 | accented_letters,
120 | language,
121 | )
122 |
123 |
124 | def train_validate_test_split(
125 | df, train_percent=0.6, validate_percent=0.2, seed=None
126 | ):
127 | np.random.seed(seed)
128 | perm = np.random.permutation(df.index)
129 | m = len(df.index)
130 | train_end = int(train_percent * m)
131 | validate_end = int(validate_percent * m) + train_end
132 | train = df.iloc[perm[:train_end]]
133 | validate = df.iloc[perm[train_end:validate_end]]
134 | test = df.iloc[perm[validate_end:]]
135 | return train, validate, test
136 |
137 |
138 | def skip(save_csv_train, save_csv_dev, save_csv_test):
139 | """
140 | Detects if the DVoice data preparation has been already done.
141 | If the preparation has been done, we can skip it.
142 | Returns
143 | -------
144 | bool
145 | if True, the preparation phase can be skipped.
146 | if False, it must be done.
147 | """
148 |
149 | # Checking folders and save options
150 | skip = False
151 |
152 | if (
153 | os.path.isfile(save_csv_train)
154 | and os.path.isfile(save_csv_dev)
155 | and os.path.isfile(save_csv_test)
156 | ):
157 | skip = True
158 |
159 | return skip
160 |
161 |
162 | def create_csv(
163 | orig_csv_file, csv_file, data_folder, accented_letters=False, language="dar"
164 | ):
165 | """
166 | Creates the csv file given a list of wav files.
167 | Arguments
168 | ---------
169 | orig_csv_file : str
170 | Path to the DVoice csv file (standard file).
171 | data_folder : str
172 | Path of the DVoice dataset.
173 | accented_letters : bool, optional
174 | Defines if accented letters will be kept as individual letters or
175 | transformed to the closest non-accented letters.
176 | Returns
177 | -------
178 | None
179 | """
180 |
181 | # Check if the given files exists
182 | if not os.path.isfile(orig_csv_file):
183 | msg = "\t%s doesn't exist, verify your dataset!" % (orig_csv_file)
184 | logger.info(msg)
185 | raise FileNotFoundError(msg)
186 |
187 | # We load and skip the header
188 | loaded_csv = open(orig_csv_file, "r").readlines()[1:]
189 | nb_samples = str(len(loaded_csv))
190 | msg = "Preparing CSV files for %s samples ..." % (str(nb_samples))
191 | logger.info(msg)
192 |
193 | # Adding some Prints
194 | msg = "Creating csv lists in %s ..." % (csv_file)
195 | logger.info(msg)
196 |
197 | csv_lines = [["ID", "duration", "wav", "spk_id", "wrd"]]
198 |
199 | # Start processing lines
200 | total_duration = 0.0
201 | for line in tzip(loaded_csv):
202 |
203 | line = line[0]
204 | # Path is at indice 1 in DVoice csv files. And .mp3 files
205 | # are located in datasets/lang/clips/
206 |
207 | mp3_path = data_folder + "/wavs/" + line.split("\t")[0]
208 | file_name = line.split("\t")[0]
209 | spk_id = line.split("\t")[0].replace(".wav", "")
210 | snt_id = file_name
211 |
212 | # Setting torchaudio backend to sox-io (needed to read mp3 files)
213 | if torchaudio.get_audio_backend() != "sox_io":
214 | logger.warning("This recipe needs the sox-io backend of torchaudio")
215 | logger.warning("The torchaudio backend is changed to sox_io")
216 | torchaudio.set_audio_backend("sox_io")
217 |
218 | duration = float(line.split("\t")[2])
219 | total_duration += duration
220 |
221 | # Getting transcript
222 | words = line.split("\t")[1]
223 |
224 | # Unicode Normalization
225 | # words = unicode_normalisation(words)
226 |
227 | # !! Language specific cleaning !!
228 | # Important: feel free to specify the text normalization
229 | # corresponding to your alphabet.
230 |
231 | if language == "dar":
232 | HAMZA = "\u0621"
233 | ALEF_MADDA = "\u0622"
234 | ALEF_HAMZA_ABOVE = "\u0623"
235 | letters = (
236 | "ابتةثجحخدذرزسشصضطظعغفقكلمنهويءآأؤإئ"
237 | + HAMZA
238 | + ALEF_MADDA
239 | + ALEF_HAMZA_ABOVE
240 | )
241 | words = re.sub("[^" + letters + "]+", " ", words).upper()
242 |
243 | # # Remove accents if specified
244 | # if not accented_letters:
245 | # words = strip_accents(words)
246 | # words = words.replace("'", " ")
247 | # words = words.replace("’", " ")
248 |
249 | # # Remove multiple spaces
250 | # words = re.sub(" +", " ", words)
251 |
252 | # # Remove spaces at the beginning and the end of the sentence
253 | # words = words.lstrip().rstrip()
254 |
255 | # # Getting chars
256 | # chars = words.replace(" ", "_")
257 | # chars = " ".join([char for char in chars][:])
258 |
259 | # Remove too short sentences (or empty):
260 | # if len(words.split(" ")) < 3:
261 | # continue
262 |
263 | # Composition of the csv_line
264 | csv_line = [snt_id, str(duration), mp3_path, spk_id, str(words)]
265 |
266 | # Adding this line to the csv_lines list
267 | csv_lines.append(csv_line)
268 |
269 | # Writing the csv lines
270 | with open(csv_file, mode="w", encoding="utf-8") as csv_f:
271 | csv_writer = csv.writer(
272 | csv_f, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL
273 | )
274 |
275 | for line in csv_lines:
276 | csv_writer.writerow(line)
277 |
278 | # Final prints
279 | msg = "%s successfully created!" % (csv_file)
280 | logger.info(msg)
281 | msg = "Number of samples: %s " % (str(len(loaded_csv)))
282 | logger.info(msg)
283 | msg = "Total duration: %s Hours" % (str(round(total_duration / 3600, 2)))
284 | logger.info(msg)
285 |
286 |
287 | def check_commonvoice_folders(data_folder):
288 | """
289 | Check if the data folder actually contains the DVoice dataset.
290 | If not, raises an error.
291 | Returns
292 | -------
293 | None
294 | Raises
295 | ------
296 | FileNotFoundError
297 | If data folder doesn't contain DVoice dataset.
298 | """
299 |
300 | files_str = "/wavs"
301 |
302 | # Checking clips
303 | if not os.path.exists(data_folder + files_str):
304 |
305 | err_msg = (
306 | "the folder %s does not exist (it is expected in "
307 | "the DVoice dataset)" % (data_folder + files_str)
308 | )
309 | raise FileNotFoundError(err_msg)
310 |
311 |
312 | def unicode_normalisation(text):
313 |
314 | try:
315 | text = unicode(text, "utf-8")
316 | except NameError: # unicode is a default on python 3
317 | pass
318 | return str(text)
319 |
320 |
321 | def strip_accents(text):
322 |
323 | text = (
324 | unicodedata.normalize("NFD", text)
325 | .encode("ascii", "ignore")
326 | .decode("utf-8")
327 | )
328 |
329 | return str(text)
330 |
331 |
--------------------------------------------------------------------------------
/wav2vec 2.0/README.md:
--------------------------------------------------------------------------------
1 | # 1. Description du modèle
2 | Nous proposons dans ce projet les étapes du fine-tuning du modèle XLSR-wav2vec2.0. XLSR-53, pour "Unsupervised Cross-Lingual Representation Learning For Speech Recognition" est une version multilingue du modèle [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et Al., 2020)](https://arxiv.org/abs/2006.11477). Il a été entraîné sur trois bases de données reparties sur 53 langues pour un total de 56.000 heures d'audios.
3 |
4 | Modèle | Architecture | Nombre d'heures | Langues | Datasets | Lien
5 | |---|---|---|---|---|---
6 | XLSR-53 | Large | 56k | 53 | MLS, CommonVoice, BABEL | [Télécharger](https://dl.fbaipublicfiles.com/fairseq/wav2vec/xlsr_53_56k.pt)
7 |
8 | XLSR-53 détecte les paramètres sur des données non annotées comme il est décrit sur [Unsupervised Cross-lingual Representation Learning for Speech Recognition (Conneau et Al. 2020)](https://arxiv.org/abs/2006.13979). Pour effectuer le fine-tuning sur le Darija et l'Arabe on utilise une Classification Temporelle Connexionniste (CTC) sur des données annotées. CTC est un algorithme utilisé sur des réseaux de neurones pour des problèmes de séquence à séquence, dans notre cas, pour la reconnaissance automatique de la parole. Il permet d'identifier des labels sur les échantillons d'audios d'entraînement.
9 |
10 |
11 |
12 | Version |
13 | Source |
14 | Taille |
15 | Langue |
16 | Speech Augmentation |
17 | WER |
18 | Durée |
19 | Epochs |
20 | Demo |
21 |
22 |
23 |
24 |
25 | Bêta |
26 | Mozilla Common Voice |
27 | 1500 |
28 | Arabe |
29 | Non |
30 | 0.7 |
31 | 4h |
32 | 30 |
33 |  |
34 |
35 |
36 |
37 | Bêta 2 |
38 | Facebook / Youtube |
39 | 2400 |
40 | Darija |
41 | Non |
42 | 0.9 |
43 | 5h |
44 | 30 |
45 | -------- |
46 |
47 |
48 |
49 | Version 1.0 |
50 | Facebook / Youtube / Dvoice |
51 | 13000 |
52 | Darija |
53 | Oui |
54 | 0.3 |
55 | 12h |
56 | 10 |
57 |  |
58 |
59 |
60 |
61 |
62 |
63 |
64 | - WER : Word Error Rate
65 |
66 | # 2. Entraînement du modèle "from scratch"
67 | Pour entraîner un modèle XLSR-53 "from scratch", nous recommandons de suivre les étapes décrit [ici](https://github.com/pytorch/fairseq/tree/master/examples/wav2vec). On y retrouve l'ensemble des modèles du projet wav2vec y compris sa version multilingue XLSR-53.
68 |
69 | - PyTorch version >= 1.5.0
70 | - Python version >= 3.6
71 | - Pour entraîner de nouveaux modèles on a besoin d'avoir un GPU NVIDIA et [NCCL](https://github.com/NVIDIA/nccl).
72 | - Installation pour un dévéloppement en local:
73 |
74 | ``` bash
75 | git clone https://github.com/pytorch/fairseq
76 | cd fairseq
77 | pip install --editable ./
78 |
79 | # on MacOS:
80 | # CFLAGS="-stdlib=libc++" pip install --editable ./
81 |
82 | # to install the latest stable release (0.10.x)
83 | # pip install fairseq
84 | ```
85 | Pour un entraînement rapide, installer la librairie [apex](https://github.com/NVIDIA/apex) de NVIDIA:
86 |
87 | ``` bash
88 | git clone https://github.com/NVIDIA/apex
89 | cd apex
90 | pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" \
91 | --global-option="--deprecated_fused_adam" --global-option="--xentropy" \
92 | --global-option="--fast_multihead_attn" ./
93 | ```
94 | Lorsque la dataset est trop volumineuse, installer [PyArrow](https://arrow.apache.org/docs/python/install.html#using-pip) : `pip install pyarrow`.
95 |
96 | #### a. Entrainement depuis une ligne de commandes
97 | Le modèle prend en entrée des fichiers. Il est recommandé de fragmenter les audios en morceaux de 10 à 30 secondes.
98 | - Préparation
99 |
100 | Installer la librairie `soundfile`:
101 | ```shell script
102 | pip install soundfile
103 | ```
104 |
105 | Ensuite:
106 |
107 | ```shell script
108 | $ python wav2vec/wav2vec_manifest.py /path/to/waves --dest /manifest/path --ext $ext --valid-percent $valid
109 | ```
110 | $ext : peut être n'importe quel format audio (mp3, flav, wav,...) tant que sounfile peut lire.
111 |
112 | $valid : prendre une petite portion (10% par exemple) des données d'entraînement pour la validation
113 |
114 | #### b. Entraînement du modèle wav2vec2.0 de base
115 | Les audios entrée doivent être à un canal avec une fréquence d'échantillonnage de 16kHz. Celle-ci est la configuration utilisée dans l'article de wav2vec2.0.
116 |
117 | ```shell script
118 | $ fairseq-hydra-train \
119 | task.data=/path/to/data \
120 | --config-dir wav2vec/config/pretraining \
121 | --config-name wav2vec2_base_librispeech
122 | ```
123 | # 3. Fine-tuning
124 | Pour faire un fine-tuning du modèle, nous recommandons de suivre les étapes listées sur ce Notebook : [Fine-tuning de XLSR-53 sur le Darija](https://github.com/nairaxo/dialectal-voice-clone/blob/main/wav2vec%202.0/xlsr_wav2vec2_darija_finetuning.ipynb).
125 |
--------------------------------------------------------------------------------
/wav2vec 2.0/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIOXLABS/DVoice/19b760ca93ed9016d406bd9baf077e49bab384ec/wav2vec 2.0/__init__.py
--------------------------------------------------------------------------------
/wav2vec 2.0/config/finetuning/base_100h.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 |
3 | common:
4 | fp16: true
5 | log_format: json
6 | log_interval: 200
7 |
8 | checkpoint:
9 | no_epoch_checkpoints: true
10 | best_checkpoint_metric: wer
11 |
12 | task:
13 | _name: audio_finetuning
14 | data: ???
15 | normalize: false
16 | labels: ltr
17 |
18 | dataset:
19 | num_workers: 6
20 | max_tokens: 3200000
21 | skip_invalid_size_inputs_valid_test: true
22 | valid_subset: dev_other
23 |
24 | distributed_training:
25 | ddp_backend: legacy_ddp
26 | distributed_world_size: 2
27 |
28 | criterion:
29 | _name: ctc
30 | zero_infinity: true
31 |
32 | optimization:
33 | max_update: 80000
34 | lr: [0.00003]
35 | sentence_avg: true
36 | update_freq: [4]
37 |
38 | optimizer:
39 | _name: adam
40 | adam_betas: (0.9,0.98)
41 | adam_eps: 1e-08
42 |
43 | lr_scheduler:
44 | _name: tri_stage
45 | phase_ratio: [0.1, 0.4, 0.5]
46 | final_lr_scale: 0.05
47 |
48 | model:
49 | _name: wav2vec_ctc
50 | w2v_path: ???
51 | apply_mask: true
52 | mask_prob: 0.65
53 | mask_channel_prob: 0.5
54 | mask_channel_length: 64
55 | layerdrop: 0.1
56 | activation_dropout: 0.1
57 | feature_grad_mult: 0.0
58 | freeze_finetune_updates: 0
59 |
--------------------------------------------------------------------------------
/wav2vec 2.0/config/finetuning/base_10h.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 |
3 | common:
4 | fp16: true
5 | log_format: json
6 | log_interval: 200
7 |
8 | checkpoint:
9 | save_interval: 50
10 | save_interval_updates: 10000
11 | keep_interval_updates: 1
12 | no_epoch_checkpoints: true
13 | best_checkpoint_metric: wer
14 |
15 | task:
16 | _name: audio_finetuning
17 | data: ???
18 | normalize: false
19 | labels: ltr
20 |
21 | dataset:
22 | num_workers: 6
23 | max_tokens: 3200000
24 | skip_invalid_size_inputs_valid_test: true
25 | validate_after_updates: 10000
26 | validate_interval: 50
27 | valid_subset: dev_other
28 |
29 | distributed_training:
30 | ddp_backend: legacy_ddp
31 | distributed_world_size: 2
32 |
33 | criterion:
34 | _name: ctc
35 | zero_infinity: true
36 |
37 | optimization:
38 | max_update: 20000
39 | lr: [0.00005]
40 | sentence_avg: true
41 | update_freq: [4]
42 |
43 | optimizer:
44 | _name: adam
45 | adam_betas: (0.9,0.98)
46 | adam_eps: 1e-08
47 |
48 | lr_scheduler:
49 | _name: tri_stage
50 | phase_ratio: [0.1, 0.4, 0.5]
51 | final_lr_scale: 0.05
52 |
53 | model:
54 | _name: wav2vec_ctc
55 | w2v_path: ???
56 | apply_mask: true
57 | mask_prob: 0.65
58 | mask_channel_prob: 0.5
59 | mask_channel_length: 64
60 | layerdrop: 0.05
61 | activation_dropout: 0.1
62 | feature_grad_mult: 0.0
63 | freeze_finetune_updates: 10000
64 |
--------------------------------------------------------------------------------
/wav2vec 2.0/config/finetuning/base_10m.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 |
3 | common:
4 | fp16: true
5 | log_format: json
6 | log_interval: 200
7 |
8 | checkpoint:
9 | save_interval: 1000
10 | save_interval_updates: 50
11 | keep_interval_updates: 1
12 | no_epoch_checkpoints: true
13 | best_checkpoint_metric: wer
14 |
15 | task:
16 | _name: audio_finetuning
17 | data: ???
18 | normalize: false
19 | labels: ltr
20 |
21 | dataset:
22 | num_workers: 6
23 | max_tokens: 3200000
24 | skip_invalid_size_inputs_valid_test: true
25 | validate_after_updates: 10000
26 | validate_interval: 1000
27 | valid_subset: dev_other
28 |
29 | distributed_training:
30 | ddp_backend: legacy_ddp
31 | distributed_world_size: 2
32 |
33 | criterion:
34 | _name: ctc
35 | zero_infinity: true
36 |
37 | optimization:
38 | max_update: 13000
39 | lr: [0.00005]
40 | sentence_avg: true
41 | update_freq: [4]
42 |
43 | optimizer:
44 | _name: adam
45 | adam_betas: (0.9,0.98)
46 | adam_eps: 1e-08
47 |
48 | lr_scheduler:
49 | _name: tri_stage
50 | phase_ratio: [0.1, 0.4, 0.5]
51 | final_lr_scale: 0.05
52 |
53 | model:
54 | _name: wav2vec_ctc
55 | w2v_path: ???
56 | apply_mask: true
57 | mask_prob: 0.65
58 | mask_channel_prob: 0.25
59 | mask_channel_length: 64
60 | layerdrop: 0.1
61 | activation_dropout: 0.1
62 | feature_grad_mult: 0.0
63 | freeze_finetune_updates: 10000
64 |
--------------------------------------------------------------------------------
/wav2vec 2.0/config/finetuning/base_1h.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 |
3 | common:
4 | fp16: true
5 | log_format: json
6 | log_interval: 200
7 |
8 | checkpoint:
9 | save_interval: 1000
10 | save_interval_updates: 50
11 | keep_interval_updates: 1
12 | no_epoch_checkpoints: true
13 | best_checkpoint_metric: wer
14 |
15 | task:
16 | _name: audio_finetuning
17 | data: ???
18 | normalize: false
19 | labels: ltr
20 |
21 | dataset:
22 | num_workers: 6
23 | max_tokens: 3200000
24 | skip_invalid_size_inputs_valid_test: true
25 | validate_after_updates: 10000
26 | validate_interval: 1000
27 | valid_subset: dev_other
28 |
29 | distributed_training:
30 | ddp_backend: legacy_ddp
31 | distributed_world_size: 2
32 |
33 | criterion:
34 | _name: ctc
35 | zero_infinity: true
36 |
37 | optimization:
38 | max_update: 13000
39 | lr: [0.00005]
40 | sentence_avg: true
41 | update_freq: [4]
42 |
43 | optimizer:
44 | _name: adam
45 | adam_betas: (0.9,0.98)
46 | adam_eps: 1e-08
47 |
48 | lr_scheduler:
49 | _name: tri_stage
50 | phase_ratio: [0.1, 0.4, 0.5]
51 | final_lr_scale: 0.05
52 |
53 | model:
54 | _name: wav2vec_ctc
55 | w2v_path: ???
56 | apply_mask: true
57 | mask_prob: 0.65
58 | mask_channel_prob: 0.25
59 | mask_channel_length: 64
60 | layerdrop: 0.1
61 | activation_dropout: 0.1
62 | feature_grad_mult: 0.0
63 | freeze_finetune_updates: 10000
64 |
--------------------------------------------------------------------------------
/wav2vec 2.0/config/finetuning/base_960h.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 |
3 | common:
4 | fp16: true
5 | log_format: json
6 | log_interval: 200
7 |
8 | checkpoint:
9 | no_epoch_checkpoints: true
10 | best_checkpoint_metric: wer
11 |
12 | task:
13 | _name: audio_finetuning
14 | data: ???
15 | normalize: false
16 | labels: ltr
17 |
18 | dataset:
19 | num_workers: 6
20 | max_tokens: 3200000
21 | skip_invalid_size_inputs_valid_test: true
22 | valid_subset: dev_other
23 |
24 | distributed_training:
25 | ddp_backend: legacy_ddp
26 | distributed_world_size: 8
27 |
28 | criterion:
29 | _name: ctc
30 | zero_infinity: true
31 |
32 | optimization:
33 | max_update: 320000
34 | lr: [0.0001]
35 | sentence_avg: true
36 |
37 | optimizer:
38 | _name: adam
39 | adam_betas: (0.9,0.98)
40 | adam_eps: 1e-08
41 |
42 | lr_scheduler:
43 | _name: tri_stage
44 | phase_ratio: [0.1, 0.4, 0.5]
45 | final_lr_scale: 0.05
46 |
47 | model:
48 | _name: wav2vec_ctc
49 | w2v_path: ???
50 | apply_mask: true
51 | mask_prob: 0.5
52 | mask_channel_prob: 0.1
53 | mask_channel_length: 64
54 | layerdrop: 0.1
55 | activation_dropout: 0.1
56 | feature_grad_mult: 0.0
57 | freeze_finetune_updates: 0
58 |
--------------------------------------------------------------------------------
/wav2vec 2.0/config/finetuning/vox_100h.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 |
3 | common:
4 | fp16: true
5 | log_format: json
6 | log_interval: 200
7 |
8 | checkpoint:
9 | no_epoch_checkpoints: true
10 | best_checkpoint_metric: wer
11 |
12 | task:
13 | _name: audio_finetuning
14 | data: ???
15 | normalize: true
16 | labels: ltr
17 |
18 | dataset:
19 | num_workers: 6
20 | max_tokens: 1280000
21 | skip_invalid_size_inputs_valid_test: true
22 | valid_subset: dev_other
23 |
24 | distributed_training:
25 | ddp_backend: legacy_ddp
26 | distributed_world_size: 4
27 |
28 | criterion:
29 | _name: ctc
30 | zero_infinity: true
31 |
32 | optimization:
33 | max_update: 80000
34 | lr: [0.00003]
35 | sentence_avg: true
36 | update_freq: [5]
37 |
38 | optimizer:
39 | _name: adam
40 | adam_betas: (0.9,0.98)
41 | adam_eps: 1e-08
42 |
43 | lr_scheduler:
44 | _name: tri_stage
45 | phase_ratio: [0.1, 0.4, 0.5]
46 | final_lr_scale: 0.05
47 |
48 | model:
49 | _name: wav2vec_ctc
50 | w2v_path: ???
51 | apply_mask: true
52 | mask_prob: 0.5
53 | mask_channel_prob: 0.5
54 | mask_channel_length: 64
55 | layerdrop: 0.1
56 | activation_dropout: 0.1
57 | feature_grad_mult: 0.0
58 | freeze_finetune_updates: 10000
59 |
--------------------------------------------------------------------------------
/wav2vec 2.0/config/finetuning/vox_10h.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 |
3 | common:
4 | fp16: true
5 | log_format: json
6 | log_interval: 200
7 |
8 | checkpoint:
9 | save_interval: 50
10 | save_interval_updates: 10000
11 | keep_interval_updates: 1
12 | no_epoch_checkpoints: true
13 | best_checkpoint_metric: wer
14 |
15 | task:
16 | _name: audio_finetuning
17 | data: ???
18 | normalize: true
19 | labels: ltr
20 |
21 | dataset:
22 | num_workers: 6
23 | max_tokens: 1280000
24 | skip_invalid_size_inputs_valid_test: true
25 | validate_after_updates: 10000
26 | validate_interval: 50
27 | valid_subset: dev_other
28 |
29 | distributed_training:
30 | ddp_backend: legacy_ddp
31 | distributed_world_size: 4
32 |
33 | criterion:
34 | _name: ctc
35 | zero_infinity: true
36 |
37 | optimization:
38 | max_update: 20000
39 | lr: [0.0001]
40 | sentence_avg: true
41 | update_freq: [5]
42 |
43 | optimizer:
44 | _name: adam
45 | adam_betas: (0.9,0.98)
46 | adam_eps: 1e-08
47 |
48 | lr_scheduler:
49 | _name: tri_stage
50 | phase_ratio: [0.1, 0.4, 0.5]
51 | final_lr_scale: 0.05
52 |
53 | model:
54 | _name: wav2vec_ctc
55 | w2v_path: ???
56 | apply_mask: true
57 | mask_prob: 0.75
58 | mask_channel_prob: 0.25
59 | mask_channel_length: 64
60 | layerdrop: 0.1
61 | activation_dropout: 0.1
62 | feature_grad_mult: 0.0
63 | freeze_finetune_updates: 10000
64 |
--------------------------------------------------------------------------------
/wav2vec 2.0/config/finetuning/vox_10m.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 |
3 | common:
4 | fp16: true
5 | log_format: json
6 | log_interval: 200
7 |
8 | checkpoint:
9 | save_interval: 1000
10 | save_interval_updates: 50
11 | keep_interval_updates: 1
12 | no_epoch_checkpoints: true
13 | best_checkpoint_metric: wer
14 |
15 | task:
16 | _name: audio_finetuning
17 | data: ???
18 | normalize: true
19 | labels: ltr
20 |
21 | dataset:
22 | num_workers: 6
23 | max_tokens: 1280000
24 | skip_invalid_size_inputs_valid_test: true
25 | validate_after_updates: 10000
26 | validate_interval: 1000
27 | valid_subset: dev_other
28 |
29 | distributed_training:
30 | ddp_backend: legacy_ddp
31 | distributed_world_size: 4
32 |
33 | criterion:
34 | _name: ctc
35 | zero_infinity: true
36 |
37 | optimization:
38 | max_update: 13000
39 | lr: [0.0001]
40 | sentence_avg: true
41 | update_freq: [5]
42 |
43 | optimizer:
44 | _name: adam
45 | adam_betas: (0.9,0.98)
46 | adam_eps: 1e-08
47 |
48 | lr_scheduler:
49 | _name: tri_stage
50 | phase_ratio: [0.1, 0.4, 0.5]
51 | final_lr_scale: 0.05
52 |
53 | model:
54 | _name: wav2vec_ctc
55 | w2v_path: ???
56 | apply_mask: true
57 | mask_prob: 0.65
58 | mask_channel_prob: 0.25
59 | mask_channel_length: 64
60 | layerdrop: 0.1
61 | activation_dropout: 0.1
62 | feature_grad_mult: 0.0
63 | freeze_finetune_updates: 10000
64 |
--------------------------------------------------------------------------------
/wav2vec 2.0/config/finetuning/vox_1h.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 |
3 | common:
4 | fp16: true
5 | log_format: json
6 | log_interval: 200
7 |
8 | checkpoint:
9 | save_interval: 1000
10 | save_interval_updates: 50
11 | keep_interval_updates: 1
12 | no_epoch_checkpoints: true
13 | best_checkpoint_metric: wer
14 |
15 | task:
16 | _name: audio_finetuning
17 | data: ???
18 | normalize: true
19 | labels: ltr
20 |
21 | dataset:
22 | num_workers: 6
23 | max_tokens: 1280000
24 | skip_invalid_size_inputs_valid_test: true
25 | validate_after_updates: 10000
26 | validate_interval: 1000
27 | valid_subset: dev_other
28 |
29 | distributed_training:
30 | ddp_backend: legacy_ddp
31 | distributed_world_size: 4
32 |
33 | criterion:
34 | _name: ctc
35 | zero_infinity: true
36 |
37 | optimization:
38 | max_update: 13000
39 | lr: [0.0003]
40 | sentence_avg: true
41 | update_freq: [5]
42 |
43 | optimizer:
44 | _name: adam
45 | adam_betas: (0.9,0.98)
46 | adam_eps: 1e-08
47 |
48 | lr_scheduler:
49 | _name: tri_stage
50 | phase_ratio: [0.1, 0.4, 0.5]
51 | final_lr_scale: 0.05
52 |
53 | model:
54 | _name: wav2vec_ctc
55 | w2v_path: ???
56 | apply_mask: true
57 | mask_prob: 0.75
58 | mask_channel_prob: 0.25
59 | mask_channel_length: 64
60 | layerdrop: 0.1
61 | activation_dropout: 0.1
62 | feature_grad_mult: 0.0
63 | freeze_finetune_updates: 10000
64 |
--------------------------------------------------------------------------------
/wav2vec 2.0/config/finetuning/vox_960h.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 |
3 | common:
4 | fp16: true
5 | log_format: json
6 | log_interval: 200
7 |
8 | checkpoint:
9 | no_epoch_checkpoints: true
10 | best_checkpoint_metric: wer
11 |
12 | task:
13 | _name: audio_finetuning
14 | data: ???
15 | normalize: true
16 | labels: ltr
17 |
18 | dataset:
19 | num_workers: 6
20 | max_tokens: 1280000
21 | skip_invalid_size_inputs_valid_test: true
22 | valid_subset: dev_other
23 |
24 | distributed_training:
25 | ddp_backend: legacy_ddp
26 | distributed_world_size: 24
27 |
28 | criterion:
29 | _name: ctc
30 | zero_infinity: true
31 |
32 | optimization:
33 | max_update: 320000
34 | lr: [0.00003]
35 | sentence_avg: true
36 |
37 | optimizer:
38 | _name: adam
39 | adam_betas: (0.9,0.98)
40 | adam_eps: 1e-08
41 |
42 | lr_scheduler:
43 | _name: tri_stage
44 | phase_ratio: [0.1, 0.4, 0.5]
45 | final_lr_scale: 0.05
46 |
47 | model:
48 | _name: wav2vec_ctc
49 | w2v_path: ???
50 | apply_mask: true
51 | mask_prob: 0.5
52 | mask_channel_prob: 0.25
53 | mask_channel_length: 64
54 | layerdrop: 0.1
55 | activation_dropout: 0.1
56 | feature_grad_mult: 0.0
57 | freeze_finetune_updates: 10000
58 |
--------------------------------------------------------------------------------
/wav2vec 2.0/config/pretraining/wav2vec2_base_librispeech.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 |
3 | common:
4 | fp16: true
5 | log_format: json
6 | log_interval: 200
7 |
8 | checkpoint:
9 | save_interval_updates: 25000
10 | keep_interval_updates: 1
11 | no_epoch_checkpoints: true
12 |
13 | task:
14 | _name: audio_pretraining
15 | data: ???
16 | max_sample_size: 250000
17 | min_sample_size: 32000
18 | normalize: false
19 |
20 | dataset:
21 | num_workers: 6
22 | max_tokens: 1400000
23 | skip_invalid_size_inputs_valid_test: true
24 |
25 | distributed_training:
26 | distributed_world_size: 64
27 | ddp_backend: legacy_ddp
28 |
29 | criterion:
30 | _name: wav2vec
31 | infonce: true
32 | log_keys: ["prob_perplexity","code_perplexity","temp"]
33 | loss_weights: [0.1, 10]
34 |
35 | optimization:
36 | max_update: 400000
37 | lr: [0.0005]
38 |
39 | optimizer:
40 | _name: adam
41 | adam_betas: (0.9,0.98)
42 | adam_eps: 1e-06
43 | weight_decay: 0.01
44 |
45 | lr_scheduler:
46 | _name: polynomial_decay
47 | warmup_updates: 32000
48 |
49 | model:
50 | _name: wav2vec2
51 | quantize_targets: true
52 | final_dim: 256
53 | encoder_layerdrop: 0.05
54 | dropout_input: 0.1
55 | dropout_features: 0.1
56 | feature_grad_mult: 0.1
57 | encoder_embed_dim: 768
58 |
--------------------------------------------------------------------------------
/wav2vec 2.0/config/pretraining/wav2vec2_large_librivox.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 |
3 | common:
4 | fp16: true
5 | log_format: json
6 | log_interval: 200
7 |
8 | checkpoint:
9 | save_interval_updates: 25000
10 | keep_interval_updates: 1
11 | no_epoch_checkpoints: true
12 |
13 | task:
14 | _name: audio_pretraining
15 | data: ???
16 | max_sample_size: 320000
17 | min_sample_size: 32000
18 | normalize: true
19 |
20 | dataset:
21 | batch_size: 4
22 | num_workers: 6
23 | max_tokens: 1200000
24 | skip_invalid_size_inputs_valid_test: true
25 |
26 | distributed_training:
27 | distributed_world_size: 128
28 | ddp_backend: legacy_ddp
29 |
30 | criterion:
31 | _name: wav2vec
32 | infonce: true
33 | log_keys: ["prob_perplexity","code_perplexity","temp"]
34 | loss_weights: [0.1, 0]
35 |
36 | optimization:
37 | max_update: 1000000
38 | lr: [0.005]
39 |
40 | optimizer:
41 | _name: adam
42 | adam_betas: (0.9,0.98)
43 | adam_eps: 1e-06
44 | weight_decay: 0.01
45 |
46 | lr_scheduler:
47 | _name: polynomial_decay
48 | warmup_updates: 32000
49 |
50 | model:
51 | _name: wav2vec2
52 | quantize_targets: true
53 | extractor_mode: layer_norm
54 | layer_norm_first: true
55 | final_dim: 768
56 | latent_temp: [2.0,0.1,0.999995]
57 | encoder_layerdrop: 0.00
58 | dropout_input: 0.0
59 | dropout_features: 0.0
60 | dropout: 0.0
61 | attention_dropout: 0.0
62 | conv_bias: true
63 |
64 | encoder_layers: 24
65 | encoder_embed_dim: 1024
66 | encoder_ffn_embed_dim: 4096
67 | encoder_attention_heads: 16
68 |
69 | feature_grad_mult: 1.0
70 |
71 |
--------------------------------------------------------------------------------
/wav2vec 2.0/config/pretraining/wav2vec2_large_librivox_tpu-pod.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 |
3 | common:
4 | tpu: true
5 | fp16: false
6 | log_format: json
7 | log_interval: 10
8 |
9 | checkpoint:
10 | save_interval_updates: 25000
11 | keep_interval_updates: 1
12 | no_epoch_checkpoints: true
13 |
14 | task:
15 | _name: audio_pretraining
16 | data: ???
17 | max_sample_size: 250000
18 | min_sample_size: 32000
19 | normalize: true
20 | num_batch_buckets: 3
21 | precompute_mask_indices: true
22 | enable_padding: true
23 |
24 | dataset:
25 | num_workers: 6
26 | max_tokens: 1200000
27 | skip_invalid_size_inputs_valid_test: true
28 |
29 | distributed_training:
30 | distributed_world_size: 128
31 | ddp_backend: legacy_ddp
32 |
33 | criterion:
34 | _name: wav2vec
35 | infonce: true
36 | log_keys: ["prob_perplexity","code_perplexity","temp"]
37 | loss_weights: [0.1, 0]
38 |
39 | optimization:
40 | max_update: 1000000
41 | lr: [0.005]
42 |
43 | optimizer:
44 | _name: adam
45 | adam_betas: (0.9,0.98)
46 | adam_eps: 1e-06
47 | weight_decay: 0.01
48 |
49 | lr_scheduler:
50 | _name: polynomial_decay
51 | warmup_updates: 32000
52 |
53 | model:
54 | _name: wav2vec2
55 | quantize_targets: true
56 | extractor_mode: layer_norm
57 | layer_norm_first: true
58 | final_dim: 768
59 | latent_temp: [2.0,0.1,0.999995]
60 | encoder_layerdrop: 0.00
61 | dropout_input: 0.0
62 | dropout_features: 0.0
63 | dropout: 0.0
64 | attention_dropout: 0.0
65 | conv_bias: true
66 |
67 | encoder_layers: 24
68 | encoder_embed_dim: 1024
69 | encoder_ffn_embed_dim: 4096
70 | encoder_attention_heads: 16
71 |
72 | feature_grad_mult: 1.0
73 |
--------------------------------------------------------------------------------
/wav2vec 2.0/config/pretraining/wav2vec2_large_librivox_tpu.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 |
3 | common:
4 | tpu: true
5 | fp16: false
6 | log_format: json
7 | log_interval: 10
8 |
9 | checkpoint:
10 | save_interval_updates: 25000
11 | keep_interval_updates: 1
12 | no_epoch_checkpoints: true
13 |
14 | task:
15 | _name: audio_pretraining
16 | data: ???
17 | max_sample_size: 250000
18 | min_sample_size: 32000
19 | normalize: true
20 | num_batch_buckets: 3
21 | precompute_mask_indices: true
22 | enable_padding: true
23 | inferred_w2v_config:
24 | mask_prob: 0.65
25 | mask_selection: 'static'
26 | mask_other: 0
27 | mask_channel_prob: 0.1
28 |
29 | dataset:
30 | num_workers: 6
31 | max_tokens: 1200000
32 | skip_invalid_size_inputs_valid_test: true
33 |
34 | distributed_training:
35 | distributed_world_size: 8
36 | ddp_backend: legacy_ddp
37 |
38 | criterion:
39 | _name: wav2vec
40 | infonce: true
41 | log_keys: ["prob_perplexity","code_perplexity","temp"]
42 | loss_weights: [0.1, 0]
43 |
44 | optimization:
45 | max_update: 1000000
46 | lr: [0.005]
47 |
48 | optimizer:
49 | _name: adam
50 | adam_betas: (0.9,0.98)
51 | adam_eps: 1e-06
52 | weight_decay: 0.01
53 |
54 | lr_scheduler:
55 | _name: polynomial_decay
56 | warmup_updates: 32000
57 |
58 | model:
59 | _name: wav2vec2
60 | quantize_targets: true
61 | extractor_mode: layer_norm
62 | layer_norm_first: true
63 | final_dim: 768
64 | latent_temp: [2.0,0.1,0.999995]
65 | encoder_layerdrop: 0.00
66 | dropout_input: 0.0
67 | dropout_features: 0.0
68 | dropout: 0.0
69 | attention_dropout: 0.0
70 | conv_bias: true
71 |
72 | encoder_layers: 24
73 | encoder_embed_dim: 1024
74 | encoder_ffn_embed_dim: 4096
75 | encoder_attention_heads: 16
76 |
77 | feature_grad_mult: 1.0
78 |
--------------------------------------------------------------------------------
/wav2vec 2.0/libri_labels.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """
8 | Helper script to pre-compute embeddings for a flashlight (previously called wav2letter++) dataset
9 | """
10 |
11 | import argparse
12 | import os
13 |
14 |
15 | def main():
16 | parser = argparse.ArgumentParser()
17 | parser.add_argument("tsv")
18 | parser.add_argument("--output-dir", required=True)
19 | parser.add_argument("--output-name", required=True)
20 | args = parser.parse_args()
21 |
22 | os.makedirs(args.output_dir, exist_ok=True)
23 |
24 | transcriptions = {}
25 |
26 | with open(args.tsv, "r") as tsv, open(
27 | os.path.join(args.output_dir, args.output_name + ".ltr"), "w"
28 | ) as ltr_out, open(
29 | os.path.join(args.output_dir, args.output_name + ".wrd"), "w"
30 | ) as wrd_out:
31 | root = next(tsv).strip()
32 | for line in tsv:
33 | line = line.strip()
34 | dir = os.path.dirname(line)
35 | if dir not in transcriptions:
36 | parts = dir.split(os.path.sep)
37 | trans_path = f"{parts[-2]}-{parts[-1]}.trans.txt"
38 | path = os.path.join(root, dir, trans_path)
39 | assert os.path.exists(path)
40 | texts = {}
41 | with open(path, "r") as trans_f:
42 | for tline in trans_f:
43 | items = tline.strip().split()
44 | texts[items[0]] = " ".join(items[1:])
45 | transcriptions[dir] = texts
46 | part = os.path.basename(line).split(".")[0]
47 | assert part in transcriptions[dir]
48 | print(transcriptions[dir][part], file=wrd_out)
49 | print(
50 | " ".join(list(transcriptions[dir][part].replace(" ", "|"))) + " |",
51 | file=ltr_out,
52 | )
53 |
54 |
55 | if __name__ == "__main__":
56 | main()
57 |
--------------------------------------------------------------------------------
/wav2vec 2.0/scripts/binarize_manifest.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # usage: bash binarize_manifest
4 |
5 | DEST_DIR=$1
6 | TRAIN_SPLIT=$2
7 | VALID_SPLIT=$3
8 | FAIRSEQ_ROOT=$4
9 |
10 | mkdir -p $DEST_DIR
11 |
12 | # split file path and lengths into separate files
13 | cut -f1 $TRAIN_SPLIT.tsv > $DEST_DIR/train_fnames.txt
14 | cut -f1 $VALID_SPLIT.tsv > $DEST_DIR/valid_fnames.txt
15 | cut -f2 $TRAIN_SPLIT.tsv > $DEST_DIR/train.lengths
16 | cut -f2 $VALID_SPLIT.tsv > $DEST_DIR/valid.lengths
17 |
18 | # copy root directory
19 | head -1 $TRAIN_SPLIT.tsv > $DEST_DIR/train.root
20 | head -1 $VALID_SPLIT.tsv > $DEST_DIR/valid.root
21 |
22 | # remove root directory
23 | sed -i '1d' $DEST_DIR/train_fnames.txt
24 | sed -i '1d' $DEST_DIR/valid_fnames.txt
25 | sed -i '1d' $DEST_DIR/train.lengths
26 | sed -i '1d' $DEST_DIR/valid.lengths
27 |
28 | # insert spaces between characters
29 | sed -i -e 's/\(.\)/\1 /g' $DEST_DIR/train_fnames.txt
30 | sed -i -e 's/\(.\)/\1 /g' $DEST_DIR/valid_fnames.txt
31 |
32 | # run preprocessor
33 | PYTHONPATH=$FAIRSEQ_ROOT python $FAIRSEQ_ROOT/fairseq_cli/preprocess.py --dataset-impl mmap --trainpref $DEST_DIR/train_fnames.txt --validpref $DEST_DIR/valid_fnames.txt --workers 60 --only-source --destdir $DEST_DIR
34 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/README.md:
--------------------------------------------------------------------------------
1 | # wav2vec Unsupervised (wav2vec-U)
2 |
3 | Wav2vec Unsupervised (wav2vec-U) is a framework for building speech recognition systems without any labeled training data as described in [Unsupervised Speech Recognition (Baevski et al., 2021)](https://ai.facebook.com/research/publications/unsupervised-speech-recognition). The model takes as input wav2vec 2.0 or XLSR representations (see [pretrained models](https://github.com/pytorch/fairseq/blob/master/examples/wav2vec)) as well as unlabeled speech and text data.
4 |
5 | The wav2vec-U training procedure consists of three consecutive main steps:
6 | * Preparation of speech representations and text data
7 | * Generative adversarial training (GAN)
8 | * Iterative self-training + Kaldi LM-decoding
9 |
10 | ## Preparation of speech and text data
11 | Similar to [wav2vec 2.0](https://github.com/pytorch/fairseq/blob/master/examples/wav2vec/README.md), data folders contain {train,valid,test}.{tsv,wrd,phn} files, where audio paths are stored in tsv files, and word, letter or phoneme transcriptions are stored in .{wrd,ltr,phn}.
12 |
13 | In **/path/to/data/with_silence** you need a *train.tsv* file as well as (optionally) *{valid,test}.{tsv,wrd,phn}*. It is nice to have *10h.{tsv,phn}* files there too for reproducing the ablation study on layer selection. In **/path/to/data/without_silence** you have the same files, except *.tsv* files contain audios with silences removed using rVAD.
14 |
15 | Pre-requisites:
16 | * set FAIRSEQ_ROOT environmental variable to your fairseq installation
17 | * set RVAD_ROOT environmental variable to a checkout of [rVADfast](https://github.com/zhenghuatan/rVADfast)
18 | * set KENLM_ROOT environmental variable to the location of [KenLM](https://github.com/kpu/kenlm) binaries
19 | * install [PyKaldi](https://github.com/pykaldi/pykaldi) and set KALDI_ROOT environmental variable to the location of your kaldi installation. To use the version bundled with PyKaldi, you can use /path/to/pykaldi/tools/kaldi
20 |
21 | Create new audio files without silences:
22 | ```shell
23 | # create a manifest file for the set original of audio files
24 | python $FAIRSEQ_ROOT/examples/wav2vec/wav2vec_manifest.py /dir/to/save/audio/files --ext wav --dest /path/to/new/train.tsv --valid-percent 0
25 |
26 | python scripts/vads.py -r $RVAD_ROOT < /path/to/train.tsv > train.vads
27 |
28 | python scripts/remove_silence.py --tsv /path/to/train.tsv --vads train.vads --out /dir/to/save/audio/files
29 |
30 | python $FAIRSEQ_ROOT/examples/wav2vec/wav2vec_manifest.py /dir/to/save/audio/files --ext wav --dest /path/to/new/train.tsv --valid-percent 0.01
31 | ```
32 |
33 | Next, we need to preprocess the audio data to better match phonemized text data:
34 |
35 | ```shell
36 | zsh scripts/prepare_audio.sh /dir/with/{train,test,valid}.tsv /output/dir /path/to/wav2vec2/model.pt 512 14
37 | ```
38 | Note that if you have splits different than train/valid/test, you will need to modify this script. The last two arguments are the PCA dimensionality and the 0-based index of the layer from which to extract representations.
39 |
40 | Now we need to prepare text data:
41 | ```shell
42 | zsh scripts/prepare_text.sh language /path/to/text/file /output/dir 1000 espeak /path/to/fasttext/lid/model
43 | ```
44 |
45 | The fourth argument is minimum number observations of phones to keep. If your text corpus is small, you might want to reduce this number.
46 |
47 | The fifth argument is which phonemizer to use. Supported values are [espeak](http://espeak.sourceforge.net/), [espeak-ng](https://github.com/espeak-ng/espeak-ng), and [G2P](https://github.com/Kyubyong/g2p) (english only).
48 |
49 | Pre-trained fasttext LID models can be downloaded [here](https://fasttext.cc/docs/en/language-identification.html).
50 |
51 | ### Prepare TIMIT data
52 | TIMIT transcripts include silence. Therefore VAD is not used for audio preprocessing, and we do not wrap transcripts with silences or insert random silence in between words.
53 |
54 | To prepare TIMIT data for both the matched an unmatched setup:
55 | ```shell
56 | bash scripts/prepare_timit.sh /dir/to/timit/raw/data /output/dir /path/to/wav2vec2/model.pt
57 | ```
58 |
59 | Note that we assume the TIMIT distribution with capitalized directories and filenames are used (e.g., `TRAIN/DR1/FCJF0/SA1.PHN`).
60 |
61 | ## Generative adversarial training (GAN)
62 |
63 | We then use a GAN model to build a first unsupervised ASR model. The data preparation above of both speech features and text data is a necessary procedure that enables the generator to match speech to text in an unsupervised way.
64 |
65 | Launching GAN training on top of preprocessed features, with default hyperparameters can be done with:
66 |
67 | ```
68 | PREFIX=w2v_unsup_gan_xp
69 | TASK_DATA=/path/to/features/precompute_unfiltered_pca512_cls128_mean_pooled
70 | TEXT_DATA=/path/to/data/phones # path to fairseq-preprocessed GAN data (phones dir)
71 | KENLM_PATH=/path/to/data/phones/kenlm.phn.o4.bin # KenLM 4-gram phoneme language model (LM data = GAN data here)
72 |
73 | PYTHONPATH=$FAIRSEQ_ROOT PREFIX=$PREFIX fairseq-hydra-train \
74 | -m --config-dir config/gan \
75 | --config-name w2vu \
76 | task.data=${TASK_DATA} \
77 | task.text_data=${TEXT_DATA} \
78 | task.kenlm_path=${KENLM_PATH} \
79 | common.user_dir=${FAIRSEQ_ROOT}/examples/wav2vec/unsupervised \
80 | model.code_penalty=2,4 model.gradient_penalty=1.5,2.0 \
81 | model.smoothness_weight=0.5,0.75,1.0 'common.seed=range(0,5)'
82 | ```
83 |
84 |
85 | Once we find the best checkpoint (chosen using unsupervised metric that combined language model perplexity and vocabulary usage), we can use it to generate phone labels (or word labels with an appropriate kaldi WFST):
86 |
87 | ```shell
88 | python w2vu_generate.py --config-dir config/generate --config-name viterbi \
89 | fairseq.common.user_dir=${FAIRSEQ_ROOT}/examples/wav2vec/unsupervised \
90 | fairseq.task.data=/path/to/dir/with/features \
91 | fairseq.common_eval.path=/path/to/gan/checkpoint \
92 | fairseq.dataset.gen_subset=valid results_path=/where/to/save/transcriptions
93 | ```
94 |
95 | The decoding without LM works best on the same adjacent-mean-pooled features that the gan was trained on, while decoding with LM works better on features before the adjacent timestep mean-pooling step (without the "_pooled" suffix).
96 |
97 | ## Iterative self-training + Kaldi LM-decoding
98 | After the GAN training provides a first unsupervised model, we can then progressively refine the quality of transcriptions using several iterations of semi-supervised learning. We perform two iterations: first, pseudo-label the training data with the unsupervised GAN model and train an HMM on the pseudo-labels. Second, we relabel the training data with the HMM and then fine-tune the original wav2vec 2.0 model using the HMM pseudo-labels with a CTC loss. Note that HMM models use phonemes as output, while wav2vec 2.0 use letter. Both are decoded using WFST decoders into words.
99 |
100 |
101 | Please see [this README](kaldi_self_train/README.md) for more instructions on how to do iterative self-training + Kaldi LM-decoding.
102 |
103 | *** Note: these instructions are a work in progress and will be updated over the next few days
104 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIOXLABS/DVoice/19b760ca93ed9016d406bd9baf077e49bab384ec/wav2vec 2.0/unsupervised/__init__.py
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/config/finetuning/w2v_finetune.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 |
3 | common:
4 | fp16: true
5 | log_format: json
6 | log_interval: 200
7 | tensorboard_logdir: tb
8 |
9 | checkpoint:
10 | no_epoch_checkpoints: true
11 | save_interval_updates: 20000
12 |
13 | task:
14 | _name: audio_finetuning
15 | data: ???
16 | normalize: true
17 | labels: ltr
18 |
19 | dataset:
20 | num_workers: 6
21 | max_tokens: 800000
22 | skip_invalid_size_inputs_valid_test: true
23 | train_subset: train
24 | valid_subset: valid
25 |
26 | distributed_training:
27 | ddp_backend: legacy_ddp
28 | distributed_world_size: 8
29 | find_unused_parameters: True
30 |
31 | criterion:
32 | _name: ctc
33 | zero_infinity: true
34 | post_process: letter
35 |
36 | optimization:
37 | max_update: 80000
38 | lr: [0.00003]
39 | sentence_avg: true
40 | update_freq: [1]
41 |
42 | optimizer:
43 | _name: adam
44 | adam_betas: (0.9,0.98)
45 | adam_eps: 1e-08
46 |
47 | lr_scheduler:
48 | _name: tri_stage
49 | phase_ratio: [0.1, 0.4, 0.5]
50 | final_lr_scale: 0.05
51 |
52 | model:
53 | _name: wav2vec_ctc
54 | w2v_path: ???
55 | apply_mask: true
56 | mask_prob: 0.25
57 | mask_channel_prob: 0.1
58 | mask_channel_length: 64
59 | layerdrop: 0.1
60 | activation_dropout: 0.1
61 | feature_grad_mult: 0.0
62 | freeze_finetune_updates: 0
63 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/config/gan/w2vu.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 |
3 | common:
4 | fp16: false
5 | fp16_no_flatten_grads: true
6 | log_format: json
7 | log_interval: 100
8 | tensorboard_logdir: tb
9 | reset_logging: false
10 | suppress_crashes: false
11 |
12 | checkpoint:
13 | save_interval: 1000
14 | save_interval_updates: 1000
15 | no_epoch_checkpoints: true
16 | best_checkpoint_metric: weighted_lm_ppl
17 | save_dir: .
18 |
19 | distributed_training:
20 | distributed_world_size: 1
21 |
22 | task:
23 | _name: unpaired_audio_text
24 | data: ???
25 | text_data: ???
26 | labels: phn
27 | sort_by_length: false
28 | unfiltered: false
29 | max_length: null
30 | append_eos: false
31 | kenlm_path: ???
32 |
33 | dataset:
34 | num_workers: 6
35 | batch_size: 160
36 | skip_invalid_size_inputs_valid_test: true
37 | valid_subset: valid
38 | validate_interval: 1000
39 | validate_interval_updates: 1000
40 |
41 | criterion:
42 | _name: model
43 | log_keys:
44 | - accuracy_dense
45 | - accuracy_token
46 | - temp
47 | - code_ppl
48 |
49 | optimization:
50 | max_update: 150000
51 | clip_norm: 5.0
52 | lr: [0]
53 |
54 | optimizer:
55 | _name: composite
56 | groups:
57 | generator:
58 | lr: [0.0004]
59 | lr_float: null
60 | optimizer:
61 | _name: adam
62 | adam_betas: [0.5,0.98]
63 | adam_eps: 1e-06
64 | weight_decay: 0
65 | amsgrad: false
66 | lr_scheduler:
67 | _name: fixed
68 | warmup_updates: 0
69 | discriminator:
70 | lr: [ 0.0005 ]
71 | lr_float: null
72 | optimizer:
73 | _name: adam
74 | adam_betas: [0.5,0.98]
75 | adam_eps: 1e-06
76 | weight_decay: 0.0001
77 | amsgrad: false
78 | lr_scheduler:
79 | _name: fixed
80 | warmup_updates: 0
81 |
82 | lr_scheduler: pass_through
83 |
84 | model:
85 | _name: wav2vec_u
86 |
87 | discriminator_dim: 384
88 | discriminator_depth: 2
89 | discriminator_kernel: 6
90 | discriminator_linear_emb: false
91 | discriminator_causal: true
92 | discriminator_max_pool: false
93 | discriminator_act_after_linear: false
94 | discriminator_dropout: 0.0
95 | discriminator_weight_norm: false
96 |
97 | generator_stride: 1
98 | generator_kernel: 4
99 | generator_bias: false
100 | generator_dropout: 0.1
101 |
102 | smoothness_weight: 0.5
103 | smoothing: 0
104 | smoothing_one_sided: false
105 | gumbel: false
106 | hard_gumbel: false
107 | gradient_penalty: 1.5
108 | code_penalty: 4.0
109 | temp: [ 2,0.1,0.99995 ]
110 | input_dim: 512
111 |
112 | segmentation:
113 | type: JOIN
114 | mean_pool_join: false
115 | remove_zeros: false
116 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/config/generate/viterbi.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 |
3 | fairseq:
4 | task:
5 | _name: unpaired_audio_text
6 | labels: phn
7 | data: ???
8 | sort_by_length: false
9 | shuffle: false
10 | text_data: ''
11 |
12 | common_eval:
13 | path: ???
14 | quiet: true
15 |
16 | dataset:
17 | gen_subset: valid
18 | batch_size: 1
19 |
20 | w2l_decoder: VITERBI
21 | post_process: silence
22 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/config/timit_matched/test.uid:
--------------------------------------------------------------------------------
1 | FDHC0_SI1559
2 | FDHC0_SI2189
3 | FDHC0_SI929
4 | FDHC0_SX119
5 | FDHC0_SX209
6 | FDHC0_SX29
7 | FDHC0_SX299
8 | FDHC0_SX389
9 | FELC0_SI1386
10 | FELC0_SI2016
11 | FELC0_SI756
12 | FELC0_SX126
13 | FELC0_SX216
14 | FELC0_SX306
15 | FELC0_SX36
16 | FELC0_SX396
17 | FJLM0_SI1043
18 | FJLM0_SI1673
19 | FJLM0_SI2303
20 | FJLM0_SX143
21 | FJLM0_SX233
22 | FJLM0_SX323
23 | FJLM0_SX413
24 | FJLM0_SX53
25 | FMGD0_SI1564
26 | FMGD0_SI2194
27 | FMGD0_SI934
28 | FMGD0_SX124
29 | FMGD0_SX214
30 | FMGD0_SX304
31 | FMGD0_SX34
32 | FMGD0_SX394
33 | FMLD0_SI2185
34 | FMLD0_SI822
35 | FMLD0_SI925
36 | FMLD0_SX115
37 | FMLD0_SX205
38 | FMLD0_SX25
39 | FMLD0_SX295
40 | FMLD0_SX385
41 | FNLP0_SI1308
42 | FNLP0_SI1938
43 | FNLP0_SI678
44 | FNLP0_SX138
45 | FNLP0_SX228
46 | FNLP0_SX318
47 | FNLP0_SX408
48 | FNLP0_SX48
49 | FPAS0_SI1272
50 | FPAS0_SI2204
51 | FPAS0_SI944
52 | FPAS0_SX134
53 | FPAS0_SX224
54 | FPAS0_SX314
55 | FPAS0_SX404
56 | FPAS0_SX44
57 | FPKT0_SI1538
58 | FPKT0_SI2168
59 | FPKT0_SI908
60 | FPKT0_SX188
61 | FPKT0_SX278
62 | FPKT0_SX368
63 | FPKT0_SX8
64 | FPKT0_SX98
65 | MBPM0_SI1577
66 | MBPM0_SI1584
67 | MBPM0_SI947
68 | MBPM0_SX137
69 | MBPM0_SX227
70 | MBPM0_SX317
71 | MBPM0_SX407
72 | MBPM0_SX47
73 | MCMJ0_SI1094
74 | MCMJ0_SI464
75 | MCMJ0_SI602
76 | MCMJ0_SX104
77 | MCMJ0_SX14
78 | MCMJ0_SX194
79 | MCMJ0_SX284
80 | MCMJ0_SX374
81 | MDAB0_SI1039
82 | MDAB0_SI1669
83 | MDAB0_SI2299
84 | MDAB0_SX139
85 | MDAB0_SX229
86 | MDAB0_SX319
87 | MDAB0_SX409
88 | MDAB0_SX49
89 | MGRT0_SI1450
90 | MGRT0_SI2080
91 | MGRT0_SI820
92 | MGRT0_SX10
93 | MGRT0_SX100
94 | MGRT0_SX190
95 | MGRT0_SX280
96 | MGRT0_SX370
97 | MJDH0_SI1354
98 | MJDH0_SI1984
99 | MJDH0_SI724
100 | MJDH0_SX184
101 | MJDH0_SX274
102 | MJDH0_SX364
103 | MJDH0_SX4
104 | MJDH0_SX94
105 | MJLN0_SI1449
106 | MJLN0_SI2079
107 | MJLN0_SI819
108 | MJLN0_SX189
109 | MJLN0_SX279
110 | MJLN0_SX369
111 | MJLN0_SX9
112 | MJLN0_SX99
113 | MJMP0_SI1535
114 | MJMP0_SI1791
115 | MJMP0_SI905
116 | MJMP0_SX185
117 | MJMP0_SX275
118 | MJMP0_SX365
119 | MJMP0_SX5
120 | MJMP0_SX95
121 | MKLT0_SI1213
122 | MKLT0_SI1843
123 | MKLT0_SI583
124 | MKLT0_SX133
125 | MKLT0_SX223
126 | MKLT0_SX313
127 | MKLT0_SX403
128 | MKLT0_SX43
129 | MLLL0_SI1363
130 | MLLL0_SI1993
131 | MLLL0_SI733
132 | MLLL0_SX103
133 | MLLL0_SX13
134 | MLLL0_SX193
135 | MLLL0_SX283
136 | MLLL0_SX373
137 | MLNT0_SI1574
138 | MLNT0_SI1902
139 | MLNT0_SI642
140 | MLNT0_SX102
141 | MLNT0_SX12
142 | MLNT0_SX192
143 | MLNT0_SX282
144 | MLNT0_SX372
145 | MNJM0_SI1580
146 | MNJM0_SI2210
147 | MNJM0_SI950
148 | MNJM0_SX140
149 | MNJM0_SX230
150 | MNJM0_SX320
151 | MNJM0_SX410
152 | MNJM0_SX50
153 | MPAM0_SI1189
154 | MPAM0_SI1819
155 | MPAM0_SI1961
156 | MPAM0_SX109
157 | MPAM0_SX19
158 | MPAM0_SX199
159 | MPAM0_SX289
160 | MPAM0_SX379
161 | MTAS1_SI1473
162 | MTAS1_SI2098
163 | MTAS1_SI838
164 | MTAS1_SX118
165 | MTAS1_SX208
166 | MTAS1_SX28
167 | MTAS1_SX298
168 | MTAS1_SX388
169 | MTLS0_SI1370
170 | MTLS0_SI2000
171 | MTLS0_SI740
172 | MTLS0_SX110
173 | MTLS0_SX20
174 | MTLS0_SX200
175 | MTLS0_SX290
176 | MTLS0_SX380
177 | MWBT0_SI1553
178 | MWBT0_SI2183
179 | MWBT0_SI923
180 | MWBT0_SX113
181 | MWBT0_SX203
182 | MWBT0_SX23
183 | MWBT0_SX293
184 | MWBT0_SX383
185 | MWEW0_SI1361
186 | MWEW0_SI1991
187 | MWEW0_SI731
188 | MWEW0_SX101
189 | MWEW0_SX11
190 | MWEW0_SX191
191 | MWEW0_SX281
192 | MWEW0_SX371
193 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/config/timit_matched/valid.uid:
--------------------------------------------------------------------------------
1 | FADG0_SI1279
2 | FADG0_SI1909
3 | FADG0_SI649
4 | FADG0_SX109
5 | FADG0_SX19
6 | FADG0_SX199
7 | FADG0_SX289
8 | FADG0_SX379
9 | FAKS0_SI1573
10 | FAKS0_SI2203
11 | FAKS0_SI943
12 | FAKS0_SX133
13 | FAKS0_SX223
14 | FAKS0_SX313
15 | FAKS0_SX403
16 | FAKS0_SX43
17 | FCAL1_SI1403
18 | FCAL1_SI2033
19 | FCAL1_SI773
20 | FCAL1_SX143
21 | FCAL1_SX233
22 | FCAL1_SX323
23 | FCAL1_SX413
24 | FCAL1_SX53
25 | FCMH0_SI1454
26 | FCMH0_SI2084
27 | FCMH0_SI824
28 | FCMH0_SX104
29 | FCMH0_SX14
30 | FCMH0_SX194
31 | FCMH0_SX284
32 | FCMH0_SX374
33 | FDAC1_SI1474
34 | FDAC1_SI2104
35 | FDAC1_SI844
36 | FDAC1_SX124
37 | FDAC1_SX214
38 | FDAC1_SX304
39 | FDAC1_SX34
40 | FDAC1_SX394
41 | FDMS0_SI1218
42 | FDMS0_SI1502
43 | FDMS0_SI1848
44 | FDMS0_SX138
45 | FDMS0_SX228
46 | FDMS0_SX318
47 | FDMS0_SX408
48 | FDMS0_SX48
49 | FDRW0_SI1283
50 | FDRW0_SI1423
51 | FDRW0_SI653
52 | FDRW0_SX113
53 | FDRW0_SX203
54 | FDRW0_SX23
55 | FDRW0_SX293
56 | FDRW0_SX383
57 | FEDW0_SI1084
58 | FEDW0_SI1653
59 | FEDW0_SI1714
60 | FEDW0_SX184
61 | FEDW0_SX274
62 | FEDW0_SX364
63 | FEDW0_SX4
64 | FEDW0_SX94
65 | FGJD0_SI1179
66 | FGJD0_SI549
67 | FGJD0_SI818
68 | FGJD0_SX189
69 | FGJD0_SX279
70 | FGJD0_SX369
71 | FGJD0_SX9
72 | FGJD0_SX99
73 | FJEM0_SI1264
74 | FJEM0_SI1894
75 | FJEM0_SI634
76 | FJEM0_SX184
77 | FJEM0_SX274
78 | FJEM0_SX364
79 | FJEM0_SX4
80 | FJEM0_SX94
81 | FJMG0_SI1181
82 | FJMG0_SI1811
83 | FJMG0_SI551
84 | FJMG0_SX101
85 | FJMG0_SX11
86 | FJMG0_SX191
87 | FJMG0_SX281
88 | FJMG0_SX371
89 | FJSJ0_SI1484
90 | FJSJ0_SI2114
91 | FJSJ0_SI854
92 | FJSJ0_SX134
93 | FJSJ0_SX224
94 | FJSJ0_SX314
95 | FJSJ0_SX404
96 | FJSJ0_SX44
97 | FKMS0_SI1490
98 | FKMS0_SI2120
99 | FKMS0_SI860
100 | FKMS0_SX140
101 | FKMS0_SX230
102 | FKMS0_SX320
103 | FKMS0_SX410
104 | FKMS0_SX50
105 | FMAH0_SI1289
106 | FMAH0_SI1919
107 | FMAH0_SI659
108 | FMAH0_SX119
109 | FMAH0_SX209
110 | FMAH0_SX29
111 | FMAH0_SX299
112 | FMAH0_SX389
113 | FMML0_SI1040
114 | FMML0_SI1670
115 | FMML0_SI2300
116 | FMML0_SX140
117 | FMML0_SX230
118 | FMML0_SX320
119 | FMML0_SX410
120 | FMML0_SX50
121 | FNMR0_SI1399
122 | FNMR0_SI2029
123 | FNMR0_SI769
124 | FNMR0_SX139
125 | FNMR0_SX229
126 | FNMR0_SX319
127 | FNMR0_SX409
128 | FNMR0_SX49
129 | FREW0_SI1030
130 | FREW0_SI1280
131 | FREW0_SI1910
132 | FREW0_SX110
133 | FREW0_SX20
134 | FREW0_SX200
135 | FREW0_SX290
136 | FREW0_SX380
137 | FSEM0_SI1198
138 | FSEM0_SI1828
139 | FSEM0_SI568
140 | FSEM0_SX118
141 | FSEM0_SX208
142 | FSEM0_SX28
143 | FSEM0_SX298
144 | FSEM0_SX388
145 | MAJC0_SI1946
146 | MAJC0_SI2095
147 | MAJC0_SI835
148 | MAJC0_SX115
149 | MAJC0_SX205
150 | MAJC0_SX25
151 | MAJC0_SX295
152 | MAJC0_SX385
153 | MBDG0_SI1463
154 | MBDG0_SI2093
155 | MBDG0_SI833
156 | MBDG0_SX113
157 | MBDG0_SX203
158 | MBDG0_SX23
159 | MBDG0_SX293
160 | MBDG0_SX383
161 | MBNS0_SI1220
162 | MBNS0_SI1850
163 | MBNS0_SI590
164 | MBNS0_SX140
165 | MBNS0_SX230
166 | MBNS0_SX320
167 | MBNS0_SX410
168 | MBNS0_SX50
169 | MBWM0_SI1304
170 | MBWM0_SI1934
171 | MBWM0_SI674
172 | MBWM0_SX134
173 | MBWM0_SX224
174 | MBWM0_SX314
175 | MBWM0_SX404
176 | MBWM0_SX44
177 | MCSH0_SI1549
178 | MCSH0_SI2179
179 | MCSH0_SI919
180 | MCSH0_SX109
181 | MCSH0_SX19
182 | MCSH0_SX199
183 | MCSH0_SX289
184 | MCSH0_SX379
185 | MDLF0_SI1583
186 | MDLF0_SI2213
187 | MDLF0_SI953
188 | MDLF0_SX143
189 | MDLF0_SX233
190 | MDLF0_SX323
191 | MDLF0_SX413
192 | MDLF0_SX53
193 | MDLS0_SI1628
194 | MDLS0_SI2258
195 | MDLS0_SI998
196 | MDLS0_SX188
197 | MDLS0_SX278
198 | MDLS0_SX368
199 | MDLS0_SX8
200 | MDLS0_SX98
201 | MDVC0_SI2174
202 | MDVC0_SI2196
203 | MDVC0_SI936
204 | MDVC0_SX126
205 | MDVC0_SX216
206 | MDVC0_SX306
207 | MDVC0_SX36
208 | MDVC0_SX396
209 | MERS0_SI1019
210 | MERS0_SI1649
211 | MERS0_SI497
212 | MERS0_SX119
213 | MERS0_SX209
214 | MERS0_SX29
215 | MERS0_SX299
216 | MERS0_SX389
217 | MGJF0_SI1901
218 | MGJF0_SI641
219 | MGJF0_SI776
220 | MGJF0_SX101
221 | MGJF0_SX11
222 | MGJF0_SX191
223 | MGJF0_SX281
224 | MGJF0_SX371
225 | MGLB0_SI1534
226 | MGLB0_SI2164
227 | MGLB0_SI904
228 | MGLB0_SX184
229 | MGLB0_SX274
230 | MGLB0_SX364
231 | MGLB0_SX4
232 | MGLB0_SX94
233 | MGWT0_SI1539
234 | MGWT0_SI2169
235 | MGWT0_SI909
236 | MGWT0_SX189
237 | MGWT0_SX279
238 | MGWT0_SX369
239 | MGWT0_SX9
240 | MGWT0_SX99
241 | MJAR0_SI1988
242 | MJAR0_SI2247
243 | MJAR0_SI728
244 | MJAR0_SX188
245 | MJAR0_SX278
246 | MJAR0_SX368
247 | MJAR0_SX8
248 | MJAR0_SX98
249 | MJFC0_SI1033
250 | MJFC0_SI1663
251 | MJFC0_SI2293
252 | MJFC0_SX133
253 | MJFC0_SX223
254 | MJFC0_SX313
255 | MJFC0_SX403
256 | MJFC0_SX43
257 | MJSW0_SI1010
258 | MJSW0_SI1640
259 | MJSW0_SI2270
260 | MJSW0_SX110
261 | MJSW0_SX20
262 | MJSW0_SX200
263 | MJSW0_SX290
264 | MJSW0_SX380
265 | MMDB1_SI1625
266 | MMDB1_SI2255
267 | MMDB1_SI995
268 | MMDB1_SX185
269 | MMDB1_SX275
270 | MMDB1_SX365
271 | MMDB1_SX5
272 | MMDB1_SX95
273 | MMDM2_SI1452
274 | MMDM2_SI1555
275 | MMDM2_SI2082
276 | MMDM2_SX102
277 | MMDM2_SX12
278 | MMDM2_SX192
279 | MMDM2_SX282
280 | MMDM2_SX372
281 | MMJR0_SI1648
282 | MMJR0_SI2166
283 | MMJR0_SI2278
284 | MMJR0_SX118
285 | MMJR0_SX208
286 | MMJR0_SX28
287 | MMJR0_SX298
288 | MMJR0_SX388
289 | MMWH0_SI1089
290 | MMWH0_SI1301
291 | MMWH0_SI459
292 | MMWH0_SX189
293 | MMWH0_SX279
294 | MMWH0_SX369
295 | MMWH0_SX9
296 | MMWH0_SX99
297 | MPDF0_SI1542
298 | MPDF0_SI2172
299 | MPDF0_SI912
300 | MPDF0_SX102
301 | MPDF0_SX12
302 | MPDF0_SX192
303 | MPDF0_SX282
304 | MPDF0_SX372
305 | MRCS0_SI1223
306 | MRCS0_SI1853
307 | MRCS0_SI593
308 | MRCS0_SX143
309 | MRCS0_SX233
310 | MRCS0_SX323
311 | MRCS0_SX413
312 | MRCS0_SX53
313 | MREB0_SI1375
314 | MREB0_SI2005
315 | MREB0_SI745
316 | MREB0_SX115
317 | MREB0_SX205
318 | MREB0_SX25
319 | MREB0_SX295
320 | MREB0_SX385
321 | MRJM4_SI1489
322 | MRJM4_SI2119
323 | MRJM4_SI859
324 | MRJM4_SX139
325 | MRJM4_SX229
326 | MRJM4_SX319
327 | MRJM4_SX409
328 | MRJM4_SX49
329 | MRJR0_SI1182
330 | MRJR0_SI1812
331 | MRJR0_SI2313
332 | MRJR0_SX102
333 | MRJR0_SX12
334 | MRJR0_SX192
335 | MRJR0_SX282
336 | MRJR0_SX372
337 | MROA0_SI1307
338 | MROA0_SI1970
339 | MROA0_SI677
340 | MROA0_SX137
341 | MROA0_SX227
342 | MROA0_SX317
343 | MROA0_SX407
344 | MROA0_SX47
345 | MRTK0_SI1093
346 | MRTK0_SI1723
347 | MRTK0_SI1750
348 | MRTK0_SX103
349 | MRTK0_SX13
350 | MRTK0_SX193
351 | MRTK0_SX283
352 | MRTK0_SX373
353 | MRWS1_SI1130
354 | MRWS1_SI1496
355 | MRWS1_SI500
356 | MRWS1_SX140
357 | MRWS1_SX230
358 | MRWS1_SX320
359 | MRWS1_SX410
360 | MRWS1_SX50
361 | MTAA0_SI1285
362 | MTAA0_SI1915
363 | MTAA0_SI596
364 | MTAA0_SX115
365 | MTAA0_SX205
366 | MTAA0_SX25
367 | MTAA0_SX295
368 | MTAA0_SX385
369 | MTDT0_SI1994
370 | MTDT0_SI2254
371 | MTDT0_SI994
372 | MTDT0_SX184
373 | MTDT0_SX274
374 | MTDT0_SX364
375 | MTDT0_SX4
376 | MTDT0_SX94
377 | MTEB0_SI1133
378 | MTEB0_SI2064
379 | MTEB0_SI503
380 | MTEB0_SX143
381 | MTEB0_SX233
382 | MTEB0_SX323
383 | MTEB0_SX413
384 | MTEB0_SX53
385 | MTHC0_SI1015
386 | MTHC0_SI1645
387 | MTHC0_SI2275
388 | MTHC0_SX115
389 | MTHC0_SX205
390 | MTHC0_SX25
391 | MTHC0_SX295
392 | MTHC0_SX385
393 | MWJG0_SI1124
394 | MWJG0_SI1754
395 | MWJG0_SI494
396 | MWJG0_SX134
397 | MWJG0_SX224
398 | MWJG0_SX314
399 | MWJG0_SX404
400 | MWJG0_SX44
401 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/data/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from .extracted_features_dataset import ExtractedFeaturesDataset
7 | from .random_input_dataset import RandomInputDataset
8 |
9 |
10 | __all__ = [
11 | "ExtractedFeaturesDataset",
12 | "RandomInputDataset",
13 | ]
14 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/data/extracted_features_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 |
7 | import logging
8 | import os
9 | import contextlib
10 |
11 | import numpy as np
12 | import torch
13 |
14 | from fairseq.data import FairseqDataset, data_utils
15 |
16 |
17 | logger = logging.getLogger(__name__)
18 |
19 |
20 | class ExtractedFeaturesDataset(FairseqDataset):
21 | def __init__(
22 | self,
23 | path,
24 | split,
25 | min_length=3,
26 | max_length=None,
27 | labels=None,
28 | label_dict=None,
29 | shuffle=True,
30 | sort_by_length=True,
31 | ):
32 | super().__init__()
33 |
34 | self.min_length = min_length
35 | self.max_length = max_length
36 | self.shuffle = shuffle
37 | self.sort_by_length = sort_by_length
38 | self.label_dict = label_dict
39 |
40 | if labels is not None:
41 | assert label_dict is not None
42 |
43 | self.sizes = []
44 | self.offsets = []
45 | self.labels = []
46 |
47 | path = os.path.join(path, split)
48 | data_path = path
49 | self.data = np.load(data_path + ".npy", mmap_mode="r")
50 |
51 | offset = 0
52 | skipped = 0
53 |
54 | if not os.path.exists(path + f".{labels}"):
55 | labels = None
56 |
57 | with open(data_path + ".lengths", "r") as len_f, open(
58 | path + f".{labels}", "r"
59 | ) if labels is not None else contextlib.ExitStack() as lbl_f:
60 | for line in len_f:
61 | length = int(line.rstrip())
62 | lbl = None if labels is None else next(lbl_f).rstrip().split()
63 | if length >= min_length and (
64 | max_length is None or length <= max_length
65 | ):
66 | self.sizes.append(length)
67 | self.offsets.append(offset)
68 | if lbl is not None:
69 | self.labels.append(lbl)
70 | offset += length
71 |
72 | self.sizes = np.asarray(self.sizes)
73 | self.offsets = np.asarray(self.offsets)
74 |
75 | logger.info(f"loaded {len(self.offsets)}, skipped {skipped} samples")
76 |
77 | def __getitem__(self, index):
78 | offset = self.offsets[index]
79 | end = self.sizes[index] + offset
80 | feats = torch.from_numpy(self.data[offset:end].copy()).float()
81 |
82 | res = {"id": index, "features": feats}
83 | if len(self.labels) > 0:
84 | res["target"] = self.label_dict.encode_line(
85 | self.labels[index],
86 | line_tokenizer=lambda x: x,
87 | append_eos=False,
88 | )
89 |
90 | return res
91 |
92 | def __len__(self):
93 | return len(self.sizes)
94 |
95 | def collater(self, samples):
96 | if len(samples) == 0:
97 | return {}
98 |
99 | features = [s["features"] for s in samples]
100 | sizes = [len(s) for s in features]
101 |
102 | target_size = max(sizes)
103 |
104 | collated_features = features[0].new_zeros(
105 | len(features), target_size, features[0].size(-1)
106 | )
107 | padding_mask = torch.BoolTensor(collated_features.shape[:-1]).fill_(False)
108 | for i, (f, size) in enumerate(zip(features, sizes)):
109 | collated_features[i, :size] = f
110 | padding_mask[i, size:] = True
111 |
112 | res = {
113 | "id": torch.LongTensor([s["id"] for s in samples]),
114 | "net_input": {"features": collated_features, "padding_mask": padding_mask},
115 | }
116 |
117 | if len(self.labels) > 0:
118 | target = data_utils.collate_tokens(
119 | [s["target"] for s in samples],
120 | pad_idx=self.label_dict.pad(),
121 | left_pad=False,
122 | )
123 | res["target"] = target
124 | return res
125 |
126 | def num_tokens(self, index):
127 | return self.size(index)
128 |
129 | def size(self, index):
130 | return self.sizes[index]
131 |
132 | def ordered_indices(self):
133 | """Return an ordered list of indices. Batches will be constructed based
134 | on this order."""
135 | if self.shuffle:
136 | order = [np.random.permutation(len(self))]
137 | else:
138 | order = [np.arange(len(self))]
139 |
140 | if self.sort_by_length:
141 | order.append(self.sizes)
142 | return np.lexsort(order)[::-1]
143 | else:
144 | return order[0]
145 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/data/random_input_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import random
7 | from typing import List
8 |
9 | from fairseq.data import BaseWrapperDataset, data_utils
10 |
11 |
12 | class RandomInputDataset(BaseWrapperDataset):
13 | def __init__(
14 | self,
15 | dataset,
16 | random_input_dataset,
17 | input_key_path: List[str],
18 | add_to_input,
19 | pad_idx,
20 | ):
21 | super().__init__(dataset)
22 | self.random_input_dataset = random_input_dataset
23 | if isinstance(input_key_path, str):
24 | input_key_path = [input_key_path]
25 | assert len(input_key_path) > 0
26 | self.input_key_path = input_key_path
27 | self.add_to_input = add_to_input
28 | self.pad_idx = pad_idx
29 |
30 | def get_target(self, item):
31 | target_loc = item
32 | for p in self.input_key_path[:-1]:
33 | target_loc = target_loc[p]
34 | return self.input_key_path[-1], target_loc
35 |
36 | def get_target_value(self, item):
37 | k, target_loc = self.get_target(item)
38 | return target_loc[k]
39 |
40 | def __getitem__(self, index):
41 | item = self.dataset[index]
42 | k, target_loc = self.get_target(item)
43 | target_loc[k] = random.choice(self.random_input_dataset)
44 | return item
45 |
46 | def collater(self, samples):
47 | collated = self.dataset.collater(samples)
48 | if len(collated) == 0:
49 | return collated
50 | indices = set(collated["id"].tolist())
51 |
52 | random_inputs = data_utils.collate_tokens(
53 | [self.get_target_value(s) for s in samples if s["id"] in indices],
54 | pad_idx=self.pad_idx,
55 | left_pad=False,
56 | )
57 | k, target_loc = self.get_target(
58 | collated if not self.add_to_input else collated["net_input"]
59 | )
60 | target_loc[k] = random_inputs
61 |
62 | return collated
63 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/kaldi_self_train/README.md:
--------------------------------------------------------------------------------
1 | # Self-Training with Kaldi HMM Models
2 | This folder contains recipes for self-training on pseudo phone transcripts and
3 | decoding into phones or words with [kaldi](https://github.com/kaldi-asr/kaldi).
4 |
5 | To start, download and install kaldi follow its instruction, and place this
6 | folder in `path/to/kaldi/egs`.
7 |
8 | ## Training
9 | Assuming the following has been prepared:
10 | - `w2v_dir`: contains features `{train,valid}.{npy,lengths}`, real transcripts `{train,valid}.${label}`, and dict `dict.${label}.txt`
11 | - `lab_dir`: contains pseudo labels `{train,valid}.txt`
12 | - `arpa_lm`: Arpa-format n-gram phone LM for decoding
13 | - `arpa_lm_bin`: Arpa-format n-gram phone LM for unsupervised model selection to be used with KenLM
14 |
15 | Set these variables in `train.sh`, as well as `out_dir`, the output directory,
16 | and then run it.
17 |
18 | The output will be:
19 | ```
20 | ==== WER w.r.t. real transcript (select based on unsupervised metric)
21 | INFO:root:./out/exp/mono/decode_valid/scoring/14.0.0.tra.txt: score 0.9178 wer 28.71% lm_ppl 24.4500 gt_wer 25.57%
22 | INFO:root:./out/exp/tri1/decode_valid/scoring/17.1.0.tra.txt: score 0.9257 wer 26.99% lm_ppl 30.8494 gt_wer 21.90%
23 | INFO:root:./out/exp/tri2b/decode_valid/scoring/8.0.0.tra.txt: score 0.7506 wer 23.15% lm_ppl 25.5944 gt_wer 15.78%
24 | ```
25 | where `wer` is the word eror rate with respect to the pseudo label, `gt_wer` to
26 | the ground truth label, `lm_ppl` the language model perplexity of HMM prediced
27 | transcripts, and `score` is the unsupervised metric for model selection. We
28 | choose the model and the LM parameter of the one with the lowest score. In the
29 | example above, it is `tri2b`, `8.0.0`.
30 |
31 |
32 | ## Decoding into Phones
33 | In `decode_phone.sh`, set `out_dir` the same as used in `train.sh`, set
34 | `dec_exp` and `dec_lmparam` to the selected model and LM parameter (e.g.
35 | `tri2b` and `8.0.0` in the above example). `dec_script` needs to be set
36 | according to `dec_exp`: for mono/tri1/tri2b, use `decode.sh`; for tri3b, use
37 | `decode_fmllr.sh`.
38 |
39 | The output will be saved at `out_dir/dec_data`
40 |
41 |
42 | ## Decoding into Words
43 | `decode_word_step1.sh` prepares WFSTs for word decoding. Besides the variables
44 | mentioned above, set
45 | - `wrd_arpa_lm`: Arpa-format n-gram word LM for decoding
46 | - `wrd_arpa_lm_bin`: Arpa-format n-gram word LM for unsupervised model selection
47 |
48 | `decode_word_step1.sh` decodes the `train` and `valid` split into word and runs
49 | unsupervised model selection using the `valid` split. The output is like:
50 | ```
51 | INFO:root:./out/exp/tri2b/decodeword_valid/scoring/17.0.0.tra.txt: score 1.8693 wer 24.97% lm_ppl 1785.5333 gt_wer 31.45%
52 | ```
53 |
54 | After determining the LM parameter (`17.0.0` in the example above), set it in
55 | `decode_word_step2.sh` and run it. The output will be saved at
56 | `out_dir/dec_data_word`.
57 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/kaldi_self_train/st/cmd.sh:
--------------------------------------------------------------------------------
1 | # you can change cmd.sh depending on what type of queue you are using.
2 | # If you have no queueing system and want to run on a local machine, you
3 | # can change all instances 'queue.pl' to run.pl (but be careful and run
4 | # commands one by one: most recipes will exhaust the memory on your
5 | # machine). queue.pl works with GridEngine (qsub). slurm.pl works
6 | # with slurm. Different queues are configured differently, with different
7 | # queue names and different ways of specifying things like memory;
8 | # to account for these differences you can create and edit the file
9 | # conf/queue.conf to match your queue's configuration. Search for
10 | # conf/queue.conf in http://kaldi-asr.org/doc/queue.html for more information,
11 | # or search for the string 'default_config' in utils/queue.pl or utils/slurm.pl.
12 |
13 | export train_cmd="run.pl --mem 2G"
14 | export decode_cmd="run.pl --mem 4G"
15 | export mkgraph_cmd="run.pl --mem 8G"
16 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/kaldi_self_train/st/decode_phone.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # decode into phones (and prepare a new data directory for HMM outputs)
4 |
5 | . ./path.sh
6 |
7 | set -eu
8 |
9 | out_dir= # same as in train.sh
10 | dec_lmparam= # LM hyperparameters (e.g., 7.0.0)
11 | dec_exp=
12 | dec_script=
13 | dec_splits="train valid"
14 | dec_data_dir=$out_dir/dec_data # where to write HMM output
15 |
16 | data_dir=${out_dir}/data
17 |
18 | local/decode.sh --nj 40 --graph_name graph \
19 | --val_sets "$dec_splits" --decode_script $dec_script \
20 | $out_dir/exp/$dec_exp $data_dir $data_dir/lang_test
21 |
22 | if [ ! -z $dec_lmparam ]; then
23 | for x in $dec_splits; do
24 | mkdir -p $dec_data_dir/$x
25 | cp $data_dir/$x/{feats.scp,cmvn.scp,utt2spk,spk2utt} $dec_data_dir/$x/
26 |
27 | tra=$out_dir/exp/$dec_exp/decode_${x}/scoring/${dec_lmparam}.tra
28 | cat $tra | utils/int2sym.pl -f 2- $data_dir/lang/words.txt | \
29 | sed 's:::g' | sed 's:::g' > $dec_data_dir/${x}/text
30 | utils/fix_data_dir.sh $dec_data_dir/${x}
31 | echo "WER on ${x} is" $(compute-wer ark:$data_dir/${x}_gt/text ark:$dec_data_dir/$x/text | cut -d" " -f2-)
32 | done
33 | fi
34 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/kaldi_self_train/st/decode_word_step1.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # prepare word WFSTs, reference data, and decode
4 |
5 | set -eu
6 |
7 | w2v_dir= # same as in train.sh
8 | out_dir= # same as in train.sh
9 | lexicon= # word to phone mapping
10 | wrd_arpa_lm= # word LM
11 | wrd_arpa_lm_bin= # word LM for KenLM, used in unsupervised selection
12 |
13 | dec_exp= # what HMM stage to decode (e.g., tri3b)
14 | dec_script= # what decoding script to use (e.g., steps/decode_fmllr.sh)
15 | phn_label=phnc
16 | wrd_label=wrd
17 | dec_suffix=word
18 | dec_splits="train valid"
19 | valid_split="valid"
20 |
21 | data_dir=$out_dir/data
22 | wrd_data_dir=$out_dir/data_word
23 |
24 | lexicon_clean=$(mktemp)
25 | cat $lexicon | sort | uniq > $lexicon_clean
26 | local/prepare_lang_word.sh $w2v_dir/dict.${phn_label}.txt $data_dir $lexicon_clean && rm $lexicon_clean
27 | local/prepare_lm.sh --langdir $data_dir/lang_word --lmdir $data_dir/lang_test_word $wrd_arpa_lm $data_dir
28 |
29 | for x in $dec_splits; do
30 | x_gt=${x}_gt
31 | mkdir -p $wrd_data_dir/$x_gt
32 | cp $data_dir/$x_gt/{feats.scp,cmvn.scp,utt2spk,spk2utt} $wrd_data_dir/$x_gt/
33 | python local/copy_aligned_text.py < $w2v_dir/$x.$wrd_label > $wrd_data_dir/$x_gt/text
34 | done
35 |
36 | local/decode.sh --nj 40 --graph_name graph${dec_suffix} --decode_suffix $dec_suffix \
37 | --val_sets "$dec_splits" --decode_script $dec_script \
38 | $out_dir/exp/$dec_exp $data_dir $data_dir/lang_test_word
39 |
40 | local/unsup_select_decode_word.sh \
41 | --split $valid_split --kenlm_path $wrd_arpa_lm_bin \
42 | --ref_txt $wrd_data_dir/${valid_split}_gt/text \
43 | --psd_txt $data_dir/${valid_split}/text \
44 | --dec_name decode${dec_suffix} --graph_name graph${dec_suffix} \
45 | --phonemize_lexicon $data_dir/local/dict_word/lexicon.txt \
46 | $out_dir/exp
47 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/kaldi_self_train/st/decode_word_step2.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # prepare a new data directory of HMM word output
4 |
5 | . ./path.sh
6 |
7 | set -eu
8 |
9 | out_dir= # same as in train.sh
10 | dec_lmparam= # LM hyperparameters (e.g., 7.0.0)
11 |
12 | dec_exp=tri3b # what HMM stage to decode (e.g., tri3b)
13 | dec_suffix=word
14 | dec_splits="train valid"
15 | dec_data_dir=$out_dir/dec_data_word # where to write HMM output
16 |
17 | data_dir=$out_dir/data
18 | wrd_data_dir=$out_dir/data_word
19 |
20 | for x in $dec_splits; do
21 | mkdir -p $dec_data_dir/$x
22 | cp $data_dir/$x/{feats.scp,cmvn.scp,utt2spk,spk2utt} $dec_data_dir/$x/
23 |
24 | tra=$out_dir/exp/$dec_exp/decode${dec_suffix}_${x}/scoring/${dec_lmparam}.tra
25 | cat $tra | utils/int2sym.pl -f 2- $data_dir/lang_word/words.txt | \
26 | sed 's:::g' | sed 's:::g' > $dec_data_dir/$x/text
27 | utils/fix_data_dir.sh $dec_data_dir/$x
28 | echo "WER on $x is" $(compute-wer ark:$wrd_data_dir/${x}_gt/text ark:$dec_data_dir/$x/text | cut -d" " -f2-)
29 | done
30 |
31 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/kaldi_self_train/st/local/copy_aligned_text.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | for idx, line in enumerate(sys.stdin):
4 | print(f"utt{idx:010d} {line}", end='')
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/kaldi_self_train/st/local/decode.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -u
4 |
5 | val_sets="dev_other"
6 | graph_name=graph
7 | decode_suffix=""
8 | decode_script="steps/decode_fmllr.sh"
9 | decode_args=""
10 | nj=60
11 |
12 | . ./cmd.sh
13 | . ./path.sh
14 | . parse_options.sh
15 |
16 | set -x
17 | exp_dir=$1
18 | data_root=$2
19 | lang_test=$3
20 |
21 | graph=$exp_dir/$graph_name
22 |
23 | if [ ! -d $graph ]; then
24 | utils/mkgraph.sh $lang_test $exp_dir $graph
25 | fi
26 |
27 | for part in $val_sets; do
28 | dec_dir=$exp_dir/decode${decode_suffix}_${part}
29 | if [ ! -d $dec_dir ]; then
30 | echo "decoding $part for $exp_dir"
31 | $decode_script --nj $nj --cmd "$decode_cmd" $decode_args \
32 | $graph $data_root/$part $dec_dir &
33 | else
34 | echo "$dec_dir exists. skip"
35 | fi
36 | done
37 |
38 | wait
39 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/kaldi_self_train/st/local/prepare_data_from_w2v.py:
--------------------------------------------------------------------------------
1 | import kaldi_io
2 | import numpy as np
3 | import os
4 |
5 |
6 | def get_parser():
7 | import argparse
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument("w2v_dir", help="wav2vec feature and text directory")
10 | parser.add_argument("tar_root", help="output data directory in kaldi's format")
11 | parser.add_argument("split", help="name of the subset")
12 | parser.add_argument("--label", default="", help="if specified, copy labels too")
13 | return parser
14 |
15 | def main():
16 | parser = get_parser()
17 | args = parser.parse_args()
18 |
19 | tar_dir = os.path.join(args.tar_root, args.split)
20 | os.makedirs(tar_dir, exist_ok=True)
21 |
22 | lengths_path = os.path.join(args.w2v_dir, f"{args.split}.lengths")
23 | with open(lengths_path) as f:
24 | lengths = [int(line.rstrip()) for line in f]
25 | offsets = [0] + np.cumsum(lengths[:-1]).tolist()
26 | feats = np.load(
27 | os.path.join(args.w2v_dir, f"{args.split}.npy"),
28 | mmap_mode="r"
29 | )
30 | assert feats.shape[0] == sum(lengths), \
31 | f"lengths mismatch {feats.shape[0]} != {sum(lengths)}"
32 |
33 | ark_path = os.path.join(tar_dir, "feats.ark")
34 | scp_path = os.path.join(tar_dir, "feats.scp")
35 | wspec = f"ark:| copy-feats --compress=true ark:- ark,scp:{ark_path},{scp_path}"
36 | with kaldi_io.open_or_fd(wspec, "wb") as f:
37 | for idx, (offset, length) in enumerate(zip(offsets, lengths)):
38 | feat = feats[offset:offset+length]
39 | kaldi_io.write_mat(f, feat, key=f"utt{idx:010d}")
40 |
41 | u2s_path = os.path.join(tar_dir, "utt2spk")
42 | s2u_path = os.path.join(tar_dir, "spk2utt")
43 | with open(u2s_path, "w") as f_u2s, open(s2u_path, "w") as f_s2u:
44 | for idx in range(len(lengths)):
45 | f_u2s.write(f"utt{idx:010d} utt{idx:010d}\n")
46 | f_s2u.write(f"utt{idx:010d} utt{idx:010d}\n")
47 |
48 | if bool(args.label):
49 | lab_path = os.path.join(args.w2v_dir, f"{args.split}.{args.label}")
50 | txt_path = os.path.join(tar_dir, "text")
51 | with open(lab_path) as f_lab, open(txt_path, "w") as f_txt:
52 | for idx, line in enumerate(f_lab):
53 | f_txt.write(f"utt{idx:010d} {line}")
54 |
55 | if __name__ == "__main__":
56 | main()
57 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/kaldi_self_train/st/local/prepare_lang.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | sil_prob=0.5
4 | num_sil_states=3
5 | num_nonsil_states=1
6 |
7 | . ./cmd.sh
8 | . ./path.sh
9 | . parse_options.sh
10 |
11 | set -eux
12 |
13 | dict=$1
14 | data_dir=$2
15 |
16 | dict_dir=$data_dir/local/dict
17 | tmplm_dir=$data_dir/local/lang_tmp
18 | lm_dir=$data_dir/lang
19 |
20 | mkdir -p $dict_dir $tmplm_dir $lm_dir
21 |
22 | # prepare dict
23 | echo "SIL" > $dict_dir/silence_phones.txt
24 | echo "SIL" > $dict_dir/optional_silence.txt
25 | awk '{print $1}' $dict > $dict_dir/nonsilence_phones.txt
26 |
27 | echo "SIL SIL" > $dict_dir/lexicon.txt
28 | echo " SIL" >> $dict_dir/lexicon.txt
29 | awk '{print $1" "$1}' $dict >> $dict_dir/lexicon.txt
30 |
31 | echo "SIL" > $dict_dir/extra_questions.txt
32 | awk '{printf $1" "} END {printf "\n"}' $dict >> $dict_dir/extra_questions.txt
33 |
34 | # prepare lang
35 | utils/prepare_lang.sh --sil-prob $sil_prob --position-dependent-phones false \
36 | --num_sil_states $num_sil_states --num_nonsil_states $num_nonsil_states \
37 | $dict_dir "" $tmplm_dir $lm_dir
38 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/kaldi_self_train/st/local/prepare_lang_word.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | num_sil_states=3
4 | num_nonsil_states=1
5 |
6 | . ./cmd.sh
7 | . ./path.sh
8 | . parse_options.sh
9 |
10 | set -eux
11 |
12 | dict=$1
13 | data_dir=$2
14 | lexicon=$3
15 |
16 | dict_dir=$data_dir/local/dict_word
17 | tmplm_dir=$data_dir/local/lang_tmp_word
18 | lm_dir=$data_dir/lang_word
19 |
20 | mkdir -p $dict_dir $tmplm_dir $lm_dir
21 |
22 | # prepare dict
23 | echo "SIL" > $dict_dir/silence_phones.txt
24 | echo "SIL" > $dict_dir/optional_silence.txt
25 | awk '{print $1}' $dict > $dict_dir/nonsilence_phones.txt
26 |
27 | (echo "!SIL SIL"; echo " SIL";) | cat - $lexicon > $dict_dir/lexicon.txt
28 |
29 | echo "SIL" > $dict_dir/extra_questions.txt
30 | awk '{printf $1" "} END {printf "\n"}' $dict >> $dict_dir/extra_questions.txt
31 |
32 | # prepare lang
33 | utils/prepare_lang.sh --position-dependent-phones false \
34 | --num_sil_states $num_sil_states --num_nonsil_states $num_nonsil_states \
35 | $dict_dir "" $tmplm_dir $lm_dir
36 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/kaldi_self_train/st/local/prepare_lm.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | langdir=""
4 | lmdir=""
5 |
6 | . ./cmd.sh
7 | . ./path.sh
8 | . parse_options.sh
9 |
10 | arpa_lm=$1
11 | data=$2
12 |
13 | if [ -z $langdir ]; then
14 | langdir=$data/lang
15 | fi
16 | if [ -z $lmdir ]; then
17 | lmdir=$data/lang_test
18 | fi
19 |
20 | if [ ! -d $langdir ]; then
21 | echo "$langdir not found. run local/prepare_lang.sh first" && exit 1
22 | fi
23 |
24 | mkdir -p $lmdir
25 | cp -r $langdir/* $lmdir
26 |
27 | if [[ "$arpa_lm" == *.gz ]]; then
28 | gunzip -c $arpa_lm | arpa2fst --disambig-symbol=#0 --read-symbol-table=$lmdir/words.txt - $lmdir/G.fst
29 | else
30 | arpa2fst --disambig-symbol=#0 --read-symbol-table=$lmdir/words.txt $arpa_lm $lmdir/G.fst
31 | fi
32 | fstisstochastic $lmdir/G.fst
33 | utils/validate_lang.pl $lmdir || exit 1
34 |
35 | echo "done preparing lm ($lmdir)"
36 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/kaldi_self_train/st/local/score.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # Copyright 2012 Johns Hopkins University (Author: Daniel Povey)
3 | # 2014 Guoguo Chen
4 | # Apache 2.0
5 |
6 | [ -f ./path.sh ] && . ./path.sh
7 |
8 | # begin configuration section.
9 | cmd=run.pl
10 | stage=0
11 | decode_mbr=true
12 | word_ins_penalty=0.0,0.5,1.0
13 | min_lmwt=7
14 | max_lmwt=17
15 | iter=final
16 | #end configuration section.
17 |
18 | [ -f ./path.sh ] && . ./path.sh
19 | . parse_options.sh || exit 1;
20 |
21 | if [ $# -ne 3 ]; then
22 | echo "Usage: local/score.sh [--cmd (run.pl|queue.pl...)] "
23 | echo " Options:"
24 | echo " --cmd (run.pl|queue.pl...) # specify how to run the sub-processes."
25 | echo " --stage (0|1|2) # start scoring script from part-way through."
26 | echo " --decode_mbr (true/false) # maximum bayes risk decoding (confusion network)."
27 | echo " --min_lmwt # minumum LM-weight for lattice rescoring "
28 | echo " --max_lmwt # maximum LM-weight for lattice rescoring "
29 | exit 1;
30 | fi
31 |
32 | data=$1
33 | lang_or_graph=$2
34 | dir=$3
35 |
36 | symtab=$lang_or_graph/words.txt
37 |
38 | for f in $symtab $dir/lat.1.gz $data/text; do
39 | [ ! -f $f ] && echo "score.sh: no such file $f" && exit 1;
40 | done
41 |
42 | mkdir -p $dir/scoring/log
43 |
44 | cat $data/text | sed 's:::g' | sed 's:::g' > $dir/scoring/test_filt.txt
45 |
46 | for wip in $(echo $word_ins_penalty | sed 's/,/ /g'); do
47 | $cmd LMWT=$min_lmwt:$max_lmwt $dir/scoring/log/best_path.LMWT.$wip.log \
48 | lattice-scale --inv-acoustic-scale=LMWT "ark:gunzip -c $dir/lat.*.gz|" ark:- \| \
49 | lattice-add-penalty --word-ins-penalty=$wip ark:- ark:- \| \
50 | lattice-best-path --word-symbol-table=$symtab \
51 | ark:- ark,t:$dir/scoring/LMWT.$wip.tra || exit 1;
52 | done
53 |
54 | # Note: the double level of quoting for the sed command
55 | for wip in $(echo $word_ins_penalty | sed 's/,/ /g'); do
56 | $cmd LMWT=$min_lmwt:$max_lmwt $dir/scoring/log/score.LMWT.$wip.log \
57 | cat $dir/scoring/LMWT.$wip.tra \| \
58 | utils/int2sym.pl -f 2- $symtab \| sed 's:\::g' \| \
59 | compute-wer --text --mode=present \
60 | ark:$dir/scoring/test_filt.txt ark,p:- ">&" $dir/wer_LMWT_$wip || exit 1;
61 | done
62 |
63 | exit 0;
64 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/kaldi_self_train/st/local/show_wer.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | split="dev_other"
4 | ref_data=""
5 | get_best_wer=true
6 | dec_name="decode"
7 | graph_name="graph"
8 |
9 | . ./cmd.sh
10 | . ./path.sh
11 | . parse_options.sh
12 |
13 | exp_root=$1
14 |
15 | set -eu
16 |
17 | echo "==== WER w.r.t. pseudo transcript"
18 | for x in $exp_root/*/${dec_name}_${split}*; do grep WER $x/wer_* 2>/dev/null | utils/best_wer.sh; done
19 |
20 |
21 | if [ ! -z $ref_data ]; then
22 | echo "==== WER w.r.t. real transcript (select based on pseudo WER)"
23 | ref_txt=$ref_data/$split/text
24 | for x in $exp_root/*/${dec_name}_${split}*; do
25 | lang=$(dirname $x)/$graph_name
26 |
27 | lmwt=$(
28 | grep WER $x/wer_* 2>/dev/null | utils/best_wer.sh |
29 | sed 's/.*wer_\(.*\)$/\1/g' | sed 's/_/./g'
30 | )
31 | tra=$x/scoring/$lmwt.tra
32 | cat $tra | utils/int2sym.pl -f 2- $lang/words.txt | sed 's:::g' | sed 's:::g' | \
33 | compute-wer --text --mode=present \
34 | ark:$ref_txt ark,p:- 2> /dev/null | grep WER | xargs -I{} echo {} $tra
35 | done
36 | fi
37 |
38 | if [ ! -z $ref_data ] && $get_best_wer; then
39 | echo "==== WER w.r.t. real transcript (select based on true WER)"
40 | ref_txt=$ref_data/$split/text
41 | for x in $exp_root/*/${dec_name}_${split}*; do
42 | lang=$(dirname $x)/$graph_name
43 |
44 | for tra in $x/scoring/*.tra; do
45 | cat $tra | utils/int2sym.pl -f 2- $lang/words.txt | sed 's:::g' | sed 's:::g' | \
46 | compute-wer --text --mode=present \
47 | ark:$ref_txt ark,p:- 2> /dev/null | grep WER | xargs -I{} echo {} $tra
48 | done | sort -k2n | head -n1
49 | done
50 | fi
51 |
52 | exit 0;
53 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/kaldi_self_train/st/local/train_subset_lgbeam.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | out_root=/tmp
4 | out_name=train_${RANDOM}
5 | num_nonsil_states=1
6 |
7 | valid="dev_other"
8 | train="train"
9 | mono_size="-1" # 2000
10 | tri1_size="-1" # 5000
11 | tri2b_size="-1" # 10000
12 | tri3b_size="-1" # 10000
13 |
14 | # Acoustic model parameters
15 | numLeavesTri1=2000
16 | numGaussTri1=10000
17 | numLeavesMLLT=2500
18 | numGaussMLLT=15000
19 | numLeavesSAT=2500
20 | numGaussSAT=15000
21 |
22 | stage=1
23 | max_stage=1
24 |
25 | . ./cmd.sh
26 | . ./path.sh
27 | . parse_options.sh
28 |
29 | data=$1
30 | lang=$2
31 | lang_test=$3
32 |
33 | exp_root=$out_root/$out_name
34 |
35 | # you might not want to do this for interactive shells.
36 | set -e
37 |
38 |
39 | if [ $stage -le 1 ] && [ $max_stage -ge 1 ]; then
40 | # train a monophone system
41 | if [ ! $mono_size -eq -1 ]; then
42 | utils/subset_data_dir.sh $data/$train $mono_size $data/${train}_${mono_size}
43 | mono_train=${train}_${mono_size}
44 | else
45 | mono_train=${train}
46 | fi
47 |
48 | steps/train_mono.sh --boost-silence 1.25 --nj 20 --cmd "$train_cmd" \
49 | --initial-beam 40 --regular-beam 60 --retry-beam 120 \
50 | $data/$mono_train $lang $exp_root/mono
51 |
52 | utils/mkgraph.sh $lang_test $exp_root/mono $exp_root/mono/graph
53 | steps/decode.sh --nj 20 --cmd "$decode_cmd" \
54 | $exp_root/mono/graph $data/$valid $exp_root/mono/decode_$valid &
55 | fi
56 |
57 |
58 | if [ $stage -le 2 ] && [ $max_stage -ge 2 ]; then
59 | # train a first delta + delta-delta triphone system on a subset of 5000 utterances
60 | if [ ! $tri1_size -eq -1 ]; then
61 | utils/subset_data_dir.sh $data/$train $tri1_size $data/${train}_${tri1_size}
62 | tri1_train=${train}_${tri1_size}
63 | else
64 | tri1_train=${train}
65 | fi
66 |
67 | steps/align_si.sh --boost-silence 1.25 --nj 10 --cmd "$train_cmd" \
68 | $data/$tri1_train $lang \
69 | $exp_root/mono $exp_root/mono_ali_${tri1_train}
70 |
71 | steps_gan/train_deltas.sh --boost-silence 1.25 --cmd "$train_cmd" \
72 | --num_nonsil_states $num_nonsil_states $numLeavesTri1 $numGaussTri1 \
73 | $data/$tri1_train $lang \
74 | $exp_root/mono_ali_${tri1_train} $exp_root/tri1
75 |
76 | utils/mkgraph.sh $lang_test $exp_root/tri1 $exp_root/tri1/graph
77 | steps/decode.sh --nj 20 --cmd "$decode_cmd" \
78 | $exp_root/tri1/graph $data/$valid $exp_root/tri1/decode_$valid &
79 | fi
80 |
81 | if [ $stage -le 3 ] && [ $max_stage -ge 3 ]; then
82 | # train an LDA+MLLT system.
83 | if [ ! $tri2b_size -eq -1 ]; then
84 | utils/subset_data_dir.sh $data/$train $tri2b_size $data/${train}_${tri2b_size}
85 | tri2b_train=${train}_${tri2b_size}
86 | else
87 | tri2b_train=${train}
88 | fi
89 |
90 | steps/align_si.sh --nj 10 --cmd "$train_cmd" \
91 | $data/$tri2b_train $lang \
92 | $exp_root/tri1 $exp_root/tri1_ali_${tri2b_train}
93 |
94 | steps_gan/train_lda_mllt.sh --cmd "$train_cmd" \
95 | --num_nonsil_states $num_nonsil_states \
96 | --splice-opts "--left-context=3 --right-context=3" $numLeavesMLLT $numGaussMLLT \
97 | $data/$tri2b_train $lang \
98 | $exp_root/tri1_ali_${tri2b_train} $exp_root/tri2b
99 |
100 | utils/mkgraph.sh $lang_test $exp_root/tri2b $exp_root/tri2b/graph
101 | steps/decode.sh --nj 20 --cmd "$decode_cmd" \
102 | $exp_root/tri2b/graph $data/$valid $exp_root/tri2b/decode_$valid &
103 | fi
104 |
105 |
106 | if [ $stage -le 4 ] && [ $max_stage -ge 4 ]; then
107 | # Train tri3b, which is LDA+MLLT+SAT on 10k utts
108 | if [ ! $tri3b_size -eq -1 ]; then
109 | utils/subset_data_dir.sh $data/$train $tri3b_size $data/${train}_${tri3b_size}
110 | tri3b_train=${train}_${tri3b_size}
111 | else
112 | tri3b_train=${train}
113 | fi
114 |
115 | steps/align_si.sh --nj 10 --cmd "$train_cmd" --use-graphs true \
116 | $data/$tri3b_train $lang \
117 | $exp_root/tri2b $exp_root/tri2b_ali_${tri2b_train}
118 |
119 | steps_gan/train_sat.sh --cmd "$train_cmd" \
120 | --num_nonsil_states $num_nonsil_states $numLeavesSAT $numGaussSAT \
121 | $data/$tri3b_train $lang \
122 | $exp_root/tri2b_ali_${tri2b_train} $exp_root/tri3b
123 |
124 | utils/mkgraph.sh $lang_test $exp_root/tri3b $exp_root/tri3b/graph
125 | steps/decode_fmllr.sh --nj 20 --cmd "$decode_cmd" \
126 | $exp_root/tri3b/graph $data/$valid $exp_root/tri3b/decode_$valid &
127 | fi
128 |
129 | wait
130 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/kaldi_self_train/st/local/unsup_select.py:
--------------------------------------------------------------------------------
1 | """
2 | Implement unsupervised metric for decoding hyperparameter selection:
3 | $$ alpha * LM_PPL + ViterbitUER(%) * 100 $$
4 | """
5 | import argparse
6 | import logging
7 | import math
8 | import sys
9 |
10 | import kenlm
11 | import editdistance
12 | from g2p_en import G2p
13 |
14 | logging.root.setLevel(logging.INFO)
15 | logging.basicConfig(stream=sys.stdout, level=logging.INFO)
16 | logger = logging.getLogger(__name__)
17 |
18 |
19 | def get_parser():
20 | parser = argparse.ArgumentParser()
21 | parser.add_argument("ref_tra", help="reference pseudo labels")
22 | parser.add_argument("hyp_tra", help="decoded pseudo labels to be assess")
23 | parser.add_argument("--kenlm_path", default="/checkpoint/abaevski/data/speech/libri/librispeech_lm_novox.phnc_o5.bin", help="")
24 | parser.add_argument("--uppercase", action="store_true", help="")
25 | parser.add_argument("--skipwords", default="", help="")
26 | parser.add_argument("--gt_tra", default="", help="ground truth pseudo labels for computing oracle WER")
27 | parser.add_argument("--min_vt_uer", default=0.0, type=float)
28 | parser.add_argument("--phonemize", action="store_true", help="phonemize word hypotheses, used when reference is phone transcript")
29 | parser.add_argument("--phonemize_lexicon", default="", type=str, help="use a lexicon for phonemizing")
30 | return parser
31 |
32 | def load_tra(tra_path):
33 | with open(tra_path, "r") as f:
34 | uid_to_tra = {}
35 | for line in f:
36 | toks = line.rstrip().split()
37 | uid, tra = toks[0], " ".join(toks[1:])
38 | uid_to_tra[uid] = tra
39 | logger.debug(f"loaded {len(uid_to_tra)} utterances from {tra_path}")
40 | return uid_to_tra
41 |
42 | def load_lex(lex_path):
43 | with open(lex_path, "r") as f:
44 | w2p = {}
45 | for line in f:
46 | w, p = line.rstrip().split(None, 1)
47 | w2p[w] = p.split()
48 | return w2p
49 |
50 | def compute_wer(ref_uid_to_tra, hyp_uid_to_tra, g2p, g2p_dict):
51 | d_cnt = 0
52 | w_cnt = 0
53 | w_cnt_h = 0
54 | for uid in hyp_uid_to_tra:
55 | ref = ref_uid_to_tra[uid].split()
56 | if g2p_dict is not None:
57 | hyp = []
58 | for word in hyp_uid_to_tra[uid].split():
59 | if word in g2p_dict:
60 | hyp = hyp + g2p_dict[word]
61 | else:
62 | logger.warning(f"{word} not in g2p_dict")
63 | elif g2p is not None:
64 | hyp = g2p(hyp_uid_to_tra[uid])
65 | hyp = [p for p in hyp if p != "'" and p != " "]
66 | hyp = [p[:-1] if p[-1].isnumeric() else p for p in hyp]
67 | else:
68 | hyp = hyp_uid_to_tra[uid].split()
69 | logger.debug((
70 | f"======================\n"
71 | f"HYP: {' '.join(hyp)}\n"
72 | f"REF: {' '.join(ref)}"
73 | ))
74 | d_cnt += editdistance.eval(ref, hyp)
75 | w_cnt += len(ref)
76 | w_cnt_h += len(hyp)
77 | wer = float(d_cnt) / w_cnt
78 | logger.debug((
79 | f"wer = {wer*100:.2f}%; num. of ref words = {w_cnt}; "
80 | f"num. of hyp words = {w_cnt_h}; num. of sentences = {len(ref_uid_to_tra)}"
81 | ))
82 | return wer
83 |
84 | def compute_lm_ppl(hyp_uid_to_tra, score_fn):
85 | lm_score = 0.
86 | w_cnt = 0
87 | for hyp in hyp_uid_to_tra.values():
88 | cur_score = score_fn(hyp)
89 | cur_cnt = len(hyp.split()) + 1 # plus one for
90 | lm_score += cur_score
91 | w_cnt += cur_cnt
92 | logger.debug((
93 | f"======================\n"
94 | f"score sum/avg = {cur_score:.2f}/{cur_score/cur_cnt:.2f}\n"
95 | f"hyp = {hyp}"
96 | ))
97 | lm_ppl = math.pow(10, -lm_score / w_cnt)
98 | logger.debug(f"lm ppl = {lm_ppl:.2f}; num. of words = {w_cnt}")
99 | return lm_ppl
100 |
101 | def main():
102 | args = get_parser().parse_args()
103 | logger.debug(f"Args: {args}")
104 |
105 | ref_uid_to_tra = load_tra(args.ref_tra)
106 | hyp_uid_to_tra = load_tra(args.hyp_tra)
107 | assert not bool(set(hyp_uid_to_tra.keys()) - set(ref_uid_to_tra.keys()))
108 |
109 | lm = kenlm.Model(args.kenlm_path)
110 | skipwords = set(args.skipwords.split(","))
111 | def compute_lm_score(s):
112 | s = " ".join(w for w in s.split() if w not in skipwords)
113 | s = s.upper() if args.uppercase else s
114 | return lm.score(s)
115 |
116 | g2p, g2p_dict = None, None
117 | if args.phonemize:
118 | if args.phonemize_lexicon:
119 | g2p_dict = load_lex(args.phonemize_lexicon)
120 | else:
121 | g2p = G2p()
122 |
123 | wer = compute_wer(ref_uid_to_tra, hyp_uid_to_tra, g2p, g2p_dict)
124 | lm_ppl = compute_lm_ppl(hyp_uid_to_tra, compute_lm_score)
125 |
126 | gt_wer = -math.inf
127 | if args.gt_tra:
128 | gt_uid_to_tra = load_tra(args.gt_tra)
129 | gt_wer = compute_wer(gt_uid_to_tra, hyp_uid_to_tra, None, None)
130 |
131 | score = math.log(lm_ppl) * max(wer, args.min_vt_uer)
132 | logging.info(f"{args.hyp_tra}: score={score:.4f}; wer={wer*100:.2f}%; lm_ppl={lm_ppl:.4f}; gt_wer={gt_wer*100:.2f}%")
133 |
134 | if __name__ == "__main__":
135 | main()
136 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/kaldi_self_train/st/local/unsup_select_decode.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | split="dev_other"
4 | ref_txt="" # ground truth transcript path
5 | psd_txt="" # pseudo transcript path
6 | get_best_wer=true
7 | dec_name="decode"
8 | graph_name="graph"
9 | kenlm_path=/checkpoint/abaevski/data/speech/libri/librispeech_lm_novox.phnc_o6.bin
10 |
11 | . ./cmd.sh
12 | . ./path.sh
13 | . parse_options.sh
14 |
15 | exp_root=$1
16 | unsup_args=""
17 | if [ $# -ge 2 ]; then
18 | unsup_args=$2
19 | fi
20 |
21 | set -eu
22 |
23 | if [ ! -z $ref_txt ] && $get_best_wer; then
24 | echo "==== WER w.r.t. real transcript (select based on unsupervised metric)"
25 | for x in $exp_root/*/${dec_name}_${split}*; do
26 | lang=$(dirname $x)/$graph_name
27 |
28 | (
29 | for tra in $x/scoring/*.tra; do
30 | cat $tra | utils/int2sym.pl -f 2- $lang/words.txt | sed 's:::g' | sed 's:::g' > $tra.txt
31 | python local/unsup_select.py $psd_txt $tra.txt --kenlm_path $kenlm_path --gt_tra $ref_txt $unsup_args
32 | done 2>/dev/null | grep "score=" | sed 's/=/ /g' | sed 's/;//g' | sort -k3n | head -n1
33 | ) &
34 | done
35 | fi
36 | wait
37 |
38 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/kaldi_self_train/st/local/unsup_select_decode_word.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | split="dev_other"
4 | ref_txt="" # ground truth transcript path
5 | psd_txt="" # pseudo transcript path
6 | get_best_wer=true
7 | dec_name="decode"
8 | graph_name="graph"
9 | kenlm_path=/checkpoint/abaevski/data/speech/libri/librispeech_lm_novox.phnc_o6.bin
10 | phonemize_lexicon=""
11 |
12 | . ./cmd.sh
13 | . ./path.sh
14 | . parse_options.sh
15 | . /private/home/wnhsu/unsup_asr/fairseq-py-unsup/env.sh
16 |
17 | exp_root=$1
18 |
19 | set -eu
20 |
21 | if [ ! -z $ref_txt ] && $get_best_wer; then
22 | echo "==== WER w.r.t. real transcript (select based on unsupervised metric)"
23 | for x in $exp_root/*/${dec_name}_${split}*; do
24 | lang=$(dirname $x)/$graph_name
25 |
26 | for tra in $x/scoring/*.tra; do
27 | cat $tra | utils/int2sym.pl -f 2- $lang/words.txt | sed 's:\::g' > $tra.txt
28 | python local/unsup_select.py $psd_txt $tra.txt \
29 | --kenlm_path $kenlm_path --gt_tra $ref_txt --phonemize \
30 | --phonemize_lexicon "$phonemize_lexicon"
31 | done | grep "score=" | sed 's/=/ /g' | sed 's/;//g' | sort -k3n | head -n1
32 | done
33 | fi
34 |
35 |
36 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/kaldi_self_train/st/path.sh:
--------------------------------------------------------------------------------
1 | export KALDI_ROOT=`pwd`/../../..
2 | export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PWD:$PATH
3 | [ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1
4 | . $KALDI_ROOT/tools/config/common_path.sh
5 | export LC_ALL=C
6 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/kaldi_self_train/st/steps:
--------------------------------------------------------------------------------
1 | ../../wsj/s5/steps
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/kaldi_self_train/st/steps_gan/train_deltas.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # Copyright 2012 Johns Hopkins University (Author: Daniel Povey)
4 | # Apache 2.0
5 |
6 | # Begin configuration.
7 | stage=-4 # This allows restarting after partway, when something when wrong.
8 | config=
9 | cmd=run.pl
10 | scale_opts="--transition-scale=1.0 --acoustic-scale=0.1 --self-loop-scale=0.1"
11 | realign_iters="10 20 30";
12 | num_iters=35 # Number of iterations of training
13 | max_iter_inc=25 # Last iter to increase #Gauss on.
14 | beam=10
15 | careful=false
16 | retry_beam=40
17 | boost_silence=1.0 # Factor by which to boost silence likelihoods in alignment
18 | power=0.25 # Exponent for number of gaussians according to occurrence counts
19 | cluster_thresh=-1 # for build-tree control final bottom-up clustering of leaves
20 | norm_vars=false # deprecated. Prefer --cmvn-opts "--norm-vars=true"
21 | # use the option --cmvn-opts "--norm-means=false"
22 | cmvn_opts=
23 | delta_opts=
24 | context_opts= # use"--context-width=5 --central-position=2" for quinphone
25 | num_nonsil_states=3
26 | # End configuration.
27 |
28 | echo "$0 $@" # Print the command line for logging
29 |
30 | [ -f path.sh ] && . ./path.sh;
31 | . parse_options.sh || exit 1;
32 |
33 | if [ $# != 6 ]; then
34 | echo "Usage: steps/train_deltas.sh "
35 | echo "e.g.: steps/train_deltas.sh 2000 10000 data/train_si84_half data/lang exp/mono_ali exp/tri1"
36 | echo "main options (for others, see top of script file)"
37 | echo " --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs."
38 | echo " --config # config containing options"
39 | echo " --stage # stage to do partial re-run from."
40 | exit 1;
41 | fi
42 |
43 | numleaves=$1
44 | totgauss=$2
45 | data=$3
46 | lang=$4
47 | alidir=$5
48 | dir=$6
49 |
50 | for f in $alidir/final.mdl $alidir/ali.1.gz $data/feats.scp $lang/phones.txt; do
51 | [ ! -f $f ] && echo "train_deltas.sh: no such file $f" && exit 1;
52 | done
53 |
54 | numgauss=$numleaves
55 | incgauss=$[($totgauss-$numgauss)/$max_iter_inc] # per-iter increment for #Gauss
56 | oov=`cat $lang/oov.int` || exit 1;
57 | ciphonelist=`cat $lang/phones/context_indep.csl` || exit 1;
58 | nj=`cat $alidir/num_jobs` || exit 1;
59 | mkdir -p $dir/log
60 | echo $nj > $dir/num_jobs
61 |
62 | utils/lang/check_phones_compatible.sh $lang/phones.txt $alidir/phones.txt || exit 1;
63 | cp $lang/phones.txt $dir || exit 1;
64 |
65 | sdata=$data/split$nj;
66 | split_data.sh $data $nj || exit 1;
67 |
68 |
69 | [ $(cat $alidir/cmvn_opts 2>/dev/null | wc -c) -gt 1 ] && [ -z "$cmvn_opts" ] && \
70 | echo "$0: warning: ignoring CMVN options from source directory $alidir"
71 | $norm_vars && cmvn_opts="--norm-vars=true $cmvn_opts"
72 | echo $cmvn_opts > $dir/cmvn_opts # keep track of options to CMVN.
73 | [ ! -z $delta_opts ] && echo $delta_opts > $dir/delta_opts
74 |
75 | feats="ark,s,cs:apply-cmvn $cmvn_opts --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp scp:$sdata/JOB/feats.scp ark:- | add-deltas $delta_opts ark:- ark:- |"
76 |
77 | rm $dir/.error 2>/dev/null
78 |
79 | if [ $stage -le -3 ]; then
80 | echo "$0: accumulating tree stats"
81 | $cmd JOB=1:$nj $dir/log/acc_tree.JOB.log \
82 | acc-tree-stats $context_opts \
83 | --ci-phones=$ciphonelist $alidir/final.mdl "$feats" \
84 | "ark:gunzip -c $alidir/ali.JOB.gz|" $dir/JOB.treeacc || exit 1;
85 | sum-tree-stats $dir/treeacc $dir/*.treeacc 2>$dir/log/sum_tree_acc.log || exit 1;
86 | rm $dir/*.treeacc
87 | fi
88 |
89 | if [ $stage -le -2 ]; then
90 | echo "$0: getting questions for tree-building, via clustering"
91 | # preparing questions, roots file...
92 | cluster-phones --pdf-class-list=$(($num_nonsil_states / 2)) $context_opts \
93 | $dir/treeacc $lang/phones/sets.int \
94 | $dir/questions.int 2> $dir/log/questions.log || exit 1;
95 | cat $lang/phones/extra_questions.int >> $dir/questions.int
96 | compile-questions $context_opts $lang/topo $dir/questions.int \
97 | $dir/questions.qst 2>$dir/log/compile_questions.log || exit 1;
98 |
99 | echo "$0: building the tree"
100 | $cmd $dir/log/build_tree.log \
101 | build-tree $context_opts --verbose=1 --max-leaves=$numleaves \
102 | --cluster-thresh=$cluster_thresh $dir/treeacc $lang/phones/roots.int \
103 | $dir/questions.qst $lang/topo $dir/tree || exit 1;
104 |
105 | $cmd $dir/log/init_model.log \
106 | gmm-init-model --write-occs=$dir/1.occs \
107 | $dir/tree $dir/treeacc $lang/topo $dir/1.mdl || exit 1;
108 | if grep 'no stats' $dir/log/init_model.log; then
109 | echo "** The warnings above about 'no stats' generally mean you have phones **"
110 | echo "** (or groups of phones) in your phone set that had no corresponding data. **"
111 | echo "** You should probably figure out whether something went wrong, **"
112 | echo "** or whether your data just doesn't happen to have examples of those **"
113 | echo "** phones. **"
114 | fi
115 |
116 | gmm-mixup --mix-up=$numgauss $dir/1.mdl $dir/1.occs $dir/1.mdl 2>$dir/log/mixup.log || exit 1;
117 | rm $dir/treeacc
118 | fi
119 |
120 | if [ $stage -le -1 ]; then
121 | # Convert the alignments.
122 | echo "$0: converting alignments from $alidir to use current tree"
123 | $cmd JOB=1:$nj $dir/log/convert.JOB.log \
124 | convert-ali $alidir/final.mdl $dir/1.mdl $dir/tree \
125 | "ark:gunzip -c $alidir/ali.JOB.gz|" "ark:|gzip -c >$dir/ali.JOB.gz" || exit 1;
126 | fi
127 |
128 | if [ $stage -le 0 ]; then
129 | echo "$0: compiling graphs of transcripts"
130 | $cmd JOB=1:$nj $dir/log/compile_graphs.JOB.log \
131 | compile-train-graphs --read-disambig-syms=$lang/phones/disambig.int $dir/tree $dir/1.mdl $lang/L.fst \
132 | "ark:utils/sym2int.pl --map-oov $oov -f 2- $lang/words.txt < $sdata/JOB/text |" \
133 | "ark:|gzip -c >$dir/fsts.JOB.gz" || exit 1;
134 | fi
135 |
136 | x=1
137 | while [ $x -lt $num_iters ]; do
138 | echo "$0: training pass $x"
139 | if [ $stage -le $x ]; then
140 | if echo $realign_iters | grep -w $x >/dev/null; then
141 | echo "$0: aligning data"
142 | mdl="gmm-boost-silence --boost=$boost_silence `cat $lang/phones/optional_silence.csl` $dir/$x.mdl - |"
143 | $cmd JOB=1:$nj $dir/log/align.$x.JOB.log \
144 | gmm-align-compiled $scale_opts --beam=$beam --retry-beam=$retry_beam --careful=$careful "$mdl" \
145 | "ark:gunzip -c $dir/fsts.JOB.gz|" "$feats" \
146 | "ark:|gzip -c >$dir/ali.JOB.gz" || exit 1;
147 | fi
148 | $cmd JOB=1:$nj $dir/log/acc.$x.JOB.log \
149 | gmm-acc-stats-ali $dir/$x.mdl "$feats" \
150 | "ark,s,cs:gunzip -c $dir/ali.JOB.gz|" $dir/$x.JOB.acc || exit 1;
151 | $cmd $dir/log/update.$x.log \
152 | gmm-est --mix-up=$numgauss --power=$power \
153 | --write-occs=$dir/$[$x+1].occs $dir/$x.mdl \
154 | "gmm-sum-accs - $dir/$x.*.acc |" $dir/$[$x+1].mdl || exit 1;
155 | rm $dir/$x.mdl $dir/$x.*.acc
156 | rm $dir/$x.occs
157 | fi
158 | [ $x -le $max_iter_inc ] && numgauss=$[$numgauss+$incgauss];
159 | x=$[$x+1];
160 | done
161 |
162 | rm $dir/final.mdl $dir/final.occs 2>/dev/null
163 | ln -s $x.mdl $dir/final.mdl
164 | ln -s $x.occs $dir/final.occs
165 |
166 | steps/diagnostic/analyze_alignments.sh --cmd "$cmd" $lang $dir
167 |
168 | # Summarize warning messages...
169 | utils/summarize_warnings.pl $dir/log
170 |
171 | steps/info/gmm_dir_info.pl $dir
172 |
173 | echo "$0: Done training system with delta+delta-delta features in $dir"
174 |
175 | exit 0
176 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/kaldi_self_train/st/steps_gan/train_lda_mllt.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # Copyright 2012 Johns Hopkins University (Author: Daniel Povey)
4 | #
5 | # LDA+MLLT refers to the way we transform the features after computing
6 | # the MFCCs: we splice across several frames, reduce the dimension (to 40
7 | # by default) using Linear Discriminant Analysis), and then later estimate,
8 | # over multiple iterations, a diagonalizing transform known as MLLT or STC.
9 | # See http://kaldi-asr.org/doc/transform.html for more explanation.
10 | #
11 | # Apache 2.0.
12 |
13 | # Begin configuration.
14 | cmd=run.pl
15 | config=
16 | stage=-5
17 | scale_opts="--transition-scale=1.0 --acoustic-scale=0.1 --self-loop-scale=0.1"
18 | realign_iters="10 20 30";
19 | mllt_iters="2 4 6 12";
20 | num_iters=35 # Number of iterations of training
21 | max_iter_inc=25 # Last iter to increase #Gauss on.
22 | dim=40
23 | beam=10
24 | retry_beam=40
25 | careful=false
26 | boost_silence=1.0 # Factor by which to boost silence likelihoods in alignment
27 | power=0.25 # Exponent for number of gaussians according to occurrence counts
28 | randprune=4.0 # This is approximately the ratio by which we will speed up the
29 | # LDA and MLLT calculations via randomized pruning.
30 | splice_opts=
31 | cluster_thresh=-1 # for build-tree control final bottom-up clustering of leaves
32 | norm_vars=false # deprecated. Prefer --cmvn-opts "--norm-vars=false"
33 | cmvn_opts=
34 | context_opts= # use "--context-width=5 --central-position=2" for quinphone.
35 | # End configuration.
36 | train_tree=true # if false, don't actually train the tree.
37 | use_lda_mat= # If supplied, use this LDA[+MLLT] matrix.
38 | num_nonsil_states=3
39 |
40 | echo "$0 $@" # Print the command line for logging
41 |
42 | [ -f path.sh ] && . ./path.sh
43 | . parse_options.sh || exit 1;
44 |
45 | if [ $# != 6 ]; then
46 | echo "Usage: steps/train_lda_mllt.sh [options] <#leaves> <#gauss> "
47 | echo " e.g.: steps/train_lda_mllt.sh 2500 15000 data/train_si84 data/lang exp/tri1_ali_si84 exp/tri2b"
48 | echo "Main options (for others, see top of script file)"
49 | echo " --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs."
50 | echo " --config # config containing options"
51 | echo " --stage # stage to do partial re-run from."
52 | exit 1;
53 | fi
54 |
55 | numleaves=$1
56 | totgauss=$2
57 | data=$3
58 | lang=$4
59 | alidir=$5
60 | dir=$6
61 |
62 | for f in $alidir/final.mdl $alidir/ali.1.gz $data/feats.scp $lang/phones.txt; do
63 | [ ! -f $f ] && echo "train_lda_mllt.sh: no such file $f" && exit 1;
64 | done
65 |
66 | numgauss=$numleaves
67 | incgauss=$[($totgauss-$numgauss)/$max_iter_inc] # per-iter #gauss increment
68 | oov=`cat $lang/oov.int` || exit 1;
69 | nj=`cat $alidir/num_jobs` || exit 1;
70 | silphonelist=`cat $lang/phones/silence.csl` || exit 1;
71 | ciphonelist=`cat $lang/phones/context_indep.csl` || exit 1;
72 |
73 | mkdir -p $dir/log
74 |
75 | utils/lang/check_phones_compatible.sh $lang/phones.txt $alidir/phones.txt || exit 1;
76 | cp $lang/phones.txt $dir || exit 1;
77 |
78 | echo $nj >$dir/num_jobs
79 | echo "$splice_opts" >$dir/splice_opts # keep track of frame-splicing options
80 | # so that later stages of system building can know what they were.
81 |
82 |
83 | [ $(cat $alidir/cmvn_opts 2>/dev/null | wc -c) -gt 1 ] && [ -z "$cmvn_opts" ] && \
84 | echo "$0: warning: ignoring CMVN options from source directory $alidir"
85 | $norm_vars && cmvn_opts="--norm-vars=true $cmvn_opts"
86 | echo $cmvn_opts > $dir/cmvn_opts # keep track of options to CMVN.
87 |
88 | sdata=$data/split$nj;
89 | split_data.sh $data $nj || exit 1;
90 |
91 | splicedfeats="ark,s,cs:apply-cmvn $cmvn_opts --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp scp:$sdata/JOB/feats.scp ark:- | splice-feats $splice_opts ark:- ark:- |"
92 | # Note: $feats gets overwritten later in the script.
93 | feats="$splicedfeats transform-feats $dir/0.mat ark:- ark:- |"
94 |
95 |
96 |
97 | if [ $stage -le -5 ]; then
98 | if [ -z "$use_lda_mat" ]; then
99 | echo "$0: Accumulating LDA statistics."
100 | rm $dir/lda.*.acc 2>/dev/null
101 | $cmd JOB=1:$nj $dir/log/lda_acc.JOB.log \
102 | ali-to-post "ark:gunzip -c $alidir/ali.JOB.gz|" ark:- \| \
103 | weight-silence-post 0.0 $silphonelist $alidir/final.mdl ark:- ark:- \| \
104 | acc-lda --rand-prune=$randprune $alidir/final.mdl "$splicedfeats" ark,s,cs:- \
105 | $dir/lda.JOB.acc || exit 1;
106 | est-lda --write-full-matrix=$dir/full.mat --dim=$dim $dir/0.mat $dir/lda.*.acc \
107 | 2>$dir/log/lda_est.log || exit 1;
108 | rm $dir/lda.*.acc
109 | else
110 | echo "$0: Using supplied LDA matrix $use_lda_mat"
111 | cp $use_lda_mat $dir/0.mat || exit 1;
112 | [ ! -z "$mllt_iters" ] && \
113 | echo "$0: Warning: using supplied LDA matrix $use_lda_mat but we will do MLLT," && \
114 | echo " which you might not want; to disable MLLT, specify --mllt-iters ''" && \
115 | sleep 5
116 | fi
117 | fi
118 |
119 | cur_lda_iter=0
120 |
121 | if [ $stage -le -4 ] && $train_tree; then
122 | echo "$0: Accumulating tree stats"
123 | $cmd JOB=1:$nj $dir/log/acc_tree.JOB.log \
124 | acc-tree-stats $context_opts \
125 | --ci-phones=$ciphonelist $alidir/final.mdl "$feats" \
126 | "ark:gunzip -c $alidir/ali.JOB.gz|" $dir/JOB.treeacc || exit 1;
127 | [ `ls $dir/*.treeacc | wc -w` -ne "$nj" ] && echo "$0: Wrong #tree-accs" && exit 1;
128 | $cmd $dir/log/sum_tree_acc.log \
129 | sum-tree-stats $dir/treeacc $dir/*.treeacc || exit 1;
130 | rm $dir/*.treeacc
131 | fi
132 |
133 |
134 | if [ $stage -le -3 ] && $train_tree; then
135 | echo "$0: Getting questions for tree clustering."
136 | # preparing questions, roots file...
137 | cluster-phones --pdf-class-list=$(($num_nonsil_states / 2)) $context_opts $dir/treeacc $lang/phones/sets.int \
138 | $dir/questions.int 2> $dir/log/questions.log || exit 1;
139 | cat $lang/phones/extra_questions.int >> $dir/questions.int
140 | compile-questions $context_opts $lang/topo $dir/questions.int \
141 | $dir/questions.qst 2>$dir/log/compile_questions.log || exit 1;
142 |
143 | echo "$0: Building the tree"
144 | $cmd $dir/log/build_tree.log \
145 | build-tree $context_opts --verbose=1 --max-leaves=$numleaves \
146 | --cluster-thresh=$cluster_thresh $dir/treeacc $lang/phones/roots.int \
147 | $dir/questions.qst $lang/topo $dir/tree || exit 1;
148 | fi
149 |
150 | if [ $stage -le -2 ]; then
151 | echo "$0: Initializing the model"
152 | if $train_tree; then
153 | gmm-init-model --write-occs=$dir/1.occs \
154 | $dir/tree $dir/treeacc $lang/topo $dir/1.mdl 2> $dir/log/init_model.log || exit 1;
155 | grep 'no stats' $dir/log/init_model.log && echo "This is a bad warning.";
156 | rm $dir/treeacc
157 | else
158 | cp $alidir/tree $dir/ || exit 1;
159 | $cmd JOB=1 $dir/log/init_model.log \
160 | gmm-init-model-flat $dir/tree $lang/topo $dir/1.mdl \
161 | "$feats subset-feats ark:- ark:-|" || exit 1;
162 | fi
163 | fi
164 |
165 |
166 | if [ $stage -le -1 ]; then
167 | # Convert the alignments.
168 | echo "$0: Converting alignments from $alidir to use current tree"
169 | $cmd JOB=1:$nj $dir/log/convert.JOB.log \
170 | convert-ali $alidir/final.mdl $dir/1.mdl $dir/tree \
171 | "ark:gunzip -c $alidir/ali.JOB.gz|" "ark:|gzip -c >$dir/ali.JOB.gz" || exit 1;
172 | fi
173 |
174 | if [ $stage -le 0 ] && [ "$realign_iters" != "" ]; then
175 | echo "$0: Compiling graphs of transcripts"
176 | $cmd JOB=1:$nj $dir/log/compile_graphs.JOB.log \
177 | compile-train-graphs --read-disambig-syms=$lang/phones/disambig.int $dir/tree $dir/1.mdl $lang/L.fst \
178 | "ark:utils/sym2int.pl --map-oov $oov -f 2- $lang/words.txt < $data/split$nj/JOB/text |" \
179 | "ark:|gzip -c >$dir/fsts.JOB.gz" || exit 1;
180 | fi
181 |
182 |
183 | x=1
184 | while [ $x -lt $num_iters ]; do
185 | echo Training pass $x
186 | if echo $realign_iters | grep -w $x >/dev/null && [ $stage -le $x ]; then
187 | echo Aligning data
188 | mdl="gmm-boost-silence --boost=$boost_silence `cat $lang/phones/optional_silence.csl` $dir/$x.mdl - |"
189 | $cmd JOB=1:$nj $dir/log/align.$x.JOB.log \
190 | gmm-align-compiled $scale_opts --beam=$beam --retry-beam=$retry_beam --careful=$careful "$mdl" \
191 | "ark:gunzip -c $dir/fsts.JOB.gz|" "$feats" \
192 | "ark:|gzip -c >$dir/ali.JOB.gz" || exit 1;
193 | fi
194 | if echo $mllt_iters | grep -w $x >/dev/null; then
195 | if [ $stage -le $x ]; then
196 | echo "$0: Estimating MLLT"
197 | $cmd JOB=1:$nj $dir/log/macc.$x.JOB.log \
198 | ali-to-post "ark:gunzip -c $dir/ali.JOB.gz|" ark:- \| \
199 | weight-silence-post 0.0 $silphonelist $dir/$x.mdl ark:- ark:- \| \
200 | gmm-acc-mllt --rand-prune=$randprune $dir/$x.mdl "$feats" ark:- $dir/$x.JOB.macc \
201 | || exit 1;
202 | est-mllt $dir/$x.mat.new $dir/$x.*.macc 2> $dir/log/mupdate.$x.log || exit 1;
203 | gmm-transform-means $dir/$x.mat.new $dir/$x.mdl $dir/$x.mdl \
204 | 2> $dir/log/transform_means.$x.log || exit 1;
205 | compose-transforms --print-args=false $dir/$x.mat.new $dir/$cur_lda_iter.mat $dir/$x.mat || exit 1;
206 | rm $dir/$x.*.macc
207 | fi
208 | feats="$splicedfeats transform-feats $dir/$x.mat ark:- ark:- |"
209 | cur_lda_iter=$x
210 | fi
211 |
212 | if [ $stage -le $x ]; then
213 | $cmd JOB=1:$nj $dir/log/acc.$x.JOB.log \
214 | gmm-acc-stats-ali $dir/$x.mdl "$feats" \
215 | "ark,s,cs:gunzip -c $dir/ali.JOB.gz|" $dir/$x.JOB.acc || exit 1;
216 | $cmd $dir/log/update.$x.log \
217 | gmm-est --write-occs=$dir/$[$x+1].occs --mix-up=$numgauss --power=$power \
218 | $dir/$x.mdl "gmm-sum-accs - $dir/$x.*.acc |" $dir/$[$x+1].mdl || exit 1;
219 | rm $dir/$x.mdl $dir/$x.*.acc $dir/$x.occs
220 | fi
221 | [ $x -le $max_iter_inc ] && numgauss=$[$numgauss+$incgauss];
222 | x=$[$x+1];
223 | done
224 |
225 | rm $dir/final.{mdl,mat,occs} 2>/dev/null
226 | ln -s $x.mdl $dir/final.mdl
227 | ln -s $x.occs $dir/final.occs
228 | ln -s $cur_lda_iter.mat $dir/final.mat
229 |
230 | steps/diagnostic/analyze_alignments.sh --cmd "$cmd" $lang $dir
231 |
232 | # Summarize warning messages...
233 | utils/summarize_warnings.pl $dir/log
234 |
235 | steps/info/gmm_dir_info.pl $dir
236 |
237 | echo "$0: Done training system with LDA+MLLT features in $dir"
238 |
239 | exit 0
240 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/kaldi_self_train/st/train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -eu
4 |
5 | w2v_dir= # contains features `{train,valid}.{npy,lengths}`, real transcripts `{train,valid}.${label}`, and dict `dict.${label}.txt`
6 | lab_dir= # contains pseudo labels `{train,valid}.txt`
7 | out_dir= # output root
8 | arpa_lm= # phone LM
9 | arpa_lm_bin= # (binary) phone LM for KenLM, used in unsupervised selection
10 |
11 | label=phnc
12 | train_name="train"
13 | valid_name="valid"
14 | data_dir=${out_dir}/data
15 |
16 | mkdir -p ${out_dir}/exp
17 | local/prepare_lang.sh $w2v_dir/dict.${label}.txt $data_dir
18 | local/prepare_lm.sh $arpa_lm $data_dir
19 |
20 | for x in $train_name $valid_name; do
21 | x_gt=${x}_gt
22 |
23 | # prepare pseudo data
24 | python local/prepare_data_from_w2v.py $w2v_dir $data_dir $x
25 | steps/compute_cmvn_stats.sh $data_dir/$x $out_dir/exp/make_feat/$x $out_dir/feats/$x
26 | python local/copy_aligned_text.py < $lab_dir/$x.txt > $data_dir/$x/text
27 |
28 | # prepare ground truth data
29 | mkdir $data_dir/$x_gt
30 | cp $data_dir/$x/{feats.scp,cmvn.scp,utt2spk,spk2utt} $data_dir/$x_gt/
31 | python local/copy_aligned_text.py < $w2v_dir/$x.$label > $data_dir/$x_gt/text
32 | done
33 |
34 | local/train_subset_lgbeam.sh \
35 | --out_root ${out_dir} --out_name exp --train $train_name --valid $valid_name \
36 | --mono_size 2000 --tri1_size 5000 --tri2b_size -1 --tri3b_size -1 \
37 | --stage 1 --max_stage 3 $data_dir $data_dir/lang $data_dir/lang_test
38 |
39 | local/unsup_select_decode.sh \
40 | --split $valid_name --kenlm_path $arpa_lm_bin \
41 | --ref_txt $data_dir/${valid_name}_gt/text \
42 | --psd_txt $data_dir/${valid_name}/text \
43 | $out_dir/exp
44 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/kaldi_self_train/st/utils:
--------------------------------------------------------------------------------
1 | ../../wsj/s5/utils
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from .wav2vec_u import Wav2vec_U
7 |
8 |
9 | __all__ = [
10 | "Wav2vec_U",
11 | ]
12 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/scripts/apply_pca.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3 -u
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import argparse
8 | import os
9 | import os.path as osp
10 | import math
11 | import numpy as np
12 | import tqdm
13 | import torch
14 | from shutil import copyfile
15 |
16 | from npy_append_array import NpyAppendArray
17 |
18 |
19 | def get_parser():
20 | parser = argparse.ArgumentParser(
21 | description="transforms features via a given pca and stored them in target dir"
22 | )
23 | # fmt: off
24 | parser.add_argument('source', help='directory with features')
25 | parser.add_argument('--split', help='which split to read', required=True)
26 | parser.add_argument('--save-dir', help='where to save the output', required=True)
27 | parser.add_argument('--pca-path', type=str, help='pca location. will append _A.npy and _b.npy', required=True)
28 | parser.add_argument('--batch-size', type=int, default=2048000, help='batch size')
29 | parser.add_argument('--unfiltered', action='store_true', help='process the unfiltered version')
30 | # fmt: on
31 |
32 | return parser
33 |
34 |
35 | def main():
36 | parser = get_parser()
37 | args = parser.parse_args()
38 |
39 | source_path = osp.join(args.source, args.split)
40 | data_poth = source_path + "_unfiltered" if args.unfiltered else source_path
41 |
42 | print(f"data path: {data_poth}")
43 |
44 | features = np.load(data_poth + ".npy", mmap_mode="r")
45 | pca_A = torch.from_numpy(np.load(args.pca_path + "_A.npy")).cuda()
46 | pca_b = torch.from_numpy(np.load(args.pca_path + "_b.npy")).cuda()
47 |
48 | os.makedirs(args.save_dir, exist_ok=True)
49 | save_path = osp.join(args.save_dir, args.split)
50 |
51 | copyfile(source_path + ".tsv", save_path + ".tsv")
52 | copyfile(data_poth + ".lengths", save_path + ".lengths")
53 |
54 | if osp.exists(source_path + ".phn"):
55 | copyfile(source_path + ".phn", save_path + ".phn")
56 |
57 | if osp.exists(source_path + ".wrd"):
58 | copyfile(source_path + ".wrd", save_path + ".wrd")
59 |
60 | if osp.exists(save_path + ".npy"):
61 | os.remove(save_path + ".npy")
62 | npaa = NpyAppendArray(save_path + ".npy")
63 |
64 | batches = math.ceil(features.shape[0] / args.batch_size)
65 |
66 | with torch.no_grad():
67 | for b in tqdm.trange(batches):
68 | start = b * args.batch_size
69 | end = start + args.batch_size
70 | x = torch.from_numpy(features[start:end]).cuda()
71 | x = torch.matmul(x, pca_A) + pca_b
72 | npaa.append(x.cpu().numpy())
73 |
74 |
75 | if __name__ == "__main__":
76 | main()
77 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/scripts/copy_labels.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3 -u
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import sys
8 |
9 | for idx, line in enumerate(sys.stdin):
10 | print(f"utt{idx:010d} {line}", end="")
11 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/scripts/filter_lexicon.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3 -u
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import argparse
8 | import sys
9 |
10 | from fairseq.data import Dictionary
11 |
12 |
13 | def get_parser():
14 | parser = argparse.ArgumentParser(
15 | description="filters a lexicon given a unit dictionary"
16 | )
17 | parser.add_argument("-d", "--unit-dict", help="unit dictionary", required=True)
18 | return parser
19 |
20 |
21 | def main():
22 | parser = get_parser()
23 | args = parser.parse_args()
24 |
25 | d = Dictionary.load(args.unit_dict)
26 | symbols = set(d.symbols)
27 |
28 | for line in sys.stdin:
29 | items = line.rstrip().split()
30 | skip = len(items) < 2
31 | for x in items[1:]:
32 | if x not in symbols:
33 | skip = True
34 | break
35 | if not skip:
36 | print(line, end="")
37 |
38 |
39 | if __name__ == "__main__":
40 | main()
41 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/scripts/filter_tsv.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3 -u
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import os
8 | import argparse
9 | import sys
10 |
11 |
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument("--tsv", required=True, type=str)
14 | parser.add_argument("--no-skip", action="store_true")
15 | parser.add_argument("--keep", action="store_true")
16 | params = parser.parse_args()
17 |
18 |
19 | def get_fname(line):
20 | p = os.path.basename(line.split("\t")[0])
21 | p = os.path.splitext(p)[0]
22 | return p
23 |
24 |
25 | # filenames to exclude
26 | seen = set()
27 | with open(params.tsv) as f:
28 | if not params.no_skip:
29 | root = next(f).rstrip()
30 | for line in f:
31 | seen.add(get_fname(line))
32 |
33 | for i, line in enumerate(sys.stdin):
34 | exists = get_fname(line) in seen
35 | keep = (exists and params.keep) or (not exists and not params.keep)
36 | if i == 0 or keep:
37 | print(line, end="")
38 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/scripts/g2p_wrd_to_phn.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3 -u
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import argparse
8 | import sys
9 |
10 | from g2p_en import G2p
11 |
12 |
13 | def main():
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument(
16 | "--compact",
17 | action="store_true",
18 | help="if set, compacts phones",
19 | )
20 | args = parser.parse_args()
21 |
22 | compact = args.compact
23 |
24 | wrd_to_phn = {}
25 | g2p = G2p()
26 | for line in sys.stdin:
27 | words = line.strip().split()
28 | phones = []
29 | for w in words:
30 | if w not in wrd_to_phn:
31 | wrd_to_phn[w] = g2p(w)
32 | if compact:
33 | wrd_to_phn[w] = [
34 | p[:-1] if p[-1].isnumeric() else p for p in wrd_to_phn[w]
35 | ]
36 | phones.extend(wrd_to_phn[w])
37 | try:
38 | print(" ".join(phones))
39 | except:
40 | print(wrd_to_phn, words, phones, file=sys.stderr)
41 | raise
42 |
43 |
44 | if __name__ == "__main__":
45 | main()
46 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/scripts/ltr_to_wrd.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3 -u
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import sys
8 |
9 |
10 | def main():
11 | for line in sys.stdin:
12 | print(line.replace(" ", "").replace("|", " ").strip())
13 |
14 |
15 | if __name__ == "__main__":
16 | main()
17 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/scripts/mean_pool.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3 -u
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import argparse
8 | import os
9 | import os.path as osp
10 | import math
11 | import numpy as np
12 | import tqdm
13 | import torch
14 | import torch.nn.functional as F
15 | from shutil import copyfile
16 |
17 | from npy_append_array import NpyAppendArray
18 |
19 |
20 | def get_parser():
21 | parser = argparse.ArgumentParser(
22 | description="mean pools representations by compressing uniform splits of the data"
23 | )
24 | # fmt: off
25 | parser.add_argument('source', help='directory with features')
26 | parser.add_argument('--split', help='which split to read', required=True)
27 | parser.add_argument('--save-dir', help='where to save the output', required=True)
28 | parser.add_argument('--subsample-rate', type=float, default=0.5, help='size to subsample data to')
29 |
30 | parser.add_argument('--remove-extra', action='store_true', help='if true, removes extra states that cant be pooled, otherwise pads with 0s')
31 | # fmt: on
32 |
33 | return parser
34 |
35 |
36 | def main():
37 | parser = get_parser()
38 | args = parser.parse_args()
39 |
40 | source_path = osp.join(args.source, args.split)
41 |
42 | print(f"data path: {source_path}")
43 |
44 | features = np.load(source_path + ".npy", mmap_mode="r")
45 |
46 | os.makedirs(args.save_dir, exist_ok=True)
47 | save_path = osp.join(args.save_dir, args.split)
48 |
49 | copyfile(source_path + ".tsv", save_path + ".tsv")
50 |
51 | if os.path.exists(source_path + ".phn"):
52 | copyfile(source_path + ".phn", save_path + ".phn")
53 | if os.path.exists(source_path + ".wrd"):
54 | copyfile(source_path + ".wrd", save_path + ".wrd")
55 |
56 | if os.path.exists(osp.join(args.source, "dict.phn.txt")):
57 | copyfile(
58 | osp.join(args.source, "dict.phn.txt"),
59 | osp.join(args.save_dir, "dict.phn.txt"),
60 | )
61 |
62 | if osp.exists(save_path + ".npy"):
63 | os.remove(save_path + ".npy")
64 | npaa = NpyAppendArray(save_path + ".npy")
65 |
66 | with open(source_path + ".lengths", "r") as lf:
67 | lengths = lf.readlines()
68 |
69 | fsz = features.shape[-1]
70 | start = 0
71 | with torch.no_grad():
72 | with open(save_path + ".lengths", "w") as lengths_out:
73 | for length in tqdm.tqdm(lengths):
74 | length = int(length)
75 | end = start + length
76 | feats = features[start:end]
77 | start += length
78 | x = torch.from_numpy(feats).cuda()
79 | target_num = math.ceil(length * args.subsample_rate)
80 | rem = length % target_num
81 |
82 | if rem > 0:
83 | if args.remove_extra:
84 | to_rem = target_num - rem
85 | target_num -= 1
86 | x = x[:-to_rem]
87 | else:
88 | to_add = target_num - rem
89 | x = F.pad(x, [0, 0, 0, to_add])
90 | x[-to_add:] = x[-to_add - 1]
91 |
92 | x = x.view(target_num, -1, fsz)
93 | x = x.mean(dim=-2)
94 | print(target_num, file=lengths_out)
95 | npaa.append(x.cpu().numpy())
96 |
97 |
98 | if __name__ == "__main__":
99 | main()
100 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/scripts/merge_clusters.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3 -u
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import argparse
8 | import os
9 | import os.path as osp
10 | import numpy as np
11 | import tqdm
12 | import torch
13 | import random
14 | from shutil import copyfile
15 |
16 | from npy_append_array import NpyAppendArray
17 |
18 |
19 | def get_parser():
20 | parser = argparse.ArgumentParser(
21 | description="transforms features via a given pca and stored them in target dir"
22 | )
23 | # fmt: off
24 | parser.add_argument('source', help='directory with features')
25 | parser.add_argument('--split', help='which split to read', required=True)
26 | parser.add_argument('--save-dir', help='where to save the output', required=True)
27 | parser.add_argument('--cluster-dir', help='where the clusters are')
28 | parser.add_argument('--pooling', type=str, default='mean', choices=['mean', 'sample'], help='how to pool')
29 | # fmt: on
30 |
31 | return parser
32 |
33 |
34 | def main():
35 | parser = get_parser()
36 | args = parser.parse_args()
37 |
38 | source_path = osp.join(args.source, args.split)
39 | cluster_path = osp.join(args.cluster_dir, args.split + ".src")
40 | print(f"data path: {source_path}")
41 |
42 | features = np.load(source_path + ".npy", mmap_mode="r")
43 | sizes = []
44 | offsets = []
45 | offset = 0
46 | with open(source_path + ".lengths", "r") as len_f:
47 | for line in len_f:
48 | length = int(line.rstrip())
49 | sizes.append(length)
50 | offsets.append(offset)
51 | offset += length
52 |
53 | clusters = []
54 | with open(cluster_path, "r") as cf:
55 | for line in cf:
56 | line = line.rstrip()
57 | items = line.split()
58 | items = list(map(int, items))
59 | clusters.append(items)
60 |
61 | os.makedirs(args.save_dir, exist_ok=True)
62 | save_path = osp.join(args.save_dir, args.split)
63 |
64 | copyfile(source_path + ".tsv", save_path + ".tsv")
65 |
66 | if os.path.exists(source_path + ".phn"):
67 | copyfile(source_path + ".phn", save_path + ".phn")
68 | if os.path.exists(osp.join(args.source, "dict.phn.txt")):
69 | copyfile(
70 | osp.join(args.source, "dict.phn.txt"),
71 | osp.join(args.save_dir, "dict.phn.txt"),
72 | )
73 | if os.path.exists(source_path + ".wrd"):
74 | copyfile(source_path + ".wrd", save_path + ".wrd")
75 |
76 | if osp.exists(save_path + ".npy"):
77 | os.remove(save_path + ".npy")
78 | npaa = NpyAppendArray(save_path + ".npy")
79 |
80 | def merge(feats, clust):
81 | feats = torch.from_numpy(feats.copy())
82 | clust = torch.LongTensor(clust)
83 | _, counts = clust.unique_consecutive(return_counts=True)
84 | curr = 0
85 |
86 | merged = []
87 | for c in counts:
88 | c = c.item()
89 | start = curr
90 | end = curr + c
91 | curr += c
92 | if args.pooling == "mean":
93 | new_x = feats[start:end].mean(dim=0)
94 | elif args.pooling == "sample":
95 | new_x = feats[start + int(random.random() * c)]
96 | else:
97 | raise NotImplementedError()
98 | merged.append(new_x)
99 |
100 | return torch.stack(merged, dim=0).numpy()
101 |
102 | with open(save_path + ".lengths", "w") as l_f:
103 | for size, offset, clust in tqdm.tqdm(
104 | zip(sizes, offsets, clusters), total=len(sizes)
105 | ):
106 | end = size + offset
107 | feats = features[offset:end]
108 | feats = merge(feats, clust)
109 | print(len(feats), file=l_f)
110 | npaa.append(feats)
111 |
112 |
113 | if __name__ == "__main__":
114 | main()
115 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/scripts/normalize_and_filter_text.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import argparse
8 | import fasttext as ft
9 | import os
10 | import regex
11 | import sys
12 |
13 |
14 | def get_parser():
15 | parser = argparse.ArgumentParser(
16 | description="reads text from stdin and outputs normalized, lid-filtered version to stdout"
17 | )
18 | parser.add_argument(
19 | "--fasttext-model",
20 | help="path to fasttext model",
21 | default="lid.187.bin",
22 | )
23 | parser.add_argument("--lang", help="language id", required=True)
24 | parser.add_argument(
25 | "--lid-threshold",
26 | type=float,
27 | help="threshold for this lang id probability",
28 | default=0.4,
29 | )
30 |
31 | return parser
32 |
33 |
34 | def main():
35 | parser = get_parser()
36 | args = parser.parse_args()
37 | filter_r = regex.compile(r"[^\p{L}\p{N}\p{M}\' \-]")
38 |
39 | lg = args.lang.lower()
40 | lg_label = f"__label__{lg}"
41 | thresh = args.lid_threshold
42 |
43 | if os.path.exists(args.fasttext_model):
44 | model = ft.load_model(args.fasttext_model)
45 | else:
46 | print(
47 | f"fasttext language id model {args.fasttext_model} not found. Proceeding without language filtering. "
48 | f"To enable language filtering, please download the latest language id model "
49 | f"from https://fasttext.cc/docs/en/language-identification.html",
50 | file=sys.stderr,
51 | )
52 | model = None
53 |
54 | for line in sys.stdin:
55 | line = line.strip()
56 | line = filter_r.sub(" ", line)
57 | line = " ".join(line.split())
58 |
59 | if model is not None:
60 | lid, prob = model.predict(line, k=100)
61 | try:
62 | target_idx = lid.index(lg_label)
63 | except ValueError:
64 | continue
65 | if target_idx == 0 or prob[target_idx] >= thresh:
66 | print(line)
67 | else:
68 | print(line)
69 |
70 |
71 | if __name__ == "__main__":
72 | main()
73 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/scripts/normalize_text.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import regex
8 | import sys
9 |
10 |
11 | def main():
12 | filter_r = regex.compile(r"[^\p{L}\p{N}\p{M}\' \-]")
13 |
14 | for line in sys.stdin:
15 | line = line.strip()
16 | line = filter_r.sub(" ", line)
17 | line = " ".join(line.split())
18 | print(line)
19 |
20 |
21 | if __name__ == "__main__":
22 | main()
23 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/scripts/pca.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3 -u
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import argparse
8 | import os
9 | import os.path as osp
10 | import numpy as np
11 |
12 | import faiss
13 |
14 |
15 |
16 | def get_parser():
17 | parser = argparse.ArgumentParser(
18 | description="compute a pca matrix given an array of numpy features"
19 | )
20 | # fmt: off
21 | parser.add_argument('data', help='numpy file containing features')
22 | parser.add_argument('--output', help='where to save the pca matrix', required=True)
23 | parser.add_argument('--dim', type=int, help='dim for pca reduction', required=True)
24 | parser.add_argument('--eigen-power', type=float, default=0, help='eigen power, -0.5 for whitening')
25 |
26 | return parser
27 |
28 |
29 | def main():
30 | parser = get_parser()
31 | args = parser.parse_args()
32 |
33 | print("Reading features")
34 | x = np.load(args.data, mmap_mode="r")
35 |
36 | print("Computing PCA")
37 | pca = faiss.PCAMatrix(x.shape[-1], args.dim, args.eigen_power)
38 | pca.train(x)
39 | b = faiss.vector_to_array(pca.b)
40 | A = faiss.vector_to_array(pca.A).reshape(pca.d_out, pca.d_in)
41 |
42 | os.makedirs(args.output, exist_ok=True)
43 |
44 | prefix = str(args.dim)
45 | if args.eigen_power != 0:
46 | prefix += f"_{args.eigen_power}"
47 |
48 | np.save(osp.join(args.output, f"{prefix}_pca_A"), A.T)
49 | np.save(osp.join(args.output, f"{prefix}_pca_b"), b)
50 |
51 |
52 | if __name__ == "__main__":
53 | main()
54 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/scripts/phonemize_with_sil.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3 -u
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import argparse
8 | import numpy as np
9 | import sys
10 |
11 |
12 | def get_parser():
13 | parser = argparse.ArgumentParser(
14 | description="converts words to phones adding optional silences around in between words"
15 | )
16 | parser.add_argument(
17 | "--sil-prob",
18 | "-s",
19 | type=float,
20 | default=0,
21 | help="probability of inserting silence between each word",
22 | )
23 | parser.add_argument(
24 | "--surround",
25 | action="store_true",
26 | help="if set, surrounds each example with silence",
27 | )
28 | parser.add_argument(
29 | "--lexicon",
30 | help="lexicon to convert to phones",
31 | required=True,
32 | )
33 |
34 | return parser
35 |
36 |
37 | def main():
38 | parser = get_parser()
39 | args = parser.parse_args()
40 |
41 | sil_prob = args.sil_prob
42 | surround = args.surround
43 | sil = ""
44 |
45 | wrd_to_phn = {}
46 |
47 | with open(args.lexicon, "r") as lf:
48 | for line in lf:
49 | items = line.rstrip().split()
50 | assert len(items) > 1, line
51 | assert items[0] not in wrd_to_phn, items
52 | wrd_to_phn[items[0]] = items[1:]
53 |
54 | for line in sys.stdin:
55 | words = line.strip().split()
56 |
57 | if not all(w in wrd_to_phn for w in words):
58 | continue
59 |
60 | phones = []
61 | if surround:
62 | phones.append(sil)
63 |
64 | sample_sil_probs = None
65 | if sil_prob > 0 and len(words) > 1:
66 | sample_sil_probs = np.random.random(len(words) - 1)
67 |
68 | for i, w in enumerate(words):
69 | phones.extend(wrd_to_phn[w])
70 | if (
71 | sample_sil_probs is not None
72 | and i < len(sample_sil_probs)
73 | and sample_sil_probs[i] < sil_prob
74 | ):
75 | phones.append(sil)
76 |
77 | if surround:
78 | phones.append(sil)
79 | print(" ".join(phones))
80 |
81 |
82 | if __name__ == "__main__":
83 | main()
84 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/scripts/prepare_audio.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env zsh
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | source_dir=$1
8 | tgt_dir=$2
9 | model=$3
10 |
11 | if [ -z "$4" ]
12 | then
13 | dim=512
14 | else
15 | dim=$4
16 | fi
17 |
18 | echo "using $dim dim for PCA"
19 |
20 | if [ -z "$5" ]
21 | then
22 | layer=14
23 | else
24 | layer=$5
25 | fi
26 |
27 | echo "extracting from layer $layer"
28 |
29 | train_split=train
30 | valid_split=valid
31 | test_split=test
32 |
33 | all_splits=($train_split)
34 |
35 | if [[ -f "$source_dir/valid.tsv" ]]; then
36 | all_splits+=('valid')
37 | fi
38 |
39 | if [[ -f "$source_dir/test.tsv" ]]; then
40 | all_splits+=('test')
41 | fi
42 |
43 | echo "processing splits: $all_splits"
44 |
45 | mkdir -p $tgt_dir
46 |
47 | cp $source_dir/*.tsv $tgt_dir
48 | cp $source_dir/*.wrd $tgt_dir
49 | cp $source_dir/*.ltr $tgt_dir
50 | cp $source_dir/*.phn $tgt_dir
51 | cp $source_dir/dict* $tgt_dir
52 |
53 | setopt shwordsplit
54 |
55 | for split in $all_splits; do
56 | python $FAIRSEQ_ROOT/examples/wav2vec/unsupervised/scripts/wav2vec_extract_features.py $source_dir --split $split \
57 | --save-dir $tgt_dir --checkpoint $model --layer $layer
58 | done
59 |
60 | python $FAIRSEQ_ROOT/examples/wav2vec/unsupervised/scripts/wav2vec_cluster_faiss.py $tgt_dir/${train_split}.tsv \
61 | --checkpoint $model --save-dir $tgt_dir -f "CLUS128" --sample-pct 1.0
62 |
63 | for split in $all_splits; do
64 | python $FAIRSEQ_ROOT/examples/wav2vec/unsupervised/scripts/wav2vec_apply_cluster_faiss.py $tgt_dir \
65 | --checkpoint $model --path $tgt_dir/CLUS128 --split $split
66 | done
67 |
68 | python $FAIRSEQ_ROOT/examples/wav2vec/unsupervised/scripts/pca.py $tgt_dir/${train_split}.npy --output $tgt_dir/pca --dim $dim
69 |
70 | for split in $all_splits; do
71 | python $FAIRSEQ_ROOT/examples/wav2vec/unsupervised/scripts/apply_pca.py $tgt_dir --split $split --save-dir $tgt_dir/precompute_pca$dim --pca-path $tgt_dir/pca/${dim}_pca --batch-size 1048000
72 |
73 | python $FAIRSEQ_ROOT/examples/wav2vec/unsupervised/scripts/merge_clusters.py $tgt_dir/precompute_pca$dim --cluster-dir $tgt_dir/CLUS128 \
74 | --split $split --save-dir $tgt_dir/precompute_pca${dim}_cls128_mean --pooling mean
75 |
76 | python $FAIRSEQ_ROOT/examples/wav2vec/unsupervised/scripts/mean_pool.py $tgt_dir/precompute_pca${dim}_cls128_mean \
77 | --save-dir $tgt_dir/precompute_pca${dim}_cls128_mean_pooled --split $split
78 | done
79 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/scripts/prepare_text.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env zsh
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | lg=$1
8 | text_path=$2
9 | target_dir=$3
10 | min_phones=$4
11 | phonemizer=$5
12 | lid_path=$6
13 |
14 | if [ -z "$lid_path" ]; then
15 | lid_path="lid.187.bin"
16 | fi
17 |
18 | ph_lg=${lg:l}
19 | if test "$lg" = 'fr'; then
20 | ph_lg='fr-fr'
21 | elif test "$lg" = 'en'; then
22 | ph_lg='en-us'
23 | elif test "$lg" = 'pt'; then
24 | ph_lg='pt-br'
25 | fi
26 |
27 | ESPEAK_PATH=''
28 | if test "$phonemizer" = 'espeak'; then
29 | ESPEAK_PATH=$(which espeak)
30 | elif test "$phonemizer" = 'espeak-ng'; then
31 | ESPEAK_PATH=$(which espeak-ng)
32 | elif test "$phonemizer" = 'G2P'; then
33 | ESPEAK_PATH=''
34 | else
35 | echo "Unknown phonemizer $phonemizer. Valid options are espeak, espean-ng and G2P"
36 | exit 1
37 | fi
38 |
39 | echo $lg
40 | echo $ph_lg
41 | echo $text_path
42 | echo $target_dir
43 | echo "min phone seen threshold is $min_phones"
44 |
45 | mkdir -p $target_dir
46 | python $FAIRSEQ_ROOT/examples/wav2vec/unsupervised/scripts/normalize_and_filter_text.py --lang $lg --fasttext-model $lid_path < $text_path | grep -v '\-\-\-' >! $target_dir/lm.upper.lid.txt
47 | python $FAIRSEQ_ROOT/fairseq_cli/preprocess.py --dataset-impl mmap --trainpref $target_dir/lm.upper.lid.txt --only-source --destdir $target_dir --thresholdsrc 2 --padding-factor 1 --dict-only
48 | cut -f1 -d' ' $target_dir/dict.txt | grep -v -x '[[:punct:]]*' | grep -Pv '\d\d\d\d\d+' >! $target_dir/words.txt
49 |
50 |
51 | if [ -z "$ESPEAK_PATH" ]; then
52 | python $FAIRSEQ_ROOT/examples/wav2vec/unsupervised/scripts/g2p_wrd_to_phn.py --compact < $target_dir/words.txt > $target_dir/phones.txt
53 | else
54 | # echoing 1 into corpus will prevent the mismatch lines between lexicon and phones in case the phonemizer fails
55 | one=$(echo "1" | PHONEMIZER_ESPEAK_PATH=$ESPEAK_PATH phonemize -p ' ' -w '' -l $ph_lg --language-switch remove-flags)
56 | sed 's/$/ 1/' $target_dir/words.txt | PHONEMIZER_ESPEAK_PATH=$ESPEAK_PATH phonemize -o $target_dir/phones.txt -p ' ' -w '' -l $ph_lg -j 70 --language-switch remove-flags
57 | echo "one is ${one}"
58 | sed -i "s/${one}$//" $target_dir/phones.txt
59 | fi
60 |
61 | paste $target_dir/words.txt $target_dir/phones.txt >! $target_dir/lexicon.lst
62 |
63 | python $FAIRSEQ_ROOT/fairseq_cli/preprocess.py --dataset-impl mmap --trainpref $target_dir/phones.txt --only-source --destdir $target_dir/phones --thresholdsrc $min_phones --padding-factor 1 --dict-only
64 |
65 | python $FAIRSEQ_ROOT/examples/wav2vec/unsupervised/scripts/filter_lexicon.py -d $target_dir/phones/dict.txt < $target_dir/lexicon.lst >! $target_dir/lexicon_filtered.lst
66 | python $FAIRSEQ_ROOT/examples/wav2vec/unsupervised/scripts/phonemize_with_sil.py -s 0.25 --surround --lexicon $target_dir/lexicon_filtered.lst < $target_dir/lm.upper.lid.txt >! $target_dir/phones/lm.phones.filtered.txt
67 | cp $target_dir/phones/dict.txt $target_dir/phones/dict.phn.txt
68 | echo " 0" >> $target_dir/phones/dict.phn.txt
69 | python $FAIRSEQ_ROOT/fairseq_cli/preprocess.py --dataset-impl mmap --trainpref $target_dir/phones/lm.phones.filtered.txt --workers 70 --only-source --destdir $target_dir/phones --srcdict $target_dir/phones/dict.phn.txt
70 |
71 | $KENLM_ROOT/lmplz -o 4 < $target_dir/lm.upper.lid.txt --discount_fallback --prune 0 0 0 3 >! $target_dir/kenlm.wrd.o40003.arpa
72 | $KENLM_ROOT/build_binary $target_dir/kenlm.wrd.o40003.arpa $target_dir/kenlm.wrd.o40003.bin
73 |
74 | lg=$lg python $FAIRSEQ_ROOT/examples/speech_recognition/kaldi/kaldi_initializer.py kaldi_root=$KALDI_ROOT fst_dir=$target_dir/fst/phn_to_words_sil lm_arpa=$target_dir/kenlm.wrd.o40003.arpa wav2letter_lexicon=$target_dir/lexicon_filtered.lst data_dir=$target_dir/phones in_labels=phn "blank_symbol=''"
75 | lg=$lg python $FAIRSEQ_ROOT/examples/speech_recognition/kaldi/kaldi_initializer.py kaldi_root=$KALDI_ROOT fst_dir=$target_dir/fst/phn_to_words lm_arpa=$target_dir/kenlm.wrd.o40003.arpa wav2letter_lexicon=$target_dir/lexicon_filtered.lst data_dir=$target_dir/phones in_labels=phn
76 |
77 | $KENLM_ROOT/lmplz -o 4 < $target_dir/phones/lm.phones.filtered.txt --discount_fallback >! $target_dir/phones/lm.phones.filtered.04.arpa
78 | $KENLM_ROOT/build_binary $target_dir/phones/lm.phones.filtered.04.arpa $target_dir/phones/lm.phones.filtered.04.bin
79 | $KENLM_ROOT/lmplz -o 6 < $target_dir/phones/lm.phones.filtered.txt --discount_fallback >! $target_dir/phones/lm.phones.filtered.06.arpa
80 | $KENLM_ROOT/build_binary $target_dir/phones/lm.phones.filtered.06.arpa $target_dir/phones/lm.phones.filtered.06.bin
81 |
82 | lg=$lg python $FAIRSEQ_ROOT/examples/speech_recognition/kaldi/kaldi_initializer.py kaldi_root=$KALDI_ROOT fst_dir=$target_dir/fst/phn_to_phn_sil lm_arpa=$target_dir/phones/lm.phones.filtered.06.arpa data_dir=$target_dir/phones in_labels=phn "blank_symbol=''"
83 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/scripts/prepare_timit.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | timit_root=$1 # assume it is the upper-cased version
8 | tgt_dir=$2
9 | model=$3
10 |
11 | set -eu
12 |
13 | setups="matched unmatched"
14 | splits="test valid train train_text"
15 |
16 | tgt_dir=$(realpath $tgt_dir)
17 | sph2wav=$KALDI_ROOT/tools/sph2pipe_v2.5/sph2pipe
18 | wav_dir=$tgt_dir/wav
19 |
20 |
21 | mkdir -p $tgt_dir $wav_dir
22 | find $timit_root/{TRAIN,TEST} -iname "*.WAV" > $tgt_dir/all_sph.flist
23 | cat $tgt_dir/all_sph.flist | sed -e 's#//*#/#g' -e 's#.*/\([^/]*\)/\([^/]*\).WAV#\1_\2#g' > $tgt_dir/all.uid
24 | paste -d' ' $tgt_dir/{all_sph.flist,all.uid} | \
25 | awk -v sph2wav=$sph2wav -v wav_dir=$wav_dir '{print sph2wav " -f wav " $1 " > " wav_dir "/" $2 ".wav"}' \
26 | > $tgt_dir/sph2wav.sh
27 | bash $tgt_dir/sph2wav.sh
28 | cat $tgt_dir/all.uid | awk -v wav_dir=$(pwd)/$wav_dir '{print $1" "wav_dir"/"$1".wav"}' | sort > $tgt_dir/all_wav.scp
29 | cut -d' ' -f2 $tgt_dir/all_wav.scp | xargs -I{} soxi -s {} > $tgt_dir/all.dur
30 | paste -d' ' $tgt_dir/{all_wav.scp,all.dur} > $tgt_dir/all_wav_dur.scp
31 | rm $tgt_dir/{all.uid,all_sph.flist,sph2wav.sh}
32 |
33 | find $timit_root/{TRAIN,TEST} -iname "*.PHN" > $tgt_dir/all_phn60.flist
34 | while read line; do
35 | if [ ! -f $line ]; then
36 | >&2 echo "Cannot find transcription file '$line'" && exit 1;
37 | fi
38 | cut -f3 -d' ' "$line" | tr '\n' ' ' | perl -ape 's: *$:\n:;'
39 | done < $tgt_dir/all_phn60.flist > $tgt_dir/all.phn60
40 | cat $tgt_dir/all_phn60.flist | sed -e 's#//*#/#g' -e 's#.*/\([^/]*\)/\([^/]*\).PHN#\1_\2#g' | \
41 | paste -d' ' - $tgt_dir/all.phn60 | \
42 | $KALDI_ROOT/egs/timit/s5/local/timit_norm_trans.pl -i - -m $KALDI_ROOT/egs/timit/s5/conf/phones.60-48-39.map -to 39 | \
43 | sort > $tgt_dir/all.phn
44 | echo "done preparing wav and 39-phone transcripts"
45 |
46 |
47 | for s in $setups; do
48 | mkdir -p $tgt_dir/$s
49 | for x in $splits; do
50 | uid_path=config/timit_${s}/${x}.uid
51 | grep -w -f $uid_path $tgt_dir/all.phn | cut -d' ' -f2- > $tgt_dir/$s/$x.phn
52 | ln -sf $(realpath $tgt_dir/$s/$x.phn) $tgt_dir/$s/$x.wrd
53 |
54 | echo "/" > $tgt_dir/$s/$x.tsv && grep -w -f $uid_path $tgt_dir/all_wav_dur.scp | cut -d' ' -f2- | sed 's# #\t#' >> $tgt_dir/$s/$x.tsv
55 | done
56 |
57 | for x in $splits; do
58 | cat $tgt_dir/$s/$x.phn
59 | done | tr ' ' '\n' | sort -u | awk '{print $1" "1}' > $tgt_dir/$s/dict.phn.txt
60 | ln -sf $(realpath $tgt_dir/$s/dict.phn.txt) $tgt_dir/$s/dict.wrd.txt
61 | done
62 | echo "done preparing unmatched and matched setups for TIMIT"
63 |
64 |
65 | for s in $setups; do
66 | zsh scripts/prepare_audio.sh $tgt_dir/$s $tgt_dir/$s/feat $model
67 |
68 | lm_dir=$tgt_dir/$s/phones
69 | fst_dir=$tgt_dir/$s/fst/phn_to_phn
70 |
71 | python $FAIRSEQ_ROOT/fairseq_cli/preprocess.py --dataset-impl mmap --trainpref $tgt_dir/$s/train_text.phn --workers 10 --only-source --destdir $lm_dir --srcdict $tgt_dir/$s/dict.phn.txt
72 | $KENLM_ROOT/lmplz -o 3 < $tgt_dir/$s/train_text.phn --discount_fallback >$lm_dir/train_text_phn.03.arpa
73 | $KENLM_ROOT/build_binary $lm_dir/train_text_phn.03.arpa $lm_dir/train_text_phn.03.bin
74 | $KENLM_ROOT/lmplz -o 4 < $tgt_dir/$s/train_text.phn --discount_fallback >$lm_dir/train_text_phn.04.arpa
75 | $KENLM_ROOT/build_binary $lm_dir/train_text_phn.04.arpa $lm_dir/train_text_phn.04.bin
76 |
77 | python $FAIRSEQ_ROOT/examples/speech_recognition/kaldi/kaldi_initializer.py kaldi_root=$KALDI_ROOT fst_dir=$fst_dir lm_arpa=$lm_dir/train_text_phn.03.arpa data_dir=$tgt_dir/$s in_labels=phn
78 | done
79 | echo "done preprocessing audio and text for wav2vec-U"
80 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/scripts/remove_silence.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3 -u
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """
8 | get intervals from .vads file, specify output data, and this script removes silences and saves the audio data in out path folder
9 | paths=shards/train.tsv
10 | vads=shards/train.vads
11 | python remove_silence.py --paths $paths --vads $vads
12 | """
13 |
14 | import os
15 | import argparse
16 | import torch
17 | import torchaudio
18 | import tqdm
19 |
20 |
21 | parser = argparse.ArgumentParser()
22 | parser.add_argument("--tsv", default="", type=str)
23 | parser.add_argument("--vads", default="", type=str)
24 | parser.add_argument("--out", type=str)
25 | params = parser.parse_args()
26 |
27 | # load paths
28 | paths = []
29 | with open(params.tsv) as f:
30 | root = next(f).rstrip()
31 | for line in f:
32 | paths.append(os.path.join(root, line.rstrip().split("\t")[0]))
33 |
34 | # load vads
35 | list_intervals = []
36 | with open(params.vads) as f:
37 | for line in f:
38 | interval = [
39 | [int(w.split(":")[0]), int(w.split(":")[1])] for w in line.rstrip().split()
40 | ]
41 | list_intervals.append(interval)
42 |
43 |
44 | # load audio and keep only intervals (i.e. remove silences)
45 | for i in tqdm.trange(len(paths)):
46 | data, _ = torchaudio.load(paths[i])
47 | if len(list_intervals[i]) > 0:
48 | data_filtered = torch.cat(
49 | [data[0][int(it[0]) : int(it[1])] for it in list_intervals[i]]
50 | ).unsqueeze(0)
51 | else:
52 | data_filtered = data
53 |
54 | # YOU MAY NEED TO MODIFY THIS TO GET THE RIGHT SUBPATH
55 | # outpath = params.out + '/'.join(paths[i].split('/')[-1])
56 | outpath = params.out + "/" + "/".join(paths[i].split("/")[-2:])
57 |
58 | if not os.path.isdir("/".join(outpath.split("/")[:-1])):
59 | os.makedirs("/".join(outpath.split("/")[:-1]))
60 | if not os.path.exists(outpath):
61 | torchaudio.save(outpath, data_filtered, sample_rate=16000)
62 | else:
63 | print(outpath, "exists!")
64 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/scripts/vads.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3 -u
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import argparse
8 | import sys
9 |
10 | from copy import deepcopy
11 | from scipy.signal import lfilter
12 |
13 | import numpy as np
14 | from tqdm import tqdm
15 | import soundfile as sf
16 | import os.path as osp
17 |
18 |
19 | def get_parser():
20 | parser = argparse.ArgumentParser(description="compute vad segments")
21 | parser.add_argument(
22 | "--rvad-home",
23 | "-r",
24 | help="path to rvad home (see https://github.com/zhenghuatan/rVADfast)",
25 | required=True,
26 | )
27 |
28 | return parser
29 |
30 |
31 | def rvad(speechproc, path):
32 | winlen, ovrlen, pre_coef, nfilter, nftt = 0.025, 0.01, 0.97, 20, 512
33 | ftThres = 0.5
34 | vadThres = 0.4
35 | opts = 1
36 |
37 | data, fs = sf.read(path)
38 | assert fs == 16_000, "sample rate must be 16khz"
39 | ft, flen, fsh10, nfr10 = speechproc.sflux(data, fs, winlen, ovrlen, nftt)
40 |
41 | # --spectral flatness --
42 | pv01 = np.zeros(ft.shape[0])
43 | pv01[np.less_equal(ft, ftThres)] = 1
44 | pitch = deepcopy(ft)
45 |
46 | pvblk = speechproc.pitchblockdetect(pv01, pitch, nfr10, opts)
47 |
48 | # --filtering--
49 | ENERGYFLOOR = np.exp(-50)
50 | b = np.array([0.9770, -0.9770])
51 | a = np.array([1.0000, -0.9540])
52 | fdata = lfilter(b, a, data, axis=0)
53 |
54 | # --pass 1--
55 | noise_samp, noise_seg, n_noise_samp = speechproc.snre_highenergy(
56 | fdata, nfr10, flen, fsh10, ENERGYFLOOR, pv01, pvblk
57 | )
58 |
59 | # sets noisy segments to zero
60 | for j in range(n_noise_samp):
61 | fdata[range(int(noise_samp[j, 0]), int(noise_samp[j, 1]) + 1)] = 0
62 |
63 | vad_seg = speechproc.snre_vad(
64 | fdata, nfr10, flen, fsh10, ENERGYFLOOR, pv01, pvblk, vadThres
65 | )
66 | return vad_seg, data
67 |
68 |
69 | def main():
70 | parser = get_parser()
71 | args = parser.parse_args()
72 |
73 | sys.path.append(args.rvad_home)
74 | import speechproc
75 |
76 | stride = 160
77 | lines = sys.stdin.readlines()
78 | root = lines[0].rstrip()
79 | for fpath in tqdm(lines[1:]):
80 | path = osp.join(root, fpath.split()[0])
81 | vads, wav = rvad(speechproc, path)
82 |
83 | start = None
84 | vad_segs = []
85 | for i, v in enumerate(vads):
86 | if start is None and v == 1:
87 | start = i * stride
88 | elif start is not None and v == 0:
89 | vad_segs.append((start, i * stride))
90 | start = None
91 | if start is not None:
92 | vad_segs.append((start, len(wav)))
93 |
94 | print(" ".join(f"{v[0]}:{v[1]}" for v in vad_segs))
95 |
96 |
97 | if __name__ == "__main__":
98 | main()
99 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/scripts/wav2vec_apply_cluster_faiss.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3 -u
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import argparse
8 | import os
9 | import os.path as osp
10 | import numpy as np
11 | import tqdm
12 | import torch
13 | import sys
14 |
15 | import faiss
16 | import torch.nn.functional as F
17 |
18 | from wav2vec_cluster_faiss import parse_faiss_specs, Wav2VecFeatureReader
19 |
20 |
21 | def get_parser():
22 | parser = argparse.ArgumentParser(description="apply clusters")
23 | # fmt: off
24 | parser.add_argument('data', help='location of tsv files')
25 | parser.add_argument('--split', help='split to process', required=True)
26 | parser.add_argument('--labels', help='split to process', default="phn")
27 | parser.add_argument('--path', help='path to pca and centroids', required=True)
28 | parser.add_argument('--checkpoint', type=str, help='checkpoint for wav2vec model (if using wav2vec features)', required=True)
29 | parser.add_argument('--layer', '-l', type=int, help='which layer to read', default=14)
30 | parser.add_argument('--max-tsz', type=int, help='batch kmeans up to this much', default=14)
31 | # fmt: on
32 |
33 | return parser
34 |
35 |
36 | def get_iterator(args):
37 | label_path = osp.join(args.data, f"{args.split}.{args.labels}")
38 | if osp.exists(label_path):
39 | lp = open(label_path, "r")
40 | else:
41 | lp = None
42 |
43 | with open(osp.join(args.data, f"{args.split}.tsv"), "r") as fp:
44 | lines = fp.read().split("\n")
45 | root = lines.pop(0).strip()
46 | files = [line.rstrip() for line in lines if len(line) > 0]
47 |
48 | if lp is not None:
49 | lbls = [line.rstrip() for line in lp]
50 | else:
51 | lbls = [None] * len(files)
52 |
53 | num = len(files)
54 | reader = Wav2VecFeatureReader(args.checkpoint, args.layer)
55 |
56 | def iterate():
57 | for fname, lbl in zip(files, lbls):
58 | file = osp.join(root, fname.split("\t")[0])
59 | feats = reader.get_feats(file)
60 | yield feats.data, fname, lbl
61 |
62 | return iterate, num, root
63 |
64 |
65 | def main():
66 | parser = get_parser()
67 | args = parser.parse_args()
68 |
69 | spec = osp.basename(args.path)
70 |
71 | try:
72 | faiss_spec = parse_faiss_specs(spec.rstrip("/"))[0]
73 | except:
74 | print(spec)
75 | raise
76 |
77 | print("Faiss Spec:", faiss_spec, file=sys.stderr)
78 |
79 | if faiss_spec.pca:
80 | A = torch.from_numpy(np.load(osp.join(args.path, "pca_A.npy"))).cuda()
81 | b = torch.from_numpy(np.load(osp.join(args.path, "pca_b.npy"))).cuda()
82 | print("Loaded PCA", file=sys.stderr)
83 |
84 | centroids = np.load(osp.join(args.path, "centroids.npy"))
85 | print("Loaded centroids", centroids.shape, file=sys.stderr)
86 |
87 | res = faiss.StandardGpuResources()
88 | index_flat = (
89 | faiss.IndexFlatL2(centroids.shape[1])
90 | if not faiss_spec.sphere
91 | else faiss.IndexFlatIP(centroids.shape[1])
92 | )
93 | faiss_index = faiss.index_cpu_to_gpu(res, 0, index_flat)
94 | faiss_index.add(centroids)
95 |
96 | generator, num, root = get_iterator(args)
97 | iterator = generator()
98 |
99 | had_labels = False
100 | label_path = osp.join(args.path, f"{args.split}.{args.labels}")
101 |
102 | with torch.no_grad():
103 | with open(osp.join(args.path, f"{args.split}.src"), "w") as fp, open(
104 | osp.join(args.path, f"{args.split}.tsv"), "w"
105 | ) as pp, open(label_path, "w") as lp:
106 | print(root, file=pp)
107 | for f, fname, lbl in tqdm.tqdm(iterator, total=num):
108 | if faiss_spec.pca:
109 | f = torch.mm(f, A) + b
110 | if faiss_spec.norm:
111 | f = F.normalize(f, p=2, dim=-1)
112 |
113 | f = f.cpu().numpy()
114 |
115 | _, z = faiss_index.search(f, 1)
116 |
117 | print(" ".join(str(x.item()) for x in z), file=fp)
118 | print(fname, file=pp)
119 |
120 | if lbl is not None:
121 | print(lbl, file=lp)
122 | had_labels = True
123 | if not had_labels:
124 | os.remove(label_path)
125 |
126 |
127 | if __name__ == "__main__":
128 | main()
129 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/scripts/wav2vec_cluster_faiss.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3 -u
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import argparse
8 | import gc
9 | import os
10 | import os.path as osp
11 | import random
12 | import numpy as np
13 | import tqdm
14 | import torch
15 |
16 | from collections import namedtuple
17 |
18 | import faiss
19 |
20 | import fairseq
21 | import soundfile as sf
22 |
23 |
24 | def get_parser():
25 | parser = argparse.ArgumentParser(
26 | description="compute kmeans codebook from kaldi-computed feats"
27 | )
28 | # fmt: off
29 | parser.add_argument('data', help='location of tsv files')
30 | parser.add_argument('--save-dir', help='where to save the output', required=True)
31 | parser.add_argument('--checkpoint', type=str, help='checkpoint for wav2vec model (if using wav2vec features)', required=True)
32 | parser.add_argument('--sample-pct', '-r', type=float, help='percentage of timesteps to sample', default=0)
33 | parser.add_argument('--layer', '-l', type=int, help='which layer to read', default=14)
34 | parser.add_argument('--faiss-specs', '-f', type=str,
35 | help='faiss index specs; separated by space '
36 | 'format is: PCAx_NORM_CLUSx_SPHERICAL -> '
37 | 'PCAx if exists first apply PCA '
38 | 'NORM if exists, normalize the vector by L2 norm '
39 | 'CLUSx must exist, cluster to x clusters '
40 | 'SPEHRICAL if exists, apply spherical kmeans',
41 | default='l2')
42 | # fmt: on
43 |
44 | return parser
45 |
46 |
47 | faiss_spec = namedtuple("faiss_spec", ["pca", "norm", "n_clus", "sphere", "spec_str"])
48 |
49 |
50 | def parse_faiss_specs(specs_str):
51 | specs = []
52 | for ss in specs_str.split():
53 | comps = ss.split("_")
54 | pca = 0
55 | norm = False
56 | n_clus = 0
57 | sphere = False
58 | for c in comps:
59 | if c.startswith("PCA"):
60 | pca = int(c[3:])
61 | elif c == "NORM":
62 | norm = True
63 | elif c.startswith("CLUS"):
64 | n_clus = int(c[4:])
65 | elif c == "SPHERICAL":
66 | sphere = True
67 | assert n_clus > 0
68 | specs.append(
69 | faiss_spec(pca=pca, norm=norm, n_clus=n_clus, sphere=sphere, spec_str=ss)
70 | )
71 | return specs
72 |
73 |
74 | class Wav2VecFeatureReader(object):
75 | def __init__(self, cp_file, layer):
76 | state = fairseq.checkpoint_utils.load_checkpoint_to_cpu(cp_file)
77 |
78 | self.layer = layer
79 |
80 | if "cfg" in state:
81 | w2v_args = state["cfg"]
82 | task = fairseq.tasks.setup_task(w2v_args.task)
83 | model = task.build_model(w2v_args.model)
84 | else:
85 | w2v_args = state["args"]
86 | task = fairseq.tasks.setup_task(w2v_args)
87 | model = task.build_model(w2v_args)
88 | model.load_state_dict(state["model"], strict=True)
89 | model.eval()
90 | model.cuda()
91 | self.model = model
92 |
93 | def read_audio(self, fname):
94 | """Load an audio file and return PCM along with the sample rate"""
95 | wav, sr = sf.read(fname)
96 | assert sr == 16e3
97 |
98 | return wav
99 |
100 | def get_feats(self, loc):
101 | x = self.read_audio(loc)
102 | with torch.no_grad():
103 | source = torch.from_numpy(x).view(1, -1).float().cuda()
104 | res = self.model(
105 | source=source, mask=False, features_only=True, layer=self.layer
106 | )
107 | return res["layer_results"][self.layer][0].squeeze(1)
108 |
109 |
110 | def get_iterator(args):
111 | with open(args.data, "r") as fp:
112 | lines = fp.read().split("\n")
113 | root = lines.pop(0).strip()
114 | files = [osp.join(root, line.split("\t")[0]) for line in lines if len(line) > 0]
115 |
116 | if getattr(args, "sample_pct", 0) > 0:
117 | files = random.sample(files, int(args.sample_pct * len(files)))
118 | num = len(files)
119 | reader = Wav2VecFeatureReader(args.checkpoint, args.layer)
120 |
121 | def iterate():
122 | for fname in files:
123 | feats = reader.get_feats(fname)
124 | yield feats.cpu().numpy()
125 |
126 | return iterate, num
127 |
128 |
129 | def main():
130 | parser = get_parser()
131 | args = parser.parse_args()
132 |
133 | faiss_specs = parse_faiss_specs(args.faiss_specs)
134 | print("Faiss Specs:", faiss_specs)
135 |
136 | feat_path = osp.join(args.save_dir, "features")
137 | if osp.exists(feat_path + ".npy"):
138 | feats = np.load(feat_path + ".npy")
139 | else:
140 | generator, num = get_iterator(args)
141 | iterator = generator()
142 |
143 | feats = []
144 | for f in tqdm.tqdm(iterator, total=num):
145 | feats.append(f)
146 |
147 | del iterator
148 | del generator
149 |
150 | feats = np.concatenate(feats)
151 |
152 | print(feats.shape)
153 |
154 | os.makedirs(args.save_dir, exist_ok=True)
155 | # np.save(feat_path, feats)
156 |
157 | gc.collect()
158 | torch.cuda.empty_cache()
159 |
160 | reload = False
161 | for spec in faiss_specs:
162 | print("Processing spec", spec)
163 |
164 | if reload:
165 | print("Reloading...")
166 | del feats
167 | gc.collect()
168 | feats = np.load(feat_path + ".npy")
169 |
170 | save_path = osp.join(args.save_dir, spec.spec_str)
171 | os.makedirs(save_path, exist_ok=True)
172 | d = feats.shape[-1]
173 | x = feats
174 | if spec.pca > 0:
175 | print("Computing PCA")
176 | pca = faiss.PCAMatrix(d, spec.pca)
177 | pca.train(x)
178 | d = spec.pca
179 | b = faiss.vector_to_array(pca.b)
180 | A = faiss.vector_to_array(pca.A).reshape(pca.d_out, pca.d_in)
181 | np.save(osp.join(save_path, "pca_A"), A.T)
182 | np.save(osp.join(save_path, "pca_b"), b)
183 | print("Applying PCA")
184 | x = pca.apply_py(x)
185 |
186 | if spec.norm:
187 | reload = spec.pca <= 0
188 | print("Normalizing")
189 | faiss.normalize_L2(x)
190 |
191 | print("Computing kmeans")
192 | kmeans = faiss.Kmeans(
193 | d,
194 | spec.n_clus,
195 | niter=50,
196 | verbose=True,
197 | spherical=spec.sphere,
198 | max_points_per_centroid=feats.shape[0],
199 | gpu=True,
200 | nredo=3,
201 | )
202 | kmeans.train(x)
203 | np.save(osp.join(save_path, "centroids"), kmeans.centroids)
204 | del kmeans
205 | del x
206 | gc.collect()
207 |
208 |
209 | if __name__ == "__main__":
210 | main()
211 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/scripts/wav2vec_extract_features.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3 -u
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import argparse
8 | import os
9 | import os.path as osp
10 | import tqdm
11 | import torch
12 | import torch.nn.functional as F
13 | from shutil import copyfile
14 |
15 | from npy_append_array import NpyAppendArray
16 |
17 | import fairseq
18 | import soundfile as sf
19 |
20 |
21 | def get_parser():
22 | parser = argparse.ArgumentParser(
23 | description="compute kmeans codebook from kaldi-computed feats"
24 | )
25 | # fmt: off
26 | parser.add_argument('data', help='location of tsv files')
27 | parser.add_argument('--split', help='which split to read', required=True)
28 | parser.add_argument('--save-dir', help='where to save the output', required=True)
29 | parser.add_argument('--checkpoint', type=str, help='checkpoint for wav2vec ctc model', required=True)
30 | parser.add_argument('--layer', type=int, default=14, help='which layer to use')
31 | # fmt: on
32 |
33 | return parser
34 |
35 |
36 | class Wav2VecFeatureReader(object):
37 | def __init__(self, cp_file, layer):
38 | model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
39 | [cp_file]
40 | )
41 | model = model[0]
42 | model.eval()
43 | model.cuda()
44 | self.model = model
45 | self.task = task
46 | self.layer = layer
47 |
48 | def read_audio(self, fname):
49 | """Load an audio file and return PCM along with the sample rate"""
50 | wav, sr = sf.read(fname)
51 | assert sr == 16e3
52 |
53 | return wav
54 |
55 | def get_feats(self, loc):
56 | x = self.read_audio(loc)
57 | with torch.no_grad():
58 | source = torch.from_numpy(x).float().cuda()
59 | if self.task.cfg.normalize:
60 | assert source.dim() == 1, source.dim()
61 | with torch.no_grad():
62 | source = F.layer_norm(source, source.shape)
63 | source = source.view(1, -1)
64 |
65 | m_res = self.model(source=source, mask=False, features_only=True, layer=self.layer)
66 | return m_res["x"].squeeze(0).cpu()
67 |
68 |
69 | def get_iterator(args):
70 | with open(osp.join(args.data, args.split) + ".tsv", "r") as fp:
71 | lines = fp.read().split("\n")
72 | root = lines.pop(0).strip()
73 | files = [osp.join(root, line.split("\t")[0]) for line in lines if len(line) > 0]
74 |
75 | num = len(files)
76 | reader = Wav2VecFeatureReader(args.checkpoint, args.layer)
77 |
78 | def iterate():
79 | for fname in files:
80 | w2v_feats = reader.get_feats(fname)
81 | yield w2v_feats
82 |
83 | return iterate, num
84 |
85 |
86 | def main():
87 | parser = get_parser()
88 | args = parser.parse_args()
89 |
90 | os.makedirs(args.save_dir, exist_ok=True)
91 |
92 | def create_files(dest):
93 | copyfile(osp.join(args.data, args.split) + ".tsv", dest + ".tsv")
94 | if osp.exists(osp.join(args.data, args.split) + ".wrd"):
95 | copyfile(osp.join(args.data, args.split) + ".wrd", dest + ".wrd")
96 | if osp.exists(osp.join(args.data, args.split) + ".phn"):
97 | copyfile(osp.join(args.data, args.split) + ".phn", dest + ".phn")
98 |
99 | if osp.exists(dest + ".npy"):
100 | os.remove(dest + ".npy")
101 | npaa = NpyAppendArray(dest + ".npy")
102 | return npaa
103 |
104 | save_path = osp.join(args.save_dir, args.split)
105 | npaa = create_files(save_path)
106 |
107 | generator, num = get_iterator(args)
108 | iterator = generator()
109 |
110 | with open(save_path + ".lengths", "w") as l_f:
111 | for w2v_feats in tqdm.tqdm(iterator, total=num):
112 | print(len(w2v_feats), file=l_f)
113 |
114 | if len(w2v_feats) > 0:
115 | npaa.append(w2v_feats.numpy())
116 |
117 |
118 | if __name__ == "__main__":
119 | main()
120 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/scripts/wer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3 -u
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """
8 | Implement unsupervised metric for decoding hyperparameter selection:
9 | $$ alpha * LM_PPL + ViterbitUER(%) * 100 $$
10 | """
11 | import argparse
12 | import logging
13 | import sys
14 |
15 | import editdistance
16 |
17 | logging.root.setLevel(logging.INFO)
18 | logging.basicConfig(stream=sys.stdout, level=logging.INFO)
19 | logger = logging.getLogger(__name__)
20 |
21 |
22 | def get_parser():
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument("-s", "--hypo", help="hypo transcription", required=True)
25 | parser.add_argument(
26 | "-r", "--reference", help="reference transcription", required=True
27 | )
28 | return parser
29 |
30 |
31 | def compute_wer(ref_uid_to_tra, hyp_uid_to_tra, g2p):
32 | d_cnt = 0
33 | w_cnt = 0
34 | w_cnt_h = 0
35 | for uid in hyp_uid_to_tra:
36 | ref = ref_uid_to_tra[uid].split()
37 | if g2p is not None:
38 | hyp = g2p(hyp_uid_to_tra[uid])
39 | hyp = [p for p in hyp if p != "'" and p != " "]
40 | hyp = [p[:-1] if p[-1].isnumeric() else p for p in hyp]
41 | else:
42 | hyp = hyp_uid_to_tra[uid].split()
43 | d_cnt += editdistance.eval(ref, hyp)
44 | w_cnt += len(ref)
45 | w_cnt_h += len(hyp)
46 | wer = float(d_cnt) / w_cnt
47 | logger.debug(
48 | (
49 | f"wer = {wer * 100:.2f}%; num. of ref words = {w_cnt}; "
50 | f"num. of hyp words = {w_cnt_h}; num. of sentences = {len(ref_uid_to_tra)}"
51 | )
52 | )
53 | return wer
54 |
55 |
56 | def main():
57 | args = get_parser().parse_args()
58 |
59 | errs = 0
60 | count = 0
61 | with open(args.hypo, "r") as hf, open(args.reference, "r") as rf:
62 | for h, r in zip(hf, rf):
63 | h = h.rstrip().split()
64 | r = r.rstrip().split()
65 | errs += editdistance.eval(r, h)
66 | count += len(r)
67 |
68 | logger.info(f"UER: {errs / count * 100:.2f}%")
69 |
70 |
71 | if __name__ == "__main__":
72 | main()
73 |
74 |
75 | def load_tra(tra_path):
76 | with open(tra_path, "r") as f:
77 | uid_to_tra = {}
78 | for line in f:
79 | uid, tra = line.split(None, 1)
80 | uid_to_tra[uid] = tra
81 | logger.debug(f"loaded {len(uid_to_tra)} utterances from {tra_path}")
82 | return uid_to_tra
83 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/scripts/wrd_to_ltr.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3 -u
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import sys
8 |
9 |
10 | def main():
11 | for line in sys.stdin:
12 | print(" ".join(list(line.strip().replace(" ", "|"))) + " |")
13 |
14 |
15 | if __name__ == "__main__":
16 | main()
17 |
--------------------------------------------------------------------------------
/wav2vec 2.0/unsupervised/tasks/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from .unpaired_audio_text import UnpairedAudioText
7 |
8 |
9 | __all__ = [
10 | "UnpairedAudioText",
11 | ]
12 |
--------------------------------------------------------------------------------
/wav2vec 2.0/vq-wav2vec_featurize.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """
8 | Helper script to pre-compute embeddings for a flashlight (previously called wav2letter++) dataset
9 | """
10 |
11 | import argparse
12 | import glob
13 | import os
14 | import os.path as osp
15 | import pprint
16 |
17 | import soundfile as sf
18 | import torch
19 | import fairseq
20 | from torch import nn
21 | from torch.utils.data import DataLoader
22 |
23 |
24 | try:
25 | import tqdm
26 | except:
27 | print("Install tqdm to use --log-format=tqdm")
28 |
29 |
30 | class FilesDataset:
31 | def __init__(self, files, labels):
32 | self.files = files
33 | if labels and osp.exists(labels):
34 | with open(labels, "r") as lbl_f:
35 | self.labels = [line.rstrip() for line in lbl_f]
36 | else:
37 | self.labels = labels
38 |
39 | def __len__(self):
40 | return len(self.files)
41 |
42 | def __getitem__(self, index):
43 | fname = self.files[index]
44 |
45 | wav, sr = sf.read(fname)
46 | assert sr == 16000
47 |
48 | wav = torch.from_numpy(wav).float()
49 | lbls = None
50 | if self.labels:
51 | if isinstance(self.labels, str):
52 | lbl_file = osp.splitext(fname)[0] + "." + self.labels
53 | with open(lbl_file, "r") as lblf:
54 | lbls = lblf.readline()
55 | assert lbls is not None
56 | else:
57 | lbls = self.labels[index]
58 | return wav, lbls
59 |
60 | def collate(self, batch):
61 | return batch
62 |
63 |
64 | class ArgTypes:
65 | @staticmethod
66 | def existing_path(arg):
67 | arg = str(arg)
68 | assert osp.exists(arg), f"File {arg} does not exist"
69 | return arg
70 |
71 | @staticmethod
72 | def mkdir(arg):
73 | arg = str(arg)
74 | os.makedirs(arg, exist_ok=True)
75 | return arg
76 |
77 |
78 | class DatasetWriter:
79 | def __init__(self):
80 |
81 | self.args = self.load_config()
82 | pprint.pprint(self.args.__dict__)
83 |
84 | self.model = self.load_model()
85 |
86 | def __getattr__(self, attr):
87 | return getattr(self.args, attr)
88 |
89 | def read_manifest(self, fname):
90 |
91 | with open(fname, "r") as fp:
92 | lines = fp.read().split("\n")
93 | root = lines.pop(0).strip()
94 | fnames = [
95 | osp.join(root, line.split("\t")[0]) for line in lines if len(line) > 0
96 | ]
97 |
98 | return fnames
99 |
100 | def process_splits(self):
101 |
102 | if self.args.shard is not None or self.args.num_shards is not None:
103 | assert self.args.shard is not None and self.args.num_shards is not None
104 |
105 | for split in self.splits:
106 | print(split)
107 |
108 | if self.extension == "tsv":
109 | datadir = osp.join(self.data_dir, f"{split}.{self.extension}")
110 | print("Reading manifest file: ", datadir)
111 | files = self.read_manifest(datadir)
112 | else:
113 | datadir = osp.join(self.data_dir, split, f"**/*.{self.extension}")
114 | files = glob.glob(datadir, recursive=True)
115 |
116 | assert len(files) > 0
117 |
118 | if self.args.shard is not None:
119 | files = files[self.args.shard :: self.args.num_shards]
120 |
121 | lbls = []
122 | with open(self.data_file(split), "w") as srcf:
123 | for line, lbl in self.iterate(files):
124 | print(line, file=srcf)
125 | if self.args.labels:
126 | lbls.append(lbl + "\n")
127 |
128 | if self.args.labels:
129 | assert all(a is not None for a in lbls)
130 | with open(self.lbl_file(split), "w") as lblf:
131 | lblf.writelines(lbls)
132 |
133 | def iterate(self, files):
134 |
135 | data = self.load_data(files)
136 | for samples in tqdm.tqdm(data, total=len(files) // 32):
137 |
138 | for wav, lbl in samples:
139 | x = wav.unsqueeze(0).float().cuda()
140 |
141 | div = 1
142 | while x.size(-1) // div > self.args.max_size:
143 | div += 1
144 |
145 | xs = x.chunk(div, dim=-1)
146 |
147 | result = []
148 | for x in xs:
149 | torch.cuda.empty_cache()
150 | x = self.model.feature_extractor(x)
151 | if self.quantize_location == "encoder":
152 | with torch.no_grad():
153 | _, idx = self.model.vector_quantizer.forward_idx(x)
154 | idx = idx.squeeze(0).cpu()
155 | else:
156 | with torch.no_grad():
157 | z = self.model.feature_aggregator(x)
158 | _, idx = self.model.vector_quantizer.forward_idx(z)
159 | idx = idx.squeeze(0).cpu()
160 | result.append(idx)
161 |
162 | idx = torch.cat(result, dim=0)
163 | yield " ".join("-".join(map(str, a.tolist())) for a in idx), lbl
164 |
165 | def lbl_file(self, name):
166 | shard_part = "" if self.args.shard is None else f".{self.args.shard}"
167 | return osp.join(self.output_dir, f"{name}.lbl{shard_part}")
168 |
169 | def data_file(self, name):
170 | shard_part = "" if self.args.shard is None else f".{self.args.shard}"
171 | return osp.join(self.output_dir, f"{name}.src{shard_part}")
172 |
173 | def var_file(self):
174 | return osp.join(self.output_dir, f"vars.pt")
175 |
176 | def load_config(self):
177 |
178 | parser = argparse.ArgumentParser("Vector Quantized wav2vec features")
179 |
180 | # Model Arguments
181 | parser.add_argument("--checkpoint", type=ArgTypes.existing_path, required=True)
182 | parser.add_argument("--data-parallel", action="store_true")
183 |
184 | # Output Arguments
185 | parser.add_argument("--output-dir", type=ArgTypes.mkdir, required=True)
186 |
187 | # Data Arguments
188 | parser.add_argument("--data-dir", type=ArgTypes.existing_path, required=True)
189 | parser.add_argument("--splits", type=str, nargs="+", required=True)
190 | parser.add_argument("--extension", type=str, required=True)
191 | parser.add_argument("--labels", type=str, required=False)
192 |
193 | parser.add_argument("--shard", type=int, default=None)
194 | parser.add_argument("--num-shards", type=int, default=None)
195 | parser.add_argument("--max-size", type=int, default=1300000)
196 |
197 | # Logger Arguments
198 | parser.add_argument(
199 | "--log-format", type=str, choices=["none", "simple", "tqdm"]
200 | )
201 |
202 | return parser.parse_args()
203 |
204 | def load_data(self, fnames):
205 |
206 | dataset = FilesDataset(fnames, self.args.labels)
207 | loader = DataLoader(
208 | dataset, batch_size=32, collate_fn=dataset.collate, num_workers=8
209 | )
210 | return loader
211 |
212 | def load_model(self):
213 | model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([self.checkpoint])
214 | model = model[0]
215 |
216 | self.quantize_location = getattr(cfg.model, "vq", "encoder")
217 |
218 | model.eval().float()
219 | model.cuda()
220 |
221 | if self.data_parallel:
222 | model = nn.DataParallel(model)
223 |
224 | return model
225 |
226 | def __call__(self):
227 |
228 | self.process_splits()
229 |
230 | if hasattr(self.model.feature_extractor, "vars") and (
231 | self.args.shard is None or self.args.shard == 0
232 | ):
233 | vars = (
234 | self.model.feature_extractor.vars.view(
235 | self.model.feature_extractor.banks,
236 | self.model.feature_extractor.num_vars,
237 | -1,
238 | )
239 | .cpu()
240 | .detach()
241 | )
242 | print("writing learned latent variable embeddings: ", vars.shape)
243 | torch.save(vars, self.var_file())
244 |
245 |
246 | if __name__ == "__main__":
247 | write_data = DatasetWriter()
248 |
249 | write_data()
250 | print("Done.")
251 |
--------------------------------------------------------------------------------
/wav2vec 2.0/wav2vec_featurize.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """
8 | Helper script to pre-compute embeddings for a flashlight (previously called wav2letter++) dataset
9 | """
10 |
11 | import argparse
12 | import glob
13 | import os
14 | from shutil import copy
15 |
16 | import h5py
17 | import numpy as np
18 | import soundfile as sf
19 | import torch
20 | import tqdm
21 | import fairseq
22 | from torch import nn
23 |
24 |
25 | def read_audio(fname):
26 | """ Load an audio file and return PCM along with the sample rate """
27 |
28 | wav, sr = sf.read(fname)
29 | assert sr == 16e3
30 |
31 | return wav, 16e3
32 |
33 |
34 | class PretrainedWav2VecModel(nn.Module):
35 | def __init__(self, fname):
36 | super().__init__()
37 |
38 | model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([fname])
39 | model = model[0]
40 | model.eval()
41 |
42 | self.model = model
43 |
44 | def forward(self, x):
45 | with torch.no_grad():
46 | z = self.model.feature_extractor(x)
47 | if isinstance(z, tuple):
48 | z = z[0]
49 | c = self.model.feature_aggregator(z)
50 | return z, c
51 |
52 |
53 | class EmbeddingWriterConfig(argparse.ArgumentParser):
54 | def __init__(self):
55 | super().__init__("Pre-compute embeddings for flashlight datasets")
56 |
57 | kwargs = {"action": "store", "type": str, "required": True}
58 |
59 | self.add_argument("--input", "-i", help="Input Directory", **kwargs)
60 | self.add_argument("--output", "-o", help="Output Directory", **kwargs)
61 | self.add_argument("--model", help="Path to model checkpoint", **kwargs)
62 | self.add_argument("--split", help="Dataset Splits", nargs="+", **kwargs)
63 | self.add_argument(
64 | "--ext", default="wav", required=False, help="Audio file extension"
65 | )
66 |
67 | self.add_argument(
68 | "--no-copy-labels",
69 | action="store_true",
70 | help="Do not copy label files. Useful for large datasets, use --targetdir in flashlight then.",
71 | )
72 | self.add_argument(
73 | "--use-feat",
74 | action="store_true",
75 | help="Use the feature vector ('z') instead of context vector ('c') for features",
76 | )
77 | self.add_argument("--gpu", help="GPU to use", default=0, type=int)
78 |
79 |
80 | class Prediction:
81 | """ Lightweight wrapper around a fairspeech embedding model """
82 |
83 | def __init__(self, fname, gpu=0):
84 | self.gpu = gpu
85 | self.model = PretrainedWav2VecModel(fname).cuda(gpu)
86 |
87 | def __call__(self, x):
88 | x = torch.from_numpy(x).float().cuda(self.gpu)
89 | with torch.no_grad():
90 | z, c = self.model(x.unsqueeze(0))
91 |
92 | return z.squeeze(0).cpu().numpy(), c.squeeze(0).cpu().numpy()
93 |
94 |
95 | class H5Writer:
96 | """ Write features as hdf5 file in flashlight compatible format """
97 |
98 | def __init__(self, fname):
99 | self.fname = fname
100 | os.makedirs(os.path.dirname(self.fname), exist_ok=True)
101 |
102 | def write(self, data):
103 | channel, T = data.shape
104 |
105 | with h5py.File(self.fname, "w") as out_ds:
106 | data = data.T.flatten()
107 | out_ds["features"] = data
108 | out_ds["info"] = np.array([16e3 // 160, T, channel])
109 |
110 |
111 | class EmbeddingDatasetWriter(object):
112 | """Given a model and a flashlight dataset, pre-compute and store embeddings
113 |
114 | Args:
115 | input_root, str :
116 | Path to the flashlight dataset
117 | output_root, str :
118 | Desired output directory. Will be created if non-existent
119 | split, str :
120 | Dataset split
121 | """
122 |
123 | def __init__(
124 | self,
125 | input_root,
126 | output_root,
127 | split,
128 | model_fname,
129 | extension="wav",
130 | gpu=0,
131 | verbose=False,
132 | use_feat=False,
133 | ):
134 |
135 | assert os.path.exists(model_fname)
136 |
137 | self.model_fname = model_fname
138 | self.model = Prediction(self.model_fname, gpu)
139 |
140 | self.input_root = input_root
141 | self.output_root = output_root
142 | self.split = split
143 | self.verbose = verbose
144 | self.extension = extension
145 | self.use_feat = use_feat
146 |
147 | assert os.path.exists(self.input_path), "Input path '{}' does not exist".format(
148 | self.input_path
149 | )
150 |
151 | def _progress(self, iterable, **kwargs):
152 | if self.verbose:
153 | return tqdm.tqdm(iterable, **kwargs)
154 | return iterable
155 |
156 | def require_output_path(self, fname=None):
157 | path = self.get_output_path(fname)
158 | os.makedirs(path, exist_ok=True)
159 |
160 | @property
161 | def input_path(self):
162 | return self.get_input_path()
163 |
164 | @property
165 | def output_path(self):
166 | return self.get_output_path()
167 |
168 | def get_input_path(self, fname=None):
169 | if fname is None:
170 | return os.path.join(self.input_root, self.split)
171 | return os.path.join(self.get_input_path(), fname)
172 |
173 | def get_output_path(self, fname=None):
174 | if fname is None:
175 | return os.path.join(self.output_root, self.split)
176 | return os.path.join(self.get_output_path(), fname)
177 |
178 | def copy_labels(self):
179 | self.require_output_path()
180 |
181 | labels = list(
182 | filter(
183 | lambda x: self.extension not in x, glob.glob(self.get_input_path("*"))
184 | )
185 | )
186 | for fname in tqdm.tqdm(labels):
187 | copy(fname, self.output_path)
188 |
189 | @property
190 | def input_fnames(self):
191 | return sorted(glob.glob(self.get_input_path("*.{}".format(self.extension))))
192 |
193 | def __len__(self):
194 | return len(self.input_fnames)
195 |
196 | def write_features(self):
197 |
198 | paths = self.input_fnames
199 |
200 | fnames_context = map(
201 | lambda x: os.path.join(
202 | self.output_path, x.replace("." + self.extension, ".h5context")
203 | ),
204 | map(os.path.basename, paths),
205 | )
206 |
207 | for name, target_fname in self._progress(
208 | zip(paths, fnames_context), total=len(self)
209 | ):
210 | wav, sr = read_audio(name)
211 | z, c = self.model(wav)
212 | feat = z if self.use_feat else c
213 | writer = H5Writer(target_fname)
214 | writer.write(feat)
215 |
216 | def __repr__(self):
217 |
218 | return "EmbeddingDatasetWriter ({n_files} files)\n\tinput:\t{input_root}\n\toutput:\t{output_root}\n\tsplit:\t{split})".format(
219 | n_files=len(self), **self.__dict__
220 | )
221 |
222 |
223 | if __name__ == "__main__":
224 |
225 | args = EmbeddingWriterConfig().parse_args()
226 |
227 | for split in args.split:
228 |
229 | writer = EmbeddingDatasetWriter(
230 | input_root=args.input,
231 | output_root=args.output,
232 | split=split,
233 | model_fname=args.model,
234 | gpu=args.gpu,
235 | extension=args.ext,
236 | use_feat=args.use_feat,
237 | )
238 |
239 | print(writer)
240 | writer.require_output_path()
241 |
242 | print("Writing Features...")
243 | writer.write_features()
244 | print("Done.")
245 |
246 | if not args.no_copy_labels:
247 | print("Copying label data...")
248 | writer.copy_labels()
249 | print("Done.")
250 |
--------------------------------------------------------------------------------
/wav2vec 2.0/wav2vec_manifest.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | """
7 | Data pre-processing: build vocabularies and binarize training data.
8 | """
9 |
10 | import argparse
11 | import glob
12 | import os
13 | import random
14 |
15 | import soundfile
16 |
17 |
18 | def get_parser():
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument(
21 | "root", metavar="DIR", help="root directory containing flac files to index"
22 | )
23 | parser.add_argument(
24 | "--valid-percent",
25 | default=0.01,
26 | type=float,
27 | metavar="D",
28 | help="percentage of data to use as validation set (between 0 and 1)",
29 | )
30 | parser.add_argument(
31 | "--dest", default=".", type=str, metavar="DIR", help="output directory"
32 | )
33 | parser.add_argument(
34 | "--ext", default="flac", type=str, metavar="EXT", help="extension to look for"
35 | )
36 | parser.add_argument("--seed", default=42, type=int, metavar="N", help="random seed")
37 | parser.add_argument(
38 | "--path-must-contain",
39 | default=None,
40 | type=str,
41 | metavar="FRAG",
42 | help="if set, path must contain this substring for a file to be included in the manifest",
43 | )
44 | return parser
45 |
46 |
47 | def main(args):
48 | assert args.valid_percent >= 0 and args.valid_percent <= 1.0
49 |
50 | if not os.path.exists(args.dest):
51 | os.makedirs(args.dest)
52 |
53 | dir_path = os.path.realpath(args.root)
54 | search_path = os.path.join(dir_path, "**/*." + args.ext)
55 | rand = random.Random(args.seed)
56 |
57 | valid_f = (
58 | open(os.path.join(args.dest, "valid.tsv"), "w")
59 | if args.valid_percent > 0
60 | else None
61 | )
62 |
63 | with open(os.path.join(args.dest, "train.tsv"), "w") as train_f:
64 | print(dir_path, file=train_f)
65 |
66 | if valid_f is not None:
67 | print(dir_path, file=valid_f)
68 |
69 | for fname in glob.iglob(search_path, recursive=True):
70 | file_path = os.path.realpath(fname)
71 |
72 | if args.path_must_contain and args.path_must_contain not in file_path:
73 | continue
74 |
75 | frames = soundfile.info(fname).frames
76 | dest = train_f if rand.random() > args.valid_percent else valid_f
77 | print(
78 | "{}\t{}".format(os.path.relpath(file_path, dir_path), frames), file=dest
79 | )
80 | if valid_f is not None:
81 | valid_f.close()
82 |
83 |
84 | if __name__ == "__main__":
85 | parser = get_parser()
86 | args = parser.parse_args()
87 | main(args)
88 |
--------------------------------------------------------------------------------