├── DeepSpeech └── README.md ├── README.md ├── logo.jpeg ├── speechbrain ├── ASR │ └── CTC │ │ ├── README.md │ │ ├── dvoice_prepare.py │ │ ├── extra_requirements.txt │ │ ├── hparams │ │ ├── train_dar_with_wav2vec.yaml │ │ └── train_sw_with_wav2vec.yaml │ │ └── train_with_wav2vec.py └── dvoice_prepare.py └── wav2vec 2.0 ├── README.md ├── __init__.py ├── config ├── finetuning │ ├── base_100h.yaml │ ├── base_10h.yaml │ ├── base_10m.yaml │ ├── base_1h.yaml │ ├── base_960h.yaml │ ├── vox_100h.yaml │ ├── vox_10h.yaml │ ├── vox_10m.yaml │ ├── vox_1h.yaml │ └── vox_960h.yaml └── pretraining │ ├── wav2vec2_base_librispeech.yaml │ ├── wav2vec2_large_librivox.yaml │ ├── wav2vec2_large_librivox_tpu-pod.yaml │ └── wav2vec2_large_librivox_tpu.yaml ├── libri_labels.py ├── scripts └── binarize_manifest.sh ├── unsupervised ├── README.md ├── __init__.py ├── config │ ├── finetuning │ │ └── w2v_finetune.yaml │ ├── gan │ │ └── w2vu.yaml │ ├── generate │ │ └── viterbi.yaml │ ├── timit_matched │ │ ├── test.uid │ │ ├── train.uid │ │ ├── train_text.uid │ │ └── valid.uid │ └── timit_unmatched │ │ ├── test.uid │ │ ├── train.uid │ │ ├── train_text.uid │ │ └── valid.uid ├── data │ ├── __init__.py │ ├── extracted_features_dataset.py │ └── random_input_dataset.py ├── kaldi_self_train │ ├── README.md │ └── st │ │ ├── cmd.sh │ │ ├── decode_phone.sh │ │ ├── decode_word_step1.sh │ │ ├── decode_word_step2.sh │ │ ├── local │ │ ├── copy_aligned_text.py │ │ ├── decode.sh │ │ ├── prepare_data_from_w2v.py │ │ ├── prepare_lang.sh │ │ ├── prepare_lang_word.sh │ │ ├── prepare_lm.sh │ │ ├── score.sh │ │ ├── show_wer.sh │ │ ├── train_subset_lgbeam.sh │ │ ├── unsup_select.py │ │ ├── unsup_select_decode.sh │ │ └── unsup_select_decode_word.sh │ │ ├── path.sh │ │ ├── steps │ │ ├── steps_gan │ │ ├── train_deltas.sh │ │ ├── train_lda_mllt.sh │ │ └── train_sat.sh │ │ ├── train.sh │ │ └── utils ├── models │ ├── __init__.py │ └── wav2vec_u.py ├── scripts │ ├── apply_pca.py │ ├── copy_labels.py │ ├── filter_lexicon.py │ ├── filter_tsv.py │ ├── g2p_wrd_to_phn.py │ ├── ltr_to_wrd.py │ ├── mean_pool.py │ ├── merge_clusters.py │ ├── normalize_and_filter_text.py │ ├── normalize_text.py │ ├── pca.py │ ├── phonemize_with_sil.py │ ├── prepare_audio.sh │ ├── prepare_text.sh │ ├── prepare_timit.sh │ ├── remove_silence.py │ ├── vads.py │ ├── wav2vec_apply_cluster_faiss.py │ ├── wav2vec_cluster_faiss.py │ ├── wav2vec_extract_features.py │ ├── wer.py │ └── wrd_to_ltr.py ├── tasks │ ├── __init__.py │ └── unpaired_audio_text.py └── w2vu_generate.py ├── vq-wav2vec_featurize.py ├── wav2vec_featurize.py ├── wav2vec_manifest.py └── xlsr_wav2vec2_darija_finetuning.ipynb /DeepSpeech/README.md: -------------------------------------------------------------------------------- 1 | # 1. Description du modèle 2 | [DeepSpeech](https://arxiv.org/abs/1412.5567) est un modèle de reconnaissance vocale open source développé par Mozilla. DeepSpeech est très flexible, s'adapte bien sur les données à l'entraînement grâce à sa capacité à identifier les bruits de fond et au modèle de langage inclut au moment de l'entraînement. 3 | 4 | # 2. Entraîner son propre modèle 5 | Pour entraîner le modèle, il faut avoir les pré-requis suivants: 6 | - Python 3.6. 7 | - Mac or Linux environment. 8 | - CUDA 10.0 / CuDNN v7.6 par Dockerfile. 9 | 10 | Il faut ensuite suivre les étapes suivantes: 11 | 12 | ## 2.1. Préparation de l'environnement 13 | ```shell script 14 | # Importer le code du repository 15 | git clone --branch v0.9.3 https://github.com/mozilla/DeepSpeech 16 | 17 | # Créer un environnement virtuel 18 | $ python3 -m venv $HOME/tmp/deepspeech-train-venv/ 19 | $ source $HOME/tmp/deepspeech-train-venv/bin/activate 20 | 21 | # Installer le code pour l'entraînement et toutes ses dépendances 22 | cd DeepSpeech 23 | pip3 install --upgrade pip==20.2.2 wheel==0.34.2 setuptools==49.6.0 24 | pip3 install --upgrade -e . 25 | sudo apt-get install python3-dev 26 | 27 | # Il est recommandé d'installer Tensorflow sur GPU, lorsqu'on dispose d'une machine NVIDIA avec au moins 8G de RAM 28 | pip3 uninstall tensorflow 29 | pip3 install 'tensorflow-gpu==1.15.4' 30 | 31 | # Pour ceux qui veulent entraîner l'environnement sous Docker 32 | make Dockerfile.train 33 | make Dockerfile.train DEEPSPEECH_REPO=git://your/fork DEEPSPEECH_SHA=origin/your-branch 34 | 35 | ``` 36 | 37 | ## 2.2. Charger la base de données 38 | On suppose ici qu'on utilise la base de données de Mozilla Common Voice. 39 | 40 | ```shell script 41 | bin/import_cv2.py --filter_alphabet path/to/some/alphabet.txt /path/to/extracted/language/archive 42 | python3 DeepSpeech.py --train_files ../data/CV/en/clips/train.csv --dev_files ../data/CV/ar/clips/dev.csv --test_files ../data/CV/ar/clips/test.csv 43 | ``` 44 | 45 | ## 2.3. Entraîner le modèle 46 | L'entraînement du modèle est tout simple: 47 | ```shell script 48 | python3 DeepSpeech.py --helpfull 49 | ./bin/run-ldc93s1.sh 50 | ``` 51 | Pour plus de personnalisation à l'entraînement du modèle, on vous recommande de suivre les étapes décrites [ici](https://deepspeech.readthedocs.io/en/r0.9/TRAINING.html). 52 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | logo 3 |

4 | 5 | # 1. A propos 6 | Basée à Rabat, Londres et Paris, AIOX-Labs mobilise les technologies d’intelligence artificielle pour répondre aux besoins métiers et projets data des entreprises. 7 | - Au service de la croissance des groupes, de l’optimisation des processus ou de l’amélioration de l'expérience client. 8 | - AIOX-Labs est multisecteur, de la fintech à l’industrie en passant par le retail et les biens de consommation. 9 | - Des data products business ready avec un socle algorithmique solide et une adaptabilité pour les besoins spécifiques de chaque client. 10 | - Une équipe complémentaire composée de docteurs en IA et d’experts métiers avec une assise scientifique solide et des publications internationales. 11 | 12 | Site web : https://www.aiox-labs.com/ 13 | 14 | Dialectal Voice est un projet communautaire initié par AIOX-Labs pour faciliter la reconnaissance de la voix par les Systèmes Intelligents. Aujourd'hui, le besoin en Systèmes IA capable de reconnaître la voix humaine s'exprime de plus en plus au sein des communautés. On remarque cependant que pour certaines langues comme le Darija, il n'existe pas assez de solutions de technologie vocale. Pour répondre à ce besoin, on s'est proposé alors d'établir ce programme de construction itérative et intéractive d'une base de données dialectale et ouverte à tous afin d'aider à améliorer les modèles de reconnaissance et de génération de la voix. 15 | 16 | Sites web : https://dvoice.ma/, https://dvoice.sn/ 17 | 18 | # 2. Contribuer au projet DVoice 19 | Il existe deux manières de contribuer au projet: 20 | ## 2.1. Soumettre des enregistrements "هضر" 21 | Pour aider à améliorer les technologies vocales sur l’arabe dialectal marocain, il est nécessaire d’avoir une quantité importante de données pour l’entraînement des modèles. Les données se présentent sous forme de « voix + transcription textuelle ». La contribution à l’enrichissement de la dataset se fait par la lecture d’un texte Darija affiché à l’écran suivie de la soumission de l’enregistrement. 22 | 23 | N.B : Au moment de la lecture des textes, veuillez vous assurer que votre micro fonctionne parfaitement et qu’il n’y a pas beaucoup de bruits de fond. 24 | 25 | ## 2.2. Évaluer des enregistrements "سمع" 26 | Une autre manière de contribuer à cette initiative, c’est d’évaluer des enregistrements. Il s’agit ici d’écouter un échantillon « voix + transcription textuelle » et de l’évaluer en fonction de la proximité entre la voix enregistrée et le texte correspondant. Lorsqu’on juge les deux assez proches, on clique sur « Oui » sinon ça sera « Non ». 27 | 28 | # 3. Obtenir la dataset 29 | ## 3.1. DVoice-v1.0 30 | La base de données construite est ouverte au publique. Son acquisition se fait sous deux différentes manières: 31 | - Après souscription sur ce site. On vous l'enverra par suite par mail dans un bref délai. 32 | - En allant sur le répertoire Zenodo depuis ce lien : [Télécharger ici](https://zenodo.org/record/5482551). 33 | 34 | ## 3.2. DVoice-v2.0 35 | Cette deuxième version de la dataset, disponible [ici](https://zenodo.org/record/6342622), intègre les données augmentées, facilement repérables et une dataset swahilie obtenue par transfer learning depuis la dataset multilingue [VoxLingua107](http://bark.phon.ioc.ee/voxlingua107/). 36 | 37 | N.B : Pour être informé lorsqu'une nouvelle mise à jour de la base de données sera disponible, vous pouvez vous souscrire depuis la page Souscription. 38 | 39 | ## 3.2. Dataset alternative 40 | L'entraînement d'un modèle de reconnaissance vocale nécessite souvent d'avoir d'une part un ensemble d'enregistrements vocales et d'une autre part leurs transcriptions en texte. L'initiative DVoice vise justement à pourvoir à la communauté ces prérequis afin de mieux faciliter l'adoption des technologies vocales et la recherche autour d'elles. 41 | 42 | Une méthode alternative consiste à faire du transfert learning sur des enregistrements en transcrivant ces dernières en texte. On procède comme suit: 43 | - Récupérer du contenu vidéo public sur Youtube, Facebook,... 44 | - Convertir les vidéos en audios. 45 | - Segmenter les audios obtenus (faire une segmentation par silence). 46 | - Effectuer un pre-processing sur ces données (réduire les bruits, ajouter du silence si nécessaire,...). 47 | - Labelliser les audios avec la bibliothèque [SpeechRecognition](https://pypi.org/project/SpeechRecognition/) 48 | 49 | # 4. Notre vision pour DVoice Africa 50 | La plus grande nouveauté de cette première mise à jour du projet DVoice est certainement l'inclusion d'autres langues africaines, un pas important vers l'objectif ultime de ce projet. Aujourd'hui, le projet comprend, au niveau de la collection de données, le Darija puis, via DVoice Senegal, supporte six langues et dialectes (Wolof, Serere, Diola, Mandingue, Pular, Soninke) parlées au Sénégal et dans plusieurs pays d'Afrique de l'Est. Au niveau modèle, il supporte le Darija et le Swahili. 51 | 52 | Les premières versions des modèles testés ont de très bons scores. Nous encourageons vivement tout le monde à participer à ce programme communautaire en visitant notre site [dvoice.ma](https://dvoice.ma/) mais aussi le tout nouveau [dvoice.sn](https://dvoice.sn/) pour proposer ou valider quelques enregistrements. Celà aidera considérablement à la construction d'une grande base de données vocale africaine. 53 | 54 | De notre part on essaie encore d'aller plus loin en accentuant sur des potentiels paramètres/approches à embarquer et sur la possibilité d'intégrer un modèle de langage qui permettra de mieux affiner les transcriptions. L'idée ensuite est de mettre au point et de facilitier les technologies vocales pour les langues africaines. 55 | -------------------------------------------------------------------------------- /logo.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIOXLABS/DVoice/19b760ca93ed9016d406bd9baf077e49bab384ec/logo.jpeg -------------------------------------------------------------------------------- /speechbrain/ASR/CTC/README.md: -------------------------------------------------------------------------------- 1 | # DVoice ASR with CTC based Seq2Seq models. 2 | This folder contains scripts necessary to run an ASR experiment with the DVoice datasets : [Link](https://zenodo.org/record/6342622) 3 | 4 | # Data preparation 5 | [DVoice](https://dvoice.ma) attempts to provide automatic voice processing solutions for African languages and dialects. We use preprocessing techniques including voice augmentation to fill the data gap for each language. 6 | 7 | # How to run 8 | - First, get Speechbrain 9 | 10 | ``` bash 11 | git clone https://github.com/speechbrain/speechbrain 12 | ``` 13 | - Place the Dvoice/speechbrain folder inside speechbrain/recipes folder 14 | 15 | - Go to speechbrain/recipes/DVoice/ASR/CTC the run: 16 | ``` bash 17 | python train.py hparams/{hparam_file}.py 18 | ``` 19 | 20 | 21 | # Languages 22 | Here is a list of the different languages and dialects that we tested within the DVoice dataset and CTC: 23 | - Darija 24 | - Swahili (upcoming soon) 25 | 26 | # Results 27 | 28 | | Language | DVoice Release | hyperparams file | LM | Val. CER | Val. WER | Test CER | Test WER | HuggingFace link | 29 | | ------------- |:-------------:|:---------------------------:| -----:| -----:| -----:| -----:| -----:| :-----------:| 30 | | Darija (Moroccan Arabic) | v2.0 | train_dar_with_wav2vec.yaml | No | 5.51 | 18.46 | 5.85 | 18.28 | [Link](https://huggingface.co/nairaxo/dvoice-darija) | 31 | | Swahili | v2.0 | train_sw_with_wav2vec.yaml | No | 8.83 | 22.78 | 9.46 | 23.16 | [Link](https://huggingface.co/nairaxo/dvoice-swahili) | 32 | 33 | 34 | 35 | 36 | ## How to simply use pretrained models to transcribe my audio file? 37 | 38 | SpeechBrain provides a simple interface to transcribe audio files with pretrained models. All the necessary information can be found on the different HuggingFace repositories (see the results table above) corresponding to our different models for DVoice. 39 | 40 | # **About SpeechBrain** 41 | - Website: https://speechbrain.github.io/ 42 | - Code: https://github.com/speechbrain/speechbrain/ 43 | - HuggingFace: https://huggingface.co/speechbrain/ 44 | 45 | 46 | # **Citing SpeechBrain** 47 | Please, cite SpeechBrain if you use it for your research or business. 48 | 49 | ```bibtex 50 | @misc{speechbrain, 51 | title={{SpeechBrain}: A General-Purpose Speech Toolkit}, 52 | author={Mirco Ravanelli and Titouan Parcollet and Peter Plantinga and Aku Rouhe and Samuele Cornell and Loren Lugosch and Cem Subakan and Nauman Dawalatabad and Abdelwahab Heba and Jianyuan Zhong and Ju-Chieh Chou and Sung-Lin Yeh and Szu-Wei Fu and Chien-Feng Liao and Elena Rastorgueva and François Grondin and William Aris and Hwidong Na and Yan Gao and Renato De Mori and Yoshua Bengio}, 53 | year={2021}, 54 | eprint={2106.04624}, 55 | archivePrefix={arXiv}, 56 | primaryClass={eess.AS}, 57 | note={arXiv:2106.04624} 58 | } 59 | ``` 60 | -------------------------------------------------------------------------------- /speechbrain/ASR/CTC/dvoice_prepare.py: -------------------------------------------------------------------------------- 1 | ../../dvoice_prepare.py 2 | -------------------------------------------------------------------------------- /speechbrain/ASR/CTC/extra_requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.13 2 | -------------------------------------------------------------------------------- /speechbrain/ASR/CTC/hparams/train_dar_with_wav2vec.yaml: -------------------------------------------------------------------------------- 1 | # ################################ 2 | # Model: wav2vec2 + DNN + CTC 3 | # Augmentation: SpecAugment 4 | # Authors: Titouan Parcollet 2021 5 | # ################################ 6 | 7 | # Seed needs to be set at top of yaml, before objects with parameters are made 8 | seed: 1234 9 | __set_seed: !!python/object/apply:torch.manual_seed [!ref ] 10 | output_folder: !ref results/wav2vec2_ctc_DAR/ 11 | wer_file: !ref /wer.txt 12 | save_folder: !ref /save 13 | train_log: !ref /train_log.txt 14 | 15 | # URL for the biggest LeBenchmark wav2vec french. 16 | wav2vec2_hub: facebook/wav2vec2-large-xlsr-53 17 | 18 | # Data files 19 | data_folder: #!PLACEHOLDER # e.g, /dataset/ 20 | train_csv_file: !ref /texts/train2.csv # Standard CommonVoice .tsv files 21 | dev_csv_file: !ref /texts/dev2.csv # Standard CommonVoice .tsv files 22 | test_csv_file: !ref /texts/test2.csv # Standard CommonVoice .tsv files 23 | accented_letters: True 24 | language: dar # use 'it' for Italian, 'rw' for Kinyarwanda, 'en' for english 25 | train_csv: !ref /train.csv 26 | valid_csv: !ref /dev.csv 27 | test_csv: !ref /test.csv 28 | skip_prep: False # Skip data preparation 29 | 30 | # We remove utterance slonger than 10s in the train/dev/test sets as 31 | # longer sentences certainly correspond to "open microphones". 32 | avoid_if_longer_than: 15.0 33 | 34 | # Training parameters 35 | number_of_epochs: 30 36 | number_of_ctc_epochs: 15 37 | lr: 1.0 38 | lr_wav2vec: 0.0001 39 | ctc_weight: 0.3 40 | sorting: ascending 41 | auto_mix_prec: False 42 | sample_rate: 16000 43 | ckpt_interval_minutes: 30 # save checkpoint every N min 44 | 45 | # With data_parallel batch_size is split into N jobs 46 | # With DDP batch_size is multiplied by N jobs 47 | # Must be 6 per GPU to fit 16GB of VRAM 48 | batch_size: 4 49 | test_batch_size: 4 50 | 51 | dataloader_options: 52 | batch_size: !ref 53 | num_workers: 2 54 | test_dataloader_options: 55 | batch_size: !ref 56 | num_workers: 2 57 | 58 | # BPE parameters 59 | token_type: char # ["unigram", "bpe", "char"] 60 | character_coverage: 1.0 61 | 62 | # Model parameters 63 | activation: !name:torch.nn.LeakyReLU 64 | wav2vec_output_dim: 1024 65 | dnn_neurons: 1024 66 | freeze_wav2vec: False 67 | 68 | # Outputs 69 | output_neurons: 36 # BPE size, index(blank/eos/bos) = 0 70 | 71 | # Decoding parameters 72 | # Be sure that the bos and eos index match with the BPEs ones 73 | blank_index: 0 74 | bos_index: 1 75 | eos_index: 2 76 | min_decode_ratio: 0.0 77 | max_decode_ratio: 1.0 78 | beam_size: 80 79 | eos_threshold: 1.5 80 | using_max_attn_shift: True 81 | max_attn_shift: 140 82 | ctc_weight_decode: 0.0 83 | temperature: 1.50 84 | 85 | # 86 | # Functions and classes 87 | # 88 | epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter 89 | limit: !ref 90 | 91 | augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment 92 | sample_rate: !ref 93 | speeds: [95, 100, 105] 94 | 95 | enc: !new:speechbrain.nnet.containers.Sequential 96 | input_shape: [null, null, !ref ] 97 | linear1: !name:speechbrain.nnet.linear.Linear 98 | n_neurons: 1024 99 | bias: True 100 | bn1: !name:speechbrain.nnet.normalization.BatchNorm1d 101 | activation: !new:torch.nn.LeakyReLU 102 | drop: !new:torch.nn.Dropout 103 | p: 0.15 104 | linear2: !name:speechbrain.nnet.linear.Linear 105 | n_neurons: 1024 106 | bias: True 107 | bn2: !name:speechbrain.nnet.normalization.BatchNorm1d 108 | activation2: !new:torch.nn.LeakyReLU 109 | drop2: !new:torch.nn.Dropout 110 | p: 0.15 111 | linear3: !name:speechbrain.nnet.linear.Linear 112 | n_neurons: 1024 113 | bias: True 114 | bn3: !name:speechbrain.nnet.normalization.BatchNorm1d 115 | activation3: !new:torch.nn.LeakyReLU 116 | 117 | wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2 118 | source: !ref 119 | output_norm: True 120 | freeze: !ref 121 | save_path: !ref /wav2vec2_checkpoint 122 | 123 | ##### 124 | # Uncomment this block if you prefer to use a Fairseq pretrained model instead 125 | # of a HuggingFace one. Here, we provide an URL that is obtained from the 126 | # Fairseq github for the multilingual XLSR. 127 | # 128 | #wav2vec2_url: https://dl.fbaipublicfiles.com/fairseq/wav2vec/xlsr_53_56k.pt 129 | #wav2vec2: !new:speechbrain.lobes.models.fairseq_wav2vec.FairseqWav2Vec2 130 | # pretrained_path: !ref 131 | # output_norm: True 132 | # freeze: False 133 | # save_path: !ref /wav2vec2_checkpoint/model.pt 134 | ##### 135 | 136 | 137 | ctc_lin: !new:speechbrain.nnet.linear.Linear 138 | input_size: !ref 139 | n_neurons: !ref 140 | 141 | log_softmax: !new:speechbrain.nnet.activations.Softmax 142 | apply_log: True 143 | 144 | ctc_cost: !name:speechbrain.nnet.losses.ctc_loss 145 | blank_index: !ref 146 | 147 | modules: 148 | wav2vec2: !ref 149 | enc: !ref 150 | ctc_lin: !ref 151 | 152 | model: !new:torch.nn.ModuleList 153 | - [!ref , !ref ] 154 | 155 | model_opt_class: !name:torch.optim.Adadelta 156 | lr: !ref 157 | rho: 0.95 158 | eps: 1.e-8 159 | 160 | wav2vec_opt_class: !name:torch.optim.Adam 161 | lr: !ref 162 | 163 | lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler 164 | initial_value: !ref 165 | improvement_threshold: 0.0025 166 | annealing_factor: 0.8 167 | patient: 0 168 | 169 | lr_annealing_wav2vec: !new:speechbrain.nnet.schedulers.NewBobScheduler 170 | initial_value: !ref 171 | improvement_threshold: 0.0025 172 | annealing_factor: 0.9 173 | patient: 0 174 | 175 | checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer 176 | checkpoints_dir: !ref 177 | recoverables: 178 | wav2vec2: !ref 179 | model: !ref 180 | scheduler_model: !ref 181 | scheduler_wav2vec: !ref 182 | counter: !ref 183 | 184 | train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger 185 | save_file: !ref 186 | 187 | error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 188 | 189 | cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 190 | split_tokens: True 191 | -------------------------------------------------------------------------------- /speechbrain/ASR/CTC/hparams/train_sw_with_wav2vec.yaml: -------------------------------------------------------------------------------- 1 | # ################################ 2 | # Model: wav2vec2 + DNN + CTC 3 | # Augmentation: SpecAugment 4 | # Authors: Titouan Parcollet 2021 5 | # ################################ 6 | 7 | # Seed needs to be set at top of yaml, before objects with parameters are made 8 | seed: 1234 9 | __set_seed: !!python/object/apply:torch.manual_seed [!ref ] 10 | output_folder: !ref results/wav2vec2_ctc_SW/ 11 | wer_file: !ref /wer.txt 12 | save_folder: !ref /save 13 | train_log: !ref /train_log.txt 14 | 15 | # URL for the biggest LeBenchmark wav2vec french. 16 | wav2vec2_hub: facebook/wav2vec2-large-xlsr-53 17 | 18 | # Data files 19 | data_folder: #!PLACEHOLDER # e.g, /dataset/ 20 | train_csv_file: !ref /texts/train2.csv # Standard CommonVoice .tsv files 21 | dev_csv_file: !ref /texts/dev2.csv # Standard CommonVoice .tsv files 22 | test_csv_file: !ref /texts/test2.csv # Standard CommonVoice .tsv files 23 | accented_letters: True 24 | language: sw # use 'it' for Italian, 'rw' for Kinyarwanda, 'en' for english 25 | train_csv: !ref /train.csv 26 | valid_csv: !ref /dev.csv 27 | test_csv: !ref /test.csv 28 | skip_prep: False # Skip data preparation 29 | 30 | # We remove utterance slonger than 10s in the train/dev/test sets as 31 | # longer sentences certainly correspond to "open microphones". 32 | avoid_if_longer_than: 15.0 33 | 34 | # Training parameters 35 | number_of_epochs: 30 36 | number_of_ctc_epochs: 15 37 | lr: 1.0 38 | lr_wav2vec: 0.0001 39 | ctc_weight: 0.3 40 | sorting: ascending 41 | auto_mix_prec: False 42 | sample_rate: 16000 43 | ckpt_interval_minutes: 30 # save checkpoint every N min 44 | 45 | # With data_parallel batch_size is split into N jobs 46 | # With DDP batch_size is multiplied by N jobs 47 | # Must be 6 per GPU to fit 16GB of VRAM 48 | batch_size: 4 49 | test_batch_size: 4 50 | 51 | dataloader_options: 52 | batch_size: !ref 53 | num_workers: 2 54 | test_dataloader_options: 55 | batch_size: !ref 56 | num_workers: 2 57 | 58 | # BPE parameters 59 | token_type: char # ["unigram", "bpe", "char"] 60 | character_coverage: 1.0 61 | 62 | # Model parameters 63 | activation: !name:torch.nn.LeakyReLU 64 | wav2vec_output_dim: 1024 65 | dnn_neurons: 1024 66 | freeze_wav2vec: False 67 | 68 | # Outputs 69 | output_neurons: 38 # BPE size, index(blank/eos/bos) = 0 70 | 71 | # Decoding parameters 72 | # Be sure that the bos and eos index match with the BPEs ones 73 | blank_index: 0 74 | bos_index: 1 75 | eos_index: 2 76 | min_decode_ratio: 0.0 77 | max_decode_ratio: 1.0 78 | beam_size: 80 79 | eos_threshold: 1.5 80 | using_max_attn_shift: True 81 | max_attn_shift: 140 82 | ctc_weight_decode: 0.0 83 | temperature: 1.50 84 | 85 | # 86 | # Functions and classes 87 | # 88 | epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter 89 | limit: !ref 90 | 91 | augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment 92 | sample_rate: !ref 93 | speeds: [95, 100, 105] 94 | 95 | enc: !new:speechbrain.nnet.containers.Sequential 96 | input_shape: [null, null, !ref ] 97 | linear1: !name:speechbrain.nnet.linear.Linear 98 | n_neurons: 1024 99 | bias: True 100 | bn1: !name:speechbrain.nnet.normalization.BatchNorm1d 101 | activation: !new:torch.nn.LeakyReLU 102 | drop: !new:torch.nn.Dropout 103 | p: 0.15 104 | linear2: !name:speechbrain.nnet.linear.Linear 105 | n_neurons: 1024 106 | bias: True 107 | bn2: !name:speechbrain.nnet.normalization.BatchNorm1d 108 | activation2: !new:torch.nn.LeakyReLU 109 | drop2: !new:torch.nn.Dropout 110 | p: 0.15 111 | linear3: !name:speechbrain.nnet.linear.Linear 112 | n_neurons: 1024 113 | bias: True 114 | bn3: !name:speechbrain.nnet.normalization.BatchNorm1d 115 | activation3: !new:torch.nn.LeakyReLU 116 | 117 | wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2 118 | source: !ref 119 | output_norm: True 120 | freeze: !ref 121 | save_path: !ref /wav2vec2_checkpoint 122 | 123 | ##### 124 | # Uncomment this block if you prefer to use a Fairseq pretrained model instead 125 | # of a HuggingFace one. Here, we provide an URL that is obtained from the 126 | # Fairseq github for the multilingual XLSR. 127 | # 128 | #wav2vec2_url: https://dl.fbaipublicfiles.com/fairseq/wav2vec/xlsr_53_56k.pt 129 | #wav2vec2: !new:speechbrain.lobes.models.fairseq_wav2vec.FairseqWav2Vec2 130 | # pretrained_path: !ref 131 | # output_norm: True 132 | # freeze: False 133 | # save_path: !ref /wav2vec2_checkpoint/model.pt 134 | ##### 135 | 136 | 137 | ctc_lin: !new:speechbrain.nnet.linear.Linear 138 | input_size: !ref 139 | n_neurons: !ref 140 | 141 | log_softmax: !new:speechbrain.nnet.activations.Softmax 142 | apply_log: True 143 | 144 | ctc_cost: !name:speechbrain.nnet.losses.ctc_loss 145 | blank_index: !ref 146 | 147 | modules: 148 | wav2vec2: !ref 149 | enc: !ref 150 | ctc_lin: !ref 151 | 152 | model: !new:torch.nn.ModuleList 153 | - [!ref , !ref ] 154 | 155 | model_opt_class: !name:torch.optim.Adadelta 156 | lr: !ref 157 | rho: 0.95 158 | eps: 1.e-8 159 | 160 | wav2vec_opt_class: !name:torch.optim.Adam 161 | lr: !ref 162 | 163 | lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler 164 | initial_value: !ref 165 | improvement_threshold: 0.0025 166 | annealing_factor: 0.8 167 | patient: 0 168 | 169 | lr_annealing_wav2vec: !new:speechbrain.nnet.schedulers.NewBobScheduler 170 | initial_value: !ref 171 | improvement_threshold: 0.0025 172 | annealing_factor: 0.9 173 | patient: 0 174 | 175 | checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer 176 | checkpoints_dir: !ref 177 | recoverables: 178 | wav2vec2: !ref 179 | model: !ref 180 | scheduler_model: !ref 181 | scheduler_wav2vec: !ref 182 | counter: !ref 183 | 184 | train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger 185 | save_file: !ref 186 | 187 | error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 188 | 189 | cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 190 | split_tokens: True 191 | -------------------------------------------------------------------------------- /speechbrain/dvoice_prepare.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data preparation. 3 | Download: https://dvoice.ma/ 4 | Author 5 | ------ 6 | Abdou Mohamed Naira 7 | """ 8 | 9 | import os 10 | import csv 11 | import re 12 | import logging 13 | import torchaudio 14 | import unicodedata 15 | from tqdm.contrib import tzip 16 | 17 | import torch 18 | import random 19 | import pandas as pd 20 | from tqdm import tqdm 21 | import numpy as np 22 | from speechbrain.dataio.dataio import read_audio 23 | from speechbrain.processing.speech_augmentation import SpeedPerturb 24 | from speechbrain.processing.speech_augmentation import DropChunk 25 | from speechbrain.processing.speech_augmentation import DropFreq 26 | from speechbrain.processing.speech_augmentation import DoClip 27 | from speechbrain.lobes.augment import TimeDomainSpecAugment 28 | 29 | 30 | 31 | logger = logging.getLogger(__name__) 32 | 33 | def prepare_dvoice( 34 | data_folder, 35 | save_folder, 36 | train_csv_file=None, 37 | dev_csv_file=None, 38 | test_csv_file=None, 39 | accented_letters=False, 40 | language="dar", 41 | skip_prep=False, 42 | 43 | ): 44 | 45 | if skip_prep: 46 | return 47 | 48 | # If not specified point toward standard location w.r.t CommonVoice tree 49 | if train_csv_file is None: 50 | train_csv_file = data_folder + "texts/train.csv" 51 | else: 52 | train_csv_file = train_csv_file 53 | 54 | if dev_csv_file is None: 55 | dev_csv_file = data_folder + "texts/dev.csv" 56 | else: 57 | dev_csv_file = dev_csv_file 58 | 59 | if test_csv_file is None: 60 | test_csv_file = data_folder + "texts/test.csv" 61 | else: 62 | test_csv_file = test_csv_file 63 | 64 | # Setting the save folder 65 | if not os.path.exists(save_folder): 66 | os.makedirs(save_folder) 67 | 68 | # Setting ouput files 69 | save_csv_train = save_folder + "/train.csv" 70 | save_csv_dev = save_folder + "/dev.csv" 71 | save_csv_test = save_folder + "/test.csv" 72 | 73 | # If csv already exists, we skip the data preparation 74 | if skip(save_csv_train, save_csv_dev, save_csv_test): 75 | 76 | msg = "%s already exists, skipping data preparation!" % (save_csv_train) 77 | logger.info(msg) 78 | 79 | msg = "%s already exists, skipping data preparation!" % (save_csv_dev) 80 | logger.info(msg) 81 | 82 | msg = "%s already exists, skipping data preparation!" % (save_csv_test) 83 | logger.info(msg) 84 | 85 | return 86 | 87 | # Additional checks to make sure the data folder contains Common Voice 88 | check_commonvoice_folders(data_folder) 89 | 90 | # Creating csv file for training data 91 | if train_csv_file is not None: 92 | 93 | create_csv( 94 | train_csv_file, 95 | save_csv_train, 96 | data_folder, 97 | accented_letters, 98 | language, 99 | ) 100 | 101 | # Creating csv file for dev data 102 | if dev_csv_file is not None: 103 | 104 | create_csv( 105 | dev_csv_file, 106 | save_csv_dev, 107 | data_folder, 108 | accented_letters, 109 | language, 110 | ) 111 | 112 | # Creating csv file for test data 113 | if test_csv_file is not None: 114 | 115 | create_csv( 116 | test_csv_file, 117 | save_csv_test, 118 | data_folder, 119 | accented_letters, 120 | language, 121 | ) 122 | 123 | 124 | def train_validate_test_split( 125 | df, train_percent=0.6, validate_percent=0.2, seed=None 126 | ): 127 | np.random.seed(seed) 128 | perm = np.random.permutation(df.index) 129 | m = len(df.index) 130 | train_end = int(train_percent * m) 131 | validate_end = int(validate_percent * m) + train_end 132 | train = df.iloc[perm[:train_end]] 133 | validate = df.iloc[perm[train_end:validate_end]] 134 | test = df.iloc[perm[validate_end:]] 135 | return train, validate, test 136 | 137 | 138 | def skip(save_csv_train, save_csv_dev, save_csv_test): 139 | """ 140 | Detects if the DVoice data preparation has been already done. 141 | If the preparation has been done, we can skip it. 142 | Returns 143 | ------- 144 | bool 145 | if True, the preparation phase can be skipped. 146 | if False, it must be done. 147 | """ 148 | 149 | # Checking folders and save options 150 | skip = False 151 | 152 | if ( 153 | os.path.isfile(save_csv_train) 154 | and os.path.isfile(save_csv_dev) 155 | and os.path.isfile(save_csv_test) 156 | ): 157 | skip = True 158 | 159 | return skip 160 | 161 | 162 | def create_csv( 163 | orig_csv_file, csv_file, data_folder, accented_letters=False, language="dar" 164 | ): 165 | """ 166 | Creates the csv file given a list of wav files. 167 | Arguments 168 | --------- 169 | orig_csv_file : str 170 | Path to the DVoice csv file (standard file). 171 | data_folder : str 172 | Path of the DVoice dataset. 173 | accented_letters : bool, optional 174 | Defines if accented letters will be kept as individual letters or 175 | transformed to the closest non-accented letters. 176 | Returns 177 | ------- 178 | None 179 | """ 180 | 181 | # Check if the given files exists 182 | if not os.path.isfile(orig_csv_file): 183 | msg = "\t%s doesn't exist, verify your dataset!" % (orig_csv_file) 184 | logger.info(msg) 185 | raise FileNotFoundError(msg) 186 | 187 | # We load and skip the header 188 | loaded_csv = open(orig_csv_file, "r").readlines()[1:] 189 | nb_samples = str(len(loaded_csv)) 190 | msg = "Preparing CSV files for %s samples ..." % (str(nb_samples)) 191 | logger.info(msg) 192 | 193 | # Adding some Prints 194 | msg = "Creating csv lists in %s ..." % (csv_file) 195 | logger.info(msg) 196 | 197 | csv_lines = [["ID", "duration", "wav", "spk_id", "wrd"]] 198 | 199 | # Start processing lines 200 | total_duration = 0.0 201 | for line in tzip(loaded_csv): 202 | 203 | line = line[0] 204 | # Path is at indice 1 in DVoice csv files. And .mp3 files 205 | # are located in datasets/lang/clips/ 206 | 207 | mp3_path = data_folder + "/wavs/" + line.split("\t")[0] 208 | file_name = line.split("\t")[0] 209 | spk_id = line.split("\t")[0].replace(".wav", "") 210 | snt_id = file_name 211 | 212 | # Setting torchaudio backend to sox-io (needed to read mp3 files) 213 | if torchaudio.get_audio_backend() != "sox_io": 214 | logger.warning("This recipe needs the sox-io backend of torchaudio") 215 | logger.warning("The torchaudio backend is changed to sox_io") 216 | torchaudio.set_audio_backend("sox_io") 217 | 218 | duration = float(line.split("\t")[2]) 219 | total_duration += duration 220 | 221 | # Getting transcript 222 | words = line.split("\t")[1] 223 | 224 | # Unicode Normalization 225 | # words = unicode_normalisation(words) 226 | 227 | # !! Language specific cleaning !! 228 | # Important: feel free to specify the text normalization 229 | # corresponding to your alphabet. 230 | 231 | if language == "dar": 232 | HAMZA = "\u0621" 233 | ALEF_MADDA = "\u0622" 234 | ALEF_HAMZA_ABOVE = "\u0623" 235 | letters = ( 236 | "ابتةثجحخدذرزسشصضطظعغفقكلمنهويءآأؤإئ" 237 | + HAMZA 238 | + ALEF_MADDA 239 | + ALEF_HAMZA_ABOVE 240 | ) 241 | words = re.sub("[^" + letters + "]+", " ", words).upper() 242 | 243 | # # Remove accents if specified 244 | # if not accented_letters: 245 | # words = strip_accents(words) 246 | # words = words.replace("'", " ") 247 | # words = words.replace("’", " ") 248 | 249 | # # Remove multiple spaces 250 | # words = re.sub(" +", " ", words) 251 | 252 | # # Remove spaces at the beginning and the end of the sentence 253 | # words = words.lstrip().rstrip() 254 | 255 | # # Getting chars 256 | # chars = words.replace(" ", "_") 257 | # chars = " ".join([char for char in chars][:]) 258 | 259 | # Remove too short sentences (or empty): 260 | # if len(words.split(" ")) < 3: 261 | # continue 262 | 263 | # Composition of the csv_line 264 | csv_line = [snt_id, str(duration), mp3_path, spk_id, str(words)] 265 | 266 | # Adding this line to the csv_lines list 267 | csv_lines.append(csv_line) 268 | 269 | # Writing the csv lines 270 | with open(csv_file, mode="w", encoding="utf-8") as csv_f: 271 | csv_writer = csv.writer( 272 | csv_f, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL 273 | ) 274 | 275 | for line in csv_lines: 276 | csv_writer.writerow(line) 277 | 278 | # Final prints 279 | msg = "%s successfully created!" % (csv_file) 280 | logger.info(msg) 281 | msg = "Number of samples: %s " % (str(len(loaded_csv))) 282 | logger.info(msg) 283 | msg = "Total duration: %s Hours" % (str(round(total_duration / 3600, 2))) 284 | logger.info(msg) 285 | 286 | 287 | def check_commonvoice_folders(data_folder): 288 | """ 289 | Check if the data folder actually contains the DVoice dataset. 290 | If not, raises an error. 291 | Returns 292 | ------- 293 | None 294 | Raises 295 | ------ 296 | FileNotFoundError 297 | If data folder doesn't contain DVoice dataset. 298 | """ 299 | 300 | files_str = "/wavs" 301 | 302 | # Checking clips 303 | if not os.path.exists(data_folder + files_str): 304 | 305 | err_msg = ( 306 | "the folder %s does not exist (it is expected in " 307 | "the DVoice dataset)" % (data_folder + files_str) 308 | ) 309 | raise FileNotFoundError(err_msg) 310 | 311 | 312 | def unicode_normalisation(text): 313 | 314 | try: 315 | text = unicode(text, "utf-8") 316 | except NameError: # unicode is a default on python 3 317 | pass 318 | return str(text) 319 | 320 | 321 | def strip_accents(text): 322 | 323 | text = ( 324 | unicodedata.normalize("NFD", text) 325 | .encode("ascii", "ignore") 326 | .decode("utf-8") 327 | ) 328 | 329 | return str(text) 330 | 331 | -------------------------------------------------------------------------------- /wav2vec 2.0/README.md: -------------------------------------------------------------------------------- 1 | # 1. Description du modèle 2 | Nous proposons dans ce projet les étapes du fine-tuning du modèle XLSR-wav2vec2.0. XLSR-53, pour "Unsupervised Cross-Lingual Representation Learning For Speech Recognition" est une version multilingue du modèle [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et Al., 2020)](https://arxiv.org/abs/2006.11477). Il a été entraîné sur trois bases de données reparties sur 53 langues pour un total de 56.000 heures d'audios. 3 | 4 | Modèle | Architecture | Nombre d'heures | Langues | Datasets | Lien 5 | |---|---|---|---|---|--- 6 | XLSR-53 | Large | 56k | 53 | MLS, CommonVoice, BABEL | [Télécharger](https://dl.fbaipublicfiles.com/fairseq/wav2vec/xlsr_53_56k.pt) 7 | 8 | XLSR-53 détecte les paramètres sur des données non annotées comme il est décrit sur [Unsupervised Cross-lingual Representation Learning for Speech Recognition (Conneau et Al. 2020)](https://arxiv.org/abs/2006.13979). Pour effectuer le fine-tuning sur le Darija et l'Arabe on utilise une Classification Temporelle Connexionniste (CTC) sur des données annotées. CTC est un algorithme utilisé sur des réseaux de neurones pour des problèmes de séquence à séquence, dans notre cas, pour la reconnaissance automatique de la parole. Il permet d'identifier des labels sur les échantillons d'audios d'entraînement. 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 |
VersionSourceTailleLangueSpeech AugmentationWERDuréeEpochsDemo
BêtaMozilla Common Voice1500ArabeNon0.74h30
Bêta 2Facebook / Youtube2400DarijaNon0.95h30--------
Version 1.0Facebook / Youtube / Dvoice13000DarijaOui0.312h10
63 | 64 | - WER : Word Error Rate 65 | 66 | # 2. Entraînement du modèle "from scratch" 67 | Pour entraîner un modèle XLSR-53 "from scratch", nous recommandons de suivre les étapes décrit [ici](https://github.com/pytorch/fairseq/tree/master/examples/wav2vec). On y retrouve l'ensemble des modèles du projet wav2vec y compris sa version multilingue XLSR-53. 68 | 69 | - PyTorch version >= 1.5.0 70 | - Python version >= 3.6 71 | - Pour entraîner de nouveaux modèles on a besoin d'avoir un GPU NVIDIA et [NCCL](https://github.com/NVIDIA/nccl). 72 | - Installation pour un dévéloppement en local: 73 | 74 | ``` bash 75 | git clone https://github.com/pytorch/fairseq 76 | cd fairseq 77 | pip install --editable ./ 78 | 79 | # on MacOS: 80 | # CFLAGS="-stdlib=libc++" pip install --editable ./ 81 | 82 | # to install the latest stable release (0.10.x) 83 | # pip install fairseq 84 | ``` 85 | Pour un entraînement rapide, installer la librairie [apex](https://github.com/NVIDIA/apex) de NVIDIA: 86 | 87 | ``` bash 88 | git clone https://github.com/NVIDIA/apex 89 | cd apex 90 | pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" \ 91 | --global-option="--deprecated_fused_adam" --global-option="--xentropy" \ 92 | --global-option="--fast_multihead_attn" ./ 93 | ``` 94 | Lorsque la dataset est trop volumineuse, installer [PyArrow](https://arrow.apache.org/docs/python/install.html#using-pip) : `pip install pyarrow`. 95 | 96 | #### a. Entrainement depuis une ligne de commandes 97 | Le modèle prend en entrée des fichiers. Il est recommandé de fragmenter les audios en morceaux de 10 à 30 secondes. 98 | - Préparation 99 | 100 | Installer la librairie `soundfile`: 101 | ```shell script 102 | pip install soundfile 103 | ``` 104 | 105 | Ensuite: 106 | 107 | ```shell script 108 | $ python wav2vec/wav2vec_manifest.py /path/to/waves --dest /manifest/path --ext $ext --valid-percent $valid 109 | ``` 110 | $ext : peut être n'importe quel format audio (mp3, flav, wav,...) tant que sounfile peut lire. 111 | 112 | $valid : prendre une petite portion (10% par exemple) des données d'entraînement pour la validation 113 | 114 | #### b. Entraînement du modèle wav2vec2.0 de base 115 | Les audios entrée doivent être à un canal avec une fréquence d'échantillonnage de 16kHz. Celle-ci est la configuration utilisée dans l'article de wav2vec2.0. 116 | 117 | ```shell script 118 | $ fairseq-hydra-train \ 119 | task.data=/path/to/data \ 120 | --config-dir wav2vec/config/pretraining \ 121 | --config-name wav2vec2_base_librispeech 122 | ``` 123 | # 3. Fine-tuning 124 | Pour faire un fine-tuning du modèle, nous recommandons de suivre les étapes listées sur ce Notebook : [Fine-tuning de XLSR-53 sur le Darija](https://github.com/nairaxo/dialectal-voice-clone/blob/main/wav2vec%202.0/xlsr_wav2vec2_darija_finetuning.ipynb). 125 | -------------------------------------------------------------------------------- /wav2vec 2.0/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIOXLABS/DVoice/19b760ca93ed9016d406bd9baf077e49bab384ec/wav2vec 2.0/__init__.py -------------------------------------------------------------------------------- /wav2vec 2.0/config/finetuning/base_100h.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | common: 4 | fp16: true 5 | log_format: json 6 | log_interval: 200 7 | 8 | checkpoint: 9 | no_epoch_checkpoints: true 10 | best_checkpoint_metric: wer 11 | 12 | task: 13 | _name: audio_finetuning 14 | data: ??? 15 | normalize: false 16 | labels: ltr 17 | 18 | dataset: 19 | num_workers: 6 20 | max_tokens: 3200000 21 | skip_invalid_size_inputs_valid_test: true 22 | valid_subset: dev_other 23 | 24 | distributed_training: 25 | ddp_backend: legacy_ddp 26 | distributed_world_size: 2 27 | 28 | criterion: 29 | _name: ctc 30 | zero_infinity: true 31 | 32 | optimization: 33 | max_update: 80000 34 | lr: [0.00003] 35 | sentence_avg: true 36 | update_freq: [4] 37 | 38 | optimizer: 39 | _name: adam 40 | adam_betas: (0.9,0.98) 41 | adam_eps: 1e-08 42 | 43 | lr_scheduler: 44 | _name: tri_stage 45 | phase_ratio: [0.1, 0.4, 0.5] 46 | final_lr_scale: 0.05 47 | 48 | model: 49 | _name: wav2vec_ctc 50 | w2v_path: ??? 51 | apply_mask: true 52 | mask_prob: 0.65 53 | mask_channel_prob: 0.5 54 | mask_channel_length: 64 55 | layerdrop: 0.1 56 | activation_dropout: 0.1 57 | feature_grad_mult: 0.0 58 | freeze_finetune_updates: 0 59 | -------------------------------------------------------------------------------- /wav2vec 2.0/config/finetuning/base_10h.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | common: 4 | fp16: true 5 | log_format: json 6 | log_interval: 200 7 | 8 | checkpoint: 9 | save_interval: 50 10 | save_interval_updates: 10000 11 | keep_interval_updates: 1 12 | no_epoch_checkpoints: true 13 | best_checkpoint_metric: wer 14 | 15 | task: 16 | _name: audio_finetuning 17 | data: ??? 18 | normalize: false 19 | labels: ltr 20 | 21 | dataset: 22 | num_workers: 6 23 | max_tokens: 3200000 24 | skip_invalid_size_inputs_valid_test: true 25 | validate_after_updates: 10000 26 | validate_interval: 50 27 | valid_subset: dev_other 28 | 29 | distributed_training: 30 | ddp_backend: legacy_ddp 31 | distributed_world_size: 2 32 | 33 | criterion: 34 | _name: ctc 35 | zero_infinity: true 36 | 37 | optimization: 38 | max_update: 20000 39 | lr: [0.00005] 40 | sentence_avg: true 41 | update_freq: [4] 42 | 43 | optimizer: 44 | _name: adam 45 | adam_betas: (0.9,0.98) 46 | adam_eps: 1e-08 47 | 48 | lr_scheduler: 49 | _name: tri_stage 50 | phase_ratio: [0.1, 0.4, 0.5] 51 | final_lr_scale: 0.05 52 | 53 | model: 54 | _name: wav2vec_ctc 55 | w2v_path: ??? 56 | apply_mask: true 57 | mask_prob: 0.65 58 | mask_channel_prob: 0.5 59 | mask_channel_length: 64 60 | layerdrop: 0.05 61 | activation_dropout: 0.1 62 | feature_grad_mult: 0.0 63 | freeze_finetune_updates: 10000 64 | -------------------------------------------------------------------------------- /wav2vec 2.0/config/finetuning/base_10m.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | common: 4 | fp16: true 5 | log_format: json 6 | log_interval: 200 7 | 8 | checkpoint: 9 | save_interval: 1000 10 | save_interval_updates: 50 11 | keep_interval_updates: 1 12 | no_epoch_checkpoints: true 13 | best_checkpoint_metric: wer 14 | 15 | task: 16 | _name: audio_finetuning 17 | data: ??? 18 | normalize: false 19 | labels: ltr 20 | 21 | dataset: 22 | num_workers: 6 23 | max_tokens: 3200000 24 | skip_invalid_size_inputs_valid_test: true 25 | validate_after_updates: 10000 26 | validate_interval: 1000 27 | valid_subset: dev_other 28 | 29 | distributed_training: 30 | ddp_backend: legacy_ddp 31 | distributed_world_size: 2 32 | 33 | criterion: 34 | _name: ctc 35 | zero_infinity: true 36 | 37 | optimization: 38 | max_update: 13000 39 | lr: [0.00005] 40 | sentence_avg: true 41 | update_freq: [4] 42 | 43 | optimizer: 44 | _name: adam 45 | adam_betas: (0.9,0.98) 46 | adam_eps: 1e-08 47 | 48 | lr_scheduler: 49 | _name: tri_stage 50 | phase_ratio: [0.1, 0.4, 0.5] 51 | final_lr_scale: 0.05 52 | 53 | model: 54 | _name: wav2vec_ctc 55 | w2v_path: ??? 56 | apply_mask: true 57 | mask_prob: 0.65 58 | mask_channel_prob: 0.25 59 | mask_channel_length: 64 60 | layerdrop: 0.1 61 | activation_dropout: 0.1 62 | feature_grad_mult: 0.0 63 | freeze_finetune_updates: 10000 64 | -------------------------------------------------------------------------------- /wav2vec 2.0/config/finetuning/base_1h.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | common: 4 | fp16: true 5 | log_format: json 6 | log_interval: 200 7 | 8 | checkpoint: 9 | save_interval: 1000 10 | save_interval_updates: 50 11 | keep_interval_updates: 1 12 | no_epoch_checkpoints: true 13 | best_checkpoint_metric: wer 14 | 15 | task: 16 | _name: audio_finetuning 17 | data: ??? 18 | normalize: false 19 | labels: ltr 20 | 21 | dataset: 22 | num_workers: 6 23 | max_tokens: 3200000 24 | skip_invalid_size_inputs_valid_test: true 25 | validate_after_updates: 10000 26 | validate_interval: 1000 27 | valid_subset: dev_other 28 | 29 | distributed_training: 30 | ddp_backend: legacy_ddp 31 | distributed_world_size: 2 32 | 33 | criterion: 34 | _name: ctc 35 | zero_infinity: true 36 | 37 | optimization: 38 | max_update: 13000 39 | lr: [0.00005] 40 | sentence_avg: true 41 | update_freq: [4] 42 | 43 | optimizer: 44 | _name: adam 45 | adam_betas: (0.9,0.98) 46 | adam_eps: 1e-08 47 | 48 | lr_scheduler: 49 | _name: tri_stage 50 | phase_ratio: [0.1, 0.4, 0.5] 51 | final_lr_scale: 0.05 52 | 53 | model: 54 | _name: wav2vec_ctc 55 | w2v_path: ??? 56 | apply_mask: true 57 | mask_prob: 0.65 58 | mask_channel_prob: 0.25 59 | mask_channel_length: 64 60 | layerdrop: 0.1 61 | activation_dropout: 0.1 62 | feature_grad_mult: 0.0 63 | freeze_finetune_updates: 10000 64 | -------------------------------------------------------------------------------- /wav2vec 2.0/config/finetuning/base_960h.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | common: 4 | fp16: true 5 | log_format: json 6 | log_interval: 200 7 | 8 | checkpoint: 9 | no_epoch_checkpoints: true 10 | best_checkpoint_metric: wer 11 | 12 | task: 13 | _name: audio_finetuning 14 | data: ??? 15 | normalize: false 16 | labels: ltr 17 | 18 | dataset: 19 | num_workers: 6 20 | max_tokens: 3200000 21 | skip_invalid_size_inputs_valid_test: true 22 | valid_subset: dev_other 23 | 24 | distributed_training: 25 | ddp_backend: legacy_ddp 26 | distributed_world_size: 8 27 | 28 | criterion: 29 | _name: ctc 30 | zero_infinity: true 31 | 32 | optimization: 33 | max_update: 320000 34 | lr: [0.0001] 35 | sentence_avg: true 36 | 37 | optimizer: 38 | _name: adam 39 | adam_betas: (0.9,0.98) 40 | adam_eps: 1e-08 41 | 42 | lr_scheduler: 43 | _name: tri_stage 44 | phase_ratio: [0.1, 0.4, 0.5] 45 | final_lr_scale: 0.05 46 | 47 | model: 48 | _name: wav2vec_ctc 49 | w2v_path: ??? 50 | apply_mask: true 51 | mask_prob: 0.5 52 | mask_channel_prob: 0.1 53 | mask_channel_length: 64 54 | layerdrop: 0.1 55 | activation_dropout: 0.1 56 | feature_grad_mult: 0.0 57 | freeze_finetune_updates: 0 58 | -------------------------------------------------------------------------------- /wav2vec 2.0/config/finetuning/vox_100h.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | common: 4 | fp16: true 5 | log_format: json 6 | log_interval: 200 7 | 8 | checkpoint: 9 | no_epoch_checkpoints: true 10 | best_checkpoint_metric: wer 11 | 12 | task: 13 | _name: audio_finetuning 14 | data: ??? 15 | normalize: true 16 | labels: ltr 17 | 18 | dataset: 19 | num_workers: 6 20 | max_tokens: 1280000 21 | skip_invalid_size_inputs_valid_test: true 22 | valid_subset: dev_other 23 | 24 | distributed_training: 25 | ddp_backend: legacy_ddp 26 | distributed_world_size: 4 27 | 28 | criterion: 29 | _name: ctc 30 | zero_infinity: true 31 | 32 | optimization: 33 | max_update: 80000 34 | lr: [0.00003] 35 | sentence_avg: true 36 | update_freq: [5] 37 | 38 | optimizer: 39 | _name: adam 40 | adam_betas: (0.9,0.98) 41 | adam_eps: 1e-08 42 | 43 | lr_scheduler: 44 | _name: tri_stage 45 | phase_ratio: [0.1, 0.4, 0.5] 46 | final_lr_scale: 0.05 47 | 48 | model: 49 | _name: wav2vec_ctc 50 | w2v_path: ??? 51 | apply_mask: true 52 | mask_prob: 0.5 53 | mask_channel_prob: 0.5 54 | mask_channel_length: 64 55 | layerdrop: 0.1 56 | activation_dropout: 0.1 57 | feature_grad_mult: 0.0 58 | freeze_finetune_updates: 10000 59 | -------------------------------------------------------------------------------- /wav2vec 2.0/config/finetuning/vox_10h.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | common: 4 | fp16: true 5 | log_format: json 6 | log_interval: 200 7 | 8 | checkpoint: 9 | save_interval: 50 10 | save_interval_updates: 10000 11 | keep_interval_updates: 1 12 | no_epoch_checkpoints: true 13 | best_checkpoint_metric: wer 14 | 15 | task: 16 | _name: audio_finetuning 17 | data: ??? 18 | normalize: true 19 | labels: ltr 20 | 21 | dataset: 22 | num_workers: 6 23 | max_tokens: 1280000 24 | skip_invalid_size_inputs_valid_test: true 25 | validate_after_updates: 10000 26 | validate_interval: 50 27 | valid_subset: dev_other 28 | 29 | distributed_training: 30 | ddp_backend: legacy_ddp 31 | distributed_world_size: 4 32 | 33 | criterion: 34 | _name: ctc 35 | zero_infinity: true 36 | 37 | optimization: 38 | max_update: 20000 39 | lr: [0.0001] 40 | sentence_avg: true 41 | update_freq: [5] 42 | 43 | optimizer: 44 | _name: adam 45 | adam_betas: (0.9,0.98) 46 | adam_eps: 1e-08 47 | 48 | lr_scheduler: 49 | _name: tri_stage 50 | phase_ratio: [0.1, 0.4, 0.5] 51 | final_lr_scale: 0.05 52 | 53 | model: 54 | _name: wav2vec_ctc 55 | w2v_path: ??? 56 | apply_mask: true 57 | mask_prob: 0.75 58 | mask_channel_prob: 0.25 59 | mask_channel_length: 64 60 | layerdrop: 0.1 61 | activation_dropout: 0.1 62 | feature_grad_mult: 0.0 63 | freeze_finetune_updates: 10000 64 | -------------------------------------------------------------------------------- /wav2vec 2.0/config/finetuning/vox_10m.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | common: 4 | fp16: true 5 | log_format: json 6 | log_interval: 200 7 | 8 | checkpoint: 9 | save_interval: 1000 10 | save_interval_updates: 50 11 | keep_interval_updates: 1 12 | no_epoch_checkpoints: true 13 | best_checkpoint_metric: wer 14 | 15 | task: 16 | _name: audio_finetuning 17 | data: ??? 18 | normalize: true 19 | labels: ltr 20 | 21 | dataset: 22 | num_workers: 6 23 | max_tokens: 1280000 24 | skip_invalid_size_inputs_valid_test: true 25 | validate_after_updates: 10000 26 | validate_interval: 1000 27 | valid_subset: dev_other 28 | 29 | distributed_training: 30 | ddp_backend: legacy_ddp 31 | distributed_world_size: 4 32 | 33 | criterion: 34 | _name: ctc 35 | zero_infinity: true 36 | 37 | optimization: 38 | max_update: 13000 39 | lr: [0.0001] 40 | sentence_avg: true 41 | update_freq: [5] 42 | 43 | optimizer: 44 | _name: adam 45 | adam_betas: (0.9,0.98) 46 | adam_eps: 1e-08 47 | 48 | lr_scheduler: 49 | _name: tri_stage 50 | phase_ratio: [0.1, 0.4, 0.5] 51 | final_lr_scale: 0.05 52 | 53 | model: 54 | _name: wav2vec_ctc 55 | w2v_path: ??? 56 | apply_mask: true 57 | mask_prob: 0.65 58 | mask_channel_prob: 0.25 59 | mask_channel_length: 64 60 | layerdrop: 0.1 61 | activation_dropout: 0.1 62 | feature_grad_mult: 0.0 63 | freeze_finetune_updates: 10000 64 | -------------------------------------------------------------------------------- /wav2vec 2.0/config/finetuning/vox_1h.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | common: 4 | fp16: true 5 | log_format: json 6 | log_interval: 200 7 | 8 | checkpoint: 9 | save_interval: 1000 10 | save_interval_updates: 50 11 | keep_interval_updates: 1 12 | no_epoch_checkpoints: true 13 | best_checkpoint_metric: wer 14 | 15 | task: 16 | _name: audio_finetuning 17 | data: ??? 18 | normalize: true 19 | labels: ltr 20 | 21 | dataset: 22 | num_workers: 6 23 | max_tokens: 1280000 24 | skip_invalid_size_inputs_valid_test: true 25 | validate_after_updates: 10000 26 | validate_interval: 1000 27 | valid_subset: dev_other 28 | 29 | distributed_training: 30 | ddp_backend: legacy_ddp 31 | distributed_world_size: 4 32 | 33 | criterion: 34 | _name: ctc 35 | zero_infinity: true 36 | 37 | optimization: 38 | max_update: 13000 39 | lr: [0.0003] 40 | sentence_avg: true 41 | update_freq: [5] 42 | 43 | optimizer: 44 | _name: adam 45 | adam_betas: (0.9,0.98) 46 | adam_eps: 1e-08 47 | 48 | lr_scheduler: 49 | _name: tri_stage 50 | phase_ratio: [0.1, 0.4, 0.5] 51 | final_lr_scale: 0.05 52 | 53 | model: 54 | _name: wav2vec_ctc 55 | w2v_path: ??? 56 | apply_mask: true 57 | mask_prob: 0.75 58 | mask_channel_prob: 0.25 59 | mask_channel_length: 64 60 | layerdrop: 0.1 61 | activation_dropout: 0.1 62 | feature_grad_mult: 0.0 63 | freeze_finetune_updates: 10000 64 | -------------------------------------------------------------------------------- /wav2vec 2.0/config/finetuning/vox_960h.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | common: 4 | fp16: true 5 | log_format: json 6 | log_interval: 200 7 | 8 | checkpoint: 9 | no_epoch_checkpoints: true 10 | best_checkpoint_metric: wer 11 | 12 | task: 13 | _name: audio_finetuning 14 | data: ??? 15 | normalize: true 16 | labels: ltr 17 | 18 | dataset: 19 | num_workers: 6 20 | max_tokens: 1280000 21 | skip_invalid_size_inputs_valid_test: true 22 | valid_subset: dev_other 23 | 24 | distributed_training: 25 | ddp_backend: legacy_ddp 26 | distributed_world_size: 24 27 | 28 | criterion: 29 | _name: ctc 30 | zero_infinity: true 31 | 32 | optimization: 33 | max_update: 320000 34 | lr: [0.00003] 35 | sentence_avg: true 36 | 37 | optimizer: 38 | _name: adam 39 | adam_betas: (0.9,0.98) 40 | adam_eps: 1e-08 41 | 42 | lr_scheduler: 43 | _name: tri_stage 44 | phase_ratio: [0.1, 0.4, 0.5] 45 | final_lr_scale: 0.05 46 | 47 | model: 48 | _name: wav2vec_ctc 49 | w2v_path: ??? 50 | apply_mask: true 51 | mask_prob: 0.5 52 | mask_channel_prob: 0.25 53 | mask_channel_length: 64 54 | layerdrop: 0.1 55 | activation_dropout: 0.1 56 | feature_grad_mult: 0.0 57 | freeze_finetune_updates: 10000 58 | -------------------------------------------------------------------------------- /wav2vec 2.0/config/pretraining/wav2vec2_base_librispeech.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | common: 4 | fp16: true 5 | log_format: json 6 | log_interval: 200 7 | 8 | checkpoint: 9 | save_interval_updates: 25000 10 | keep_interval_updates: 1 11 | no_epoch_checkpoints: true 12 | 13 | task: 14 | _name: audio_pretraining 15 | data: ??? 16 | max_sample_size: 250000 17 | min_sample_size: 32000 18 | normalize: false 19 | 20 | dataset: 21 | num_workers: 6 22 | max_tokens: 1400000 23 | skip_invalid_size_inputs_valid_test: true 24 | 25 | distributed_training: 26 | distributed_world_size: 64 27 | ddp_backend: legacy_ddp 28 | 29 | criterion: 30 | _name: wav2vec 31 | infonce: true 32 | log_keys: ["prob_perplexity","code_perplexity","temp"] 33 | loss_weights: [0.1, 10] 34 | 35 | optimization: 36 | max_update: 400000 37 | lr: [0.0005] 38 | 39 | optimizer: 40 | _name: adam 41 | adam_betas: (0.9,0.98) 42 | adam_eps: 1e-06 43 | weight_decay: 0.01 44 | 45 | lr_scheduler: 46 | _name: polynomial_decay 47 | warmup_updates: 32000 48 | 49 | model: 50 | _name: wav2vec2 51 | quantize_targets: true 52 | final_dim: 256 53 | encoder_layerdrop: 0.05 54 | dropout_input: 0.1 55 | dropout_features: 0.1 56 | feature_grad_mult: 0.1 57 | encoder_embed_dim: 768 58 | -------------------------------------------------------------------------------- /wav2vec 2.0/config/pretraining/wav2vec2_large_librivox.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | common: 4 | fp16: true 5 | log_format: json 6 | log_interval: 200 7 | 8 | checkpoint: 9 | save_interval_updates: 25000 10 | keep_interval_updates: 1 11 | no_epoch_checkpoints: true 12 | 13 | task: 14 | _name: audio_pretraining 15 | data: ??? 16 | max_sample_size: 320000 17 | min_sample_size: 32000 18 | normalize: true 19 | 20 | dataset: 21 | batch_size: 4 22 | num_workers: 6 23 | max_tokens: 1200000 24 | skip_invalid_size_inputs_valid_test: true 25 | 26 | distributed_training: 27 | distributed_world_size: 128 28 | ddp_backend: legacy_ddp 29 | 30 | criterion: 31 | _name: wav2vec 32 | infonce: true 33 | log_keys: ["prob_perplexity","code_perplexity","temp"] 34 | loss_weights: [0.1, 0] 35 | 36 | optimization: 37 | max_update: 1000000 38 | lr: [0.005] 39 | 40 | optimizer: 41 | _name: adam 42 | adam_betas: (0.9,0.98) 43 | adam_eps: 1e-06 44 | weight_decay: 0.01 45 | 46 | lr_scheduler: 47 | _name: polynomial_decay 48 | warmup_updates: 32000 49 | 50 | model: 51 | _name: wav2vec2 52 | quantize_targets: true 53 | extractor_mode: layer_norm 54 | layer_norm_first: true 55 | final_dim: 768 56 | latent_temp: [2.0,0.1,0.999995] 57 | encoder_layerdrop: 0.00 58 | dropout_input: 0.0 59 | dropout_features: 0.0 60 | dropout: 0.0 61 | attention_dropout: 0.0 62 | conv_bias: true 63 | 64 | encoder_layers: 24 65 | encoder_embed_dim: 1024 66 | encoder_ffn_embed_dim: 4096 67 | encoder_attention_heads: 16 68 | 69 | feature_grad_mult: 1.0 70 | 71 | -------------------------------------------------------------------------------- /wav2vec 2.0/config/pretraining/wav2vec2_large_librivox_tpu-pod.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | common: 4 | tpu: true 5 | fp16: false 6 | log_format: json 7 | log_interval: 10 8 | 9 | checkpoint: 10 | save_interval_updates: 25000 11 | keep_interval_updates: 1 12 | no_epoch_checkpoints: true 13 | 14 | task: 15 | _name: audio_pretraining 16 | data: ??? 17 | max_sample_size: 250000 18 | min_sample_size: 32000 19 | normalize: true 20 | num_batch_buckets: 3 21 | precompute_mask_indices: true 22 | enable_padding: true 23 | 24 | dataset: 25 | num_workers: 6 26 | max_tokens: 1200000 27 | skip_invalid_size_inputs_valid_test: true 28 | 29 | distributed_training: 30 | distributed_world_size: 128 31 | ddp_backend: legacy_ddp 32 | 33 | criterion: 34 | _name: wav2vec 35 | infonce: true 36 | log_keys: ["prob_perplexity","code_perplexity","temp"] 37 | loss_weights: [0.1, 0] 38 | 39 | optimization: 40 | max_update: 1000000 41 | lr: [0.005] 42 | 43 | optimizer: 44 | _name: adam 45 | adam_betas: (0.9,0.98) 46 | adam_eps: 1e-06 47 | weight_decay: 0.01 48 | 49 | lr_scheduler: 50 | _name: polynomial_decay 51 | warmup_updates: 32000 52 | 53 | model: 54 | _name: wav2vec2 55 | quantize_targets: true 56 | extractor_mode: layer_norm 57 | layer_norm_first: true 58 | final_dim: 768 59 | latent_temp: [2.0,0.1,0.999995] 60 | encoder_layerdrop: 0.00 61 | dropout_input: 0.0 62 | dropout_features: 0.0 63 | dropout: 0.0 64 | attention_dropout: 0.0 65 | conv_bias: true 66 | 67 | encoder_layers: 24 68 | encoder_embed_dim: 1024 69 | encoder_ffn_embed_dim: 4096 70 | encoder_attention_heads: 16 71 | 72 | feature_grad_mult: 1.0 73 | -------------------------------------------------------------------------------- /wav2vec 2.0/config/pretraining/wav2vec2_large_librivox_tpu.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | common: 4 | tpu: true 5 | fp16: false 6 | log_format: json 7 | log_interval: 10 8 | 9 | checkpoint: 10 | save_interval_updates: 25000 11 | keep_interval_updates: 1 12 | no_epoch_checkpoints: true 13 | 14 | task: 15 | _name: audio_pretraining 16 | data: ??? 17 | max_sample_size: 250000 18 | min_sample_size: 32000 19 | normalize: true 20 | num_batch_buckets: 3 21 | precompute_mask_indices: true 22 | enable_padding: true 23 | inferred_w2v_config: 24 | mask_prob: 0.65 25 | mask_selection: 'static' 26 | mask_other: 0 27 | mask_channel_prob: 0.1 28 | 29 | dataset: 30 | num_workers: 6 31 | max_tokens: 1200000 32 | skip_invalid_size_inputs_valid_test: true 33 | 34 | distributed_training: 35 | distributed_world_size: 8 36 | ddp_backend: legacy_ddp 37 | 38 | criterion: 39 | _name: wav2vec 40 | infonce: true 41 | log_keys: ["prob_perplexity","code_perplexity","temp"] 42 | loss_weights: [0.1, 0] 43 | 44 | optimization: 45 | max_update: 1000000 46 | lr: [0.005] 47 | 48 | optimizer: 49 | _name: adam 50 | adam_betas: (0.9,0.98) 51 | adam_eps: 1e-06 52 | weight_decay: 0.01 53 | 54 | lr_scheduler: 55 | _name: polynomial_decay 56 | warmup_updates: 32000 57 | 58 | model: 59 | _name: wav2vec2 60 | quantize_targets: true 61 | extractor_mode: layer_norm 62 | layer_norm_first: true 63 | final_dim: 768 64 | latent_temp: [2.0,0.1,0.999995] 65 | encoder_layerdrop: 0.00 66 | dropout_input: 0.0 67 | dropout_features: 0.0 68 | dropout: 0.0 69 | attention_dropout: 0.0 70 | conv_bias: true 71 | 72 | encoder_layers: 24 73 | encoder_embed_dim: 1024 74 | encoder_ffn_embed_dim: 4096 75 | encoder_attention_heads: 16 76 | 77 | feature_grad_mult: 1.0 78 | -------------------------------------------------------------------------------- /wav2vec 2.0/libri_labels.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Helper script to pre-compute embeddings for a flashlight (previously called wav2letter++) dataset 9 | """ 10 | 11 | import argparse 12 | import os 13 | 14 | 15 | def main(): 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("tsv") 18 | parser.add_argument("--output-dir", required=True) 19 | parser.add_argument("--output-name", required=True) 20 | args = parser.parse_args() 21 | 22 | os.makedirs(args.output_dir, exist_ok=True) 23 | 24 | transcriptions = {} 25 | 26 | with open(args.tsv, "r") as tsv, open( 27 | os.path.join(args.output_dir, args.output_name + ".ltr"), "w" 28 | ) as ltr_out, open( 29 | os.path.join(args.output_dir, args.output_name + ".wrd"), "w" 30 | ) as wrd_out: 31 | root = next(tsv).strip() 32 | for line in tsv: 33 | line = line.strip() 34 | dir = os.path.dirname(line) 35 | if dir not in transcriptions: 36 | parts = dir.split(os.path.sep) 37 | trans_path = f"{parts[-2]}-{parts[-1]}.trans.txt" 38 | path = os.path.join(root, dir, trans_path) 39 | assert os.path.exists(path) 40 | texts = {} 41 | with open(path, "r") as trans_f: 42 | for tline in trans_f: 43 | items = tline.strip().split() 44 | texts[items[0]] = " ".join(items[1:]) 45 | transcriptions[dir] = texts 46 | part = os.path.basename(line).split(".")[0] 47 | assert part in transcriptions[dir] 48 | print(transcriptions[dir][part], file=wrd_out) 49 | print( 50 | " ".join(list(transcriptions[dir][part].replace(" ", "|"))) + " |", 51 | file=ltr_out, 52 | ) 53 | 54 | 55 | if __name__ == "__main__": 56 | main() 57 | -------------------------------------------------------------------------------- /wav2vec 2.0/scripts/binarize_manifest.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # usage: bash binarize_manifest 4 | 5 | DEST_DIR=$1 6 | TRAIN_SPLIT=$2 7 | VALID_SPLIT=$3 8 | FAIRSEQ_ROOT=$4 9 | 10 | mkdir -p $DEST_DIR 11 | 12 | # split file path and lengths into separate files 13 | cut -f1 $TRAIN_SPLIT.tsv > $DEST_DIR/train_fnames.txt 14 | cut -f1 $VALID_SPLIT.tsv > $DEST_DIR/valid_fnames.txt 15 | cut -f2 $TRAIN_SPLIT.tsv > $DEST_DIR/train.lengths 16 | cut -f2 $VALID_SPLIT.tsv > $DEST_DIR/valid.lengths 17 | 18 | # copy root directory 19 | head -1 $TRAIN_SPLIT.tsv > $DEST_DIR/train.root 20 | head -1 $VALID_SPLIT.tsv > $DEST_DIR/valid.root 21 | 22 | # remove root directory 23 | sed -i '1d' $DEST_DIR/train_fnames.txt 24 | sed -i '1d' $DEST_DIR/valid_fnames.txt 25 | sed -i '1d' $DEST_DIR/train.lengths 26 | sed -i '1d' $DEST_DIR/valid.lengths 27 | 28 | # insert spaces between characters 29 | sed -i -e 's/\(.\)/\1 /g' $DEST_DIR/train_fnames.txt 30 | sed -i -e 's/\(.\)/\1 /g' $DEST_DIR/valid_fnames.txt 31 | 32 | # run preprocessor 33 | PYTHONPATH=$FAIRSEQ_ROOT python $FAIRSEQ_ROOT/fairseq_cli/preprocess.py --dataset-impl mmap --trainpref $DEST_DIR/train_fnames.txt --validpref $DEST_DIR/valid_fnames.txt --workers 60 --only-source --destdir $DEST_DIR 34 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/README.md: -------------------------------------------------------------------------------- 1 | # wav2vec Unsupervised (wav2vec-U) 2 | 3 | Wav2vec Unsupervised (wav2vec-U) is a framework for building speech recognition systems without any labeled training data as described in [Unsupervised Speech Recognition (Baevski et al., 2021)](https://ai.facebook.com/research/publications/unsupervised-speech-recognition). The model takes as input wav2vec 2.0 or XLSR representations (see [pretrained models](https://github.com/pytorch/fairseq/blob/master/examples/wav2vec)) as well as unlabeled speech and text data. 4 | 5 | The wav2vec-U training procedure consists of three consecutive main steps: 6 | * Preparation of speech representations and text data 7 | * Generative adversarial training (GAN) 8 | * Iterative self-training + Kaldi LM-decoding 9 | 10 | ## Preparation of speech and text data 11 | Similar to [wav2vec 2.0](https://github.com/pytorch/fairseq/blob/master/examples/wav2vec/README.md), data folders contain {train,valid,test}.{tsv,wrd,phn} files, where audio paths are stored in tsv files, and word, letter or phoneme transcriptions are stored in .{wrd,ltr,phn}. 12 | 13 | In **/path/to/data/with_silence** you need a *train.tsv* file as well as (optionally) *{valid,test}.{tsv,wrd,phn}*. It is nice to have *10h.{tsv,phn}* files there too for reproducing the ablation study on layer selection. In **/path/to/data/without_silence** you have the same files, except *.tsv* files contain audios with silences removed using rVAD. 14 | 15 | Pre-requisites: 16 | * set FAIRSEQ_ROOT environmental variable to your fairseq installation 17 | * set RVAD_ROOT environmental variable to a checkout of [rVADfast](https://github.com/zhenghuatan/rVADfast) 18 | * set KENLM_ROOT environmental variable to the location of [KenLM](https://github.com/kpu/kenlm) binaries 19 | * install [PyKaldi](https://github.com/pykaldi/pykaldi) and set KALDI_ROOT environmental variable to the location of your kaldi installation. To use the version bundled with PyKaldi, you can use /path/to/pykaldi/tools/kaldi 20 | 21 | Create new audio files without silences: 22 | ```shell 23 | # create a manifest file for the set original of audio files 24 | python $FAIRSEQ_ROOT/examples/wav2vec/wav2vec_manifest.py /dir/to/save/audio/files --ext wav --dest /path/to/new/train.tsv --valid-percent 0 25 | 26 | python scripts/vads.py -r $RVAD_ROOT < /path/to/train.tsv > train.vads 27 | 28 | python scripts/remove_silence.py --tsv /path/to/train.tsv --vads train.vads --out /dir/to/save/audio/files 29 | 30 | python $FAIRSEQ_ROOT/examples/wav2vec/wav2vec_manifest.py /dir/to/save/audio/files --ext wav --dest /path/to/new/train.tsv --valid-percent 0.01 31 | ``` 32 | 33 | Next, we need to preprocess the audio data to better match phonemized text data: 34 | 35 | ```shell 36 | zsh scripts/prepare_audio.sh /dir/with/{train,test,valid}.tsv /output/dir /path/to/wav2vec2/model.pt 512 14 37 | ``` 38 | Note that if you have splits different than train/valid/test, you will need to modify this script. The last two arguments are the PCA dimensionality and the 0-based index of the layer from which to extract representations. 39 | 40 | Now we need to prepare text data: 41 | ```shell 42 | zsh scripts/prepare_text.sh language /path/to/text/file /output/dir 1000 espeak /path/to/fasttext/lid/model 43 | ``` 44 | 45 | The fourth argument is minimum number observations of phones to keep. If your text corpus is small, you might want to reduce this number. 46 | 47 | The fifth argument is which phonemizer to use. Supported values are [espeak](http://espeak.sourceforge.net/), [espeak-ng](https://github.com/espeak-ng/espeak-ng), and [G2P](https://github.com/Kyubyong/g2p) (english only). 48 | 49 | Pre-trained fasttext LID models can be downloaded [here](https://fasttext.cc/docs/en/language-identification.html). 50 | 51 | ### Prepare TIMIT data 52 | TIMIT transcripts include silence. Therefore VAD is not used for audio preprocessing, and we do not wrap transcripts with silences or insert random silence in between words. 53 | 54 | To prepare TIMIT data for both the matched an unmatched setup: 55 | ```shell 56 | bash scripts/prepare_timit.sh /dir/to/timit/raw/data /output/dir /path/to/wav2vec2/model.pt 57 | ``` 58 | 59 | Note that we assume the TIMIT distribution with capitalized directories and filenames are used (e.g., `TRAIN/DR1/FCJF0/SA1.PHN`). 60 | 61 | ## Generative adversarial training (GAN) 62 | 63 | We then use a GAN model to build a first unsupervised ASR model. The data preparation above of both speech features and text data is a necessary procedure that enables the generator to match speech to text in an unsupervised way. 64 | 65 | Launching GAN training on top of preprocessed features, with default hyperparameters can be done with: 66 | 67 | ``` 68 | PREFIX=w2v_unsup_gan_xp 69 | TASK_DATA=/path/to/features/precompute_unfiltered_pca512_cls128_mean_pooled 70 | TEXT_DATA=/path/to/data/phones # path to fairseq-preprocessed GAN data (phones dir) 71 | KENLM_PATH=/path/to/data/phones/kenlm.phn.o4.bin # KenLM 4-gram phoneme language model (LM data = GAN data here) 72 | 73 | PYTHONPATH=$FAIRSEQ_ROOT PREFIX=$PREFIX fairseq-hydra-train \ 74 | -m --config-dir config/gan \ 75 | --config-name w2vu \ 76 | task.data=${TASK_DATA} \ 77 | task.text_data=${TEXT_DATA} \ 78 | task.kenlm_path=${KENLM_PATH} \ 79 | common.user_dir=${FAIRSEQ_ROOT}/examples/wav2vec/unsupervised \ 80 | model.code_penalty=2,4 model.gradient_penalty=1.5,2.0 \ 81 | model.smoothness_weight=0.5,0.75,1.0 'common.seed=range(0,5)' 82 | ``` 83 | 84 | 85 | Once we find the best checkpoint (chosen using unsupervised metric that combined language model perplexity and vocabulary usage), we can use it to generate phone labels (or word labels with an appropriate kaldi WFST): 86 | 87 | ```shell 88 | python w2vu_generate.py --config-dir config/generate --config-name viterbi \ 89 | fairseq.common.user_dir=${FAIRSEQ_ROOT}/examples/wav2vec/unsupervised \ 90 | fairseq.task.data=/path/to/dir/with/features \ 91 | fairseq.common_eval.path=/path/to/gan/checkpoint \ 92 | fairseq.dataset.gen_subset=valid results_path=/where/to/save/transcriptions 93 | ``` 94 | 95 | The decoding without LM works best on the same adjacent-mean-pooled features that the gan was trained on, while decoding with LM works better on features before the adjacent timestep mean-pooling step (without the "_pooled" suffix). 96 | 97 | ## Iterative self-training + Kaldi LM-decoding 98 | After the GAN training provides a first unsupervised model, we can then progressively refine the quality of transcriptions using several iterations of semi-supervised learning. We perform two iterations: first, pseudo-label the training data with the unsupervised GAN model and train an HMM on the pseudo-labels. Second, we relabel the training data with the HMM and then fine-tune the original wav2vec 2.0 model using the HMM pseudo-labels with a CTC loss. Note that HMM models use phonemes as output, while wav2vec 2.0 use letter. Both are decoded using WFST decoders into words. 99 | 100 | 101 | Please see [this README](kaldi_self_train/README.md) for more instructions on how to do iterative self-training + Kaldi LM-decoding. 102 | 103 | *** Note: these instructions are a work in progress and will be updated over the next few days 104 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIOXLABS/DVoice/19b760ca93ed9016d406bd9baf077e49bab384ec/wav2vec 2.0/unsupervised/__init__.py -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/config/finetuning/w2v_finetune.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | common: 4 | fp16: true 5 | log_format: json 6 | log_interval: 200 7 | tensorboard_logdir: tb 8 | 9 | checkpoint: 10 | no_epoch_checkpoints: true 11 | save_interval_updates: 20000 12 | 13 | task: 14 | _name: audio_finetuning 15 | data: ??? 16 | normalize: true 17 | labels: ltr 18 | 19 | dataset: 20 | num_workers: 6 21 | max_tokens: 800000 22 | skip_invalid_size_inputs_valid_test: true 23 | train_subset: train 24 | valid_subset: valid 25 | 26 | distributed_training: 27 | ddp_backend: legacy_ddp 28 | distributed_world_size: 8 29 | find_unused_parameters: True 30 | 31 | criterion: 32 | _name: ctc 33 | zero_infinity: true 34 | post_process: letter 35 | 36 | optimization: 37 | max_update: 80000 38 | lr: [0.00003] 39 | sentence_avg: true 40 | update_freq: [1] 41 | 42 | optimizer: 43 | _name: adam 44 | adam_betas: (0.9,0.98) 45 | adam_eps: 1e-08 46 | 47 | lr_scheduler: 48 | _name: tri_stage 49 | phase_ratio: [0.1, 0.4, 0.5] 50 | final_lr_scale: 0.05 51 | 52 | model: 53 | _name: wav2vec_ctc 54 | w2v_path: ??? 55 | apply_mask: true 56 | mask_prob: 0.25 57 | mask_channel_prob: 0.1 58 | mask_channel_length: 64 59 | layerdrop: 0.1 60 | activation_dropout: 0.1 61 | feature_grad_mult: 0.0 62 | freeze_finetune_updates: 0 63 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/config/gan/w2vu.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | common: 4 | fp16: false 5 | fp16_no_flatten_grads: true 6 | log_format: json 7 | log_interval: 100 8 | tensorboard_logdir: tb 9 | reset_logging: false 10 | suppress_crashes: false 11 | 12 | checkpoint: 13 | save_interval: 1000 14 | save_interval_updates: 1000 15 | no_epoch_checkpoints: true 16 | best_checkpoint_metric: weighted_lm_ppl 17 | save_dir: . 18 | 19 | distributed_training: 20 | distributed_world_size: 1 21 | 22 | task: 23 | _name: unpaired_audio_text 24 | data: ??? 25 | text_data: ??? 26 | labels: phn 27 | sort_by_length: false 28 | unfiltered: false 29 | max_length: null 30 | append_eos: false 31 | kenlm_path: ??? 32 | 33 | dataset: 34 | num_workers: 6 35 | batch_size: 160 36 | skip_invalid_size_inputs_valid_test: true 37 | valid_subset: valid 38 | validate_interval: 1000 39 | validate_interval_updates: 1000 40 | 41 | criterion: 42 | _name: model 43 | log_keys: 44 | - accuracy_dense 45 | - accuracy_token 46 | - temp 47 | - code_ppl 48 | 49 | optimization: 50 | max_update: 150000 51 | clip_norm: 5.0 52 | lr: [0] 53 | 54 | optimizer: 55 | _name: composite 56 | groups: 57 | generator: 58 | lr: [0.0004] 59 | lr_float: null 60 | optimizer: 61 | _name: adam 62 | adam_betas: [0.5,0.98] 63 | adam_eps: 1e-06 64 | weight_decay: 0 65 | amsgrad: false 66 | lr_scheduler: 67 | _name: fixed 68 | warmup_updates: 0 69 | discriminator: 70 | lr: [ 0.0005 ] 71 | lr_float: null 72 | optimizer: 73 | _name: adam 74 | adam_betas: [0.5,0.98] 75 | adam_eps: 1e-06 76 | weight_decay: 0.0001 77 | amsgrad: false 78 | lr_scheduler: 79 | _name: fixed 80 | warmup_updates: 0 81 | 82 | lr_scheduler: pass_through 83 | 84 | model: 85 | _name: wav2vec_u 86 | 87 | discriminator_dim: 384 88 | discriminator_depth: 2 89 | discriminator_kernel: 6 90 | discriminator_linear_emb: false 91 | discriminator_causal: true 92 | discriminator_max_pool: false 93 | discriminator_act_after_linear: false 94 | discriminator_dropout: 0.0 95 | discriminator_weight_norm: false 96 | 97 | generator_stride: 1 98 | generator_kernel: 4 99 | generator_bias: false 100 | generator_dropout: 0.1 101 | 102 | smoothness_weight: 0.5 103 | smoothing: 0 104 | smoothing_one_sided: false 105 | gumbel: false 106 | hard_gumbel: false 107 | gradient_penalty: 1.5 108 | code_penalty: 4.0 109 | temp: [ 2,0.1,0.99995 ] 110 | input_dim: 512 111 | 112 | segmentation: 113 | type: JOIN 114 | mean_pool_join: false 115 | remove_zeros: false 116 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/config/generate/viterbi.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | fairseq: 4 | task: 5 | _name: unpaired_audio_text 6 | labels: phn 7 | data: ??? 8 | sort_by_length: false 9 | shuffle: false 10 | text_data: '' 11 | 12 | common_eval: 13 | path: ??? 14 | quiet: true 15 | 16 | dataset: 17 | gen_subset: valid 18 | batch_size: 1 19 | 20 | w2l_decoder: VITERBI 21 | post_process: silence 22 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/config/timit_matched/test.uid: -------------------------------------------------------------------------------- 1 | FDHC0_SI1559 2 | FDHC0_SI2189 3 | FDHC0_SI929 4 | FDHC0_SX119 5 | FDHC0_SX209 6 | FDHC0_SX29 7 | FDHC0_SX299 8 | FDHC0_SX389 9 | FELC0_SI1386 10 | FELC0_SI2016 11 | FELC0_SI756 12 | FELC0_SX126 13 | FELC0_SX216 14 | FELC0_SX306 15 | FELC0_SX36 16 | FELC0_SX396 17 | FJLM0_SI1043 18 | FJLM0_SI1673 19 | FJLM0_SI2303 20 | FJLM0_SX143 21 | FJLM0_SX233 22 | FJLM0_SX323 23 | FJLM0_SX413 24 | FJLM0_SX53 25 | FMGD0_SI1564 26 | FMGD0_SI2194 27 | FMGD0_SI934 28 | FMGD0_SX124 29 | FMGD0_SX214 30 | FMGD0_SX304 31 | FMGD0_SX34 32 | FMGD0_SX394 33 | FMLD0_SI2185 34 | FMLD0_SI822 35 | FMLD0_SI925 36 | FMLD0_SX115 37 | FMLD0_SX205 38 | FMLD0_SX25 39 | FMLD0_SX295 40 | FMLD0_SX385 41 | FNLP0_SI1308 42 | FNLP0_SI1938 43 | FNLP0_SI678 44 | FNLP0_SX138 45 | FNLP0_SX228 46 | FNLP0_SX318 47 | FNLP0_SX408 48 | FNLP0_SX48 49 | FPAS0_SI1272 50 | FPAS0_SI2204 51 | FPAS0_SI944 52 | FPAS0_SX134 53 | FPAS0_SX224 54 | FPAS0_SX314 55 | FPAS0_SX404 56 | FPAS0_SX44 57 | FPKT0_SI1538 58 | FPKT0_SI2168 59 | FPKT0_SI908 60 | FPKT0_SX188 61 | FPKT0_SX278 62 | FPKT0_SX368 63 | FPKT0_SX8 64 | FPKT0_SX98 65 | MBPM0_SI1577 66 | MBPM0_SI1584 67 | MBPM0_SI947 68 | MBPM0_SX137 69 | MBPM0_SX227 70 | MBPM0_SX317 71 | MBPM0_SX407 72 | MBPM0_SX47 73 | MCMJ0_SI1094 74 | MCMJ0_SI464 75 | MCMJ0_SI602 76 | MCMJ0_SX104 77 | MCMJ0_SX14 78 | MCMJ0_SX194 79 | MCMJ0_SX284 80 | MCMJ0_SX374 81 | MDAB0_SI1039 82 | MDAB0_SI1669 83 | MDAB0_SI2299 84 | MDAB0_SX139 85 | MDAB0_SX229 86 | MDAB0_SX319 87 | MDAB0_SX409 88 | MDAB0_SX49 89 | MGRT0_SI1450 90 | MGRT0_SI2080 91 | MGRT0_SI820 92 | MGRT0_SX10 93 | MGRT0_SX100 94 | MGRT0_SX190 95 | MGRT0_SX280 96 | MGRT0_SX370 97 | MJDH0_SI1354 98 | MJDH0_SI1984 99 | MJDH0_SI724 100 | MJDH0_SX184 101 | MJDH0_SX274 102 | MJDH0_SX364 103 | MJDH0_SX4 104 | MJDH0_SX94 105 | MJLN0_SI1449 106 | MJLN0_SI2079 107 | MJLN0_SI819 108 | MJLN0_SX189 109 | MJLN0_SX279 110 | MJLN0_SX369 111 | MJLN0_SX9 112 | MJLN0_SX99 113 | MJMP0_SI1535 114 | MJMP0_SI1791 115 | MJMP0_SI905 116 | MJMP0_SX185 117 | MJMP0_SX275 118 | MJMP0_SX365 119 | MJMP0_SX5 120 | MJMP0_SX95 121 | MKLT0_SI1213 122 | MKLT0_SI1843 123 | MKLT0_SI583 124 | MKLT0_SX133 125 | MKLT0_SX223 126 | MKLT0_SX313 127 | MKLT0_SX403 128 | MKLT0_SX43 129 | MLLL0_SI1363 130 | MLLL0_SI1993 131 | MLLL0_SI733 132 | MLLL0_SX103 133 | MLLL0_SX13 134 | MLLL0_SX193 135 | MLLL0_SX283 136 | MLLL0_SX373 137 | MLNT0_SI1574 138 | MLNT0_SI1902 139 | MLNT0_SI642 140 | MLNT0_SX102 141 | MLNT0_SX12 142 | MLNT0_SX192 143 | MLNT0_SX282 144 | MLNT0_SX372 145 | MNJM0_SI1580 146 | MNJM0_SI2210 147 | MNJM0_SI950 148 | MNJM0_SX140 149 | MNJM0_SX230 150 | MNJM0_SX320 151 | MNJM0_SX410 152 | MNJM0_SX50 153 | MPAM0_SI1189 154 | MPAM0_SI1819 155 | MPAM0_SI1961 156 | MPAM0_SX109 157 | MPAM0_SX19 158 | MPAM0_SX199 159 | MPAM0_SX289 160 | MPAM0_SX379 161 | MTAS1_SI1473 162 | MTAS1_SI2098 163 | MTAS1_SI838 164 | MTAS1_SX118 165 | MTAS1_SX208 166 | MTAS1_SX28 167 | MTAS1_SX298 168 | MTAS1_SX388 169 | MTLS0_SI1370 170 | MTLS0_SI2000 171 | MTLS0_SI740 172 | MTLS0_SX110 173 | MTLS0_SX20 174 | MTLS0_SX200 175 | MTLS0_SX290 176 | MTLS0_SX380 177 | MWBT0_SI1553 178 | MWBT0_SI2183 179 | MWBT0_SI923 180 | MWBT0_SX113 181 | MWBT0_SX203 182 | MWBT0_SX23 183 | MWBT0_SX293 184 | MWBT0_SX383 185 | MWEW0_SI1361 186 | MWEW0_SI1991 187 | MWEW0_SI731 188 | MWEW0_SX101 189 | MWEW0_SX11 190 | MWEW0_SX191 191 | MWEW0_SX281 192 | MWEW0_SX371 193 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/config/timit_matched/valid.uid: -------------------------------------------------------------------------------- 1 | FADG0_SI1279 2 | FADG0_SI1909 3 | FADG0_SI649 4 | FADG0_SX109 5 | FADG0_SX19 6 | FADG0_SX199 7 | FADG0_SX289 8 | FADG0_SX379 9 | FAKS0_SI1573 10 | FAKS0_SI2203 11 | FAKS0_SI943 12 | FAKS0_SX133 13 | FAKS0_SX223 14 | FAKS0_SX313 15 | FAKS0_SX403 16 | FAKS0_SX43 17 | FCAL1_SI1403 18 | FCAL1_SI2033 19 | FCAL1_SI773 20 | FCAL1_SX143 21 | FCAL1_SX233 22 | FCAL1_SX323 23 | FCAL1_SX413 24 | FCAL1_SX53 25 | FCMH0_SI1454 26 | FCMH0_SI2084 27 | FCMH0_SI824 28 | FCMH0_SX104 29 | FCMH0_SX14 30 | FCMH0_SX194 31 | FCMH0_SX284 32 | FCMH0_SX374 33 | FDAC1_SI1474 34 | FDAC1_SI2104 35 | FDAC1_SI844 36 | FDAC1_SX124 37 | FDAC1_SX214 38 | FDAC1_SX304 39 | FDAC1_SX34 40 | FDAC1_SX394 41 | FDMS0_SI1218 42 | FDMS0_SI1502 43 | FDMS0_SI1848 44 | FDMS0_SX138 45 | FDMS0_SX228 46 | FDMS0_SX318 47 | FDMS0_SX408 48 | FDMS0_SX48 49 | FDRW0_SI1283 50 | FDRW0_SI1423 51 | FDRW0_SI653 52 | FDRW0_SX113 53 | FDRW0_SX203 54 | FDRW0_SX23 55 | FDRW0_SX293 56 | FDRW0_SX383 57 | FEDW0_SI1084 58 | FEDW0_SI1653 59 | FEDW0_SI1714 60 | FEDW0_SX184 61 | FEDW0_SX274 62 | FEDW0_SX364 63 | FEDW0_SX4 64 | FEDW0_SX94 65 | FGJD0_SI1179 66 | FGJD0_SI549 67 | FGJD0_SI818 68 | FGJD0_SX189 69 | FGJD0_SX279 70 | FGJD0_SX369 71 | FGJD0_SX9 72 | FGJD0_SX99 73 | FJEM0_SI1264 74 | FJEM0_SI1894 75 | FJEM0_SI634 76 | FJEM0_SX184 77 | FJEM0_SX274 78 | FJEM0_SX364 79 | FJEM0_SX4 80 | FJEM0_SX94 81 | FJMG0_SI1181 82 | FJMG0_SI1811 83 | FJMG0_SI551 84 | FJMG0_SX101 85 | FJMG0_SX11 86 | FJMG0_SX191 87 | FJMG0_SX281 88 | FJMG0_SX371 89 | FJSJ0_SI1484 90 | FJSJ0_SI2114 91 | FJSJ0_SI854 92 | FJSJ0_SX134 93 | FJSJ0_SX224 94 | FJSJ0_SX314 95 | FJSJ0_SX404 96 | FJSJ0_SX44 97 | FKMS0_SI1490 98 | FKMS0_SI2120 99 | FKMS0_SI860 100 | FKMS0_SX140 101 | FKMS0_SX230 102 | FKMS0_SX320 103 | FKMS0_SX410 104 | FKMS0_SX50 105 | FMAH0_SI1289 106 | FMAH0_SI1919 107 | FMAH0_SI659 108 | FMAH0_SX119 109 | FMAH0_SX209 110 | FMAH0_SX29 111 | FMAH0_SX299 112 | FMAH0_SX389 113 | FMML0_SI1040 114 | FMML0_SI1670 115 | FMML0_SI2300 116 | FMML0_SX140 117 | FMML0_SX230 118 | FMML0_SX320 119 | FMML0_SX410 120 | FMML0_SX50 121 | FNMR0_SI1399 122 | FNMR0_SI2029 123 | FNMR0_SI769 124 | FNMR0_SX139 125 | FNMR0_SX229 126 | FNMR0_SX319 127 | FNMR0_SX409 128 | FNMR0_SX49 129 | FREW0_SI1030 130 | FREW0_SI1280 131 | FREW0_SI1910 132 | FREW0_SX110 133 | FREW0_SX20 134 | FREW0_SX200 135 | FREW0_SX290 136 | FREW0_SX380 137 | FSEM0_SI1198 138 | FSEM0_SI1828 139 | FSEM0_SI568 140 | FSEM0_SX118 141 | FSEM0_SX208 142 | FSEM0_SX28 143 | FSEM0_SX298 144 | FSEM0_SX388 145 | MAJC0_SI1946 146 | MAJC0_SI2095 147 | MAJC0_SI835 148 | MAJC0_SX115 149 | MAJC0_SX205 150 | MAJC0_SX25 151 | MAJC0_SX295 152 | MAJC0_SX385 153 | MBDG0_SI1463 154 | MBDG0_SI2093 155 | MBDG0_SI833 156 | MBDG0_SX113 157 | MBDG0_SX203 158 | MBDG0_SX23 159 | MBDG0_SX293 160 | MBDG0_SX383 161 | MBNS0_SI1220 162 | MBNS0_SI1850 163 | MBNS0_SI590 164 | MBNS0_SX140 165 | MBNS0_SX230 166 | MBNS0_SX320 167 | MBNS0_SX410 168 | MBNS0_SX50 169 | MBWM0_SI1304 170 | MBWM0_SI1934 171 | MBWM0_SI674 172 | MBWM0_SX134 173 | MBWM0_SX224 174 | MBWM0_SX314 175 | MBWM0_SX404 176 | MBWM0_SX44 177 | MCSH0_SI1549 178 | MCSH0_SI2179 179 | MCSH0_SI919 180 | MCSH0_SX109 181 | MCSH0_SX19 182 | MCSH0_SX199 183 | MCSH0_SX289 184 | MCSH0_SX379 185 | MDLF0_SI1583 186 | MDLF0_SI2213 187 | MDLF0_SI953 188 | MDLF0_SX143 189 | MDLF0_SX233 190 | MDLF0_SX323 191 | MDLF0_SX413 192 | MDLF0_SX53 193 | MDLS0_SI1628 194 | MDLS0_SI2258 195 | MDLS0_SI998 196 | MDLS0_SX188 197 | MDLS0_SX278 198 | MDLS0_SX368 199 | MDLS0_SX8 200 | MDLS0_SX98 201 | MDVC0_SI2174 202 | MDVC0_SI2196 203 | MDVC0_SI936 204 | MDVC0_SX126 205 | MDVC0_SX216 206 | MDVC0_SX306 207 | MDVC0_SX36 208 | MDVC0_SX396 209 | MERS0_SI1019 210 | MERS0_SI1649 211 | MERS0_SI497 212 | MERS0_SX119 213 | MERS0_SX209 214 | MERS0_SX29 215 | MERS0_SX299 216 | MERS0_SX389 217 | MGJF0_SI1901 218 | MGJF0_SI641 219 | MGJF0_SI776 220 | MGJF0_SX101 221 | MGJF0_SX11 222 | MGJF0_SX191 223 | MGJF0_SX281 224 | MGJF0_SX371 225 | MGLB0_SI1534 226 | MGLB0_SI2164 227 | MGLB0_SI904 228 | MGLB0_SX184 229 | MGLB0_SX274 230 | MGLB0_SX364 231 | MGLB0_SX4 232 | MGLB0_SX94 233 | MGWT0_SI1539 234 | MGWT0_SI2169 235 | MGWT0_SI909 236 | MGWT0_SX189 237 | MGWT0_SX279 238 | MGWT0_SX369 239 | MGWT0_SX9 240 | MGWT0_SX99 241 | MJAR0_SI1988 242 | MJAR0_SI2247 243 | MJAR0_SI728 244 | MJAR0_SX188 245 | MJAR0_SX278 246 | MJAR0_SX368 247 | MJAR0_SX8 248 | MJAR0_SX98 249 | MJFC0_SI1033 250 | MJFC0_SI1663 251 | MJFC0_SI2293 252 | MJFC0_SX133 253 | MJFC0_SX223 254 | MJFC0_SX313 255 | MJFC0_SX403 256 | MJFC0_SX43 257 | MJSW0_SI1010 258 | MJSW0_SI1640 259 | MJSW0_SI2270 260 | MJSW0_SX110 261 | MJSW0_SX20 262 | MJSW0_SX200 263 | MJSW0_SX290 264 | MJSW0_SX380 265 | MMDB1_SI1625 266 | MMDB1_SI2255 267 | MMDB1_SI995 268 | MMDB1_SX185 269 | MMDB1_SX275 270 | MMDB1_SX365 271 | MMDB1_SX5 272 | MMDB1_SX95 273 | MMDM2_SI1452 274 | MMDM2_SI1555 275 | MMDM2_SI2082 276 | MMDM2_SX102 277 | MMDM2_SX12 278 | MMDM2_SX192 279 | MMDM2_SX282 280 | MMDM2_SX372 281 | MMJR0_SI1648 282 | MMJR0_SI2166 283 | MMJR0_SI2278 284 | MMJR0_SX118 285 | MMJR0_SX208 286 | MMJR0_SX28 287 | MMJR0_SX298 288 | MMJR0_SX388 289 | MMWH0_SI1089 290 | MMWH0_SI1301 291 | MMWH0_SI459 292 | MMWH0_SX189 293 | MMWH0_SX279 294 | MMWH0_SX369 295 | MMWH0_SX9 296 | MMWH0_SX99 297 | MPDF0_SI1542 298 | MPDF0_SI2172 299 | MPDF0_SI912 300 | MPDF0_SX102 301 | MPDF0_SX12 302 | MPDF0_SX192 303 | MPDF0_SX282 304 | MPDF0_SX372 305 | MRCS0_SI1223 306 | MRCS0_SI1853 307 | MRCS0_SI593 308 | MRCS0_SX143 309 | MRCS0_SX233 310 | MRCS0_SX323 311 | MRCS0_SX413 312 | MRCS0_SX53 313 | MREB0_SI1375 314 | MREB0_SI2005 315 | MREB0_SI745 316 | MREB0_SX115 317 | MREB0_SX205 318 | MREB0_SX25 319 | MREB0_SX295 320 | MREB0_SX385 321 | MRJM4_SI1489 322 | MRJM4_SI2119 323 | MRJM4_SI859 324 | MRJM4_SX139 325 | MRJM4_SX229 326 | MRJM4_SX319 327 | MRJM4_SX409 328 | MRJM4_SX49 329 | MRJR0_SI1182 330 | MRJR0_SI1812 331 | MRJR0_SI2313 332 | MRJR0_SX102 333 | MRJR0_SX12 334 | MRJR0_SX192 335 | MRJR0_SX282 336 | MRJR0_SX372 337 | MROA0_SI1307 338 | MROA0_SI1970 339 | MROA0_SI677 340 | MROA0_SX137 341 | MROA0_SX227 342 | MROA0_SX317 343 | MROA0_SX407 344 | MROA0_SX47 345 | MRTK0_SI1093 346 | MRTK0_SI1723 347 | MRTK0_SI1750 348 | MRTK0_SX103 349 | MRTK0_SX13 350 | MRTK0_SX193 351 | MRTK0_SX283 352 | MRTK0_SX373 353 | MRWS1_SI1130 354 | MRWS1_SI1496 355 | MRWS1_SI500 356 | MRWS1_SX140 357 | MRWS1_SX230 358 | MRWS1_SX320 359 | MRWS1_SX410 360 | MRWS1_SX50 361 | MTAA0_SI1285 362 | MTAA0_SI1915 363 | MTAA0_SI596 364 | MTAA0_SX115 365 | MTAA0_SX205 366 | MTAA0_SX25 367 | MTAA0_SX295 368 | MTAA0_SX385 369 | MTDT0_SI1994 370 | MTDT0_SI2254 371 | MTDT0_SI994 372 | MTDT0_SX184 373 | MTDT0_SX274 374 | MTDT0_SX364 375 | MTDT0_SX4 376 | MTDT0_SX94 377 | MTEB0_SI1133 378 | MTEB0_SI2064 379 | MTEB0_SI503 380 | MTEB0_SX143 381 | MTEB0_SX233 382 | MTEB0_SX323 383 | MTEB0_SX413 384 | MTEB0_SX53 385 | MTHC0_SI1015 386 | MTHC0_SI1645 387 | MTHC0_SI2275 388 | MTHC0_SX115 389 | MTHC0_SX205 390 | MTHC0_SX25 391 | MTHC0_SX295 392 | MTHC0_SX385 393 | MWJG0_SI1124 394 | MWJG0_SI1754 395 | MWJG0_SI494 396 | MWJG0_SX134 397 | MWJG0_SX224 398 | MWJG0_SX314 399 | MWJG0_SX404 400 | MWJG0_SX44 401 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .extracted_features_dataset import ExtractedFeaturesDataset 7 | from .random_input_dataset import RandomInputDataset 8 | 9 | 10 | __all__ = [ 11 | "ExtractedFeaturesDataset", 12 | "RandomInputDataset", 13 | ] 14 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/data/extracted_features_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import logging 8 | import os 9 | import contextlib 10 | 11 | import numpy as np 12 | import torch 13 | 14 | from fairseq.data import FairseqDataset, data_utils 15 | 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | class ExtractedFeaturesDataset(FairseqDataset): 21 | def __init__( 22 | self, 23 | path, 24 | split, 25 | min_length=3, 26 | max_length=None, 27 | labels=None, 28 | label_dict=None, 29 | shuffle=True, 30 | sort_by_length=True, 31 | ): 32 | super().__init__() 33 | 34 | self.min_length = min_length 35 | self.max_length = max_length 36 | self.shuffle = shuffle 37 | self.sort_by_length = sort_by_length 38 | self.label_dict = label_dict 39 | 40 | if labels is not None: 41 | assert label_dict is not None 42 | 43 | self.sizes = [] 44 | self.offsets = [] 45 | self.labels = [] 46 | 47 | path = os.path.join(path, split) 48 | data_path = path 49 | self.data = np.load(data_path + ".npy", mmap_mode="r") 50 | 51 | offset = 0 52 | skipped = 0 53 | 54 | if not os.path.exists(path + f".{labels}"): 55 | labels = None 56 | 57 | with open(data_path + ".lengths", "r") as len_f, open( 58 | path + f".{labels}", "r" 59 | ) if labels is not None else contextlib.ExitStack() as lbl_f: 60 | for line in len_f: 61 | length = int(line.rstrip()) 62 | lbl = None if labels is None else next(lbl_f).rstrip().split() 63 | if length >= min_length and ( 64 | max_length is None or length <= max_length 65 | ): 66 | self.sizes.append(length) 67 | self.offsets.append(offset) 68 | if lbl is not None: 69 | self.labels.append(lbl) 70 | offset += length 71 | 72 | self.sizes = np.asarray(self.sizes) 73 | self.offsets = np.asarray(self.offsets) 74 | 75 | logger.info(f"loaded {len(self.offsets)}, skipped {skipped} samples") 76 | 77 | def __getitem__(self, index): 78 | offset = self.offsets[index] 79 | end = self.sizes[index] + offset 80 | feats = torch.from_numpy(self.data[offset:end].copy()).float() 81 | 82 | res = {"id": index, "features": feats} 83 | if len(self.labels) > 0: 84 | res["target"] = self.label_dict.encode_line( 85 | self.labels[index], 86 | line_tokenizer=lambda x: x, 87 | append_eos=False, 88 | ) 89 | 90 | return res 91 | 92 | def __len__(self): 93 | return len(self.sizes) 94 | 95 | def collater(self, samples): 96 | if len(samples) == 0: 97 | return {} 98 | 99 | features = [s["features"] for s in samples] 100 | sizes = [len(s) for s in features] 101 | 102 | target_size = max(sizes) 103 | 104 | collated_features = features[0].new_zeros( 105 | len(features), target_size, features[0].size(-1) 106 | ) 107 | padding_mask = torch.BoolTensor(collated_features.shape[:-1]).fill_(False) 108 | for i, (f, size) in enumerate(zip(features, sizes)): 109 | collated_features[i, :size] = f 110 | padding_mask[i, size:] = True 111 | 112 | res = { 113 | "id": torch.LongTensor([s["id"] for s in samples]), 114 | "net_input": {"features": collated_features, "padding_mask": padding_mask}, 115 | } 116 | 117 | if len(self.labels) > 0: 118 | target = data_utils.collate_tokens( 119 | [s["target"] for s in samples], 120 | pad_idx=self.label_dict.pad(), 121 | left_pad=False, 122 | ) 123 | res["target"] = target 124 | return res 125 | 126 | def num_tokens(self, index): 127 | return self.size(index) 128 | 129 | def size(self, index): 130 | return self.sizes[index] 131 | 132 | def ordered_indices(self): 133 | """Return an ordered list of indices. Batches will be constructed based 134 | on this order.""" 135 | if self.shuffle: 136 | order = [np.random.permutation(len(self))] 137 | else: 138 | order = [np.arange(len(self))] 139 | 140 | if self.sort_by_length: 141 | order.append(self.sizes) 142 | return np.lexsort(order)[::-1] 143 | else: 144 | return order[0] 145 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/data/random_input_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import random 7 | from typing import List 8 | 9 | from fairseq.data import BaseWrapperDataset, data_utils 10 | 11 | 12 | class RandomInputDataset(BaseWrapperDataset): 13 | def __init__( 14 | self, 15 | dataset, 16 | random_input_dataset, 17 | input_key_path: List[str], 18 | add_to_input, 19 | pad_idx, 20 | ): 21 | super().__init__(dataset) 22 | self.random_input_dataset = random_input_dataset 23 | if isinstance(input_key_path, str): 24 | input_key_path = [input_key_path] 25 | assert len(input_key_path) > 0 26 | self.input_key_path = input_key_path 27 | self.add_to_input = add_to_input 28 | self.pad_idx = pad_idx 29 | 30 | def get_target(self, item): 31 | target_loc = item 32 | for p in self.input_key_path[:-1]: 33 | target_loc = target_loc[p] 34 | return self.input_key_path[-1], target_loc 35 | 36 | def get_target_value(self, item): 37 | k, target_loc = self.get_target(item) 38 | return target_loc[k] 39 | 40 | def __getitem__(self, index): 41 | item = self.dataset[index] 42 | k, target_loc = self.get_target(item) 43 | target_loc[k] = random.choice(self.random_input_dataset) 44 | return item 45 | 46 | def collater(self, samples): 47 | collated = self.dataset.collater(samples) 48 | if len(collated) == 0: 49 | return collated 50 | indices = set(collated["id"].tolist()) 51 | 52 | random_inputs = data_utils.collate_tokens( 53 | [self.get_target_value(s) for s in samples if s["id"] in indices], 54 | pad_idx=self.pad_idx, 55 | left_pad=False, 56 | ) 57 | k, target_loc = self.get_target( 58 | collated if not self.add_to_input else collated["net_input"] 59 | ) 60 | target_loc[k] = random_inputs 61 | 62 | return collated 63 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/kaldi_self_train/README.md: -------------------------------------------------------------------------------- 1 | # Self-Training with Kaldi HMM Models 2 | This folder contains recipes for self-training on pseudo phone transcripts and 3 | decoding into phones or words with [kaldi](https://github.com/kaldi-asr/kaldi). 4 | 5 | To start, download and install kaldi follow its instruction, and place this 6 | folder in `path/to/kaldi/egs`. 7 | 8 | ## Training 9 | Assuming the following has been prepared: 10 | - `w2v_dir`: contains features `{train,valid}.{npy,lengths}`, real transcripts `{train,valid}.${label}`, and dict `dict.${label}.txt` 11 | - `lab_dir`: contains pseudo labels `{train,valid}.txt` 12 | - `arpa_lm`: Arpa-format n-gram phone LM for decoding 13 | - `arpa_lm_bin`: Arpa-format n-gram phone LM for unsupervised model selection to be used with KenLM 14 | 15 | Set these variables in `train.sh`, as well as `out_dir`, the output directory, 16 | and then run it. 17 | 18 | The output will be: 19 | ``` 20 | ==== WER w.r.t. real transcript (select based on unsupervised metric) 21 | INFO:root:./out/exp/mono/decode_valid/scoring/14.0.0.tra.txt: score 0.9178 wer 28.71% lm_ppl 24.4500 gt_wer 25.57% 22 | INFO:root:./out/exp/tri1/decode_valid/scoring/17.1.0.tra.txt: score 0.9257 wer 26.99% lm_ppl 30.8494 gt_wer 21.90% 23 | INFO:root:./out/exp/tri2b/decode_valid/scoring/8.0.0.tra.txt: score 0.7506 wer 23.15% lm_ppl 25.5944 gt_wer 15.78% 24 | ``` 25 | where `wer` is the word eror rate with respect to the pseudo label, `gt_wer` to 26 | the ground truth label, `lm_ppl` the language model perplexity of HMM prediced 27 | transcripts, and `score` is the unsupervised metric for model selection. We 28 | choose the model and the LM parameter of the one with the lowest score. In the 29 | example above, it is `tri2b`, `8.0.0`. 30 | 31 | 32 | ## Decoding into Phones 33 | In `decode_phone.sh`, set `out_dir` the same as used in `train.sh`, set 34 | `dec_exp` and `dec_lmparam` to the selected model and LM parameter (e.g. 35 | `tri2b` and `8.0.0` in the above example). `dec_script` needs to be set 36 | according to `dec_exp`: for mono/tri1/tri2b, use `decode.sh`; for tri3b, use 37 | `decode_fmllr.sh`. 38 | 39 | The output will be saved at `out_dir/dec_data` 40 | 41 | 42 | ## Decoding into Words 43 | `decode_word_step1.sh` prepares WFSTs for word decoding. Besides the variables 44 | mentioned above, set 45 | - `wrd_arpa_lm`: Arpa-format n-gram word LM for decoding 46 | - `wrd_arpa_lm_bin`: Arpa-format n-gram word LM for unsupervised model selection 47 | 48 | `decode_word_step1.sh` decodes the `train` and `valid` split into word and runs 49 | unsupervised model selection using the `valid` split. The output is like: 50 | ``` 51 | INFO:root:./out/exp/tri2b/decodeword_valid/scoring/17.0.0.tra.txt: score 1.8693 wer 24.97% lm_ppl 1785.5333 gt_wer 31.45% 52 | ``` 53 | 54 | After determining the LM parameter (`17.0.0` in the example above), set it in 55 | `decode_word_step2.sh` and run it. The output will be saved at 56 | `out_dir/dec_data_word`. 57 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/kaldi_self_train/st/cmd.sh: -------------------------------------------------------------------------------- 1 | # you can change cmd.sh depending on what type of queue you are using. 2 | # If you have no queueing system and want to run on a local machine, you 3 | # can change all instances 'queue.pl' to run.pl (but be careful and run 4 | # commands one by one: most recipes will exhaust the memory on your 5 | # machine). queue.pl works with GridEngine (qsub). slurm.pl works 6 | # with slurm. Different queues are configured differently, with different 7 | # queue names and different ways of specifying things like memory; 8 | # to account for these differences you can create and edit the file 9 | # conf/queue.conf to match your queue's configuration. Search for 10 | # conf/queue.conf in http://kaldi-asr.org/doc/queue.html for more information, 11 | # or search for the string 'default_config' in utils/queue.pl or utils/slurm.pl. 12 | 13 | export train_cmd="run.pl --mem 2G" 14 | export decode_cmd="run.pl --mem 4G" 15 | export mkgraph_cmd="run.pl --mem 8G" 16 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/kaldi_self_train/st/decode_phone.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # decode into phones (and prepare a new data directory for HMM outputs) 4 | 5 | . ./path.sh 6 | 7 | set -eu 8 | 9 | out_dir= # same as in train.sh 10 | dec_lmparam= # LM hyperparameters (e.g., 7.0.0) 11 | dec_exp= 12 | dec_script= 13 | dec_splits="train valid" 14 | dec_data_dir=$out_dir/dec_data # where to write HMM output 15 | 16 | data_dir=${out_dir}/data 17 | 18 | local/decode.sh --nj 40 --graph_name graph \ 19 | --val_sets "$dec_splits" --decode_script $dec_script \ 20 | $out_dir/exp/$dec_exp $data_dir $data_dir/lang_test 21 | 22 | if [ ! -z $dec_lmparam ]; then 23 | for x in $dec_splits; do 24 | mkdir -p $dec_data_dir/$x 25 | cp $data_dir/$x/{feats.scp,cmvn.scp,utt2spk,spk2utt} $dec_data_dir/$x/ 26 | 27 | tra=$out_dir/exp/$dec_exp/decode_${x}/scoring/${dec_lmparam}.tra 28 | cat $tra | utils/int2sym.pl -f 2- $data_dir/lang/words.txt | \ 29 | sed 's:::g' | sed 's:::g' > $dec_data_dir/${x}/text 30 | utils/fix_data_dir.sh $dec_data_dir/${x} 31 | echo "WER on ${x} is" $(compute-wer ark:$data_dir/${x}_gt/text ark:$dec_data_dir/$x/text | cut -d" " -f2-) 32 | done 33 | fi 34 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/kaldi_self_train/st/decode_word_step1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # prepare word WFSTs, reference data, and decode 4 | 5 | set -eu 6 | 7 | w2v_dir= # same as in train.sh 8 | out_dir= # same as in train.sh 9 | lexicon= # word to phone mapping 10 | wrd_arpa_lm= # word LM 11 | wrd_arpa_lm_bin= # word LM for KenLM, used in unsupervised selection 12 | 13 | dec_exp= # what HMM stage to decode (e.g., tri3b) 14 | dec_script= # what decoding script to use (e.g., steps/decode_fmllr.sh) 15 | phn_label=phnc 16 | wrd_label=wrd 17 | dec_suffix=word 18 | dec_splits="train valid" 19 | valid_split="valid" 20 | 21 | data_dir=$out_dir/data 22 | wrd_data_dir=$out_dir/data_word 23 | 24 | lexicon_clean=$(mktemp) 25 | cat $lexicon | sort | uniq > $lexicon_clean 26 | local/prepare_lang_word.sh $w2v_dir/dict.${phn_label}.txt $data_dir $lexicon_clean && rm $lexicon_clean 27 | local/prepare_lm.sh --langdir $data_dir/lang_word --lmdir $data_dir/lang_test_word $wrd_arpa_lm $data_dir 28 | 29 | for x in $dec_splits; do 30 | x_gt=${x}_gt 31 | mkdir -p $wrd_data_dir/$x_gt 32 | cp $data_dir/$x_gt/{feats.scp,cmvn.scp,utt2spk,spk2utt} $wrd_data_dir/$x_gt/ 33 | python local/copy_aligned_text.py < $w2v_dir/$x.$wrd_label > $wrd_data_dir/$x_gt/text 34 | done 35 | 36 | local/decode.sh --nj 40 --graph_name graph${dec_suffix} --decode_suffix $dec_suffix \ 37 | --val_sets "$dec_splits" --decode_script $dec_script \ 38 | $out_dir/exp/$dec_exp $data_dir $data_dir/lang_test_word 39 | 40 | local/unsup_select_decode_word.sh \ 41 | --split $valid_split --kenlm_path $wrd_arpa_lm_bin \ 42 | --ref_txt $wrd_data_dir/${valid_split}_gt/text \ 43 | --psd_txt $data_dir/${valid_split}/text \ 44 | --dec_name decode${dec_suffix} --graph_name graph${dec_suffix} \ 45 | --phonemize_lexicon $data_dir/local/dict_word/lexicon.txt \ 46 | $out_dir/exp 47 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/kaldi_self_train/st/decode_word_step2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # prepare a new data directory of HMM word output 4 | 5 | . ./path.sh 6 | 7 | set -eu 8 | 9 | out_dir= # same as in train.sh 10 | dec_lmparam= # LM hyperparameters (e.g., 7.0.0) 11 | 12 | dec_exp=tri3b # what HMM stage to decode (e.g., tri3b) 13 | dec_suffix=word 14 | dec_splits="train valid" 15 | dec_data_dir=$out_dir/dec_data_word # where to write HMM output 16 | 17 | data_dir=$out_dir/data 18 | wrd_data_dir=$out_dir/data_word 19 | 20 | for x in $dec_splits; do 21 | mkdir -p $dec_data_dir/$x 22 | cp $data_dir/$x/{feats.scp,cmvn.scp,utt2spk,spk2utt} $dec_data_dir/$x/ 23 | 24 | tra=$out_dir/exp/$dec_exp/decode${dec_suffix}_${x}/scoring/${dec_lmparam}.tra 25 | cat $tra | utils/int2sym.pl -f 2- $data_dir/lang_word/words.txt | \ 26 | sed 's:::g' | sed 's:::g' > $dec_data_dir/$x/text 27 | utils/fix_data_dir.sh $dec_data_dir/$x 28 | echo "WER on $x is" $(compute-wer ark:$wrd_data_dir/${x}_gt/text ark:$dec_data_dir/$x/text | cut -d" " -f2-) 29 | done 30 | 31 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/kaldi_self_train/st/local/copy_aligned_text.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | for idx, line in enumerate(sys.stdin): 4 | print(f"utt{idx:010d} {line}", end='') -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/kaldi_self_train/st/local/decode.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -u 4 | 5 | val_sets="dev_other" 6 | graph_name=graph 7 | decode_suffix="" 8 | decode_script="steps/decode_fmllr.sh" 9 | decode_args="" 10 | nj=60 11 | 12 | . ./cmd.sh 13 | . ./path.sh 14 | . parse_options.sh 15 | 16 | set -x 17 | exp_dir=$1 18 | data_root=$2 19 | lang_test=$3 20 | 21 | graph=$exp_dir/$graph_name 22 | 23 | if [ ! -d $graph ]; then 24 | utils/mkgraph.sh $lang_test $exp_dir $graph 25 | fi 26 | 27 | for part in $val_sets; do 28 | dec_dir=$exp_dir/decode${decode_suffix}_${part} 29 | if [ ! -d $dec_dir ]; then 30 | echo "decoding $part for $exp_dir" 31 | $decode_script --nj $nj --cmd "$decode_cmd" $decode_args \ 32 | $graph $data_root/$part $dec_dir & 33 | else 34 | echo "$dec_dir exists. skip" 35 | fi 36 | done 37 | 38 | wait 39 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/kaldi_self_train/st/local/prepare_data_from_w2v.py: -------------------------------------------------------------------------------- 1 | import kaldi_io 2 | import numpy as np 3 | import os 4 | 5 | 6 | def get_parser(): 7 | import argparse 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("w2v_dir", help="wav2vec feature and text directory") 10 | parser.add_argument("tar_root", help="output data directory in kaldi's format") 11 | parser.add_argument("split", help="name of the subset") 12 | parser.add_argument("--label", default="", help="if specified, copy labels too") 13 | return parser 14 | 15 | def main(): 16 | parser = get_parser() 17 | args = parser.parse_args() 18 | 19 | tar_dir = os.path.join(args.tar_root, args.split) 20 | os.makedirs(tar_dir, exist_ok=True) 21 | 22 | lengths_path = os.path.join(args.w2v_dir, f"{args.split}.lengths") 23 | with open(lengths_path) as f: 24 | lengths = [int(line.rstrip()) for line in f] 25 | offsets = [0] + np.cumsum(lengths[:-1]).tolist() 26 | feats = np.load( 27 | os.path.join(args.w2v_dir, f"{args.split}.npy"), 28 | mmap_mode="r" 29 | ) 30 | assert feats.shape[0] == sum(lengths), \ 31 | f"lengths mismatch {feats.shape[0]} != {sum(lengths)}" 32 | 33 | ark_path = os.path.join(tar_dir, "feats.ark") 34 | scp_path = os.path.join(tar_dir, "feats.scp") 35 | wspec = f"ark:| copy-feats --compress=true ark:- ark,scp:{ark_path},{scp_path}" 36 | with kaldi_io.open_or_fd(wspec, "wb") as f: 37 | for idx, (offset, length) in enumerate(zip(offsets, lengths)): 38 | feat = feats[offset:offset+length] 39 | kaldi_io.write_mat(f, feat, key=f"utt{idx:010d}") 40 | 41 | u2s_path = os.path.join(tar_dir, "utt2spk") 42 | s2u_path = os.path.join(tar_dir, "spk2utt") 43 | with open(u2s_path, "w") as f_u2s, open(s2u_path, "w") as f_s2u: 44 | for idx in range(len(lengths)): 45 | f_u2s.write(f"utt{idx:010d} utt{idx:010d}\n") 46 | f_s2u.write(f"utt{idx:010d} utt{idx:010d}\n") 47 | 48 | if bool(args.label): 49 | lab_path = os.path.join(args.w2v_dir, f"{args.split}.{args.label}") 50 | txt_path = os.path.join(tar_dir, "text") 51 | with open(lab_path) as f_lab, open(txt_path, "w") as f_txt: 52 | for idx, line in enumerate(f_lab): 53 | f_txt.write(f"utt{idx:010d} {line}") 54 | 55 | if __name__ == "__main__": 56 | main() 57 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/kaldi_self_train/st/local/prepare_lang.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | sil_prob=0.5 4 | num_sil_states=3 5 | num_nonsil_states=1 6 | 7 | . ./cmd.sh 8 | . ./path.sh 9 | . parse_options.sh 10 | 11 | set -eux 12 | 13 | dict=$1 14 | data_dir=$2 15 | 16 | dict_dir=$data_dir/local/dict 17 | tmplm_dir=$data_dir/local/lang_tmp 18 | lm_dir=$data_dir/lang 19 | 20 | mkdir -p $dict_dir $tmplm_dir $lm_dir 21 | 22 | # prepare dict 23 | echo "SIL" > $dict_dir/silence_phones.txt 24 | echo "SIL" > $dict_dir/optional_silence.txt 25 | awk '{print $1}' $dict > $dict_dir/nonsilence_phones.txt 26 | 27 | echo "SIL SIL" > $dict_dir/lexicon.txt 28 | echo " SIL" >> $dict_dir/lexicon.txt 29 | awk '{print $1" "$1}' $dict >> $dict_dir/lexicon.txt 30 | 31 | echo "SIL" > $dict_dir/extra_questions.txt 32 | awk '{printf $1" "} END {printf "\n"}' $dict >> $dict_dir/extra_questions.txt 33 | 34 | # prepare lang 35 | utils/prepare_lang.sh --sil-prob $sil_prob --position-dependent-phones false \ 36 | --num_sil_states $num_sil_states --num_nonsil_states $num_nonsil_states \ 37 | $dict_dir "" $tmplm_dir $lm_dir 38 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/kaldi_self_train/st/local/prepare_lang_word.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | num_sil_states=3 4 | num_nonsil_states=1 5 | 6 | . ./cmd.sh 7 | . ./path.sh 8 | . parse_options.sh 9 | 10 | set -eux 11 | 12 | dict=$1 13 | data_dir=$2 14 | lexicon=$3 15 | 16 | dict_dir=$data_dir/local/dict_word 17 | tmplm_dir=$data_dir/local/lang_tmp_word 18 | lm_dir=$data_dir/lang_word 19 | 20 | mkdir -p $dict_dir $tmplm_dir $lm_dir 21 | 22 | # prepare dict 23 | echo "SIL" > $dict_dir/silence_phones.txt 24 | echo "SIL" > $dict_dir/optional_silence.txt 25 | awk '{print $1}' $dict > $dict_dir/nonsilence_phones.txt 26 | 27 | (echo "!SIL SIL"; echo " SIL";) | cat - $lexicon > $dict_dir/lexicon.txt 28 | 29 | echo "SIL" > $dict_dir/extra_questions.txt 30 | awk '{printf $1" "} END {printf "\n"}' $dict >> $dict_dir/extra_questions.txt 31 | 32 | # prepare lang 33 | utils/prepare_lang.sh --position-dependent-phones false \ 34 | --num_sil_states $num_sil_states --num_nonsil_states $num_nonsil_states \ 35 | $dict_dir "" $tmplm_dir $lm_dir 36 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/kaldi_self_train/st/local/prepare_lm.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | langdir="" 4 | lmdir="" 5 | 6 | . ./cmd.sh 7 | . ./path.sh 8 | . parse_options.sh 9 | 10 | arpa_lm=$1 11 | data=$2 12 | 13 | if [ -z $langdir ]; then 14 | langdir=$data/lang 15 | fi 16 | if [ -z $lmdir ]; then 17 | lmdir=$data/lang_test 18 | fi 19 | 20 | if [ ! -d $langdir ]; then 21 | echo "$langdir not found. run local/prepare_lang.sh first" && exit 1 22 | fi 23 | 24 | mkdir -p $lmdir 25 | cp -r $langdir/* $lmdir 26 | 27 | if [[ "$arpa_lm" == *.gz ]]; then 28 | gunzip -c $arpa_lm | arpa2fst --disambig-symbol=#0 --read-symbol-table=$lmdir/words.txt - $lmdir/G.fst 29 | else 30 | arpa2fst --disambig-symbol=#0 --read-symbol-table=$lmdir/words.txt $arpa_lm $lmdir/G.fst 31 | fi 32 | fstisstochastic $lmdir/G.fst 33 | utils/validate_lang.pl $lmdir || exit 1 34 | 35 | echo "done preparing lm ($lmdir)" 36 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/kaldi_self_train/st/local/score.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright 2012 Johns Hopkins University (Author: Daniel Povey) 3 | # 2014 Guoguo Chen 4 | # Apache 2.0 5 | 6 | [ -f ./path.sh ] && . ./path.sh 7 | 8 | # begin configuration section. 9 | cmd=run.pl 10 | stage=0 11 | decode_mbr=true 12 | word_ins_penalty=0.0,0.5,1.0 13 | min_lmwt=7 14 | max_lmwt=17 15 | iter=final 16 | #end configuration section. 17 | 18 | [ -f ./path.sh ] && . ./path.sh 19 | . parse_options.sh || exit 1; 20 | 21 | if [ $# -ne 3 ]; then 22 | echo "Usage: local/score.sh [--cmd (run.pl|queue.pl...)] " 23 | echo " Options:" 24 | echo " --cmd (run.pl|queue.pl...) # specify how to run the sub-processes." 25 | echo " --stage (0|1|2) # start scoring script from part-way through." 26 | echo " --decode_mbr (true/false) # maximum bayes risk decoding (confusion network)." 27 | echo " --min_lmwt # minumum LM-weight for lattice rescoring " 28 | echo " --max_lmwt # maximum LM-weight for lattice rescoring " 29 | exit 1; 30 | fi 31 | 32 | data=$1 33 | lang_or_graph=$2 34 | dir=$3 35 | 36 | symtab=$lang_or_graph/words.txt 37 | 38 | for f in $symtab $dir/lat.1.gz $data/text; do 39 | [ ! -f $f ] && echo "score.sh: no such file $f" && exit 1; 40 | done 41 | 42 | mkdir -p $dir/scoring/log 43 | 44 | cat $data/text | sed 's:::g' | sed 's:::g' > $dir/scoring/test_filt.txt 45 | 46 | for wip in $(echo $word_ins_penalty | sed 's/,/ /g'); do 47 | $cmd LMWT=$min_lmwt:$max_lmwt $dir/scoring/log/best_path.LMWT.$wip.log \ 48 | lattice-scale --inv-acoustic-scale=LMWT "ark:gunzip -c $dir/lat.*.gz|" ark:- \| \ 49 | lattice-add-penalty --word-ins-penalty=$wip ark:- ark:- \| \ 50 | lattice-best-path --word-symbol-table=$symtab \ 51 | ark:- ark,t:$dir/scoring/LMWT.$wip.tra || exit 1; 52 | done 53 | 54 | # Note: the double level of quoting for the sed command 55 | for wip in $(echo $word_ins_penalty | sed 's/,/ /g'); do 56 | $cmd LMWT=$min_lmwt:$max_lmwt $dir/scoring/log/score.LMWT.$wip.log \ 57 | cat $dir/scoring/LMWT.$wip.tra \| \ 58 | utils/int2sym.pl -f 2- $symtab \| sed 's:\::g' \| \ 59 | compute-wer --text --mode=present \ 60 | ark:$dir/scoring/test_filt.txt ark,p:- ">&" $dir/wer_LMWT_$wip || exit 1; 61 | done 62 | 63 | exit 0; 64 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/kaldi_self_train/st/local/show_wer.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | split="dev_other" 4 | ref_data="" 5 | get_best_wer=true 6 | dec_name="decode" 7 | graph_name="graph" 8 | 9 | . ./cmd.sh 10 | . ./path.sh 11 | . parse_options.sh 12 | 13 | exp_root=$1 14 | 15 | set -eu 16 | 17 | echo "==== WER w.r.t. pseudo transcript" 18 | for x in $exp_root/*/${dec_name}_${split}*; do grep WER $x/wer_* 2>/dev/null | utils/best_wer.sh; done 19 | 20 | 21 | if [ ! -z $ref_data ]; then 22 | echo "==== WER w.r.t. real transcript (select based on pseudo WER)" 23 | ref_txt=$ref_data/$split/text 24 | for x in $exp_root/*/${dec_name}_${split}*; do 25 | lang=$(dirname $x)/$graph_name 26 | 27 | lmwt=$( 28 | grep WER $x/wer_* 2>/dev/null | utils/best_wer.sh | 29 | sed 's/.*wer_\(.*\)$/\1/g' | sed 's/_/./g' 30 | ) 31 | tra=$x/scoring/$lmwt.tra 32 | cat $tra | utils/int2sym.pl -f 2- $lang/words.txt | sed 's:::g' | sed 's:::g' | \ 33 | compute-wer --text --mode=present \ 34 | ark:$ref_txt ark,p:- 2> /dev/null | grep WER | xargs -I{} echo {} $tra 35 | done 36 | fi 37 | 38 | if [ ! -z $ref_data ] && $get_best_wer; then 39 | echo "==== WER w.r.t. real transcript (select based on true WER)" 40 | ref_txt=$ref_data/$split/text 41 | for x in $exp_root/*/${dec_name}_${split}*; do 42 | lang=$(dirname $x)/$graph_name 43 | 44 | for tra in $x/scoring/*.tra; do 45 | cat $tra | utils/int2sym.pl -f 2- $lang/words.txt | sed 's:::g' | sed 's:::g' | \ 46 | compute-wer --text --mode=present \ 47 | ark:$ref_txt ark,p:- 2> /dev/null | grep WER | xargs -I{} echo {} $tra 48 | done | sort -k2n | head -n1 49 | done 50 | fi 51 | 52 | exit 0; 53 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/kaldi_self_train/st/local/train_subset_lgbeam.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | out_root=/tmp 4 | out_name=train_${RANDOM} 5 | num_nonsil_states=1 6 | 7 | valid="dev_other" 8 | train="train" 9 | mono_size="-1" # 2000 10 | tri1_size="-1" # 5000 11 | tri2b_size="-1" # 10000 12 | tri3b_size="-1" # 10000 13 | 14 | # Acoustic model parameters 15 | numLeavesTri1=2000 16 | numGaussTri1=10000 17 | numLeavesMLLT=2500 18 | numGaussMLLT=15000 19 | numLeavesSAT=2500 20 | numGaussSAT=15000 21 | 22 | stage=1 23 | max_stage=1 24 | 25 | . ./cmd.sh 26 | . ./path.sh 27 | . parse_options.sh 28 | 29 | data=$1 30 | lang=$2 31 | lang_test=$3 32 | 33 | exp_root=$out_root/$out_name 34 | 35 | # you might not want to do this for interactive shells. 36 | set -e 37 | 38 | 39 | if [ $stage -le 1 ] && [ $max_stage -ge 1 ]; then 40 | # train a monophone system 41 | if [ ! $mono_size -eq -1 ]; then 42 | utils/subset_data_dir.sh $data/$train $mono_size $data/${train}_${mono_size} 43 | mono_train=${train}_${mono_size} 44 | else 45 | mono_train=${train} 46 | fi 47 | 48 | steps/train_mono.sh --boost-silence 1.25 --nj 20 --cmd "$train_cmd" \ 49 | --initial-beam 40 --regular-beam 60 --retry-beam 120 \ 50 | $data/$mono_train $lang $exp_root/mono 51 | 52 | utils/mkgraph.sh $lang_test $exp_root/mono $exp_root/mono/graph 53 | steps/decode.sh --nj 20 --cmd "$decode_cmd" \ 54 | $exp_root/mono/graph $data/$valid $exp_root/mono/decode_$valid & 55 | fi 56 | 57 | 58 | if [ $stage -le 2 ] && [ $max_stage -ge 2 ]; then 59 | # train a first delta + delta-delta triphone system on a subset of 5000 utterances 60 | if [ ! $tri1_size -eq -1 ]; then 61 | utils/subset_data_dir.sh $data/$train $tri1_size $data/${train}_${tri1_size} 62 | tri1_train=${train}_${tri1_size} 63 | else 64 | tri1_train=${train} 65 | fi 66 | 67 | steps/align_si.sh --boost-silence 1.25 --nj 10 --cmd "$train_cmd" \ 68 | $data/$tri1_train $lang \ 69 | $exp_root/mono $exp_root/mono_ali_${tri1_train} 70 | 71 | steps_gan/train_deltas.sh --boost-silence 1.25 --cmd "$train_cmd" \ 72 | --num_nonsil_states $num_nonsil_states $numLeavesTri1 $numGaussTri1 \ 73 | $data/$tri1_train $lang \ 74 | $exp_root/mono_ali_${tri1_train} $exp_root/tri1 75 | 76 | utils/mkgraph.sh $lang_test $exp_root/tri1 $exp_root/tri1/graph 77 | steps/decode.sh --nj 20 --cmd "$decode_cmd" \ 78 | $exp_root/tri1/graph $data/$valid $exp_root/tri1/decode_$valid & 79 | fi 80 | 81 | if [ $stage -le 3 ] && [ $max_stage -ge 3 ]; then 82 | # train an LDA+MLLT system. 83 | if [ ! $tri2b_size -eq -1 ]; then 84 | utils/subset_data_dir.sh $data/$train $tri2b_size $data/${train}_${tri2b_size} 85 | tri2b_train=${train}_${tri2b_size} 86 | else 87 | tri2b_train=${train} 88 | fi 89 | 90 | steps/align_si.sh --nj 10 --cmd "$train_cmd" \ 91 | $data/$tri2b_train $lang \ 92 | $exp_root/tri1 $exp_root/tri1_ali_${tri2b_train} 93 | 94 | steps_gan/train_lda_mllt.sh --cmd "$train_cmd" \ 95 | --num_nonsil_states $num_nonsil_states \ 96 | --splice-opts "--left-context=3 --right-context=3" $numLeavesMLLT $numGaussMLLT \ 97 | $data/$tri2b_train $lang \ 98 | $exp_root/tri1_ali_${tri2b_train} $exp_root/tri2b 99 | 100 | utils/mkgraph.sh $lang_test $exp_root/tri2b $exp_root/tri2b/graph 101 | steps/decode.sh --nj 20 --cmd "$decode_cmd" \ 102 | $exp_root/tri2b/graph $data/$valid $exp_root/tri2b/decode_$valid & 103 | fi 104 | 105 | 106 | if [ $stage -le 4 ] && [ $max_stage -ge 4 ]; then 107 | # Train tri3b, which is LDA+MLLT+SAT on 10k utts 108 | if [ ! $tri3b_size -eq -1 ]; then 109 | utils/subset_data_dir.sh $data/$train $tri3b_size $data/${train}_${tri3b_size} 110 | tri3b_train=${train}_${tri3b_size} 111 | else 112 | tri3b_train=${train} 113 | fi 114 | 115 | steps/align_si.sh --nj 10 --cmd "$train_cmd" --use-graphs true \ 116 | $data/$tri3b_train $lang \ 117 | $exp_root/tri2b $exp_root/tri2b_ali_${tri2b_train} 118 | 119 | steps_gan/train_sat.sh --cmd "$train_cmd" \ 120 | --num_nonsil_states $num_nonsil_states $numLeavesSAT $numGaussSAT \ 121 | $data/$tri3b_train $lang \ 122 | $exp_root/tri2b_ali_${tri2b_train} $exp_root/tri3b 123 | 124 | utils/mkgraph.sh $lang_test $exp_root/tri3b $exp_root/tri3b/graph 125 | steps/decode_fmllr.sh --nj 20 --cmd "$decode_cmd" \ 126 | $exp_root/tri3b/graph $data/$valid $exp_root/tri3b/decode_$valid & 127 | fi 128 | 129 | wait 130 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/kaldi_self_train/st/local/unsup_select.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implement unsupervised metric for decoding hyperparameter selection: 3 | $$ alpha * LM_PPL + ViterbitUER(%) * 100 $$ 4 | """ 5 | import argparse 6 | import logging 7 | import math 8 | import sys 9 | 10 | import kenlm 11 | import editdistance 12 | from g2p_en import G2p 13 | 14 | logging.root.setLevel(logging.INFO) 15 | logging.basicConfig(stream=sys.stdout, level=logging.INFO) 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | def get_parser(): 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument("ref_tra", help="reference pseudo labels") 22 | parser.add_argument("hyp_tra", help="decoded pseudo labels to be assess") 23 | parser.add_argument("--kenlm_path", default="/checkpoint/abaevski/data/speech/libri/librispeech_lm_novox.phnc_o5.bin", help="") 24 | parser.add_argument("--uppercase", action="store_true", help="") 25 | parser.add_argument("--skipwords", default="", help="") 26 | parser.add_argument("--gt_tra", default="", help="ground truth pseudo labels for computing oracle WER") 27 | parser.add_argument("--min_vt_uer", default=0.0, type=float) 28 | parser.add_argument("--phonemize", action="store_true", help="phonemize word hypotheses, used when reference is phone transcript") 29 | parser.add_argument("--phonemize_lexicon", default="", type=str, help="use a lexicon for phonemizing") 30 | return parser 31 | 32 | def load_tra(tra_path): 33 | with open(tra_path, "r") as f: 34 | uid_to_tra = {} 35 | for line in f: 36 | toks = line.rstrip().split() 37 | uid, tra = toks[0], " ".join(toks[1:]) 38 | uid_to_tra[uid] = tra 39 | logger.debug(f"loaded {len(uid_to_tra)} utterances from {tra_path}") 40 | return uid_to_tra 41 | 42 | def load_lex(lex_path): 43 | with open(lex_path, "r") as f: 44 | w2p = {} 45 | for line in f: 46 | w, p = line.rstrip().split(None, 1) 47 | w2p[w] = p.split() 48 | return w2p 49 | 50 | def compute_wer(ref_uid_to_tra, hyp_uid_to_tra, g2p, g2p_dict): 51 | d_cnt = 0 52 | w_cnt = 0 53 | w_cnt_h = 0 54 | for uid in hyp_uid_to_tra: 55 | ref = ref_uid_to_tra[uid].split() 56 | if g2p_dict is not None: 57 | hyp = [] 58 | for word in hyp_uid_to_tra[uid].split(): 59 | if word in g2p_dict: 60 | hyp = hyp + g2p_dict[word] 61 | else: 62 | logger.warning(f"{word} not in g2p_dict") 63 | elif g2p is not None: 64 | hyp = g2p(hyp_uid_to_tra[uid]) 65 | hyp = [p for p in hyp if p != "'" and p != " "] 66 | hyp = [p[:-1] if p[-1].isnumeric() else p for p in hyp] 67 | else: 68 | hyp = hyp_uid_to_tra[uid].split() 69 | logger.debug(( 70 | f"======================\n" 71 | f"HYP: {' '.join(hyp)}\n" 72 | f"REF: {' '.join(ref)}" 73 | )) 74 | d_cnt += editdistance.eval(ref, hyp) 75 | w_cnt += len(ref) 76 | w_cnt_h += len(hyp) 77 | wer = float(d_cnt) / w_cnt 78 | logger.debug(( 79 | f"wer = {wer*100:.2f}%; num. of ref words = {w_cnt}; " 80 | f"num. of hyp words = {w_cnt_h}; num. of sentences = {len(ref_uid_to_tra)}" 81 | )) 82 | return wer 83 | 84 | def compute_lm_ppl(hyp_uid_to_tra, score_fn): 85 | lm_score = 0. 86 | w_cnt = 0 87 | for hyp in hyp_uid_to_tra.values(): 88 | cur_score = score_fn(hyp) 89 | cur_cnt = len(hyp.split()) + 1 # plus one for 90 | lm_score += cur_score 91 | w_cnt += cur_cnt 92 | logger.debug(( 93 | f"======================\n" 94 | f"score sum/avg = {cur_score:.2f}/{cur_score/cur_cnt:.2f}\n" 95 | f"hyp = {hyp}" 96 | )) 97 | lm_ppl = math.pow(10, -lm_score / w_cnt) 98 | logger.debug(f"lm ppl = {lm_ppl:.2f}; num. of words = {w_cnt}") 99 | return lm_ppl 100 | 101 | def main(): 102 | args = get_parser().parse_args() 103 | logger.debug(f"Args: {args}") 104 | 105 | ref_uid_to_tra = load_tra(args.ref_tra) 106 | hyp_uid_to_tra = load_tra(args.hyp_tra) 107 | assert not bool(set(hyp_uid_to_tra.keys()) - set(ref_uid_to_tra.keys())) 108 | 109 | lm = kenlm.Model(args.kenlm_path) 110 | skipwords = set(args.skipwords.split(",")) 111 | def compute_lm_score(s): 112 | s = " ".join(w for w in s.split() if w not in skipwords) 113 | s = s.upper() if args.uppercase else s 114 | return lm.score(s) 115 | 116 | g2p, g2p_dict = None, None 117 | if args.phonemize: 118 | if args.phonemize_lexicon: 119 | g2p_dict = load_lex(args.phonemize_lexicon) 120 | else: 121 | g2p = G2p() 122 | 123 | wer = compute_wer(ref_uid_to_tra, hyp_uid_to_tra, g2p, g2p_dict) 124 | lm_ppl = compute_lm_ppl(hyp_uid_to_tra, compute_lm_score) 125 | 126 | gt_wer = -math.inf 127 | if args.gt_tra: 128 | gt_uid_to_tra = load_tra(args.gt_tra) 129 | gt_wer = compute_wer(gt_uid_to_tra, hyp_uid_to_tra, None, None) 130 | 131 | score = math.log(lm_ppl) * max(wer, args.min_vt_uer) 132 | logging.info(f"{args.hyp_tra}: score={score:.4f}; wer={wer*100:.2f}%; lm_ppl={lm_ppl:.4f}; gt_wer={gt_wer*100:.2f}%") 133 | 134 | if __name__ == "__main__": 135 | main() 136 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/kaldi_self_train/st/local/unsup_select_decode.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | split="dev_other" 4 | ref_txt="" # ground truth transcript path 5 | psd_txt="" # pseudo transcript path 6 | get_best_wer=true 7 | dec_name="decode" 8 | graph_name="graph" 9 | kenlm_path=/checkpoint/abaevski/data/speech/libri/librispeech_lm_novox.phnc_o6.bin 10 | 11 | . ./cmd.sh 12 | . ./path.sh 13 | . parse_options.sh 14 | 15 | exp_root=$1 16 | unsup_args="" 17 | if [ $# -ge 2 ]; then 18 | unsup_args=$2 19 | fi 20 | 21 | set -eu 22 | 23 | if [ ! -z $ref_txt ] && $get_best_wer; then 24 | echo "==== WER w.r.t. real transcript (select based on unsupervised metric)" 25 | for x in $exp_root/*/${dec_name}_${split}*; do 26 | lang=$(dirname $x)/$graph_name 27 | 28 | ( 29 | for tra in $x/scoring/*.tra; do 30 | cat $tra | utils/int2sym.pl -f 2- $lang/words.txt | sed 's:::g' | sed 's:::g' > $tra.txt 31 | python local/unsup_select.py $psd_txt $tra.txt --kenlm_path $kenlm_path --gt_tra $ref_txt $unsup_args 32 | done 2>/dev/null | grep "score=" | sed 's/=/ /g' | sed 's/;//g' | sort -k3n | head -n1 33 | ) & 34 | done 35 | fi 36 | wait 37 | 38 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/kaldi_self_train/st/local/unsup_select_decode_word.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | split="dev_other" 4 | ref_txt="" # ground truth transcript path 5 | psd_txt="" # pseudo transcript path 6 | get_best_wer=true 7 | dec_name="decode" 8 | graph_name="graph" 9 | kenlm_path=/checkpoint/abaevski/data/speech/libri/librispeech_lm_novox.phnc_o6.bin 10 | phonemize_lexicon="" 11 | 12 | . ./cmd.sh 13 | . ./path.sh 14 | . parse_options.sh 15 | . /private/home/wnhsu/unsup_asr/fairseq-py-unsup/env.sh 16 | 17 | exp_root=$1 18 | 19 | set -eu 20 | 21 | if [ ! -z $ref_txt ] && $get_best_wer; then 22 | echo "==== WER w.r.t. real transcript (select based on unsupervised metric)" 23 | for x in $exp_root/*/${dec_name}_${split}*; do 24 | lang=$(dirname $x)/$graph_name 25 | 26 | for tra in $x/scoring/*.tra; do 27 | cat $tra | utils/int2sym.pl -f 2- $lang/words.txt | sed 's:\::g' > $tra.txt 28 | python local/unsup_select.py $psd_txt $tra.txt \ 29 | --kenlm_path $kenlm_path --gt_tra $ref_txt --phonemize \ 30 | --phonemize_lexicon "$phonemize_lexicon" 31 | done | grep "score=" | sed 's/=/ /g' | sed 's/;//g' | sort -k3n | head -n1 32 | done 33 | fi 34 | 35 | 36 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/kaldi_self_train/st/path.sh: -------------------------------------------------------------------------------- 1 | export KALDI_ROOT=`pwd`/../../.. 2 | export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PWD:$PATH 3 | [ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1 4 | . $KALDI_ROOT/tools/config/common_path.sh 5 | export LC_ALL=C 6 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/kaldi_self_train/st/steps: -------------------------------------------------------------------------------- 1 | ../../wsj/s5/steps -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/kaldi_self_train/st/steps_gan/train_deltas.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright 2012 Johns Hopkins University (Author: Daniel Povey) 4 | # Apache 2.0 5 | 6 | # Begin configuration. 7 | stage=-4 # This allows restarting after partway, when something when wrong. 8 | config= 9 | cmd=run.pl 10 | scale_opts="--transition-scale=1.0 --acoustic-scale=0.1 --self-loop-scale=0.1" 11 | realign_iters="10 20 30"; 12 | num_iters=35 # Number of iterations of training 13 | max_iter_inc=25 # Last iter to increase #Gauss on. 14 | beam=10 15 | careful=false 16 | retry_beam=40 17 | boost_silence=1.0 # Factor by which to boost silence likelihoods in alignment 18 | power=0.25 # Exponent for number of gaussians according to occurrence counts 19 | cluster_thresh=-1 # for build-tree control final bottom-up clustering of leaves 20 | norm_vars=false # deprecated. Prefer --cmvn-opts "--norm-vars=true" 21 | # use the option --cmvn-opts "--norm-means=false" 22 | cmvn_opts= 23 | delta_opts= 24 | context_opts= # use"--context-width=5 --central-position=2" for quinphone 25 | num_nonsil_states=3 26 | # End configuration. 27 | 28 | echo "$0 $@" # Print the command line for logging 29 | 30 | [ -f path.sh ] && . ./path.sh; 31 | . parse_options.sh || exit 1; 32 | 33 | if [ $# != 6 ]; then 34 | echo "Usage: steps/train_deltas.sh " 35 | echo "e.g.: steps/train_deltas.sh 2000 10000 data/train_si84_half data/lang exp/mono_ali exp/tri1" 36 | echo "main options (for others, see top of script file)" 37 | echo " --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs." 38 | echo " --config # config containing options" 39 | echo " --stage # stage to do partial re-run from." 40 | exit 1; 41 | fi 42 | 43 | numleaves=$1 44 | totgauss=$2 45 | data=$3 46 | lang=$4 47 | alidir=$5 48 | dir=$6 49 | 50 | for f in $alidir/final.mdl $alidir/ali.1.gz $data/feats.scp $lang/phones.txt; do 51 | [ ! -f $f ] && echo "train_deltas.sh: no such file $f" && exit 1; 52 | done 53 | 54 | numgauss=$numleaves 55 | incgauss=$[($totgauss-$numgauss)/$max_iter_inc] # per-iter increment for #Gauss 56 | oov=`cat $lang/oov.int` || exit 1; 57 | ciphonelist=`cat $lang/phones/context_indep.csl` || exit 1; 58 | nj=`cat $alidir/num_jobs` || exit 1; 59 | mkdir -p $dir/log 60 | echo $nj > $dir/num_jobs 61 | 62 | utils/lang/check_phones_compatible.sh $lang/phones.txt $alidir/phones.txt || exit 1; 63 | cp $lang/phones.txt $dir || exit 1; 64 | 65 | sdata=$data/split$nj; 66 | split_data.sh $data $nj || exit 1; 67 | 68 | 69 | [ $(cat $alidir/cmvn_opts 2>/dev/null | wc -c) -gt 1 ] && [ -z "$cmvn_opts" ] && \ 70 | echo "$0: warning: ignoring CMVN options from source directory $alidir" 71 | $norm_vars && cmvn_opts="--norm-vars=true $cmvn_opts" 72 | echo $cmvn_opts > $dir/cmvn_opts # keep track of options to CMVN. 73 | [ ! -z $delta_opts ] && echo $delta_opts > $dir/delta_opts 74 | 75 | feats="ark,s,cs:apply-cmvn $cmvn_opts --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp scp:$sdata/JOB/feats.scp ark:- | add-deltas $delta_opts ark:- ark:- |" 76 | 77 | rm $dir/.error 2>/dev/null 78 | 79 | if [ $stage -le -3 ]; then 80 | echo "$0: accumulating tree stats" 81 | $cmd JOB=1:$nj $dir/log/acc_tree.JOB.log \ 82 | acc-tree-stats $context_opts \ 83 | --ci-phones=$ciphonelist $alidir/final.mdl "$feats" \ 84 | "ark:gunzip -c $alidir/ali.JOB.gz|" $dir/JOB.treeacc || exit 1; 85 | sum-tree-stats $dir/treeacc $dir/*.treeacc 2>$dir/log/sum_tree_acc.log || exit 1; 86 | rm $dir/*.treeacc 87 | fi 88 | 89 | if [ $stage -le -2 ]; then 90 | echo "$0: getting questions for tree-building, via clustering" 91 | # preparing questions, roots file... 92 | cluster-phones --pdf-class-list=$(($num_nonsil_states / 2)) $context_opts \ 93 | $dir/treeacc $lang/phones/sets.int \ 94 | $dir/questions.int 2> $dir/log/questions.log || exit 1; 95 | cat $lang/phones/extra_questions.int >> $dir/questions.int 96 | compile-questions $context_opts $lang/topo $dir/questions.int \ 97 | $dir/questions.qst 2>$dir/log/compile_questions.log || exit 1; 98 | 99 | echo "$0: building the tree" 100 | $cmd $dir/log/build_tree.log \ 101 | build-tree $context_opts --verbose=1 --max-leaves=$numleaves \ 102 | --cluster-thresh=$cluster_thresh $dir/treeacc $lang/phones/roots.int \ 103 | $dir/questions.qst $lang/topo $dir/tree || exit 1; 104 | 105 | $cmd $dir/log/init_model.log \ 106 | gmm-init-model --write-occs=$dir/1.occs \ 107 | $dir/tree $dir/treeacc $lang/topo $dir/1.mdl || exit 1; 108 | if grep 'no stats' $dir/log/init_model.log; then 109 | echo "** The warnings above about 'no stats' generally mean you have phones **" 110 | echo "** (or groups of phones) in your phone set that had no corresponding data. **" 111 | echo "** You should probably figure out whether something went wrong, **" 112 | echo "** or whether your data just doesn't happen to have examples of those **" 113 | echo "** phones. **" 114 | fi 115 | 116 | gmm-mixup --mix-up=$numgauss $dir/1.mdl $dir/1.occs $dir/1.mdl 2>$dir/log/mixup.log || exit 1; 117 | rm $dir/treeacc 118 | fi 119 | 120 | if [ $stage -le -1 ]; then 121 | # Convert the alignments. 122 | echo "$0: converting alignments from $alidir to use current tree" 123 | $cmd JOB=1:$nj $dir/log/convert.JOB.log \ 124 | convert-ali $alidir/final.mdl $dir/1.mdl $dir/tree \ 125 | "ark:gunzip -c $alidir/ali.JOB.gz|" "ark:|gzip -c >$dir/ali.JOB.gz" || exit 1; 126 | fi 127 | 128 | if [ $stage -le 0 ]; then 129 | echo "$0: compiling graphs of transcripts" 130 | $cmd JOB=1:$nj $dir/log/compile_graphs.JOB.log \ 131 | compile-train-graphs --read-disambig-syms=$lang/phones/disambig.int $dir/tree $dir/1.mdl $lang/L.fst \ 132 | "ark:utils/sym2int.pl --map-oov $oov -f 2- $lang/words.txt < $sdata/JOB/text |" \ 133 | "ark:|gzip -c >$dir/fsts.JOB.gz" || exit 1; 134 | fi 135 | 136 | x=1 137 | while [ $x -lt $num_iters ]; do 138 | echo "$0: training pass $x" 139 | if [ $stage -le $x ]; then 140 | if echo $realign_iters | grep -w $x >/dev/null; then 141 | echo "$0: aligning data" 142 | mdl="gmm-boost-silence --boost=$boost_silence `cat $lang/phones/optional_silence.csl` $dir/$x.mdl - |" 143 | $cmd JOB=1:$nj $dir/log/align.$x.JOB.log \ 144 | gmm-align-compiled $scale_opts --beam=$beam --retry-beam=$retry_beam --careful=$careful "$mdl" \ 145 | "ark:gunzip -c $dir/fsts.JOB.gz|" "$feats" \ 146 | "ark:|gzip -c >$dir/ali.JOB.gz" || exit 1; 147 | fi 148 | $cmd JOB=1:$nj $dir/log/acc.$x.JOB.log \ 149 | gmm-acc-stats-ali $dir/$x.mdl "$feats" \ 150 | "ark,s,cs:gunzip -c $dir/ali.JOB.gz|" $dir/$x.JOB.acc || exit 1; 151 | $cmd $dir/log/update.$x.log \ 152 | gmm-est --mix-up=$numgauss --power=$power \ 153 | --write-occs=$dir/$[$x+1].occs $dir/$x.mdl \ 154 | "gmm-sum-accs - $dir/$x.*.acc |" $dir/$[$x+1].mdl || exit 1; 155 | rm $dir/$x.mdl $dir/$x.*.acc 156 | rm $dir/$x.occs 157 | fi 158 | [ $x -le $max_iter_inc ] && numgauss=$[$numgauss+$incgauss]; 159 | x=$[$x+1]; 160 | done 161 | 162 | rm $dir/final.mdl $dir/final.occs 2>/dev/null 163 | ln -s $x.mdl $dir/final.mdl 164 | ln -s $x.occs $dir/final.occs 165 | 166 | steps/diagnostic/analyze_alignments.sh --cmd "$cmd" $lang $dir 167 | 168 | # Summarize warning messages... 169 | utils/summarize_warnings.pl $dir/log 170 | 171 | steps/info/gmm_dir_info.pl $dir 172 | 173 | echo "$0: Done training system with delta+delta-delta features in $dir" 174 | 175 | exit 0 176 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/kaldi_self_train/st/steps_gan/train_lda_mllt.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright 2012 Johns Hopkins University (Author: Daniel Povey) 4 | # 5 | # LDA+MLLT refers to the way we transform the features after computing 6 | # the MFCCs: we splice across several frames, reduce the dimension (to 40 7 | # by default) using Linear Discriminant Analysis), and then later estimate, 8 | # over multiple iterations, a diagonalizing transform known as MLLT or STC. 9 | # See http://kaldi-asr.org/doc/transform.html for more explanation. 10 | # 11 | # Apache 2.0. 12 | 13 | # Begin configuration. 14 | cmd=run.pl 15 | config= 16 | stage=-5 17 | scale_opts="--transition-scale=1.0 --acoustic-scale=0.1 --self-loop-scale=0.1" 18 | realign_iters="10 20 30"; 19 | mllt_iters="2 4 6 12"; 20 | num_iters=35 # Number of iterations of training 21 | max_iter_inc=25 # Last iter to increase #Gauss on. 22 | dim=40 23 | beam=10 24 | retry_beam=40 25 | careful=false 26 | boost_silence=1.0 # Factor by which to boost silence likelihoods in alignment 27 | power=0.25 # Exponent for number of gaussians according to occurrence counts 28 | randprune=4.0 # This is approximately the ratio by which we will speed up the 29 | # LDA and MLLT calculations via randomized pruning. 30 | splice_opts= 31 | cluster_thresh=-1 # for build-tree control final bottom-up clustering of leaves 32 | norm_vars=false # deprecated. Prefer --cmvn-opts "--norm-vars=false" 33 | cmvn_opts= 34 | context_opts= # use "--context-width=5 --central-position=2" for quinphone. 35 | # End configuration. 36 | train_tree=true # if false, don't actually train the tree. 37 | use_lda_mat= # If supplied, use this LDA[+MLLT] matrix. 38 | num_nonsil_states=3 39 | 40 | echo "$0 $@" # Print the command line for logging 41 | 42 | [ -f path.sh ] && . ./path.sh 43 | . parse_options.sh || exit 1; 44 | 45 | if [ $# != 6 ]; then 46 | echo "Usage: steps/train_lda_mllt.sh [options] <#leaves> <#gauss> " 47 | echo " e.g.: steps/train_lda_mllt.sh 2500 15000 data/train_si84 data/lang exp/tri1_ali_si84 exp/tri2b" 48 | echo "Main options (for others, see top of script file)" 49 | echo " --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs." 50 | echo " --config # config containing options" 51 | echo " --stage # stage to do partial re-run from." 52 | exit 1; 53 | fi 54 | 55 | numleaves=$1 56 | totgauss=$2 57 | data=$3 58 | lang=$4 59 | alidir=$5 60 | dir=$6 61 | 62 | for f in $alidir/final.mdl $alidir/ali.1.gz $data/feats.scp $lang/phones.txt; do 63 | [ ! -f $f ] && echo "train_lda_mllt.sh: no such file $f" && exit 1; 64 | done 65 | 66 | numgauss=$numleaves 67 | incgauss=$[($totgauss-$numgauss)/$max_iter_inc] # per-iter #gauss increment 68 | oov=`cat $lang/oov.int` || exit 1; 69 | nj=`cat $alidir/num_jobs` || exit 1; 70 | silphonelist=`cat $lang/phones/silence.csl` || exit 1; 71 | ciphonelist=`cat $lang/phones/context_indep.csl` || exit 1; 72 | 73 | mkdir -p $dir/log 74 | 75 | utils/lang/check_phones_compatible.sh $lang/phones.txt $alidir/phones.txt || exit 1; 76 | cp $lang/phones.txt $dir || exit 1; 77 | 78 | echo $nj >$dir/num_jobs 79 | echo "$splice_opts" >$dir/splice_opts # keep track of frame-splicing options 80 | # so that later stages of system building can know what they were. 81 | 82 | 83 | [ $(cat $alidir/cmvn_opts 2>/dev/null | wc -c) -gt 1 ] && [ -z "$cmvn_opts" ] && \ 84 | echo "$0: warning: ignoring CMVN options from source directory $alidir" 85 | $norm_vars && cmvn_opts="--norm-vars=true $cmvn_opts" 86 | echo $cmvn_opts > $dir/cmvn_opts # keep track of options to CMVN. 87 | 88 | sdata=$data/split$nj; 89 | split_data.sh $data $nj || exit 1; 90 | 91 | splicedfeats="ark,s,cs:apply-cmvn $cmvn_opts --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp scp:$sdata/JOB/feats.scp ark:- | splice-feats $splice_opts ark:- ark:- |" 92 | # Note: $feats gets overwritten later in the script. 93 | feats="$splicedfeats transform-feats $dir/0.mat ark:- ark:- |" 94 | 95 | 96 | 97 | if [ $stage -le -5 ]; then 98 | if [ -z "$use_lda_mat" ]; then 99 | echo "$0: Accumulating LDA statistics." 100 | rm $dir/lda.*.acc 2>/dev/null 101 | $cmd JOB=1:$nj $dir/log/lda_acc.JOB.log \ 102 | ali-to-post "ark:gunzip -c $alidir/ali.JOB.gz|" ark:- \| \ 103 | weight-silence-post 0.0 $silphonelist $alidir/final.mdl ark:- ark:- \| \ 104 | acc-lda --rand-prune=$randprune $alidir/final.mdl "$splicedfeats" ark,s,cs:- \ 105 | $dir/lda.JOB.acc || exit 1; 106 | est-lda --write-full-matrix=$dir/full.mat --dim=$dim $dir/0.mat $dir/lda.*.acc \ 107 | 2>$dir/log/lda_est.log || exit 1; 108 | rm $dir/lda.*.acc 109 | else 110 | echo "$0: Using supplied LDA matrix $use_lda_mat" 111 | cp $use_lda_mat $dir/0.mat || exit 1; 112 | [ ! -z "$mllt_iters" ] && \ 113 | echo "$0: Warning: using supplied LDA matrix $use_lda_mat but we will do MLLT," && \ 114 | echo " which you might not want; to disable MLLT, specify --mllt-iters ''" && \ 115 | sleep 5 116 | fi 117 | fi 118 | 119 | cur_lda_iter=0 120 | 121 | if [ $stage -le -4 ] && $train_tree; then 122 | echo "$0: Accumulating tree stats" 123 | $cmd JOB=1:$nj $dir/log/acc_tree.JOB.log \ 124 | acc-tree-stats $context_opts \ 125 | --ci-phones=$ciphonelist $alidir/final.mdl "$feats" \ 126 | "ark:gunzip -c $alidir/ali.JOB.gz|" $dir/JOB.treeacc || exit 1; 127 | [ `ls $dir/*.treeacc | wc -w` -ne "$nj" ] && echo "$0: Wrong #tree-accs" && exit 1; 128 | $cmd $dir/log/sum_tree_acc.log \ 129 | sum-tree-stats $dir/treeacc $dir/*.treeacc || exit 1; 130 | rm $dir/*.treeacc 131 | fi 132 | 133 | 134 | if [ $stage -le -3 ] && $train_tree; then 135 | echo "$0: Getting questions for tree clustering." 136 | # preparing questions, roots file... 137 | cluster-phones --pdf-class-list=$(($num_nonsil_states / 2)) $context_opts $dir/treeacc $lang/phones/sets.int \ 138 | $dir/questions.int 2> $dir/log/questions.log || exit 1; 139 | cat $lang/phones/extra_questions.int >> $dir/questions.int 140 | compile-questions $context_opts $lang/topo $dir/questions.int \ 141 | $dir/questions.qst 2>$dir/log/compile_questions.log || exit 1; 142 | 143 | echo "$0: Building the tree" 144 | $cmd $dir/log/build_tree.log \ 145 | build-tree $context_opts --verbose=1 --max-leaves=$numleaves \ 146 | --cluster-thresh=$cluster_thresh $dir/treeacc $lang/phones/roots.int \ 147 | $dir/questions.qst $lang/topo $dir/tree || exit 1; 148 | fi 149 | 150 | if [ $stage -le -2 ]; then 151 | echo "$0: Initializing the model" 152 | if $train_tree; then 153 | gmm-init-model --write-occs=$dir/1.occs \ 154 | $dir/tree $dir/treeacc $lang/topo $dir/1.mdl 2> $dir/log/init_model.log || exit 1; 155 | grep 'no stats' $dir/log/init_model.log && echo "This is a bad warning."; 156 | rm $dir/treeacc 157 | else 158 | cp $alidir/tree $dir/ || exit 1; 159 | $cmd JOB=1 $dir/log/init_model.log \ 160 | gmm-init-model-flat $dir/tree $lang/topo $dir/1.mdl \ 161 | "$feats subset-feats ark:- ark:-|" || exit 1; 162 | fi 163 | fi 164 | 165 | 166 | if [ $stage -le -1 ]; then 167 | # Convert the alignments. 168 | echo "$0: Converting alignments from $alidir to use current tree" 169 | $cmd JOB=1:$nj $dir/log/convert.JOB.log \ 170 | convert-ali $alidir/final.mdl $dir/1.mdl $dir/tree \ 171 | "ark:gunzip -c $alidir/ali.JOB.gz|" "ark:|gzip -c >$dir/ali.JOB.gz" || exit 1; 172 | fi 173 | 174 | if [ $stage -le 0 ] && [ "$realign_iters" != "" ]; then 175 | echo "$0: Compiling graphs of transcripts" 176 | $cmd JOB=1:$nj $dir/log/compile_graphs.JOB.log \ 177 | compile-train-graphs --read-disambig-syms=$lang/phones/disambig.int $dir/tree $dir/1.mdl $lang/L.fst \ 178 | "ark:utils/sym2int.pl --map-oov $oov -f 2- $lang/words.txt < $data/split$nj/JOB/text |" \ 179 | "ark:|gzip -c >$dir/fsts.JOB.gz" || exit 1; 180 | fi 181 | 182 | 183 | x=1 184 | while [ $x -lt $num_iters ]; do 185 | echo Training pass $x 186 | if echo $realign_iters | grep -w $x >/dev/null && [ $stage -le $x ]; then 187 | echo Aligning data 188 | mdl="gmm-boost-silence --boost=$boost_silence `cat $lang/phones/optional_silence.csl` $dir/$x.mdl - |" 189 | $cmd JOB=1:$nj $dir/log/align.$x.JOB.log \ 190 | gmm-align-compiled $scale_opts --beam=$beam --retry-beam=$retry_beam --careful=$careful "$mdl" \ 191 | "ark:gunzip -c $dir/fsts.JOB.gz|" "$feats" \ 192 | "ark:|gzip -c >$dir/ali.JOB.gz" || exit 1; 193 | fi 194 | if echo $mllt_iters | grep -w $x >/dev/null; then 195 | if [ $stage -le $x ]; then 196 | echo "$0: Estimating MLLT" 197 | $cmd JOB=1:$nj $dir/log/macc.$x.JOB.log \ 198 | ali-to-post "ark:gunzip -c $dir/ali.JOB.gz|" ark:- \| \ 199 | weight-silence-post 0.0 $silphonelist $dir/$x.mdl ark:- ark:- \| \ 200 | gmm-acc-mllt --rand-prune=$randprune $dir/$x.mdl "$feats" ark:- $dir/$x.JOB.macc \ 201 | || exit 1; 202 | est-mllt $dir/$x.mat.new $dir/$x.*.macc 2> $dir/log/mupdate.$x.log || exit 1; 203 | gmm-transform-means $dir/$x.mat.new $dir/$x.mdl $dir/$x.mdl \ 204 | 2> $dir/log/transform_means.$x.log || exit 1; 205 | compose-transforms --print-args=false $dir/$x.mat.new $dir/$cur_lda_iter.mat $dir/$x.mat || exit 1; 206 | rm $dir/$x.*.macc 207 | fi 208 | feats="$splicedfeats transform-feats $dir/$x.mat ark:- ark:- |" 209 | cur_lda_iter=$x 210 | fi 211 | 212 | if [ $stage -le $x ]; then 213 | $cmd JOB=1:$nj $dir/log/acc.$x.JOB.log \ 214 | gmm-acc-stats-ali $dir/$x.mdl "$feats" \ 215 | "ark,s,cs:gunzip -c $dir/ali.JOB.gz|" $dir/$x.JOB.acc || exit 1; 216 | $cmd $dir/log/update.$x.log \ 217 | gmm-est --write-occs=$dir/$[$x+1].occs --mix-up=$numgauss --power=$power \ 218 | $dir/$x.mdl "gmm-sum-accs - $dir/$x.*.acc |" $dir/$[$x+1].mdl || exit 1; 219 | rm $dir/$x.mdl $dir/$x.*.acc $dir/$x.occs 220 | fi 221 | [ $x -le $max_iter_inc ] && numgauss=$[$numgauss+$incgauss]; 222 | x=$[$x+1]; 223 | done 224 | 225 | rm $dir/final.{mdl,mat,occs} 2>/dev/null 226 | ln -s $x.mdl $dir/final.mdl 227 | ln -s $x.occs $dir/final.occs 228 | ln -s $cur_lda_iter.mat $dir/final.mat 229 | 230 | steps/diagnostic/analyze_alignments.sh --cmd "$cmd" $lang $dir 231 | 232 | # Summarize warning messages... 233 | utils/summarize_warnings.pl $dir/log 234 | 235 | steps/info/gmm_dir_info.pl $dir 236 | 237 | echo "$0: Done training system with LDA+MLLT features in $dir" 238 | 239 | exit 0 240 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/kaldi_self_train/st/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -eu 4 | 5 | w2v_dir= # contains features `{train,valid}.{npy,lengths}`, real transcripts `{train,valid}.${label}`, and dict `dict.${label}.txt` 6 | lab_dir= # contains pseudo labels `{train,valid}.txt` 7 | out_dir= # output root 8 | arpa_lm= # phone LM 9 | arpa_lm_bin= # (binary) phone LM for KenLM, used in unsupervised selection 10 | 11 | label=phnc 12 | train_name="train" 13 | valid_name="valid" 14 | data_dir=${out_dir}/data 15 | 16 | mkdir -p ${out_dir}/exp 17 | local/prepare_lang.sh $w2v_dir/dict.${label}.txt $data_dir 18 | local/prepare_lm.sh $arpa_lm $data_dir 19 | 20 | for x in $train_name $valid_name; do 21 | x_gt=${x}_gt 22 | 23 | # prepare pseudo data 24 | python local/prepare_data_from_w2v.py $w2v_dir $data_dir $x 25 | steps/compute_cmvn_stats.sh $data_dir/$x $out_dir/exp/make_feat/$x $out_dir/feats/$x 26 | python local/copy_aligned_text.py < $lab_dir/$x.txt > $data_dir/$x/text 27 | 28 | # prepare ground truth data 29 | mkdir $data_dir/$x_gt 30 | cp $data_dir/$x/{feats.scp,cmvn.scp,utt2spk,spk2utt} $data_dir/$x_gt/ 31 | python local/copy_aligned_text.py < $w2v_dir/$x.$label > $data_dir/$x_gt/text 32 | done 33 | 34 | local/train_subset_lgbeam.sh \ 35 | --out_root ${out_dir} --out_name exp --train $train_name --valid $valid_name \ 36 | --mono_size 2000 --tri1_size 5000 --tri2b_size -1 --tri3b_size -1 \ 37 | --stage 1 --max_stage 3 $data_dir $data_dir/lang $data_dir/lang_test 38 | 39 | local/unsup_select_decode.sh \ 40 | --split $valid_name --kenlm_path $arpa_lm_bin \ 41 | --ref_txt $data_dir/${valid_name}_gt/text \ 42 | --psd_txt $data_dir/${valid_name}/text \ 43 | $out_dir/exp 44 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/kaldi_self_train/st/utils: -------------------------------------------------------------------------------- 1 | ../../wsj/s5/utils -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .wav2vec_u import Wav2vec_U 7 | 8 | 9 | __all__ = [ 10 | "Wav2vec_U", 11 | ] 12 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/scripts/apply_pca.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import os 9 | import os.path as osp 10 | import math 11 | import numpy as np 12 | import tqdm 13 | import torch 14 | from shutil import copyfile 15 | 16 | from npy_append_array import NpyAppendArray 17 | 18 | 19 | def get_parser(): 20 | parser = argparse.ArgumentParser( 21 | description="transforms features via a given pca and stored them in target dir" 22 | ) 23 | # fmt: off 24 | parser.add_argument('source', help='directory with features') 25 | parser.add_argument('--split', help='which split to read', required=True) 26 | parser.add_argument('--save-dir', help='where to save the output', required=True) 27 | parser.add_argument('--pca-path', type=str, help='pca location. will append _A.npy and _b.npy', required=True) 28 | parser.add_argument('--batch-size', type=int, default=2048000, help='batch size') 29 | parser.add_argument('--unfiltered', action='store_true', help='process the unfiltered version') 30 | # fmt: on 31 | 32 | return parser 33 | 34 | 35 | def main(): 36 | parser = get_parser() 37 | args = parser.parse_args() 38 | 39 | source_path = osp.join(args.source, args.split) 40 | data_poth = source_path + "_unfiltered" if args.unfiltered else source_path 41 | 42 | print(f"data path: {data_poth}") 43 | 44 | features = np.load(data_poth + ".npy", mmap_mode="r") 45 | pca_A = torch.from_numpy(np.load(args.pca_path + "_A.npy")).cuda() 46 | pca_b = torch.from_numpy(np.load(args.pca_path + "_b.npy")).cuda() 47 | 48 | os.makedirs(args.save_dir, exist_ok=True) 49 | save_path = osp.join(args.save_dir, args.split) 50 | 51 | copyfile(source_path + ".tsv", save_path + ".tsv") 52 | copyfile(data_poth + ".lengths", save_path + ".lengths") 53 | 54 | if osp.exists(source_path + ".phn"): 55 | copyfile(source_path + ".phn", save_path + ".phn") 56 | 57 | if osp.exists(source_path + ".wrd"): 58 | copyfile(source_path + ".wrd", save_path + ".wrd") 59 | 60 | if osp.exists(save_path + ".npy"): 61 | os.remove(save_path + ".npy") 62 | npaa = NpyAppendArray(save_path + ".npy") 63 | 64 | batches = math.ceil(features.shape[0] / args.batch_size) 65 | 66 | with torch.no_grad(): 67 | for b in tqdm.trange(batches): 68 | start = b * args.batch_size 69 | end = start + args.batch_size 70 | x = torch.from_numpy(features[start:end]).cuda() 71 | x = torch.matmul(x, pca_A) + pca_b 72 | npaa.append(x.cpu().numpy()) 73 | 74 | 75 | if __name__ == "__main__": 76 | main() 77 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/scripts/copy_labels.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import sys 8 | 9 | for idx, line in enumerate(sys.stdin): 10 | print(f"utt{idx:010d} {line}", end="") 11 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/scripts/filter_lexicon.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import sys 9 | 10 | from fairseq.data import Dictionary 11 | 12 | 13 | def get_parser(): 14 | parser = argparse.ArgumentParser( 15 | description="filters a lexicon given a unit dictionary" 16 | ) 17 | parser.add_argument("-d", "--unit-dict", help="unit dictionary", required=True) 18 | return parser 19 | 20 | 21 | def main(): 22 | parser = get_parser() 23 | args = parser.parse_args() 24 | 25 | d = Dictionary.load(args.unit_dict) 26 | symbols = set(d.symbols) 27 | 28 | for line in sys.stdin: 29 | items = line.rstrip().split() 30 | skip = len(items) < 2 31 | for x in items[1:]: 32 | if x not in symbols: 33 | skip = True 34 | break 35 | if not skip: 36 | print(line, end="") 37 | 38 | 39 | if __name__ == "__main__": 40 | main() 41 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/scripts/filter_tsv.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import argparse 9 | import sys 10 | 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--tsv", required=True, type=str) 14 | parser.add_argument("--no-skip", action="store_true") 15 | parser.add_argument("--keep", action="store_true") 16 | params = parser.parse_args() 17 | 18 | 19 | def get_fname(line): 20 | p = os.path.basename(line.split("\t")[0]) 21 | p = os.path.splitext(p)[0] 22 | return p 23 | 24 | 25 | # filenames to exclude 26 | seen = set() 27 | with open(params.tsv) as f: 28 | if not params.no_skip: 29 | root = next(f).rstrip() 30 | for line in f: 31 | seen.add(get_fname(line)) 32 | 33 | for i, line in enumerate(sys.stdin): 34 | exists = get_fname(line) in seen 35 | keep = (exists and params.keep) or (not exists and not params.keep) 36 | if i == 0 or keep: 37 | print(line, end="") 38 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/scripts/g2p_wrd_to_phn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import sys 9 | 10 | from g2p_en import G2p 11 | 12 | 13 | def main(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument( 16 | "--compact", 17 | action="store_true", 18 | help="if set, compacts phones", 19 | ) 20 | args = parser.parse_args() 21 | 22 | compact = args.compact 23 | 24 | wrd_to_phn = {} 25 | g2p = G2p() 26 | for line in sys.stdin: 27 | words = line.strip().split() 28 | phones = [] 29 | for w in words: 30 | if w not in wrd_to_phn: 31 | wrd_to_phn[w] = g2p(w) 32 | if compact: 33 | wrd_to_phn[w] = [ 34 | p[:-1] if p[-1].isnumeric() else p for p in wrd_to_phn[w] 35 | ] 36 | phones.extend(wrd_to_phn[w]) 37 | try: 38 | print(" ".join(phones)) 39 | except: 40 | print(wrd_to_phn, words, phones, file=sys.stderr) 41 | raise 42 | 43 | 44 | if __name__ == "__main__": 45 | main() 46 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/scripts/ltr_to_wrd.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import sys 8 | 9 | 10 | def main(): 11 | for line in sys.stdin: 12 | print(line.replace(" ", "").replace("|", " ").strip()) 13 | 14 | 15 | if __name__ == "__main__": 16 | main() 17 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/scripts/mean_pool.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import os 9 | import os.path as osp 10 | import math 11 | import numpy as np 12 | import tqdm 13 | import torch 14 | import torch.nn.functional as F 15 | from shutil import copyfile 16 | 17 | from npy_append_array import NpyAppendArray 18 | 19 | 20 | def get_parser(): 21 | parser = argparse.ArgumentParser( 22 | description="mean pools representations by compressing uniform splits of the data" 23 | ) 24 | # fmt: off 25 | parser.add_argument('source', help='directory with features') 26 | parser.add_argument('--split', help='which split to read', required=True) 27 | parser.add_argument('--save-dir', help='where to save the output', required=True) 28 | parser.add_argument('--subsample-rate', type=float, default=0.5, help='size to subsample data to') 29 | 30 | parser.add_argument('--remove-extra', action='store_true', help='if true, removes extra states that cant be pooled, otherwise pads with 0s') 31 | # fmt: on 32 | 33 | return parser 34 | 35 | 36 | def main(): 37 | parser = get_parser() 38 | args = parser.parse_args() 39 | 40 | source_path = osp.join(args.source, args.split) 41 | 42 | print(f"data path: {source_path}") 43 | 44 | features = np.load(source_path + ".npy", mmap_mode="r") 45 | 46 | os.makedirs(args.save_dir, exist_ok=True) 47 | save_path = osp.join(args.save_dir, args.split) 48 | 49 | copyfile(source_path + ".tsv", save_path + ".tsv") 50 | 51 | if os.path.exists(source_path + ".phn"): 52 | copyfile(source_path + ".phn", save_path + ".phn") 53 | if os.path.exists(source_path + ".wrd"): 54 | copyfile(source_path + ".wrd", save_path + ".wrd") 55 | 56 | if os.path.exists(osp.join(args.source, "dict.phn.txt")): 57 | copyfile( 58 | osp.join(args.source, "dict.phn.txt"), 59 | osp.join(args.save_dir, "dict.phn.txt"), 60 | ) 61 | 62 | if osp.exists(save_path + ".npy"): 63 | os.remove(save_path + ".npy") 64 | npaa = NpyAppendArray(save_path + ".npy") 65 | 66 | with open(source_path + ".lengths", "r") as lf: 67 | lengths = lf.readlines() 68 | 69 | fsz = features.shape[-1] 70 | start = 0 71 | with torch.no_grad(): 72 | with open(save_path + ".lengths", "w") as lengths_out: 73 | for length in tqdm.tqdm(lengths): 74 | length = int(length) 75 | end = start + length 76 | feats = features[start:end] 77 | start += length 78 | x = torch.from_numpy(feats).cuda() 79 | target_num = math.ceil(length * args.subsample_rate) 80 | rem = length % target_num 81 | 82 | if rem > 0: 83 | if args.remove_extra: 84 | to_rem = target_num - rem 85 | target_num -= 1 86 | x = x[:-to_rem] 87 | else: 88 | to_add = target_num - rem 89 | x = F.pad(x, [0, 0, 0, to_add]) 90 | x[-to_add:] = x[-to_add - 1] 91 | 92 | x = x.view(target_num, -1, fsz) 93 | x = x.mean(dim=-2) 94 | print(target_num, file=lengths_out) 95 | npaa.append(x.cpu().numpy()) 96 | 97 | 98 | if __name__ == "__main__": 99 | main() 100 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/scripts/merge_clusters.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import os 9 | import os.path as osp 10 | import numpy as np 11 | import tqdm 12 | import torch 13 | import random 14 | from shutil import copyfile 15 | 16 | from npy_append_array import NpyAppendArray 17 | 18 | 19 | def get_parser(): 20 | parser = argparse.ArgumentParser( 21 | description="transforms features via a given pca and stored them in target dir" 22 | ) 23 | # fmt: off 24 | parser.add_argument('source', help='directory with features') 25 | parser.add_argument('--split', help='which split to read', required=True) 26 | parser.add_argument('--save-dir', help='where to save the output', required=True) 27 | parser.add_argument('--cluster-dir', help='where the clusters are') 28 | parser.add_argument('--pooling', type=str, default='mean', choices=['mean', 'sample'], help='how to pool') 29 | # fmt: on 30 | 31 | return parser 32 | 33 | 34 | def main(): 35 | parser = get_parser() 36 | args = parser.parse_args() 37 | 38 | source_path = osp.join(args.source, args.split) 39 | cluster_path = osp.join(args.cluster_dir, args.split + ".src") 40 | print(f"data path: {source_path}") 41 | 42 | features = np.load(source_path + ".npy", mmap_mode="r") 43 | sizes = [] 44 | offsets = [] 45 | offset = 0 46 | with open(source_path + ".lengths", "r") as len_f: 47 | for line in len_f: 48 | length = int(line.rstrip()) 49 | sizes.append(length) 50 | offsets.append(offset) 51 | offset += length 52 | 53 | clusters = [] 54 | with open(cluster_path, "r") as cf: 55 | for line in cf: 56 | line = line.rstrip() 57 | items = line.split() 58 | items = list(map(int, items)) 59 | clusters.append(items) 60 | 61 | os.makedirs(args.save_dir, exist_ok=True) 62 | save_path = osp.join(args.save_dir, args.split) 63 | 64 | copyfile(source_path + ".tsv", save_path + ".tsv") 65 | 66 | if os.path.exists(source_path + ".phn"): 67 | copyfile(source_path + ".phn", save_path + ".phn") 68 | if os.path.exists(osp.join(args.source, "dict.phn.txt")): 69 | copyfile( 70 | osp.join(args.source, "dict.phn.txt"), 71 | osp.join(args.save_dir, "dict.phn.txt"), 72 | ) 73 | if os.path.exists(source_path + ".wrd"): 74 | copyfile(source_path + ".wrd", save_path + ".wrd") 75 | 76 | if osp.exists(save_path + ".npy"): 77 | os.remove(save_path + ".npy") 78 | npaa = NpyAppendArray(save_path + ".npy") 79 | 80 | def merge(feats, clust): 81 | feats = torch.from_numpy(feats.copy()) 82 | clust = torch.LongTensor(clust) 83 | _, counts = clust.unique_consecutive(return_counts=True) 84 | curr = 0 85 | 86 | merged = [] 87 | for c in counts: 88 | c = c.item() 89 | start = curr 90 | end = curr + c 91 | curr += c 92 | if args.pooling == "mean": 93 | new_x = feats[start:end].mean(dim=0) 94 | elif args.pooling == "sample": 95 | new_x = feats[start + int(random.random() * c)] 96 | else: 97 | raise NotImplementedError() 98 | merged.append(new_x) 99 | 100 | return torch.stack(merged, dim=0).numpy() 101 | 102 | with open(save_path + ".lengths", "w") as l_f: 103 | for size, offset, clust in tqdm.tqdm( 104 | zip(sizes, offsets, clusters), total=len(sizes) 105 | ): 106 | end = size + offset 107 | feats = features[offset:end] 108 | feats = merge(feats, clust) 109 | print(len(feats), file=l_f) 110 | npaa.append(feats) 111 | 112 | 113 | if __name__ == "__main__": 114 | main() 115 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/scripts/normalize_and_filter_text.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import fasttext as ft 9 | import os 10 | import regex 11 | import sys 12 | 13 | 14 | def get_parser(): 15 | parser = argparse.ArgumentParser( 16 | description="reads text from stdin and outputs normalized, lid-filtered version to stdout" 17 | ) 18 | parser.add_argument( 19 | "--fasttext-model", 20 | help="path to fasttext model", 21 | default="lid.187.bin", 22 | ) 23 | parser.add_argument("--lang", help="language id", required=True) 24 | parser.add_argument( 25 | "--lid-threshold", 26 | type=float, 27 | help="threshold for this lang id probability", 28 | default=0.4, 29 | ) 30 | 31 | return parser 32 | 33 | 34 | def main(): 35 | parser = get_parser() 36 | args = parser.parse_args() 37 | filter_r = regex.compile(r"[^\p{L}\p{N}\p{M}\' \-]") 38 | 39 | lg = args.lang.lower() 40 | lg_label = f"__label__{lg}" 41 | thresh = args.lid_threshold 42 | 43 | if os.path.exists(args.fasttext_model): 44 | model = ft.load_model(args.fasttext_model) 45 | else: 46 | print( 47 | f"fasttext language id model {args.fasttext_model} not found. Proceeding without language filtering. " 48 | f"To enable language filtering, please download the latest language id model " 49 | f"from https://fasttext.cc/docs/en/language-identification.html", 50 | file=sys.stderr, 51 | ) 52 | model = None 53 | 54 | for line in sys.stdin: 55 | line = line.strip() 56 | line = filter_r.sub(" ", line) 57 | line = " ".join(line.split()) 58 | 59 | if model is not None: 60 | lid, prob = model.predict(line, k=100) 61 | try: 62 | target_idx = lid.index(lg_label) 63 | except ValueError: 64 | continue 65 | if target_idx == 0 or prob[target_idx] >= thresh: 66 | print(line) 67 | else: 68 | print(line) 69 | 70 | 71 | if __name__ == "__main__": 72 | main() 73 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/scripts/normalize_text.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import regex 8 | import sys 9 | 10 | 11 | def main(): 12 | filter_r = regex.compile(r"[^\p{L}\p{N}\p{M}\' \-]") 13 | 14 | for line in sys.stdin: 15 | line = line.strip() 16 | line = filter_r.sub(" ", line) 17 | line = " ".join(line.split()) 18 | print(line) 19 | 20 | 21 | if __name__ == "__main__": 22 | main() 23 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/scripts/pca.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import os 9 | import os.path as osp 10 | import numpy as np 11 | 12 | import faiss 13 | 14 | 15 | 16 | def get_parser(): 17 | parser = argparse.ArgumentParser( 18 | description="compute a pca matrix given an array of numpy features" 19 | ) 20 | # fmt: off 21 | parser.add_argument('data', help='numpy file containing features') 22 | parser.add_argument('--output', help='where to save the pca matrix', required=True) 23 | parser.add_argument('--dim', type=int, help='dim for pca reduction', required=True) 24 | parser.add_argument('--eigen-power', type=float, default=0, help='eigen power, -0.5 for whitening') 25 | 26 | return parser 27 | 28 | 29 | def main(): 30 | parser = get_parser() 31 | args = parser.parse_args() 32 | 33 | print("Reading features") 34 | x = np.load(args.data, mmap_mode="r") 35 | 36 | print("Computing PCA") 37 | pca = faiss.PCAMatrix(x.shape[-1], args.dim, args.eigen_power) 38 | pca.train(x) 39 | b = faiss.vector_to_array(pca.b) 40 | A = faiss.vector_to_array(pca.A).reshape(pca.d_out, pca.d_in) 41 | 42 | os.makedirs(args.output, exist_ok=True) 43 | 44 | prefix = str(args.dim) 45 | if args.eigen_power != 0: 46 | prefix += f"_{args.eigen_power}" 47 | 48 | np.save(osp.join(args.output, f"{prefix}_pca_A"), A.T) 49 | np.save(osp.join(args.output, f"{prefix}_pca_b"), b) 50 | 51 | 52 | if __name__ == "__main__": 53 | main() 54 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/scripts/phonemize_with_sil.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import numpy as np 9 | import sys 10 | 11 | 12 | def get_parser(): 13 | parser = argparse.ArgumentParser( 14 | description="converts words to phones adding optional silences around in between words" 15 | ) 16 | parser.add_argument( 17 | "--sil-prob", 18 | "-s", 19 | type=float, 20 | default=0, 21 | help="probability of inserting silence between each word", 22 | ) 23 | parser.add_argument( 24 | "--surround", 25 | action="store_true", 26 | help="if set, surrounds each example with silence", 27 | ) 28 | parser.add_argument( 29 | "--lexicon", 30 | help="lexicon to convert to phones", 31 | required=True, 32 | ) 33 | 34 | return parser 35 | 36 | 37 | def main(): 38 | parser = get_parser() 39 | args = parser.parse_args() 40 | 41 | sil_prob = args.sil_prob 42 | surround = args.surround 43 | sil = "" 44 | 45 | wrd_to_phn = {} 46 | 47 | with open(args.lexicon, "r") as lf: 48 | for line in lf: 49 | items = line.rstrip().split() 50 | assert len(items) > 1, line 51 | assert items[0] not in wrd_to_phn, items 52 | wrd_to_phn[items[0]] = items[1:] 53 | 54 | for line in sys.stdin: 55 | words = line.strip().split() 56 | 57 | if not all(w in wrd_to_phn for w in words): 58 | continue 59 | 60 | phones = [] 61 | if surround: 62 | phones.append(sil) 63 | 64 | sample_sil_probs = None 65 | if sil_prob > 0 and len(words) > 1: 66 | sample_sil_probs = np.random.random(len(words) - 1) 67 | 68 | for i, w in enumerate(words): 69 | phones.extend(wrd_to_phn[w]) 70 | if ( 71 | sample_sil_probs is not None 72 | and i < len(sample_sil_probs) 73 | and sample_sil_probs[i] < sil_prob 74 | ): 75 | phones.append(sil) 76 | 77 | if surround: 78 | phones.append(sil) 79 | print(" ".join(phones)) 80 | 81 | 82 | if __name__ == "__main__": 83 | main() 84 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/scripts/prepare_audio.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env zsh 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | source_dir=$1 8 | tgt_dir=$2 9 | model=$3 10 | 11 | if [ -z "$4" ] 12 | then 13 | dim=512 14 | else 15 | dim=$4 16 | fi 17 | 18 | echo "using $dim dim for PCA" 19 | 20 | if [ -z "$5" ] 21 | then 22 | layer=14 23 | else 24 | layer=$5 25 | fi 26 | 27 | echo "extracting from layer $layer" 28 | 29 | train_split=train 30 | valid_split=valid 31 | test_split=test 32 | 33 | all_splits=($train_split) 34 | 35 | if [[ -f "$source_dir/valid.tsv" ]]; then 36 | all_splits+=('valid') 37 | fi 38 | 39 | if [[ -f "$source_dir/test.tsv" ]]; then 40 | all_splits+=('test') 41 | fi 42 | 43 | echo "processing splits: $all_splits" 44 | 45 | mkdir -p $tgt_dir 46 | 47 | cp $source_dir/*.tsv $tgt_dir 48 | cp $source_dir/*.wrd $tgt_dir 49 | cp $source_dir/*.ltr $tgt_dir 50 | cp $source_dir/*.phn $tgt_dir 51 | cp $source_dir/dict* $tgt_dir 52 | 53 | setopt shwordsplit 54 | 55 | for split in $all_splits; do 56 | python $FAIRSEQ_ROOT/examples/wav2vec/unsupervised/scripts/wav2vec_extract_features.py $source_dir --split $split \ 57 | --save-dir $tgt_dir --checkpoint $model --layer $layer 58 | done 59 | 60 | python $FAIRSEQ_ROOT/examples/wav2vec/unsupervised/scripts/wav2vec_cluster_faiss.py $tgt_dir/${train_split}.tsv \ 61 | --checkpoint $model --save-dir $tgt_dir -f "CLUS128" --sample-pct 1.0 62 | 63 | for split in $all_splits; do 64 | python $FAIRSEQ_ROOT/examples/wav2vec/unsupervised/scripts/wav2vec_apply_cluster_faiss.py $tgt_dir \ 65 | --checkpoint $model --path $tgt_dir/CLUS128 --split $split 66 | done 67 | 68 | python $FAIRSEQ_ROOT/examples/wav2vec/unsupervised/scripts/pca.py $tgt_dir/${train_split}.npy --output $tgt_dir/pca --dim $dim 69 | 70 | for split in $all_splits; do 71 | python $FAIRSEQ_ROOT/examples/wav2vec/unsupervised/scripts/apply_pca.py $tgt_dir --split $split --save-dir $tgt_dir/precompute_pca$dim --pca-path $tgt_dir/pca/${dim}_pca --batch-size 1048000 72 | 73 | python $FAIRSEQ_ROOT/examples/wav2vec/unsupervised/scripts/merge_clusters.py $tgt_dir/precompute_pca$dim --cluster-dir $tgt_dir/CLUS128 \ 74 | --split $split --save-dir $tgt_dir/precompute_pca${dim}_cls128_mean --pooling mean 75 | 76 | python $FAIRSEQ_ROOT/examples/wav2vec/unsupervised/scripts/mean_pool.py $tgt_dir/precompute_pca${dim}_cls128_mean \ 77 | --save-dir $tgt_dir/precompute_pca${dim}_cls128_mean_pooled --split $split 78 | done 79 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/scripts/prepare_text.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env zsh 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | lg=$1 8 | text_path=$2 9 | target_dir=$3 10 | min_phones=$4 11 | phonemizer=$5 12 | lid_path=$6 13 | 14 | if [ -z "$lid_path" ]; then 15 | lid_path="lid.187.bin" 16 | fi 17 | 18 | ph_lg=${lg:l} 19 | if test "$lg" = 'fr'; then 20 | ph_lg='fr-fr' 21 | elif test "$lg" = 'en'; then 22 | ph_lg='en-us' 23 | elif test "$lg" = 'pt'; then 24 | ph_lg='pt-br' 25 | fi 26 | 27 | ESPEAK_PATH='' 28 | if test "$phonemizer" = 'espeak'; then 29 | ESPEAK_PATH=$(which espeak) 30 | elif test "$phonemizer" = 'espeak-ng'; then 31 | ESPEAK_PATH=$(which espeak-ng) 32 | elif test "$phonemizer" = 'G2P'; then 33 | ESPEAK_PATH='' 34 | else 35 | echo "Unknown phonemizer $phonemizer. Valid options are espeak, espean-ng and G2P" 36 | exit 1 37 | fi 38 | 39 | echo $lg 40 | echo $ph_lg 41 | echo $text_path 42 | echo $target_dir 43 | echo "min phone seen threshold is $min_phones" 44 | 45 | mkdir -p $target_dir 46 | python $FAIRSEQ_ROOT/examples/wav2vec/unsupervised/scripts/normalize_and_filter_text.py --lang $lg --fasttext-model $lid_path < $text_path | grep -v '\-\-\-' >! $target_dir/lm.upper.lid.txt 47 | python $FAIRSEQ_ROOT/fairseq_cli/preprocess.py --dataset-impl mmap --trainpref $target_dir/lm.upper.lid.txt --only-source --destdir $target_dir --thresholdsrc 2 --padding-factor 1 --dict-only 48 | cut -f1 -d' ' $target_dir/dict.txt | grep -v -x '[[:punct:]]*' | grep -Pv '\d\d\d\d\d+' >! $target_dir/words.txt 49 | 50 | 51 | if [ -z "$ESPEAK_PATH" ]; then 52 | python $FAIRSEQ_ROOT/examples/wav2vec/unsupervised/scripts/g2p_wrd_to_phn.py --compact < $target_dir/words.txt > $target_dir/phones.txt 53 | else 54 | # echoing 1 into corpus will prevent the mismatch lines between lexicon and phones in case the phonemizer fails 55 | one=$(echo "1" | PHONEMIZER_ESPEAK_PATH=$ESPEAK_PATH phonemize -p ' ' -w '' -l $ph_lg --language-switch remove-flags) 56 | sed 's/$/ 1/' $target_dir/words.txt | PHONEMIZER_ESPEAK_PATH=$ESPEAK_PATH phonemize -o $target_dir/phones.txt -p ' ' -w '' -l $ph_lg -j 70 --language-switch remove-flags 57 | echo "one is ${one}" 58 | sed -i "s/${one}$//" $target_dir/phones.txt 59 | fi 60 | 61 | paste $target_dir/words.txt $target_dir/phones.txt >! $target_dir/lexicon.lst 62 | 63 | python $FAIRSEQ_ROOT/fairseq_cli/preprocess.py --dataset-impl mmap --trainpref $target_dir/phones.txt --only-source --destdir $target_dir/phones --thresholdsrc $min_phones --padding-factor 1 --dict-only 64 | 65 | python $FAIRSEQ_ROOT/examples/wav2vec/unsupervised/scripts/filter_lexicon.py -d $target_dir/phones/dict.txt < $target_dir/lexicon.lst >! $target_dir/lexicon_filtered.lst 66 | python $FAIRSEQ_ROOT/examples/wav2vec/unsupervised/scripts/phonemize_with_sil.py -s 0.25 --surround --lexicon $target_dir/lexicon_filtered.lst < $target_dir/lm.upper.lid.txt >! $target_dir/phones/lm.phones.filtered.txt 67 | cp $target_dir/phones/dict.txt $target_dir/phones/dict.phn.txt 68 | echo " 0" >> $target_dir/phones/dict.phn.txt 69 | python $FAIRSEQ_ROOT/fairseq_cli/preprocess.py --dataset-impl mmap --trainpref $target_dir/phones/lm.phones.filtered.txt --workers 70 --only-source --destdir $target_dir/phones --srcdict $target_dir/phones/dict.phn.txt 70 | 71 | $KENLM_ROOT/lmplz -o 4 < $target_dir/lm.upper.lid.txt --discount_fallback --prune 0 0 0 3 >! $target_dir/kenlm.wrd.o40003.arpa 72 | $KENLM_ROOT/build_binary $target_dir/kenlm.wrd.o40003.arpa $target_dir/kenlm.wrd.o40003.bin 73 | 74 | lg=$lg python $FAIRSEQ_ROOT/examples/speech_recognition/kaldi/kaldi_initializer.py kaldi_root=$KALDI_ROOT fst_dir=$target_dir/fst/phn_to_words_sil lm_arpa=$target_dir/kenlm.wrd.o40003.arpa wav2letter_lexicon=$target_dir/lexicon_filtered.lst data_dir=$target_dir/phones in_labels=phn "blank_symbol=''" 75 | lg=$lg python $FAIRSEQ_ROOT/examples/speech_recognition/kaldi/kaldi_initializer.py kaldi_root=$KALDI_ROOT fst_dir=$target_dir/fst/phn_to_words lm_arpa=$target_dir/kenlm.wrd.o40003.arpa wav2letter_lexicon=$target_dir/lexicon_filtered.lst data_dir=$target_dir/phones in_labels=phn 76 | 77 | $KENLM_ROOT/lmplz -o 4 < $target_dir/phones/lm.phones.filtered.txt --discount_fallback >! $target_dir/phones/lm.phones.filtered.04.arpa 78 | $KENLM_ROOT/build_binary $target_dir/phones/lm.phones.filtered.04.arpa $target_dir/phones/lm.phones.filtered.04.bin 79 | $KENLM_ROOT/lmplz -o 6 < $target_dir/phones/lm.phones.filtered.txt --discount_fallback >! $target_dir/phones/lm.phones.filtered.06.arpa 80 | $KENLM_ROOT/build_binary $target_dir/phones/lm.phones.filtered.06.arpa $target_dir/phones/lm.phones.filtered.06.bin 81 | 82 | lg=$lg python $FAIRSEQ_ROOT/examples/speech_recognition/kaldi/kaldi_initializer.py kaldi_root=$KALDI_ROOT fst_dir=$target_dir/fst/phn_to_phn_sil lm_arpa=$target_dir/phones/lm.phones.filtered.06.arpa data_dir=$target_dir/phones in_labels=phn "blank_symbol=''" 83 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/scripts/prepare_timit.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | timit_root=$1 # assume it is the upper-cased version 8 | tgt_dir=$2 9 | model=$3 10 | 11 | set -eu 12 | 13 | setups="matched unmatched" 14 | splits="test valid train train_text" 15 | 16 | tgt_dir=$(realpath $tgt_dir) 17 | sph2wav=$KALDI_ROOT/tools/sph2pipe_v2.5/sph2pipe 18 | wav_dir=$tgt_dir/wav 19 | 20 | 21 | mkdir -p $tgt_dir $wav_dir 22 | find $timit_root/{TRAIN,TEST} -iname "*.WAV" > $tgt_dir/all_sph.flist 23 | cat $tgt_dir/all_sph.flist | sed -e 's#//*#/#g' -e 's#.*/\([^/]*\)/\([^/]*\).WAV#\1_\2#g' > $tgt_dir/all.uid 24 | paste -d' ' $tgt_dir/{all_sph.flist,all.uid} | \ 25 | awk -v sph2wav=$sph2wav -v wav_dir=$wav_dir '{print sph2wav " -f wav " $1 " > " wav_dir "/" $2 ".wav"}' \ 26 | > $tgt_dir/sph2wav.sh 27 | bash $tgt_dir/sph2wav.sh 28 | cat $tgt_dir/all.uid | awk -v wav_dir=$(pwd)/$wav_dir '{print $1" "wav_dir"/"$1".wav"}' | sort > $tgt_dir/all_wav.scp 29 | cut -d' ' -f2 $tgt_dir/all_wav.scp | xargs -I{} soxi -s {} > $tgt_dir/all.dur 30 | paste -d' ' $tgt_dir/{all_wav.scp,all.dur} > $tgt_dir/all_wav_dur.scp 31 | rm $tgt_dir/{all.uid,all_sph.flist,sph2wav.sh} 32 | 33 | find $timit_root/{TRAIN,TEST} -iname "*.PHN" > $tgt_dir/all_phn60.flist 34 | while read line; do 35 | if [ ! -f $line ]; then 36 | >&2 echo "Cannot find transcription file '$line'" && exit 1; 37 | fi 38 | cut -f3 -d' ' "$line" | tr '\n' ' ' | perl -ape 's: *$:\n:;' 39 | done < $tgt_dir/all_phn60.flist > $tgt_dir/all.phn60 40 | cat $tgt_dir/all_phn60.flist | sed -e 's#//*#/#g' -e 's#.*/\([^/]*\)/\([^/]*\).PHN#\1_\2#g' | \ 41 | paste -d' ' - $tgt_dir/all.phn60 | \ 42 | $KALDI_ROOT/egs/timit/s5/local/timit_norm_trans.pl -i - -m $KALDI_ROOT/egs/timit/s5/conf/phones.60-48-39.map -to 39 | \ 43 | sort > $tgt_dir/all.phn 44 | echo "done preparing wav and 39-phone transcripts" 45 | 46 | 47 | for s in $setups; do 48 | mkdir -p $tgt_dir/$s 49 | for x in $splits; do 50 | uid_path=config/timit_${s}/${x}.uid 51 | grep -w -f $uid_path $tgt_dir/all.phn | cut -d' ' -f2- > $tgt_dir/$s/$x.phn 52 | ln -sf $(realpath $tgt_dir/$s/$x.phn) $tgt_dir/$s/$x.wrd 53 | 54 | echo "/" > $tgt_dir/$s/$x.tsv && grep -w -f $uid_path $tgt_dir/all_wav_dur.scp | cut -d' ' -f2- | sed 's# #\t#' >> $tgt_dir/$s/$x.tsv 55 | done 56 | 57 | for x in $splits; do 58 | cat $tgt_dir/$s/$x.phn 59 | done | tr ' ' '\n' | sort -u | awk '{print $1" "1}' > $tgt_dir/$s/dict.phn.txt 60 | ln -sf $(realpath $tgt_dir/$s/dict.phn.txt) $tgt_dir/$s/dict.wrd.txt 61 | done 62 | echo "done preparing unmatched and matched setups for TIMIT" 63 | 64 | 65 | for s in $setups; do 66 | zsh scripts/prepare_audio.sh $tgt_dir/$s $tgt_dir/$s/feat $model 67 | 68 | lm_dir=$tgt_dir/$s/phones 69 | fst_dir=$tgt_dir/$s/fst/phn_to_phn 70 | 71 | python $FAIRSEQ_ROOT/fairseq_cli/preprocess.py --dataset-impl mmap --trainpref $tgt_dir/$s/train_text.phn --workers 10 --only-source --destdir $lm_dir --srcdict $tgt_dir/$s/dict.phn.txt 72 | $KENLM_ROOT/lmplz -o 3 < $tgt_dir/$s/train_text.phn --discount_fallback >$lm_dir/train_text_phn.03.arpa 73 | $KENLM_ROOT/build_binary $lm_dir/train_text_phn.03.arpa $lm_dir/train_text_phn.03.bin 74 | $KENLM_ROOT/lmplz -o 4 < $tgt_dir/$s/train_text.phn --discount_fallback >$lm_dir/train_text_phn.04.arpa 75 | $KENLM_ROOT/build_binary $lm_dir/train_text_phn.04.arpa $lm_dir/train_text_phn.04.bin 76 | 77 | python $FAIRSEQ_ROOT/examples/speech_recognition/kaldi/kaldi_initializer.py kaldi_root=$KALDI_ROOT fst_dir=$fst_dir lm_arpa=$lm_dir/train_text_phn.03.arpa data_dir=$tgt_dir/$s in_labels=phn 78 | done 79 | echo "done preprocessing audio and text for wav2vec-U" 80 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/scripts/remove_silence.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | get intervals from .vads file, specify output data, and this script removes silences and saves the audio data in out path folder 9 | paths=shards/train.tsv 10 | vads=shards/train.vads 11 | python remove_silence.py --paths $paths --vads $vads 12 | """ 13 | 14 | import os 15 | import argparse 16 | import torch 17 | import torchaudio 18 | import tqdm 19 | 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument("--tsv", default="", type=str) 23 | parser.add_argument("--vads", default="", type=str) 24 | parser.add_argument("--out", type=str) 25 | params = parser.parse_args() 26 | 27 | # load paths 28 | paths = [] 29 | with open(params.tsv) as f: 30 | root = next(f).rstrip() 31 | for line in f: 32 | paths.append(os.path.join(root, line.rstrip().split("\t")[0])) 33 | 34 | # load vads 35 | list_intervals = [] 36 | with open(params.vads) as f: 37 | for line in f: 38 | interval = [ 39 | [int(w.split(":")[0]), int(w.split(":")[1])] for w in line.rstrip().split() 40 | ] 41 | list_intervals.append(interval) 42 | 43 | 44 | # load audio and keep only intervals (i.e. remove silences) 45 | for i in tqdm.trange(len(paths)): 46 | data, _ = torchaudio.load(paths[i]) 47 | if len(list_intervals[i]) > 0: 48 | data_filtered = torch.cat( 49 | [data[0][int(it[0]) : int(it[1])] for it in list_intervals[i]] 50 | ).unsqueeze(0) 51 | else: 52 | data_filtered = data 53 | 54 | # YOU MAY NEED TO MODIFY THIS TO GET THE RIGHT SUBPATH 55 | # outpath = params.out + '/'.join(paths[i].split('/')[-1]) 56 | outpath = params.out + "/" + "/".join(paths[i].split("/")[-2:]) 57 | 58 | if not os.path.isdir("/".join(outpath.split("/")[:-1])): 59 | os.makedirs("/".join(outpath.split("/")[:-1])) 60 | if not os.path.exists(outpath): 61 | torchaudio.save(outpath, data_filtered, sample_rate=16000) 62 | else: 63 | print(outpath, "exists!") 64 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/scripts/vads.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import sys 9 | 10 | from copy import deepcopy 11 | from scipy.signal import lfilter 12 | 13 | import numpy as np 14 | from tqdm import tqdm 15 | import soundfile as sf 16 | import os.path as osp 17 | 18 | 19 | def get_parser(): 20 | parser = argparse.ArgumentParser(description="compute vad segments") 21 | parser.add_argument( 22 | "--rvad-home", 23 | "-r", 24 | help="path to rvad home (see https://github.com/zhenghuatan/rVADfast)", 25 | required=True, 26 | ) 27 | 28 | return parser 29 | 30 | 31 | def rvad(speechproc, path): 32 | winlen, ovrlen, pre_coef, nfilter, nftt = 0.025, 0.01, 0.97, 20, 512 33 | ftThres = 0.5 34 | vadThres = 0.4 35 | opts = 1 36 | 37 | data, fs = sf.read(path) 38 | assert fs == 16_000, "sample rate must be 16khz" 39 | ft, flen, fsh10, nfr10 = speechproc.sflux(data, fs, winlen, ovrlen, nftt) 40 | 41 | # --spectral flatness -- 42 | pv01 = np.zeros(ft.shape[0]) 43 | pv01[np.less_equal(ft, ftThres)] = 1 44 | pitch = deepcopy(ft) 45 | 46 | pvblk = speechproc.pitchblockdetect(pv01, pitch, nfr10, opts) 47 | 48 | # --filtering-- 49 | ENERGYFLOOR = np.exp(-50) 50 | b = np.array([0.9770, -0.9770]) 51 | a = np.array([1.0000, -0.9540]) 52 | fdata = lfilter(b, a, data, axis=0) 53 | 54 | # --pass 1-- 55 | noise_samp, noise_seg, n_noise_samp = speechproc.snre_highenergy( 56 | fdata, nfr10, flen, fsh10, ENERGYFLOOR, pv01, pvblk 57 | ) 58 | 59 | # sets noisy segments to zero 60 | for j in range(n_noise_samp): 61 | fdata[range(int(noise_samp[j, 0]), int(noise_samp[j, 1]) + 1)] = 0 62 | 63 | vad_seg = speechproc.snre_vad( 64 | fdata, nfr10, flen, fsh10, ENERGYFLOOR, pv01, pvblk, vadThres 65 | ) 66 | return vad_seg, data 67 | 68 | 69 | def main(): 70 | parser = get_parser() 71 | args = parser.parse_args() 72 | 73 | sys.path.append(args.rvad_home) 74 | import speechproc 75 | 76 | stride = 160 77 | lines = sys.stdin.readlines() 78 | root = lines[0].rstrip() 79 | for fpath in tqdm(lines[1:]): 80 | path = osp.join(root, fpath.split()[0]) 81 | vads, wav = rvad(speechproc, path) 82 | 83 | start = None 84 | vad_segs = [] 85 | for i, v in enumerate(vads): 86 | if start is None and v == 1: 87 | start = i * stride 88 | elif start is not None and v == 0: 89 | vad_segs.append((start, i * stride)) 90 | start = None 91 | if start is not None: 92 | vad_segs.append((start, len(wav))) 93 | 94 | print(" ".join(f"{v[0]}:{v[1]}" for v in vad_segs)) 95 | 96 | 97 | if __name__ == "__main__": 98 | main() 99 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/scripts/wav2vec_apply_cluster_faiss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import os 9 | import os.path as osp 10 | import numpy as np 11 | import tqdm 12 | import torch 13 | import sys 14 | 15 | import faiss 16 | import torch.nn.functional as F 17 | 18 | from wav2vec_cluster_faiss import parse_faiss_specs, Wav2VecFeatureReader 19 | 20 | 21 | def get_parser(): 22 | parser = argparse.ArgumentParser(description="apply clusters") 23 | # fmt: off 24 | parser.add_argument('data', help='location of tsv files') 25 | parser.add_argument('--split', help='split to process', required=True) 26 | parser.add_argument('--labels', help='split to process', default="phn") 27 | parser.add_argument('--path', help='path to pca and centroids', required=True) 28 | parser.add_argument('--checkpoint', type=str, help='checkpoint for wav2vec model (if using wav2vec features)', required=True) 29 | parser.add_argument('--layer', '-l', type=int, help='which layer to read', default=14) 30 | parser.add_argument('--max-tsz', type=int, help='batch kmeans up to this much', default=14) 31 | # fmt: on 32 | 33 | return parser 34 | 35 | 36 | def get_iterator(args): 37 | label_path = osp.join(args.data, f"{args.split}.{args.labels}") 38 | if osp.exists(label_path): 39 | lp = open(label_path, "r") 40 | else: 41 | lp = None 42 | 43 | with open(osp.join(args.data, f"{args.split}.tsv"), "r") as fp: 44 | lines = fp.read().split("\n") 45 | root = lines.pop(0).strip() 46 | files = [line.rstrip() for line in lines if len(line) > 0] 47 | 48 | if lp is not None: 49 | lbls = [line.rstrip() for line in lp] 50 | else: 51 | lbls = [None] * len(files) 52 | 53 | num = len(files) 54 | reader = Wav2VecFeatureReader(args.checkpoint, args.layer) 55 | 56 | def iterate(): 57 | for fname, lbl in zip(files, lbls): 58 | file = osp.join(root, fname.split("\t")[0]) 59 | feats = reader.get_feats(file) 60 | yield feats.data, fname, lbl 61 | 62 | return iterate, num, root 63 | 64 | 65 | def main(): 66 | parser = get_parser() 67 | args = parser.parse_args() 68 | 69 | spec = osp.basename(args.path) 70 | 71 | try: 72 | faiss_spec = parse_faiss_specs(spec.rstrip("/"))[0] 73 | except: 74 | print(spec) 75 | raise 76 | 77 | print("Faiss Spec:", faiss_spec, file=sys.stderr) 78 | 79 | if faiss_spec.pca: 80 | A = torch.from_numpy(np.load(osp.join(args.path, "pca_A.npy"))).cuda() 81 | b = torch.from_numpy(np.load(osp.join(args.path, "pca_b.npy"))).cuda() 82 | print("Loaded PCA", file=sys.stderr) 83 | 84 | centroids = np.load(osp.join(args.path, "centroids.npy")) 85 | print("Loaded centroids", centroids.shape, file=sys.stderr) 86 | 87 | res = faiss.StandardGpuResources() 88 | index_flat = ( 89 | faiss.IndexFlatL2(centroids.shape[1]) 90 | if not faiss_spec.sphere 91 | else faiss.IndexFlatIP(centroids.shape[1]) 92 | ) 93 | faiss_index = faiss.index_cpu_to_gpu(res, 0, index_flat) 94 | faiss_index.add(centroids) 95 | 96 | generator, num, root = get_iterator(args) 97 | iterator = generator() 98 | 99 | had_labels = False 100 | label_path = osp.join(args.path, f"{args.split}.{args.labels}") 101 | 102 | with torch.no_grad(): 103 | with open(osp.join(args.path, f"{args.split}.src"), "w") as fp, open( 104 | osp.join(args.path, f"{args.split}.tsv"), "w" 105 | ) as pp, open(label_path, "w") as lp: 106 | print(root, file=pp) 107 | for f, fname, lbl in tqdm.tqdm(iterator, total=num): 108 | if faiss_spec.pca: 109 | f = torch.mm(f, A) + b 110 | if faiss_spec.norm: 111 | f = F.normalize(f, p=2, dim=-1) 112 | 113 | f = f.cpu().numpy() 114 | 115 | _, z = faiss_index.search(f, 1) 116 | 117 | print(" ".join(str(x.item()) for x in z), file=fp) 118 | print(fname, file=pp) 119 | 120 | if lbl is not None: 121 | print(lbl, file=lp) 122 | had_labels = True 123 | if not had_labels: 124 | os.remove(label_path) 125 | 126 | 127 | if __name__ == "__main__": 128 | main() 129 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/scripts/wav2vec_cluster_faiss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import gc 9 | import os 10 | import os.path as osp 11 | import random 12 | import numpy as np 13 | import tqdm 14 | import torch 15 | 16 | from collections import namedtuple 17 | 18 | import faiss 19 | 20 | import fairseq 21 | import soundfile as sf 22 | 23 | 24 | def get_parser(): 25 | parser = argparse.ArgumentParser( 26 | description="compute kmeans codebook from kaldi-computed feats" 27 | ) 28 | # fmt: off 29 | parser.add_argument('data', help='location of tsv files') 30 | parser.add_argument('--save-dir', help='where to save the output', required=True) 31 | parser.add_argument('--checkpoint', type=str, help='checkpoint for wav2vec model (if using wav2vec features)', required=True) 32 | parser.add_argument('--sample-pct', '-r', type=float, help='percentage of timesteps to sample', default=0) 33 | parser.add_argument('--layer', '-l', type=int, help='which layer to read', default=14) 34 | parser.add_argument('--faiss-specs', '-f', type=str, 35 | help='faiss index specs; separated by space ' 36 | 'format is: PCAx_NORM_CLUSx_SPHERICAL -> ' 37 | 'PCAx if exists first apply PCA ' 38 | 'NORM if exists, normalize the vector by L2 norm ' 39 | 'CLUSx must exist, cluster to x clusters ' 40 | 'SPEHRICAL if exists, apply spherical kmeans', 41 | default='l2') 42 | # fmt: on 43 | 44 | return parser 45 | 46 | 47 | faiss_spec = namedtuple("faiss_spec", ["pca", "norm", "n_clus", "sphere", "spec_str"]) 48 | 49 | 50 | def parse_faiss_specs(specs_str): 51 | specs = [] 52 | for ss in specs_str.split(): 53 | comps = ss.split("_") 54 | pca = 0 55 | norm = False 56 | n_clus = 0 57 | sphere = False 58 | for c in comps: 59 | if c.startswith("PCA"): 60 | pca = int(c[3:]) 61 | elif c == "NORM": 62 | norm = True 63 | elif c.startswith("CLUS"): 64 | n_clus = int(c[4:]) 65 | elif c == "SPHERICAL": 66 | sphere = True 67 | assert n_clus > 0 68 | specs.append( 69 | faiss_spec(pca=pca, norm=norm, n_clus=n_clus, sphere=sphere, spec_str=ss) 70 | ) 71 | return specs 72 | 73 | 74 | class Wav2VecFeatureReader(object): 75 | def __init__(self, cp_file, layer): 76 | state = fairseq.checkpoint_utils.load_checkpoint_to_cpu(cp_file) 77 | 78 | self.layer = layer 79 | 80 | if "cfg" in state: 81 | w2v_args = state["cfg"] 82 | task = fairseq.tasks.setup_task(w2v_args.task) 83 | model = task.build_model(w2v_args.model) 84 | else: 85 | w2v_args = state["args"] 86 | task = fairseq.tasks.setup_task(w2v_args) 87 | model = task.build_model(w2v_args) 88 | model.load_state_dict(state["model"], strict=True) 89 | model.eval() 90 | model.cuda() 91 | self.model = model 92 | 93 | def read_audio(self, fname): 94 | """Load an audio file and return PCM along with the sample rate""" 95 | wav, sr = sf.read(fname) 96 | assert sr == 16e3 97 | 98 | return wav 99 | 100 | def get_feats(self, loc): 101 | x = self.read_audio(loc) 102 | with torch.no_grad(): 103 | source = torch.from_numpy(x).view(1, -1).float().cuda() 104 | res = self.model( 105 | source=source, mask=False, features_only=True, layer=self.layer 106 | ) 107 | return res["layer_results"][self.layer][0].squeeze(1) 108 | 109 | 110 | def get_iterator(args): 111 | with open(args.data, "r") as fp: 112 | lines = fp.read().split("\n") 113 | root = lines.pop(0).strip() 114 | files = [osp.join(root, line.split("\t")[0]) for line in lines if len(line) > 0] 115 | 116 | if getattr(args, "sample_pct", 0) > 0: 117 | files = random.sample(files, int(args.sample_pct * len(files))) 118 | num = len(files) 119 | reader = Wav2VecFeatureReader(args.checkpoint, args.layer) 120 | 121 | def iterate(): 122 | for fname in files: 123 | feats = reader.get_feats(fname) 124 | yield feats.cpu().numpy() 125 | 126 | return iterate, num 127 | 128 | 129 | def main(): 130 | parser = get_parser() 131 | args = parser.parse_args() 132 | 133 | faiss_specs = parse_faiss_specs(args.faiss_specs) 134 | print("Faiss Specs:", faiss_specs) 135 | 136 | feat_path = osp.join(args.save_dir, "features") 137 | if osp.exists(feat_path + ".npy"): 138 | feats = np.load(feat_path + ".npy") 139 | else: 140 | generator, num = get_iterator(args) 141 | iterator = generator() 142 | 143 | feats = [] 144 | for f in tqdm.tqdm(iterator, total=num): 145 | feats.append(f) 146 | 147 | del iterator 148 | del generator 149 | 150 | feats = np.concatenate(feats) 151 | 152 | print(feats.shape) 153 | 154 | os.makedirs(args.save_dir, exist_ok=True) 155 | # np.save(feat_path, feats) 156 | 157 | gc.collect() 158 | torch.cuda.empty_cache() 159 | 160 | reload = False 161 | for spec in faiss_specs: 162 | print("Processing spec", spec) 163 | 164 | if reload: 165 | print("Reloading...") 166 | del feats 167 | gc.collect() 168 | feats = np.load(feat_path + ".npy") 169 | 170 | save_path = osp.join(args.save_dir, spec.spec_str) 171 | os.makedirs(save_path, exist_ok=True) 172 | d = feats.shape[-1] 173 | x = feats 174 | if spec.pca > 0: 175 | print("Computing PCA") 176 | pca = faiss.PCAMatrix(d, spec.pca) 177 | pca.train(x) 178 | d = spec.pca 179 | b = faiss.vector_to_array(pca.b) 180 | A = faiss.vector_to_array(pca.A).reshape(pca.d_out, pca.d_in) 181 | np.save(osp.join(save_path, "pca_A"), A.T) 182 | np.save(osp.join(save_path, "pca_b"), b) 183 | print("Applying PCA") 184 | x = pca.apply_py(x) 185 | 186 | if spec.norm: 187 | reload = spec.pca <= 0 188 | print("Normalizing") 189 | faiss.normalize_L2(x) 190 | 191 | print("Computing kmeans") 192 | kmeans = faiss.Kmeans( 193 | d, 194 | spec.n_clus, 195 | niter=50, 196 | verbose=True, 197 | spherical=spec.sphere, 198 | max_points_per_centroid=feats.shape[0], 199 | gpu=True, 200 | nredo=3, 201 | ) 202 | kmeans.train(x) 203 | np.save(osp.join(save_path, "centroids"), kmeans.centroids) 204 | del kmeans 205 | del x 206 | gc.collect() 207 | 208 | 209 | if __name__ == "__main__": 210 | main() 211 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/scripts/wav2vec_extract_features.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import os 9 | import os.path as osp 10 | import tqdm 11 | import torch 12 | import torch.nn.functional as F 13 | from shutil import copyfile 14 | 15 | from npy_append_array import NpyAppendArray 16 | 17 | import fairseq 18 | import soundfile as sf 19 | 20 | 21 | def get_parser(): 22 | parser = argparse.ArgumentParser( 23 | description="compute kmeans codebook from kaldi-computed feats" 24 | ) 25 | # fmt: off 26 | parser.add_argument('data', help='location of tsv files') 27 | parser.add_argument('--split', help='which split to read', required=True) 28 | parser.add_argument('--save-dir', help='where to save the output', required=True) 29 | parser.add_argument('--checkpoint', type=str, help='checkpoint for wav2vec ctc model', required=True) 30 | parser.add_argument('--layer', type=int, default=14, help='which layer to use') 31 | # fmt: on 32 | 33 | return parser 34 | 35 | 36 | class Wav2VecFeatureReader(object): 37 | def __init__(self, cp_file, layer): 38 | model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task( 39 | [cp_file] 40 | ) 41 | model = model[0] 42 | model.eval() 43 | model.cuda() 44 | self.model = model 45 | self.task = task 46 | self.layer = layer 47 | 48 | def read_audio(self, fname): 49 | """Load an audio file and return PCM along with the sample rate""" 50 | wav, sr = sf.read(fname) 51 | assert sr == 16e3 52 | 53 | return wav 54 | 55 | def get_feats(self, loc): 56 | x = self.read_audio(loc) 57 | with torch.no_grad(): 58 | source = torch.from_numpy(x).float().cuda() 59 | if self.task.cfg.normalize: 60 | assert source.dim() == 1, source.dim() 61 | with torch.no_grad(): 62 | source = F.layer_norm(source, source.shape) 63 | source = source.view(1, -1) 64 | 65 | m_res = self.model(source=source, mask=False, features_only=True, layer=self.layer) 66 | return m_res["x"].squeeze(0).cpu() 67 | 68 | 69 | def get_iterator(args): 70 | with open(osp.join(args.data, args.split) + ".tsv", "r") as fp: 71 | lines = fp.read().split("\n") 72 | root = lines.pop(0).strip() 73 | files = [osp.join(root, line.split("\t")[0]) for line in lines if len(line) > 0] 74 | 75 | num = len(files) 76 | reader = Wav2VecFeatureReader(args.checkpoint, args.layer) 77 | 78 | def iterate(): 79 | for fname in files: 80 | w2v_feats = reader.get_feats(fname) 81 | yield w2v_feats 82 | 83 | return iterate, num 84 | 85 | 86 | def main(): 87 | parser = get_parser() 88 | args = parser.parse_args() 89 | 90 | os.makedirs(args.save_dir, exist_ok=True) 91 | 92 | def create_files(dest): 93 | copyfile(osp.join(args.data, args.split) + ".tsv", dest + ".tsv") 94 | if osp.exists(osp.join(args.data, args.split) + ".wrd"): 95 | copyfile(osp.join(args.data, args.split) + ".wrd", dest + ".wrd") 96 | if osp.exists(osp.join(args.data, args.split) + ".phn"): 97 | copyfile(osp.join(args.data, args.split) + ".phn", dest + ".phn") 98 | 99 | if osp.exists(dest + ".npy"): 100 | os.remove(dest + ".npy") 101 | npaa = NpyAppendArray(dest + ".npy") 102 | return npaa 103 | 104 | save_path = osp.join(args.save_dir, args.split) 105 | npaa = create_files(save_path) 106 | 107 | generator, num = get_iterator(args) 108 | iterator = generator() 109 | 110 | with open(save_path + ".lengths", "w") as l_f: 111 | for w2v_feats in tqdm.tqdm(iterator, total=num): 112 | print(len(w2v_feats), file=l_f) 113 | 114 | if len(w2v_feats) > 0: 115 | npaa.append(w2v_feats.numpy()) 116 | 117 | 118 | if __name__ == "__main__": 119 | main() 120 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/scripts/wer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Implement unsupervised metric for decoding hyperparameter selection: 9 | $$ alpha * LM_PPL + ViterbitUER(%) * 100 $$ 10 | """ 11 | import argparse 12 | import logging 13 | import sys 14 | 15 | import editdistance 16 | 17 | logging.root.setLevel(logging.INFO) 18 | logging.basicConfig(stream=sys.stdout, level=logging.INFO) 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | def get_parser(): 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("-s", "--hypo", help="hypo transcription", required=True) 25 | parser.add_argument( 26 | "-r", "--reference", help="reference transcription", required=True 27 | ) 28 | return parser 29 | 30 | 31 | def compute_wer(ref_uid_to_tra, hyp_uid_to_tra, g2p): 32 | d_cnt = 0 33 | w_cnt = 0 34 | w_cnt_h = 0 35 | for uid in hyp_uid_to_tra: 36 | ref = ref_uid_to_tra[uid].split() 37 | if g2p is not None: 38 | hyp = g2p(hyp_uid_to_tra[uid]) 39 | hyp = [p for p in hyp if p != "'" and p != " "] 40 | hyp = [p[:-1] if p[-1].isnumeric() else p for p in hyp] 41 | else: 42 | hyp = hyp_uid_to_tra[uid].split() 43 | d_cnt += editdistance.eval(ref, hyp) 44 | w_cnt += len(ref) 45 | w_cnt_h += len(hyp) 46 | wer = float(d_cnt) / w_cnt 47 | logger.debug( 48 | ( 49 | f"wer = {wer * 100:.2f}%; num. of ref words = {w_cnt}; " 50 | f"num. of hyp words = {w_cnt_h}; num. of sentences = {len(ref_uid_to_tra)}" 51 | ) 52 | ) 53 | return wer 54 | 55 | 56 | def main(): 57 | args = get_parser().parse_args() 58 | 59 | errs = 0 60 | count = 0 61 | with open(args.hypo, "r") as hf, open(args.reference, "r") as rf: 62 | for h, r in zip(hf, rf): 63 | h = h.rstrip().split() 64 | r = r.rstrip().split() 65 | errs += editdistance.eval(r, h) 66 | count += len(r) 67 | 68 | logger.info(f"UER: {errs / count * 100:.2f}%") 69 | 70 | 71 | if __name__ == "__main__": 72 | main() 73 | 74 | 75 | def load_tra(tra_path): 76 | with open(tra_path, "r") as f: 77 | uid_to_tra = {} 78 | for line in f: 79 | uid, tra = line.split(None, 1) 80 | uid_to_tra[uid] = tra 81 | logger.debug(f"loaded {len(uid_to_tra)} utterances from {tra_path}") 82 | return uid_to_tra 83 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/scripts/wrd_to_ltr.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import sys 8 | 9 | 10 | def main(): 11 | for line in sys.stdin: 12 | print(" ".join(list(line.strip().replace(" ", "|"))) + " |") 13 | 14 | 15 | if __name__ == "__main__": 16 | main() 17 | -------------------------------------------------------------------------------- /wav2vec 2.0/unsupervised/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .unpaired_audio_text import UnpairedAudioText 7 | 8 | 9 | __all__ = [ 10 | "UnpairedAudioText", 11 | ] 12 | -------------------------------------------------------------------------------- /wav2vec 2.0/vq-wav2vec_featurize.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Helper script to pre-compute embeddings for a flashlight (previously called wav2letter++) dataset 9 | """ 10 | 11 | import argparse 12 | import glob 13 | import os 14 | import os.path as osp 15 | import pprint 16 | 17 | import soundfile as sf 18 | import torch 19 | import fairseq 20 | from torch import nn 21 | from torch.utils.data import DataLoader 22 | 23 | 24 | try: 25 | import tqdm 26 | except: 27 | print("Install tqdm to use --log-format=tqdm") 28 | 29 | 30 | class FilesDataset: 31 | def __init__(self, files, labels): 32 | self.files = files 33 | if labels and osp.exists(labels): 34 | with open(labels, "r") as lbl_f: 35 | self.labels = [line.rstrip() for line in lbl_f] 36 | else: 37 | self.labels = labels 38 | 39 | def __len__(self): 40 | return len(self.files) 41 | 42 | def __getitem__(self, index): 43 | fname = self.files[index] 44 | 45 | wav, sr = sf.read(fname) 46 | assert sr == 16000 47 | 48 | wav = torch.from_numpy(wav).float() 49 | lbls = None 50 | if self.labels: 51 | if isinstance(self.labels, str): 52 | lbl_file = osp.splitext(fname)[0] + "." + self.labels 53 | with open(lbl_file, "r") as lblf: 54 | lbls = lblf.readline() 55 | assert lbls is not None 56 | else: 57 | lbls = self.labels[index] 58 | return wav, lbls 59 | 60 | def collate(self, batch): 61 | return batch 62 | 63 | 64 | class ArgTypes: 65 | @staticmethod 66 | def existing_path(arg): 67 | arg = str(arg) 68 | assert osp.exists(arg), f"File {arg} does not exist" 69 | return arg 70 | 71 | @staticmethod 72 | def mkdir(arg): 73 | arg = str(arg) 74 | os.makedirs(arg, exist_ok=True) 75 | return arg 76 | 77 | 78 | class DatasetWriter: 79 | def __init__(self): 80 | 81 | self.args = self.load_config() 82 | pprint.pprint(self.args.__dict__) 83 | 84 | self.model = self.load_model() 85 | 86 | def __getattr__(self, attr): 87 | return getattr(self.args, attr) 88 | 89 | def read_manifest(self, fname): 90 | 91 | with open(fname, "r") as fp: 92 | lines = fp.read().split("\n") 93 | root = lines.pop(0).strip() 94 | fnames = [ 95 | osp.join(root, line.split("\t")[0]) for line in lines if len(line) > 0 96 | ] 97 | 98 | return fnames 99 | 100 | def process_splits(self): 101 | 102 | if self.args.shard is not None or self.args.num_shards is not None: 103 | assert self.args.shard is not None and self.args.num_shards is not None 104 | 105 | for split in self.splits: 106 | print(split) 107 | 108 | if self.extension == "tsv": 109 | datadir = osp.join(self.data_dir, f"{split}.{self.extension}") 110 | print("Reading manifest file: ", datadir) 111 | files = self.read_manifest(datadir) 112 | else: 113 | datadir = osp.join(self.data_dir, split, f"**/*.{self.extension}") 114 | files = glob.glob(datadir, recursive=True) 115 | 116 | assert len(files) > 0 117 | 118 | if self.args.shard is not None: 119 | files = files[self.args.shard :: self.args.num_shards] 120 | 121 | lbls = [] 122 | with open(self.data_file(split), "w") as srcf: 123 | for line, lbl in self.iterate(files): 124 | print(line, file=srcf) 125 | if self.args.labels: 126 | lbls.append(lbl + "\n") 127 | 128 | if self.args.labels: 129 | assert all(a is not None for a in lbls) 130 | with open(self.lbl_file(split), "w") as lblf: 131 | lblf.writelines(lbls) 132 | 133 | def iterate(self, files): 134 | 135 | data = self.load_data(files) 136 | for samples in tqdm.tqdm(data, total=len(files) // 32): 137 | 138 | for wav, lbl in samples: 139 | x = wav.unsqueeze(0).float().cuda() 140 | 141 | div = 1 142 | while x.size(-1) // div > self.args.max_size: 143 | div += 1 144 | 145 | xs = x.chunk(div, dim=-1) 146 | 147 | result = [] 148 | for x in xs: 149 | torch.cuda.empty_cache() 150 | x = self.model.feature_extractor(x) 151 | if self.quantize_location == "encoder": 152 | with torch.no_grad(): 153 | _, idx = self.model.vector_quantizer.forward_idx(x) 154 | idx = idx.squeeze(0).cpu() 155 | else: 156 | with torch.no_grad(): 157 | z = self.model.feature_aggregator(x) 158 | _, idx = self.model.vector_quantizer.forward_idx(z) 159 | idx = idx.squeeze(0).cpu() 160 | result.append(idx) 161 | 162 | idx = torch.cat(result, dim=0) 163 | yield " ".join("-".join(map(str, a.tolist())) for a in idx), lbl 164 | 165 | def lbl_file(self, name): 166 | shard_part = "" if self.args.shard is None else f".{self.args.shard}" 167 | return osp.join(self.output_dir, f"{name}.lbl{shard_part}") 168 | 169 | def data_file(self, name): 170 | shard_part = "" if self.args.shard is None else f".{self.args.shard}" 171 | return osp.join(self.output_dir, f"{name}.src{shard_part}") 172 | 173 | def var_file(self): 174 | return osp.join(self.output_dir, f"vars.pt") 175 | 176 | def load_config(self): 177 | 178 | parser = argparse.ArgumentParser("Vector Quantized wav2vec features") 179 | 180 | # Model Arguments 181 | parser.add_argument("--checkpoint", type=ArgTypes.existing_path, required=True) 182 | parser.add_argument("--data-parallel", action="store_true") 183 | 184 | # Output Arguments 185 | parser.add_argument("--output-dir", type=ArgTypes.mkdir, required=True) 186 | 187 | # Data Arguments 188 | parser.add_argument("--data-dir", type=ArgTypes.existing_path, required=True) 189 | parser.add_argument("--splits", type=str, nargs="+", required=True) 190 | parser.add_argument("--extension", type=str, required=True) 191 | parser.add_argument("--labels", type=str, required=False) 192 | 193 | parser.add_argument("--shard", type=int, default=None) 194 | parser.add_argument("--num-shards", type=int, default=None) 195 | parser.add_argument("--max-size", type=int, default=1300000) 196 | 197 | # Logger Arguments 198 | parser.add_argument( 199 | "--log-format", type=str, choices=["none", "simple", "tqdm"] 200 | ) 201 | 202 | return parser.parse_args() 203 | 204 | def load_data(self, fnames): 205 | 206 | dataset = FilesDataset(fnames, self.args.labels) 207 | loader = DataLoader( 208 | dataset, batch_size=32, collate_fn=dataset.collate, num_workers=8 209 | ) 210 | return loader 211 | 212 | def load_model(self): 213 | model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([self.checkpoint]) 214 | model = model[0] 215 | 216 | self.quantize_location = getattr(cfg.model, "vq", "encoder") 217 | 218 | model.eval().float() 219 | model.cuda() 220 | 221 | if self.data_parallel: 222 | model = nn.DataParallel(model) 223 | 224 | return model 225 | 226 | def __call__(self): 227 | 228 | self.process_splits() 229 | 230 | if hasattr(self.model.feature_extractor, "vars") and ( 231 | self.args.shard is None or self.args.shard == 0 232 | ): 233 | vars = ( 234 | self.model.feature_extractor.vars.view( 235 | self.model.feature_extractor.banks, 236 | self.model.feature_extractor.num_vars, 237 | -1, 238 | ) 239 | .cpu() 240 | .detach() 241 | ) 242 | print("writing learned latent variable embeddings: ", vars.shape) 243 | torch.save(vars, self.var_file()) 244 | 245 | 246 | if __name__ == "__main__": 247 | write_data = DatasetWriter() 248 | 249 | write_data() 250 | print("Done.") 251 | -------------------------------------------------------------------------------- /wav2vec 2.0/wav2vec_featurize.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Helper script to pre-compute embeddings for a flashlight (previously called wav2letter++) dataset 9 | """ 10 | 11 | import argparse 12 | import glob 13 | import os 14 | from shutil import copy 15 | 16 | import h5py 17 | import numpy as np 18 | import soundfile as sf 19 | import torch 20 | import tqdm 21 | import fairseq 22 | from torch import nn 23 | 24 | 25 | def read_audio(fname): 26 | """ Load an audio file and return PCM along with the sample rate """ 27 | 28 | wav, sr = sf.read(fname) 29 | assert sr == 16e3 30 | 31 | return wav, 16e3 32 | 33 | 34 | class PretrainedWav2VecModel(nn.Module): 35 | def __init__(self, fname): 36 | super().__init__() 37 | 38 | model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([fname]) 39 | model = model[0] 40 | model.eval() 41 | 42 | self.model = model 43 | 44 | def forward(self, x): 45 | with torch.no_grad(): 46 | z = self.model.feature_extractor(x) 47 | if isinstance(z, tuple): 48 | z = z[0] 49 | c = self.model.feature_aggregator(z) 50 | return z, c 51 | 52 | 53 | class EmbeddingWriterConfig(argparse.ArgumentParser): 54 | def __init__(self): 55 | super().__init__("Pre-compute embeddings for flashlight datasets") 56 | 57 | kwargs = {"action": "store", "type": str, "required": True} 58 | 59 | self.add_argument("--input", "-i", help="Input Directory", **kwargs) 60 | self.add_argument("--output", "-o", help="Output Directory", **kwargs) 61 | self.add_argument("--model", help="Path to model checkpoint", **kwargs) 62 | self.add_argument("--split", help="Dataset Splits", nargs="+", **kwargs) 63 | self.add_argument( 64 | "--ext", default="wav", required=False, help="Audio file extension" 65 | ) 66 | 67 | self.add_argument( 68 | "--no-copy-labels", 69 | action="store_true", 70 | help="Do not copy label files. Useful for large datasets, use --targetdir in flashlight then.", 71 | ) 72 | self.add_argument( 73 | "--use-feat", 74 | action="store_true", 75 | help="Use the feature vector ('z') instead of context vector ('c') for features", 76 | ) 77 | self.add_argument("--gpu", help="GPU to use", default=0, type=int) 78 | 79 | 80 | class Prediction: 81 | """ Lightweight wrapper around a fairspeech embedding model """ 82 | 83 | def __init__(self, fname, gpu=0): 84 | self.gpu = gpu 85 | self.model = PretrainedWav2VecModel(fname).cuda(gpu) 86 | 87 | def __call__(self, x): 88 | x = torch.from_numpy(x).float().cuda(self.gpu) 89 | with torch.no_grad(): 90 | z, c = self.model(x.unsqueeze(0)) 91 | 92 | return z.squeeze(0).cpu().numpy(), c.squeeze(0).cpu().numpy() 93 | 94 | 95 | class H5Writer: 96 | """ Write features as hdf5 file in flashlight compatible format """ 97 | 98 | def __init__(self, fname): 99 | self.fname = fname 100 | os.makedirs(os.path.dirname(self.fname), exist_ok=True) 101 | 102 | def write(self, data): 103 | channel, T = data.shape 104 | 105 | with h5py.File(self.fname, "w") as out_ds: 106 | data = data.T.flatten() 107 | out_ds["features"] = data 108 | out_ds["info"] = np.array([16e3 // 160, T, channel]) 109 | 110 | 111 | class EmbeddingDatasetWriter(object): 112 | """Given a model and a flashlight dataset, pre-compute and store embeddings 113 | 114 | Args: 115 | input_root, str : 116 | Path to the flashlight dataset 117 | output_root, str : 118 | Desired output directory. Will be created if non-existent 119 | split, str : 120 | Dataset split 121 | """ 122 | 123 | def __init__( 124 | self, 125 | input_root, 126 | output_root, 127 | split, 128 | model_fname, 129 | extension="wav", 130 | gpu=0, 131 | verbose=False, 132 | use_feat=False, 133 | ): 134 | 135 | assert os.path.exists(model_fname) 136 | 137 | self.model_fname = model_fname 138 | self.model = Prediction(self.model_fname, gpu) 139 | 140 | self.input_root = input_root 141 | self.output_root = output_root 142 | self.split = split 143 | self.verbose = verbose 144 | self.extension = extension 145 | self.use_feat = use_feat 146 | 147 | assert os.path.exists(self.input_path), "Input path '{}' does not exist".format( 148 | self.input_path 149 | ) 150 | 151 | def _progress(self, iterable, **kwargs): 152 | if self.verbose: 153 | return tqdm.tqdm(iterable, **kwargs) 154 | return iterable 155 | 156 | def require_output_path(self, fname=None): 157 | path = self.get_output_path(fname) 158 | os.makedirs(path, exist_ok=True) 159 | 160 | @property 161 | def input_path(self): 162 | return self.get_input_path() 163 | 164 | @property 165 | def output_path(self): 166 | return self.get_output_path() 167 | 168 | def get_input_path(self, fname=None): 169 | if fname is None: 170 | return os.path.join(self.input_root, self.split) 171 | return os.path.join(self.get_input_path(), fname) 172 | 173 | def get_output_path(self, fname=None): 174 | if fname is None: 175 | return os.path.join(self.output_root, self.split) 176 | return os.path.join(self.get_output_path(), fname) 177 | 178 | def copy_labels(self): 179 | self.require_output_path() 180 | 181 | labels = list( 182 | filter( 183 | lambda x: self.extension not in x, glob.glob(self.get_input_path("*")) 184 | ) 185 | ) 186 | for fname in tqdm.tqdm(labels): 187 | copy(fname, self.output_path) 188 | 189 | @property 190 | def input_fnames(self): 191 | return sorted(glob.glob(self.get_input_path("*.{}".format(self.extension)))) 192 | 193 | def __len__(self): 194 | return len(self.input_fnames) 195 | 196 | def write_features(self): 197 | 198 | paths = self.input_fnames 199 | 200 | fnames_context = map( 201 | lambda x: os.path.join( 202 | self.output_path, x.replace("." + self.extension, ".h5context") 203 | ), 204 | map(os.path.basename, paths), 205 | ) 206 | 207 | for name, target_fname in self._progress( 208 | zip(paths, fnames_context), total=len(self) 209 | ): 210 | wav, sr = read_audio(name) 211 | z, c = self.model(wav) 212 | feat = z if self.use_feat else c 213 | writer = H5Writer(target_fname) 214 | writer.write(feat) 215 | 216 | def __repr__(self): 217 | 218 | return "EmbeddingDatasetWriter ({n_files} files)\n\tinput:\t{input_root}\n\toutput:\t{output_root}\n\tsplit:\t{split})".format( 219 | n_files=len(self), **self.__dict__ 220 | ) 221 | 222 | 223 | if __name__ == "__main__": 224 | 225 | args = EmbeddingWriterConfig().parse_args() 226 | 227 | for split in args.split: 228 | 229 | writer = EmbeddingDatasetWriter( 230 | input_root=args.input, 231 | output_root=args.output, 232 | split=split, 233 | model_fname=args.model, 234 | gpu=args.gpu, 235 | extension=args.ext, 236 | use_feat=args.use_feat, 237 | ) 238 | 239 | print(writer) 240 | writer.require_output_path() 241 | 242 | print("Writing Features...") 243 | writer.write_features() 244 | print("Done.") 245 | 246 | if not args.no_copy_labels: 247 | print("Copying label data...") 248 | writer.copy_labels() 249 | print("Done.") 250 | -------------------------------------------------------------------------------- /wav2vec 2.0/wav2vec_manifest.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | Data pre-processing: build vocabularies and binarize training data. 8 | """ 9 | 10 | import argparse 11 | import glob 12 | import os 13 | import random 14 | 15 | import soundfile 16 | 17 | 18 | def get_parser(): 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument( 21 | "root", metavar="DIR", help="root directory containing flac files to index" 22 | ) 23 | parser.add_argument( 24 | "--valid-percent", 25 | default=0.01, 26 | type=float, 27 | metavar="D", 28 | help="percentage of data to use as validation set (between 0 and 1)", 29 | ) 30 | parser.add_argument( 31 | "--dest", default=".", type=str, metavar="DIR", help="output directory" 32 | ) 33 | parser.add_argument( 34 | "--ext", default="flac", type=str, metavar="EXT", help="extension to look for" 35 | ) 36 | parser.add_argument("--seed", default=42, type=int, metavar="N", help="random seed") 37 | parser.add_argument( 38 | "--path-must-contain", 39 | default=None, 40 | type=str, 41 | metavar="FRAG", 42 | help="if set, path must contain this substring for a file to be included in the manifest", 43 | ) 44 | return parser 45 | 46 | 47 | def main(args): 48 | assert args.valid_percent >= 0 and args.valid_percent <= 1.0 49 | 50 | if not os.path.exists(args.dest): 51 | os.makedirs(args.dest) 52 | 53 | dir_path = os.path.realpath(args.root) 54 | search_path = os.path.join(dir_path, "**/*." + args.ext) 55 | rand = random.Random(args.seed) 56 | 57 | valid_f = ( 58 | open(os.path.join(args.dest, "valid.tsv"), "w") 59 | if args.valid_percent > 0 60 | else None 61 | ) 62 | 63 | with open(os.path.join(args.dest, "train.tsv"), "w") as train_f: 64 | print(dir_path, file=train_f) 65 | 66 | if valid_f is not None: 67 | print(dir_path, file=valid_f) 68 | 69 | for fname in glob.iglob(search_path, recursive=True): 70 | file_path = os.path.realpath(fname) 71 | 72 | if args.path_must_contain and args.path_must_contain not in file_path: 73 | continue 74 | 75 | frames = soundfile.info(fname).frames 76 | dest = train_f if rand.random() > args.valid_percent else valid_f 77 | print( 78 | "{}\t{}".format(os.path.relpath(file_path, dir_path), frames), file=dest 79 | ) 80 | if valid_f is not None: 81 | valid_f.close() 82 | 83 | 84 | if __name__ == "__main__": 85 | parser = get_parser() 86 | args = parser.parse_args() 87 | main(args) 88 | --------------------------------------------------------------------------------