├── Distil_Whisper.pdf ├── LICENSE ├── README.md └── training ├── Makefile ├── README.md ├── create_student_model.py ├── flax ├── LICENSE ├── Makefile ├── README.md ├── conversion_scripts │ └── run_convert_distilled_train_state_to_hf.sh ├── convert_train_state_to_hf.py ├── create_student_model.py ├── distil_whisper │ ├── __init__.py │ ├── layers.py │ ├── modeling_flax_whisper.py │ ├── partitioner.py │ ├── pipeline.py │ └── train_state.py ├── distillation_scripts │ ├── run_32_2_pt.sh │ ├── run_bs_sweep.yaml │ ├── run_dataset_sweep.yaml │ ├── run_decoder_sweep.yaml │ ├── run_distillation_12_2_timestamped.sh │ ├── run_distillation_15s_context.sh │ ├── run_distillation_16_2.sh │ ├── run_distillation_24_2.sh │ ├── run_distillation_24_2_timestamped.sh │ ├── run_distillation_32_2.sh │ ├── run_distillation_32_2_by_samples.sh │ ├── run_distillation_32_2_gpu.sh │ ├── run_distillation_32_2_timestamped.sh │ ├── run_distillation_large_32_2_gpu_timestamped.sh │ ├── run_distillation_objective.yaml │ ├── run_dropout_sweep.yaml │ ├── run_librispeech.sh │ ├── run_librispeech_dummy_pt.sh │ ├── run_librispeech_streaming_dummy.sh │ ├── run_lr_sweep.yaml │ ├── run_mse_sweep.yaml │ ├── run_timestamp_sweep.yaml │ └── run_wer_sweep.yaml ├── evaluation_scripts │ ├── run_baselines.sh │ ├── run_distilled.sh │ ├── run_distilled_16_2.sh │ ├── run_librispeech_eval_dummy.sh │ └── test │ │ ├── run_baselines.sh │ │ ├── run_baselines_pt.sh │ │ └── run_distilled.sh ├── finetuning_scripts │ ├── run_librispeech.sh │ ├── run_librispeech_dummy.sh │ ├── run_librispeech_eval.sh │ ├── run_librispeech_eval_dummy.sh │ └── run_librispeech_sweep.yaml ├── initialisation_scripts │ ├── run_large_32_2_init.sh │ ├── run_medium_24_2_init.sh │ ├── run_small_12_2_init.sh │ ├── run_tiny_2_1_init.sh │ └── run_tiny_2_1_init_pt.sh ├── latency_scripts │ ├── run_speculative.sh │ ├── run_speed.sh │ ├── run_speed_longform.sh │ └── run_trial.sh ├── long_form_transcription_scripts │ ├── run_chunk_length_s_sweep.yaml │ ├── run_eval_with_pipeline.sh │ ├── run_length_penalty_sweep.yaml │ ├── run_tedlium_long_form.sh │ ├── run_tedlium_long_form_dummy.sh │ ├── run_tedlium_long_form_timestamps.sh │ ├── run_top_k_temperature_sweep.yaml │ └── test │ │ ├── run_baselines.sh │ │ ├── run_baselines_pt.sh │ │ └── run_distilled.sh ├── noise_evaluation_scripts │ ├── run_baselines.sh │ ├── run_baselines_pt.sh │ ├── run_distilled.sh │ └── test │ │ ├── run_baselines.sh │ │ └── run_distilled.sh ├── pseudo_labelling_scripts │ ├── run_librispeech_pseudo_labelling.sh │ ├── run_librispeech_pseudo_labelling_dummy.sh │ ├── run_pseudo_labelling.sh │ ├── run_pseudo_labelling_2.sh │ ├── run_pseudo_labelling_dummy_pt.sh │ ├── run_pseudo_labelling_token_ids.sh │ └── run_pseudo_labelling_token_ids_2.sh ├── pyproject.toml ├── requirements.txt ├── run_distillation.py ├── run_eval.py ├── run_finetuning.py ├── run_long_form_transcription.py ├── run_orig_longform.sh ├── run_pseudo_labelling_pt.py ├── run_pt_long_form_transcription.py ├── run_speculative_decoding.py ├── run_speed.sh ├── run_speed_pt.py ├── setup.py └── tpu_connect.sh ├── pyproject.toml ├── run_distillation.py ├── run_eval.py ├── run_pseudo_labelling.py └── setup.py /Distil_Whisper.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/distil-whisper/cc96130f6e4cc74cab4545f3c6e7e5c204ced871/Distil_Whisper.pdf -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright 2023 The OpenAI Authors and The HuggingFace Inc. team. All rights reserved. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /training/Makefile: -------------------------------------------------------------------------------- 1 | check_dirs := . 2 | 3 | quality: 4 | black --check $(check_dirs) 5 | ruff $(check_dirs) 6 | 7 | style: 8 | black $(check_dirs) 9 | ruff $(check_dirs) --fix 10 | -------------------------------------------------------------------------------- /training/create_student_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Initialise a student Whisper model from a pre-trained teacher model for 18 | teacher-student distillation. 19 | """ 20 | 21 | import argparse 22 | import copy 23 | import logging 24 | 25 | import numpy as np 26 | import torch 27 | from transformers import GenerationConfig, WhisperForConditionalGeneration, WhisperProcessor 28 | 29 | 30 | logger = logging.getLogger(__name__) 31 | 32 | 33 | def parse_args(): 34 | parser = argparse.ArgumentParser( 35 | description="Initialise a student Whisper model from a teacher model, copying the relevant layer weights and adjusting the processor as necessary." 36 | ) 37 | parser.add_argument( 38 | "--teacher_checkpoint", 39 | type=str, 40 | required=True, 41 | help="The HF Hub ID of the teacher checkpoint.", 42 | ) 43 | parser.add_argument( 44 | "--subfolder", 45 | type=str, 46 | default="", 47 | help="In case the relevant teacher weights are located inside a subfolder of the model repo on huggingface.co, you " 48 | "can specify the folder name here.", 49 | ) 50 | parser.add_argument( 51 | "--encoder_layers", 52 | type=int, 53 | default=None, 54 | help="Number of encoder layers to use in the student model. Defaults to all layers from the teacher.", 55 | ) 56 | parser.add_argument( 57 | "--decoder_layers", 58 | type=int, 59 | default=2, 60 | help="Number of decoder layers to use in the student model. Defaults to 2 layers.", 61 | ) 62 | parser.add_argument( 63 | "--decoder_layers_numbers", 64 | type=int, 65 | nargs="*", 66 | help="Layers numbers of the decoder teacher to use in the student model. Defaults to None, equivalent to taking first and last layer (and equivalent to `--decoder_layers_numbers 0 -1`).", 67 | ) 68 | parser.add_argument( 69 | "--save_dir", 70 | type=str, 71 | required=True, 72 | help="Where to save the student weights and processor.", 73 | ) 74 | parser.add_argument( 75 | "--push_to_hub", 76 | type=bool, 77 | required=False, 78 | default=False, 79 | help="Whether to push the student weights and processor to the Hub.", 80 | ) 81 | parser.add_argument( 82 | "--cache_dir", 83 | type=str, 84 | default=None, 85 | help="Where to store the pretrained models downloaded from huggingface.co", 86 | ) 87 | 88 | args = parser.parse_args() 89 | return args 90 | 91 | 92 | def init_student_model_from_teacher( 93 | teacher_checkpoint, 94 | encoder_layers=None, 95 | decoder_layers=2, 96 | decoder_layers_numbers=None, 97 | save_dir=None, 98 | push_to_hub=None, 99 | cache_dir=None, 100 | subfolder="", 101 | ): 102 | if decoder_layers_numbers is not None and len(decoder_layers_numbers) != decoder_layers: 103 | raise ValueError( 104 | f"Got {len(decoder_layers_numbers)} layers number for {decoder_layers} decoder layers." 105 | ) 106 | 107 | teacher_model = WhisperForConditionalGeneration.from_pretrained( 108 | teacher_checkpoint, 109 | cache_dir=cache_dir, 110 | subfolder=subfolder, 111 | low_cpu_mem_usage=True, 112 | ) 113 | processor = WhisperProcessor.from_pretrained(teacher_checkpoint) 114 | generation_config = GenerationConfig.from_pretrained(teacher_checkpoint) 115 | generation_config.forced_decoder_ids = None 116 | 117 | teacher_config = teacher_model.config 118 | teacher_encoder_layers = teacher_config.encoder_layers 119 | teacher_decoder_layers = teacher_config.decoder_layers 120 | 121 | student_config = copy.deepcopy(teacher_config) 122 | student_config.update( 123 | { 124 | "encoder_layers": encoder_layers if encoder_layers is not None else teacher_encoder_layers, 125 | "decoder_layers": decoder_layers, 126 | } 127 | ) 128 | 129 | encoder_mapping = np.linspace(0, teacher_encoder_layers - 1, student_config.encoder_layers, dtype=int) 130 | encoder_mapping[-1] = teacher_encoder_layers - 1 131 | 132 | encoder_map = {} 133 | for student_layer, teacher_layer in enumerate(encoder_mapping): 134 | encoder_map[teacher_layer] = student_layer 135 | 136 | if decoder_layers_numbers is None: 137 | decoder_mapping = np.linspace(0, teacher_decoder_layers - 1, student_config.decoder_layers, dtype=int) 138 | decoder_mapping[-1] = teacher_decoder_layers - 1 139 | else: 140 | decoder_mapping = decoder_layers_numbers 141 | 142 | decoder_map = {} 143 | for student_layer, teacher_layer in enumerate(decoder_mapping): 144 | decoder_map[teacher_layer] = student_layer 145 | 146 | # init the student params from the teacher model 147 | student_model = WhisperForConditionalGeneration(student_config) 148 | missing_keys, unexpected_keys = student_model.load_state_dict(teacher_model.state_dict(), strict=False) 149 | if len(missing_keys) > 0: 150 | raise RuntimeError( 151 | "Error(s) in loading state_dict for WhisperForConditionalGeneration. \n" 152 | f"Missing key(s) in state_dict: {missing_keys}" 153 | ) 154 | if decoder_layers == teacher_decoder_layers: 155 | decoder_keys = [key for key in unexpected_keys if "model.decoder.layers" in key] 156 | if len(decoder_keys) > 0: 157 | raise RuntimeError( 158 | "Error(s) in loading state_dict for WhisperForConditionalGeneration. \n" 159 | f"Unexpected key(s) in state_dict: {decoder_keys}" 160 | ) 161 | if encoder_layers == teacher_encoder_layers: 162 | encoder_keys = [key for key in unexpected_keys if "model.encoder.layers" in key] 163 | if len(encoder_keys) > 0: 164 | raise RuntimeError( 165 | "Error(s) in loading state_dict for WhisperForConditionalGeneration. \n" 166 | f"Unexpected key(s) in state_dict: {encoder_keys}" 167 | ) 168 | 169 | for layer in range(teacher_decoder_layers): 170 | if layer in decoder_map: 171 | # re-introduce pre-defined layers from the teacher 172 | student_model.model.decoder.layers[decoder_map[layer]].load_state_dict( 173 | teacher_model.model.decoder.layers[layer].state_dict() 174 | ) 175 | 176 | if encoder_layers is not None: 177 | for layer in range(teacher_encoder_layers): 178 | if layer in encoder_map: 179 | # re-introduce pre-defined layers from the teacher 180 | student_model.model.encoder.layers[encoder_map[layer]].load_state_dict( 181 | teacher_model.model.encoder.layers[layer].state_dict() 182 | ) 183 | 184 | # remove the teacher params and model 185 | del teacher_model 186 | 187 | # save the converted weights and model 188 | if save_dir is not None: 189 | student_model.save_pretrained(save_dir) 190 | # we also need to correctly save the processor and generation config 191 | processor.save_pretrained(save_dir) 192 | generation_config.save_pretrained(save_dir) 193 | 194 | # check we can do a forward pass with the saved model - first load the weights and processor 195 | logger.info("Checking we can load the saved model...") 196 | student_model = WhisperForConditionalGeneration.from_pretrained( 197 | save_dir, 198 | low_cpu_mem_usage=True, 199 | ) 200 | processor = WhisperProcessor.from_pretrained(save_dir) 201 | 202 | # define some random inputs 203 | input_features = processor(np.ones(16000), sampling_rate=16000, return_tensors="pt").input_features 204 | decoder_start_token_id = student_model.config.decoder_start_token_id 205 | decoder_input_ids = torch.ones((input_features.shape[0], 1), dtype=torch.long) * decoder_start_token_id 206 | 207 | # do a forward pass - outputs will be gibberish for the initialised model so we can't check them 208 | # but we make can sure the model runs as expected 209 | logger.info("Checking we can run the converted model forward...") 210 | _ = student_model(input_features, decoder_input_ids=decoder_input_ids).logits 211 | logger.info("Conversion successful!") 212 | 213 | if push_to_hub: 214 | student_model.push_to_hub(save_dir) 215 | processor.push_to_hub(save_dir) 216 | generation_config.push_to_hub(save_dir) 217 | 218 | 219 | if __name__ == "__main__": 220 | args = parse_args() 221 | 222 | init_student_model_from_teacher( 223 | teacher_checkpoint=args.teacher_checkpoint, 224 | encoder_layers=args.encoder_layers, 225 | decoder_layers=args.decoder_layers, 226 | decoder_layers_numbers=args.decoder_layers_numbers, 227 | save_dir=args.save_dir, 228 | push_to_hub=args.push_to_hub, 229 | cache_dir=args.cache_dir, 230 | subfolder=args.subfolder, 231 | ) 232 | -------------------------------------------------------------------------------- /training/flax/LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /training/flax/Makefile: -------------------------------------------------------------------------------- 1 | check_dirs := . 2 | 3 | quality: 4 | black --check $(check_dirs) 5 | ruff $(check_dirs) 6 | 7 | style: 8 | black $(check_dirs) 9 | ruff $(check_dirs) --fix 10 | -------------------------------------------------------------------------------- /training/flax/README.md: -------------------------------------------------------------------------------- 1 | ## Reproducing Distil-Whisper 2 | 3 | This sub-folder contains all the training and inference scripts to reproduce the Distil-Whisper project. Distil-Whisper 4 | is written in JAX to leverage the fast training and inference speed offered by TPU v4 hardware. However, it also works 5 | efficiently on GPU hardware without any additional code changes. 6 | 7 | Reproducing the Distil-Whisper project requires four stages to be completed in successive order: 8 | 9 | 1. [Pseudo-labelling](#pseudo-labelling) 10 | 2. [Initialisation](#initialisation) 11 | 3. [Training](#training) 12 | 4. [Evaluation](#evaluation) 13 | 14 | This README is partitioned according to the four stages. Each section provides a minimal example for running the 15 | scripts used in the project. The final scripts used to train the model are referenced in-line. 16 | 17 | It is worth noting that the experiments performed in JAX/Flax have been on English ASR only. For multilingual training code, 18 | the [PyTorch Training Code](../README.md) can easily be used, facilitating anyone to run Whisper distillation on a language of their choice. 19 | 20 | ## Requirements 21 | 22 | Distil-Whisper is written in Python, JAX and Flax, and heavily leverages the Flax Whisper implementation in 23 | [🤗 Transformers](https://github.com/huggingface/transformers). The instructions for installing the package are as follows: 24 | 1. Install JAX from the [official instructions](https://github.com/google/jax#installation), ensuring you install the correct version for your hardware (GPU or TPU). 25 | 2. Install the `distil_whisper` package by cloning the repository and performing an editable installation: 26 | 27 | ```bash 28 | git clone https://github.com/huggingface/distil-whisper.git 29 | cd distil-whisper/training/flax 30 | pip install -e . 31 | ``` 32 | 33 | ## Pseudo-Labelling 34 | 35 | Pseudo-labelling is the process of generating target text predictions for the input audio data using the teacher model. 36 | The generated text labels then replace the ground truth text labels when performing distillation. The rationale for 37 | using pseudo-labels instead of ground truth labels is to circumvent the issue of inconsistent transcription formatting 38 | across datasets. 39 | 40 | The python script [`run_pseudo_labelling.py`](run_pseudo_labelling.py) is a flexible inference script that can be used 41 | to generate pseudo-labels under a range of settings, including using both greedy and beam-search. It is also compatible 42 | with [🤗 Datasets](https://github.com/huggingface/datasets) *streaming mode*, allowing users to load massive audio 43 | datasets with **no disk space requirements**. For more information on streaming mode, the reader is referred to the 44 | blog post: [A Complete Guide to Audio Datasets](https://huggingface.co/blog/audio-datasets#streaming-mode-the-silver-bullet). 45 | 46 | The following script demonstrates how to pseudo-label the [LibriSpeech 960h](https://huggingface.co/datasets/librispeech_asr) 47 | dataset with greedy sampling and streaming mode: 48 | 49 | ```bash 50 | #!/usr/bin/env bash 51 | 52 | python run_pseudo_labelling.py \ 53 | --model_name_or_path "openai/whisper-large-v2" \ 54 | --dataset_name "librispeech_asr" \ 55 | --dataset_config_name "all" \ 56 | --data_split_name "train.clean.100+train.clean.360+train.other.500" \ 57 | --text_column_name "text" \ 58 | --output_dir "./transcriptions" \ 59 | --per_device_eval_batch_size 16 \ 60 | --max_label_length 256 \ 61 | --dtype "bfloat16" \ 62 | --report_to "wandb" \ 63 | --dataloader_num_workers 16 \ 64 | --streaming \ 65 | --push_to_hub \ 66 | --generation_num_beams 1 # for greedy, set >1 for beam 67 | 68 | ``` 69 | 70 | The script will save the generated pseudo-labels alongside the file ids to the output directory `output_dir`. Adding the 71 | `--push_to_hub` argument uploads the generated pseudo-labels to the Hugging Face Hub on save. 72 | 73 | The directory [`pseudo_labelling_scripts`](pseudo_labelling_scripts) contains a collection of bash scripts for 74 | pseudo-labelling all 10 audio datasets used in the project. The datasets with the Whisper generated transcriptions 75 | can be found on the Hugging Face Hub under the [Distil Whisper organisation](https://huggingface.co/datasets?sort=trending&search=distil-whisper%2F). 76 | They can be re-used should you wish to bypass the data labelling stage of the reproduction. 77 | 78 | 79 | 80 | ## Initialisation 81 | 82 | The script [`create_student_model.py`](create_student_model.py) can be used to initialise a small student model 83 | from a large teacher model. When initialising a student model with fewer layers than the teacher model, the student is 84 | initialised by copying maximally spaced layers from the teacher, as per the [DistilBart](https://arxiv.org/abs/2010.13002) 85 | recommendations. 86 | 87 | The following command demonstrates how to initialise a student model from the [large-v2](https://huggingface.co/openai/whisper-large-v2) 88 | checkpoint, with all 32 encoder layer and 2 decoder layers. The 2 student decoder layers are copied from teacher layers 89 | 1 and 32 respectively, as the maximally spaced layers. 90 | 91 | ```bash 92 | #!/usr/bin/env bash 93 | 94 | python create_student_model.py \ 95 | --teacher_checkpoint "openai/whisper-large-v2" \ 96 | --encoder_layers 32 \ 97 | --decoder_layers 2 \ 98 | --save_dir "./large-32-2" \ 99 | --push_to_hub 100 | ``` 101 | 102 | 103 | ## Training 104 | 105 | The script [`run_distillation.py`](run_distillation.py) is an end-to-end script for loading multiple 106 | datasets, a student model, a teacher model, and performing teacher-student distillation. It uses the loss formulation 107 | from [DistilBart](https://arxiv.org/abs/2010.13002), which is a combination of a cross-entropy, KL-divergence and 108 | mean-square error (MSE) loss: 109 | 110 | https://github.com/huggingface/distil-whisper/blob/4dd831543e6c40b1159f1ec951db7f4fe0e86850/run_distillation.py#L1725 111 | 112 | The weight assigned to the MSE loss is configurable. The others are fixed to the values from the DistilBART paper. 113 | 114 | The following command takes the LibriSpeech 960h dataset that was pseudo-labelled in the first stage and trains the 115 | 2-layer decoder model intialised in the previous step. Note that multiple training datasets and splits can be loaded 116 | by separating the dataset arguments by `+` symbols. Thus, the script generalises to any number of training datasets. 117 | 118 | ```bash 119 | #!/usr/bin/env bash 120 | 121 | python3 run_distillation.py \ 122 | --model_name_or_path "./large-32-2" \ 123 | --teacher_model_name_or_path "openai/whisper-large-v2" \ 124 | --train_dataset_name "librispeech_asr+librispeech_asr+librispeech_asr" \ 125 | --train_dataset_config_name "all+all+all" \ 126 | --train_split_name "train.clean.100+train.clean.360+train.other.500" \ 127 | --train_dataset_samples "100+360+500" \ 128 | --eval_dataset_name "librispeech_asr" \ 129 | --eval_dataset_config_name "all" \ 130 | --eval_split_name "validation.clean" \ 131 | --eval_steps 5000 \ 132 | --save_steps 5000 \ 133 | --warmup_steps 500 \ 134 | --learning_rate 0.0001 \ 135 | --lr_scheduler_type "constant_with_warmup" \ 136 | --logging_steps 25 \ 137 | --save_total_limit 1 \ 138 | --max_steps 20000 \ 139 | --wer_threshold 10 \ 140 | --per_device_train_batch_size 64 \ 141 | --per_device_eval_batch_size 64 \ 142 | --dataloader_num_workers 16 \ 143 | --dtype "bfloat16" \ 144 | --output_dir "./" \ 145 | --do_train \ 146 | --do_eval \ 147 | --use_scan \ 148 | --gradient_checkpointing \ 149 | --overwrite_output_dir \ 150 | --predict_with_generate \ 151 | --freeze_encoder \ 152 | --streaming \ 153 | --use_auth_token \ 154 | --push_to_hub 155 | 156 | ``` 157 | 158 | The above training script will take approximately 20 hours to complete on a TPU v4-8 and yield a final WER of 2.3%. 159 | 160 | Training logs will be reported to TensorBoard and WandB, provided the relevant packages are available. An example of a 161 | saved checkpoint pushed to the Hugging Face Hub can be found here: [large-32-2](https://huggingface.co/distil-whisper/large-32-2). 162 | 163 | There are a few noteworthy arguments that can be configured to give optimal training performance: 164 | * `train_dataset_samples`: defines the number of training samples in each dataset. Used to calculate the sampling probabilities in the dataloader. A good starting point is setting the samples to the number of hours of audio data in each split. A more refined strategy is setting it to the number of training samples in each split, however this might require downloading the dataset offline to compute these statistics. 165 | * `wer_threshold`: sets the WER threshold between the normalised pseudo-labels and normalised ground truth labels. Any samples with WER > `wer_threshold` are discarded from the training data. This is beneficial to avoid training the student model on pseudo-labels where Whisper hallucinated or got the predictions grossly wrong. 166 | * `freeze_encoder`: whether to freeze the entire encoder of the student model during training. Beneficial when the student encoder is copied exactly from the teacher encoder. In this case, the encoder hidden-states from the teacher model are re-used for the student model. Stopping the gradient computation through the encoder and sharing the encoder hidden-states provides a significant memory saving, and can enable up to 2x batch sizes. 167 | * `dtype`: data type (dtype) in which the model computation should be performed. Note that this only controls the dtype of the computations (forward and backward pass), and not the dtype of the parameters or optimiser states. 168 | 169 | The Distil Whisper project extends the above script to train on a combined dataset formed from 12 open-source ASR datasets, 170 | totalling 22k hours and over 50k speakers. Template scripts to run training on this composite dataset can be found 171 | in the directory [`distillation_scripts`](distillation_scripts). 172 | 173 | ## Evaluation 174 | 175 | There are two types of evaluation performed in Distil-Whisper: 176 | 1. Short form: evaluation on audio samples less than 30s in duration. Examples include typical ASR test sets, such as the LibriSpeech validation set. 177 | 2. Long form: evaluation on audio samples longer than 30s in duration. Examples include entire TED talks or earnings calls. 178 | 179 | Both forms of evaluation are performed using the *word-error rate (WER)* metric. 180 | 181 | ### Short Form 182 | 183 | The script [`run_eval.py`](run_eval.py) can be used to evaluate a trained student model over multiple validation sets. 184 | The following example demonstrates how to evaluate the student model trained in the previous step on the LibriSpeech 185 | `validation.clean` and `validation.other` dev sets. Again, it leverages streaming mode to bypass the need to download 186 | the data offline: 187 | 188 | ```bash 189 | #!/usr/bin/env bash 190 | 191 | python run_eval.py \ 192 | --model_name_or_path "./large-32-2" \ 193 | --dataset_name "librispeech_asr+librispeech_asr" \ 194 | --dataset_config_name "all+all" \ 195 | --dataset_split_name "validation.clean+validation.other" \ 196 | --output_dir "./large-32-2" \ 197 | --per_device_eval_batch_size 64 \ 198 | --dtype "bfloat16" \ 199 | --dataloader_num_workers 16 \ 200 | --report_to "wandb" \ 201 | --streaming \ 202 | --predict_with_generate 203 | 204 | ``` 205 | 206 | ### Long Form 207 | 208 | Long form evaluation runs on the premise that a single long audio file can be *chunked* into smaller segments and 209 | inferred in parallel. The resulting transcriptions are then joined at the boundaries to give the final text prediction. 210 | A small overlap (or *stride*) is used between adjacent segments to ensure a continuous transcription across chunks. 211 | 212 | This style of chunked inference is performed using the [`FlaxWhisperPipeline`](https://github.com/huggingface/distil-whisper/blob/6426022e3b3a0a498b4150a636b54e2e3898bf1a/distil_whisper/pipeline.py#L61) 213 | class, which is heavily inspired from [Whisper JAX](https://github.com/sanchit-gandhi/whisper-jax/tree/main#pipeline-usage). 214 | 215 | The script [`run_long_form_transcription.py`](run_long_form_transcription.py) can be used to evaluate the trained 216 | student model on an arbitrary number of long-form evaluation sets. The following script demonstrates how to evaluate 217 | the example student model on two such test sets, [Earnings 21](https://huggingface.co/datasets/distil-whisper/earnings21) 218 | and [Earnings 22](https://huggingface.co/datasets/distil-whisper/earnings22): 219 | 220 | ```bash 221 | #!/usr/bin/env bash 222 | 223 | python run_long_form_transcription.py \ 224 | --model_name_or_path "./large-32-2" \ 225 | --dataset_name "distil-whisper/earnings21+distil-whisper/earnings22" \ 226 | --dataset_config_name "default+default" \ 227 | --dataset_split_name "test+test+test+test" \ 228 | --text_column_name "transcription+transcription" \ 229 | --output_dir "./large-32-2" \ 230 | --per_device_eval_batch_size 64 \ 231 | --chunk_length_s 15 \ 232 | --dtype "bfloat16" \ 233 | --report_to "wandb" \ 234 | --streaming 235 | 236 | ``` 237 | 238 | The argument `chunk_length_s` controls the length of the chunked audio samples. It should be set to match the typical 239 | length of audio the student model was trained on. If unsure about what value of `chunk_length_s` is optimal for your case, 240 | it is recommended to run a *sweep* over all possible values. A template script for running a [WandB sweep](https://docs.wandb.ai/guides/sweeps) 241 | can be found under [`run_chunk_length_s_sweep.yaml`](long_form_transcription_scripts/run_chunk_length_s_sweep.yaml). 242 | 243 | ### 1. Pseudo Labelling 244 | 245 | #### Greedy vs Beam 246 | 247 | We found there to be little-to-no difference in the downstream performance of the distilled model after pseudo labelling 248 | using either greedy or beam-search. We attribute this to the minimal difference in performance of the pre-trained Whisper 249 | model under greedy and beam-search decoding, giving pseudo-labelled transcriptions of similar quality. We encourage 250 | users to generate pseudo-labels using greedy decoding given it runs significantly faster. Beam search is only advised if 251 | the pre-trained model is hallucinating significantly on the audio inputs, in which case it helps reduce the frequency and 252 | severity of hallucinations. If using beam search, the number of beams can be kept low: even 2 beams helps reduce the 253 | amount of hallucinations significantly. 254 | 255 | #### Timestamps 256 | 257 | Whisper is trained on a timestamp prediction task as part of the pre-training set-up. Here, a fixed proportion of the 258 | pre-training data includes sequence-level *timestamps* as part of the transcription labels: 259 | 260 | ```bash 261 | <|0.00|> Hey, this is a test transcription. <|3.42|> 262 | ``` 263 | 264 | Timestamp prediction is useful for enriching the transcriptions with timing information for downstream tasks, such as 265 | aligning the Whisper transcription with the output of a speaker diarization system, and also reduces the frequency of 266 | hallucinations. 267 | 268 | The pseudo-labelling scrip [`run_pseudo_labelling.py`](run_pseudo_labelling.py) can be extended to predict timestamp 269 | information in the audio data by appending the `--return_timestamps` flag to the launch command. The timestamped labelled 270 | data can be passed to the training script in exactly the same way as the non-timestamped version, and the pre-processing 271 | function will take care of encoding the timestamps and appending the required task tokens. 272 | 273 | #### Previous Context 274 | 275 | Whisper is also pre-trained on a prompting task, where the transcription for the preceding utterance is fed as context 276 | to the current one: 277 | 278 | ```bash 279 | <|startofprev|> This is the previous context from the preceding utterance.<|startoftranscript|> And this is the current utterance.<|endoftranscript|> 280 | ``` 281 | 282 | Annotating the transcriptions with previous context labels is only possible for datasets where we have consecutive files 283 | and unique speaker ids, since we need to ensure segment `i` directly follows on from segment `i-1` if we use it as the 284 | prompt. 285 | 286 | As per the Whisper paper, we mask out the loss over the previous context tokens. At inference time, we can replace the 287 | previous context with a “prompt” to encourage the model to generate text in the style of the prompt (i.e. for specific 288 | named entities, or styles of transcription) 289 | 290 | ## Acknowledgements 291 | 292 | * 🤗 Hugging Face Transformers for the base Whisper implementation 293 | * Google's [TPU Research Cloud (TRC)](https://sites.research.google/trc/about/) programme for their generous provision of Cloud TPUs 294 | -------------------------------------------------------------------------------- /training/flax/conversion_scripts/run_convert_distilled_train_state_to_hf.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD=10000000000 python convert_train_state_to_hf.py \ 4 | --model_name_or_path "distil-whisper/large-32-2" \ 5 | --output_dir "./" \ 6 | --resume_from_checkpoint "checkpoint-15000" \ 7 | --cache_dir "/home/sanchitgandhi/.cache" \ 8 | --use_scan 9 | -------------------------------------------------------------------------------- /training/flax/convert_train_state_to_hf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Convert a Flax training state to HF Transformers Whisper weights. 18 | """ 19 | 20 | import logging 21 | import os 22 | import sys 23 | from dataclasses import field 24 | from pathlib import Path 25 | from typing import Callable, Optional 26 | 27 | import flax 28 | import jax 29 | import jax.numpy as jnp 30 | import optax 31 | from flax import jax_utils, traverse_util 32 | from flax.serialization import from_bytes 33 | from flax.training import train_state 34 | from flax.training.common_utils import shard_prng_key 35 | from huggingface_hub import Repository, create_repo 36 | from optax._src import linear_algebra 37 | from transformers import ( 38 | AutoConfig, 39 | HfArgumentParser, 40 | Seq2SeqTrainingArguments, 41 | ) 42 | from transformers.file_utils import get_full_repo_name 43 | from transformers.utils import check_min_version 44 | from transformers.utils.versions import require_version 45 | 46 | from distil_whisper import FlaxWhisperForConditionalGeneration 47 | 48 | 49 | # initialise JAX for multi-host set-up on TPU 50 | jax.distributed.initialize() 51 | 52 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 53 | check_min_version("4.27.0.dev0") 54 | 55 | require_version( 56 | "datasets>=1.18.0", 57 | "To fix: pip install -r examples/flax/speech-recogintion/requirements.txt", 58 | ) 59 | 60 | logger = logging.getLogger(__name__) 61 | 62 | 63 | @flax.struct.dataclass 64 | class ModelArguments: 65 | """ 66 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 67 | """ 68 | 69 | model_name_or_path: str = field( 70 | metadata={"help": ("Path to pretrained student model or model identifier from huggingface.co/models")} 71 | ) 72 | config_name: Optional[str] = field( 73 | default=None, 74 | metadata={"help": "Pretrained config name or path if not the same as model_name"}, 75 | ) 76 | cache_dir: Optional[str] = field( 77 | default=None, 78 | metadata={"help": ("Where to store the pretrained models downloaded from huggingface.co")}, 79 | ) 80 | use_fast_tokenizer: bool = field( 81 | default=True, 82 | metadata={"help": ("Whether to use one of the fast tokenizer (backed by the tokenizers library) or not.")}, 83 | ) 84 | model_revision: str = field( 85 | default="main", 86 | metadata={"help": ("The specific model version to use (can be a branch name, tag name or commit id).")}, 87 | ) 88 | use_auth_token: bool = field( 89 | default=False, 90 | metadata={ 91 | "help": ( 92 | "Will use the token generated when running `transformers-cli login`" 93 | " (necessary to use this script with private models)." 94 | ) 95 | }, 96 | ) 97 | dtype: Optional[str] = field( 98 | default="float32", 99 | metadata={ 100 | "help": ( 101 | "Floating-point format in which the model weights should be initialized" 102 | " and trained. Choose one of `[float32, float16, bfloat16]`." 103 | ) 104 | }, 105 | ) 106 | load_with_scan_weights: bool = field( 107 | default=False, 108 | metadata={ 109 | "help": "Whether the pre-trained checkpoint has its weights stored in scan format. Set to True for scanned " 110 | "weights, defaults to False for non-scan (unrolled) weights." 111 | }, 112 | ) 113 | use_scan: bool = field( 114 | default=True, 115 | metadata={"help": ("Whether or not to use `scan_with_axes` over the encoder and decoder blocks.")}, 116 | ) 117 | 118 | 119 | def create_learning_rate_fn( 120 | num_train_steps: int, lr_scheduler_type: str, num_warmup_steps: int, learning_rate: float 121 | ) -> Callable[[int], jnp.array]: 122 | """Returns a linear warmup, linear_decay learning rate function.""" 123 | lr_scheduler_types = ("linear", "constant_with_warmup") 124 | 125 | if lr_scheduler_type not in lr_scheduler_types: 126 | raise ValueError( 127 | f"lr_scheduler_type of type {lr_scheduler_type} not supported, choose from {lr_scheduler_types}." 128 | ) 129 | 130 | warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) 131 | decay_fn = optax.linear_schedule( 132 | init_value=learning_rate, 133 | end_value=0 if lr_scheduler_type == "linear" else learning_rate, 134 | transition_steps=num_train_steps - num_warmup_steps, 135 | ) 136 | schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) 137 | return schedule_fn 138 | 139 | 140 | class TrainState(train_state.TrainState): 141 | dropout_rng: jnp.ndarray 142 | max_grad_norm: float 143 | 144 | def apply_gradients(self, *, grads, **kwargs): 145 | """Updates `step`, `params`, `opt_state` and `**kwargs` in return value, clipping the 146 | gradients by the maximum grad norm. 147 | 148 | Note that internally this function calls `.tx.update()` followed by a call 149 | to `optax.apply_updates()` to update `params` and `opt_state`. 150 | 151 | Args: 152 | grads: Gradients that have the same pytree structure as `.params`. 153 | **kwargs: Additional dataclass attributes that should be `.replace()`-ed. 154 | 155 | Returns: 156 | An updated instance of `self` with `step` incremented by one, `params` 157 | and `opt_state` updated by applying `grads`, and additional attributes 158 | replaced as specified by `kwargs`. 159 | """ 160 | # clip gradients by global l2 norm 161 | g_norm = linear_algebra.global_norm(grads) 162 | g_norm = jnp.maximum(self.max_grad_norm, g_norm) 163 | grads = jax.tree_map(lambda t: (t / g_norm) * self.max_grad_norm, grads) 164 | 165 | updates, new_opt_state = self.tx.update(grads, self.opt_state, self.params) 166 | new_params = optax.apply_updates(self.params, updates) 167 | 168 | return self.replace( 169 | step=self.step + 1, 170 | params=new_params, 171 | opt_state=new_opt_state, 172 | **kwargs, 173 | ) 174 | 175 | def replicate(self): 176 | return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng)) 177 | 178 | def unreplicate(self): 179 | return jax_utils.unreplicate(self) 180 | 181 | 182 | def main(): 183 | # 1. Parse input arguments 184 | # See all possible arguments in src/transformers/training_args.py 185 | # or by passing the --help flag to this script. 186 | # We now keep distinct sets of args, for a cleaner separation of concerns. 187 | parser = HfArgumentParser( 188 | ( 189 | ModelArguments, 190 | Seq2SeqTrainingArguments, 191 | ) 192 | ) 193 | 194 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 195 | # If we pass only one argument to the script and it's the path to a json file, 196 | # let's parse it to get our arguments. 197 | model_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 198 | else: 199 | model_args, training_args = parser.parse_args_into_dataclasses() 200 | 201 | # Handle the repository creation 202 | if training_args.push_to_hub: 203 | if training_args.hub_model_id is None: 204 | repo_name = get_full_repo_name( 205 | Path(training_args.output_dir).absolute().name, 206 | token=training_args.hub_token, 207 | ) 208 | else: 209 | repo_name = training_args.hub_model_id 210 | create_repo(repo_name, exist_ok=True, token=training_args.hub_token) 211 | repo = Repository( 212 | training_args.output_dir, 213 | clone_from=repo_name, 214 | token=training_args.hub_token, 215 | ) 216 | 217 | # 5. Load pretrained config, model and processor 218 | config = AutoConfig.from_pretrained( 219 | (model_args.config_name if model_args.config_name else model_args.model_name_or_path), 220 | cache_dir=model_args.cache_dir, 221 | revision=model_args.model_revision, 222 | use_auth_token=True if model_args.use_auth_token else None, 223 | ) 224 | student_model, student_params = FlaxWhisperForConditionalGeneration.from_pretrained( 225 | model_args.model_name_or_path, 226 | config=config, 227 | dtype=getattr(jnp, model_args.dtype), 228 | cache_dir=model_args.cache_dir, 229 | revision=model_args.model_revision, 230 | use_auth_token=True if model_args.use_auth_token else None, 231 | _do_init=False, 232 | use_scan=model_args.load_with_scan_weights, 233 | ) 234 | 235 | # enable scan / gradient checkpointing if necessary in the student model 236 | if model_args.use_scan: 237 | student_model.enable_scan() # to enable scan in the nn.Module 238 | student_params = student_model.convert_unroll_to_scan(student_params) # to convert the unrolled params to scan 239 | 240 | # Initialize our student state 241 | rng = jax.random.PRNGKey(training_args.seed) 242 | rng, dropout_rng = jax.random.split(rng) 243 | 244 | total_train_steps = int(training_args.max_steps) 245 | 246 | # Create learning rate schedule 247 | linear_decay_lr_schedule_fn = create_learning_rate_fn( 248 | total_train_steps, 249 | training_args.lr_scheduler_type, 250 | training_args.warmup_steps, 251 | training_args.learning_rate, 252 | ) 253 | 254 | # We use Optax's "masking" functionality to not apply weight decay 255 | # to bias and LayerNorm scale parameters. decay_mask_fn returns a 256 | # mask boolean with the same structure as the parameters. 257 | # The mask is True for parameters that should be decayed. 258 | def decay_mask_fn(params): 259 | flat_params = traverse_util.flatten_dict(params) 260 | # find out all LayerNorm parameters 261 | layer_norm_candidates = [ 262 | "layer_norm", 263 | "self_attn_layer_norm", 264 | "final_layer_norm", 265 | "encoder_attn_layer_norm", 266 | ] 267 | layer_norm_named_params = { 268 | layer[-2:] 269 | for layer_norm_name in layer_norm_candidates 270 | for layer in flat_params.keys() 271 | if layer_norm_name in "".join(layer).lower() 272 | } 273 | flat_mask = {path: path[-1] != "bias" and path[-2:] not in layer_norm_named_params for path in flat_params} 274 | return traverse_util.unflatten_dict(flat_mask) 275 | 276 | # create adam optimizer 277 | adamw = optax.adamw( 278 | learning_rate=linear_decay_lr_schedule_fn, 279 | b1=training_args.adam_beta1, 280 | b2=training_args.adam_beta2, 281 | eps=training_args.adam_epsilon, 282 | weight_decay=training_args.weight_decay, 283 | mask=decay_mask_fn, 284 | ) 285 | 286 | # Setup train state 287 | student_state = TrainState.create( 288 | apply_fn=student_model.__call__, 289 | params=student_params, 290 | tx=adamw, 291 | dropout_rng=dropout_rng, 292 | max_grad_norm=training_args.max_grad_norm, 293 | ) 294 | 295 | if training_args.resume_from_checkpoint is not None: 296 | if os.path.isfile(os.path.join(training_args.resume_from_checkpoint, "train_state.msgpack")): 297 | logger.info( 298 | f"Checkpoint detected, resuming training at {training_args.resume_from_checkpoint}. To avoid " 299 | "this behavior, omit the resume_from_checkpoint argument." 300 | ) 301 | with Path(os.path.join(training_args.resume_from_checkpoint, "train_state.msgpack")).open("rb") as f: 302 | student_state = from_bytes(student_state, f.read()) 303 | else: 304 | logger.warning( 305 | f"Checkpoint {training_args.resume_from_checkpoint} not detected, training from scratch. Ensure " 306 | f"you pass the path to a folder with a valid checkpoint for your model." 307 | ) 308 | 309 | cur_step = int(jax.device_get(student_state.step)) 310 | 311 | # save weights in HF Transformers format 312 | if jax.process_index() == 0: 313 | student_model.disable_scan() 314 | student_state_params = student_model.convert_scan_to_unroll(student_state.params) 315 | student_params = jax.device_get(student_state_params) 316 | student_model.save_pretrained( 317 | os.path.join(training_args.output_dir, f"checkpoint-{cur_step}"), params=student_params 318 | ) 319 | if training_args.push_to_hub: 320 | repo.push_to_hub( 321 | commit_message=f"Saving weights of step {cur_step}", 322 | blocking=False, 323 | ) 324 | 325 | 326 | if __name__ == "__main__": 327 | main() 328 | -------------------------------------------------------------------------------- /training/flax/create_student_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Initialise a student Whisper model from a pre-trained teacher model for 18 | teacher-student distillation. 19 | """ 20 | 21 | import argparse 22 | import copy 23 | import logging 24 | 25 | import jax 26 | import numpy as np 27 | from flax.core import freeze, unfreeze 28 | from transformers import GenerationConfig, WhisperFeatureExtractor, WhisperProcessor 29 | 30 | from distil_whisper import FlaxWhisperForConditionalGeneration 31 | 32 | 33 | logger = logging.getLogger(__name__) 34 | 35 | 36 | def parse_args(): 37 | parser = argparse.ArgumentParser( 38 | description="Initialise a student Whisper model from a teacher model, copying the relevant layer weights and adjusting the processor as necessary." 39 | ) 40 | parser.add_argument( 41 | "--teacher_checkpoint", 42 | type=str, 43 | required=True, 44 | help="The HF Hub ID of the teacher checkpoint.", 45 | ) 46 | parser.add_argument( 47 | "--subfolder", 48 | type=str, 49 | default="", 50 | help="In case the relevant teacher weights are located inside a subfolder of the model repo on huggingface.co, you " 51 | "can specify the folder name here.", 52 | ) 53 | parser.add_argument( 54 | "--encoder_layers", 55 | type=int, 56 | default=None, 57 | help="Number of encoder layers to use in the student model. Defaults to all layers from the teacher.", 58 | ) 59 | parser.add_argument( 60 | "--decoder_layers", 61 | type=int, 62 | default=2, 63 | help="Number of decoder layers to use in the student model. Defaults to 2 layers.", 64 | ) 65 | parser.add_argument( 66 | "--max_source_positions", 67 | type=int, 68 | default=None, 69 | help="The maximum sequence length of log-mel filter-bank features that this model might ever be used with. Can " 70 | "be used to create a student model with a shorter context length than the teacher model. Defaults to the number " 71 | "of source positions in the teacher model (1500).", 72 | ) 73 | parser.add_argument( 74 | "--save_dir", 75 | type=str, 76 | required=True, 77 | help="Where to save the student weights and processor.", 78 | ) 79 | parser.add_argument( 80 | "--push_to_hub", 81 | type=bool, 82 | required=False, 83 | default=False, 84 | help="Whether to push the student weights and processor to the Hub.", 85 | ) 86 | parser.add_argument( 87 | "--cache_dir", 88 | type=str, 89 | default=None, 90 | help="Where to store the pretrained models downloaded from huggingface.co", 91 | ) 92 | 93 | args = parser.parse_args() 94 | return args 95 | 96 | 97 | def init_student_model_from_teacher( 98 | teacher_checkpoint, 99 | encoder_layers=None, 100 | decoder_layers=2, 101 | max_source_positions=None, 102 | save_dir=None, 103 | push_to_hub=None, 104 | cache_dir=None, 105 | subfolder="", 106 | ): 107 | teacher_model, teacher_params = FlaxWhisperForConditionalGeneration.from_pretrained( 108 | teacher_checkpoint, 109 | _do_init=False, 110 | cache_dir=cache_dir, 111 | subfolder=subfolder, 112 | ) 113 | processor = WhisperProcessor.from_pretrained(teacher_checkpoint) 114 | generation_config = GenerationConfig.from_pretrained(teacher_checkpoint) 115 | 116 | teacher_config = teacher_model.config 117 | teacher_encoder_layers = teacher_config.encoder_layers 118 | teacher_decoder_layers = teacher_config.decoder_layers 119 | 120 | student_config = copy.deepcopy(teacher_config) 121 | student_config.update( 122 | { 123 | "encoder_layers": encoder_layers if encoder_layers is not None else teacher_encoder_layers, 124 | "decoder_layers": decoder_layers, 125 | "max_source_positions": ( 126 | max_source_positions if max_source_positions is not None else student_config.max_source_positions 127 | ), 128 | } 129 | ) 130 | 131 | encoder_mapping = np.linspace(0, teacher_encoder_layers - 1, student_config.encoder_layers, dtype=int) 132 | encoder_mapping[-1] = teacher_encoder_layers - 1 133 | 134 | encoder_map = {} 135 | for student_layer, teacher_layer in enumerate(encoder_mapping): 136 | encoder_map[str(teacher_layer)] = str(student_layer) 137 | 138 | decoder_mapping = np.linspace(0, teacher_decoder_layers - 1, student_config.decoder_layers, dtype=int) 139 | decoder_mapping[-1] = teacher_decoder_layers - 1 140 | 141 | decoder_map = {} 142 | for student_layer, teacher_layer in enumerate(decoder_mapping): 143 | decoder_map[str(teacher_layer)] = str(student_layer) 144 | 145 | # init the student params from the teacher model 146 | student_params = unfreeze(teacher_params) 147 | student_params["model"]["decoder"]["layers"] = {} 148 | 149 | for layer in teacher_params["model"]["decoder"]["layers"]: 150 | if layer in decoder_map: 151 | # re-introduce pre-defined layers from the teacher 152 | student_params["model"]["decoder"]["layers"][decoder_map[layer]] = teacher_params["model"]["decoder"][ 153 | "layers" 154 | ][layer] 155 | 156 | if encoder_layers is not None: 157 | student_params["model"]["encoder"]["layers"] = {} 158 | for layer in teacher_params["model"]["encoder"]["layers"]: 159 | if layer in encoder_map: 160 | # re-introduce pre-defined layers from the teacher 161 | student_params["model"]["encoder"]["layers"][encoder_map[layer]] = teacher_params["model"]["encoder"][ 162 | "layers" 163 | ][layer] 164 | 165 | if max_source_positions is not None: 166 | # slice the first MAX_SOURCE_POSITIONS embedding weights 167 | student_params["model"]["encoder"]["embed_positions"]["embedding"] = teacher_params["model"]["encoder"][ 168 | "embed_positions" 169 | ]["embedding"][: student_config.max_source_positions, :] 170 | # update the feature extractor to handle the new input length 171 | chunk_length = int(student_config.max_source_positions * 2 / 100) 172 | processor.feature_extractor = WhisperFeatureExtractor(chunk_length=chunk_length) 173 | 174 | # remove the teacher params and model 175 | del teacher_params, teacher_model 176 | 177 | # save the converted weights and model 178 | student_params = freeze(student_params) 179 | student_model = FlaxWhisperForConditionalGeneration(student_config, _do_init=False) 180 | 181 | if save_dir is not None: 182 | student_model.save_pretrained(save_dir, params=student_params) 183 | # we also need to correctly save the processor and generation config 184 | processor.save_pretrained(save_dir) 185 | generation_config.save_pretrained(save_dir) 186 | 187 | # check we can do a forward pass with the saved model - first load the weights and processor 188 | logger.info("Checking we can load the saved model...") 189 | student_model, student_params = FlaxWhisperForConditionalGeneration.from_pretrained( 190 | save_dir, 191 | _do_init=False, 192 | ) 193 | processor = WhisperProcessor.from_pretrained(save_dir) 194 | 195 | # define some random inputs 196 | input_features = processor(np.ones(16000), sampling_rate=16000, return_tensors="np").input_features 197 | decoder_start_token_id = student_model.config.decoder_start_token_id 198 | decoder_input_ids = np.ones((input_features.shape[0], 1)) * decoder_start_token_id 199 | 200 | # do a forward pass - outputs will be gibberish for the initialised model so we can't check them 201 | logger.info("Checking we can run the converted model forward...") 202 | _ = student_model(input_features, decoder_input_ids=decoder_input_ids, params=student_params).logits 203 | logger.info("Conversion successful!") 204 | 205 | if push_to_hub: 206 | student_model.push_to_hub(save_dir, params=student_params) 207 | processor.push_to_hub(save_dir) 208 | generation_config.push_to_hub(save_dir) 209 | 210 | 211 | if __name__ == "__main__": 212 | args = parse_args() 213 | 214 | # Set the verbosity to info of the logger - we only want one process per machine to log things on the screen 215 | logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) 216 | 217 | init_student_model_from_teacher( 218 | teacher_checkpoint=args.teacher_checkpoint, 219 | encoder_layers=args.encoder_layers, 220 | decoder_layers=args.decoder_layers, 221 | max_source_positions=args.max_source_positions, 222 | save_dir=args.save_dir, 223 | push_to_hub=args.push_to_hub, 224 | cache_dir=args.cache_dir, 225 | subfolder=args.subfolder, 226 | ) 227 | -------------------------------------------------------------------------------- /training/flax/distil_whisper/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | __version__ = "0.0.1" 17 | 18 | from .modeling_flax_whisper import FlaxWhisperForConditionalGeneration 19 | from .partitioner import PjitPartitioner 20 | from .pipeline import FlaxWhisperPipeline 21 | from .train_state import InferenceState 22 | -------------------------------------------------------------------------------- /training/flax/distil_whisper/train_state.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Mapping, MutableMapping, Optional, Tuple 2 | 3 | import flax.core 4 | import flax.serialization 5 | import flax.struct 6 | import jax.numpy as jnp 7 | from flax import traverse_util 8 | from flax.core import scope as flax_scope 9 | from flax.linen import partitioning as flax_partitioning 10 | 11 | 12 | EMPTY_DICT = flax.core.freeze({}) 13 | FrozenDict = flax_scope.FrozenDict 14 | FrozenVariableDict = flax_scope.FrozenVariableDict 15 | MutableVariableDict = flax_scope.MutableVariableDict 16 | VariableDict = flax_scope.VariableDict 17 | 18 | 19 | def _validate_params_axes(params_axes, params): 20 | axis_names = flax_partitioning.get_axis_names(params_axes) 21 | missing_params_axes = set(traverse_util.flatten_dict(params, sep="/")) - set( 22 | traverse_util.flatten_dict(axis_names, sep="/") 23 | ) 24 | if missing_params_axes: 25 | raise ValueError(f"Missing axis names for parameters: {missing_params_axes}") 26 | 27 | 28 | def _split_variables_and_axes( 29 | variables_and_axes: FrozenVariableDict, 30 | ) -> Tuple[FrozenVariableDict, FrozenVariableDict]: 31 | """Splits `variables_and_axes` into two separate dicts with the same keys.""" 32 | # For each `key`, `key_axes` (if any) are its axes in `variables_and_axes`. 33 | variables = {} 34 | axes = {} 35 | for k, v in variables_and_axes.items(): 36 | if k.endswith("_axes"): 37 | axes[k[:-5]] = v # k without "_axes". 38 | _validate_params_axes(v, variables_and_axes[k[:-5]]) # k without "_axes". 39 | else: 40 | variables[k] = v 41 | return flax.core.freeze(variables), flax.core.freeze(axes) 42 | 43 | 44 | class InferenceState(flax.struct.PyTreeNode): 45 | """State compatible with FlaxOptimTrainState without optimizer state.""" 46 | 47 | step: jnp.ndarray 48 | params: flax_scope.FrozenVariableDict 49 | params_axes: Optional[flax_scope.FrozenVariableDict] = None 50 | flax_mutables: flax_scope.FrozenDict = EMPTY_DICT 51 | flax_mutables_axes: Optional[flax_scope.FrozenVariableDict] = None 52 | 53 | @classmethod 54 | def create(cls, model_variables: FrozenVariableDict) -> "InferenceState": 55 | other_variables, params = model_variables.pop("params") 56 | if "params_axes" in other_variables: 57 | other_variables, params_axes = other_variables.pop("params_axes") 58 | _validate_params_axes(params_axes, params) 59 | else: 60 | params_axes = None 61 | 62 | # Split other_variables into mutables and their corresponding axes. 63 | flax_mutables, flax_mutables_axes = _split_variables_and_axes(other_variables) 64 | flax_mutables_axes = flax_mutables_axes or None 65 | return InferenceState( 66 | step=jnp.array(0), 67 | params=params, 68 | params_axes=params_axes, 69 | flax_mutables=flax_mutables, 70 | flax_mutables_axes=flax_mutables_axes, 71 | ) 72 | 73 | @property 74 | def param_states(self) -> FrozenVariableDict: 75 | """The optimizer states of the parameters as a PyTree.""" 76 | raise NotImplementedError("InferenceState has no optimizer states.") 77 | 78 | def apply_gradient(self, *args, **kwargs) -> "InferenceState": 79 | raise NotImplementedError("InferenceState does not support `apply_gradient`.") 80 | 81 | def state_dict(self) -> MutableMapping[str, Any]: 82 | state_dict = { 83 | "target": flax.core.unfreeze(self.params), 84 | "state": {"step": self.step}, 85 | } 86 | if self.flax_mutables: 87 | state_dict["flax_mutables"] = flax.core.unfreeze(self.flax_mutables) 88 | return state_dict 89 | 90 | def replace_step(self, step: jnp.ndarray) -> "InferenceState": 91 | return self.replace(step=step) 92 | 93 | def replace_params(self, params: FrozenVariableDict) -> "InferenceState": 94 | return self.replace(params=params) 95 | 96 | def replace_flax_mutables(self, flax_mutables: FrozenDict) -> "InferenceState": 97 | return self.replace(flax_mutables=flax_mutables) 98 | 99 | def restore_state(self, state_dict: Mapping[str, Any]) -> "InferenceState": 100 | return self.replace( 101 | params=flax.core.freeze(state_dict["target"]), 102 | step=state_dict["state"]["step"], 103 | flax_mutables=( 104 | flax.core.freeze(state_dict["flax_mutables"]) if "flax_mutables" in state_dict else EMPTY_DICT 105 | ), 106 | ) 107 | 108 | def as_logical_axes(self) -> "InferenceState": 109 | # Set step to None so that when the logical axes are processed by the 110 | # flax.partitioning.logical_to_mesh_axes function, it will be skipped 111 | # because jax.tree_map will short circut and never call the function on the 112 | # step. 113 | flax_mutables_axes = self.flax_mutables_axes or EMPTY_DICT 114 | return InferenceState( 115 | step=None, 116 | params=flax_partitioning.get_axis_names(self.params_axes), 117 | flax_mutables=flax_partitioning.get_axis_names(flax_mutables_axes), 118 | ) 119 | -------------------------------------------------------------------------------- /training/flax/distillation_scripts/run_32_2_pt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | accelerate launch --multi_gpu --mixed_precision=bf16 --num_processes=2 run_distillation_pt.py \ 4 | --model_name_or_path distil-whisper/large-32-2 \ 5 | --teacher_model_name_or_path openai/whisper-large-v2 \ 6 | --train_dataset_config_name all+all+all+l \ 7 | --train_dataset_samples 2.9+10.4+14.9+226.6 \ 8 | --train_dataset_name librispeech_asr+librispeech_asr+librispeech_asr+gigaspeech-l \ 9 | --train_split_name train.clean.100+train.clean.360+train.other.500+train \ 10 | --eval_dataset_name librispeech_asr+librispeech_asr+gigaspeech-l \ 11 | --eval_dataset_config_name all+all+l \ 12 | --eval_split_name validation.clean+validation.other+validation \ 13 | --eval_text_column_name text+text+text \ 14 | --eval_steps 2500 \ 15 | --save_steps 2500 \ 16 | --warmup_steps 50 \ 17 | --learning_rate 0.0001 \ 18 | --lr_scheduler_type constant_with_warmup \ 19 | --logging_steps 25 \ 20 | --save_total_limit 1 \ 21 | --max_steps 10000 \ 22 | --wer_threshold 10 \ 23 | --per_device_train_batch_size 64 \ 24 | --gradient_accumulation_steps 2 \ 25 | --per_device_eval_batch_size 64 \ 26 | --dataloader_num_workers 16 \ 27 | --cache_dir /fsx/sanchit/cache \ 28 | --dataset_cache_dir /fsx/sanchit/cache \ 29 | --dtype bfloat16 \ 30 | --output_dir ./ \ 31 | --wandb_project distil-whisper-training \ 32 | --do_train \ 33 | --do_eval \ 34 | --gradient_checkpointing \ 35 | --overwrite_output_dir \ 36 | --predict_with_generate \ 37 | --freeze_encoder \ 38 | --streaming 39 | -------------------------------------------------------------------------------- /training/flax/distillation_scripts/run_bs_sweep.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - python3 3 | - ${program} 4 | - --do_train 5 | - --use_scan 6 | - --gradient_checkpointing 7 | - --overwrite_output_dir 8 | - --predict_with_generate 9 | - --freeze_encoder 10 | - --streaming 11 | - --use_auth_token 12 | - --compilation_cache 13 | - ${args} 14 | method: grid 15 | metric: 16 | goal: minimize 17 | name: train/loss 18 | parameters: 19 | model_name_or_path: 20 | value: distil-whisper/large-32-2 21 | teacher_model_name_or_path: 22 | value: openai/whisper-large-v2 23 | train_dataset_name: 24 | value: librispeech_asr 25 | train_dataset_config_name: 26 | value: all 27 | train_split_name: 28 | value: train.other.500 29 | train_dataset_samples: 30 | value: 100 31 | cache_dir: 32 | value: /fsx/sanchitgandhi/cache 33 | dataset_cache_dir: 34 | value: /fsx/sanchitgandhi/cache 35 | output_dir: 36 | value: ./ 37 | per_device_train_batch_size: 38 | values: 39 | - 128 40 | - 256 41 | - 512 42 | precision: 43 | values: 44 | - "full_mixed" 45 | - "half_mixed" 46 | dtype: 47 | value: bfloat16 48 | do_eval: 49 | value: false 50 | learning_rate: 51 | value: 3e-4 52 | lr_scheduler_type: 53 | value: constant_with_warmup 54 | warmup_steps: 55 | value: 30 56 | max_steps: 57 | value: 30 58 | save_steps: 59 | value: 51 # don't save checkpoints during sweep 60 | dataloader_num_workers: 61 | value: 48 62 | logging_steps: 63 | value: 5 64 | wer_threshold: 65 | value: 100 66 | program: run_distillation.py 67 | project: distil-whisper-sweeps 68 | -------------------------------------------------------------------------------- /training/flax/distillation_scripts/run_dataset_sweep.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - python3 3 | - ${program} 4 | - --do_train 5 | - --do_eval 6 | - --use_scan 7 | - --gradient_checkpointing 8 | - --overwrite_output_dir 9 | - --predict_with_generate 10 | - --freeze_encoder 11 | - --streaming 12 | - --use_auth_token 13 | - ${args} 14 | method: grid 15 | metric: 16 | goal: minimize 17 | name: gigaspeech-l/validation/wer 18 | parameters: 19 | model_name_or_path: 20 | value: distil-whisper/large-32-2 21 | teacher_model_name_or_path: 22 | value: openai/whisper-large-v2 23 | max_train_samples: 24 | values: 25 | - 109876 26 | - 219752 27 | - 439504 28 | - 879008 29 | - 1758015 30 | - 3516030 31 | - 7032061 32 | train_dataset_name: 33 | value: librispeech_asr-timestamped+librispeech_asr-timestamped+librispeech_asr-timestamped+common_voice_13_0-timestamped+voxpopuli-timestamped+ami-ihm-timestamped+ami-sdm-timestamped+peoples_speech-clean-timestamped+tedlium-timestamped+switchboard-data+gigaspeech-l-timestamped+librispeech_asr-prompted+librispeech_asr-prompted+librispeech_asr-prompted+tedlium-prompted 34 | train_dataset_config_name: 35 | value: all+all+all+en+en+ihm+sdm+clean+release3+all+l+all+all+all+release3 36 | train_split_name: 37 | value: train.clean.100+train.clean.360+train.other.500+train+train+train+train+train+train+train+train+train.clean.100+train.clean.360+train.other.500+train 38 | train_dataset_samples: 39 | value: 2.9+10.4+14.9+89+18.2+10.9+10.9+288+26.8+371.2+226.6+2.9+10.4+14.9+26.8 40 | eval_dataset_name: 41 | value: librispeech_asr+librispeech_asr+common_voice_13_0+voxpopuli+ami-ihm+ami-sdm+peoples_speech-clean+tedlium+switchboard-data+gigaspeech-l+spgispeech+chime4+google/fleurs 42 | eval_dataset_config_name: 43 | value: all+all+en+en+ihm+sdm+clean+release3+all+l+L+1-channel+en_us 44 | eval_split_name: 45 | value: validation.clean+validation.other+validation+validation+validation+validation+validation+validation+validation+validation+validation+validation+validation 46 | eval_text_column_name: 47 | value: text+text+text+text+text+text+text+text+text+text+text+text+transcription 48 | cache_dir: 49 | value: /home/sanchitgandhi/.cache 50 | dataset_cache_dir: 51 | value: /home/sanchitgandhi/.cache 52 | output_dir: 53 | value: ./ 54 | per_device_train_batch_size: 55 | value: 64 56 | per_device_eval_batch_size: 57 | value: 64 58 | dtype: 59 | value: bfloat16 60 | learning_rate: 61 | value: 1e-4 62 | lr_scheduler_type: 63 | value: constant_with_warmup 64 | warmup_steps: 65 | value: 50 66 | max_steps: 67 | value: 10000 68 | save_steps: 69 | value: 10001 # don't save checkpoints during sweep 70 | dataloader_num_workers: 71 | value: 48 72 | logging_steps: 73 | value: 25 74 | wer_threshold: 75 | value: 10 76 | program: run_distillation.py 77 | project: distil-whisper-sweeps 78 | -------------------------------------------------------------------------------- /training/flax/distillation_scripts/run_decoder_sweep.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - python3 3 | - ${program} 4 | - --do_train 5 | - --do_eval 6 | - --use_scan 7 | - --gradient_checkpointing 8 | - --overwrite_output_dir 9 | - --predict_with_generate 10 | - --freeze_encoder 11 | - --streaming 12 | - --use_auth_token 13 | - ${args} 14 | method: grid 15 | metric: 16 | goal: minimize 17 | name: gigaspeech-l/validation/wer 18 | parameters: 19 | model_name_or_path: 20 | values: 21 | - distil-whisper/large-32-16 22 | - distil-whisper/large-32-8 23 | - distil-whisper/large-32-4 24 | - distil-whisper/large-32-2 25 | teacher_model_name_or_path: 26 | value: openai/whisper-large-v2 27 | train_dataset_name: 28 | value: librispeech_asr-timestamped+librispeech_asr-timestamped+librispeech_asr-timestamped+common_voice_13_0-timestamped+voxpopuli-timestamped+ami-ihm-timestamped+ami-sdm-timestamped+peoples_speech-clean-timestamped+tedlium-timestamped+switchboard-data+gigaspeech-l-timestamped+librispeech_asr-prompted+librispeech_asr-prompted+librispeech_asr-prompted+tedlium-prompted 29 | train_dataset_config_name: 30 | value: all+all+all+en+en+ihm+sdm+clean+release3+all+l+all+all+all+release3 31 | train_split_name: 32 | value: train.clean.100+train.clean.360+train.other.500+train+train+train+train+train+train+train+train+train.clean.100+train.clean.360+train.other.500+train 33 | train_dataset_samples: 34 | value: 2.9+10.4+14.9+89+18.2+10.9+10.9+288+26.8+371.2+226.6+2.9+10.4+14.9+26.8 35 | eval_dataset_name: 36 | value: librispeech_asr+librispeech_asr+common_voice_13_0+voxpopuli+ami-ihm+ami-sdm+peoples_speech-clean+tedlium+switchboard-data+gigaspeech-l+spgispeech+chime4+google/fleurs 37 | eval_dataset_config_name: 38 | value: all+all+en+en+ihm+sdm+clean+release3+all+l+L+1-channel+en_us 39 | eval_split_name: 40 | value: validation.clean+validation.other+validation+validation+validation+validation+validation+validation+validation+validation+validation+validation+validation 41 | eval_text_column_name: 42 | value: text+text+text+text+text+text+text+text+text+text+text+text+transcription 43 | cache_dir: 44 | value: /home/sanchitgandhi/.cache 45 | dataset_cache_dir: 46 | value: /home/sanchitgandhi/.cache 47 | output_dir: 48 | value: ./ 49 | per_device_train_batch_size: 50 | value: 64 51 | per_device_eval_batch_size: 52 | value: 64 53 | dtype: 54 | value: bfloat16 55 | learning_rate: 56 | value: 1e-4 57 | lr_scheduler_type: 58 | value: constant_with_warmup 59 | warmup_steps: 60 | value: 50 61 | max_steps: 62 | value: 10000 63 | save_steps: 64 | value: 10001 # don't save checkpoints during sweep 65 | dataloader_num_workers: 66 | value: 48 67 | logging_steps: 68 | value: 25 69 | wer_threshold: 70 | value: 10 71 | program: run_distillation.py 72 | project: distil-whisper-sweeps 73 | -------------------------------------------------------------------------------- /training/flax/distillation_scripts/run_distillation_12_2_timestamped.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD=10000000000 python3 run_distillation.py \ 4 | --model_name_or_path "distil-whisper/small-12-2" \ 5 | --teacher_model_name_or_path "openai/whisper-medium.en" \ 6 | --train_dataset_config_name "all+all+all+en+en+ihm+sdm+clean+release3+all+l+all+all+all+release3" \ 7 | --train_dataset_samples "2.9+10.4+14.9+89+18.2+10.9+10.9+288+26.8+371.2+226.6+2.9+10.4+14.9+26.8" \ 8 | --train_dataset_name "librispeech_asr-timestamped+librispeech_asr-timestamped+librispeech_asr-timestamped+common_voice_13_0-timestamped+voxpopuli-timestamped+ami-ihm-timestamped+ami-sdm-timestamped+peoples_speech-clean-timestamped+tedlium-timestamped+switchboard-data+gigaspeech-l-timestamped+librispeech_asr-prompted+librispeech_asr-prompted+librispeech_asr-prompted+tedlium-prompted" \ 9 | --train_split_name "train.clean.100+train.clean.360+train.other.500+train+train+train+train+train+train+train+train+train.clean.100+train.clean.360+train.other.500+train" \ 10 | --eval_dataset_name "distil-whisper/gigaspeech-l+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset" \ 11 | --eval_dataset_config_name "l+librispeech+librispeech+common_voice+common_voice+voxpopuli+voxpopuli+tedlium+tedlium+spgispeech+spgispeech+ami+ami" \ 12 | --eval_split_name "validation+clean+other+clean+other+clean+other+clean+other+clean+other+clean+other" \ 13 | --eval_text_column_name "text+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript" \ 14 | --eval_steps 5000 \ 15 | --save_steps 5000 \ 16 | --warmup_steps 500 \ 17 | --learning_rate 0.0001 \ 18 | --logging_steps 25 \ 19 | --save_total_limit 1 \ 20 | --max_steps 80000 \ 21 | --wer_threshold 10 \ 22 | --per_device_train_batch_size 64 \ 23 | --per_device_eval_batch_size 64 \ 24 | --dtype "bfloat16" \ 25 | --dataloader_num_workers 16 \ 26 | --cache_dir "/home/sanchitgandhi/.cache" \ 27 | --dataset_cache_dir "/home/sanchitgandhi/.cache" \ 28 | --output_dir "./" \ 29 | --timestamp_probability 0.2 \ 30 | --wandb_name "small-12-2-tpu-timestamped-prob-0.2" \ 31 | --wandb_dir "/home/sanchitgandhi/.cache" \ 32 | --wandb_project "distil-whisper" \ 33 | --do_train \ 34 | --do_eval \ 35 | --use_scan \ 36 | --gradient_checkpointing \ 37 | --overwrite_output_dir \ 38 | --predict_with_generate \ 39 | --freeze_encoder \ 40 | --streaming \ 41 | --use_auth_token \ 42 | --push_to_hub 43 | -------------------------------------------------------------------------------- /training/flax/distillation_scripts/run_distillation_15s_context.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD=10000000000 python3 run_distillation.py \ 4 | --model_name_or_path "distil-whisper/large-32-2-15s-context" \ 5 | --teacher_model_name_or_path "openai/whisper-large-v2" \ 6 | --feature_extractor_name "openai/whisper-large-v2" \ 7 | --train_dataset_config_name "all+all+all+en+en+ihm+sdm+clean+release3+all+l+L" \ 8 | --train_dataset_samples "100+360+500+2300+450+90+90+12000+450+3600+2500+5000" \ 9 | --train_dataset_name "librispeech_asr+librispeech_asr+librispeech_asr+common_voice_13_0+voxpopuli+ami-ihm+ami-sdm+peoples_speech-clean+tedlium+switchboard-data+gigaspeech-l+spgispeech" \ 10 | --train_split_name "train.clean.100+train.clean.360+train.other.500+train+train+train+train+train+train+train+train+train" \ 11 | --eval_dataset_name "distil-whisper/gigaspeech-l+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset" \ 12 | --eval_dataset_config_name "l+librispeech+librispeech+common_voice+common_voice+voxpopuli+voxpopuli+tedlium+tedlium+spgispeech+spgispeech+ami+ami" \ 13 | --eval_split_name "validation+clean+other+clean+other+clean+other+clean+other+clean+other+clean+other" \ 14 | --eval_text_column_name "text+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript" \ 15 | --eval_steps 5000 \ 16 | --save_steps 5000 \ 17 | --warmup_steps 500 \ 18 | --learning_rate 0.0001 \ 19 | --lr_scheduler_type "linear" \ 20 | --logging_steps 25 \ 21 | --save_total_limit 1 \ 22 | --max_steps 80000 \ 23 | --wer_threshold 10 \ 24 | --per_device_train_batch_size 64 \ 25 | --per_device_eval_batch_size 64 \ 26 | --max_duration_in_seconds 15 \ 27 | --dataloader_num_workers 16 \ 28 | --cache_dir "/home/sanchitgandhi/.cache" \ 29 | --dataset_cache_dir "/home/sanchitgandhi/.cache" \ 30 | --dtype "bfloat16" \ 31 | --output_dir "./" \ 32 | --wandb_name "large-32-2-ts-28k-wer-10-context-15s" \ 33 | --wandb_dir "/home/sanchitgandhi/.cache" \ 34 | --wandb_project "distil-whisper" \ 35 | --do_train \ 36 | --do_eval \ 37 | --use_scan \ 38 | --gradient_checkpointing \ 39 | --overwrite_output_dir \ 40 | --predict_with_generate \ 41 | --streaming \ 42 | --use_auth_token \ 43 | --push_to_hub 44 | -------------------------------------------------------------------------------- /training/flax/distillation_scripts/run_distillation_16_2.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD=10000000000 python3 run_distillation.py \ 4 | --model_name_or_path "distil-whisper/large-16-2" \ 5 | --teacher_model_name_or_path "openai/whisper-large-v2" \ 6 | --train_dataset_config_name "all+all+all+en+en+ihm+sdm+clean+release3+all+l+L" \ 7 | --train_dataset_samples "100+360+500+2300+450+90+90+12000+450+3600+2500+5000" \ 8 | --train_dataset_name "librispeech_asr+librispeech_asr+librispeech_asr+common_voice_13_0+voxpopuli+ami-ihm+ami-sdm+peoples_speech-clean+tedlium+switchboard-data+gigaspeech-l+spgispeech" \ 9 | --train_split_name "train.clean.100+train.clean.360+train.other.500+train+train+train+train+train+train+train+train+train" \ 10 | --eval_dataset_name "distil-whisper/gigaspeech-l+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset" \ 11 | --eval_dataset_config_name "l+librispeech+librispeech+common_voice+common_voice+voxpopuli+voxpopuli+tedlium+tedlium+spgispeech+spgispeech+ami+ami" \ 12 | --eval_split_name "validation+clean+other+clean+other+clean+other+clean+other+clean+other+clean+other" \ 13 | --eval_text_column_name "text+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript" \ 14 | --eval_steps 5000 \ 15 | --save_steps 5000 \ 16 | --warmup_steps 500 \ 17 | --learning_rate 0.0001 \ 18 | --lr_scheduler_type "linear" \ 19 | --logging_steps 25 \ 20 | --save_total_limit 1 \ 21 | --max_steps 80000 \ 22 | --wer_threshold 10 \ 23 | --per_device_eval_batch_size 64 \ 24 | --per_device_train_batch_size 64 \ 25 | --dataloader_num_workers 16 \ 26 | --cache_dir "/home/sanchitgandhi/.cache" \ 27 | --dataset_cache_dir "/home/sanchitgandhi/.cache" \ 28 | --dtype "bfloat16" \ 29 | --output_dir "./" \ 30 | --wandb_name "large-16-2-ts-28k-wer-10" \ 31 | --wandb_dir "/home/sanchitgandhi/.cache" \ 32 | --wandb_project "distil-whisper" \ 33 | --do_train \ 34 | --do_eval \ 35 | --use_scan \ 36 | --gradient_checkpointing \ 37 | --overwrite_output_dir \ 38 | --predict_with_generate \ 39 | --streaming \ 40 | --use_auth_token \ 41 | --push_to_hub 42 | -------------------------------------------------------------------------------- /training/flax/distillation_scripts/run_distillation_24_2.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD=10000000000 python3 run_distillation.py \ 4 | --model_name_or_path "distil-whisper/medium-24-2" \ 5 | --teacher_model_name_or_path "openai/whisper-medium.en" \ 6 | --train_dataset_config_name "all+all+all+en+en+ihm+sdm+clean+release3+all+l+L" \ 7 | --train_dataset_samples "100+360+500+2300+450+90+90+12000+450+3600+2500+5000" \ 8 | --train_dataset_name "librispeech_asr+librispeech_asr+librispeech_asr+common_voice_13_0+voxpopuli+ami-ihm+ami-sdm+peoples_speech-clean+tedlium+switchboard-data+gigaspeech-l+spgispeech" \ 9 | --train_split_name "train.clean.100+train.clean.360+train.other.500+train+train+train+train+train+train+train+train+train" \ 10 | --eval_dataset_name "distil-whisper/gigaspeech-l+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset" \ 11 | --eval_dataset_config_name "l+librispeech+librispeech+common_voice+common_voice+voxpopuli+voxpopuli+tedlium+tedlium+spgispeech+spgispeech+ami+ami" \ 12 | --eval_split_name "validation+clean+other+clean+other+clean+other+clean+other+clean+other+clean+other" \ 13 | --eval_text_column_name "text+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript" \ 14 | --eval_steps 5000 \ 15 | --save_steps 5000 \ 16 | --warmup_steps 500 \ 17 | --learning_rate 0.0001 \ 18 | --lr_scheduler_type "linear" \ 19 | --logging_steps 25 \ 20 | --save_total_limit 1 \ 21 | --max_steps 80000 \ 22 | --wer_threshold 10 \ 23 | --per_device_eval_batch_size 64 \ 24 | --per_device_train_batch_size 64 \ 25 | --dataloader_num_workers 16 \ 26 | --cache_dir "/home/sanchitgandhi/.cache" \ 27 | --dataset_cache_dir "/home/sanchitgandhi/.cache" \ 28 | --dtype "bfloat16" \ 29 | --output_dir "./" \ 30 | --wandb_name "medium-24-2-ts-freeze-28k-wer-10" \ 31 | --wandb_dir "/home/sanchitgandhi/.cache" \ 32 | --wandb_project "distil-whisper" \ 33 | --do_train \ 34 | --do_eval \ 35 | --use_scan \ 36 | --gradient_checkpointing \ 37 | --overwrite_output_dir \ 38 | --predict_with_generate \ 39 | --streaming \ 40 | --freeze_encoder \ 41 | --use_auth_token \ 42 | --push_to_hub 43 | -------------------------------------------------------------------------------- /training/flax/distillation_scripts/run_distillation_24_2_timestamped.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD=10000000000 python3 run_distillation.py \ 4 | --model_name_or_path "distil-whisper/medium-24-2" \ 5 | --teacher_model_name_or_path "openai/whisper-medium.en" \ 6 | --train_dataset_config_name "all+all+all+en+en+ihm+sdm+clean+release3+all+l+all+all+all+release3" \ 7 | --train_dataset_samples "2.9+10.4+14.9+89+18.2+10.9+10.9+288+26.8+371.2+226.6+2.9+10.4+14.9+26.8" \ 8 | --train_dataset_name "librispeech_asr-timestamped+librispeech_asr-timestamped+librispeech_asr-timestamped+common_voice_13_0-timestamped+voxpopuli-timestamped+ami-ihm-timestamped+ami-sdm-timestamped+peoples_speech-clean-timestamped+tedlium-timestamped+switchboard-data+gigaspeech-l-timestamped+librispeech_asr-prompted+librispeech_asr-prompted+librispeech_asr-prompted+tedlium-prompted" \ 9 | --train_split_name "train.clean.100+train.clean.360+train.other.500+train+train+train+train+train+train+train+train+train.clean.100+train.clean.360+train.other.500+train" \ 10 | --eval_dataset_name "distil-whisper/gigaspeech-l+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset" \ 11 | --eval_dataset_config_name "l+librispeech+librispeech+common_voice+common_voice+voxpopuli+voxpopuli+tedlium+tedlium+spgispeech+spgispeech+ami+ami" \ 12 | --eval_split_name "validation+clean+other+clean+other+clean+other+clean+other+clean+other+clean+other" \ 13 | --eval_text_column_name "text+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript" \ 14 | --eval_steps 5000 \ 15 | --save_steps 5000 \ 16 | --warmup_steps 500 \ 17 | --learning_rate 0.0001 \ 18 | --logging_steps 25 \ 19 | --save_total_limit 1 \ 20 | --max_steps 80000 \ 21 | --wer_threshold 10 \ 22 | --per_device_train_batch_size 64 \ 23 | --per_device_eval_batch_size 64 \ 24 | --dtype "bfloat16" \ 25 | --dataloader_num_workers 16 \ 26 | --cache_dir "/home/sanchitgandhi/.cache" \ 27 | --dataset_cache_dir "/home/sanchitgandhi/.cache" \ 28 | --output_dir "./" \ 29 | --timestamp_probability 0.2 \ 30 | --wandb_name "medium-24-2-tpu-timestamped-prob-0.2" \ 31 | --wandb_dir "/home/sanchitgandhi/.cache" \ 32 | --wandb_project "distil-whisper" \ 33 | --do_train \ 34 | --do_eval \ 35 | --use_scan \ 36 | --gradient_checkpointing \ 37 | --overwrite_output_dir \ 38 | --predict_with_generate \ 39 | --freeze_encoder \ 40 | --streaming \ 41 | --use_auth_token \ 42 | --push_to_hub 43 | -------------------------------------------------------------------------------- /training/flax/distillation_scripts/run_distillation_32_2.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD=10000000000 python3 run_distillation.py \ 4 | --model_name_or_path "distil-whisper/large-32-2" \ 5 | --teacher_model_name_or_path "openai/whisper-large-v2" \ 6 | --train_dataset_config_name "all+all+all+l" \ 7 | --train_dataset_samples "100+360+500+2500" \ 8 | --train_dataset_name "librispeech_asr-token-ids+librispeech_asr-token-ids+librispeech_asr-token-ids+gigaspeech-l-token-ids" \ 9 | --train_split_name "train.clean.100+train.clean.360+train.other.500+train" \ 10 | --eval_dataset_name "librispeech_asr+librispeech_asr+gigaspeech-l" \ 11 | --eval_dataset_config_name "all+all+l" \ 12 | --eval_split_name "validation.clean+validation.other+validation" \ 13 | --eval_text_column_name "text+text+text" \ 14 | --eval_steps 5000 \ 15 | --save_steps 5000 \ 16 | --warmup_steps 50 \ 17 | --learning_rate 0.0001 \ 18 | --lr_scheduler_type "constant_with_warmup" \ 19 | --logging_steps 25 \ 20 | --save_total_limit 1 \ 21 | --max_steps 10000 \ 22 | --wer_threshold 10 \ 23 | --per_device_train_batch_size 64 \ 24 | --per_device_eval_batch_size 64 \ 25 | --dataloader_num_workers 16 \ 26 | --cache_dir "/home/sanchitgandhi/.cache" \ 27 | --dataset_cache_dir "/home/sanchitgandhi/.cache" \ 28 | --dtype "bfloat16" \ 29 | --output_dir "./" \ 30 | --wandb_name "large-32-2-ls-gs-token-ids" \ 31 | --wandb_dir "/home/sanchitgandhi/.cache" \ 32 | --wandb_project "distil-whisper" \ 33 | --do_train \ 34 | --do_eval \ 35 | --use_scan \ 36 | --gradient_checkpointing \ 37 | --overwrite_output_dir \ 38 | --predict_with_generate \ 39 | --freeze_encoder \ 40 | --streaming \ 41 | --use_auth_token \ 42 | --push_to_hub 43 | -------------------------------------------------------------------------------- /training/flax/distillation_scripts/run_distillation_32_2_by_samples.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD=10000000000 python3 run_distillation.py \ 4 | --model_name_or_path "distil-whisper/large-32-2" \ 5 | --teacher_model_name_or_path "openai/whisper-large-v2" \ 6 | --train_dataset_config_name "all+all+all+en+en+ihm+sdm+clean+release3+all+l+L" \ 7 | --train_dataset_samples "2.9+10.4+14.9+89+18.2+10.9+10.9+288+26.8+371.2+226.6+192.7" \ 8 | --train_dataset_name "librispeech_asr+librispeech_asr+librispeech_asr+common_voice_13_0+voxpopuli+ami-ihm+ami-sdm+peoples_speech-clean+tedlium+switchboard-data+gigaspeech-l+spgispeech" \ 9 | --train_split_name "train.clean.100+train.clean.360+train.other.500+train+train+train+train+train+train+train+train+train" \ 10 | --eval_dataset_name "distil-whisper/gigaspeech-l+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset" \ 11 | --eval_dataset_config_name "l+librispeech+librispeech+common_voice+common_voice+voxpopuli+voxpopuli+tedlium+tedlium+spgispeech+spgispeech+ami+ami" \ 12 | --eval_split_name "validation+clean+other+clean+other+clean+other+clean+other+clean+other+clean+other" \ 13 | --eval_text_column_name "text+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript" \ 14 | --eval_steps 5000 \ 15 | --save_steps 5000 \ 16 | --warmup_steps 500 \ 17 | --learning_rate 0.0001 \ 18 | --lr_scheduler_type "linear" \ 19 | --logging_steps 25 \ 20 | --save_total_limit 1 \ 21 | --max_steps 80000 \ 22 | --wer_threshold 10 \ 23 | --per_device_train_batch_size 64 \ 24 | --per_device_eval_batch_size 64 \ 25 | --dataloader_num_workers 16 \ 26 | --cache_dir "/home/sanchitgandhi/.cache" \ 27 | --dataset_cache_dir "/home/sanchitgandhi/.cache" \ 28 | --dtype "bfloat16" \ 29 | --output_dir "./" \ 30 | --wandb_name "large-32-2-ts-freeze-28k-wer-10-probs-by-samples" \ 31 | --wandb_dir "/home/sanchitgandhi/.cache" \ 32 | --wandb_project "distil-whisper" \ 33 | --do_train \ 34 | --do_eval \ 35 | --use_scan \ 36 | --gradient_checkpointing \ 37 | --overwrite_output_dir \ 38 | --predict_with_generate \ 39 | --freeze_encoder \ 40 | --streaming \ 41 | --use_auth_token \ 42 | --push_to_hub 43 | -------------------------------------------------------------------------------- /training/flax/distillation_scripts/run_distillation_32_2_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD=10000000000 python3 run_distillation.py \ 4 | --model_name_or_path "distil-whisper/large-32-2" \ 5 | --teacher_model_name_or_path "openai/whisper-large-v2" \ 6 | --train_dataset_config_name "all+all+all+en+en+ihm+sdm+clean+all+L" \ 7 | --train_dataset_samples "2.9+10.4+14.9+89+18.2+10.9+10.9+288+371.2+192.7" \ 8 | --train_dataset_name "librispeech_asr+librispeech_asr+librispeech_asr+common_voice_13_0+voxpopuli+ami-ihm+ami-sdm+peoples_speech-clean+switchboard-data+spgispeech" \ 9 | --train_split_name "train.clean.100+train.clean.360+train.other.500+train+train+train+train+train+train+train" \ 10 | --eval_dataset_name "distil-whisper/gigaspeech-l+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset" \ 11 | --eval_dataset_config_name "l+librispeech+librispeech+common_voice+common_voice+voxpopuli+voxpopuli+tedlium+tedlium+spgispeech+spgispeech+ami+ami" \ 12 | --eval_split_name "validation+clean+other+clean+other+clean+other+clean+other+clean+other+clean+other" \ 13 | --eval_text_column_name "text+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript" \ 14 | --eval_steps 1250 \ 15 | --save_steps 1250 \ 16 | --warmup_steps 250 \ 17 | --learning_rate 0.0001 \ 18 | --lr_scheduler_type "constant_with_warmup" \ 19 | --logging_steps 25 \ 20 | --save_total_limit 1 \ 21 | --max_steps 20000 \ 22 | --wer_threshold 10 \ 23 | --per_device_train_batch_size 128 \ 24 | --per_device_eval_batch_size 128 \ 25 | --dtype "bfloat16" \ 26 | --precision "full_mixed" \ 27 | --dataloader_num_workers 16 \ 28 | --cache_dir "/fsx/sanchit/.cache" \ 29 | --dataset_cache_dir "/fsx/sanchit/.cache" \ 30 | --output_dir "./" \ 31 | --wandb_name "large-32-2-gpu-flat-lr" \ 32 | --wandb_dir "/fsx/sanchit/.cache" \ 33 | --wandb_project "distil-whisper" \ 34 | --do_train \ 35 | --do_eval \ 36 | --use_scan \ 37 | --gradient_checkpointing \ 38 | --overwrite_output_dir \ 39 | --predict_with_generate \ 40 | --freeze_encoder \ 41 | --streaming \ 42 | --use_auth_token \ 43 | --push_to_hub 44 | -------------------------------------------------------------------------------- /training/flax/distillation_scripts/run_distillation_32_2_timestamped.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD=10000000000 python3 run_distillation.py \ 4 | --model_name_or_path "distil-whisper/large-32-2" \ 5 | --teacher_model_name_or_path "openai/whisper-large-v2" \ 6 | --train_dataset_config_name "all+all+all+en+en+ihm+sdm+clean+release3+all+l+all+all+all+release3" \ 7 | --train_dataset_samples "2.9+10.4+14.9+89+18.2+10.9+10.9+288+26.8+371.2+226.6+2.9+10.4+14.9+26.8" \ 8 | --train_dataset_name "librispeech_asr-timestamped+librispeech_asr-timestamped+librispeech_asr-timestamped+common_voice_13_0-timestamped+voxpopuli-timestamped+ami-ihm-timestamped+ami-sdm-timestamped+peoples_speech-clean-timestamped+tedlium-timestamped+switchboard-data+gigaspeech-l-timestamped+librispeech_asr-prompted+librispeech_asr-prompted+librispeech_asr-prompted+tedlium-prompted" \ 9 | --train_split_name "train.clean.100+train.clean.360+train.other.500+train+train+train+train+train+train+train+train+train.clean.100+train.clean.360+train.other.500+train" \ 10 | --eval_dataset_name "distil-whisper/gigaspeech-l+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset" \ 11 | --eval_dataset_config_name "l+librispeech+librispeech+common_voice+common_voice+voxpopuli+voxpopuli+tedlium+tedlium+spgispeech+spgispeech+ami+ami" \ 12 | --eval_split_name "validation+clean+other+clean+other+clean+other+clean+other+clean+other+clean+other" \ 13 | --eval_text_column_name "text+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript" \ 14 | --eval_steps 5000 \ 15 | --save_steps 5000 \ 16 | --warmup_steps 500 \ 17 | --learning_rate 0.0001 \ 18 | --logging_steps 25 \ 19 | --save_total_limit 1 \ 20 | --max_steps 80000 \ 21 | --wer_threshold 10 \ 22 | --per_device_train_batch_size 64 \ 23 | --per_device_eval_batch_size 64 \ 24 | --dtype "bfloat16" \ 25 | --dataloader_num_workers 16 \ 26 | --cache_dir "/home/sanchitgandhi/.cache" \ 27 | --dataset_cache_dir "/home/sanchitgandhi/.cache" \ 28 | --output_dir "./" \ 29 | --wandb_name "large-32-2-tpu-timestamped" \ 30 | --wandb_dir "/home/sanchitgandhi/.cache" \ 31 | --wandb_project "distil-whisper" \ 32 | --do_train \ 33 | --do_eval \ 34 | --use_scan \ 35 | --gradient_checkpointing \ 36 | --overwrite_output_dir \ 37 | --predict_with_generate \ 38 | --freeze_encoder \ 39 | --streaming \ 40 | --use_auth_token \ 41 | --push_to_hub 42 | -------------------------------------------------------------------------------- /training/flax/distillation_scripts/run_distillation_large_32_2_gpu_timestamped.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD=10000000000 python3 run_distillation.py \ 4 | --model_name_or_path "distil-whisper/large-32-2" \ 5 | --teacher_model_name_or_path "openai/whisper-large-v2" \ 6 | --train_dataset_config_name "all+all+all+en+en+ihm+sdm+clean+release3+all+l+all+all+all+release3" \ 7 | --train_dataset_samples "2.9+10.4+14.9+89+18.2+10.9+10.9+288+26.8+371.2+226.6+2.9+10.4+14.9+26.8" \ 8 | --train_dataset_name "librispeech_asr-timestamped+librispeech_asr-timestamped+librispeech_asr-timestamped+common_voice_13_0-timestamped+voxpopuli-timestamped+ami-ihm-timestamped+ami-sdm-timestamped+peoples_speech-clean-timestamped+tedlium-timestamped+switchboard-data+gigaspeech-l-timestamped+librispeech_asr-prompted+librispeech_asr-prompted+librispeech_asr-prompted+tedlium-prompted" \ 9 | --train_split_name "train.clean.100+train.clean.360+train.other.500+train+train+train+train+train+train+train+train+train.clean.100+train.clean.360+train.other.500+train" \ 10 | --eval_dataset_name "distil-whisper/gigaspeech-l+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset" \ 11 | --eval_dataset_config_name "l+librispeech+librispeech+common_voice+common_voice+voxpopuli+voxpopuli+tedlium+tedlium+spgispeech+spgispeech+ami+ami" \ 12 | --eval_split_name "validation+clean+other+clean+other+clean+other+clean+other+clean+other+clean+other" \ 13 | --eval_text_column_name "text+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript" \ 14 | --eval_steps 5000 \ 15 | --save_steps 5000 \ 16 | --warmup_steps 500 \ 17 | --learning_rate 0.0001 \ 18 | --logging_steps 25 \ 19 | --save_total_limit 1 \ 20 | --max_steps 80000 \ 21 | --wer_threshold 10 \ 22 | --per_device_train_batch_size 64 \ 23 | --per_device_eval_batch_size 64 \ 24 | --dtype "bfloat16" \ 25 | --dataloader_num_workers 16 \ 26 | --cache_dir "/fsx/sanchit/.cache" \ 27 | --dataset_cache_dir "/fsx/sanchit/.cache" \ 28 | --output_dir "./" \ 29 | --wandb_name "large-32-2-gpu-timestamped" \ 30 | --wandb_dir "/fsx/sanchit/.cache" \ 31 | --wandb_project "distil-whisper" \ 32 | --do_train \ 33 | --do_eval \ 34 | --use_scan \ 35 | --gradient_checkpointing \ 36 | --overwrite_output_dir \ 37 | --predict_with_generate \ 38 | --freeze_encoder \ 39 | --streaming \ 40 | --use_auth_token \ 41 | --push_to_hub 42 | -------------------------------------------------------------------------------- /training/flax/distillation_scripts/run_distillation_objective.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - python3 3 | - ${program} 4 | - --do_train 5 | - --do_eval 6 | - --use_scan 7 | - --gradient_checkpointing 8 | - --overwrite_output_dir 9 | - --predict_with_generate 10 | - --freeze_encoder 11 | - --streaming 12 | - --use_auth_token 13 | - ${args} 14 | method: grid 15 | metric: 16 | goal: minimize 17 | name: gigaspeech-l/validation/wer 18 | parameters: 19 | model_name_or_path: 20 | value: distil-whisper/large-32-2 21 | teacher_model_name_or_path: 22 | value: openai/whisper-large-v2 23 | train_dataset_name: 24 | value: librispeech_asr-timestamped+librispeech_asr-timestamped+librispeech_asr-timestamped+common_voice_13_0-timestamped+voxpopuli-timestamped+ami-ihm-timestamped+ami-sdm-timestamped+peoples_speech-clean-timestamped+tedlium-timestamped+switchboard-data+gigaspeech-l-timestamped+librispeech_asr-prompted+librispeech_asr-prompted+librispeech_asr-prompted+tedlium-prompted 25 | train_dataset_config_name: 26 | value: all+all+all+en+en+ihm+sdm+clean+release3+all+l+all+all+all+release3 27 | train_split_name: 28 | value: train.clean.100+train.clean.360+train.other.500+train+train+train+train+train+train+train+train+train.clean.100+train.clean.360+train.other.500+train 29 | train_dataset_samples: 30 | value: 2.9+10.4+14.9+89+18.2+10.9+10.9+288+26.8+371.2+226.6+2.9+10.4+14.9+26.8 31 | eval_dataset_name: 32 | value: librispeech_asr+librispeech_asr+common_voice_13_0+voxpopuli+ami-ihm+ami-sdm+peoples_speech-clean+tedlium+switchboard-data+gigaspeech-l+spgispeech+chime4+google/fleurs 33 | eval_dataset_config_name: 34 | value: all+all+en+en+ihm+sdm+clean+release3+all+l+L+1-channel+en_us 35 | eval_split_name: 36 | value: validation.clean+validation.other+validation+validation+validation+validation+validation+validation+validation+validation+validation+validation+validation 37 | eval_text_column_name: 38 | value: text+text+text+text+text+text+text+text+text+text+text+text+transcription 39 | cache_dir: 40 | value: /home/sanchitgandhi/.cache 41 | dataset_cache_dir: 42 | value: /home/sanchitgandhi/.cache 43 | output_dir: 44 | value: ./ 45 | per_device_train_batch_size: 46 | value: 64 47 | per_device_eval_batch_size: 48 | value: 64 49 | dtype: 50 | value: bfloat16 51 | learning_rate: 52 | value: 1e-4 53 | lr_scheduler_type: 54 | value: constant_with_warmup 55 | warmup_steps: 56 | value: 50 57 | max_steps: 58 | value: 10000 59 | save_steps: 60 | value: 10001 # don't save checkpoints during sweep 61 | dataloader_num_workers: 62 | value: 48 63 | logging_steps: 64 | value: 25 65 | wer_threshold: 66 | value: 10 67 | kl_weight: 68 | values: 69 | - 0.0 70 | - 1.0 71 | program: run_distillation.py 72 | project: distil-whisper-sweeps 73 | -------------------------------------------------------------------------------- /training/flax/distillation_scripts/run_dropout_sweep.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - python3 3 | - ${program} 4 | - --do_train 5 | - --do_eval 6 | - --use_scan 7 | - --gradient_checkpointing 8 | - --overwrite_output_dir 9 | - --predict_with_generate 10 | - --streaming 11 | - --use_auth_token 12 | - ${args} 13 | method: random 14 | metric: 15 | goal: minimize 16 | name: eval/wer 17 | parameters: 18 | model_name_or_path: 19 | value: distil-whisper/large-32-2 20 | teacher_model_name_or_path: 21 | value: openai/whisper-large-v2 22 | train_dataset_name: 23 | value: librispeech_asr+librispeech_asr+librispeech_asr+common_voice_13_0+voxpopuli+ami-ihm+ami-sdm+peoples_speech-clean+tedlium+switchboard-data+gigaspeech-l+spgispeech 24 | train_dataset_config_name: 25 | value: all+all+all+en+en+ihm+sdm+clean+release3+all+l+L 26 | train_split_name: 27 | value: train.clean.100+train.clean.360+train.other.500+train+train+train+train+train+train+train+train+train 28 | train_dataset_samples: 29 | value: 100+360+500+2300+450+90+90+12000+450+3600+2500+5000 30 | eval_dataset_name: 31 | value: "distil-whisper/gigaspeech-l" 32 | eval_dataset_config_name: 33 | value: "l" 34 | cache_dir: 35 | value: /home/sanchitgandhi/cache 36 | dataset_cache_dir: 37 | value: /home/sanchitgandhi/cache 38 | output_dir: 39 | value: ./ 40 | per_device_train_batch_size: 41 | value: 32 42 | per_device_eval_batch_size: 43 | value: 64 44 | dtype: 45 | value: bfloat16 46 | learning_rate: 47 | value: 1e-4 48 | lr_scheduler_type: 49 | value: constant_with_warmup 50 | warmup_steps: 51 | value: 50 52 | max_steps: 53 | value: 1000 54 | eval_steps: 55 | value: 1000 56 | save_steps: 57 | value: 1000 58 | dataloader_num_workers: 59 | value: 16 60 | logging_steps: 61 | value: 5 62 | wer_threshold: 63 | value: 10 64 | activation_dropout: 65 | values: 66 | - 0 67 | - 0.05 68 | - 0.1 69 | attention_dropout: 70 | values: 71 | - 0 72 | - 0.05 73 | - 0.1 74 | dropout: 75 | values: 76 | - 0 77 | - 0.05 78 | - 0.1 79 | freeze_encoder: 80 | values: 81 | - true 82 | - false 83 | program: run_distillation.py 84 | project: distil-whisper-sweeps 85 | -------------------------------------------------------------------------------- /training/flax/distillation_scripts/run_librispeech.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD=10000000000 python run_distillation.py \ 4 | --model_name_or_path "distil-whisper/large-32-2" \ 5 | --teacher_model_name_or_path "openai/whisper-large-v2" \ 6 | --dataset_name "distil-whisper/librispeech_asr" \ 7 | --dataset_config_name "all" \ 8 | --train_split_name "train.clean.100+train.clean.360+train.other.500" \ 9 | --eval_split_name "validation.clean" \ 10 | --text_column_name "whisper_transcript" \ 11 | --cache_dir "/home/sanchitgandhi/cache" \ 12 | --dataset_cache_dir "/home/sanchitgandhi/cache" \ 13 | --output_dir "./" \ 14 | --wandb_name "large-32-2-ts-librispeech" \ 15 | --wandb_dir "/home/sanchitgandhi/.cache" \ 16 | --wandb_project "distil-whisper-librispeech" \ 17 | --per_device_train_batch_size 32 \ 18 | --per_device_eval_batch_size 16 \ 19 | --dtype "bfloat16" \ 20 | --learning_rate 1e-4 \ 21 | --warmup_steps 500 \ 22 | --temperature 2.0 \ 23 | --do_train \ 24 | --do_eval \ 25 | --num_train_epochs 10 \ 26 | --preprocessing_num_workers 16 \ 27 | --dataloader_num_workers 8 \ 28 | --logging_steps 25 \ 29 | --use_scan \ 30 | --gradient_checkpointing \ 31 | --overwrite_output_dir \ 32 | --predict_with_generate \ 33 | --push_to_hub 34 | -------------------------------------------------------------------------------- /training/flax/distillation_scripts/run_librispeech_dummy_pt.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | accelerate launch --mixed_precision=bf16 --num_processes=1 run_distillation_pt.py \ 4 | --model_name_or_path "distil-whisper/tiny-random-whisper-2-1" \ 5 | --teacher_model_name_or_path "distil-whisper/tiny-random-whisper" \ 6 | --train_dataset_name "distil-whisper/librispeech_asr_dummy" \ 7 | --train_dataset_config_name "clean" \ 8 | --train_dataset_samples "100" \ 9 | --train_split_name "validation" \ 10 | --eval_dataset_name "distil-whisper/librispeech_asr_dummy" \ 11 | --eval_dataset_config_name "clean" \ 12 | --eval_split_name "validation" \ 13 | --eval_text_column_name "text" \ 14 | --cache_dir "/home/sanchit/.cache" \ 15 | --dataset_cache_dir "/home/sanchit/.cache" \ 16 | --wandb_project "distil-whisper-debug" \ 17 | --output_dir "./" \ 18 | --do_train \ 19 | --do_eval \ 20 | --learning_rate 1e-4 \ 21 | --warmup_steps 25 \ 22 | --per_device_train_batch_size 8 \ 23 | --per_device_eval_batch_size 8 \ 24 | --gradient_checkpointing \ 25 | --max_steps 100 \ 26 | --eval_steps 50 \ 27 | --save_steps 50 \ 28 | --dataloader_num_workers 14 \ 29 | --wer_threshold 10 \ 30 | --logging_steps 5 \ 31 | --overwrite_output_dir \ 32 | --dtype bfloat16 \ 33 | --predict_with_generate \ 34 | --freeze_encoder \ 35 | --streaming False 36 | -------------------------------------------------------------------------------- /training/flax/distillation_scripts/run_librispeech_streaming_dummy.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python run_distillation.py \ 4 | --model_name_or_path "distil-whisper/tiny-random-whisper-2-1" \ 5 | --teacher_model_name_or_path "distil-whisper/tiny-random-whisper" \ 6 | --train_dataset_name "distil-whisper/librispeech_asr+distil-whisper/librispeech_asr-timestamped" \ 7 | --train_dataset_config_name "all+all" \ 8 | --train_dataset_samples "100+360" \ 9 | --train_split_name "train.clean.100+train.clean.360" \ 10 | --eval_dataset_name "distil-whisper/gigaspeech-l+esb/diagnostic-dataset" \ 11 | --eval_dataset_config_name "l+librispeech" \ 12 | --eval_split_name "validation+clean" \ 13 | --eval_text_column_name "text+ortho_transcript" \ 14 | --max_train_samples 1024 \ 15 | --max_eval_samples 32 \ 16 | --cache_dir "/home/sanchitgandhi/.cache" \ 17 | --dataset_cache_dir "/home/sanchitgandhi/.cache" \ 18 | --wandb_dir "/home/sanchitgandhi/.cache" \ 19 | --output_dir "./" \ 20 | --do_train \ 21 | --do_eval \ 22 | --per_device_train_batch_size 2 \ 23 | --per_device_eval_batch_size 2 \ 24 | --max_steps 10 \ 25 | --eval_steps 5 \ 26 | --dataloader_num_workers 14 \ 27 | --save_steps 5 \ 28 | --wer_threshold 10 \ 29 | --wandb_project "distil-whisper-debug" \ 30 | --logging_steps 1 \ 31 | --use_scan \ 32 | --gradient_checkpointing \ 33 | --overwrite_output_dir \ 34 | --predict_with_generate \ 35 | --return_timestamps \ 36 | --timestamp_probability 1 \ 37 | --freeze_encoder 38 | -------------------------------------------------------------------------------- /training/flax/distillation_scripts/run_lr_sweep.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - python3 3 | - ${program} 4 | - --do_train 5 | - --do_eval 6 | - --use_scan 7 | - --gradient_checkpointing 8 | - --overwrite_output_dir 9 | - --predict_with_generate 10 | - --freeze_encoder 11 | - --streaming 12 | - --use_auth_token 13 | - --compilation_cache 14 | - --load_with_scan_weights # checkpoint is saved with scan weights 15 | - ${args} 16 | method: grid 17 | metric: 18 | goal: minimize 19 | name: eval/wer 20 | parameters: 21 | model_name_or_path: 22 | value: distil-whisper/large-32-2-ts-freeze-librispeech # resume from a partially trained checkpoint 23 | teacher_model_name_or_path: 24 | value: openai/whisper-large-v2 25 | train_dataset_name: 26 | value: librispeech_asr+librispeech_asr+librispeech_asr+common_voice_13_0+voxpopuli+ami-ihm+ami-sdm+peoples_speech-clean+tedlium+switchboard-data+gigaspeech-l+spgispeech 27 | train_dataset_config_name: 28 | value: all+all+all+en+en+ihm+sdm+clean+release3+all+l+L 29 | train_split_name: 30 | value: train.clean.100+train.clean.360+train.other.500+train+train+train+train+train+train+train+train+train 31 | train_dataset_samples: 32 | value: 100+360+500+2300+450+90+90+12000+450+3600+2500+5000 33 | eval_dataset_name: 34 | value: "distil-whisper/gigaspeech-l" 35 | eval_dataset_config_name: 36 | value: "l" 37 | cache_dir: 38 | value: /fsx/sanchit/cache 39 | dataset_cache_dir: 40 | value: /fsx/sanchit/cache 41 | output_dir: 42 | value: ./ 43 | per_device_train_batch_size: 44 | value: 128 45 | per_device_eval_batch_size: 46 | value: 128 47 | dtype: 48 | value: bfloat16 49 | learning_rate: 50 | values: 51 | - 1e-3 52 | - 3e-4 53 | - 1e-4 54 | - 3e-5 55 | - 1e-5 56 | lr_scheduler_type: 57 | value: constant_with_warmup 58 | warmup_steps: 59 | value: 50 60 | max_steps: 61 | value: 500 62 | eval_steps: 63 | value: 500 64 | save_steps: 65 | value: 501 # don't save checkpoints during sweep 66 | dataloader_num_workers: 67 | value: 16 68 | logging_steps: 69 | value: 5 70 | wer_threshold: 71 | value: 10 72 | program: run_distillation.py 73 | project: distil-whisper-sweeps 74 | -------------------------------------------------------------------------------- /training/flax/distillation_scripts/run_mse_sweep.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - python3 3 | - ${program} 4 | - --do_train 5 | - --do_eval 6 | - --gradient_checkpointing 7 | - --overwrite_output_dir 8 | - --predict_with_generate 9 | - --streaming 10 | - --use_auth_token 11 | - --use_scan 12 | - ${args} 13 | method: grid 14 | metric: 15 | goal: minimize 16 | name: eval/wer 17 | parameters: 18 | model_name_or_path: 19 | value: distil-whisper/large-16-2 20 | teacher_model_name_or_path: 21 | value: openai/whisper-large-v2 22 | train_dataset_name: 23 | value: librispeech_asr+librispeech_asr+librispeech_asr+common_voice_13_0+voxpopuli+ami-ihm+ami-sdm+peoples_speech-clean+tedlium+switchboard-data+gigaspeech-l+spgispeech 24 | train_dataset_config_name: 25 | value: all+all+all+en+en+ihm+sdm+clean+release3+all+l+L 26 | train_split_name: 27 | value: train.clean.100+train.clean.360+train.other.500+train+train+train+train+train+train+train+train+train 28 | train_dataset_samples: 29 | value: 100+360+500+2300+450+90+90+12000+450+3600+2500+5000 30 | eval_dataset_name: 31 | value: "distil-whisper/gigaspeech-l" 32 | eval_dataset_config_name: 33 | value: "l" 34 | cache_dir: 35 | value: /home/sanchitgandhi/cache 36 | dataset_cache_dir: 37 | value: /home/sanchitgandhi/cache 38 | output_dir: 39 | value: ./ 40 | per_device_train_batch_size: 41 | value: 32 42 | per_device_eval_batch_size: 43 | value: 64 44 | dtype: 45 | value: bfloat16 46 | learning_rate: 47 | value: 0.0001 48 | lr_scheduler_type: 49 | value: constant_with_warmup 50 | warmup_steps: 51 | value: 50 52 | max_steps: 53 | value: 2500 54 | eval_steps: 55 | value: 2500 56 | save_steps: 57 | value: 2001 # don't save checkpoints during sweep 58 | dataloader_num_workers: 59 | value: 16 60 | logging_steps: 61 | value: 5 62 | wer_threshold: 63 | value: 10 64 | mse_weight: 65 | values: 66 | - 0.0 67 | - 0.3 68 | - 1 69 | - 3 70 | program: run_distillation.py 71 | project: distil-whisper-sweeps 72 | -------------------------------------------------------------------------------- /training/flax/distillation_scripts/run_timestamp_sweep.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - python3 3 | - ${program} 4 | - --do_train 5 | - --do_eval 6 | - --use_scan 7 | - --gradient_checkpointing 8 | - --overwrite_output_dir 9 | - --predict_with_generate 10 | - --freeze_encoder 11 | - --streaming 12 | - --use_auth_token 13 | - --compilation_cache 14 | - --return_timestamps 15 | - ${args} 16 | method: grid 17 | metric: 18 | goal: minimize 19 | name: eval/wer 20 | parameters: 21 | model_name_or_path: 22 | value: distil-whisper/large-32-2 23 | teacher_model_name_or_path: 24 | value: openai/whisper-large-v2 25 | train_dataset_name: 26 | value: librispeech_asr-timestamped+librispeech_asr-timestamped+librispeech_asr-timestamped+common_voice_13_0-timestamped+voxpopuli-timestamped+ami-ihm-timestamped+ami-sdm-timestamped+peoples_speech-clean-timestamped+tedlium-timestamped+switchboard-data+gigaspeech-l-timestamped+spgispeech-timestamped 27 | train_dataset_config_name: 28 | value: all+all+all+en+en+ihm+sdm+clean+release3+all+l+L 29 | train_split_name: 30 | value: train.clean.100+train.clean.360+train.other.500+train+train+train+train+train+train+train+train+train 31 | train_dataset_samples: 32 | value: 2.9+10.4+14.9+89+18.2+10.9+10.9+288+26.8+371.2+226.6+192.7 33 | timestamp_probability: 34 | values: 35 | - 0.0 36 | - 0.2 37 | - 0.4 38 | - 0.6 39 | - 0.8 40 | - 1.0 41 | round_timestamps: 42 | values: 43 | - True 44 | - False 45 | eval_dataset_name: 46 | value: "distil-whisper/gigaspeech-l" 47 | eval_dataset_config_name: 48 | value: "l" 49 | cache_dir: 50 | value: /home/sanchitgandhi/.cache 51 | dataset_cache_dir: 52 | value: /home/sanchitgandhi/.cache 53 | output_dir: 54 | value: ./ 55 | per_device_train_batch_size: 56 | value: 64 57 | dtype: 58 | value: bfloat16 59 | learning_rate: 60 | value: 1e-4 61 | lr_scheduler_type: 62 | value: constant_with_warmup 63 | warmup_steps: 64 | value: 50 65 | max_steps: 66 | value: 2500 67 | save_steps: 68 | value: 2501 # don't save checkpoints during sweep 69 | dataloader_num_workers: 70 | value: 48 71 | logging_steps: 72 | value: 25 73 | wer_threshold: 74 | value: 10 75 | program: run_distillation.py 76 | project: distil-whisper-sweeps 77 | -------------------------------------------------------------------------------- /training/flax/distillation_scripts/run_wer_sweep.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - python3 3 | - ${program} 4 | - --do_train 5 | - --do_eval 6 | - --use_scan 7 | - --gradient_checkpointing 8 | - --overwrite_output_dir 9 | - --predict_with_generate 10 | - --freeze_encoder 11 | - --streaming 12 | - --use_auth_token 13 | - ${args} 14 | method: grid 15 | metric: 16 | goal: minimize 17 | name: gigaspeech-l/validation/wer 18 | parameters: 19 | model_name_or_path: 20 | value: distil-whisper/large-32-2 21 | teacher_model_name_or_path: 22 | value: openai/whisper-large-v2 23 | train_dataset_name: 24 | value: librispeech_asr-timestamped+librispeech_asr-timestamped+librispeech_asr-timestamped+common_voice_13_0-timestamped+voxpopuli-timestamped+ami-ihm-timestamped+ami-sdm-timestamped+peoples_speech-clean-timestamped+tedlium-timestamped+switchboard-data+gigaspeech-l-timestamped+librispeech_asr-prompted+librispeech_asr-prompted+librispeech_asr-prompted+tedlium-prompted 25 | train_dataset_config_name: 26 | value: all+all+all+en+en+ihm+sdm+clean+release3+all+l+all+all+all+release3 27 | train_split_name: 28 | value: train.clean.100+train.clean.360+train.other.500+train+train+train+train+train+train+train+train+train.clean.100+train.clean.360+train.other.500+train 29 | train_dataset_samples: 30 | value: 2.9+10.4+14.9+89+18.2+10.9+10.9+288+26.8+371.2+226.6+2.9+10.4+14.9+26.8 31 | eval_dataset_name: 32 | value: librispeech_asr+librispeech_asr+common_voice_13_0+voxpopuli+ami-ihm+ami-sdm+peoples_speech-clean+tedlium+switchboard-data+gigaspeech-l+spgispeech+chime4+google/fleurs 33 | eval_dataset_config_name: 34 | value: all+all+en+en+ihm+sdm+clean+release3+all+l+L+1-channel+en_us 35 | eval_split_name: 36 | value: validation.clean+validation.other+validation+validation+validation+validation+validation+validation+validation+validation+validation+validation+validation 37 | eval_text_column_name: 38 | value: text+text+text+text+text+text+text+text+text+text+text+text+transcription 39 | cache_dir: 40 | value: /home/sanchitgandhi/.cache 41 | dataset_cache_dir: 42 | value: /home/sanchitgandhi/.cache 43 | output_dir: 44 | value: ./ 45 | per_device_train_batch_size: 46 | value: 64 47 | per_device_eval_batch_size: 48 | value: 64 49 | dtype: 50 | value: bfloat16 51 | learning_rate: 52 | value: 1e-4 53 | lr_scheduler_type: 54 | value: constant_with_warmup 55 | warmup_steps: 56 | value: 50 57 | max_steps: 58 | value: 10000 59 | save_steps: 60 | value: 10001 # don't save checkpoints during sweep 61 | dataloader_num_workers: 62 | value: 48 63 | logging_steps: 64 | value: 25 65 | wer_threshold: 66 | values: 67 | - 100 68 | - 20 69 | - 15 70 | - 10 71 | - 5 72 | program: run_distillation.py 73 | project: distil-whisper-sweeps 74 | -------------------------------------------------------------------------------- /training/flax/evaluation_scripts/run_baselines.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python run_eval.py \ 4 | --model_name_or_path "openai/whisper-tiny.en" \ 5 | --dataset_name "librispeech_asr+librispeech_asr+common_voice_13_0+voxpopuli+ami-ihm+ami-sdm+peoples_speech-clean+tedlium+switchboard-data+gigaspeech-l+spgispeech+chime4+google/fleurs+sanchit-gandhi/earnings22_split_resampled" \ 6 | --dataset_config_name "all+all+en+en+ihm+sdm+clean+release3+all+l+L+1-channel+en_us+default" \ 7 | --dataset_split_name "validation.clean+validation.other+validation+validation+validation+validation+validation+validation+validation+validation+validation+validation+validation+validation" \ 8 | --text_column_name "text+text+text+text+text+text+text+text+text+text+text+text+transcription+sentence" \ 9 | --cache_dir "/home/sanchitgandhi/.cache" \ 10 | --dataset_cache_dir "/home/sanchitgandhi/.cache" \ 11 | --output_dir "./" \ 12 | --wandb_dir "/home/sanchitgandhi/.cache" \ 13 | --wandb_project "distil-whisper-eval" \ 14 | --wandb_name "tiny.en" \ 15 | --per_device_eval_batch_size 32 \ 16 | --dtype "bfloat16" \ 17 | --dataloader_num_workers 0 \ 18 | --report_to "wandb" \ 19 | --streaming \ 20 | --predict_with_generate 21 | 22 | python run_eval.py \ 23 | --model_name_or_path "openai/whisper-base.en" \ 24 | --dataset_name "librispeech_asr+librispeech_asr+common_voice_13_0+voxpopuli+ami-ihm+ami-sdm+peoples_speech-clean+tedlium+switchboard-data+gigaspeech-l+spgispeech+chime4+google/fleurs+sanchit-gandhi/earnings22_split_resampled" \ 25 | --dataset_config_name "all+all+en+en+ihm+sdm+clean+release3+all+l+L+1-channel+en_us+default" \ 26 | --dataset_split_name "validation.clean+validation.other+validation+validation+validation+validation+validation+validation+validation+validation+validation+validation+validation+validation" \ 27 | --text_column_name "text+text+text+text+text+text+text+text+text+text+text+text+transcription+sentence" \ 28 | --cache_dir "/home/sanchitgandhi/.cache" \ 29 | --dataset_cache_dir "/home/sanchitgandhi/.cache" \ 30 | --output_dir "./" \ 31 | --wandb_dir "/home/sanchitgandhi/.cache" \ 32 | --wandb_project "distil-whisper-eval" \ 33 | --wandb_name "base.en" \ 34 | --per_device_eval_batch_size 32 \ 35 | --dtype "bfloat16" \ 36 | --dataloader_num_workers 0 \ 37 | --report_to "wandb" \ 38 | --streaming \ 39 | --predict_with_generate 40 | 41 | python run_eval.py \ 42 | --model_name_or_path "openai/whisper-small.en" \ 43 | --dataset_name "librispeech_asr+librispeech_asr+common_voice_13_0+voxpopuli+ami-ihm+ami-sdm+peoples_speech-clean+tedlium+switchboard-data+gigaspeech-l+spgispeech+chime4+google/fleurs+sanchit-gandhi/earnings22_split_resampled" \ 44 | --dataset_config_name "all+all+en+en+ihm+sdm+clean+release3+all+l+L+1-channel+en_us+default" \ 45 | --dataset_split_name "validation.clean+validation.other+validation+validation+validation+validation+validation+validation+validation+validation+validation+validation+validation+validation" \ 46 | --text_column_name "text+text+text+text+text+text+text+text+text+text+text+text+transcription+sentence" \ 47 | --cache_dir "/home/sanchitgandhi/.cache" \ 48 | --dataset_cache_dir "/home/sanchitgandhi/.cache" \ 49 | --output_dir "./" \ 50 | --wandb_dir "/home/sanchitgandhi/.cache" \ 51 | --wandb_project "distil-whisper-eval" \ 52 | --wandb_name "small.en" \ 53 | --per_device_eval_batch_size 32 \ 54 | --dtype "bfloat16" \ 55 | --dataloader_num_workers 0 \ 56 | --report_to "wandb" \ 57 | --streaming \ 58 | --predict_with_generate 59 | 60 | python run_eval.py \ 61 | --model_name_or_path "openai/whisper-medium.en" \ 62 | --dataset_name "librispeech_asr+librispeech_asr+common_voice_13_0+voxpopuli+ami-ihm+ami-sdm+peoples_speech-clean+tedlium+switchboard-data+gigaspeech-l+spgispeech+chime4+google/fleurs+sanchit-gandhi/earnings22_split_resampled" \ 63 | --dataset_config_name "all+all+en+en+ihm+sdm+clean+release3+all+l+L+1-channel+en_us+default" \ 64 | --dataset_split_name "validation.clean+validation.other+validation+validation+validation+validation+validation+validation+validation+validation+validation+validation+validation+validation" \ 65 | --text_column_name "text+text+text+text+text+text+text+text+text+text+text+text+transcription+sentence" \ 66 | --cache_dir "/home/sanchitgandhi/.cache" \ 67 | --dataset_cache_dir "/home/sanchitgandhi/.cache" \ 68 | --output_dir "./" \ 69 | --wandb_dir "/home/sanchitgandhi/.cache" \ 70 | --wandb_project "distil-whisper-eval" \ 71 | --wandb_name "medium.en" \ 72 | --per_device_eval_batch_size 32 \ 73 | --dtype "bfloat16" \ 74 | --dataloader_num_workers 0 \ 75 | --report_to "wandb" \ 76 | --streaming \ 77 | --predict_with_generate 78 | 79 | python run_eval.py \ 80 | --model_name_or_path "openai/whisper-large-v2" \ 81 | --dataset_name "librispeech_asr+librispeech_asr+common_voice_13_0+voxpopuli+ami-ihm+ami-sdm+peoples_speech-clean+tedlium+switchboard-data+gigaspeech-l+spgispeech+chime4+google/fleurs+sanchit-gandhi/earnings22_split_resampled" \ 82 | --dataset_config_name "all+all+en+en+ihm+sdm+clean+release3+all+l+L+1-channel+en_us+default" \ 83 | --dataset_split_name "validation.clean+validation.other+validation+validation+validation+validation+validation+validation+validation+validation+validation+validation+validation+validation" \ 84 | --text_column_name "text+text+text+text+text+text+text+text+text+text+text+text+transcription+sentence" \ 85 | --cache_dir "/home/sanchitgandhi/.cache" \ 86 | --dataset_cache_dir "/home/sanchitgandhi/.cache" \ 87 | --output_dir "./" \ 88 | --wandb_dir "/home/sanchitgandhi/.cache" \ 89 | --wandb_project "distil-whisper-eval" \ 90 | --wandb_name "large-v2" \ 91 | --per_device_eval_batch_size 16 \ 92 | --dtype "bfloat16" \ 93 | --dataloader_num_workers 0 \ 94 | --report_to "wandb" \ 95 | --streaming \ 96 | --predict_with_generate -------------------------------------------------------------------------------- /training/flax/evaluation_scripts/run_distilled.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python run_eval.py \ 4 | --model_name_or_path "sanchit-gandhi/large-32-2-ts-freeze-28k-wer-10" \ 5 | --subfolder "checkpoint-15000" \ 6 | --dataset_name "librispeech_asr+librispeech_asr+common_voice_13_0+voxpopuli+ami-ihm+ami-sdm+peoples_speech-clean+tedlium+switchboard-data+gigaspeech-l+spgispeech+chime4+google/fleurs+sanchit-gandhi/earnings22_split_resampled" \ 7 | --dataset_config_name "all+all+en+en+ihm+sdm+clean+release3+all+l+L+1-channel+en_us+default" \ 8 | --dataset_split_name "validation.clean+validation.other+validation+validation+validation+validation+validation+validation+validation+validation+validation+validation+validation+validation" \ 9 | --text_column_name "text+text+text+text+text+text+text+text+text+text+text+text+transcription+sentence" \ 10 | --cache_dir "/home/sanchitgandhi/.cache" \ 11 | --dataset_cache_dir "/home/sanchitgandhi/.cache" \ 12 | --output_dir "./" \ 13 | --wandb_dir "/home/sanchitgandhi/.cache" \ 14 | --wandb_project "distil-whisper-eval" \ 15 | --wandb_name "large-32-2-ts-freeze-28k-wer-10-30k-steps" \ 16 | --per_device_eval_batch_size 64 \ 17 | --dtype "bfloat16" \ 18 | --dataloader_num_workers 0 \ 19 | --report_to "wandb" \ 20 | --streaming \ 21 | --predict_with_generate 22 | -------------------------------------------------------------------------------- /training/flax/evaluation_scripts/run_distilled_16_2.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python run_eval.py \ 4 | --model_name_or_path "sanchit-gandhi/large-16-2-ts-28k-wer-10" \ 5 | --subfolder "checkpoint-10000" \ 6 | --dataset_name "librispeech_asr+librispeech_asr+common_voice_13_0+voxpopuli+ami-ihm+ami-sdm+peoples_speech-clean+tedlium+switchboard-data+gigaspeech-l+spgispeech+chime4+google/fleurs" \ 7 | --dataset_config_name "all+all+en+en+ihm+sdm+clean+release3+all+l+L+1-channel+en_us" \ 8 | --dataset_split_name "validation.clean+validation.other+validation+validation+validation+validation+validation+validation+validation+validation+validation+validation+validation" \ 9 | --text_column_name "text+text+text+text+text+text+text+text+text+text+text+text+transcription" \ 10 | --cache_dir "/home/sanchitgandhi/.cache" \ 11 | --dataset_cache_dir "/home/sanchitgandhi/.cache" \ 12 | --output_dir "./" \ 13 | --wandb_dir "/home/sanchitgandhi/.cache" \ 14 | --wandb_project "distil-whisper-eval" \ 15 | --wandb_name "large-16-2-eval" \ 16 | --per_device_eval_batch_size 64 \ 17 | --dtype "bfloat16" \ 18 | --dataloader_num_workers 0 \ 19 | --report_to "wandb" \ 20 | --streaming \ 21 | --predict_with_generate 22 | -------------------------------------------------------------------------------- /training/flax/evaluation_scripts/run_librispeech_eval_dummy.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python run_eval.py \ 4 | --model_name_or_path "openai/whisper-large-v2" \ 5 | --dataset_name "gigaspeech-l+gigaspeech-l" \ 6 | --dataset_config_name "l+l" \ 7 | --dataset_split_name "train+validation" \ 8 | --text_column_name "text" \ 9 | --cache_dir "/home/sanchitgandhi/.cache" \ 10 | --dataset_cache_dir "/home/sanchitgandhi/.cache" \ 11 | --output_dir "./" \ 12 | --wandb_dir "/home/sanchitgandhi/.cache" \ 13 | --wandb_project "distil-whisper-label" \ 14 | --wandb_name "whisper-large-v2-gigaspeech-l-with-audio" \ 15 | --per_device_eval_batch_size 64 \ 16 | --dtype "bfloat16" \ 17 | --dataloader_num_workers 0 \ 18 | --report_to "wandb" \ 19 | --streaming \ 20 | --max_eval_samples 1024 \ 21 | --predict_with_generate \ 22 | --log_audio 23 | -------------------------------------------------------------------------------- /training/flax/evaluation_scripts/test/run_baselines.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | DATASET_NAMES="librispeech_asr+librispeech_asr+common_voice_13_0+voxpopuli+ami-ihm+ami-sdm+peoples_speech-clean+tedlium+switchboard-data+switchboard-data+gigaspeech-l+spgispeech+chime4+google/fleurs+earnings22" 4 | DATASET_CONFIG_NAMES="all+all+en+en+ihm+sdm+clean+release3+all+all+l+L+1-channel+en_us+chunked" 5 | DATASET_SPLIT_NAMES="test.clean+test.other+test+test+test+test+test+test+test.switchboard+test.callhome+test+test+test+test+test" 6 | TEXT_COLUMN_NAMES="text+text+text+text+text+text+text+text+text+text+text+text+text+transcription+transcription" 7 | 8 | python run_eval.py \ 9 | --model_name_or_path "openai/whisper-tiny.en" \ 10 | --dataset_name $DATASET_NAMES \ 11 | --dataset_config_name $DATASET_CONFIG_NAMES \ 12 | --dataset_split_name $DATASET_SPLIT_NAMES \ 13 | --text_column_name $TEXT_COLUMN_NAMES \ 14 | --cache_dir "/home/sanchitgandhi/.cache" \ 15 | --dataset_cache_dir "/home/sanchitgandhi/.cache" \ 16 | --output_dir "./" \ 17 | --wandb_dir "/home/sanchitgandhi/.cache" \ 18 | --wandb_project "distil-whisper-test" \ 19 | --wandb_name "tiny.en" \ 20 | --per_device_eval_batch_size 32 \ 21 | --dtype "bfloat16" \ 22 | --dataloader_num_workers 0 \ 23 | --report_to "wandb" \ 24 | --streaming \ 25 | --predict_with_generate 26 | 27 | python run_eval.py \ 28 | --model_name_or_path "openai/whisper-base.en" \ 29 | --dataset_name $DATASET_NAMES \ 30 | --dataset_config_name $DATASET_CONFIG_NAMES \ 31 | --dataset_split_name $DATASET_SPLIT_NAMES \ 32 | --text_column_name $TEXT_COLUMN_NAMES \ 33 | --cache_dir "/home/sanchitgandhi/.cache" \ 34 | --dataset_cache_dir "/home/sanchitgandhi/.cache" \ 35 | --output_dir "./" \ 36 | --wandb_dir "/home/sanchitgandhi/.cache" \ 37 | --wandb_project "distil-whisper-test" \ 38 | --wandb_name "base.en" \ 39 | --per_device_eval_batch_size 32 \ 40 | --dtype "bfloat16" \ 41 | --dataloader_num_workers 0 \ 42 | --report_to "wandb" \ 43 | --streaming \ 44 | --predict_with_generate 45 | 46 | python run_eval.py \ 47 | --model_name_or_path "openai/whisper-small.en" \ 48 | --dataset_name $DATASET_NAMES \ 49 | --dataset_config_name $DATASET_CONFIG_NAMES \ 50 | --dataset_split_name $DATASET_SPLIT_NAMES \ 51 | --text_column_name $TEXT_COLUMN_NAMES \ 52 | --cache_dir "/home/sanchitgandhi/.cache" \ 53 | --dataset_cache_dir "/home/sanchitgandhi/.cache" \ 54 | --output_dir "./" \ 55 | --wandb_dir "/home/sanchitgandhi/.cache" \ 56 | --wandb_project "distil-whisper-test" \ 57 | --wandb_name "small.en" \ 58 | --per_device_eval_batch_size 32 \ 59 | --dtype "bfloat16" \ 60 | --dataloader_num_workers 0 \ 61 | --report_to "wandb" \ 62 | --streaming \ 63 | --predict_with_generate 64 | 65 | python run_eval.py \ 66 | --model_name_or_path "openai/whisper-medium.en" \ 67 | --dataset_name $DATASET_NAMES \ 68 | --dataset_config_name $DATASET_CONFIG_NAMES \ 69 | --dataset_split_name $DATASET_SPLIT_NAMES \ 70 | --text_column_name $TEXT_COLUMN_NAMES \ 71 | --cache_dir "/home/sanchitgandhi/.cache" \ 72 | --dataset_cache_dir "/home/sanchitgandhi/.cache" \ 73 | --output_dir "./" \ 74 | --wandb_dir "/home/sanchitgandhi/.cache" \ 75 | --wandb_project "distil-whisper-test" \ 76 | --wandb_name "medium.en" \ 77 | --per_device_eval_batch_size 32 \ 78 | --dtype "bfloat16" \ 79 | --dataloader_num_workers 0 \ 80 | --report_to "wandb" \ 81 | --streaming \ 82 | --predict_with_generate 83 | 84 | python run_eval.py \ 85 | --model_name_or_path "openai/whisper-large-v2" \ 86 | --dataset_name $DATASET_NAMES \ 87 | --dataset_config_name $DATASET_CONFIG_NAMES \ 88 | --dataset_split_name $DATASET_SPLIT_NAMES \ 89 | --text_column_name $TEXT_COLUMN_NAMES \ 90 | --cache_dir "/home/sanchitgandhi/.cache" \ 91 | --dataset_cache_dir "/home/sanchitgandhi/.cache" \ 92 | --output_dir "./" \ 93 | --wandb_dir "/home/sanchitgandhi/.cache" \ 94 | --wandb_project "distil-whisper-test" \ 95 | --wandb_name "large-v2" \ 96 | --per_device_eval_batch_size 16 \ 97 | --dtype "bfloat16" \ 98 | --dataloader_num_workers 0 \ 99 | --report_to "wandb" \ 100 | --streaming \ 101 | --predict_with_generate 102 | -------------------------------------------------------------------------------- /training/flax/evaluation_scripts/test/run_baselines_pt.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | DATASET_NAMES="librispeech_asr+librispeech_asr+common_voice_13_0+voxpopuli+ami-ihm+ami-sdm+peoples_speech-clean+tedlium+switchboard-data+switchboard-data+gigaspeech-l+spgispeech+chime4+google/fleurs+earnings22" 4 | DATASET_CONFIG_NAMES="all+all+en+en+ihm+sdm+clean+release3+all+all+l+L+1-channel+en_us+chunked" 5 | DATASET_SPLIT_NAMES="test.clean+test.other+test+test+test+test+test+test+test.switchboard+test.callhome+test+test+test+test+test" 6 | TEXT_COLUMN_NAMES="text+text+text+text+text+text+text+text+text+text+text+text+text+transcription+transcription" 7 | 8 | python run_pt_long_form_transcription.py \ 9 | --model_name_or_path "facebook/wav2vec2-large-960h" \ 10 | --wandb_name "facebook/wav2vec2-large-960h" \ 11 | --dataset_name $DATASET_NAMES \ 12 | --dataset_config_name $DATASET_CONFIG_NAMES \ 13 | --dataset_split_name $DATASET_SPLIT_NAMES \ 14 | --text_column_name $TEXT_COLUMN_NAMES \ 15 | --output_dir "./" \ 16 | --wandb_project "distil-whisper-test" \ 17 | --per_device_eval_batch_size 32 \ 18 | --dtype "float16" \ 19 | --dataloader_num_workers 0 \ 20 | --report_to "wandb" \ 21 | --streaming \ 22 | --predict_with_generate 23 | -------------------------------------------------------------------------------- /training/flax/evaluation_scripts/test/run_distilled.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | DATASET_NAMES="librispeech_asr+librispeech_asr+common_voice_13_0+voxpopuli+ami-ihm+ami-sdm+peoples_speech-clean+tedlium+switchboard-data+switchboard-data+gigaspeech-l+spgispeech+chime4+google/fleurs+earnings22" 4 | DATASET_CONFIG_NAMES="all+all+en+en+ihm+sdm+clean+release3+all+all+l+L+1-channel+en_us+chunked" 5 | DATASET_SPLIT_NAMES="test.clean+test.other+test+test+test+test+test+test+test.switchboard+test.callhome+test+test+test+test+test" 6 | TEXT_COLUMN_NAMES="text+text+text+text+text+text+text+text+text+text+text+text+text+transcription+transcription" 7 | 8 | python run_eval.py \ 9 | --model_name_or_path "sanchit-gandhi/large-32-2-tpu-timestamped-resumed" \ 10 | --wandb_name "large-32-2-tpu-timestamped" \ 11 | --dataset_name $DATASET_NAMES \ 12 | --dataset_config_name $DATASET_CONFIG_NAMES \ 13 | --dataset_split_name $DATASET_SPLIT_NAMES \ 14 | --text_column_name $TEXT_COLUMN_NAMES \ 15 | --cache_dir "/home/sanchitgandhi/.cache" \ 16 | --dataset_cache_dir "/home/sanchitgandhi/.cache" \ 17 | --output_dir "./" \ 18 | --wandb_dir "/home/sanchitgandhi/.cache" \ 19 | --wandb_project "distil-whisper-test" \ 20 | --per_device_eval_batch_size 32 \ 21 | --dtype "bfloat16" \ 22 | --dataloader_num_workers 0 \ 23 | --report_to "wandb" \ 24 | --streaming \ 25 | --predict_with_generate 26 | 27 | 28 | python run_eval.py \ 29 | --model_name_or_path "sanchit-gandhi/medium-24-2-tpu-timestamped-prob-0.2" \ 30 | --subfolder "checkpoint-45000" \ 31 | --wandb_name "medium-24-2-tpu-timestamped-prob-0.2" \ 32 | --dataset_name $DATASET_NAMES \ 33 | --dataset_config_name $DATASET_CONFIG_NAMES \ 34 | --dataset_split_name $DATASET_SPLIT_NAMES \ 35 | --text_column_name $TEXT_COLUMN_NAMES \ 36 | --cache_dir "/home/sanchitgandhi/.cache" \ 37 | --dataset_cache_dir "/home/sanchitgandhi/.cache" \ 38 | --output_dir "./" \ 39 | --wandb_dir "/home/sanchitgandhi/.cache" \ 40 | --wandb_project "distil-whisper-test" \ 41 | --per_device_eval_batch_size 32 \ 42 | --dtype "bfloat16" \ 43 | --dataloader_num_workers 0 \ 44 | --report_to "wandb" \ 45 | --streaming \ 46 | --predict_with_generate 47 | -------------------------------------------------------------------------------- /training/flax/finetuning_scripts/run_librispeech.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python run_finetuning.py \ 4 | --model_name_or_path "distil-whisper/large-32-2" \ 5 | --dataset_name "distil-whisper/librispeech_asr" \ 6 | --dataset_config_name "all" \ 7 | --train_split_name "train.clean.100+train.clean.360+train.other.500" \ 8 | --eval_split_name "validation.clean" \ 9 | --text_column_name "whisper_transcript" \ 10 | --cache_dir "/home/sanchitgandhi/cache" \ 11 | --dataset_cache_dir "/home/sanchitgandhi/cache" \ 12 | --output_dir "./" \ 13 | --wandb_name "large-32-2-pl-librispeech" \ 14 | --wandb_dir "/home/sanchitgandhi/.cache" \ 15 | --wandb_project "distil-whisper-librispeech" \ 16 | --per_device_train_batch_size 32 \ 17 | --per_device_eval_batch_size 16 \ 18 | --dtype "bfloat16" \ 19 | --learning_rate 1e-4 \ 20 | --warmup_steps 500 \ 21 | --do_train \ 22 | --do_eval \ 23 | --num_train_epochs 10 \ 24 | --preprocessing_num_workers 16 \ 25 | --dataloader_num_workers 8 \ 26 | --logging_steps 25 \ 27 | --use_scan \ 28 | --gradient_checkpointing \ 29 | --overwrite_output_dir \ 30 | --predict_with_generate \ 31 | --push_to_hub 32 | -------------------------------------------------------------------------------- /training/flax/finetuning_scripts/run_librispeech_dummy.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python run_finetuning.py \ 4 | --model_name_or_path "distil-whisper/tiny-random-whisper" \ 5 | --dataset_name "distil-whisper/librispeech_asr" \ 6 | --dataset_config_name "all" \ 7 | --train_split_name "train.clean.100[:1024]" \ 8 | --eval_split_name "validation.clean[:1024]" \ 9 | --cache_dir "/home/sanchitgandhi/cache" \ 10 | --dataset_cache_dir "/home/sanchitgandhi/cache" \ 11 | --wandb_dir "/home/sanchitgandhi/.cache" \ 12 | --text_column_name "text" \ 13 | --output_dir "./" \ 14 | --do_train \ 15 | --do_eval \ 16 | --per_device_train_batch_size 8 \ 17 | --per_device_eval_batch_size 4 \ 18 | --dtype "bfloat16" \ 19 | --num_train_epochs 2 \ 20 | --dataloader_num_workers 16 \ 21 | --freeze_encoder \ 22 | --wandb_project "distil-whisper-debug" \ 23 | --logging_steps 2 \ 24 | --use_scan \ 25 | --gradient_checkpointing \ 26 | --overwrite_output_dir \ 27 | --predict_with_generate 28 | -------------------------------------------------------------------------------- /training/flax/finetuning_scripts/run_librispeech_eval.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python run_eval.py \ 4 | --model_name_or_path "./" \ 5 | --dataset_name "distil-whisper/librispeech_asr" \ 6 | --dataset_config_name "all" \ 7 | --test_split_name "validation.clean+validation.other+test.clean+test.other" \ 8 | --text_column_name "text" \ 9 | --cache_dir "/home/sanchitgandhi/cache" \ 10 | --dataset_cache_dir "/home/sanchitgandhi/cache" \ 11 | --output_dir "./" \ 12 | --wandb_name "large-32-2-pl-freeze-librispeech-eval" \ 13 | --wandb_dir "/home/sanchitgandhi/.cache" \ 14 | --wandb_project "distil-whisper-librispeech" \ 15 | --per_device_eval_batch_size 128 \ 16 | --dtype "bfloat16" \ 17 | --do_predict \ 18 | --preprocessing_num_workers 16 \ 19 | --dataloader_num_workers 8 \ 20 | --load_with_scan \ 21 | --predict_with_generate \ 22 | --report_to "wandb" 23 | -------------------------------------------------------------------------------- /training/flax/finetuning_scripts/run_librispeech_eval_dummy.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python run_eval.py \ 4 | --model_name_or_path "./" \ 5 | --dataset_name "distil-whisper/librispeech_asr" \ 6 | --dataset_config_name "all" \ 7 | --test_split_name "validation.clean[:32]+validation.other[:32]" \ 8 | --text_column_name "text" \ 9 | --cache_dir "/home/sanchitgandhi/cache" \ 10 | --dataset_cache_dir "/home/sanchitgandhi/cache" \ 11 | --output_dir "./" \ 12 | --wandb_dir "/home/sanchitgandhi/.cache" \ 13 | --wandb_project "distil-whisper-debug" \ 14 | --per_device_eval_batch_size 4 \ 15 | --dtype "bfloat16" \ 16 | --do_predict \ 17 | --preprocessing_num_workers 16 \ 18 | --dataloader_num_workers 8 \ 19 | --load_with_scan \ 20 | --predict_with_generate \ 21 | --report_to "wandb" 22 | -------------------------------------------------------------------------------- /training/flax/finetuning_scripts/run_librispeech_sweep.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - python3 3 | - ${program} 4 | - --do_train 5 | - --do_eval 6 | - --use_scan 7 | - --gradient_checkpointing 8 | - --overwrite_output_dir 9 | - --predict_with_generate 10 | - ${args} 11 | method: random 12 | metric: 13 | goal: minimize 14 | name: eval/wer 15 | parameters: 16 | model_name_or_path: 17 | value: distil-whisper/large-32-2 18 | dataset_name: 19 | value: distil-whisper/librispeech_asr 20 | dataset_config_name: 21 | value: all 22 | train_split_name: 23 | value: train.clean.100+train.clean.360+train.other.500 24 | eval_split_name: 25 | value: validation.clean 26 | text_column_name: 27 | value: whisper_transcript 28 | cache_dir: 29 | value: /home/sanchitgandhi/cache 30 | dataset_cache_dir: 31 | value: /home/sanchitgandhi/cache 32 | output_dir: 33 | value: ./ 34 | per_device_train_batch_size: 35 | value: 32 36 | per_device_eval_batch_size: 37 | value: 16 38 | dtype: 39 | value: bfloat16 40 | learning_rate: 41 | distribution: log_uniform 42 | max: -6.91 43 | min: -11.51 44 | warmup_steps: 45 | value 500 46 | num_train_epochs: 47 | value: 1 48 | preprocessing_num_workers: 49 | value: 16 50 | dataloader_num_workers: 51 | value: 16 52 | logging_steps: 53 | value: 25 54 | freeze_encoder: 55 | values: 56 | - True 57 | - False 58 | 59 | program: run_finetuning.py 60 | project: distil-whisper 61 | -------------------------------------------------------------------------------- /training/flax/initialisation_scripts/run_large_32_2_init.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD=10000000000 python create_student_model.py \ 4 | --teacher_checkpoint "openai/whisper-large-v2" \ 5 | --decoder_layers 2 \ 6 | --save_dir "./" -------------------------------------------------------------------------------- /training/flax/initialisation_scripts/run_medium_24_2_init.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD=10000000000 python create_student_model.py \ 4 | --teacher_checkpoint "openai/whisper-medium.en" \ 5 | --decoder_layers 2 \ 6 | --save_dir "./" 7 | -------------------------------------------------------------------------------- /training/flax/initialisation_scripts/run_small_12_2_init.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD=10000000000 python create_student_model.py \ 4 | --teacher_checkpoint "openai/whisper-small.en" \ 5 | --decoder_layers 2 \ 6 | --save_dir "./" 7 | -------------------------------------------------------------------------------- /training/flax/initialisation_scripts/run_tiny_2_1_init.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD=10000000000 python create_student_model.py \ 4 | --teacher_checkpoint "distil-whisper/tiny-random-whisper" \ 5 | --decoder_layers 1 \ 6 | --save_dir "./" 7 | -------------------------------------------------------------------------------- /training/flax/initialisation_scripts/run_tiny_2_1_init_pt.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python create_student_model_pt.py \ 4 | --teacher_checkpoint "distil-whisper/tiny-random-whisper" \ 5 | --decoder_layers 1 \ 6 | --save_dir "./" 7 | -------------------------------------------------------------------------------- /training/flax/latency_scripts/run_speculative.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # batch_sizes=(1 4) 3 | batch_sizes=(1) 4 | names=("openai/whisper-large-v2" "openai/whisper-large-v2" "openai/whisper-medium.en" "openai/whisper-medium.en") 5 | assistant_names=("patrickvonplaten/whisper-large-v2-32-2" "openai/whisper-small" "patrickvonplaten/whisper-medium-24-2" "openai/whisper-base.en") 6 | 7 | # --assistant_model_name_or_path "patrickvonplaten/whisper-large-v2-32-2" \ 8 | # --use_pipeline \ 9 | 10 | # Double loop 11 | 12 | for (( i=0; i<${#names[*]}; ++i)); do 13 | name=${names[$i]} 14 | assistant_name=${assistant_names[$i]} 15 | 16 | for batch_size in "${batch_sizes[@]}"; do 17 | CUDA_VISIBLE_DEVICES="0" python ./run_speed_pt.py \ 18 | --dataset_name "distil-whisper/chime4+distil-whisper/earnings22+google/fleurs+kensho/spgispeech" \ 19 | --wandb_name "FP16-RTX-4090-bsz${batch_size}-${name}-${assistant_name}" \ 20 | --model_name_or_path ${name} \ 21 | --wandb_project "distil-whisper-speed-bench-check-spec-dec-final" \ 22 | --dataset_config_name "1-channel+chunked+en_us+test" \ 23 | --dataset_split_name "test+test+test+test" \ 24 | --text_column_name "text+transcription+transcription+transcript" \ 25 | --attn_type "flash2" \ 26 | --assistant_model_name_or_path ${assistant_name} \ 27 | --samples_per_dataset "10" \ 28 | --batch_size ${batch_size} 29 | done 30 | done 31 | -------------------------------------------------------------------------------- /training/flax/latency_scripts/run_speed.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # --assistant_model_name_or_path "patrickvonplaten/whisper-large-v2-32-2" \ 3 | # --attn_type "flash2" \ 4 | names=("openai/whisper-large-v2" "openai/whisper-medium.en" "openai/whisper-small.en" "openai/whisper-base.en" "openai/whisper-tiny.en" "patrickvonplaten/whisper-large-v2-32-2" "patrickvonplaten/whisper-medium-24-2") 5 | batch_sizes=(1 4 16) 6 | 7 | # Double loop 8 | for name in "${names[@]}"; do 9 | for batch_size in "${batch_sizes[@]}"; do 10 | CUDA_VISIBLE_DEVICES="1" python ./run_speed_pt.py \ 11 | --dataset_name "google/fleurs+distil-whisper/chime4+distil-whisper/earnings22+kensho/spgispeech" \ 12 | --wandb_name "A100-bsz${batch_size}-${name}" \ 13 | --model_name_or_path ${name} \ 14 | --wandb_project "distil-whisper-speed-bench-256-no-timestamps" \ 15 | --dataset_config_name "en_us+1-channel+chunked+test" \ 16 | --dataset_split_nam "test+test+test+test" \ 17 | --text_column_name "transcription+text+transcription+transcript" \ 18 | --samples_per_dataset "256" \ 19 | --batch_size ${batch_size} 20 | done 21 | done 22 | -------------------------------------------------------------------------------- /training/flax/latency_scripts/run_speed_longform.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | names=("openai/whisper-large-v2" "openai/whisper-medium.en" "openai/whisper-small.en" "openai/whisper-base.en" "openai/whisper-tiny.en" "patrickvonplaten/whisper-large-v2-32-2" "patrickvonplaten/whisper-medium-24-2") 3 | 4 | # chunk_lengths=("15.0" "30.0") 5 | # --assistant_model_name_or_path "patrickvonplaten/whisper-large-v2-32-2" \ 6 | # --attn_type "flash" \ 7 | 8 | # Double loop 9 | for name in "${names[@]}"; do 10 | if [[ ${name:0:6} == "openai" ]]; then 11 | chunk_length_s=30.0 12 | else 13 | chunk_length_s=15.0 14 | fi 15 | 16 | CUDA_VISIBLE_DEVICES="1" python ./run_speed_pt.py \ 17 | --dataset_name "distil-whisper/earnings21+distil-whisper/earnings22+distil-whisper/meanwhile+distil-whisper/rev16" \ 18 | --wandb_name "T4-${name}-Longform" \ 19 | --model_name_or_path ${name} \ 20 | --wandb_project "distil-whisper-speed-bench-long-form-32" \ 21 | --dataset_config_name "full+full+default+whisper_subset" \ 22 | --dataset_split_name "test+test+test+test" \ 23 | --text_column_name "transcription+transcription+text+transcription" \ 24 | --chunk_length_s "$chunk_length_s" \ 25 | --use_pipeline \ 26 | --return_timestamps \ 27 | --max_label_length "1000000" \ 28 | --samples_per_dataset "32" \ 29 | --batch_size "1" 30 | done 31 | -------------------------------------------------------------------------------- /training/flax/latency_scripts/run_trial.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | CUDA_VISIBLE_DEVICES="0" python ./run_speed_pt.py \ 3 | --dataset_name "distil-whisper/earnings22" \ 4 | --wandb_name "[Earnings] RTX 4090 - large-v2-32-2" \ 5 | --model_name_or_path "patrickvonplaten/whisper-large-v2-32-2" \ 6 | --wandb_project "distil-whisper-speed-benchmark" \ 7 | --dataset_config_name "chunked" \ 8 | --dataset_split_nam "test" \ 9 | --text_column_name "transcription" \ 10 | --batch_size 1 11 | -------------------------------------------------------------------------------- /training/flax/long_form_transcription_scripts/run_chunk_length_s_sweep.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - python3 3 | - ${program} 4 | - --streaming 5 | - ${args} 6 | method: grid 7 | metric: 8 | goal: minimize 9 | name: tedlium-long-form/validation/wer 10 | parameters: 11 | model_name_or_path: 12 | value: sanchit-gandhi/large-32-2-ts-freeze-28k-wer-10 13 | subfolder: 14 | value: checkpoint-15000 15 | dataset_name: 16 | value: distil-whisper/tedlium-long-form 17 | dataset_config_name: 18 | value: all 19 | dataset_split_name: 20 | value: validation 21 | cache_dir: 22 | value: /home/sanchitgandhi/.cache 23 | dataset_cache_dir: 24 | value: /home/sanchitgandhi/.cache 25 | compilation_cache: 26 | value: /home/sanchitgandhi/.cache 27 | output_dir: 28 | value: ./ 29 | wandb_dir: 30 | value: /home/sanchitgandhi/.cache 31 | per_device_eval_batch_size: 32 | value: 8 33 | dtype: 34 | value: bfloat16 35 | report_to: 36 | value: wandb 37 | chunk_length_s: 38 | values: 39 | - 10 40 | - 15 41 | - 20 42 | - 25 43 | - 30 44 | generation_max_length: 45 | value: 128 46 | program: run_long_form_transcription.py 47 | project: distil-whisper-long-form -------------------------------------------------------------------------------- /training/flax/long_form_transcription_scripts/run_eval_with_pipeline.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | DATASET_NAMES="librispeech_asr+librispeech_asr+common_voice_13_0+voxpopuli+ami-ihm+ami-sdm+peoples_speech-clean+tedlium+switchboard-data+gigaspeech-l+spgispeech+chime4+google/fleurs+sanchit-gandhi/earnings22_split_resampled" 4 | DATASET_CONFIG_NAMES="all+all+en+en+ihm+sdm+clean+release3+all+l+L+1-channel+en_us+default" 5 | DATASET_SPLIT_NAMES="validation.clean+validation.other+validation+validation+validation+validation+validation+validation+validation+validation+validation+validation+validation+validation" 6 | TEXT_COLUMN_NAMES="text+text+text+text+text+text+text+text+text+text+text+text+transcription+sentence" 7 | 8 | python run_long_form_transcription.py \ 9 | --model_name_or_path "sanchit-gandhi/large-32-2-ts-28k-wer-10-converted-context-20s" \ 10 | --dataset_name $DATASET_NAMES \ 11 | --dataset_config_name $DATASET_CONFIG_NAMES \ 12 | --dataset_split_name $DATASET_SPLIT_NAMES \ 13 | --text_column_name $TEXT_COLUMN_NAMES \ 14 | --cache_dir "/home/sanchitgandhi/.cache" \ 15 | --dataset_cache_dir "/home/sanchitgandhi/.cache" \ 16 | --output_dir "./" \ 17 | --wandb_dir "/home/sanchitgandhi/.cache" \ 18 | --wandb_project "distil-whisper-eval" \ 19 | --wandb_name "large-32-2-ts-freeze-28k-wer-10-30k-steps-chunk-length-15-context-20" \ 20 | --per_device_eval_batch_size 1 \ 21 | --chunk_length_s 15 \ 22 | --dtype "bfloat16" \ 23 | --report_to "wandb" \ 24 | --streaming 25 | -------------------------------------------------------------------------------- /training/flax/long_form_transcription_scripts/run_length_penalty_sweep.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - python3 3 | - ${program} 4 | - --streaming 5 | - ${args} 6 | method: grid 7 | metric: 8 | goal: minimize 9 | name: tedlium-long-form/validation/wer 10 | parameters: 11 | model_name_or_path: 12 | value: sanchit-gandhi/large-32-2-ts-freeze-28k-wer-10 13 | subfolder: 14 | value: checkpoint-15000 15 | dataset_name: 16 | value: distil-whisper/tedlium-long-form 17 | dataset_config_name: 18 | value: all 19 | dataset_split_name: 20 | value: validation 21 | cache_dir: 22 | value: /home/sanchitgandhi/.cache 23 | dataset_cache_dir: 24 | value: /home/sanchitgandhi/.cache 25 | output_dir: 26 | value: ./ 27 | wandb_dir: 28 | value: /home/sanchitgandhi/.cache 29 | per_device_eval_batch_size: 30 | value: 32 31 | dtype: 32 | value: bfloat16 33 | report_to: 34 | value: wandb 35 | generation_num_beams: 36 | value: 5 37 | generation_max_length: 38 | value: 256 39 | length_penalty: 40 | values: 41 | - 0.6 42 | - 0.8 43 | - 1.0 44 | - 1.2 45 | - 1.4 46 | program: run_long_form_transcription.py 47 | project: distil-whisper-long-form -------------------------------------------------------------------------------- /training/flax/long_form_transcription_scripts/run_tedlium_long_form.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python run_long_form_transcription.py \ 4 | --model_name_or_path "sanchit-gandhi/large-32-2-ts-freeze-28k-wer-10" \ 5 | --subfolder "checkpoint-15000" \ 6 | --dataset_name "distil-whisper/tedlium-long-form" \ 7 | --dataset_config_name "all" \ 8 | --dataset_split_name "validation" \ 9 | --cache_dir "/home/sanchitgandhi/.cache" \ 10 | --dataset_cache_dir "/home/sanchitgandhi/.cache" \ 11 | --output_dir "./" \ 12 | --wandb_dir "/home/sanchitgandhi/.cache" \ 13 | --wandb_project "distil-whisper-long-form" \ 14 | --wandb_name "large-32-2-ts-freeze-28k-wer-10-30k-steps" \ 15 | --per_device_eval_batch_size 32 \ 16 | --chunk_length_s 20 \ 17 | --dtype "bfloat16" \ 18 | --report_to "wandb" \ 19 | --streaming 20 | -------------------------------------------------------------------------------- /training/flax/long_form_transcription_scripts/run_tedlium_long_form_dummy.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python run_long_form_transcription.py \ 4 | --model_name_or_path "openai/whisper-tiny" \ 5 | --dataset_name "distil-whisper/tedlium-long-form" \ 6 | --dataset_config_name "all" \ 7 | --dataset_split_name "validation" \ 8 | --cache_dir "/home/sanchitgandhi/.cache" \ 9 | --dataset_cache_dir "/home/sanchitgandhi/.cache" \ 10 | --output_dir "./" \ 11 | --wandb_dir "/home/sanchitgandhi/.cache" \ 12 | --wandb_project "distil-whisper-debug" \ 13 | --wandb_name "whisper-tiny-tedlium-long-form" \ 14 | --per_device_eval_batch_size 64 \ 15 | --max_eval_samples 1 \ 16 | --dtype "bfloat16" \ 17 | --report_to "wandb" \ 18 | --streaming 19 | -------------------------------------------------------------------------------- /training/flax/long_form_transcription_scripts/run_tedlium_long_form_timestamps.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python run_long_form_transcription.py \ 4 | --model_name_or_path "sanchit-gandhi/large-32-2-ts-freeze-28k-wer-10-v4-8-10k-steps" \ 5 | --dataset_name "distil-whisper/tedlium-long-form+distil-whisper/tedlium-long-form" \ 6 | --dataset_config_name "all+all" \ 7 | --dataset_split_name "validation+test" \ 8 | --cache_dir "/home/sanchitgandhi/.cache" \ 9 | --dataset_cache_dir "/home/sanchitgandhi/.cache" \ 10 | --output_dir "./" \ 11 | --wandb_dir "/home/sanchitgandhi/.cache" \ 12 | --wandb_project "distil-whisper-long-form" \ 13 | --wandb_name "large-32-2-ts-freeze-28k-wer-10-v4-8-10k-steps-tedlium-timestamps" \ 14 | --per_device_eval_batch_size 32 \ 15 | --dtype "bfloat16" \ 16 | --report_to "wandb" \ 17 | --streaming \ 18 | --return_timestamps -------------------------------------------------------------------------------- /training/flax/long_form_transcription_scripts/run_top_k_temperature_sweep.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - python3 3 | - ${program} 4 | - --streaming 5 | - --do_sample 6 | - ${args} 7 | method: grid 8 | metric: 9 | goal: minimize 10 | name: tedlium-long-form/validation/wer 11 | parameters: 12 | model_name_or_path: 13 | value: sanchit-gandhi/large-32-2-ts-freeze-28k-wer-10 14 | subfolder: 15 | value: checkpoint-15000 16 | dataset_name: 17 | value: distil-whisper/tedlium-long-form 18 | dataset_config_name: 19 | value: all 20 | dataset_split_name: 21 | value: validation 22 | cache_dir: 23 | value: /home/sanchitgandhi/.cache 24 | dataset_cache_dir: 25 | value: /home/sanchitgandhi/.cache 26 | output_dir: 27 | value: ./ 28 | wandb_dir: 29 | value: /home/sanchitgandhi/.cache 30 | per_device_eval_batch_size: 31 | value: 32 32 | dtype: 33 | value: bfloat16 34 | report_to: 35 | value: wandb 36 | generation_num_beams: 37 | value: 1 38 | generation_max_length: 39 | value: 256 40 | temperature: 41 | values: 42 | - 0.2 43 | - 0.4 44 | - 0.6 45 | - 0.8 46 | - 1.0 47 | - 1.2 48 | chunk_length_s: 49 | value: 20 50 | program: run_long_form_transcription.py 51 | project: distil-whisper-long-form -------------------------------------------------------------------------------- /training/flax/long_form_transcription_scripts/test/run_baselines.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | DATASET_NAMES="distil-whisper/tedlium-long-form+distil-whisper/earnings21+distil-whisper/earnings22+distil-whisper/meanwhile+distil-whisper/rev16" 4 | DATASET_CONFIG_NAMES="all+full+full+default+whisper_subset" 5 | DATASET_SPLIT_NAMES="test+test+test+test+test" 6 | TEXT_COLUMN_NAMES="text+transcription+transcription+text+transcription" 7 | 8 | python run_long_form_transcription.py \ 9 | --model_name_or_path "openai/whisper-tiny.en" \ 10 | --dataset_name $DATASET_NAMES \ 11 | --dataset_config_name $DATASET_CONFIG_NAMES \ 12 | --dataset_split_name $DATASET_SPLIT_NAMES \ 13 | --text_column_name $TEXT_COLUMN_NAMES \ 14 | --cache_dir "/home/sanchitgandhi/.cache" \ 15 | --dataset_cache_dir "/home/sanchitgandhi/.cache" \ 16 | --output_dir "./" \ 17 | --wandb_dir "/home/sanchitgandhi/.cache" \ 18 | --wandb_project "distil-whisper-long-form-test" \ 19 | --wandb_name "tiny.en" \ 20 | --per_device_eval_batch_size 16 \ 21 | --chunk_length_s 30 \ 22 | --generation_max_length 128 \ 23 | --dtype "bfloat16" \ 24 | --report_to "wandb" \ 25 | --streaming \ 26 | --return_timestamps 27 | 28 | python run_long_form_transcription.py \ 29 | --model_name_or_path "openai/whisper-base.en" \ 30 | --dataset_name $DATASET_NAMES \ 31 | --dataset_config_name $DATASET_CONFIG_NAMES \ 32 | --dataset_split_name $DATASET_SPLIT_NAMES \ 33 | --text_column_name $TEXT_COLUMN_NAMES \ 34 | --cache_dir "/home/sanchitgandhi/.cache" \ 35 | --dataset_cache_dir "/home/sanchitgandhi/.cache" \ 36 | --output_dir "./" \ 37 | --wandb_dir "/home/sanchitgandhi/.cache" \ 38 | --wandb_project "distil-whisper-long-form-test" \ 39 | --wandb_name "base.en" \ 40 | --per_device_eval_batch_size 16 \ 41 | --chunk_length_s 30 \ 42 | --generation_max_length 128 \ 43 | --dtype "bfloat16" \ 44 | --report_to "wandb" \ 45 | --streaming \ 46 | --return_timestamps 47 | 48 | python run_long_form_transcription.py \ 49 | --model_name_or_path "openai/whisper-small.en" \ 50 | --dataset_name $DATASET_NAMES \ 51 | --dataset_config_name $DATASET_CONFIG_NAMES \ 52 | --dataset_split_name $DATASET_SPLIT_NAMES \ 53 | --text_column_name $TEXT_COLUMN_NAMES \ 54 | --cache_dir "/home/sanchitgandhi/.cache" \ 55 | --dataset_cache_dir "/home/sanchitgandhi/.cache" \ 56 | --output_dir "./" \ 57 | --wandb_dir "/home/sanchitgandhi/.cache" \ 58 | --wandb_project "distil-whisper-long-form-test" \ 59 | --wandb_name "small.en" \ 60 | --per_device_eval_batch_size 16 \ 61 | --chunk_length_s 30 \ 62 | --generation_max_length 128 \ 63 | --dtype "bfloat16" \ 64 | --report_to "wandb" \ 65 | --streaming \ 66 | --return_timestamps 67 | 68 | python run_long_form_transcription.py \ 69 | --model_name_or_path "openai/whisper-medium.en" \ 70 | --dataset_name $DATASET_NAMES \ 71 | --dataset_config_name $DATASET_CONFIG_NAMES \ 72 | --dataset_split_name $DATASET_SPLIT_NAMES \ 73 | --text_column_name $TEXT_COLUMN_NAMES \ 74 | --cache_dir "/home/sanchitgandhi/.cache" \ 75 | --dataset_cache_dir "/home/sanchitgandhi/.cache" \ 76 | --output_dir "./" \ 77 | --wandb_dir "/home/sanchitgandhi/.cache" \ 78 | --wandb_project "distil-whisper-long-form-test" \ 79 | --wandb_name "medium.en" \ 80 | --per_device_eval_batch_size 16 \ 81 | --chunk_length_s 30 \ 82 | --generation_max_length 128 \ 83 | --dtype "bfloat16" \ 84 | --report_to "wandb" \ 85 | --streaming \ 86 | --return_timestamps 87 | 88 | python run_long_form_transcription.py \ 89 | --model_name_or_path "openai/whisper-large-v2" \ 90 | --dataset_name $DATASET_NAMES \ 91 | --dataset_config_name $DATASET_CONFIG_NAMES \ 92 | --dataset_split_name $DATASET_SPLIT_NAMES \ 93 | --text_column_name $TEXT_COLUMN_NAMES \ 94 | --cache_dir "/home/sanchitgandhi/.cache" \ 95 | --dataset_cache_dir "/home/sanchitgandhi/.cache" \ 96 | --output_dir "./" \ 97 | --wandb_dir "/home/sanchitgandhi/.cache" \ 98 | --wandb_project "distil-whisper-long-form-test" \ 99 | --wandb_name "large-v2" \ 100 | --per_device_eval_batch_size 16 \ 101 | --chunk_length_s 30 \ 102 | --generation_max_length 128 \ 103 | --dtype "bfloat16" \ 104 | --report_to "wandb" \ 105 | --streaming \ 106 | --return_timestamps 107 | -------------------------------------------------------------------------------- /training/flax/long_form_transcription_scripts/test/run_baselines_pt.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | DATASET_NAMES="distil-whisper/tedlium-long-form+distil-whisper/earnings21+distil-whisper/earnings22+distil-whisper/meanwhile+distil-whisper/rev16" 4 | DATASET_CONFIG_NAMES="all+full+full+default+whisper_subset" 5 | DATASET_SPLIT_NAMES="test+test+test+test+test" 6 | TEXT_COLUMN_NAMES="text+transcription+transcription+text+transcription" 7 | 8 | python run_pt_long_form_transcription.py \ 9 | --model_name_or_path "facebook/wav2vec2-large-960h" \ 10 | --dataset_name $DATASET_NAMES \ 11 | --dataset_config_name $DATASET_CONFIG_NAMES \ 12 | --dataset_split_name $DATASET_SPLIT_NAMES \ 13 | --text_column_name $TEXT_COLUMN_NAMES \ 14 | --output_dir "./" \ 15 | --wandb_project "distil-whisper-long-form-test" \ 16 | --wandb_name "wav2vec2-large-960h" \ 17 | --per_device_eval_batch_size 32 \ 18 | --chunk_length_s 20 \ 19 | --dtype "float16" \ 20 | --report_to "wandb" \ 21 | --streaming 22 | -------------------------------------------------------------------------------- /training/flax/long_form_transcription_scripts/test/run_distilled.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | DATASET_NAMES="distil-whisper/tedlium-long-form+distil-whisper/earnings21+distil-whisper/earnings22+distil-whisper/meanwhile+distil-whisper/rev16" 4 | DATASET_CONFIG_NAMES="all+full+full+default+whisper_subset" 5 | DATASET_SPLIT_NAMES="test+test+test+test+test" 6 | TEXT_COLUMN_NAMES="text+transcription+transcription+text+transcription" 7 | 8 | python run_long_form_transcription.py \ 9 | --model_name_or_path "sanchit-gandhi/large-32-2-tpu-timestamped-resumed" \ 10 | --dataset_name $DATASET_NAMES \ 11 | --dataset_config_name $DATASET_CONFIG_NAMES \ 12 | --dataset_split_name $DATASET_SPLIT_NAMES \ 13 | --text_column_name $TEXT_COLUMN_NAMES \ 14 | --cache_dir "/home/sanchitgandhi/.cache" \ 15 | --dataset_cache_dir "/home/sanchitgandhi/.cache" \ 16 | --output_dir "./" \ 17 | --wandb_dir "/home/sanchitgandhi/.cache" \ 18 | --wandb_project "distil-whisper-long-form-test" \ 19 | --wandb_name "large-32-2" \ 20 | --per_device_eval_batch_size 16 \ 21 | --chunk_length_s 15 \ 22 | --generation_max_length 128 \ 23 | --dtype "bfloat16" \ 24 | --report_to "wandb" \ 25 | --streaming 26 | 27 | python run_long_form_transcription.py \ 28 | --model_name_or_path "sanchit-gandhi/medium-24-2-tpu-timestamped-prob-0.2" \ 29 | --subfolder "checkpoint-45000" \ 30 | --dataset_name $DATASET_NAMES \ 31 | --dataset_config_name $DATASET_CONFIG_NAMES \ 32 | --dataset_split_name $DATASET_SPLIT_NAMES \ 33 | --text_column_name $TEXT_COLUMN_NAMES \ 34 | --cache_dir "/home/sanchitgandhi/.cache" \ 35 | --dataset_cache_dir "/home/sanchitgandhi/.cache" \ 36 | --output_dir "./" \ 37 | --wandb_dir "/home/sanchitgandhi/.cache" \ 38 | --wandb_project "distil-whisper-long-form-test" \ 39 | --wandb_name "medium-24-2" \ 40 | --per_device_eval_batch_size 16 \ 41 | --chunk_length_s 20 \ 42 | --generation_max_length 128 \ 43 | --dtype "bfloat16" \ 44 | --report_to "wandb" \ 45 | --streaming 46 | -------------------------------------------------------------------------------- /training/flax/noise_evaluation_scripts/run_baselines.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | DATASET_NAME="librispeech_asr-noise+librispeech_asr-noise+librispeech_asr-noise+librispeech_asr-noise+librispeech_asr-noise+librispeech_asr-noise+librispeech_asr-noise+librispeech_asr-noise+librispeech_asr-noise+librispeech_asr-noise+librispeech_asr-noise" 4 | DATASET_CONFIG_NAME=("validation-white-noise" "validation-pub-noise") 5 | DATASET_SPLIT_NAME="40+35+30+25+20+15+10+5+0+minus5+minus10" 6 | 7 | for i in "${!DATASET_CONFIG_NAME[@]}"; do 8 | python run_eval.py \ 9 | --model_name_or_path "openai/whisper-tiny.en" \ 10 | --dataset_name $DATASET_NAME \ 11 | --dataset_config_name "${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}" \ 12 | --dataset_split_name $DATASET_SPLIT_NAME \ 13 | --cache_dir "/home/sanchitgandhi/cache" \ 14 | --dataset_cache_dir "/home/sanchitgandhi/cache" \ 15 | --output_dir "./" \ 16 | --wandb_dir "/home/sanchitgandhi/cache" \ 17 | --wandb_project "distil-whisper-noise-eval" \ 18 | --wandb_name "tiny.en-${DATASET_CONFIG_NAME[i]}" \ 19 | --per_device_eval_batch_size 64 \ 20 | --dtype "bfloat16" \ 21 | --dataloader_num_workers 16 \ 22 | --report_to "wandb" \ 23 | --streaming \ 24 | --predict_with_generate 25 | 26 | python run_eval.py \ 27 | --model_name_or_path "openai/whisper-base.en" \ 28 | --dataset_name $DATASET_NAME \ 29 | --dataset_config_name "${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}" \ 30 | --dataset_split_name $DATASET_SPLIT_NAME \ 31 | --cache_dir "/home/sanchitgandhi/cache" \ 32 | --dataset_cache_dir "/home/sanchitgandhi/cache" \ 33 | --output_dir "./" \ 34 | --wandb_dir "/home/sanchitgandhi/cache" \ 35 | --wandb_project "distil-whisper-noise-eval" \ 36 | --wandb_name "base.en-${DATASET_CONFIG_NAME[i]}" \ 37 | --per_device_eval_batch_size 64 \ 38 | --dtype "bfloat16" \ 39 | --dataloader_num_workers 16 \ 40 | --report_to "wandb" \ 41 | --streaming \ 42 | --predict_with_generate 43 | 44 | python run_eval.py \ 45 | --model_name_or_path "openai/whisper-small.en" \ 46 | --dataset_name $DATASET_NAME \ 47 | --dataset_config_name "${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}" \ 48 | --dataset_split_name $DATASET_SPLIT_NAME \ 49 | --cache_dir "/home/sanchitgandhi/cache" \ 50 | --dataset_cache_dir "/home/sanchitgandhi/cache" \ 51 | --output_dir "./" \ 52 | --wandb_dir "/home/sanchitgandhi/cache" \ 53 | --wandb_project "distil-whisper-noise-eval" \ 54 | --wandb_name "small.en-${DATASET_CONFIG_NAME[i]}" \ 55 | --per_device_eval_batch_size 64 \ 56 | --dtype "bfloat16" \ 57 | --dataloader_num_workers 16 \ 58 | --report_to "wandb" \ 59 | --streaming \ 60 | --predict_with_generate 61 | 62 | python run_eval.py \ 63 | --model_name_or_path "openai/whisper-medium.en" \ 64 | --dataset_name $DATASET_NAME \ 65 | --dataset_config_name "${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}" \ 66 | --dataset_split_name $DATASET_SPLIT_NAME \ 67 | --cache_dir "/home/sanchitgandhi/cache" \ 68 | --dataset_cache_dir "/home/sanchitgandhi/cache" \ 69 | --output_dir "./" \ 70 | --wandb_dir "/home/sanchitgandhi/cache" \ 71 | --wandb_project "distil-whisper-noise-eval" \ 72 | --wandb_name "medium.en-${DATASET_CONFIG_NAME[i]}" \ 73 | --per_device_eval_batch_size 64 \ 74 | --dtype "bfloat16" \ 75 | --dataloader_num_workers 16 \ 76 | --report_to "wandb" \ 77 | --streaming \ 78 | --predict_with_generate 79 | 80 | python run_eval.py \ 81 | --model_name_or_path "openai/whisper-large-v2" \ 82 | --dataset_name $DATASET_NAME \ 83 | --dataset_config_name "${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}" \ 84 | --dataset_split_name $DATASET_SPLIT_NAME \ 85 | --cache_dir "/home/sanchitgandhi/cache" \ 86 | --dataset_cache_dir "/home/sanchitgandhi/cache" \ 87 | --output_dir "./" \ 88 | --wandb_dir "/home/sanchitgandhi/cache" \ 89 | --wandb_project "distil-whisper-noise-eval" \ 90 | --wandb_name "large-v2-${DATASET_CONFIG_NAME[i]}" \ 91 | --per_device_eval_batch_size 32 \ 92 | --dtype "bfloat16" \ 93 | --dataloader_num_workers 16 \ 94 | --report_to "wandb" \ 95 | --streaming \ 96 | --predict_with_generate 97 | 98 | done -------------------------------------------------------------------------------- /training/flax/noise_evaluation_scripts/run_baselines_pt.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | MODEL_IDs=("facebook/wav2vec2-base-960h" "facebook/wav2vec2-large-960h" "facebook/wav2vec2-large-960h-lv60-self" "facebook/wav2vec2-large-robust-ft-libri-960h" "facebook/wav2vec2-conformer-rel-pos-large-960h-ft" "facebook/wav2vec2-conformer-rope-large-960h-ft" "facebook/hubert-large-ls960-ft" "facebook/hubert-xlarge-ls960-ft" "facebook/mms-1b-all" "facebook/mms-1b-fl102" "facebook/data2vec-audio-large-960h" "facebook/data2vec-audio-base-960h") 4 | DATASET_NAME="librispeech_asr-noise+librispeech_asr-noise+librispeech_asr-noise+librispeech_asr-noise+librispeech_asr-noise+librispeech_asr-noise+librispeech_asr-noise+librispeech_asr-noise+librispeech_asr-noise+librispeech_asr-noise+librispeech_asr-noise" 5 | DATASET_CONFIG_NAME=("test-white-noise" "test-pub-noise") 6 | DATASET_SPLIT_NAME="40+35+30+25+20+15+10+5+0+minus5+minus10" 7 | 8 | for i in "${!MODEL_IDs[@]}"; do 9 | for j in "${!DATASET_CONFIG_NAME[@]}"; do 10 | python run_pt_long_form_transcription.py \ 11 | --model_name_or_path "${MODEL_IDs[i]}" \ 12 | --dataset_name $DATASET_NAME \ 13 | --dataset_config_name "${DATASET_CONFIG_NAME[j]}+${DATASET_CONFIG_NAME[j]}+${DATASET_CONFIG_NAME[j]}+${DATASET_CONFIG_NAME[j]}+${DATASET_CONFIG_NAME[j]}+${DATASET_CONFIG_NAME[j]}+${DATASET_CONFIG_NAME[j]}+${DATASET_CONFIG_NAME[j]}+${DATASET_CONFIG_NAME[j]}+${DATASET_CONFIG_NAME[j]}+${DATASET_CONFIG_NAME[j]}" \ 14 | --dataset_split_name $DATASET_SPLIT_NAME \ 15 | --cache_dir "/home/sanchit/.cache" \ 16 | --dataset_cache_dir "/home/sanchit/.cache" \ 17 | --output_dir "./" \ 18 | --wandb_dir "/home/sanchit/.cache" \ 19 | --wandb_project "distil-whisper-noise-test" \ 20 | --wandb_name "${MODEL_IDs[i]}-${DATASET_CONFIG_NAME[j]}" \ 21 | --per_device_eval_batch_size 16 \ 22 | --dtype "float16" \ 23 | --report_to "wandb" \ 24 | --streaming \ 25 | --predict_with_generate 26 | done 27 | done 28 | -------------------------------------------------------------------------------- /training/flax/noise_evaluation_scripts/run_distilled.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | DATASET_NAME="librispeech_asr-noise+librispeech_asr-noise+librispeech_asr-noise+librispeech_asr-noise+librispeech_asr-noise+librispeech_asr-noise+librispeech_asr-noise+librispeech_asr-noise+librispeech_asr-noise+librispeech_asr-noise+librispeech_asr-noise" 4 | DATASET_CONFIG_NAME=("validation-white-noise" "validation-pub-noise") 5 | DATASET_SPLIT_NAME="40+35+30+25+20+15+10+5+0+minus5+minus10" 6 | 7 | for i in "${!DATASET_CONFIG_NAME[@]}"; do 8 | python run_eval.py \ 9 | --model_name_or_path "sanchit-gandhi/large-32-2-gpu-flat-lr" \ 10 | --dataset_name $DATASET_NAME \ 11 | --dataset_config_name "${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}" \ 12 | --dataset_split_name $DATASET_SPLIT_NAME \ 13 | --cache_dir "/home/sanchitgandhi/cache" \ 14 | --dataset_cache_dir "/home/sanchitgandhi/cache" \ 15 | --output_dir "./" \ 16 | --wandb_dir "/home/sanchitgandhi/cache" \ 17 | --wandb_project "distil-whisper-noise-eval" \ 18 | --wandb_name "large-32-2-gpu-flat-lr-${DATASET_CONFIG_NAME[i]}" \ 19 | --per_device_eval_batch_size 64 \ 20 | --dtype "bfloat16" \ 21 | --dataloader_num_workers 16 \ 22 | --report_to "wandb" \ 23 | --streaming \ 24 | --predict_with_generate 25 | done 26 | -------------------------------------------------------------------------------- /training/flax/noise_evaluation_scripts/test/run_baselines.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | DATASET_NAME="librispeech_asr-noise+librispeech_asr-noise+librispeech_asr-noise+librispeech_asr-noise+librispeech_asr-noise+librispeech_asr-noise+librispeech_asr-noise+librispeech_asr-noise+librispeech_asr-noise+librispeech_asr-noise+librispeech_asr-noise" 4 | DATASET_CONFIG_NAME=("test-white-noise" "test-pub-noise") 5 | DATASET_SPLIT_NAME="40+35+30+25+20+15+10+5+0+minus5+minus10" 6 | 7 | for i in "${!DATASET_CONFIG_NAME[@]}"; do 8 | python run_eval.py \ 9 | --model_name_or_path "openai/whisper-tiny.en" \ 10 | --dataset_name $DATASET_NAME \ 11 | --dataset_config_name "${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}" \ 12 | --dataset_split_name $DATASET_SPLIT_NAME \ 13 | --cache_dir "/home/sanchitgandhi/cache" \ 14 | --dataset_cache_dir "/home/sanchitgandhi/cache" \ 15 | --output_dir "./" \ 16 | --wandb_dir "/home/sanchitgandhi/cache" \ 17 | --wandb_project "distil-whisper-noise-test" \ 18 | --wandb_name "tiny.en-${DATASET_CONFIG_NAME[i]}" \ 19 | --per_device_eval_batch_size 32 \ 20 | --dtype "bfloat16" \ 21 | --dataloader_num_workers 16 \ 22 | --report_to "wandb" \ 23 | --streaming \ 24 | --predict_with_generate 25 | 26 | python run_eval.py \ 27 | --model_name_or_path "openai/whisper-base.en" \ 28 | --dataset_name $DATASET_NAME \ 29 | --dataset_config_name "${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}" \ 30 | --dataset_split_name $DATASET_SPLIT_NAME \ 31 | --cache_dir "/home/sanchitgandhi/cache" \ 32 | --dataset_cache_dir "/home/sanchitgandhi/cache" \ 33 | --output_dir "./" \ 34 | --wandb_dir "/home/sanchitgandhi/cache" \ 35 | --wandb_project "distil-whisper-noise-test" \ 36 | --wandb_name "base.en-${DATASET_CONFIG_NAME[i]}" \ 37 | --per_device_eval_batch_size 32 \ 38 | --dtype "bfloat16" \ 39 | --dataloader_num_workers 16 \ 40 | --report_to "wandb" \ 41 | --streaming \ 42 | --predict_with_generate 43 | 44 | python run_eval.py \ 45 | --model_name_or_path "openai/whisper-small.en" \ 46 | --dataset_name $DATASET_NAME \ 47 | --dataset_config_name "${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}" \ 48 | --dataset_split_name $DATASET_SPLIT_NAME \ 49 | --cache_dir "/home/sanchitgandhi/cache" \ 50 | --dataset_cache_dir "/home/sanchitgandhi/cache" \ 51 | --output_dir "./" \ 52 | --wandb_dir "/home/sanchitgandhi/cache" \ 53 | --wandb_project "distil-whisper-noise-test" \ 54 | --wandb_name "small.en-${DATASET_CONFIG_NAME[i]}" \ 55 | --per_device_eval_batch_size 32 \ 56 | --dtype "bfloat16" \ 57 | --dataloader_num_workers 16 \ 58 | --report_to "wandb" \ 59 | --streaming \ 60 | --predict_with_generate 61 | 62 | python run_eval.py \ 63 | --model_name_or_path "openai/whisper-medium.en" \ 64 | --dataset_name $DATASET_NAME \ 65 | --dataset_config_name "${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}" \ 66 | --dataset_split_name $DATASET_SPLIT_NAME \ 67 | --cache_dir "/home/sanchitgandhi/cache" \ 68 | --dataset_cache_dir "/home/sanchitgandhi/cache" \ 69 | --output_dir "./" \ 70 | --wandb_dir "/home/sanchitgandhi/cache" \ 71 | --wandb_project "distil-whisper-noise-test" \ 72 | --wandb_name "medium.en-${DATASET_CONFIG_NAME[i]}" \ 73 | --per_device_eval_batch_size 32 \ 74 | --dtype "bfloat16" \ 75 | --dataloader_num_workers 16 \ 76 | --report_to "wandb" \ 77 | --streaming \ 78 | --predict_with_generate 79 | 80 | python run_eval.py \ 81 | --model_name_or_path "openai/whisper-large-v2" \ 82 | --dataset_name $DATASET_NAME \ 83 | --dataset_config_name "${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}" \ 84 | --dataset_split_name $DATASET_SPLIT_NAME \ 85 | --cache_dir "/home/sanchitgandhi/cache" \ 86 | --dataset_cache_dir "/home/sanchitgandhi/cache" \ 87 | --output_dir "./" \ 88 | --wandb_dir "/home/sanchitgandhi/cache" \ 89 | --wandb_project "distil-whisper-noise-test" \ 90 | --wandb_name "large-v2-${DATASET_CONFIG_NAME[i]}" \ 91 | --per_device_eval_batch_size 16 \ 92 | --dtype "bfloat16" \ 93 | --dataloader_num_workers 16 \ 94 | --report_to "wandb" \ 95 | --streaming \ 96 | --predict_with_generate 97 | done -------------------------------------------------------------------------------- /training/flax/noise_evaluation_scripts/test/run_distilled.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | DATASET_NAME="librispeech_asr-noise+librispeech_asr-noise+librispeech_asr-noise+librispeech_asr-noise+librispeech_asr-noise+librispeech_asr-noise+librispeech_asr-noise+librispeech_asr-noise+librispeech_asr-noise+librispeech_asr-noise+librispeech_asr-noise" 4 | DATASET_CONFIG_NAME=("test-white-noise" "test-pub-noise") 5 | DATASET_SPLIT_NAME="40+35+30+25+20+15+10+5+0+minus5+minus10" 6 | 7 | for i in "${!DATASET_CONFIG_NAME[@]}"; do 8 | python run_eval.py \ 9 | --model_name_or_path "sanchit-gandhi/large-32-2-tpu-timestamped-resumed" \ 10 | --dataset_name $DATASET_NAME \ 11 | --dataset_config_name "${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}" \ 12 | --dataset_split_name $DATASET_SPLIT_NAME \ 13 | --cache_dir "/home/sanchitgandhi/cache" \ 14 | --dataset_cache_dir "/home/sanchitgandhi/cache" \ 15 | --output_dir "./" \ 16 | --wandb_dir "/home/sanchitgandhi/cache" \ 17 | --wandb_project "distil-whisper-noise-test" \ 18 | --wandb_name "large-32-2-tpu-timestamped-${DATASET_CONFIG_NAME[i]}" \ 19 | --per_device_eval_batch_size 64 \ 20 | --dtype "bfloat16" \ 21 | --dataloader_num_workers 16 \ 22 | --report_to "wandb" \ 23 | --streaming \ 24 | --predict_with_generate 25 | 26 | python run_eval.py \ 27 | --model_name_or_path "sanchit-gandhi/medium-24-2-tpu-timestamped-prob-0.2" \ 28 | --subfolder "checkpoint-45000" \ 29 | --dataset_name $DATASET_NAME \ 30 | --dataset_config_name "${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}+${DATASET_CONFIG_NAME[i]}" \ 31 | --dataset_split_name $DATASET_SPLIT_NAME \ 32 | --cache_dir "/home/sanchitgandhi/cache" \ 33 | --dataset_cache_dir "/home/sanchitgandhi/cache" \ 34 | --output_dir "./" \ 35 | --wandb_dir "/home/sanchitgandhi/cache" \ 36 | --wandb_project "distil-whisper-noise-test" \ 37 | --wandb_name "medium-24-2-tpu-timestamped-prob-0.2-${DATASET_CONFIG_NAME[i]}" \ 38 | --per_device_eval_batch_size 64 \ 39 | --dtype "bfloat16" \ 40 | --dataloader_num_workers 16 \ 41 | --report_to "wandb" \ 42 | --streaming \ 43 | --predict_with_generate 44 | done 45 | -------------------------------------------------------------------------------- /training/flax/pseudo_labelling_scripts/run_librispeech_pseudo_labelling.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python run_pseudo_labelling.py \ 4 | --model_name_or_path "openai/whisper-large-v2" \ 5 | --dataset_name "sanchit-gandhi/librispeech_asr_clean" \ 6 | --dataset_config_name "clean" \ 7 | --data_split_name "train.100" \ 8 | --text_column_name "text" \ 9 | --cache_dir "/home/sanchitgandhi/cache" \ 10 | --dataset_cache_dir "/home/sanchitgandhi/cache" \ 11 | --output_dir "./transcriptions-streaming" \ 12 | --wandb_dir "/home/sanchitgandhi/.cache" \ 13 | --wandb_project "distil-whisper-debug" \ 14 | --wandb_name "whisper-large-v2-beam-libri-train.clean.100" \ 15 | --per_device_eval_batch_size 16 \ 16 | --max_label_length 256 \ 17 | --dtype "bfloat16" \ 18 | --preprocessing_num_workers 16 \ 19 | --report_to "wandb" \ 20 | --dataloader_num_workers 16 \ 21 | --streaming False \ 22 | --generation_num_beams 1 23 | -------------------------------------------------------------------------------- /training/flax/pseudo_labelling_scripts/run_librispeech_pseudo_labelling_dummy.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python run_pseudo_labelling.py \ 4 | --model_name_or_path "openai/whisper-tiny" \ 5 | --dataset_name "distil-whisper/librispeech_asr" \ 6 | --dataset_config_name "all" \ 7 | --data_split_name "validation.clean+validation.other" \ 8 | --text_column_name "text" \ 9 | --cache_dir "/home/sanchitgandhi/.cache" \ 10 | --dataset_cache_dir "/home/sanchitgandhi/.cache" \ 11 | --output_dir "./transcriptions-streaming" \ 12 | --wandb_dir "/home/sanchitgandhi/.cache" \ 13 | --wandb_project "distil-whisper-debug" \ 14 | --per_device_eval_batch_size 1 \ 15 | --dtype "bfloat16" \ 16 | --dataloader_num_workers 16 \ 17 | --logging_steps 2 \ 18 | --report_to "wandb" \ 19 | --streaming \ 20 | --max_samples_per_split 256 \ 21 | --max_label_length 256 \ 22 | --return_timestamps \ 23 | --decode_token_ids False 24 | -------------------------------------------------------------------------------- /training/flax/pseudo_labelling_scripts/run_pseudo_labelling.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | MODEL_NAME="openai/whisper-large-v3" 4 | CACHE_DIR="/home/sanchitgandhi/.cache" 5 | OUTPUT_DIR="./transcriptions-streaming" 6 | WANDB_DIR="/home/sanchitgandhi/.cache" 7 | WANDB_PROJECT="distil-whisper-label" 8 | BATCH_SIZE=64 9 | NUM_BEAMS=1 10 | MAX_LABEL_LENGTH=256 11 | LOGGING_STEPS=500 12 | NUM_WORKERS=64 13 | RETURN_TIMESTAMPS=False 14 | 15 | python run_pseudo_labelling.py \ 16 | --model_name_or_path $MODEL_NAME \ 17 | --dataset_name "distil-whisper/librispeech_asr" \ 18 | --dataset_config_name "all" \ 19 | --data_split_name "train.other.500+validation.clean+validation.other+test.clean+test.other" \ 20 | --wandb_name "whisper-large-v2-librispeech_asr" \ 21 | --cache_dir $CACHE_DIR \ 22 | --dataset_cache_dir $CACHE_DIR \ 23 | --output_dir $OUTPUT_DIR \ 24 | --wandb_dir $WANDB_DIR \ 25 | --wandb_project $WANDB_PROJECT \ 26 | --per_device_eval_batch_size $BATCH_SIZE \ 27 | --generation_num_beams $NUM_BEAMS \ 28 | --max_label_length $MAX_LABEL_LENGTH \ 29 | --logging_steps $LOGGING_STEPS \ 30 | --dataloader_num_workers $NUM_WORKERS \ 31 | --dtype "bfloat16" \ 32 | --report_to "wandb" \ 33 | --streaming True \ 34 | --push_to_hub \ 35 | --return_timestamps $RETURN_TIMESTAMPS \ 36 | --compilation_cache $CACHE_DIR 37 | 38 | python run_pseudo_labelling.py \ 39 | --model_name_or_path $MODEL_NAME \ 40 | --dataset_name "distil-whisper/peoples_speech-clean" \ 41 | --dataset_config_name "clean" \ 42 | --data_split_name "train+validation+test" \ 43 | --wandb_name "whisper-large-v2-peoples_speech-clean" \ 44 | --cache_dir $CACHE_DIR \ 45 | --dataset_cache_dir $CACHE_DIR \ 46 | --output_dir $OUTPUT_DIR \ 47 | --wandb_dir $WANDB_DIR \ 48 | --wandb_project $WANDB_PROJECT \ 49 | --per_device_eval_batch_size $BATCH_SIZE \ 50 | --generation_num_beams $NUM_BEAMS \ 51 | --max_label_length $MAX_LABEL_LENGTH \ 52 | --logging_steps $LOGGING_STEPS \ 53 | --dataloader_num_workers $NUM_WORKERS \ 54 | --dtype "bfloat16" \ 55 | --report_to "wandb" \ 56 | --streaming True \ 57 | --push_to_hub \ 58 | --return_timestamps $RETURN_TIMESTAMPS \ 59 | --compilation_cache $CACHE_DIR 60 | -------------------------------------------------------------------------------- /training/flax/pseudo_labelling_scripts/run_pseudo_labelling_2.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | MODEL_NAME="openai/whisper-large-v3" 4 | CACHE_DIR="/home/sanchitgandhi/.cache" 5 | OUTPUT_DIR="./transcriptions-streaming" 6 | WANDB_DIR="/home/sanchitgandhi/.cache" 7 | WANDB_PROJECT="distil-whisper-label" 8 | SPLITS="train+validation+test" 9 | BATCH_SIZE=64 10 | NUM_BEAMS=1 11 | MAX_LABEL_LENGTH=256 12 | LOGGING_STEPS=500 13 | NUM_WORKERS=64 14 | RETURN_TIMESTAMPS=False 15 | 16 | DATASET_NAMES=("distil-whisper/common_voice_13_0" "distil-whisper/voxpopuli" "distil-whisper/tedlium" "distil-whisper/ami-ihm" "distil-whisper/ami-sdm" "distil-whisper/spgispeech" "distil-whisper/gigaspeech-l") 17 | CONFIGS=("en" "en" "release3" "ihm" "sdm" "L" "l") 18 | 19 | for i in "${!DATASET_NAMES[@]}"; do 20 | python run_pseudo_labelling.py \ 21 | --model_name_or_path $MODEL_NAME \ 22 | --dataset_name "${DATASET_NAMES[i]}" \ 23 | --dataset_config_name "${CONFIGS[i]}" \ 24 | --data_split_name "$SPLITS" \ 25 | --wandb_name "whisper-large-v2-${DATASET_NAMES[i]}" \ 26 | --cache_dir $CACHE_DIR \ 27 | --dataset_cache_dir $CACHE_DIR \ 28 | --output_dir $OUTPUT_DIR \ 29 | --wandb_dir $WANDB_DIR \ 30 | --wandb_project $WANDB_PROJECT \ 31 | --per_device_eval_batch_size $BATCH_SIZE \ 32 | --generation_num_beams $NUM_BEAMS \ 33 | --max_label_length $MAX_LABEL_LENGTH \ 34 | --logging_steps $LOGGING_STEPS \ 35 | --dataloader_num_workers $NUM_WORKERS \ 36 | --dtype "bfloat16" \ 37 | --report_to "wandb" \ 38 | --streaming True \ 39 | --push_to_hub \ 40 | --return_timestamps $RETURN_TIMESTAMPS \ 41 | --compilation_cache $CACHE_DIR 42 | done 43 | -------------------------------------------------------------------------------- /training/flax/pseudo_labelling_scripts/run_pseudo_labelling_dummy_pt.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | accelerate launch --mixed_precision=bf16 --num_processes=1 run_pseudo_labelling_pt.py \ 4 | --model_name_or_path "openai/whisper-tiny" \ 5 | --dataset_name "distil-whisper/librispeech_asr" \ 6 | --dataset_config_name "all" \ 7 | --data_split_name "validation.clean+validation.other" \ 8 | --text_column_name "text" \ 9 | --cache_dir "/home/sanchit/.cache" \ 10 | --dataset_cache_dir "/home/sanchit/.cache" \ 11 | --output_dir "./transcriptions-streaming" \ 12 | --wandb_project "distil-whisper-debug" \ 13 | --per_device_eval_batch_size 8 \ 14 | --dtype "bfloat16" \ 15 | --dataloader_num_workers 16 \ 16 | --logging_steps 2 \ 17 | --report_to "wandb" \ 18 | --streaming \ 19 | --max_samples_per_split 256 \ 20 | --max_label_length 256 \ 21 | --return_timestamps \ 22 | --decode_token_ids False 23 | -------------------------------------------------------------------------------- /training/flax/pseudo_labelling_scripts/run_pseudo_labelling_token_ids.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | MODEL_NAME="openai/whisper-large-v3" 4 | CACHE_DIR="/home/sanchitgandhi/.cache" 5 | OUTPUT_DIR="./transcriptions-streaming" 6 | WANDB_DIR="/home/sanchitgandhi/.cache" 7 | WANDB_PROJECT="distil-whisper-label" 8 | BATCH_SIZE=16 9 | NUM_BEAMS=1 10 | MAX_LABEL_LENGTH=256 11 | LOGGING_STEPS=500 12 | NUM_WORKERS=64 13 | RETURN_TIMESTAMPS=False 14 | DECODE_TOKEN_IDS=False 15 | 16 | python run_pseudo_labelling.py \ 17 | --model_name_or_path $MODEL_NAME \ 18 | --dataset_name "distil-whisper/librispeech_asr" \ 19 | --dataset_config_name "all" \ 20 | --data_split_name "train.other.500+validation.clean+validation.other+test.clean+test.other" \ 21 | --wandb_name "whisper-large-v2-librispeech_asr-token-ids" \ 22 | --cache_dir $CACHE_DIR \ 23 | --dataset_cache_dir $CACHE_DIR \ 24 | --output_dir $OUTPUT_DIR \ 25 | --wandb_dir $WANDB_DIR \ 26 | --wandb_project $WANDB_PROJECT \ 27 | --per_device_eval_batch_size $BATCH_SIZE \ 28 | --generation_num_beams $NUM_BEAMS \ 29 | --max_label_length $MAX_LABEL_LENGTH \ 30 | --logging_steps $LOGGING_STEPS \ 31 | --dataloader_num_workers $NUM_WORKERS \ 32 | --dtype "bfloat16" \ 33 | --report_to "wandb" \ 34 | --streaming True \ 35 | --push_to_hub \ 36 | --return_timestamps $RETURN_TIMESTAMPS \ 37 | --compilation_cache $CACHE_DIR \ 38 | --decode_token_ids $DECODE_TOKEN_IDS 39 | 40 | python run_pseudo_labelling.py \ 41 | --model_name_or_path $MODEL_NAME \ 42 | --dataset_name "distil-whisper/peoples_speech-clean" \ 43 | --dataset_config_name "clean" \ 44 | --data_split_name "train+validation+test" \ 45 | --wandb_name "whisper-large-v2-peoples_speech-clean-token-ids" \ 46 | --cache_dir $CACHE_DIR \ 47 | --dataset_cache_dir $CACHE_DIR \ 48 | --output_dir $OUTPUT_DIR \ 49 | --wandb_dir $WANDB_DIR \ 50 | --wandb_project $WANDB_PROJECT \ 51 | --per_device_eval_batch_size $BATCH_SIZE \ 52 | --generation_num_beams $NUM_BEAMS \ 53 | --max_label_length $MAX_LABEL_LENGTH \ 54 | --logging_steps $LOGGING_STEPS \ 55 | --dataloader_num_workers $NUM_WORKERS \ 56 | --dtype "bfloat16" \ 57 | --report_to "wandb" \ 58 | --streaming True \ 59 | --push_to_hub \ 60 | --return_timestamps $RETURN_TIMESTAMPS \ 61 | --compilation_cache $CACHE_DIR \ 62 | --decode_token_ids $DECODE_TOKEN_IDS 63 | -------------------------------------------------------------------------------- /training/flax/pseudo_labelling_scripts/run_pseudo_labelling_token_ids_2.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | MODEL_NAME="openai/whisper-large-v3" 4 | CACHE_DIR="/home/sanchitgandhi/.cache" 5 | OUTPUT_DIR="./transcriptions-streaming" 6 | WANDB_DIR="/home/sanchitgandhi/.cache" 7 | WANDB_PROJECT="distil-whisper-label" 8 | SPLITS="train+validation+test" 9 | BATCH_SIZE=16 10 | NUM_BEAMS=1 11 | MAX_LABEL_LENGTH=256 12 | LOGGING_STEPS=500 13 | NUM_WORKERS=64 14 | RETURN_TIMESTAMPS=False 15 | DECODE_TOKEN_IDS=False 16 | 17 | DATASET_NAMES=("distil-whisper/common_voice_13_0" "distil-whisper/voxpopuli" "distil-whisper/tedlium" "distil-whisper/ami-ihm" "distil-whisper/ami-sdm" "distil-whisper/spgispeech" "distil-whisper/gigaspeech-l") 18 | CONFIGS=("en" "en" "release3" "ihm" "sdm" "L" "l") 19 | 20 | for i in "${!DATASET_NAMES[@]}"; do 21 | python run_pseudo_labelling.py \ 22 | --model_name_or_path $MODEL_NAME \ 23 | --dataset_name "${DATASET_NAMES[i]}" \ 24 | --dataset_config_name "${CONFIGS[i]}" \ 25 | --data_split_name "$SPLITS" \ 26 | --wandb_name "whisper-large-v2-${DATASET_NAMES[i]}-token-ids" \ 27 | --cache_dir $CACHE_DIR \ 28 | --dataset_cache_dir $CACHE_DIR \ 29 | --output_dir $OUTPUT_DIR \ 30 | --wandb_dir $WANDB_DIR \ 31 | --wandb_project $WANDB_PROJECT \ 32 | --per_device_eval_batch_size $BATCH_SIZE \ 33 | --generation_num_beams $NUM_BEAMS \ 34 | --max_label_length $MAX_LABEL_LENGTH \ 35 | --logging_steps $LOGGING_STEPS \ 36 | --dataloader_num_workers $NUM_WORKERS \ 37 | --dtype "bfloat16" \ 38 | --report_to "wandb" \ 39 | --streaming True \ 40 | --push_to_hub \ 41 | --return_timestamps $RETURN_TIMESTAMPS \ 42 | --compilation_cache $CACHE_DIR \ 43 | --decode_token_ids $DECODE_TOKEN_IDS 44 | done 45 | -------------------------------------------------------------------------------- /training/flax/pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 119 3 | target-version = ['py37'] 4 | 5 | [tool.ruff] 6 | # Never enforce `E501` (line length violations). 7 | ignore = ["C901", "E501", "E741", "W605"] 8 | select = ["C", "E", "F", "I", "W"] 9 | line-length = 119 10 | 11 | # Ignore import violations in all `__init__.py` files. 12 | [tool.ruff.per-file-ignores] 13 | "__init__.py" = ["E402", "F401", "F403", "F811"] 14 | 15 | [tool.ruff.isort] 16 | lines-after-imports = 2 17 | known-first-party = ["distil_whisper"] -------------------------------------------------------------------------------- /training/flax/requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.7 2 | transformers 3 | datasets[audio] 4 | jiwer 5 | evaluate>=0.3.0 6 | -------------------------------------------------------------------------------- /training/flax/run_orig_longform.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | names=("openai/whisper-large-v2" "openai/whisper-medium.en" "openai/whisper-small.en" "openai/whisper-base.en" "openai/whisper-tiny.en") 3 | names=("openai/whisper-small.en" "openai/whisper-base.en" "openai/whisper-tiny.en") 4 | # names=("patrickvonplaten/whisper-large-v2-32-2" "patrickvonplaten/whisper-medium-24-2") 5 | 6 | # chunk_lengths=("15.0" "30.0") 7 | # --return_timestamps \ 8 | # --assistant_model_name_or_path "patrickvonplaten/whisper-large-v2-32-2" \ 9 | # --attn_type "flash2" \ 10 | 11 | # Double loop 12 | for name in "${names[@]}"; do 13 | CUDA_VISIBLE_DEVICES="1" python ./run_speed_pt.py \ 14 | --dataset_name "distil-whisper/earnings21+distil-whisper/earnings22+distil-whisper/meanwhile+distil-whisper/rev16" \ 15 | --wandb_name "A100-${name}-Longform-Orig" \ 16 | --model_name_or_path ${name} \ 17 | --wandb_project "distil-whisper-speed-bench-long-form-orig-32" \ 18 | --dataset_config_name "full+full+default+whisper_subset" \ 19 | --dataset_split_name "test+test+test+test" \ 20 | --text_column_name "transcription+transcription+text+transcription" \ 21 | --use_orig_whisper \ 22 | --max_label_length "1000000" \ 23 | --samples_per_dataset "32" \ 24 | --batch_size "1" 25 | done 26 | -------------------------------------------------------------------------------- /training/flax/run_speculative_decoding.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # make sure to use branch: https://github.com/huggingface/transformers/pull/26701 3 | import copy 4 | import time 5 | 6 | import torch 7 | from datasets import load_dataset 8 | from transformers import ( 9 | AutoProcessor, 10 | WhisperForConditionalGeneration, 11 | ) 12 | 13 | 14 | DEVICE = "cuda" 15 | DTYPE = torch.float16 16 | SAMPLING_RATE = 16_000 17 | BATCH_SIZE = 1 18 | USE_FLASH_ATTN_2 = True 19 | 20 | # TO DEBUG 21 | GAMMAS = [5, 7, 6, 5, 4, 3, 5] 22 | COUNT = 0 23 | 24 | # local loading is faster 25 | teacher = WhisperForConditionalGeneration.from_pretrained( 26 | "/home/patrick/distil_whisper/", 27 | torch_dtype=DTYPE, 28 | variant="fp16", 29 | low_cpu_mem_usage=True, 30 | use_flash_attention_2=USE_FLASH_ATTN_2, 31 | ) 32 | student = WhisperForConditionalGeneration.from_pretrained( 33 | "/home/patrick/distil_whisper_student/", 34 | torch_dtype=DTYPE, 35 | variant="fp16", 36 | low_cpu_mem_usage=True, 37 | use_flash_attention_2=USE_FLASH_ATTN_2, 38 | ) 39 | # student = WhisperForCausalLM.from_pretrained("/home/patrick/distil_whisper_student", torch_dtype=DTYPE, variant="fp16", low_cpu_mem_usage=True, use_flash_attention_2=USE_FLASH_ATTN_2) 40 | 41 | student.generation_config = copy.deepcopy(teacher.generation_config) 42 | student.generation_config.num_assistant_tokens_schedule = "constant" 43 | 44 | # teacher = WhisperForConditionalGeneration.from_pretrained( 45 | # "openai/whisper-large-v2", torch_dtype=DTYPE, variant="fp16", low_cpu_mem_usage=True 46 | # ) 47 | # student = WhisperForConditionalGeneration.from_pretrained( 48 | # "sanchit-gandhi/large-32-2-gpu-flat-lr", torch_dtype=DTYPE, variant="fp16", low_cpu_mem_usage=True 49 | # ) 50 | # 51 | teacher.to(DEVICE) 52 | student.to(DEVICE) 53 | 54 | processor = AutoProcessor.from_pretrained("sanchit-gandhi/large-32-2-gpu-flat-lr") 55 | 56 | ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") 57 | 58 | total_time_default = 0 59 | total_time_spec = 0 60 | total_time_spec_2 = 0 61 | 62 | input_values = ds[0]["audio"]["array"] 63 | inputs = processor(input_values, return_tensors="pt", sampling_rate=SAMPLING_RATE) 64 | input_features = inputs.input_features.to(device=DEVICE, dtype=DTYPE) 65 | 66 | _ = teacher.generate(input_features, max_length=100) 67 | 68 | end_idx = ds.shape[0] 69 | for audio_idx in range(0, end_idx, BATCH_SIZE): 70 | input_values = ds[audio_idx : audio_idx + BATCH_SIZE] 71 | input_values = [i["array"] for i in input_values["audio"]] 72 | 73 | inputs = processor(input_values, return_tensors="pt", sampling_rate=SAMPLING_RATE) 74 | input_features = inputs.input_features.to(device=DEVICE, dtype=DTYPE) 75 | 76 | start_time = time.time() 77 | out = teacher.generate(input_features, max_length=100) 78 | run_time = time.time() - start_time 79 | print(f"Normal Decoding: {run_time}") 80 | total_time_default += run_time 81 | 82 | default_out = processor.batch_decode(out, skip_special_tokens=True) 83 | # print("Output", default_out) 84 | 85 | # start_time = time.time() 86 | # with torch.no_grad(): 87 | # encoder_outputs = teacher.get_encoder()(input_features).last_hidden_state 88 | 89 | # out, ratio = speculative_decoding(teacher, student, encoder_outputs, max_length=100, gamma=5) 90 | # run_time = time.time() - start_time 91 | # print(20 * "=") 92 | # print(f"Speculative Decoding: {run_time}") 93 | # total_time_spec += run_time 94 | 95 | # spec_out = processor.batch_decode(out) 96 | 97 | start_time = time.time() 98 | with torch.no_grad(): 99 | encoder_outputs = teacher.get_encoder()(input_features) 100 | 101 | out = teacher.generate( 102 | assistant_model=student, 103 | assistant_encoder_outputs=encoder_outputs, 104 | encoder_outputs=encoder_outputs, 105 | max_length=100, 106 | ) 107 | run_time = time.time() - start_time 108 | 109 | spec_out_2 = processor.batch_decode(out, skip_special_tokens=True) 110 | 111 | print(f"Speculative Decoding 2: {run_time}") 112 | total_time_spec_2 += run_time 113 | 114 | if spec_out_2 != default_out: 115 | COUNT += 1 116 | print(f"Audio {audio_idx} does not match. Spec: {spec_out_2}, True: {default_out}") 117 | 118 | 119 | print(20 * "=") 120 | print("Total time", total_time_default) 121 | print(f"Overall speed-up spec 2 {total_time_default / total_time_spec_2}") 122 | # print(f"Overall speed-up {total_time_default / total_time_spec}") 123 | -------------------------------------------------------------------------------- /training/flax/run_speed.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # --wandb_project "distil-whisper-speed-bench-1024-no-timestamps" \ 3 | batch_sizes=(1 16) 4 | names=("openai/whisper-large-v2" "openai/whisper-medium.en" "openai/whisper-small.en" "openai/whisper-base.en" "openai/whisper-tiny.en" "patrickvonplaten/whisper-large-v2-32-2" "patrickvonplaten/whisper-medium-24-2") 5 | 6 | # Double loop 7 | for name in "${names[@]}"; do 8 | for batch_size in "${batch_sizes[@]}"; do 9 | CUDA_VISIBLE_DEVICES="1" python ./run_speed_pt.py \ 10 | --dataset_name "google/fleurs+distil-whisper/chime4+distil-whisper/earnings22+kensho/spgispeech" \ 11 | --wandb_name "T4-bsz${batch_size}-${name}" \ 12 | --model_name_or_path ${name} \ 13 | --wandb_project "beam-search-distil-whisper-speed-bench-256-no-timestamps" \ 14 | --dataset_config_name "en_us+1-channel+chunked+test" \ 15 | --dataset_split_name "test+test+test+test" \ 16 | --text_column_name "transcription+text+transcription+transcript" \ 17 | --samples_per_dataset "256" \ 18 | --attn_type "flash2" \ 19 | --num_beams 5 \ 20 | --batch_size ${batch_size} 21 | done 22 | done 23 | -------------------------------------------------------------------------------- /training/flax/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import os 17 | 18 | import setuptools 19 | 20 | 21 | _deps = [ 22 | "transformers>=4.34.0", 23 | "datasets[audio]>=2.14.5", 24 | "jax>=0.4.13", 25 | "flax>=0.7.2", 26 | "optax", 27 | "evaluate", 28 | "jiwer", 29 | "torch", 30 | "torchdata", 31 | "tokenizers", 32 | ] 33 | 34 | _extras_dev_deps = [ 35 | "black~=23.1", 36 | "isort>=5.5.4", 37 | "ruff>=0.0.241,<=0.0.259", 38 | ] 39 | 40 | here = os.path.abspath(os.path.dirname(__file__)) 41 | 42 | with open(os.path.join(here, "README.md"), encoding="utf-8") as f: 43 | long_description = f.read() 44 | 45 | # read version 46 | with open(os.path.join(here, "distil_whisper", "__init__.py"), encoding="utf-8") as f: 47 | for line in f: 48 | if line.startswith("__version__"): 49 | version = line.split("=")[1].strip().strip('"') 50 | break 51 | else: 52 | raise RuntimeError("Unable to find version string.") 53 | 54 | setuptools.setup( 55 | name="distil_whisper", 56 | version=version, 57 | description="Toolkit for distilling OpenAI's Whisper model.", 58 | long_description=long_description, 59 | long_description_content_type="text/markdown", 60 | packages=setuptools.find_packages(), 61 | install_requires=_deps, 62 | extras_require={ 63 | "dev": [_extras_dev_deps], 64 | }, 65 | ) 66 | -------------------------------------------------------------------------------- /training/flax/tpu_connect.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This script is adapted from https://github.com/peregilk/ttconnect#ttconnect 4 | 5 | zone="us-central2-b" # TPU v4's always are in us-central2-b. Update if using TPU v2/v3's 6 | name=$1 7 | 8 | echo "Connecting to $name"; 9 | 10 | ## Some basic checks if the input is valid 11 | output=$(gcloud compute tpus describe $name --zone $zone 2>/dev/null) 12 | if [ $? != 0 ]; then 13 | echo "Could not find a tpu-v4 with this name in the zone $zone. Exiting." 14 | exit 1 15 | fi 16 | 17 | tputype=$(echo $output | awk '{print $2}') 18 | tpusize=$(echo $tputype| cut -c4-) 19 | size="$(($tpusize / 8))" 20 | 21 | if (( $size < 1 )); then 22 | echo "This is reported as a $tputype with $size tpu(s). This is not a valid tpu-v4 resource. Exiting." 23 | exit 1 24 | fi 25 | 26 | 27 | # Check if the session exists, if not create it 28 | # If there already is a session with this name, it will just attach 29 | 30 | tmux has-session -t $name 2>/dev/null 31 | 32 | 33 | if [ $? != 0 ]; then 34 | tmux new-session -d -s $name 35 | tmux select-layout main-vertical 36 | 37 | for i in $(seq $(($size-1))); do 38 | tmux split-window -v -d -t $name 39 | # Making sure there is space to split 40 | tmux select-layout main-horizontal 41 | done 42 | 43 | for i in $(seq $(($size))); do 44 | worker=$(($i -1)) 45 | command="gcloud alpha compute tpus tpu-vm ssh $name --zone $zone --worker $worker" 46 | tmux select-pane -t $name:0.$worker 47 | tmux send-keys -t $name "$command" Enter 48 | 49 | done 50 | 51 | # Select the final layout 52 | if ((size >= 16));then 53 | tmux select-layout tiled 54 | else 55 | tmux select-layout tiled 56 | tmux select-layout main-vertical 57 | fi 58 | 59 | # Enable mouse control - for changing pane size 60 | # Disabled for now since it makes copying more difficult 61 | # tmux set-mouse on 62 | 63 | # Move cursor to worker 0 64 | tmux select-pane -t $name:0.0 65 | 66 | # Resize the left window 67 | tmux resize-pane -L 50 68 | 69 | 70 | 71 | # Set pane synchronization 72 | tmux set-window-option -t $name:0 synchronize-panes on 73 | 74 | # Set pane border format 75 | tmux set-option -t $name pane-border-status top 76 | tmux set-option -t $name pane-border-format "worker #{pane_index} " 77 | 78 | 79 | fi 80 | 81 | # Attach to the session 82 | tmux attach -t $name 83 | -------------------------------------------------------------------------------- /training/pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 119 3 | target-version = ['py37'] 4 | 5 | [tool.ruff] 6 | # Never enforce `E501` (line length violations). 7 | ignore = ["C901", "E501", "E741", "W605"] 8 | select = ["C", "E", "F", "I", "W"] 9 | line-length = 119 10 | 11 | # Ignore import violations in all `__init__.py` files. 12 | [tool.ruff.per-file-ignores] 13 | "__init__.py" = ["E402", "F401", "F403", "F811"] 14 | 15 | [tool.ruff.isort] 16 | lines-after-imports = 2 17 | -------------------------------------------------------------------------------- /training/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import os 17 | 18 | import setuptools 19 | 20 | _deps = [ 21 | "torch>=1.10", 22 | "transformers>=4.35.1", 23 | "datasets[audio]>=2.14.7", 24 | "accelerate>=0.24.1", 25 | "jiwer", 26 | "evaluate>=0.4.1", 27 | "wandb", 28 | "tensorboard", 29 | "nltk", 30 | ] 31 | 32 | _extras_dev_deps = [ 33 | "ruff==0.1.5", 34 | ] 35 | 36 | here = os.path.abspath(os.path.dirname(__file__)) 37 | 38 | with open(os.path.join(here, "README.md"), encoding="utf-8") as f: 39 | long_description = f.read() 40 | 41 | setuptools.setup( 42 | name="distil_whisper", 43 | description="Toolkit for distilling OpenAI's Whisper model.", 44 | long_description=long_description, 45 | long_description_content_type="text/markdown", 46 | packages=setuptools.find_packages(), 47 | install_requires=_deps, 48 | extras_require={ 49 | "dev": [_extras_dev_deps], 50 | }, 51 | ) 52 | 53 | --------------------------------------------------------------------------------