├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── app ├── app.py ├── monitor.sh └── run_app.sh ├── benchmarks ├── run_pipeline_dataloader.py ├── run_pjit.py ├── run_pjit_dataloader.py ├── run_pmap.py └── run_pytorch.py ├── pyproject.toml ├── setup.py ├── whisper-jax-tpu.ipynb └── whisper_jax ├── __init__.py ├── layers.py ├── modeling_flax_whisper.py ├── partitioner.py ├── pipeline.py └── train_state.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/.gitignore 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # logs 12 | logs/ 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # celery beat schedule file 89 | celerybeat-schedule 90 | 91 | # SageMath parsed files 92 | *.sage.py 93 | 94 | # Environments 95 | .env 96 | .venv 97 | env/ 98 | venv/ 99 | ENV/ 100 | env.bak/ 101 | venv.bak/ 102 | 103 | # Spyder project settings 104 | .spyderproject 105 | .spyproject 106 | 107 | # Rope project settings 108 | .ropeproject 109 | 110 | # mkdocs documentation 111 | /site 112 | 113 | # mypy 114 | .mypy_cache/ 115 | .dmypy.json 116 | dmypy.json 117 | 118 | # Pyre type checker 119 | .pyre/ 120 | 121 | # vscode 122 | .vs 123 | .vscode 124 | 125 | # Pycharm 126 | .idea 127 | 128 | # TF code 129 | tensorflow_code 130 | 131 | # Models 132 | proc_data 133 | 134 | # examples 135 | runs 136 | /runs_old 137 | /wandb 138 | /examples/runs 139 | /examples/**/*.args 140 | /examples/rag/sweep 141 | 142 | # data 143 | /data 144 | serialization_dir 145 | 146 | # emacs 147 | *.*~ 148 | debug.env 149 | 150 | # vim 151 | .*.swp 152 | 153 | #ctags 154 | tags 155 | 156 | # pre-commit 157 | .pre-commit* 158 | 159 | # .lock 160 | *.lock 161 | 162 | # DS_Store (MacOS) 163 | .DS_Store 164 | # RL pipelines may produce mp4 outputs 165 | *.mp4 166 | 167 | # dependencies 168 | /transformers 169 | 170 | # ruff 171 | .ruff_cache 172 | 173 | wandb 174 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2023- The HuggingFace Inc. team and The OpenAI Authors. All rights reserved. 2 | 3 | Apache License 4 | Version 2.0, January 2004 5 | http://www.apache.org/licenses/ 6 | 7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 8 | 9 | 1. Definitions. 10 | 11 | "License" shall mean the terms and conditions for use, reproduction, 12 | and distribution as defined by Sections 1 through 9 of this document. 13 | 14 | "Licensor" shall mean the copyright owner or entity authorized by 15 | the copyright owner that is granting the License. 16 | 17 | "Legal Entity" shall mean the union of the acting entity and all 18 | other entities that control, are controlled by, or are under common 19 | control with that entity. For the purposes of this definition, 20 | "control" means (i) the power, direct or indirect, to cause the 21 | direction or management of such entity, whether by contract or 22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 23 | outstanding shares, or (iii) beneficial ownership of such entity. 24 | 25 | "You" (or "Your") shall mean an individual or Legal Entity 26 | exercising permissions granted by this License. 27 | 28 | "Source" form shall mean the preferred form for making modifications, 29 | including but not limited to software source code, documentation 30 | source, and configuration files. 31 | 32 | "Object" form shall mean any form resulting from mechanical 33 | transformation or translation of a Source form, including but 34 | not limited to compiled object code, generated documentation, 35 | and conversions to other media types. 36 | 37 | "Work" shall mean the work of authorship, whether in Source or 38 | Object form, made available under the License, as indicated by a 39 | copyright notice that is included in or attached to the work 40 | (an example is provided in the Appendix below). 41 | 42 | "Derivative Works" shall mean any work, whether in Source or Object 43 | form, that is based on (or derived from) the Work and for which the 44 | editorial revisions, annotations, elaborations, or other modifications 45 | represent, as a whole, an original work of authorship. For the purposes 46 | of this License, Derivative Works shall not include works that remain 47 | separable from, or merely link (or bind by name) to the interfaces of, 48 | the Work and Derivative Works thereof. 49 | 50 | "Contribution" shall mean any work of authorship, including 51 | the original version of the Work and any modifications or additions 52 | to that Work or Derivative Works thereof, that is intentionally 53 | submitted to Licensor for inclusion in the Work by the copyright owner 54 | or by an individual or Legal Entity authorized to submit on behalf of 55 | the copyright owner. For the purposes of this definition, "submitted" 56 | means any form of electronic, verbal, or written communication sent 57 | to the Licensor or its representatives, including but not limited to 58 | communication on electronic mailing lists, source code control systems, 59 | and issue tracking systems that are managed by, or on behalf of, the 60 | Licensor for the purpose of discussing and improving the Work, but 61 | excluding communication that is conspicuously marked or otherwise 62 | designated in writing by the copyright owner as "Not a Contribution." 63 | 64 | "Contributor" shall mean Licensor and any individual or Legal Entity 65 | on behalf of whom a Contribution has been received by Licensor and 66 | subsequently incorporated within the Work. 67 | 68 | 2. Grant of Copyright License. Subject to the terms and conditions of 69 | this License, each Contributor hereby grants to You a perpetual, 70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 71 | copyright license to reproduce, prepare Derivative Works of, 72 | publicly display, publicly perform, sublicense, and distribute the 73 | Work and such Derivative Works in Source or Object form. 74 | 75 | 3. Grant of Patent License. Subject to the terms and conditions of 76 | this License, each Contributor hereby grants to You a perpetual, 77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 78 | (except as stated in this section) patent license to make, have made, 79 | use, offer to sell, sell, import, and otherwise transfer the Work, 80 | where such license applies only to those patent claims licensable 81 | by such Contributor that are necessarily infringed by their 82 | Contribution(s) alone or by combination of their Contribution(s) 83 | with the Work to which such Contribution(s) was submitted. If You 84 | institute patent litigation against any entity (including a 85 | cross-claim or counterclaim in a lawsuit) alleging that the Work 86 | or a Contribution incorporated within the Work constitutes direct 87 | or contributory patent infringement, then any patent licenses 88 | granted to You under this License for that Work shall terminate 89 | as of the date such litigation is filed. 90 | 91 | 4. Redistribution. You may reproduce and distribute copies of the 92 | Work or Derivative Works thereof in any medium, with or without 93 | modifications, and in Source or Object form, provided that You 94 | meet the following conditions: 95 | 96 | (a) You must give any other recipients of the Work or 97 | Derivative Works a copy of this License; and 98 | 99 | (b) You must cause any modified files to carry prominent notices 100 | stating that You changed the files; and 101 | 102 | (c) You must retain, in the Source form of any Derivative Works 103 | that You distribute, all copyright, patent, trademark, and 104 | attribution notices from the Source form of the Work, 105 | excluding those notices that do not pertain to any part of 106 | the Derivative Works; and 107 | 108 | (d) If the Work includes a "NOTICE" text file as part of its 109 | distribution, then any Derivative Works that You distribute must 110 | include a readable copy of the attribution notices contained 111 | within such NOTICE file, excluding those notices that do not 112 | pertain to any part of the Derivative Works, in at least one 113 | of the following places: within a NOTICE text file distributed 114 | as part of the Derivative Works; within the Source form or 115 | documentation, if provided along with the Derivative Works; or, 116 | within a display generated by the Derivative Works, if and 117 | wherever such third-party notices normally appear. The contents 118 | of the NOTICE file are for informational purposes only and 119 | do not modify the License. You may add Your own attribution 120 | notices within Derivative Works that You distribute, alongside 121 | or as an addendum to the NOTICE text from the Work, provided 122 | that such additional attribution notices cannot be construed 123 | as modifying the License. 124 | 125 | You may add Your own copyright statement to Your modifications and 126 | may provide additional or different license terms and conditions 127 | for use, reproduction, or distribution of Your modifications, or 128 | for any such Derivative Works as a whole, provided Your use, 129 | reproduction, and distribution of the Work otherwise complies with 130 | the conditions stated in this License. 131 | 132 | 5. Submission of Contributions. Unless You explicitly state otherwise, 133 | any Contribution intentionally submitted for inclusion in the Work 134 | by You to the Licensor shall be under the terms and conditions of 135 | this License, without any additional terms or conditions. 136 | Notwithstanding the above, nothing herein shall supersede or modify 137 | the terms of any separate license agreement you may have executed 138 | with Licensor regarding such Contributions. 139 | 140 | 6. Trademarks. This License does not grant permission to use the trade 141 | names, trademarks, service marks, or product names of the Licensor, 142 | except as required for reasonable and customary use in describing the 143 | origin of the Work and reproducing the content of the NOTICE file. 144 | 145 | 7. Disclaimer of Warranty. Unless required by applicable law or 146 | agreed to in writing, Licensor provides the Work (and each 147 | Contributor provides its Contributions) on an "AS IS" BASIS, 148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 149 | implied, including, without limitation, any warranties or conditions 150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 151 | PARTICULAR PURPOSE. You are solely responsible for determining the 152 | appropriateness of using or redistributing the Work and assume any 153 | risks associated with Your exercise of permissions under this License. 154 | 155 | 8. Limitation of Liability. In no event and under no legal theory, 156 | whether in tort (including negligence), contract, or otherwise, 157 | unless required by applicable law (such as deliberate and grossly 158 | negligent acts) or agreed to in writing, shall any Contributor be 159 | liable to You for damages, including any direct, indirect, special, 160 | incidental, or consequential damages of any character arising as a 161 | result of this License or out of the use or inability to use the 162 | Work (including but not limited to damages for loss of goodwill, 163 | work stoppage, computer failure or malfunction, or any and all 164 | other commercial damages or losses), even if such Contributor 165 | has been advised of the possibility of such damages. 166 | 167 | 9. Accepting Warranty or Additional Liability. While redistributing 168 | the Work or Derivative Works thereof, You may choose to offer, 169 | and charge a fee for, acceptance of support, warranty, indemnity, 170 | or other liability obligations and/or rights consistent with this 171 | License. However, in accepting such obligations, You may act only 172 | on Your own behalf and on Your sole responsibility, not on behalf 173 | of any other Contributor, and only if You agree to indemnify, 174 | defend, and hold each Contributor harmless for any liability 175 | incurred by, or claims asserted against, such Contributor by reason 176 | of your accepting any such warranty or additional liability. 177 | 178 | END OF TERMS AND CONDITIONS 179 | 180 | APPENDIX: How to apply the Apache License to your work. 181 | 182 | To apply the Apache License to your work, attach the following 183 | boilerplate notice, with the fields enclosed by brackets "[]" 184 | replaced with your own identifying information. (Don't include 185 | the brackets!) The text should be enclosed in the appropriate 186 | comment syntax for the file format. We also recommend that a 187 | file or class name and description of purpose be included on the 188 | same "printed page" as the copyright notice for easier 189 | identification within third-party archives. 190 | 191 | Copyright [yyyy] [name of copyright owner] 192 | 193 | Licensed under the Apache License, Version 2.0 (the "License"); 194 | you may not use this file except in compliance with the License. 195 | You may obtain a copy of the License at 196 | 197 | http://www.apache.org/licenses/LICENSE-2.0 198 | 199 | Unless required by applicable law or agreed to in writing, software 200 | distributed under the License is distributed on an "AS IS" BASIS, 201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 202 | See the License for the specific language governing permissions and 203 | limitations under the License. 204 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | check_dirs := . 2 | 3 | quality: 4 | black --check $(check_dirs) 5 | ruff $(check_dirs) 6 | 7 | style: 8 | black $(check_dirs) 9 | ruff $(check_dirs) --fix 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Whisper JAX 2 | 3 | This repository contains optimised JAX code for OpenAI's [Whisper Model](https://arxiv.org/abs/2212.04356), largely built 4 | on the 🤗 Hugging Face Transformers Whisper implementation. Compared to OpenAI's PyTorch code, Whisper JAX runs over **70x** 5 | faster, making it the fastest Whisper implementation available. 6 | 7 | The JAX code is compatible on CPU, GPU and TPU, and can be run standalone (see [Pipeline Usage](#pipeline-usage)) or 8 | as an inference endpoint (see [Creating an Endpoint](#creating-an-endpoint)). 9 | 10 | For a quick-start guide to running Whisper JAX on a Cloud TPU, refer to the following Kaggle notebook, where we transcribe 30 mins of audio in approx 30 sec: 11 | 12 | [![Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://www.kaggle.com/code/sgandhi99/whisper-jax-tpu) 13 | 14 | The Whisper JAX model is also running as a demo on the Hugging Face Hub: 15 | 16 | [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/sanchit-gandhi/whisper-jax) 17 | 18 | ## Installation 19 | 20 | Whisper JAX was tested using Python 3.9 and JAX version 0.4.5. Installation assumes that you already have the latest 21 | version of the JAX package installed on your device. You can do so using the official JAX installation guide: https://github.com/google/jax#installation 22 | 23 | Once the appropriate version of JAX has been installed, Whisper JAX can be installed through pip: 24 | ``` 25 | pip install git+https://github.com/sanchit-gandhi/whisper-jax.git 26 | ``` 27 | 28 | To update the Whisper JAX package to the latest version, simply run: 29 | ``` 30 | pip install --upgrade --no-deps --force-reinstall git+https://github.com/sanchit-gandhi/whisper-jax.git 31 | ``` 32 | 33 | ## Pipeline Usage 34 | 35 | The recommended way of running Whisper JAX is through the [`FlaxWhisperPipline`](https://github.com/sanchit-gandhi/whisper-jax/blob/main/whisper_jax/pipeline.py#L57) abstraction class. This class handles all 36 | the necessary pre- and post-processing, as well as wrapping the generate method for data parallelism across accelerator devices. 37 | 38 | Whisper JAX makes use of JAX's [`pmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html) function for data parallelism across GPU/TPU devices. This function is _Just In Time (JIT)_ 39 | compiled the first time it is called. Thereafter, the function will be _cached_, enabling it to be run in super-fast time: 40 | 41 | ```python 42 | from whisper_jax import FlaxWhisperPipline 43 | 44 | # instantiate pipeline 45 | pipeline = FlaxWhisperPipline("openai/whisper-large-v2") 46 | 47 | # JIT compile the forward call - slow, but we only do once 48 | text = pipeline("audio.mp3") 49 | 50 | # used cached function thereafter - super fast!! 51 | text = pipeline("audio.mp3") 52 | ``` 53 | 54 | ### Half-Precision 55 | 56 | The model computation can be run in half-precision by passing the dtype argument when instantiating the pipeline. This will 57 | speed-up the computation quite considerably by storing intermediate tensors in half-precision. There is no change to the precision 58 | of the model weights. 59 | 60 | For most GPUs, the dtype should be set to `jnp.float16`. For A100 GPUs or TPUs, the dtype should be set to `jnp.bfloat16`: 61 | ```python 62 | from whisper_jax import FlaxWhisperPipline 63 | import jax.numpy as jnp 64 | 65 | # instantiate pipeline in bfloat16 66 | pipeline = FlaxWhisperPipline("openai/whisper-large-v2", dtype=jnp.bfloat16) 67 | ``` 68 | 69 | ### Batching 70 | Whisper JAX also provides the option of _batching_ a single audio input across accelerator devices. The audio is first 71 | chunked into 30 second segments, and then chunks dispatched to the model to be transcribed in parallel. The resulting 72 | transcriptions are stitched back together at the boundaries to give a single, uniform transcription. In practice, batching 73 | provides a 10x speed-up compared to transcribing the audio samples sequentially, with a less than 1% penalty to the WER[^1], provided the batch size is selected large enough. 74 | 75 | To enable batching, pass the `batch_size` parameter when you instantiate the pipeline: 76 | 77 | ```python 78 | from whisper_jax import FlaxWhisperPipline 79 | 80 | # instantiate pipeline with batching 81 | pipeline = FlaxWhisperPipline("openai/whisper-large-v2", batch_size=16) 82 | ``` 83 | 84 | ### Task 85 | 86 | By default, the pipeline transcribes the audio file in the language it was spoken in. For speech translation, set the 87 | `task` argument to `"translate"`: 88 | 89 | ```python 90 | # translate 91 | text = pipeline("audio.mp3", task="translate") 92 | ``` 93 | 94 | ### Timestamps 95 | 96 | The [`FlaxWhisperPipline`](https://github.com/sanchit-gandhi/whisper-jax/blob/main/whisper_jax/pipeline.py#L57) also supports timestamp prediction. Note that enabling timestamps will require a second JIT compilation of the 97 | forward call, this time including the timestamp outputs: 98 | 99 | ```python 100 | # transcribe and return timestamps 101 | outputs = pipeline("audio.mp3", task="transcribe", return_timestamps=True) 102 | text = outputs["text"] # transcription 103 | chunks = outputs["chunks"] # transcription + timestamps 104 | ``` 105 | 106 | ### Putting it all together 107 | In the following code snippet, we instantiate the model in bfloat16 precision with batching enabled, and transcribe the audio file 108 | returning timestamps tokens: 109 | 110 | ```python 111 | from whisper_jax import FlaxWhisperPipline 112 | import jax.numpy as jnp 113 | 114 | # instantiate pipeline with bfloat16 and enable batching 115 | pipeline = FlaxWhisperPipline("openai/whisper-large-v2", dtype=jnp.bfloat16, batch_size=16) 116 | 117 | # transcribe and return timestamps 118 | outputs = pipeline("audio.mp3", task="transcribe", return_timestamps=True) 119 | ``` 120 | 121 | ## Model Usage 122 | 123 | The Whisper JAX model can use on a more granular level in much the same way as the original Hugging Face 124 | Transformers implementation. This requires the Whisper processor to be loaded separately to the model to handle the 125 | pre- and post-processing, and the generate function to be wrapped using `pmap` by hand: 126 | 127 | ```python 128 | import jax.numpy as jnp 129 | from datasets import load_dataset 130 | from flax.jax_utils import replicate 131 | from flax.training.common_utils import shard 132 | from jax import device_get, pmap 133 | from transformers import WhisperProcessor 134 | 135 | from whisper_jax import FlaxWhisperForConditionalGeneration 136 | 137 | # load the processor and model 138 | processor = WhisperProcessor.from_pretrained("openai/whisper-large-v2") 139 | model, params = FlaxWhisperForConditionalGeneration.from_pretrained( 140 | "openai/whisper-large-v2", dtype=jnp.bfloat16, _do_init=False, 141 | ) 142 | 143 | def generate_fn(input_features): 144 | pred_ids = model.generate( 145 | input_features, task="transcribe", return_timestamps=False, max_length=model.config.max_length, params=params, 146 | ) 147 | return pred_ids.sequences 148 | 149 | # pmap the generate function for data parallelism 150 | p_generate = pmap(generate_fn, "input_features") 151 | # replicate the parameters across devices 152 | params = replicate(params) 153 | 154 | # load a dummy sample from the LibriSpeech dataset 155 | ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") 156 | sample = ds[0]["audio"] 157 | 158 | # pre-process: convert the audio array to log-mel input features 159 | input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="np").input_features 160 | # replicate the input features across devices for DP 161 | input_features = shard(input_features) 162 | 163 | # run the forward pass (JIT compiled the first time it is called) 164 | pred_ids = p_generate(input_features) 165 | output_ids = device_get(pred_ids.reshape(-1, model.config.max_length)) 166 | 167 | # post-process: convert tokens ids to text string 168 | transcription = processor.batch_decode(pred_ids, skip_special_tokens=True) 169 | ``` 170 | 171 | ## Available Models and Languages 172 | All Whisper models on the Hugging Face Hub with Flax weights are compatible with Whisper JAX. This includes, but is not limited to, 173 | the official OpenAI Whisper checkpoints: 174 | 175 | | Size | Parameters | English-only | Multilingual | 176 | |----------|------------|------------------------------------------------------|-----------------------------------------------------| 177 | | tiny | 39 M | [✓](https://huggingface.co/openai/whisper-tiny.en) | [✓](https://huggingface.co/openai/whisper-tiny) | 178 | | base | 74 M | [✓](https://huggingface.co/openai/whisper-base.en) | [✓](https://huggingface.co/openai/whisper-base) | 179 | | small | 244 M | [✓](https://huggingface.co/openai/whisper-small.en) | [✓](https://huggingface.co/openai/whisper-small) | 180 | | medium | 769 M | [✓](https://huggingface.co/openai/whisper-medium.en) | [✓](https://huggingface.co/openai/whisper-medium) | 181 | | large | 1550 M | x | [✓](https://huggingface.co/openai/whisper-large) | 182 | | large-v2 | 1550 M | x | [✓](https://huggingface.co/openai/whisper-large-v2) | 183 | 184 | Should you wish to use a fine-tuned Whisper checkpoint in Whisper JAX, you should first convert the PyTorch weights to Flax. 185 | This is straightforward through use of the `from_pt` argument, which will convert the PyTorch state dict to a frozen Flax 186 | parameter dictionary on the fly. You can then push the converted Flax weights to the Hub to be used directly in Flax 187 | the next time they are required. Note that converting weights from PyTorch to Flax requires both PyTorch and Flax to be installed. 188 | 189 | For example, to convert the fine-tuned checkpoint [`sanchit-gandhi/whisper-small-hi`](https://huggingface.co/sanchit-gandhi/whisper-small-hi) from the blog post [Fine-Tuning Whisper](https://huggingface.co/blog/fine-tune-whisper): 190 | ```python 191 | from whisper_jax import FlaxWhisperForConditionalGeneration, FlaxWhisperPipline 192 | import jax.numpy as jnp 193 | 194 | checkpoint_id = "sanchit-gandhi/whisper-small-hi" 195 | # convert PyTorch weights to Flax 196 | model = FlaxWhisperForConditionalGeneration.from_pretrained(checkpoint_id, from_pt=True) 197 | # push converted weights to the Hub 198 | model.push_to_hub(checkpoint_id) 199 | 200 | # now we can load the Flax weights directly as required 201 | pipeline = FlaxWhisperPipline(checkpoint_id, dtype=jnp.bfloat16, batch_size=16) 202 | ``` 203 | 204 | ## Advanced Usage 205 | More advanced users may wish to explore different parallelisation techniques. The Whisper JAX code is 206 | built on-top of the [T5x codebase](https://github.com/google-research/t5x), meaning it can be run using model, activation, and data parallelism using the T5x 207 | partitioning convention. To use T5x partitioning, the logical axis rules and number of model partitions must be defined. 208 | For more details, the user is referred to the official T5x partitioning guide: https://github.com/google-research/t5x/blob/main/docs/usage/partitioning.md 209 | 210 | ### Pipeline 211 | The following code snippet demonstrates how data parallelism can be achieved using the pipeline `shard_params` method in 212 | an entirely equivalent way to `pmap`: 213 | 214 | ```python 215 | from whisper_jax import FlaxWhisperPipline 216 | import jax.numpy as jnp 217 | 218 | # 2D parameter and activation partitioning for DP 219 | logical_axis_rules_dp = ( 220 | ("batch", "data"), 221 | ("mlp", None), 222 | ("heads", None), 223 | ("vocab", None), 224 | ("embed", None), 225 | ("embed", None), 226 | ("joined_kv", None), 227 | ("kv", None), 228 | ("length", None), 229 | ("num_mel", None), 230 | ("channels", None), 231 | ) 232 | 233 | pipeline = FlaxWhisperPipline("openai/whisper-large-v2", dtype=jnp.bfloat16, batch_size=16) 234 | pipeline.shard_params(num_mp_partitions=1, logical_axis_rules=logical_axis_rules_dp) 235 | ``` 236 | 237 | ### Model 238 | It is also possible to use the Whisper JAX model with T5x partitioning by defining a T5x inference state and T5x partitioner: 239 | 240 | ```python 241 | import jax 242 | import jax.numpy as jnp 243 | from flax.core.frozen_dict import freeze 244 | from jax.sharding import PartitionSpec as P 245 | 246 | from whisper_jax import FlaxWhisperForConditionalGeneration, InferenceState, PjitPartitioner 247 | 248 | 249 | # 2D parameter and activation partitioning for DP 250 | logical_axis_rules_dp = [ 251 | ("batch", "data"), 252 | ("mlp", None), 253 | ("heads", None), 254 | ("vocab", None), 255 | ("embed", None), 256 | ("embed", None), 257 | ("joined_kv", None), 258 | ("kv", None), 259 | ("length", None), 260 | ("num_mel", None), 261 | ("channels", None), 262 | ] 263 | 264 | model, params = FlaxWhisperForConditionalGeneration.from_pretrained( 265 | "openai/whisper-large-v2", 266 | _do_init=False, 267 | dtype=jnp.bfloat16, 268 | ) 269 | 270 | 271 | def init_fn(): 272 | input_shape = (1, 80, 3000) 273 | 274 | input_features = jnp.zeros(input_shape, dtype="f4") 275 | input_features = input_features.at[(..., -1)].set(model.config.eos_token_id) 276 | 277 | decoder_input_ids = jnp.zeros((input_shape[0], 1), dtype="i4") 278 | decoder_attention_mask = jnp.ones_like(decoder_input_ids) 279 | 280 | batch_size, sequence_length = decoder_input_ids.shape 281 | decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) 282 | 283 | rng = jax.random.PRNGKey(0) 284 | init_params = model.module.init( 285 | rng, 286 | input_features=input_features, 287 | decoder_input_ids=decoder_input_ids, 288 | decoder_attention_mask=decoder_attention_mask, 289 | decoder_position_ids=decoder_position_ids, 290 | return_dict=False, 291 | ) 292 | return init_params 293 | 294 | 295 | # Axis names metadata 296 | param_axes = jax.eval_shape(init_fn)["params_axes"] 297 | 298 | # Create InferenceState, since the partitioner expects it 299 | state = InferenceState( 300 | step=jnp.array(0), 301 | params=freeze(model.params_shape_tree), 302 | params_axes=freeze(param_axes), 303 | flax_mutables=None, 304 | flax_mutables_axes=param_axes, 305 | ) 306 | 307 | # Define the pjit partitioner with 1 model partition 308 | partitioner = PjitPartitioner( 309 | num_partitions=1, 310 | logical_axis_rules=logical_axis_rules_dp, 311 | ) 312 | 313 | mesh_axes = partitioner.get_mesh_axes(state) 314 | params_spec = mesh_axes.params 315 | 316 | p_shard_params = partitioner.partition(model.to_bf16, (params_spec,), params_spec) 317 | 318 | 319 | def generate(params, input_features): 320 | output_ids = model.generate(input_features, params=params, max_length=model.config.max_length).sequences 321 | return output_ids 322 | 323 | 324 | p_generate = partitioner.partition( 325 | generate, 326 | in_axis_resources=(params_spec, P("data")), 327 | out_axis_resources=P("data"), 328 | ) 329 | 330 | # This will auto-magically run in mesh context 331 | params = p_shard_params(freeze(params)) 332 | 333 | # you can now run the forward pass with: 334 | # pred_ids = p_generate(input_features) 335 | ``` 336 | 337 | ## Benchmarks 338 | 339 | We compare Whisper JAX to the official [OpenAI implementation](https://github.com/openai/whisper) and the [🤗 Transformers 340 | implementation](https://huggingface.co/docs/transformers/model_doc/whisper). We benchmark the models on audio samples of 341 | increasing length and report the average inference time in seconds over 10 repeat runs. For all three systems, we pass a 342 | pre-loaded audio file to the model and measure the time for the forward pass. Leaving the task of loading the audio file 343 | to the systems adds an equal offset to all the benchmark times, so the actual time for loading **and** transcribing an 344 | audio file will be higher than the reported numbers. 345 | 346 | OpenAI and Transformers both run in PyTorch on GPU. Whisper JAX runs in JAX on GPU and TPU. OpenAI transcribes the audio 347 | sequentially in the order it is spoken. Both Transformers and Whisper JAX use a batching algorithm, where chunks of audio 348 | are batched together and transcribed in parallel (see section [Batching](#batching)). 349 | 350 | **Table 1:** Average inference time in seconds for audio files of increasing length. GPU device is a single A100 40GB GPU. 351 | TPU device is a single TPU v4-8. 352 | 353 |
354 | 355 | | | OpenAI | Transformers | Whisper JAX | Whisper JAX | 356 | |-----------|---------|--------------|-------------|-------------| 357 | | | | | | | 358 | | Framework | PyTorch | PyTorch | JAX | JAX | 359 | | Backend | GPU | GPU | GPU | TPU | 360 | | | | | | | 361 | | 1 min | 13.8 | 4.54 | 1.72 | 0.45 | 362 | | 10 min | 108.3 | 20.2 | 9.38 | 2.01 | 363 | | 1 hour | 1001.0 | 126.1 | 75.3 | 13.8 | 364 | | | | | | | 365 | 366 |
367 | 368 | ## Creating an Endpoint 369 | 370 | The Whisper JAX model is running as a demo on the Hugging Face Hub: 371 | 372 | [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/sanchit-gandhi/whisper-jax) 373 | 374 | However, at peak times there may be a queue of users that limit how quickly your audio input is transcribed. In this case, 375 | you may benefit from running the model yourself, such that you have unrestricted access to the Whisper JAX model. 376 | 377 | If you are just interested in running the model in a standalone Python script, refer to the Kaggle notebook Whisper JAX TPU: 378 | 379 | [![Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://www.kaggle.com/code/sgandhi99/whisper-jax-tpu) 380 | 381 | Otherwise, we provide all the necessary code for creating an inference endpoint. To obtain this code, first clone the 382 | repository on the GPU/TPU on which you want to host the endpoint: 383 | ``` 384 | git clone https://github.com/sanchit-gandhi/whisper-jax 385 | ``` 386 | 387 | And then install Whisper JAX from source, with the required additional endpoint dependencies: 388 | ``` 389 | cd whisper-jax 390 | pip install -e .["endpoint"] 391 | ``` 392 | 393 | We recommend that you set-up an endpoint in the same zone/region as the one you are based in. This reduces the communication 394 | time between your local machine and the remote one, which can significantly reduce the overall request time. 395 | 396 | ### Gradio App 397 | 398 | The Python script [`app.py`](app/app.py) contains the code to launch a Gradio app with the Whisper large-v2 model. 399 | By default, it uses a batch size of 16 and bfloat16 half-precision. You should update these parameters depending on your 400 | GPU/TPU device (as explained in the sections on [Half-precision](#half-precision) and [Batching](#batching)). 401 | 402 | We can launch the Gradio app on port 7860 (default) on our GPU/TPU device through the following command: 403 | ``` 404 | python app/app.py 405 | ``` 406 | 407 | This will launch a Gradio demo with the same interface as the official Whisper JAX demo. To view the Gradio app remotely, 408 | we have two options: 409 | 410 | 1. Open the port 7860 on the GPU/TPU device to listen to all requests 411 | 2. Start an ngrok server on the GPU/TPU that redirects requests to port 7860 412 | 413 | To open the port 7860 on your GPU/TPU, refer to your hardware provider's firewall instructions (for GCP, these can be 414 | found [here](https://cloud.google.com/firewall/docs/using-firewalls)). Once you have opened port 7860, you should be able 415 | to access the gradio demo through the http address: 416 | ``` 417 | http://DEVICE-IP:7860 418 | ``` 419 | where `DEVICE-IP` is the public IP address of your GPU/TPU. We can verify this address is accessible by opening this 420 | http address in a browser window on our local machine. 421 | 422 | Alternatively, we can direct network requests to the Gradio app using ngrok. By using ngrok, we don't need to open the 423 | port 7860 on our GPU/TPU - ngrok will provide us with a public http address that will automatically redirect requests to 424 | port 7860 on our accelerator. However, in our experience, using ngrok was less reliable than a direct tunnel to port 7860, 425 | thus we recommend option 1 here where possible. 426 | 427 | To set-up ngrok on your GPU/TPU, first install ngrok according to the official [installation guide](https://ngrok.com/download). 428 | You should authenticate your ngrok account if you have one, otherwise your ngrok server will be time-limited to 2 hours. 429 | Once installed and authenticated, you can launch an ngrok server on port 7860: 430 | ``` 431 | ngrok http 7860 432 | ``` 433 | The ngrok http address will be of the form: 434 | ``` 435 | https://NGROK-ADDRESS.ngrok.io 436 | ``` 437 | which can be used to access the Gradio demo through a web browser. 438 | 439 | ### Sending Requests 440 | 441 | Independent of whether you've chosen to open the port 7860 or use ngrok, we're now ready to send audio file requests to our 442 | endpoint. To do this, we'll make use of the `gradio_client` library. If you already have a recent version of Gradio, 443 | then the `gradio_client` library is included as a dependency. 444 | 445 | Otherwise, the lightweight `gradio_client` package can be installed from pip and is tested to work with Python versions 446 | 3.9 or higher: 447 | ``` 448 | pip install --upgrade gradio_client 449 | ``` 450 | 451 | We can now send json requests to our endpoint using ngrok. The function `transcribe_audio` sends an audio file to our endpoint 452 | and returns the transcription: 453 | 454 | ```python 455 | from gradio_client import Client 456 | 457 | # make sure this URL matches your http web address 458 | API_URL = "http://DEVICE-IP:7860/" # if using port 7860 459 | API_URL = "https://NGROK-ADDRESS.ngrok.io/" # if using ngrok 460 | 461 | # set up the Gradio client 462 | client = Client(API_URL) 463 | 464 | def transcribe_audio(audio_path, task="transcribe", return_timestamps=False): 465 | """Function to transcribe an audio file using our endpoint""" 466 | text, runtime = client.predict( 467 | audio_path, 468 | task, 469 | return_timestamps, 470 | api_name="/predict_1", 471 | ) 472 | return text 473 | 474 | # transcribe an audio file using our endpoint 475 | output = transcribe_audio("audio.mp3") 476 | 477 | # transcribe with timestamps 478 | output_with_timestamps = transcribe_audio("audio.mp3", return_timestamps=True) 479 | ``` 480 | 481 | ## Acknowledgements 482 | 483 | * 🤗 Hugging Face Transformers for the base Whisper implementation, particularly to [andyehrenberg](https://github.com/andyehrenberg) for the [Flax Whisper PR](https://github.com/huggingface/transformers/pull/20479) and [ArthurZucker](https://github.com/ArthurZucker) for the batching algorithm 484 | * Gradio for their easy-to-use package for building ML demos, and [pcuenca](https://github.com/pcuenca) for the help in hooking the demo up to the TPU 485 | * Google's [TPU Research Cloud (TRC)](https://sites.research.google/trc/about/) programme for Cloud TPUs 486 | * Google's [t5x Repository](https://github.com/google-research/t5x) for the model partitioning framework 487 | 488 | [^1]: See WER results from Colab: https://colab.research.google.com/drive/1rS1L4YSJqKUH_3YxIQHBI982zso23wor?usp=sharing 489 | -------------------------------------------------------------------------------- /app/app.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import os 4 | import tempfile 5 | import time 6 | from multiprocessing import Pool 7 | 8 | import gradio as gr 9 | import jax.numpy as jnp 10 | import numpy as np 11 | import yt_dlp as youtube_dl 12 | from jax.experimental.compilation_cache import compilation_cache as cc 13 | from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE 14 | from transformers.pipelines.audio_utils import ffmpeg_read 15 | 16 | from whisper_jax import FlaxWhisperPipline 17 | 18 | 19 | cc.initialize_cache("./jax_cache") 20 | checkpoint = "openai/whisper-large-v3" 21 | 22 | BATCH_SIZE = 32 23 | CHUNK_LENGTH_S = 30 24 | NUM_PROC = 32 25 | FILE_LIMIT_MB = 1000 26 | YT_LENGTH_LIMIT_S = 7200 # limit to 2 hour YouTube files 27 | 28 | title = "Whisper JAX: The Fastest Whisper API ⚡️" 29 | 30 | description = """Whisper JAX is an optimised implementation of the [Whisper model](https://huggingface.co/openai/whisper-large-v3) by OpenAI. It runs on JAX with a TPU v4-8 in the backend. Compared to PyTorch on an A100 GPU, it is over [**70x faster**](https://github.com/sanchit-gandhi/whisper-jax#benchmarks), making it the fastest Whisper API available. 31 | 32 | Note that at peak times, you may find yourself in the queue for this demo. When you submit a request, your queue position will be shown in the top right-hand side of the demo pane. Once you reach the front of the queue, your audio file will be transcribed, with the progress displayed through a progress bar. 33 | 34 | To skip the queue, you may wish to create your own inference endpoint, details for which can be found in the [Whisper JAX repository](https://github.com/sanchit-gandhi/whisper-jax#creating-an-endpoint). 35 | """ 36 | 37 | article = "Whisper large-v3 model by OpenAI. Backend running JAX on a TPU v4-8 through the generous support of the [TRC](https://sites.research.google/trc/about/) programme. Whisper JAX [code](https://github.com/sanchit-gandhi/whisper-jax) and Gradio demo by 🤗 Hugging Face." 38 | 39 | language_names = sorted(TO_LANGUAGE_CODE.keys()) 40 | 41 | logger = logging.getLogger("whisper-jax-app") 42 | logger.setLevel(logging.INFO) 43 | ch = logging.StreamHandler() 44 | ch.setLevel(logging.INFO) 45 | formatter = logging.Formatter("%(asctime)s;%(levelname)s;%(message)s", "%Y-%m-%d %H:%M:%S") 46 | ch.setFormatter(formatter) 47 | logger.addHandler(ch) 48 | 49 | 50 | def identity(batch): 51 | return batch 52 | 53 | 54 | # Copied from https://github.com/openai/whisper/blob/c09a7ae299c4c34c5839a76380ae407e7d785914/whisper/utils.py#L50 55 | def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = "."): 56 | if seconds is not None: 57 | milliseconds = round(seconds * 1000.0) 58 | 59 | hours = milliseconds // 3_600_000 60 | milliseconds -= hours * 3_600_000 61 | 62 | minutes = milliseconds // 60_000 63 | milliseconds -= minutes * 60_000 64 | 65 | seconds = milliseconds // 1_000 66 | milliseconds -= seconds * 1_000 67 | 68 | hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else "" 69 | return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" 70 | else: 71 | # we have a malformed timestamp so just return it as is 72 | return seconds 73 | 74 | 75 | if __name__ == "__main__": 76 | pipeline = FlaxWhisperPipline(checkpoint, dtype=jnp.bfloat16, batch_size=BATCH_SIZE) 77 | stride_length_s = CHUNK_LENGTH_S / 6 78 | chunk_len = round(CHUNK_LENGTH_S * pipeline.feature_extractor.sampling_rate) 79 | stride_left = stride_right = round(stride_length_s * pipeline.feature_extractor.sampling_rate) 80 | step = chunk_len - stride_left - stride_right 81 | pool = Pool(NUM_PROC) 82 | 83 | # do a pre-compile step so that the first user to use the demo isn't hit with a long transcription time 84 | logger.info("compiling forward call...") 85 | start = time.time() 86 | random_inputs = { 87 | "input_features": np.ones( 88 | (BATCH_SIZE, pipeline.model.config.num_mel_bins, 2 * pipeline.model.config.max_source_positions) 89 | ) 90 | } 91 | random_timestamps = pipeline.forward(random_inputs, batch_size=BATCH_SIZE, return_timestamps=True) 92 | compile_time = time.time() - start 93 | logger.info(f"compiled in {compile_time}s") 94 | 95 | def tqdm_generate(inputs: dict, task: str, return_timestamps: bool, progress: gr.Progress): 96 | inputs_len = inputs["array"].shape[0] 97 | all_chunk_start_idx = np.arange(0, inputs_len, step) 98 | num_samples = len(all_chunk_start_idx) 99 | num_batches = math.ceil(num_samples / BATCH_SIZE) 100 | dummy_batches = list( 101 | range(num_batches) 102 | ) # Gradio progress bar not compatible with generator, see https://github.com/gradio-app/gradio/issues/3841 103 | 104 | dataloader = pipeline.preprocess_batch(inputs, chunk_length_s=CHUNK_LENGTH_S, batch_size=BATCH_SIZE) 105 | progress(0, desc="Pre-processing audio file...") 106 | logger.info("pre-processing audio file...") 107 | dataloader = pool.map(identity, dataloader) 108 | logger.info("done post-processing") 109 | 110 | model_outputs = [] 111 | start_time = time.time() 112 | logger.info("transcribing...") 113 | # iterate over our chunked audio samples - always predict timestamps to reduce hallucinations 114 | for batch, _ in zip(dataloader, progress.tqdm(dummy_batches, desc="Transcribing...")): 115 | model_outputs.append(pipeline.forward(batch, batch_size=BATCH_SIZE, task=task, return_timestamps=True)) 116 | runtime = time.time() - start_time 117 | logger.info("done transcription") 118 | 119 | logger.info("post-processing...") 120 | post_processed = pipeline.postprocess(model_outputs, return_timestamps=True) 121 | text = post_processed["text"] 122 | if return_timestamps: 123 | timestamps = post_processed.get("chunks") 124 | timestamps = [ 125 | f"[{format_timestamp(chunk['timestamp'][0])} -> {format_timestamp(chunk['timestamp'][1])}] {chunk['text']}" 126 | for chunk in timestamps 127 | ] 128 | text = "\n".join(str(feature) for feature in timestamps) 129 | logger.info("done post-processing") 130 | return text, runtime 131 | 132 | def transcribe_chunked_audio(inputs, task, return_timestamps, progress=gr.Progress()): 133 | progress(0, desc="Loading audio file...") 134 | logger.info("loading audio file...") 135 | if inputs is None: 136 | logger.warning("No audio file") 137 | raise gr.Error("No audio file submitted! Please upload an audio file before submitting your request.") 138 | file_size_mb = os.stat(inputs).st_size / (1024 * 1024) 139 | if file_size_mb > FILE_LIMIT_MB: 140 | logger.warning("Max file size exceeded") 141 | raise gr.Error( 142 | f"File size exceeds file size limit. Got file of size {file_size_mb:.2f}MB for a limit of {FILE_LIMIT_MB}MB." 143 | ) 144 | 145 | with open(inputs, "rb") as f: 146 | inputs = f.read() 147 | 148 | inputs = ffmpeg_read(inputs, pipeline.feature_extractor.sampling_rate) 149 | inputs = {"array": inputs, "sampling_rate": pipeline.feature_extractor.sampling_rate} 150 | logger.info("done loading") 151 | text, runtime = tqdm_generate(inputs, task=task, return_timestamps=return_timestamps, progress=progress) 152 | return text, runtime 153 | 154 | def _return_yt_html_embed(yt_url): 155 | video_id = yt_url.split("?v=")[-1] 156 | HTML_str = ( 157 | f'
' 158 | "
" 159 | ) 160 | return HTML_str 161 | 162 | def download_yt_audio(yt_url, filename): 163 | info_loader = youtube_dl.YoutubeDL() 164 | try: 165 | info = info_loader.extract_info(yt_url, download=False) 166 | except youtube_dl.utils.DownloadError as err: 167 | raise gr.Error(str(err)) 168 | 169 | file_length = info["duration_string"] 170 | file_h_m_s = file_length.split(":") 171 | file_h_m_s = [int(sub_length) for sub_length in file_h_m_s] 172 | if len(file_h_m_s) == 1: 173 | file_h_m_s.insert(0, 0) 174 | if len(file_h_m_s) == 2: 175 | file_h_m_s.insert(0, 0) 176 | 177 | file_length_s = file_h_m_s[0] * 3600 + file_h_m_s[1] * 60 + file_h_m_s[2] 178 | if file_length_s > YT_LENGTH_LIMIT_S: 179 | yt_length_limit_hms = time.strftime("%HH:%MM:%SS", time.gmtime(YT_LENGTH_LIMIT_S)) 180 | file_length_hms = time.strftime("%HH:%MM:%SS", time.gmtime(file_length_s)) 181 | raise gr.Error(f"Maximum YouTube length is {yt_length_limit_hms}, got {file_length_hms} YouTube video.") 182 | 183 | ydl_opts = {"outtmpl": filename, "format": "worstvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best"} 184 | with youtube_dl.YoutubeDL(ydl_opts) as ydl: 185 | try: 186 | ydl.download([yt_url]) 187 | except youtube_dl.utils.ExtractorError as err: 188 | raise gr.Error(str(err)) 189 | 190 | def transcribe_youtube(yt_url, task, return_timestamps, progress=gr.Progress()): 191 | progress(0, desc="Loading audio file...") 192 | logger.info("loading youtube file...") 193 | html_embed_str = _return_yt_html_embed(yt_url) 194 | with tempfile.TemporaryDirectory() as tmpdirname: 195 | filepath = os.path.join(tmpdirname, "video.mp4") 196 | download_yt_audio(yt_url, filepath) 197 | 198 | with open(filepath, "rb") as f: 199 | inputs = f.read() 200 | 201 | inputs = ffmpeg_read(inputs, pipeline.feature_extractor.sampling_rate) 202 | inputs = {"array": inputs, "sampling_rate": pipeline.feature_extractor.sampling_rate} 203 | logger.info("done loading...") 204 | text, runtime = tqdm_generate(inputs, task=task, return_timestamps=return_timestamps, progress=progress) 205 | return html_embed_str, text, runtime 206 | 207 | microphone_chunked = gr.Interface( 208 | fn=transcribe_chunked_audio, 209 | inputs=[ 210 | gr.Audio(source="microphone", type="filepath"), 211 | gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"), 212 | gr.Checkbox(value=False, label="Return timestamps"), 213 | ], 214 | outputs=[ 215 | gr.Textbox(label="Transcription", show_copy_button=True), 216 | gr.Textbox(label="Transcription Time (s)"), 217 | ], 218 | allow_flagging="never", 219 | title=title, 220 | description=description, 221 | article=article, 222 | ) 223 | 224 | audio_chunked = gr.Interface( 225 | fn=transcribe_chunked_audio, 226 | inputs=[ 227 | gr.Audio(source="upload", label="Audio file", type="filepath"), 228 | gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"), 229 | gr.Checkbox(value=False, label="Return timestamps"), 230 | ], 231 | outputs=[ 232 | gr.Textbox(label="Transcription", show_copy_button=True), 233 | gr.Textbox(label="Transcription Time (s)"), 234 | ], 235 | allow_flagging="never", 236 | title=title, 237 | description=description, 238 | article=article, 239 | ) 240 | 241 | youtube = gr.Interface( 242 | fn=transcribe_youtube, 243 | inputs=[ 244 | gr.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL"), 245 | gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"), 246 | gr.Checkbox(value=False, label="Return timestamps"), 247 | ], 248 | outputs=[ 249 | gr.HTML(label="Video"), 250 | gr.Textbox(label="Transcription", show_copy_button=True), 251 | gr.Textbox(label="Transcription Time (s)"), 252 | ], 253 | allow_flagging="never", 254 | title=title, 255 | examples=[["https://www.youtube.com/watch?v=m8u-18Q0s7I", "transcribe", False]], 256 | cache_examples=False, 257 | description=description, 258 | article=article, 259 | ) 260 | 261 | demo = gr.Blocks() 262 | 263 | with demo: 264 | gr.TabbedInterface([microphone_chunked, audio_chunked, youtube], ["Microphone", "Audio File", "YouTube"]) 265 | 266 | demo.queue(max_size=5) 267 | demo.launch(server_name="0.0.0.0", show_api=False) 268 | -------------------------------------------------------------------------------- /app/monitor.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | waiting=0 4 | check_server() { 5 | url=http://localhost:7860 6 | response_code=$(curl -o /dev/null -s -w "%{http_code}" --connect-timeout 2 $url) 7 | [[ $response_code -ne 200 ]] && { 8 | return 0 9 | } 10 | return 1 11 | } 12 | 13 | while [ 1 ] 14 | do 15 | # periodically clear the /tmp directory for files created > 30 mins ago so it doesn't fill up 16 | sudo find /tmp -type f -amin +30 -delete 17 | check_server 18 | if [[ $? -ne 1 ]] 19 | then 20 | if [[ $waiting -eq 0 ]] 21 | then 22 | waiting=1 23 | echo "Restarting" 24 | pkill -9 python 25 | #sudo lsof -t /dev/accel0 | xargs kill -9 26 | sleep 5 27 | mv log.txt log_`date +%Y%m%d%H%M%S` 28 | TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD=10000000000 python ./app.py &> log.txt & 29 | else 30 | echo "Waiting for restart" 31 | fi 32 | else 33 | if [[ $waiting -eq 1 ]] 34 | then 35 | waiting=0 36 | echo "Restarted" 37 | fi 38 | fi 39 | sleep 10 40 | done 41 | -------------------------------------------------------------------------------- /app/run_app.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Sets up the environment, suitable to be run from a cron job, i.e.: 4 | # @reboot /home/sanchitgandhi/whisper-jax/run_app.sh 5 | cd ~/whisper-jax/app 6 | source ~/hf/bin/activate 7 | chmod +x monitor.sh 8 | ./monitor.sh & 9 | -------------------------------------------------------------------------------- /benchmarks/run_pipeline_dataloader.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | 4 | import numpy as np 5 | from torch.utils.data import DataLoader 6 | from transformers import WhisperProcessor 7 | from transformers.pipelines.base import no_collate_fn 8 | from transformers.pipelines.pt_utils import PipelineChunkIterator, PipelinePackIterator 9 | 10 | 11 | class Pipeline: 12 | """Relies on the data loaders defined in transformers Pipeline""" 13 | 14 | def __init__(self, checkpoint="openai/whisper-tiny.en"): 15 | self.checkpoint = checkpoint 16 | self.processor = WhisperProcessor.from_pretrained(self.checkpoint) 17 | self.feature_extractor = self.processor.feature_extractor 18 | self.tokenizer = self.processor.tokenizer 19 | 20 | @staticmethod 21 | def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right): 22 | inputs_len = inputs.shape[0] 23 | step = chunk_len - stride_left - stride_right 24 | for chunk_start_idx in range(0, inputs_len, step): 25 | chunk_end_idx = chunk_start_idx + chunk_len 26 | chunk = inputs[chunk_start_idx:chunk_end_idx] 27 | processed = feature_extractor(chunk, sampling_rate=feature_extractor.sampling_rate, return_tensors="np") 28 | _stride_left = 0 if chunk_start_idx == 0 else stride_left 29 | # all right strides must be full, otherwise it is the last item 30 | is_last = chunk_end_idx > inputs_len if stride_right > 0 else chunk_end_idx >= inputs_len 31 | _stride_right = 0 if is_last else stride_right 32 | 33 | chunk_len = chunk.shape[0] 34 | stride = (chunk_len, _stride_left, _stride_right) 35 | if chunk.shape[0] > _stride_left: 36 | yield {"is_last": is_last, "stride": stride, **processed} 37 | if is_last: 38 | break 39 | 40 | def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None): 41 | array = inputs.get("array") 42 | in_sampling_rate = inputs.get("sampling_rate") 43 | stride = inputs.get("stride", None) 44 | 45 | if in_sampling_rate != self.feature_extractor.sampling_rate: 46 | ratio = self.feature_extractor.sampling_rate / in_sampling_rate 47 | else: 48 | ratio = 1 49 | 50 | if stride is not None: 51 | if stride[0] + stride[1] > inputs.shape[0]: 52 | raise ValueError("Stride is too large for input") 53 | 54 | # Stride needs to get the chunk length here, it's going to get 55 | # swallowed by the `feature_extractor` later, and then batching 56 | # can add extra data in the inputs, so we need to keep track 57 | # of the original length in the stride so we can cut properly. 58 | stride = (inputs.shape[0], int(round(stride[0] * ratio)), int(round(stride[1] * ratio))) 59 | 60 | if chunk_length_s: 61 | if stride_length_s is None: 62 | stride_length_s = chunk_length_s / 6 63 | 64 | if isinstance(stride_length_s, (int, float)): 65 | stride_length_s = [stride_length_s, stride_length_s] 66 | 67 | chunk_len = round(chunk_length_s * self.feature_extractor.sampling_rate) 68 | stride_left = round(stride_length_s[0] * self.feature_extractor.sampling_rate) 69 | stride_right = round(stride_length_s[1] * self.feature_extractor.sampling_rate) 70 | 71 | if chunk_len < stride_left + stride_right: 72 | raise ValueError("Chunk length must be superior to stride length") 73 | 74 | for item in self.chunk_iter( 75 | array, 76 | self.feature_extractor, 77 | chunk_len, 78 | stride_left, 79 | stride_right, 80 | ): 81 | yield item 82 | else: 83 | processed = self.feature_extractor( 84 | array, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="np" 85 | ) 86 | if stride is not None: 87 | processed["stride"] = stride 88 | yield {"is_last": True, **processed} 89 | 90 | def forward(self, model_inputs, return_timestamps=False, generate_kwargs=None): 91 | print(model_inputs["stride"]) 92 | if generate_kwargs is None: 93 | generate_kwargs = {} 94 | 95 | if return_timestamps: 96 | generate_kwargs["return_timestamps"] = return_timestamps 97 | is_last = model_inputs.pop("is_last") 98 | 99 | out = {} 100 | stride = model_inputs.pop("stride", None) 101 | if stride is not None: 102 | out["stride"] = stride 103 | 104 | return {"is_last": is_last, **out} 105 | 106 | def __call__( 107 | self, 108 | inputs, 109 | chunk_length_s=0, 110 | stride_length_s=None, 111 | return_timestamps=None, 112 | return_language=None, 113 | generate_kwargs=None, 114 | batch_size=4, 115 | num_workers=1, 116 | ): 117 | dataset = PipelineChunkIterator( 118 | [inputs], self.preprocess, {"chunk_length_s": chunk_length_s, "stride_length_s": stride_length_s} 119 | ) 120 | collate_fn = no_collate_fn if batch_size == 1 else pad_collate_fn() 121 | dataloader = DataLoader(dataset, num_workers=num_workers, batch_size=batch_size, collate_fn=collate_fn) 122 | model_iterator = PipelinePackIterator(dataloader, self.forward, {}, loader_batch_size=batch_size) 123 | 124 | for batch in enumerate(model_iterator): 125 | continue 126 | 127 | 128 | def _pad(items, key): 129 | if isinstance(items[0][key], np.ndarray): 130 | if key == "input_features": 131 | # this is probably a mel spectrogram batched 132 | return np.concatenate([item[key] for item in items], axis=0) 133 | else: 134 | return [item[key] for item in items] 135 | 136 | 137 | def pad_collate_fn(): 138 | def inner(items): 139 | keys = set(items[0].keys()) 140 | for item in items: 141 | if set(item.keys()) != keys: 142 | raise ValueError( 143 | f"The elements of the batch contain different keys. Cannot batch them ({set(item.keys())} !=" 144 | f" {keys})" 145 | ) 146 | # input_values, input_pixels, input_ids, ... 147 | padded = {} 148 | for key in keys: 149 | padded[key] = _pad(items, key) 150 | return padded 151 | 152 | return inner 153 | 154 | 155 | class ManualIterator: 156 | """Manual implementation""" 157 | 158 | def __init__(self, checkpoint="openai/whisper-tiny.en"): 159 | self.checkpoint = checkpoint 160 | self.processor = WhisperProcessor.from_pretrained(self.checkpoint) 161 | self.feature_extractor = self.processor.feature_extractor 162 | self.tokenizer = self.processor.tokenizer 163 | 164 | @staticmethod 165 | def chunk_iter_with_batch(inputs, feature_extractor, chunk_len, stride_left, stride_right, batch_size): 166 | inputs_len = inputs.shape[0] 167 | step = chunk_len - stride_left - stride_right 168 | 169 | all_chunk_start_idx = np.arange(0, inputs_len, step) 170 | num_samples = len(all_chunk_start_idx) 171 | 172 | num_batches = math.ceil(num_samples / batch_size) 173 | batch_idx = np.array_split(np.arange(num_samples), num_batches) 174 | 175 | for i, idx in enumerate(batch_idx): 176 | chunk_start_idx = all_chunk_start_idx[idx] 177 | 178 | chunk_end_idx = chunk_start_idx + chunk_len 179 | 180 | chunks = [inputs[chunk_start:chunk_end] for chunk_start, chunk_end in zip(chunk_start_idx, chunk_end_idx)] 181 | processed = feature_extractor(chunks, sampling_rate=feature_extractor.sampling_rate, return_tensors="np") 182 | 183 | _stride_left = np.where(chunk_start_idx == 0, 0, stride_left) 184 | is_last = np.where(stride_right > 0, chunk_end_idx > inputs_len, chunk_end_idx >= inputs_len) 185 | _stride_right = np.where(is_last, 0, stride_right) 186 | 187 | chunk_lens = [chunk.shape[0] for chunk in chunks] 188 | strides = [ 189 | (chunk_l, _stride_l, _stride_r) 190 | for chunk_l, _stride_l, _stride_r in zip(chunk_lens, _stride_left, _stride_right) 191 | ] 192 | 193 | yield {"stride": strides, **processed} 194 | 195 | def preprocess_batch(self, inputs, chunk_length_s=0, stride_length_s=None, batch_size=None): 196 | array = inputs.get("array") 197 | in_sampling_rate = inputs.get("sampling_rate") 198 | stride = inputs.get("stride", None) 199 | 200 | if in_sampling_rate != self.feature_extractor.sampling_rate: 201 | ratio = self.feature_extractor.sampling_rate / in_sampling_rate 202 | else: 203 | ratio = 1 204 | 205 | if stride is not None: 206 | if stride[0] + stride[1] > inputs.shape[0]: 207 | raise ValueError("Stride is too large for input") 208 | 209 | # Stride needs to get the chunk length here, it's going to get 210 | # swallowed by the `feature_extractor` later, and then batching 211 | # can add extra data in the inputs, so we need to keep track 212 | # of the original length in the stride so we can cut properly. 213 | stride = (inputs.shape[0], int(round(stride[0] * ratio)), int(round(stride[1] * ratio))) 214 | 215 | if chunk_length_s: 216 | if stride_length_s is None: 217 | stride_length_s = chunk_length_s / 6 218 | 219 | if isinstance(stride_length_s, (int, float)): 220 | stride_length_s = [stride_length_s, stride_length_s] 221 | 222 | chunk_len = round(chunk_length_s * self.feature_extractor.sampling_rate) 223 | stride_left = round(stride_length_s[0] * self.feature_extractor.sampling_rate) 224 | stride_right = round(stride_length_s[1] * self.feature_extractor.sampling_rate) 225 | 226 | if chunk_len < stride_left + stride_right: 227 | raise ValueError("Chunk length must be superior to stride length") 228 | 229 | for item in self.chunk_iter_with_batch( 230 | array, 231 | self.feature_extractor, 232 | chunk_len, 233 | stride_left, 234 | stride_right, 235 | batch_size, 236 | ): 237 | yield item 238 | else: 239 | processed = self.feature_extractor( 240 | array, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="np" 241 | ) 242 | if stride is not None: 243 | processed["stride"] = stride 244 | yield processed 245 | 246 | def forward(self, model_inputs, return_timestamps=False, generate_kwargs=None): 247 | if generate_kwargs is None: 248 | generate_kwargs = {} 249 | 250 | if return_timestamps: 251 | generate_kwargs["return_timestamps"] = return_timestamps 252 | 253 | out = {} 254 | stride = model_inputs.pop("stride", None) 255 | if stride is not None: 256 | out["stride"] = stride 257 | 258 | def __call__( 259 | self, 260 | inputs, 261 | chunk_length_s=0, 262 | stride_length_s=None, 263 | return_timestamps=None, 264 | return_language=None, 265 | generate_kwargs=None, 266 | batch_size=4, 267 | num_workers=1, 268 | ): 269 | dataloader = self.preprocess_batch(inputs, chunk_length_s, stride_length_s, batch_size) 270 | 271 | for batch in dataloader: 272 | print(batch["stride"]) 273 | self.forward(batch) 274 | 275 | 276 | pipeline = Pipeline() 277 | custom_pipeline = ManualIterator() 278 | 279 | powers = np.logspace(0, 6, 7, base=2) 280 | input_lengths = [30, 60, 120, 240] 281 | 282 | for input_len in input_lengths: 283 | print(f"=== Input len {input_len} ===") 284 | inputs = {"array": np.ones(16000 * (int(input_len))), "sampling_rate": 16000} 285 | 286 | # first run our custom one since it doesn't consume the audio input 287 | start = time.time() 288 | custom_pipeline(inputs, chunk_length_s=30) 289 | runtime = time.time() - start 290 | print(f"Manual: {runtime}") 291 | 292 | # now run the transformers based one 293 | start = time.time() 294 | pipeline(inputs, chunk_length_s=30) 295 | runtime = time.time() - start 296 | print(f"Pipeline: {runtime}") 297 | 298 | 299 | """ 300 | Results: 301 | === Input len 30.0 === 302 | Manual: 0.09786629676818848 303 | Pipeline: 0.4056670665740967 304 | 305 | === Input len 60.0 === 306 | Manual: 0.14911437034606934 307 | Pipeline: 0.6139914989471436 308 | 309 | === Input len 120.0 === 310 | Manual: 0.29366159439086914 311 | Pipeline: 0.9421713352203369 312 | 313 | === Input len 240.0 === 314 | Manual: 0.5912315845489502 315 | Pipeline: 1.2646639347076416 316 | 317 | === Input len 480.0 === 318 | Manual: 1.1709322929382324 319 | Pipeline: 1.6115500926971436 320 | 321 | === Input len 960.0 === 322 | Manual: 2.373847246170044 323 | Pipeline: 3.2664968967437744 324 | 325 | === Input len 1920.0 === 326 | Manual: 4.619845151901245 327 | Pipeline: 5.50755500793457 328 | """ 329 | -------------------------------------------------------------------------------- /benchmarks/run_pjit.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | 4 | import datasets 5 | import jax 6 | import jax.numpy as jnp 7 | from datasets import concatenate_datasets, load_dataset 8 | from flax.core.frozen_dict import freeze 9 | from jax.experimental.compilation_cache import compilation_cache as cc 10 | from jax.sharding import PartitionSpec as P 11 | from transformers import WhisperConfig, WhisperProcessor 12 | 13 | from whisper_jax import FlaxWhisperForConditionalGeneration, InferenceState, PjitPartitioner 14 | 15 | 16 | datasets.logging.set_verbosity(datasets.logging.CRITICAL) 17 | 18 | cc.initialize_cache("./jax_cache") 19 | jax.config.update("jax_array", True) 20 | 21 | 22 | def parse_args(): 23 | parser = argparse.ArgumentParser(description="Benchmark Whisper large-v2") 24 | parser.add_argument( 25 | "--model_parallel_submesh", 26 | type=int, 27 | nargs="+", 28 | default=(2, 2, 1, 1), 29 | help="Model parallel submesh.", 30 | ) 31 | args = parser.parse_args() 32 | return args 33 | 34 | 35 | BATCH_SIZES = [4, 8, 16, 32] 36 | NUM_BATCHES = 100 37 | NUM_TOKENS = 25 38 | CHECKPOINT = "large-v2" 39 | 40 | # 2D parameter and activation partitioning for DP 41 | logical_axis_rules_dp = [ 42 | ("batch", "data"), 43 | ("mlp", None), 44 | ("heads", None), 45 | ("vocab", None), 46 | ("embed", None), 47 | ("embed", None), 48 | ("joined_kv", None), 49 | ("kv", None), 50 | ("length", None), 51 | ("num_mel", None), 52 | ("channels", None), 53 | ] 54 | 55 | 56 | def main(): 57 | args = parse_args() 58 | print(args.model_parallel_submesh) 59 | # processors/tokenizers are the same for all models, so just load from tiny and preprocess once 60 | processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") 61 | 62 | def preprocess(batch): 63 | batch["input_features"] = processor( 64 | batch["audio"]["array"], sampling_rate=16000, return_tensors="np" 65 | ).input_features[0] 66 | return batch 67 | 68 | librispeech = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") 69 | dataset_processed = librispeech.map(preprocess, remove_columns=librispeech.column_names) 70 | 71 | config = WhisperConfig.from_pretrained(f"openai/whisper-{CHECKPOINT}") 72 | model = FlaxWhisperForConditionalGeneration(config, _do_init=False, dtype=jnp.bfloat16) 73 | # to init the params 74 | params = model.init_weights(model.key, model.input_shape) 75 | 76 | def init_fn(): 77 | input_shape = (1, 80, 3000) 78 | 79 | input_features = jnp.zeros(input_shape, dtype="f4") 80 | input_features = input_features.at[(..., -1)].set(model.config.eos_token_id) 81 | 82 | decoder_input_ids = jnp.zeros((input_shape[0], 1), dtype="i4") 83 | decoder_attention_mask = jnp.ones_like(decoder_input_ids) 84 | 85 | batch_size, sequence_length = decoder_input_ids.shape 86 | decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) 87 | 88 | rng = jax.random.PRNGKey(0) 89 | init_params = model.module.init( 90 | rng, 91 | input_features=input_features, 92 | decoder_input_ids=decoder_input_ids, 93 | decoder_attention_mask=decoder_attention_mask, 94 | decoder_position_ids=decoder_position_ids, 95 | return_dict=False, 96 | ) 97 | return init_params 98 | 99 | # Axis names metadata 100 | param_axes = jax.eval_shape(init_fn)["params_axes"] 101 | 102 | # Create InferenceState, since the partitioner expects it 103 | state = InferenceState( 104 | step=jnp.array(0), 105 | params=freeze(model.params_shape_tree), 106 | params_axes=freeze(param_axes), 107 | flax_mutables=None, 108 | flax_mutables_axes=param_axes, 109 | ) 110 | 111 | partitioner = PjitPartitioner( 112 | num_partitions=1, 113 | logical_axis_rules=logical_axis_rules_dp, 114 | ) 115 | 116 | mesh_axes = partitioner.get_mesh_axes(state) 117 | params_spec = mesh_axes.params 118 | 119 | p_shard_params = partitioner.partition(model.to_bf16, (params_spec,), params_spec) 120 | 121 | def generate(params, input_features): 122 | output_ids = model.generate(input_features, params=params, max_new_tokens=25).sequences 123 | return output_ids 124 | 125 | p_generate = partitioner.partition( 126 | generate, 127 | in_axis_resources=(params_spec, P("data")), 128 | out_axis_resources=P("data"), 129 | ) 130 | 131 | # This will auto-magically run in mesh context 132 | params = p_shard_params(freeze(params)) 133 | 134 | for batch_size in BATCH_SIZES: 135 | eval_dataset = dataset_processed.select(range(batch_size // 2)) 136 | eval_dataset = concatenate_datasets([eval_dataset for _ in range(2 * NUM_BATCHES)]) 137 | 138 | eval_dataloader = eval_dataset.with_format("numpy").iter(batch_size=batch_size) 139 | 140 | # warm-up step 141 | batch = next(iter(eval_dataloader)) 142 | p_generate(freeze(params), batch["input_features"]) 143 | 144 | start = time.time() 145 | for batch in eval_dataloader: 146 | p_generate(freeze(params), batch["input_features"]) 147 | runtime = time.time() - start 148 | 149 | print(f"{batch_size}: {runtime:.06}") 150 | 151 | 152 | if __name__ == "__main__": 153 | main() 154 | -------------------------------------------------------------------------------- /benchmarks/run_pjit_dataloader.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | import numpy as np 6 | from datasets import load_dataset 7 | from flax.core.frozen_dict import freeze 8 | from jax.experimental.compilation_cache import compilation_cache as cc 9 | from jax.sharding import PartitionSpec as P 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | from transformers import WhisperProcessor 13 | 14 | from whisper_jax import FlaxWhisperForConditionalGeneration, InferenceState, PjitPartitioner 15 | 16 | 17 | cc.initialize_cache("./jax_cache") 18 | jax.config.update("jax_array", True) 19 | 20 | BATCH_SIZE = 128 21 | NUM_TOKENS = 256 22 | NUM_BATCHES = 5 23 | NUM_WORKERS = 4 24 | CHECKPOINT = "large-v2" 25 | STREAMING = True 26 | 27 | # 2D parameter and activation partitioning for DP 28 | logical_axis_rules_dp = [ 29 | ("batch", "data"), 30 | ("mlp", None), 31 | ("heads", None), 32 | ("vocab", None), 33 | ("embed", None), 34 | ("embed", None), 35 | ("joined_kv", None), 36 | ("kv", None), 37 | ("length", None), 38 | ("num_mel", None), 39 | ("channels", None), 40 | ] 41 | 42 | model, params = FlaxWhisperForConditionalGeneration.from_pretrained( 43 | f"openai/whisper-{CHECKPOINT}", 44 | _do_init=False, 45 | dtype=jnp.bfloat16, 46 | ) 47 | 48 | 49 | def init_fn(): 50 | input_shape = (1, 80, 3000) 51 | 52 | input_features = jnp.zeros(input_shape, dtype="f4") 53 | input_features = input_features.at[(..., -1)].set(model.config.eos_token_id) 54 | 55 | decoder_input_ids = jnp.zeros((input_shape[0], 1), dtype="i4") 56 | decoder_attention_mask = jnp.ones_like(decoder_input_ids) 57 | 58 | batch_size, sequence_length = decoder_input_ids.shape 59 | decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) 60 | 61 | rng = jax.random.PRNGKey(0) 62 | init_params = model.module.init( 63 | rng, 64 | input_features=input_features, 65 | decoder_input_ids=decoder_input_ids, 66 | decoder_attention_mask=decoder_attention_mask, 67 | decoder_position_ids=decoder_position_ids, 68 | return_dict=False, 69 | ) 70 | return init_params 71 | 72 | 73 | # Axis names metadata 74 | param_axes = jax.eval_shape(init_fn)["params_axes"] 75 | 76 | # Create InferenceState, since the partitioner expects it 77 | state = InferenceState( 78 | step=jnp.array(0), 79 | params=freeze(model.params_shape_tree), 80 | params_axes=freeze(param_axes), 81 | flax_mutables=None, 82 | flax_mutables_axes=param_axes, 83 | ) 84 | 85 | partitioner = PjitPartitioner( 86 | num_partitions=1, 87 | logical_axis_rules=logical_axis_rules_dp, 88 | ) 89 | 90 | mesh_axes = partitioner.get_mesh_axes(state) 91 | params_spec = mesh_axes.params 92 | 93 | p_shard_params = partitioner.partition(model.to_bf16, (params_spec,), params_spec) 94 | 95 | 96 | def generate(params, input_features): 97 | output_ids = model.generate(input_features, params=params, max_length=NUM_TOKENS).sequences 98 | return output_ids 99 | 100 | 101 | p_generate = partitioner.partition( 102 | generate, 103 | in_axis_resources=(params_spec, P("data")), 104 | out_axis_resources=P("data"), 105 | ) 106 | 107 | # This will auto-magically run in mesh context 108 | params = p_shard_params(freeze(params)) 109 | 110 | # warm-up 111 | pred_ids = p_generate(freeze(params), np.ones((BATCH_SIZE, 80, 3000))) 112 | 113 | # processors/tokenizers are the same for all models, so just load from tiny and preprocess once 114 | processor = WhisperProcessor.from_pretrained("openai/whisper-large-v2") 115 | 116 | 117 | def preprocess(batch): 118 | batch["input_features"] = processor( 119 | batch["audio"]["array"], sampling_rate=16000, return_tensors="np" 120 | ).input_features[0] 121 | return batch 122 | 123 | 124 | librispeech = load_dataset("speechcolab/gigaspeech", "l", split="train", streaming=STREAMING, use_auth_token=True) 125 | librispeech_features = librispeech.features.keys() 126 | 127 | librispeech_processed = librispeech.map(preprocess, remove_columns=librispeech_features) 128 | 129 | eval_dataloader = DataLoader( 130 | librispeech_processed, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, persistent_workers=True 131 | ) 132 | 133 | all_load_times = 0 134 | all_runtimes = 0 135 | 136 | load_start = time.time() 137 | for i, batch in tqdm(enumerate(eval_dataloader)): 138 | input_features = np.asarray(batch.pop("input_features")) 139 | load_time = time.time() - load_start 140 | all_load_times += load_time 141 | 142 | input_batch_size = input_features.shape[0] 143 | 144 | if input_batch_size != BATCH_SIZE: 145 | padding = np.zeros([BATCH_SIZE - input_batch_size, *input_features.shape[1:]], input_features.dtype) 146 | input_features = np.concatenate([input_features, padding]) 147 | 148 | generate_start = time.time() 149 | pred_ids = p_generate(freeze(params), input_features)[:input_batch_size] 150 | generate_runtime = time.time() - generate_start 151 | all_runtimes += generate_runtime 152 | 153 | pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True) 154 | 155 | if i + 1 == NUM_BATCHES: 156 | break 157 | 158 | print(all_load_times) 159 | print(all_runtimes) 160 | -------------------------------------------------------------------------------- /benchmarks/run_pmap.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | from datasets import concatenate_datasets, load_dataset 6 | from flax import jax_utils 7 | from flax.training.common_utils import shard 8 | from transformers import FlaxWhisperForConditionalGeneration, WhisperProcessor 9 | 10 | 11 | BATCH_SIZES = [4, 8, 16, 32, 64, 128] 12 | NUM_BATCHES = 100 13 | NUM_TOKENS = 25 14 | 15 | model, params = FlaxWhisperForConditionalGeneration.from_pretrained( 16 | "openai/whisper-large-v2", 17 | _do_init=False, 18 | dtype=jnp.bfloat16, 19 | ) 20 | 21 | params = jax_utils.replicate(params) 22 | 23 | 24 | def generate_fn(batch): 25 | pred_ids = model.generate(batch, params=params, max_new_tokens=NUM_TOKENS, min_new_tokens=NUM_TOKENS) 26 | return pred_ids.sequences 27 | 28 | 29 | p_generate_fn = jax.pmap(generate_fn, "batch") 30 | 31 | # processors/tokenizers are the same for all models, so just load from tiny and preprocess once 32 | processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") 33 | 34 | 35 | def preprocess(batch): 36 | batch["input_features"] = processor( 37 | batch["audio"]["array"], sampling_rate=16000, return_tensors="np" 38 | ).input_features[0] 39 | return batch 40 | 41 | 42 | # load a dataset of 73 audio samples 43 | librispeech = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") 44 | dataset_processed = librispeech.map(preprocess, remove_columns=librispeech.column_names) 45 | 46 | for batch_size in BATCH_SIZES: 47 | eval_dataset = dataset_processed.select(range(batch_size // 2)) 48 | eval_dataset = concatenate_datasets([eval_dataset for _ in range(2 * NUM_BATCHES)]) 49 | 50 | eval_dataloader = eval_dataset.with_format("numpy").iter(batch_size=batch_size) 51 | 52 | # warm-up step 53 | batch = next(iter(eval_dataloader)) 54 | input_features = shard(batch["input_features"]) 55 | pred_ids = p_generate_fn(input_features) 56 | 57 | start = time.time() 58 | for batch in eval_dataloader: 59 | input_features = shard(batch["input_features"]) 60 | pred_ids = p_generate_fn(input_features) 61 | runtime = time.time() - start 62 | 63 | print(f"{batch_size}: {runtime:.06}") 64 | -------------------------------------------------------------------------------- /benchmarks/run_pytorch.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from datasets import concatenate_datasets, load_dataset 4 | from torch.utils.data import DataLoader 5 | from transformers import WhisperForConditionalGeneration, WhisperProcessor 6 | 7 | 8 | BATCH_SIZES = [4, 8, 16, 32, 64, 128] 9 | NUM_BATCHES = 100 10 | NUM_TOKENS = 25 11 | 12 | model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v2") 13 | model.to("cuda").half() 14 | 15 | # processors/tokenizers are the same for all models, so just load from tiny and preprocess once 16 | processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") 17 | 18 | 19 | def preprocess(batch): 20 | batch["input_features"] = processor( 21 | batch["audio"]["array"], sampling_rate=16000, return_tensors="pt" 22 | ).input_features[0] 23 | return batch 24 | 25 | 26 | # load a dataset of 73 audio samples 27 | librispeech = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") 28 | dataset_processed = librispeech.map(preprocess, remove_columns=librispeech.column_names) 29 | 30 | for batch_size in BATCH_SIZES: 31 | eval_dataset = dataset_processed.select(range(batch_size // 2)) 32 | eval_dataset = concatenate_datasets([eval_dataset for _ in range(2 * NUM_BATCHES)]) 33 | 34 | dataloader = DataLoader( 35 | dataset=eval_dataset.with_format("torch"), batch_size=batch_size, num_workers=4, pin_memory=True 36 | ) 37 | 38 | # generate 39 | start = time.time() 40 | for batch in dataloader: 41 | input_features = batch["input_features"].to("cuda").half() 42 | out = model.generate(input_features, max_new_tokens=NUM_TOKENS, min_new_tokens=NUM_TOKENS) 43 | runtime = time.time() - start 44 | print(f"{batch_size}: {runtime:.06}") 45 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 119 3 | target-version = ['py37'] 4 | 5 | [tool.ruff] 6 | # Never enforce `E501` (line length violations). 7 | ignore = ["C901", "E501", "E741", "W605"] 8 | select = ["C", "E", "F", "I", "W"] 9 | line-length = 119 10 | 11 | # Ignore import violations in all `__init__.py` files. 12 | [tool.ruff.per-file-ignores] 13 | "__init__.py" = ["E402", "F401", "F403", "F811"] 14 | 15 | [tool.ruff.isort] 16 | lines-after-imports = 2 17 | known-first-party = ["distil_whisper"] -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import os 17 | 18 | # To use a consistent encoding 19 | from codecs import open 20 | 21 | from setuptools import find_packages, setup 22 | 23 | 24 | _deps = [ 25 | "transformers>=4.27.4,<4.35.0", 26 | "flax", 27 | "cached-property", 28 | ] 29 | 30 | _extras_dev_deps = [ 31 | "black~=23.1", 32 | "isort>=5.5.4", 33 | "ruff>=0.0.241,<=0.0.259", 34 | ] 35 | 36 | _extras_endpoint_deps = [ 37 | "gradio>=4.9.1", 38 | "requests>=2.28.2", 39 | "yt-dlp>=2023.3.4", 40 | ] 41 | 42 | 43 | here = os.path.abspath(os.path.dirname(__file__)) 44 | 45 | with open(os.path.join(here, "README.md"), encoding="utf-8") as f: 46 | long_description = f.read() 47 | 48 | # read version 49 | with open(os.path.join(here, "whisper_jax", "__init__.py"), encoding="utf-8") as f: 50 | for line in f: 51 | if line.startswith("__version__"): 52 | version = line.split("=")[1].strip().strip('"') 53 | break 54 | else: 55 | raise RuntimeError("Unable to find version string.") 56 | 57 | setup( 58 | name="whisper_jax", 59 | version=version, 60 | description="JAX implementation of OpenAI's Whisper model.", 61 | long_description=long_description, 62 | long_description_content_type="text/markdown", 63 | packages=find_packages(), 64 | install_requires=_deps, 65 | extras_require={ 66 | "dev": [_extras_dev_deps], 67 | "endpoint": [_extras_endpoint_deps], 68 | }, 69 | ) 70 | -------------------------------------------------------------------------------- /whisper_jax/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | __version__ = "0.0.1" 17 | 18 | from .modeling_flax_whisper import FlaxWhisperForConditionalGeneration 19 | from .partitioner import PjitPartitioner 20 | from .pipeline import FlaxWhisperPipline 21 | from .train_state import InferenceState 22 | -------------------------------------------------------------------------------- /whisper_jax/layers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The T5X Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Dense attention classes and mask/weighting functions.""" 17 | 18 | # pylint: disable=attribute-defined-outside-init,g-bare-generic 19 | 20 | import dataclasses 21 | import functools 22 | import operator 23 | from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Union 24 | 25 | import jax 26 | import jax.numpy as jnp 27 | import numpy as np 28 | from flax import linen as nn 29 | from flax.linen import partitioning as nn_partitioning 30 | from flax.linen.dtypes import promote_dtype 31 | from jax import lax, random 32 | 33 | 34 | # from flax.linen.partitioning import param_with_axes, with_sharding_constraint 35 | param_with_axes = nn_partitioning.param_with_axes 36 | with_sharding_constraint = nn_partitioning.with_sharding_constraint 37 | 38 | 39 | # Type annotations 40 | Array = jnp.ndarray 41 | DType = jnp.dtype 42 | PRNGKey = jnp.ndarray 43 | Shape = Iterable[int] 44 | Activation = Callable[..., Array] 45 | PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]] 46 | DotGeneralT = Callable[..., Array] 47 | ConvGeneralDilatedT = Callable[..., Array] 48 | PaddingLike = Union[str, int, Sequence[Union[int, Tuple[int, int]]]] 49 | LaxPadding = Union[str, Sequence[Tuple[int, int]]] 50 | 51 | # Parameter initializers. 52 | Initializer = Callable[[PRNGKey, Shape, DType], Array] 53 | InitializerAxis = Union[int, Tuple[int, ...]] 54 | NdInitializer = Callable[[PRNGKey, Shape, DType, InitializerAxis, InitializerAxis], Array] 55 | 56 | default_embed_init = nn.initializers.variance_scaling(1.0, "fan_in", "normal", out_axis=0) 57 | 58 | 59 | # ------------------------------------------------------------------------------ 60 | # Temporary inlined JAX N-d initializer code 61 | # TODO(levskaya): remove once new JAX release is out. 62 | # ------------------------------------------------------------------------------ 63 | def _compute_fans(shape: jax.core.NamedShape, in_axis=-2, out_axis=-1): 64 | """Inlined JAX `nn.initializer._compute_fans`.""" 65 | if isinstance(in_axis, int): 66 | in_size = shape[in_axis] 67 | else: 68 | in_size = int(np.prod([shape[i] for i in in_axis])) 69 | if isinstance(out_axis, int): 70 | out_size = shape[out_axis] 71 | else: 72 | out_size = int(np.prod([shape[i] for i in out_axis])) 73 | receptive_field_size = shape.total / in_size / out_size 74 | fan_in = in_size * receptive_field_size 75 | fan_out = out_size * receptive_field_size 76 | return fan_in, fan_out 77 | 78 | 79 | def variance_scaling(scale, mode, distribution, in_axis=-2, out_axis=-1, dtype=jnp.float_): 80 | """Inlined JAX `nn.initializer.variance_scaling`.""" 81 | 82 | def init(key, shape, dtype=dtype): 83 | return jnp.zeros(shape, dtype=dtype) 84 | dtype = jax.dtypes.canonicalize_dtype(dtype) 85 | shape = jax.core.as_named_shape(shape) 86 | fan_in, fan_out = _compute_fans(shape, in_axis, out_axis) 87 | if mode == "fan_in": 88 | denominator = fan_in 89 | elif mode == "fan_out": 90 | denominator = fan_out 91 | elif mode == "fan_avg": 92 | denominator = (fan_in + fan_out) / 2 93 | else: 94 | raise ValueError("invalid mode for variance scaling initializer: {}".format(mode)) 95 | variance = jnp.array(scale / denominator, dtype=dtype) 96 | 97 | if distribution == "truncated_normal": 98 | # constant is stddev of standard normal truncated to (-2, 2) 99 | stddev = jnp.sqrt(variance) / jnp.array(0.87962566103423978, dtype) 100 | return random.truncated_normal(key, -2, 2, shape, dtype) * stddev 101 | elif distribution == "normal": 102 | return random.normal(key, shape, dtype) * jnp.sqrt(variance) 103 | elif distribution == "uniform": 104 | return random.uniform(key, shape, dtype, -1) * jnp.sqrt(3 * variance) 105 | else: 106 | raise ValueError("invalid distribution for variance scaling " "initializer: {}".format(distribution)) 107 | 108 | return init 109 | 110 | 111 | # ------------------------------------------------------------------------------ 112 | 113 | 114 | def nd_dense_init(scale, mode, distribution): 115 | """Initializer with in_axis, out_axis set at call time.""" 116 | 117 | def init_fn(key, shape, dtype, in_axis, out_axis): 118 | fn = variance_scaling(scale, mode, distribution, in_axis, out_axis) 119 | return fn(key, shape, dtype) 120 | 121 | return init_fn 122 | 123 | 124 | def dot_product_attention( 125 | query: Array, 126 | key: Array, 127 | value: Array, 128 | bias: Optional[Array] = None, 129 | dropout_rng: Optional[PRNGKey] = None, 130 | dropout_rate: float = 0.0, 131 | deterministic: bool = False, 132 | dtype: DType = jnp.float32, 133 | float32_logits: bool = False, 134 | ): 135 | """Computes dot-product attention given query, key, and value. 136 | 137 | This is the core function for applying attention based on 138 | https://arxiv.org/abs/1706.03762. It calculates the attention weights given 139 | query and key and combines the values using the attention weights. 140 | 141 | Args: 142 | query: queries for calculating attention with shape of `[batch, q_length, 143 | num_heads, qk_depth_per_head]`. 144 | key: keys for calculating attention with shape of `[batch, kv_length, 145 | num_heads, qk_depth_per_head]`. 146 | value: values to be used in attention with shape of `[batch, kv_length, 147 | num_heads, v_depth_per_head]`. 148 | bias: bias for the attention weights. This should be broadcastable to the 149 | shape `[batch, num_heads, q_length, kv_length]` This can be used for 150 | incorporating causal masks, padding masks, proximity bias, etc. 151 | dropout_rng: JAX PRNGKey: to be used for dropout 152 | dropout_rate: dropout rate 153 | deterministic: bool, deterministic or not (to apply dropout) 154 | dtype: the dtype of the computation (default: float32) 155 | float32_logits: bool, if True then compute logits in float32 to avoid 156 | numerical issues with bfloat16. 157 | 158 | Returns: 159 | Output of shape `[batch, length, num_heads, v_depth_per_head]`. 160 | """ 161 | assert key.ndim == query.ndim == value.ndim, "q, k, v must have same rank." 162 | assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], "q, k, v batch dims must match." 163 | assert query.shape[-2] == key.shape[-2] == value.shape[-2], "q, k, v num_heads must match." 164 | assert key.shape[-3] == value.shape[-3], "k, v lengths must match." 165 | assert query.shape[-1] == key.shape[-1], "q, k depths must match." 166 | 167 | # Casting logits and softmax computation for float32 for model stability. 168 | if float32_logits: 169 | query = query.astype(jnp.float32) 170 | key = key.astype(jnp.float32) 171 | 172 | # `attn_weights`: [batch, num_heads, q_length, kv_length] 173 | attn_weights = jnp.einsum("bqhd,bkhd->bhqk", query, key) 174 | 175 | # Apply attention bias: masking, dropout, proximity bias, etc. 176 | if bias is not None: 177 | attn_weights = attn_weights + bias.astype(attn_weights.dtype) 178 | 179 | # Normalize the attention weights across `kv_length` dimension. 180 | attn_weights = jax.nn.softmax(attn_weights).astype(dtype) 181 | 182 | # Apply attention dropout. 183 | if not deterministic and dropout_rate > 0.0: 184 | keep_prob = 1.0 - dropout_rate 185 | # T5 broadcasts along the "length" dim, but unclear which one that 186 | # corresponds to in positional dimensions here, assuming query dim. 187 | dropout_shape = list(attn_weights.shape) 188 | dropout_shape[-2] = 1 189 | keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape) 190 | keep = jnp.broadcast_to(keep, attn_weights.shape) 191 | multiplier = keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype) 192 | attn_weights = attn_weights * multiplier 193 | 194 | # Take the linear combination of `value`. 195 | return jnp.einsum("bhqk,bkhd->bqhd", attn_weights, value) 196 | 197 | 198 | dynamic_vector_slice_in_dim = jax.vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None)) 199 | 200 | 201 | class MultiHeadDotProductAttention(nn.Module): 202 | """Multi-head dot-product attention. 203 | 204 | Attributes: 205 | num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) 206 | should be divisible by the number of heads. 207 | head_dim: dimension of each head. 208 | dtype: the dtype of the computation. 209 | dropout_rate: dropout rate 210 | kernel_init: initializer for the kernel of the Dense layers. 211 | float32_logits: bool, if True then compute logits in float32 to avoid 212 | numerical issues with bfloat16. 213 | """ 214 | 215 | num_heads: int 216 | head_dim: int 217 | dtype: DType = jnp.float32 218 | dropout_rate: float = 0.0 219 | kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "normal") 220 | float32_logits: bool = False # computes logits in float32 for stability. 221 | 222 | @nn.compact 223 | def __call__( 224 | self, 225 | inputs_q: Array, 226 | inputs_kv: Array, 227 | mask: Optional[Array] = None, 228 | bias: Optional[Array] = None, 229 | *, 230 | decode: bool = False, 231 | deterministic: bool = False, 232 | ) -> Array: 233 | """Applies multi-head dot product attention on the input data. 234 | 235 | Projects the inputs into multi-headed query, key, and value vectors, 236 | applies dot-product attention and project the results to an output vector. 237 | 238 | There are two modes: decoding and non-decoding (e.g., training). The mode is 239 | determined by `decode` argument. For decoding, this method is called twice, 240 | first to initialize the cache and then for an actual decoding process. The 241 | two calls are differentiated by the presence of 'cached_key' in the variable 242 | dict. In the cache initialization stage, the cache variables are initialized 243 | as zeros and will be filled in the subsequent decoding process. 244 | 245 | In the cache initialization call, `inputs_q` has a shape [batch, length, 246 | q_features] and `inputs_kv`: [batch, length, kv_features]. During the 247 | incremental decoding stage, query, key and value all have the shape [batch, 248 | 1, qkv_features] corresponding to a single step. 249 | 250 | Args: 251 | inputs_q: input queries of shape `[batch, q_length, q_features]`. 252 | inputs_kv: key/values of shape `[batch, kv_length, kv_features]`. 253 | mask: attention mask of shape `[batch, num_heads, q_length, kv_length]`. 254 | bias: attention bias of shape `[batch, num_heads, q_length, kv_length]`. 255 | decode: Whether to prepare and use an autoregressive cache. 256 | deterministic: Disables dropout if set to True. 257 | 258 | Returns: 259 | output of shape `[batch, length, q_features]`. 260 | """ 261 | projection = functools.partial( 262 | DenseGeneral, 263 | axis=-1, 264 | features=(self.num_heads, self.head_dim), 265 | kernel_axes=("embed", "heads", "kv"), 266 | dtype=self.dtype, 267 | ) 268 | 269 | # NOTE: T5 does not explicitly rescale the attention logits by 270 | # 1/sqrt(depth_kq)! This is folded into the initializers of the 271 | # linear transformations, which is equivalent under Adafactor. 272 | depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype) 273 | 274 | def query_init(*args): 275 | return self.kernel_init(*args) / depth_scaling 276 | 277 | # Project inputs_q to multi-headed q/k/v 278 | # dimensions are then [batch, length, num_heads, head_dim] 279 | query = projection(kernel_init=query_init, name="query")(inputs_q) 280 | key = projection(kernel_init=self.kernel_init, name="key")(inputs_kv) 281 | value = projection(kernel_init=self.kernel_init, name="value")(inputs_kv) 282 | 283 | query = with_sharding_constraint(query, ("batch", "length", "heads", "kv")) 284 | key = with_sharding_constraint(key, ("batch", "length", "heads", "kv")) 285 | value = with_sharding_constraint(value, ("batch", "length", "heads", "kv")) 286 | 287 | if decode: 288 | # Detect if we're initializing by absence of existing cache data. 289 | is_initialized = self.has_variable("cache", "cached_key") 290 | 291 | # The key and value have dimension [batch, length, num_heads, head_dim], 292 | # but we cache them as [batch, num_heads, head_dim, length] as a TPU 293 | # fusion optimization. This also enables the "scatter via one-hot 294 | # broadcast" trick, which means we do a one-hot broadcast instead of a 295 | # scatter/gather operations, resulting in a 3-4x speedup in practice. 296 | def swap_dims(x): 297 | return x[:-3] + tuple(x[i] for i in [-2, -1, -3]) 298 | 299 | cached_key = self.variable("cache", "cached_key", jnp.zeros, swap_dims(key.shape), key.dtype) 300 | cached_value = self.variable("cache", "cached_value", jnp.zeros, swap_dims(value.shape), value.dtype) 301 | cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) 302 | if is_initialized: 303 | batch, num_heads, head_dim, length = cached_key.value.shape 304 | # During fast autoregressive decoding, we feed one position at a time, 305 | # and cache the keys and values step by step. 306 | # Sanity shape check of cached key against input query. 307 | expected_shape = (batch, 1, num_heads, head_dim) 308 | if expected_shape != query.shape: 309 | raise ValueError( 310 | "Autoregressive cache shape error, " 311 | "expected query shape %s instead got %s." % (expected_shape, query.shape) 312 | ) 313 | 314 | # Create a OHE of the current index. NOTE: the index is increased below. 315 | cur_index = cache_index.value 316 | one_hot_indices = jax.nn.one_hot(cur_index, length, dtype=key.dtype) 317 | # In order to update the key, value caches with the current key and 318 | # value, we move the length axis to the back, similar to what we did for 319 | # the cached ones above. 320 | # Note these are currently the key and value of a single position, since 321 | # we feed one position at a time. 322 | one_token_key = jnp.moveaxis(key, -3, -1) 323 | one_token_value = jnp.moveaxis(value, -3, -1) 324 | # Update key, value caches with our new 1d spatial slices. 325 | # We implement an efficient scatter into the cache via one-hot 326 | # broadcast and addition. 327 | key = cached_key.value + one_token_key * one_hot_indices 328 | value = cached_value.value + one_token_value * one_hot_indices 329 | cached_key.value = key 330 | cached_value.value = value 331 | cache_index.value = cache_index.value + 1 332 | # Move the keys and values back to their original shapes. 333 | key = jnp.moveaxis(key, -1, -3) 334 | value = jnp.moveaxis(value, -1, -3) 335 | 336 | # Causal mask for cached decoder self-attention: our single query 337 | # position should only attend to those key positions that have already 338 | # been generated and cached, not the remaining zero elements. 339 | mask = combine_masks( 340 | mask, 341 | jnp.broadcast_to( 342 | jnp.arange(length) <= cur_index, 343 | # (1, 1, length) represent (head dim, query length, key length) 344 | # query length is 1 because during decoding we deal with one 345 | # index. 346 | # The same mask is applied to all batch elements and heads. 347 | (batch, 1, 1, length), 348 | ), 349 | ) 350 | 351 | # Grab the correct relative attention bias during decoding. This is 352 | # only required during single step decoding. 353 | if bias is not None: 354 | # The bias is a full attention matrix, but during decoding we only 355 | # have to take a slice of it. 356 | # This is equivalent to bias[..., cur_index:cur_index+1, :]. 357 | bias = dynamic_vector_slice_in_dim(jnp.squeeze(bias, axis=0), jnp.reshape(cur_index, (-1)), 1, -2) 358 | 359 | # Convert the boolean attention mask to an attention bias. 360 | if mask is not None: 361 | # attention mask in the form of attention bias 362 | attention_bias = lax.select( 363 | mask > 0, jnp.full(mask.shape, 0.0).astype(self.dtype), jnp.full(mask.shape, -1e10).astype(self.dtype) 364 | ) 365 | else: 366 | attention_bias = None 367 | 368 | # Add provided bias term (e.g. relative position embedding). 369 | if bias is not None: 370 | attention_bias = combine_biases(attention_bias, bias) 371 | 372 | dropout_rng = None 373 | if not deterministic and self.dropout_rate > 0.0: 374 | dropout_rng = self.make_rng("dropout") 375 | 376 | # Apply attention. 377 | x = dot_product_attention( 378 | query, 379 | key, 380 | value, 381 | bias=attention_bias, 382 | dropout_rng=dropout_rng, 383 | dropout_rate=self.dropout_rate, 384 | deterministic=deterministic, 385 | dtype=self.dtype, 386 | float32_logits=self.float32_logits, 387 | ) 388 | 389 | # Back to the original inputs dimensions. 390 | out = DenseGeneral( 391 | features=inputs_q.shape[-1], # output dim is set to the input dim. 392 | axis=(-2, -1), 393 | kernel_init=self.kernel_init, 394 | kernel_axes=("heads", "kv", "embed"), 395 | dtype=self.dtype, 396 | name="out", 397 | )(x) 398 | return out 399 | 400 | 401 | def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]: 402 | # A tuple by convention. len(axes_tuple) then also gives the rank efficiently. 403 | return tuple([ax if ax >= 0 else ndim + ax for ax in axes]) 404 | 405 | 406 | def _canonicalize_tuple(x): 407 | if isinstance(x, Iterable): 408 | return tuple(x) 409 | else: 410 | return (x,) 411 | 412 | 413 | # ------------------------------------------------------------------------------ 414 | # DenseGeneral for attention layers. 415 | # ------------------------------------------------------------------------------ 416 | class DenseGeneral(nn.Module): 417 | """A linear transformation (without bias) with flexible axes. 418 | 419 | Attributes: 420 | features: tuple with numbers of output features. 421 | axis: tuple with axes to apply the transformation on. 422 | dtype: the dtype of the computation (default: float32). 423 | kernel_init: initializer function for the weight matrix. 424 | """ 425 | 426 | features: Union[Iterable[int], int] 427 | axis: Union[Iterable[int], int] = -1 428 | dtype: DType = jnp.float32 429 | params_dtype: DType = jnp.float32 430 | kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "normal") 431 | kernel_axes: Tuple[str, ...] = () 432 | use_bias: bool = True 433 | bias_init: Any = nn.initializers.zeros 434 | 435 | @nn.compact 436 | def __call__(self, inputs: Array) -> Array: 437 | """Applies a linear transformation to the inputs along multiple dimensions. 438 | 439 | Args: 440 | inputs: The nd-array to be transformed. 441 | 442 | Returns: 443 | The transformed input. 444 | """ 445 | features = _canonicalize_tuple(self.features) 446 | axis = _canonicalize_tuple(self.axis) 447 | 448 | inputs = jnp.asarray(inputs, self.dtype) 449 | axis = _normalize_axes(axis, inputs.ndim) 450 | 451 | kernel_shape = tuple([inputs.shape[ax] for ax in axis]) + features 452 | kernel_in_axis = np.arange(len(axis)) 453 | kernel_out_axis = np.arange(len(axis), len(axis) + len(features)) 454 | kernel = param_with_axes( 455 | "kernel", 456 | self.kernel_init, 457 | kernel_shape, 458 | self.params_dtype, 459 | kernel_in_axis, 460 | kernel_out_axis, 461 | axes=self.kernel_axes, 462 | ) 463 | if self.use_bias: 464 | bias = param_with_axes("bias", self.bias_init, features, self.params_dtype, axes=(self.kernel_axes[-1],)) 465 | kernel = jnp.asarray(kernel, self.dtype) 466 | 467 | contract_ind = tuple(range(0, len(axis))) 468 | y = lax.dot_general(inputs, kernel, ((axis, contract_ind), ((), ()))) 469 | if self.use_bias: 470 | bias = jnp.asarray(bias, self.dtype) 471 | # y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,)) 472 | y += jnp.reshape(bias, (1,) * (len(features) - y.ndim) + bias.shape[:]) 473 | return y 474 | 475 | 476 | def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Callable: 477 | """Convert a string to an activation function.""" 478 | if fn_or_string == "linear": 479 | return lambda x: x 480 | elif isinstance(fn_or_string, str): 481 | return getattr(nn, fn_or_string) 482 | elif callable(fn_or_string): 483 | return fn_or_string 484 | else: 485 | raise ValueError("don't know how to convert %s to an activation function" % (fn_or_string,)) 486 | 487 | 488 | class MlpBlock(nn.Module): 489 | """Transformer MLP / feed-forward block. 490 | 491 | Attributes: 492 | intermediate_dim: Shared dimension of hidden layers. 493 | activations: Type of activations for each layer. Each element is either 494 | 'linear', a string function name in flax.linen, or a function. 495 | kernel_init: Kernel function, passed to the dense layers. 496 | deterministic: Whether the dropout layers should be deterministic. 497 | intermediate_dropout_rate: Dropout rate used after the intermediate layers. 498 | dtype: Type for the dense layer. 499 | """ 500 | 501 | intermediate_dim: int = 2048 502 | activations: Sequence[Union[str, Callable]] = ("relu",) 503 | kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "truncated_normal") 504 | intermediate_dropout_rate: float = 0.1 505 | dtype: Any = jnp.float32 506 | 507 | @nn.compact 508 | def __call__(self, inputs, decode: bool = False, deterministic: bool = False): 509 | """Applies Transformer MlpBlock module.""" 510 | # Iterate over specified MLP input activation functions. 511 | # e.g. ('relu',) or ('gelu', 'linear') for gated-gelu. 512 | activations = [] 513 | for idx, act_fn in enumerate(self.activations): 514 | dense_name = "wi" if len(self.activations) == 1 else f"wi_{idx}" 515 | x = DenseGeneral( 516 | self.intermediate_dim, 517 | dtype=self.dtype, 518 | kernel_init=self.kernel_init, 519 | kernel_axes=("embed", "mlp"), 520 | name=dense_name, 521 | )(inputs) 522 | x = _convert_to_activation_function(act_fn)(x) 523 | activations.append(x) 524 | 525 | # Take elementwise product of above intermediate activations. 526 | x = functools.reduce(operator.mul, activations) 527 | # Apply dropout and final dense output projection. 528 | x = nn.Dropout(rate=self.intermediate_dropout_rate, broadcast_dims=(-2,))( 529 | x, deterministic=deterministic 530 | ) # Broadcast along length. 531 | x = with_sharding_constraint(x, ("batch", "length", "mlp")) 532 | output = DenseGeneral( 533 | inputs.shape[-1], dtype=self.dtype, kernel_init=self.kernel_init, kernel_axes=("mlp", "embed"), name="wo" 534 | )(x) 535 | return output 536 | 537 | 538 | class Embed(nn.Module): 539 | """A parameterized function from integers [0, n) to d-dimensional vectors. 540 | 541 | Attributes: 542 | num_embeddings: number of embeddings. 543 | features: number of feature dimensions for each embedding. 544 | dtype: the dtype of the embedding vectors (default: float32). 545 | embedding_init: embedding initializer. 546 | one_hot: performs the gather with a one-hot contraction rather than a true 547 | gather. This is currently needed for SPMD partitioning. 548 | """ 549 | 550 | num_embeddings: int 551 | features: int 552 | cast_input_dtype: Optional[DType] = None 553 | dtype: DType = jnp.float32 554 | params_dtype: DType = jnp.float32 555 | attend_dtype: Optional[DType] = None 556 | embedding_init: Initializer = default_embed_init 557 | one_hot: bool = True 558 | embedding: Array = dataclasses.field(init=False) 559 | 560 | def setup(self): 561 | self.embedding = param_with_axes( 562 | "embedding", 563 | self.embedding_init, 564 | (self.num_embeddings, self.features), 565 | self.params_dtype, 566 | axes=("vocab", "embed"), 567 | ) 568 | 569 | def __call__(self, inputs: Array) -> Array: 570 | """Embeds the inputs along the last dimension. 571 | 572 | Args: 573 | inputs: input data, all dimensions are considered batch dimensions. 574 | 575 | Returns: 576 | Output which is embedded input data. The output shape follows the input, 577 | with an additional `features` dimension appended. 578 | """ 579 | if self.cast_input_dtype: 580 | inputs = inputs.astype(self.cast_input_dtype) 581 | if not jnp.issubdtype(inputs.dtype, jnp.integer): 582 | raise ValueError("Input type must be an integer or unsigned integer.") 583 | if self.one_hot: 584 | iota = lax.iota(jnp.int32, self.num_embeddings) 585 | one_hot = jnp.array(inputs[..., jnp.newaxis] == iota, dtype=self.dtype) 586 | output = jnp.dot(one_hot, jnp.asarray(self.embedding, self.dtype)) 587 | else: 588 | output = jnp.asarray(self.embedding, self.dtype)[inputs] 589 | output = with_sharding_constraint(output, ("batch", "length", "embed")) 590 | return output 591 | 592 | def attend(self, query: Array) -> Array: 593 | """Attend over the embedding using a query array. 594 | 595 | Args: 596 | query: array with last dimension equal the feature depth `features` of the 597 | embedding. 598 | 599 | Returns: 600 | An array with final dim `num_embeddings` corresponding to the batched 601 | inner-product of the array of query vectors against each embedding. 602 | Commonly used for weight-sharing between embeddings and logit transform 603 | in NLP models. 604 | """ 605 | dtype = self.attend_dtype if self.attend_dtype is not None else self.dtype 606 | return jnp.dot(query, jnp.asarray(self.embedding, dtype).T) 607 | 608 | 609 | class RelativePositionBiases(nn.Module): 610 | """Adds T5-style relative positional embeddings to the attention logits. 611 | 612 | Attributes: 613 | num_buckets: Number of buckets to bucket distances between key and query 614 | positions into. 615 | max_distance: Maximum distance before everything is lumped into the last 616 | distance bucket. 617 | num_heads: Number of heads in the attention layer. Each head will get a 618 | different relative position weighting. 619 | dtype: Type of arrays through this module. 620 | embedding_init: initializer for relative embedding table. 621 | """ 622 | 623 | num_buckets: int 624 | max_distance: int 625 | num_heads: int 626 | dtype: Any 627 | embedding_init: Callable[..., Array] = nn.linear.default_embed_init 628 | 629 | @staticmethod 630 | def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): 631 | """Translate relative position to a bucket number for relative attention. 632 | 633 | The relative position is defined as memory_position - query_position, i.e. 634 | the distance in tokens from the attending position to the attended-to 635 | position. If bidirectional=False, then positive relative positions are 636 | invalid. 637 | We use smaller buckets for small absolute relative_position and larger 638 | buckets for larger absolute relative_positions. All relative 639 | positions >=max_distance map to the same bucket. All relative 640 | positions <=-max_distance map to the same bucket. This should allow for 641 | more graceful generalization to longer sequences than the model has been 642 | trained on. 643 | 644 | Args: 645 | relative_position: an int32 array 646 | bidirectional: a boolean - whether the attention is bidirectional 647 | num_buckets: an integer 648 | max_distance: an integer 649 | 650 | Returns: 651 | a Tensor with the same shape as relative_position, containing int32 652 | values in the range [0, num_buckets) 653 | """ 654 | ret = 0 655 | n = -relative_position 656 | if bidirectional: 657 | num_buckets //= 2 658 | ret += (n < 0).astype(np.int32) * num_buckets 659 | n = np.abs(n) 660 | else: 661 | n = np.maximum(n, 0) 662 | # now n is in the range [0, inf) 663 | max_exact = num_buckets // 2 664 | is_small = n < max_exact 665 | val_if_large = max_exact + ( 666 | np.log(n.astype(np.float32) / max_exact + np.finfo(np.float32).eps) 667 | / np.log(max_distance / max_exact) 668 | * (num_buckets - max_exact) 669 | ).astype(np.int32) 670 | val_if_large = np.minimum(val_if_large, num_buckets - 1) 671 | ret += np.where(is_small, n, val_if_large) 672 | return ret 673 | 674 | @nn.compact 675 | def __call__(self, qlen, klen, bidirectional=True): 676 | """Produce relative position embedding attention biases. 677 | 678 | Args: 679 | qlen: attention query length. 680 | klen: attention key length. 681 | bidirectional: whether to allow positive memory-query relative position 682 | embeddings. 683 | 684 | Returns: 685 | output: `(1, len, q_len, k_len)` attention bias 686 | """ 687 | # TODO(levskaya): should we be computing this w. numpy as a program 688 | # constant? 689 | context_position = np.arange(qlen, dtype=jnp.int32)[:, None] 690 | memory_position = np.arange(klen, dtype=jnp.int32)[None, :] 691 | relative_position = memory_position - context_position # shape (qlen, klen) 692 | rp_bucket = self._relative_position_bucket( 693 | relative_position, 694 | bidirectional=bidirectional, 695 | num_buckets=self.num_buckets, 696 | max_distance=self.max_distance, 697 | ) 698 | relative_attention_bias = param_with_axes( 699 | "rel_embedding", 700 | self.embedding_init, 701 | (self.num_heads, self.num_buckets), 702 | jnp.float32, 703 | axes=("heads", "relpos_buckets"), 704 | ) 705 | 706 | relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype) 707 | # Instead of using a slow gather, we create a leading-dimension one-hot 708 | # array from rp_bucket and use it to perform the gather-equivalent via a 709 | # contraction, i.e.: 710 | # (num_head, num_buckets) x (num_buckets one-hot, qlen, klen). 711 | # This is equivalent to relative_attention_bias[:, rp_bucket] 712 | bcast_iota = lax.broadcasted_iota(jnp.int32, (self.num_buckets, 1, 1), 0) 713 | rp_bucket_one_hot = jnp.array(rp_bucket[jnp.newaxis, ...] == bcast_iota, dtype=self.dtype) 714 | # --> shape (qlen, klen, num_heads) 715 | values = lax.dot_general( 716 | relative_attention_bias, rp_bucket_one_hot, (((1,), (0,)), ((), ())) # rhs, lhs contracting dims 717 | ) # no batched dims 718 | # Add a singleton batch dimension. 719 | # --> shape (1, num_heads, qlen, klen) 720 | return values[jnp.newaxis, ...] 721 | 722 | 723 | # ------------------------------------------------------------------------------ 724 | # T5 Layernorm - no subtraction of mean or bias. 725 | # ------------------------------------------------------------------------------ 726 | # class LayerNorm(nn.Module): 727 | # """T5 Layer normalization operating on the last axis of the input data.""" 728 | # epsilon: float = 1e-6 729 | # dtype: Any = jnp.float32 730 | # scale_init: Initializer = nn.initializers.ones 731 | 732 | # @nn.compact 733 | # def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 734 | # """Applies layer normalization on the input.""" 735 | # x = jnp.asarray(x, jnp.float32) 736 | # features = x.shape[-1] 737 | # mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True) 738 | # y = jnp.asarray(x * lax.rsqrt(mean2 + self.epsilon), self.dtype) 739 | # scale = param_with_axes( 740 | # 'scale', self.scale_init, (features,), jnp.float32, axes=('embed',)) 741 | 742 | # scale = jnp.asarray(scale, self.dtype) 743 | # return y * scale 744 | 745 | 746 | class LayerNorm(nn.Module): 747 | """Layer normalization (https://arxiv.org/abs/1607.06450). 748 | Operates on the last axis of the input data. 749 | It normalizes the activations of the layer for each given example in a 750 | batch independently, rather than across a batch like Batch Normalization. 751 | i.e. applies a transformation that maintains the mean activation within 752 | each example close to 0 and the activation standard deviation close to 1. 753 | Attributes: 754 | epsilon: A small float added to variance to avoid dividing by zero. 755 | dtype: the dtype of the computation (default: float32). 756 | use_bias: If True, bias (beta) is added. 757 | use_scale: If True, multiply by scale (gamma). When the next layer is linear 758 | (also e.g. nn.relu), this can be disabled since the scaling will be done 759 | by the next layer. 760 | bias_init: Initializer for bias, by default, zero. 761 | scale_init: Initializer for scale, by default, one. 762 | """ 763 | 764 | epsilon: float = 1e-6 765 | dtype: Any = jnp.float32 766 | params_dtype: DType = jnp.float32 767 | use_bias: bool = True 768 | use_scale: bool = True 769 | bias_init: Callable[[PRNGKey, Shape, Any], Array] = nn.initializers.zeros 770 | scale_init: Callable[[PRNGKey, Shape, Any], Array] = nn.initializers.ones 771 | 772 | @nn.compact 773 | def __call__(self, x): 774 | """Applies layer normalization on the input. 775 | Args: 776 | x: the inputs 777 | Returns: 778 | Normalized inputs (the same shape as inputs). 779 | """ 780 | x = jnp.asarray(x, jnp.float32) 781 | features = x.shape[-1] 782 | mean = jnp.mean(x, axis=-1, keepdims=True) 783 | mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True) 784 | var = mean2 - lax.square(mean) 785 | mul = lax.rsqrt(var + self.epsilon) 786 | if self.use_scale: 787 | scale = param_with_axes("scale", self.scale_init, (features,), self.params_dtype, axes=("embed",)) 788 | mul = mul * jnp.asarray(scale, self.dtype) 789 | y = (x - mean) * mul 790 | if self.use_bias: 791 | bias = param_with_axes("bias", self.bias_init, (features,), self.params_dtype, axes=("embed",)) 792 | y = y + jnp.asarray(bias, self.dtype) 793 | return jnp.asarray(y, self.dtype) 794 | 795 | 796 | # ------------------------------------------------------------------------------ 797 | # Mask-making utility functions. 798 | # ------------------------------------------------------------------------------ 799 | def make_attention_mask( 800 | query_input: Array, 801 | key_input: Array, 802 | pairwise_fn: Callable = jnp.multiply, 803 | extra_batch_dims: int = 0, 804 | dtype: DType = jnp.float32, 805 | ) -> Array: 806 | """Mask-making helper for attention weights. 807 | 808 | In case of 1d inputs (i.e., `[batch, len_q]`, `[batch, len_kv]`, the 809 | attention weights will be `[batch, heads, len_q, len_kv]` and this 810 | function will produce `[batch, 1, len_q, len_kv]`. 811 | 812 | Args: 813 | query_input: a batched, flat input of query_length size 814 | key_input: a batched, flat input of key_length size 815 | pairwise_fn: broadcasting elementwise comparison function 816 | extra_batch_dims: number of extra batch dims to add singleton axes for, none 817 | by default 818 | dtype: mask return dtype 819 | 820 | Returns: 821 | A `[batch, 1, len_q, len_kv]` shaped mask for 1d attention. 822 | """ 823 | # [batch, len_q, len_kv] 824 | mask = pairwise_fn( 825 | # [batch, len_q] -> [batch, len_q, 1] 826 | jnp.expand_dims(query_input, axis=-1), 827 | # [batch, len_q] -> [batch, 1, len_kv] 828 | jnp.expand_dims(key_input, axis=-2), 829 | ) 830 | 831 | # [batch, 1, len_q, len_kv]. This creates the head dim. 832 | mask = jnp.expand_dims(mask, axis=-3) 833 | mask = jnp.expand_dims(mask, axis=tuple(range(extra_batch_dims))) 834 | return mask.astype(dtype) 835 | 836 | 837 | def make_causal_mask(x: Array, extra_batch_dims: int = 0, dtype: DType = jnp.float32) -> Array: 838 | """Make a causal mask for self-attention. 839 | 840 | In case of 1d inputs (i.e., `[batch, len]`, the self-attention weights 841 | will be `[batch, heads, len, len]` and this function will produce a 842 | causal mask of shape `[batch, 1, len, len]`. 843 | 844 | Note that a causal mask does not depend on the values of x; it only depends on 845 | the shape. If x has padding elements, they will not be treated in a special 846 | manner. 847 | 848 | Args: 849 | x: input array of shape `[batch, len]` 850 | extra_batch_dims: number of batch dims to add singleton axes for, none by 851 | default 852 | dtype: mask return dtype 853 | 854 | Returns: 855 | A `[batch, 1, len, len]` shaped causal mask for 1d attention. 856 | """ 857 | idxs = jnp.broadcast_to(jnp.arange(x.shape[-1], dtype=jnp.int32), x.shape) 858 | return make_attention_mask(idxs, idxs, jnp.greater_equal, extra_batch_dims=extra_batch_dims, dtype=dtype) 859 | 860 | 861 | def combine_masks(*masks: Optional[Array], dtype: DType = jnp.float32): 862 | """Combine attention masks. 863 | 864 | Args: 865 | *masks: set of attention mask arguments to combine, some can be None. 866 | dtype: final mask dtype 867 | 868 | Returns: 869 | Combined mask, reduced by logical and, returns None if no masks given. 870 | """ 871 | masks = [m for m in masks if m is not None] 872 | if not masks: 873 | return None 874 | assert all( 875 | (x.ndim == masks[0].ndim for x in masks) 876 | ), f"masks must have same rank: {tuple((x.ndim for x in masks))}" 877 | mask, *other_masks = masks 878 | for other_mask in other_masks: 879 | mask = jnp.logical_and(mask, other_mask) 880 | return mask.astype(dtype) 881 | 882 | 883 | def combine_biases(*masks: Optional[Array]): 884 | """Combine attention biases. 885 | 886 | Args: 887 | *masks: set of attention bias arguments to combine, some can be None. 888 | 889 | Returns: 890 | Combined mask, reduced by summation, returns None if no masks given. 891 | """ 892 | masks = [m for m in masks if m is not None] 893 | if not masks: 894 | return None 895 | assert all( 896 | (x.ndim == masks[0].ndim for x in masks) 897 | ), f"masks must have same rank: {tuple((x.ndim for x in masks))}" 898 | mask, *other_masks = masks 899 | for other_mask in other_masks: 900 | mask = mask + other_mask 901 | return mask 902 | 903 | 904 | def make_decoder_mask( 905 | decoder_target_tokens: Array, 906 | dtype: DType, 907 | decoder_causal_attention: Optional[Array] = None, 908 | decoder_segment_ids: Optional[Array] = None, 909 | ) -> Array: 910 | """Compute the self-attention mask for a decoder. 911 | 912 | Decoder mask is formed by combining a causal mask, a padding mask and an 913 | optional packing mask. If decoder_causal_attention is passed, it makes the 914 | masking non-causal for positions that have value of 1. 915 | 916 | A prefix LM is applied to a dataset which has a notion of "inputs" and 917 | "targets", e.g., a machine translation task. The inputs and targets are 918 | concatenated to form a new target. `decoder_target_tokens` is the concatenated 919 | decoder output tokens. 920 | 921 | The "inputs" portion of the concatenated sequence can attend to other "inputs" 922 | tokens even for those at a later time steps. In order to control this 923 | behavior, `decoder_causal_attention` is necessary. This is a binary mask with 924 | a value of 1 indicating that the position belonged to "inputs" portion of the 925 | original dataset. 926 | 927 | Example: 928 | 929 | Suppose we have a dataset with two examples. 930 | 931 | ds = [{"inputs": [6, 7], "targets": [8]}, 932 | {"inputs": [3, 4], "targets": [5]}] 933 | 934 | After the data preprocessing with packing, the two examples are packed into 935 | one example with the following three fields (some fields are skipped for 936 | simplicity). 937 | 938 | decoder_target_tokens = [[6, 7, 8, 3, 4, 5, 0]] 939 | decoder_segment_ids = [[1, 1, 1, 2, 2, 2, 0]] 940 | decoder_causal_attention = [[1, 1, 0, 1, 1, 0, 0]] 941 | 942 | where each array has [batch, length] shape with batch size being 1. Then, 943 | this function computes the following mask. 944 | 945 | mask = [[[[1, 1, 0, 0, 0, 0, 0], 946 | [1, 1, 0, 0, 0, 0, 0], 947 | [1, 1, 1, 0, 0, 0, 0], 948 | [0, 0, 0, 1, 1, 0, 0], 949 | [0, 0, 0, 1, 1, 0, 0], 950 | [0, 0, 0, 1, 1, 1, 0], 951 | [0, 0, 0, 0, 0, 0, 0]]]] 952 | 953 | mask[b, 1, :, :] represents the mask for the example `b` in the batch. 954 | Because mask is for a self-attention layer, the mask's shape is a square of 955 | shape [query length, key length]. 956 | 957 | mask[b, 1, i, j] = 1 means that the query token at position i can attend to 958 | the key token at position j. 959 | 960 | Args: 961 | decoder_target_tokens: decoder output tokens. [batch, length] 962 | dtype: dtype of the output mask. 963 | decoder_causal_attention: a binary mask indicating which position should 964 | only attend to earlier positions in the sequence. Others will attend 965 | bidirectionally. [batch, length] 966 | decoder_segment_ids: decoder segmentation info for packed examples. [batch, 967 | length] 968 | 969 | Returns: 970 | the combined decoder mask. 971 | """ 972 | masks = [] 973 | # The same mask is applied to all attention heads. So the head dimension is 1, 974 | # i.e., the mask will be broadcast along the heads dim. 975 | # [batch, 1, length, length] 976 | causal_mask = make_causal_mask(decoder_target_tokens, dtype=dtype) 977 | 978 | # Positions with value 1 in `decoder_causal_attneition` can attend 979 | # bidirectionally. 980 | if decoder_causal_attention is not None: 981 | # [batch, 1, length, length] 982 | inputs_mask = make_attention_mask( 983 | decoder_causal_attention, decoder_causal_attention, jnp.logical_and, dtype=dtype 984 | ) 985 | masks.append(jnp.logical_or(causal_mask, inputs_mask).astype(dtype)) 986 | else: 987 | masks.append(causal_mask) 988 | 989 | # Padding mask. 990 | masks.append(make_attention_mask(decoder_target_tokens > 0, decoder_target_tokens > 0, dtype=dtype)) 991 | 992 | # Packing mask 993 | if decoder_segment_ids is not None: 994 | masks.append(make_attention_mask(decoder_segment_ids, decoder_segment_ids, jnp.equal, dtype=dtype)) 995 | 996 | return combine_masks(*masks, dtype=dtype) 997 | 998 | 999 | def canonicalize_padding(padding: PaddingLike, rank: int) -> LaxPadding: 1000 | """ "Canonicalizes conv padding to a jax.lax supported format.""" 1001 | if isinstance(padding, str): 1002 | return padding 1003 | if isinstance(padding, int): 1004 | return [(padding, padding)] * rank 1005 | if isinstance(padding, Sequence) and len(padding) == rank: 1006 | new_pad = [] 1007 | for p in padding: 1008 | if isinstance(p, int): 1009 | new_pad.append((p, p)) 1010 | elif isinstance(p, tuple) and len(p) == 2: 1011 | new_pad.append(p) 1012 | else: 1013 | break 1014 | if len(new_pad) == rank: 1015 | return new_pad 1016 | raise ValueError( 1017 | f"Invalid padding format: {padding}, should be str, int," 1018 | f" or a sequence of len {rank} where each element is an" 1019 | f" int or pair of ints." 1020 | ) 1021 | 1022 | 1023 | def _conv_dimension_numbers(input_shape): 1024 | """Computes the dimension numbers based on the input shape.""" 1025 | ndim = len(input_shape) 1026 | lhs_spec = (0, ndim - 1) + tuple(range(1, ndim - 1)) 1027 | rhs_spec = (ndim - 1, ndim - 2) + tuple(range(0, ndim - 2)) 1028 | out_spec = lhs_spec 1029 | return lax.ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec) 1030 | 1031 | 1032 | class _Conv(nn.Module): 1033 | """Convolution Module wrapping `lax.conv_general_dilated[_local]`. 1034 | 1035 | Attributes: 1036 | features: number of convolution filters. 1037 | kernel_size: shape of the convolutional kernel. For 1D convolution, 1038 | the kernel size can be passed as an integer. For all other cases, it must 1039 | be a sequence of integers. 1040 | strides: an integer or a sequence of `n` integers, representing the 1041 | inter-window strides (default: 1). 1042 | padding: either the string `'SAME'`, the string `'VALID'`, the string 1043 | `'CIRCULAR'` (periodic boundary conditions), or a sequence of `n` `(low, 1044 | high)` integer pairs that give the padding to apply before and after each 1045 | spatial dimension. A single int is interpeted as applying the same padding 1046 | in all dims and passign a single int in a sequence causes the same padding 1047 | to be used on both sides. `'CAUSAL'` padding for a 1D convolution will 1048 | left-pad the convolution axis, resulting in same-sized output. 1049 | input_dilation: an integer or a sequence of `n` integers, giving the 1050 | dilation factor to apply in each spatial dimension of `inputs` 1051 | (default: 1). Convolution with input dilation `d` is equivalent to 1052 | transposed convolution with stride `d`. 1053 | kernel_dilation: an integer or a sequence of `n` integers, giving the 1054 | dilation factor to apply in each spatial dimension of the convolution 1055 | kernel (default: 1). Convolution with kernel dilation 1056 | is also known as 'atrous convolution'. 1057 | feature_group_count: integer, default 1. If specified divides the input 1058 | features into groups. 1059 | use_bias: whether to add a bias to the output (default: True). 1060 | mask: Optional mask for the weights during masked convolution. The mask must 1061 | be the same shape as the convolution weight matrix. 1062 | dtype: the dtype of the computation (default: infer from input and params). 1063 | params_dtype: the dtype passed to parameter initializers (default: float32). 1064 | precision: numerical precision of the computation see `jax.lax.Precision` 1065 | for details. 1066 | kernel_init: initializer for the convolutional kernel. 1067 | bias_init: initializer for the bias. 1068 | """ 1069 | 1070 | features: int 1071 | kernel_size: Sequence[int] 1072 | strides: Union[None, int, Sequence[int]] = 1 1073 | padding: PaddingLike = "SAME" 1074 | input_dilation: Union[None, int, Sequence[int]] = 1 1075 | kernel_dilation: Union[None, int, Sequence[int]] = 1 1076 | feature_group_count: int = 1 1077 | use_bias: bool = True 1078 | mask: Optional[Array] = None 1079 | dtype: Optional[DType] = None 1080 | params_dtype: DType = jnp.float32 1081 | precision: PrecisionLike = None 1082 | kernel_init: Callable[[PRNGKey, Shape, DType], Array] = nn.initializers.lecun_normal() 1083 | bias_init: Callable[[PRNGKey, Shape, DType], Array] = nn.initializers.zeros 1084 | conv_general_dilated: ConvGeneralDilatedT = lax.conv_general_dilated 1085 | kernel_axes: Tuple[str, ...] = () 1086 | 1087 | @property 1088 | def shared_weights(self) -> bool: # type: ignore 1089 | """Defines whether weights are shared or not between different pixels. 1090 | 1091 | Returns: 1092 | `True` to use shared weights in convolution (regular convolution). 1093 | `False` to use different weights at different pixels, a.k.a. 1094 | "locally connected layer", "unshared convolution", or "local convolution". 1095 | 1096 | """ 1097 | ... 1098 | 1099 | @nn.compact 1100 | def __call__(self, inputs: Array) -> Array: 1101 | """Applies a (potentially unshared) convolution to the inputs. 1102 | 1103 | Args: 1104 | inputs: input data with dimensions (*batch_dims, spatial_dims..., 1105 | features). This is the channels-last convention, i.e. NHWC for a 2d 1106 | convolution and NDHWC for a 3D convolution. Note: this is different from 1107 | the input convention used by `lax.conv_general_dilated`, which puts the 1108 | spatial dimensions last. 1109 | Note: If the input has more than 1 batch dimension, all batch dimensions 1110 | are flattened into a single dimension for the convolution and restored 1111 | before returning. In some cases directly vmap'ing the layer may yield 1112 | better performance than this default flattening approach. If the input 1113 | lacks a batch dimension it will be added for the convolution and removed 1114 | n return, an allowance made to enable writing single-example code. 1115 | 1116 | Returns: 1117 | The convolved data. 1118 | """ 1119 | 1120 | if isinstance(self.kernel_size, int): 1121 | raise TypeError( 1122 | "Expected Conv kernel_size to be a" 1123 | " tuple/list of integers (eg.: [3, 3]) but got" 1124 | f" {self.kernel_size}." 1125 | ) 1126 | else: 1127 | kernel_size = tuple(self.kernel_size) 1128 | 1129 | def maybe_broadcast(x: Optional[Union[int, Sequence[int]]]) -> Tuple[int, ...]: 1130 | if x is None: 1131 | # backward compatibility with using None as sentinel for 1132 | # broadcast 1 1133 | x = 1 1134 | if isinstance(x, int): 1135 | return (x,) * len(kernel_size) 1136 | return tuple(x) 1137 | 1138 | # Combine all input batch dimensions into a single leading batch axis. 1139 | num_batch_dimensions = inputs.ndim - (len(kernel_size) + 1) 1140 | if num_batch_dimensions != 1: 1141 | input_batch_shape = inputs.shape[:num_batch_dimensions] 1142 | total_batch_size = int(np.prod(input_batch_shape)) 1143 | flat_input_shape = (total_batch_size,) + inputs.shape[num_batch_dimensions:] 1144 | inputs = jnp.reshape(inputs, flat_input_shape) 1145 | 1146 | # self.strides or (1,) * (inputs.ndim - 2) 1147 | strides = maybe_broadcast(self.strides) 1148 | input_dilation = maybe_broadcast(self.input_dilation) 1149 | kernel_dilation = maybe_broadcast(self.kernel_dilation) 1150 | 1151 | padding_lax = canonicalize_padding(self.padding, len(kernel_size)) 1152 | if padding_lax == "CIRCULAR": 1153 | kernel_size_dilated = [(k - 1) * d + 1 for k, d in zip(kernel_size, kernel_dilation)] 1154 | zero_pad: List[Tuple[int, int]] = [(0, 0)] 1155 | pads = zero_pad + [((k - 1) // 2, k // 2) for k in kernel_size_dilated] + [(0, 0)] 1156 | inputs = jnp.pad(inputs, pads, mode="wrap") 1157 | padding_lax = "VALID" 1158 | elif padding_lax == "CAUSAL": 1159 | if len(kernel_size) != 1: 1160 | raise ValueError("Causal padding is only implemented for 1D convolutions.") 1161 | left_pad = kernel_dilation[0] * (kernel_size[0] - 1) 1162 | pads = [(0, 0), (left_pad, 0), (0, 0)] 1163 | inputs = jnp.pad(inputs, pads) 1164 | padding_lax = "VALID" 1165 | 1166 | dimension_numbers = _conv_dimension_numbers(inputs.shape) 1167 | in_features = jnp.shape(inputs)[-1] 1168 | 1169 | if self.shared_weights: 1170 | # One shared convolutional kernel for all pixels in the output. 1171 | assert in_features % self.feature_group_count == 0 1172 | kernel_shape = kernel_size + (in_features // self.feature_group_count, self.features) 1173 | 1174 | else: 1175 | if self.feature_group_count != 1: 1176 | raise NotImplementedError( 1177 | f"`lax.conv_general_dilated_local` does not support " 1178 | f"`feature_group_count != 1`, got `{self.feature_group_count}`." 1179 | ) 1180 | 1181 | # Need to know the spatial output shape of a standard convolution to 1182 | # create the unshared convolution kernel. 1183 | conv_output_shape = jax.eval_shape( 1184 | lambda lhs, rhs: self.conv_general_dilated( # pylint: disable=g-long-lambda 1185 | lhs=lhs, 1186 | rhs=rhs, 1187 | window_strides=strides, 1188 | padding=padding_lax, 1189 | dimension_numbers=dimension_numbers, 1190 | lhs_dilation=input_dilation, 1191 | rhs_dilation=kernel_dilation, 1192 | ), 1193 | inputs, 1194 | jax.ShapedArray(kernel_size + (in_features, self.features), inputs.dtype), 1195 | ).shape 1196 | 1197 | # One (unshared) convolutional kernel per each pixel in the output. 1198 | kernel_shape = conv_output_shape[1:-1] + (np.prod(kernel_size) * in_features, self.features) 1199 | 1200 | if self.mask is not None and self.mask.shape != kernel_shape: 1201 | raise ValueError( 1202 | "Mask needs to have the same shape as weights. " f"Shapes are: {self.mask.shape}, {kernel_shape}" 1203 | ) 1204 | 1205 | kernel = param_with_axes( 1206 | "kernel", 1207 | self.kernel_init, 1208 | kernel_shape, 1209 | self.params_dtype, 1210 | axes=self.kernel_axes, 1211 | ) 1212 | 1213 | if self.mask is not None: 1214 | kernel *= self.mask 1215 | 1216 | if self.use_bias: 1217 | if self.shared_weights: 1218 | # One bias weight per output channel, shared between pixels. 1219 | bias_shape = (self.features,) 1220 | else: 1221 | # One bias weight per output entry, unshared betwen pixels. 1222 | bias_shape = conv_output_shape[1:] 1223 | 1224 | bias = param_with_axes( 1225 | "bias", 1226 | self.bias_init, 1227 | bias_shape, 1228 | self.params_dtype, 1229 | axes=(self.kernel_axes[-1],), 1230 | ) 1231 | else: 1232 | bias = None 1233 | 1234 | inputs, kernel, bias = promote_dtype(inputs, kernel, bias, dtype=self.dtype) 1235 | if self.shared_weights: 1236 | y = self.conv_general_dilated( 1237 | inputs, 1238 | kernel, 1239 | strides, 1240 | padding_lax, 1241 | lhs_dilation=input_dilation, 1242 | rhs_dilation=kernel_dilation, 1243 | dimension_numbers=dimension_numbers, 1244 | feature_group_count=self.feature_group_count, 1245 | precision=self.precision, 1246 | ) 1247 | else: 1248 | y = lax.conv_general_dilated_local( 1249 | lhs=inputs, 1250 | rhs=kernel, 1251 | window_strides=strides, 1252 | padding=padding_lax, 1253 | filter_shape=kernel_size, 1254 | lhs_dilation=input_dilation, 1255 | rhs_dilation=kernel_dilation, 1256 | dimension_numbers=dimension_numbers, 1257 | precision=self.precision, 1258 | ) 1259 | 1260 | if self.use_bias: 1261 | bias = bias.reshape((1,) * (y.ndim - bias.ndim) + bias.shape) 1262 | y += bias 1263 | 1264 | if num_batch_dimensions != 1: 1265 | output_shape = input_batch_shape + y.shape[1:] 1266 | y = jnp.reshape(y, output_shape) 1267 | return y 1268 | 1269 | 1270 | class Conv(_Conv): 1271 | """Convolution Module wrapping `lax.conv_general_dilated`. 1272 | 1273 | Attributes: 1274 | features: number of convolution filters. 1275 | kernel_size: shape of the convolutional kernel. For 1D convolution, 1276 | the kernel size can be passed as an integer. For all other cases, it must 1277 | be a sequence of integers. 1278 | strides: an integer or a sequence of `n` integers, representing the 1279 | inter-window strides (default: 1). 1280 | padding: either the string `'SAME'`, the string `'VALID'`, the string 1281 | `'CIRCULAR'` (periodic boundary conditions), or a sequence of `n` `(low, 1282 | high)` integer pairs that give the padding to apply before and after each 1283 | spatial dimension. A single int is interpeted as applying the same padding 1284 | in all dims and passign a single int in a sequence causes the same padding 1285 | to be used on both sides. `'CAUSAL'` padding for a 1D convolution will 1286 | left-pad the convolution axis, resulting in same-sized output. 1287 | input_dilation: an integer or a sequence of `n` integers, giving the 1288 | dilation factor to apply in each spatial dimension of `inputs` 1289 | (default: 1). Convolution with input dilation `d` is equivalent to 1290 | transposed convolution with stride `d`. 1291 | kernel_dilation: an integer or a sequence of `n` integers, giving the 1292 | dilation factor to apply in each spatial dimension of the convolution 1293 | kernel (default: 1). Convolution with kernel dilation 1294 | is also known as 'atrous convolution'. 1295 | feature_group_count: integer, default 1. If specified divides the input 1296 | features into groups. 1297 | use_bias: whether to add a bias to the output (default: True). 1298 | mask: Optional mask for the weights during masked convolution. The mask must 1299 | be the same shape as the convolution weight matrix. 1300 | dtype: the dtype of the computation (default: infer from input and params). 1301 | params_dtype: the dtype passed to parameter initializers (default: float32). 1302 | precision: numerical precision of the computation see `jax.lax.Precision` 1303 | for details. 1304 | kernel_init: initializer for the convolutional kernel. 1305 | bias_init: initializer for the bias. 1306 | """ 1307 | 1308 | @property 1309 | def shared_weights(self) -> bool: 1310 | return True 1311 | -------------------------------------------------------------------------------- /whisper_jax/partitioner.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The T5X Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utilities for partitioning.""" 17 | 18 | import abc 19 | import collections 20 | import dataclasses 21 | import typing 22 | from typing import Any, Callable, Optional, Sequence, Tuple, Union 23 | 24 | import cached_property 25 | import jax 26 | import numpy as np 27 | from absl import logging 28 | from flax import traverse_util 29 | from flax.linen import partitioning as flax_partitioning 30 | from jax import numpy as jnp 31 | from jax import random 32 | from jax.experimental import multihost_utils 33 | from jax.experimental.mesh_utils import create_hybrid_device_mesh 34 | from jax.experimental.pjit import pjit as jax_pjit 35 | from jax.sharding import Mesh, PartitionSpec 36 | 37 | 38 | JaxDevice = Any 39 | TpuMesh = Tuple[int, int, int, int] # (x, y, z, num_cores). 40 | OtherMesh = Tuple[int, int] 41 | HardwareMesh = Union[TpuMesh, OtherMesh] 42 | PyTreeDef = type(jax.tree_util.tree_structure(None)) 43 | TrainState = Any 44 | LogicalAxisRules = Sequence[Tuple[str, Optional[str]]] 45 | 46 | if typing.TYPE_CHECKING: # See b/163639353 47 | cached_property = property # pylint: disable=invalid-name 48 | else: 49 | cached_property = cached_property.cached_property 50 | 51 | 52 | class AxisNames(tuple): 53 | """Tuple of strings specifying name for each axis. 54 | 55 | We create a separate class for this so JAX's pytree utilities can distinguish 56 | it from a tuple that should be treated as a pytree, instead treating it as a 57 | leaf. 58 | """ 59 | 60 | def __new__(cls, *names): 61 | return tuple.__new__(AxisNames, names) 62 | 63 | def __repr__(self): 64 | return "AxisNames%s" % tuple.__repr__(self) 65 | 66 | 67 | # pjit wrappers for cpu fallback. 68 | # ---------------------------------------------------------------------------- 69 | # TODO(levskaya): This function is now no different than jax_pjit, but callers 70 | # currently depend on `backend` argument 71 | def pjit( 72 | fun: Callable, # pylint: disable=g-bare-generic 73 | in_axis_resources, 74 | out_axis_resources, 75 | static_argnums: Union[int, Sequence[int]] = (), 76 | donate_argnums: Union[int, Sequence[int]] = (), 77 | backend: Optional[str] = None, 78 | ): 79 | """Wrapper for pjit.""" 80 | del backend 81 | return jax_pjit( 82 | fun, in_axis_resources, out_axis_resources, static_argnums=static_argnums, donate_argnums=donate_argnums 83 | ) 84 | 85 | 86 | # pjit wrappers for cpu fallback. 87 | # ----------------------------------------------------------------------------- 88 | # TODO(levskaya): upstream this fallback behavior to jax pjit. 89 | def pjit_with_cpu_fallback( 90 | fun: Callable, # pylint: disable=g-bare-generic 91 | in_axis_resources, 92 | out_axis_resources, 93 | static_argnums: Union[int, Sequence[int]] = (), 94 | donate_argnums: Union[int, Sequence[int]] = (), 95 | backend: Optional[str] = None, 96 | ): 97 | """Wrapper for pjit that calls normal jit on cpu.""" 98 | if jax.devices(backend)[0].platform == "cpu": 99 | return jax.jit(fun, static_argnums=static_argnums, donate_argnums=donate_argnums) 100 | else: 101 | return jax_pjit( 102 | fun, in_axis_resources, out_axis_resources, static_argnums=static_argnums, donate_argnums=donate_argnums 103 | ) 104 | 105 | 106 | def with_sharding_constraint(x, axis_resources): 107 | """Wrapper for pjit with_sharding_constraint, no-op on cpu or outside pjit.""" 108 | if jax.devices()[0].platform == "cpu" or not global_mesh_defined(): 109 | return x 110 | else: 111 | return jax.experimental.pjit.with_sharding_constraint(x, axis_resources) 112 | 113 | 114 | # pjit Mesh creation functions. 115 | # ----------------------------------------------------------------------------- 116 | def bounds_from_last_device(last_device: JaxDevice) -> HardwareMesh: 117 | """Get the bound from the given last device.""" 118 | # Must be passed the device at the highest-coordinate corner of the 119 | # relevant mesh, which is a requirement we know is satisfied by the last 120 | # device in jax.devices(). 121 | if hasattr(last_device, "coords"): 122 | x, y, z = last_device.coords 123 | return x + 1, y + 1, z + 1, last_device.core_on_chip + 1 124 | else: 125 | # On non-TPU platforms, the "mesh" is hosts x devices per host in order 126 | # to take advantage of faster within-host interconnect. 127 | return jax.host_count(), jax.local_device_count() 128 | 129 | 130 | def get_coords(device: JaxDevice) -> HardwareMesh: 131 | """Returns the coordinates of the given device.""" 132 | if hasattr(device, "coords"): 133 | return (*device.coords, device.core_on_chip) 134 | return (device.process_index, device.id % jax.local_device_count()) 135 | 136 | 137 | def global_mesh_defined(): 138 | """Checks if global xmap/pjit mesh resource environment is defined.""" 139 | maps_env = jax.experimental.maps.thread_resources.env 140 | return maps_env.physical_mesh.devices.shape != () # pylint: disable=g-explicit-bool-comparison 141 | 142 | 143 | def get_mesh( 144 | model_parallel_submesh: HardwareMesh, 145 | input_devices: Sequence[JaxDevice] = (), 146 | input_local_devices: Sequence[JaxDevice] = (), 147 | tile_by_host_if_needed: bool = True, 148 | backend: Optional[str] = None, 149 | ) -> Mesh: 150 | """Construct an xmap/pjit Mesh for the given model-parallel submesh. 151 | 152 | The resulting mesh has two resource axes: 'model', with the provided submesh 153 | shape, and 'data', which covers the rest of the mesh. 154 | 155 | Args: 156 | model_parallel_submesh: a HardwareMesh spec, namely (x,y,z,core) on TPU for 157 | a single model-parallel replica's "tile" in the physical device mesh. The 158 | first three elements (`x`, `y`, and `z`) should be factors of the pod 159 | slice; e.g., if you are using df_4x8, then `x` should be a factor of 4 160 | (one of 1, 2, 4), `y` should be a factor of 8 (one of 1, 2, 4, 8), and `z` 161 | must be 1, because TPU v3 slices are only 2D. `z` can be >1 for TPU v4 162 | (and maybe later TPUs) that allow 3D slices. `core` is the number of cores 163 | to use from each TPU node. As communication is usually fastest inside the 164 | same node, if you need a tile of more than 1 core, then 165 | you should first increase `core`: e.g., for TPU v3, (1,1,1,2) is better 166 | than (2,1,1,1). To pick a good spec, try a few possible values until you 167 | get high TPU utilization. 168 | input_devices: the devices to use, will use jax.devices() if this is not 169 | set. 170 | input_local_devices: the local devices to use, will use jax.local_devices() 171 | if this is not set. 172 | tile_by_host_if_needed: JAX currently requires that the parts of any sharded 173 | array that are located on one host's local devices form a single 174 | contiguous slice. A best effort will be made to achieve this without 175 | "tiling" the device assignment over hosts (which can reduce XLA collective 176 | performance). If this flag is True, then the device assignment will be 177 | tiled over hosts if necessary to satisfy this constraint and create a 178 | buildable mesh; if false, mesh construction will fail instead. 179 | backend: get devices from the pinned backend, if specified. This is 180 | useful for explicitly specifying the devices other than relying on 181 | jax_platform_name. 182 | 183 | Returns: 184 | A xmap / pjit Mesh containing the virtual device mesh with data, model axes. 185 | """ 186 | input_devices = input_devices or jax.devices(backend) 187 | input_local_devices = input_local_devices or jax.local_devices(0, backend) 188 | # Sort input_devices based on coords, as backends might not return devices 189 | # in order. 190 | last_device = sorted(input_devices, key=get_coords)[-1] 191 | last_input_local_devices = sorted(input_local_devices, key=get_coords)[-1] 192 | logging.info( 193 | "last device coords : %r\nlast local device coords: %r", 194 | get_coords(last_device), 195 | get_coords(last_input_local_devices), 196 | ) 197 | global_hardware_mesh = bounds_from_last_device(last_device) 198 | mesh_ndim = len(global_hardware_mesh) 199 | local_hardware_mesh = bounds_from_last_device(last_input_local_devices) 200 | mesh_err = ( 201 | f"each dimension of the model parallel submesh {model_parallel_submesh} " 202 | "must be a factor of the corresponding dimension of the global device " 203 | f"mesh {global_hardware_mesh}" 204 | ) 205 | assert not any(g % m for g, m in zip(global_hardware_mesh, model_parallel_submesh)), mesh_err 206 | assert not any(g % l for g, l in zip(global_hardware_mesh, local_hardware_mesh)) 207 | devices = np.empty(global_hardware_mesh, dtype=object) 208 | for device in input_devices: 209 | device_coords = get_coords(device) 210 | devices[device_coords] = device 211 | tile_by_host = tile_by_host_if_needed 212 | if len(global_hardware_mesh) == 4: 213 | # enable contiguous local chunks without host tiling by making Z major 214 | global_hardware_mesh = typing.cast(Tuple[int, int, int, int], global_hardware_mesh) 215 | model_parallel_submesh = typing.cast(Tuple[int, int, int, int], model_parallel_submesh) 216 | gx, gy, gz, gc = global_hardware_mesh 217 | mx, my, mz, mc = model_parallel_submesh 218 | if (mx == gx > 1 and my == mz == 1) or (mx == 1 and my == gy > 1 and mz == gz > 1): 219 | logging.info("ensuring YZ plane has a Z-major device order") 220 | # YZ should be ZY 221 | assert mc == gc, (mc, gc) 222 | global_hardware_mesh = gx, gz, gy, gc 223 | model_parallel_submesh = mx, mz, my, mc 224 | devices = devices.swapaxes(1, 2) 225 | tile_by_host = False 226 | if (my == gy > 1 and mx == mz == 1) or (my == 1 and mx == gx > 1 and mz == gz > 1): 227 | logging.info("ensuring XZ plane has a Z-major device order") 228 | # XZ should be ZX 229 | assert mc == gc, (mc, gc) 230 | global_hardware_mesh = gz, gy, gx, gc 231 | model_parallel_submesh = mz, my, mx, mc 232 | devices = devices.swapaxes(0, 2) 233 | tile_by_host = False 234 | if tile_by_host: 235 | logging.warning( 236 | "Tiling device assignment mesh by hosts, which may lead to " 237 | "reduced XLA collective performance. To avoid this, modify " 238 | "the model parallel submesh or run with more tasks per host." 239 | ) 240 | tile_err = ( 241 | "to tile the mesh by hosts, each dimension of the model parallel " 242 | "submesh must be either a factor or a multiple of the corresponding " 243 | "dimension of the per-host submesh" 244 | ) 245 | 246 | def dh_dd_mh_md(g: int, m: int, l: int) -> Tuple[int, int, int, int]: 247 | """Split a global mesh dimension into four tiling components. 248 | 249 | Args: 250 | g: global mesh bounds dimension size 251 | m: model-parallel submesh bounds dimension size 252 | l: local submesh bounds dimension size 253 | 254 | Returns: 255 | The resulting tuple divides the dimension into the hosts component of 256 | the data-parallel submesh, the devices component of the data-parallel 257 | submesh, the hosts component of the model-parallel submesh, and the 258 | devices component of the model-parallel submesh. 259 | """ 260 | d = g // m 261 | if m >= l: 262 | assert not m % l, tile_err 263 | return (d, 1, m // l, l) 264 | else: 265 | assert not l % m, tile_err 266 | return (d // (l // m), l // m, 1, m) 267 | 268 | # e.g. [(x_data_hosts, x_data_devs, x_model_hosts, x_model_devs), ...] 269 | dh_dd_mh_md_tups = map(dh_dd_mh_md, global_hardware_mesh, model_parallel_submesh, local_hardware_mesh) 270 | # reshape to e.g. (x_dh, x_dd, x_mh, x_md, y_dh, ...) 271 | devices = devices.reshape(*(s for t in dh_dd_mh_md_tups for s in t)) # pylint: disable=g-complex-comprehension 272 | # TODO(jekbradbury): reorder local subgroups for ring locality 273 | # Transpose to [data_host], [data_device], [model_host], [model_device] 274 | # block ordering e.g. (x_dh, y_dh, ..., x_dd, y_dd, ...) 275 | devices = devices.transpose( 276 | *(4 * i for i in range(mesh_ndim)), 277 | *(4 * i + 1 for i in range(mesh_ndim)), 278 | *(4 * i + 2 for i in range(mesh_ndim)), 279 | *(4 * i + 3 for i in range(mesh_ndim)), 280 | ) 281 | else: 282 | # e.g. [(x_data, x_model), (y_data, y_model), ...] 283 | model_data_tups = [(g // m, m) for g, m in zip(global_hardware_mesh, model_parallel_submesh)] 284 | # reshape to e.g. (x_data, x_model, y_data, y_model...) 285 | devices = devices.reshape(*(s for t in model_data_tups for s in t)) # pylint: disable=g-complex-comprehension 286 | # TODO(jekbradbury): reorder small subgroups for ring locality 287 | # transpose to e.g. (x_data, y_data, ..., x_model, ...) 288 | devices = devices.transpose(*(2 * i for i in range(mesh_ndim)), *(2 * i + 1 for i in range(mesh_ndim))) 289 | # reshape to (data, model) 290 | devices = devices.reshape(-1, np.prod(model_parallel_submesh)) 291 | global_mesh = Mesh(devices, ["data", "model"]) 292 | logging.info("global_mesh axis_names: %s", global_mesh.axis_names) 293 | logging.info("global_mesh devices: %s", global_mesh.devices) 294 | logging.info("global_mesh devices shape: %s", global_mesh.devices.shape) 295 | return global_mesh 296 | 297 | 298 | def get_cpu_mesh() -> Mesh: 299 | """Trivial mesh for CPU Testing.""" 300 | devices = np.empty((jax.host_count(), jax.local_device_count()), dtype=object) 301 | for device in jax.devices(): 302 | devices[device.process_index, device.id % jax.local_device_count()] = device 303 | return Mesh(devices, ["data", "model"]) 304 | 305 | 306 | def get_gpu_mesh(num_partitions: int) -> Mesh: 307 | """Mesh for GPUs that preferentially places 'model' on NVLink.""" 308 | nvlink_size = jax.local_device_count() 309 | dcn_size = jax.process_count() 310 | nvlink_mp = min(num_partitions, nvlink_size) 311 | nvlink_dp, extra1 = divmod(nvlink_size, nvlink_mp) 312 | dcn_mp, extra2 = divmod(num_partitions, nvlink_mp) 313 | assert not (extra1 or extra2), ( 314 | "number of partitions on GPU must be a factor" " or multiple of the number of local devices" 315 | ) 316 | dcn_dp = dcn_size // dcn_mp 317 | 318 | devices = create_hybrid_device_mesh( 319 | mesh_shape=[nvlink_dp, nvlink_mp], dcn_mesh_shape=[dcn_dp, dcn_mp], process_is_granule=True 320 | ) 321 | 322 | global_mesh = Mesh(devices, ["data", "model"]) 323 | logging.info("global_mesh axis_names: %s", global_mesh.axis_names) 324 | logging.info("global_mesh devices: %s", global_mesh.devices) 325 | return global_mesh 326 | 327 | 328 | def default_mesh( 329 | num_partitions: int, model_parallel_submesh: Optional[HardwareMesh] = None, backend: Optional[str] = None 330 | ) -> Mesh: 331 | """Attempt to return a default mesh for simple cases. 332 | 333 | Args: 334 | num_partitions: number of partitions to use, will be ignored if 335 | model_parallel_submesh is provided. 336 | model_parallel_submesh: 4-tuple that specifies the x,y,z,c submesh to use as 337 | the model-parallel device tile. 338 | backend: get devices from the pinned backend, if specified. This is useful 339 | for explicitly specifying the devices other than relying on 340 | jax_platform_name. 341 | 342 | Returns: 343 | xmap/pjit 2D Mesh with 'data', 'model' mesh axes. 344 | """ 345 | last_device = jax.devices(backend)[-1] 346 | platform = last_device.platform 347 | device_kind = last_device.device_kind 348 | bounds = bounds_from_last_device(last_device) 349 | 350 | if model_parallel_submesh: 351 | return get_mesh(model_parallel_submesh, backend=backend) 352 | 353 | if platform == "cpu": 354 | return get_cpu_mesh() 355 | elif platform == "gpu": 356 | return get_gpu_mesh(num_partitions) 357 | 358 | mps = None 359 | if device_kind in ("TPU v2", "TPU v3"): 360 | if num_partitions == 1: 361 | mps = (1, 1, 1, 1) 362 | elif num_partitions == 2: 363 | mps = (1, 1, 1, 2) 364 | elif num_partitions == 4: 365 | mps = (2, 1, 1, 2) 366 | elif num_partitions == 8: 367 | mps = (2, 2, 1, 2) 368 | elif num_partitions == 16: 369 | mps = (4, 2, 1, 2) 370 | # assume the use of megacore on TPU v4 371 | elif (device_kind == "TPU v4" or device_kind == "TPU v4 lite") and bounds[3] == 1: 372 | if num_partitions == 1: 373 | mps = (1, 1, 1, 1) 374 | elif num_partitions == 2: 375 | mps = (1, 2, 1, 1) 376 | elif num_partitions == 4: 377 | if bounds[0] >= 4: 378 | mps = (4, 1, 1, 1) 379 | else: 380 | mps = (2, 2, 1, 1) 381 | elif num_partitions == 8: 382 | if bounds[2] >= 8: 383 | mps = (1, 1, 8, 1) 384 | else: 385 | mps = (4, 2, 1, 1) 386 | elif num_partitions == 16: 387 | if bounds[2] >= 16: 388 | mps = (1, 1, 16, 1) 389 | elif bounds[0] >= 8: 390 | mps = (8, 2, 1, 1) 391 | elif bounds[0] >= 4: 392 | mps = (4, 4, 1, 1) 393 | else: 394 | mps = (2, 2, 4, 1) 395 | 396 | if mps is None: 397 | raise ValueError( 398 | "No default mesh for this configuration: specify " "config.model_parallel_submesh explicitly." 399 | ) 400 | return get_mesh(mps, backend=backend) 401 | 402 | 403 | # Data chunking helper. 404 | # ----------------------------------------------------------------------------- 405 | @dataclasses.dataclass 406 | class LocalChunkInfo: 407 | # The logical slice of an array located on this host's local devices. 408 | slice: Tuple[slice, ...] 409 | # A unique index for this host/local chunk among chunks with the same slice. 410 | replica_id: int 411 | 412 | 413 | class LocalChunker: 414 | """Utility class to aid chunking of sharded arrays in multihost settings.""" 415 | 416 | def __init__(self, global_mesh: Mesh): 417 | self.global_mesh = global_mesh 418 | local_mesh = global_mesh.local_mesh 419 | first_local_device = local_mesh.devices.reshape(-1)[0] 420 | host_location = collections.OrderedDict( 421 | zip(global_mesh.shape.keys(), list(zip(*np.nonzero(global_mesh.devices == first_local_device)))[0]) 422 | ) 423 | self.num_chunks = collections.OrderedDict() 424 | self.chunk_ids = collections.OrderedDict() 425 | self.mesh_axes = list(global_mesh.shape.keys()) 426 | for mesh_axis in self.mesh_axes: 427 | num_devices_per_chunk = local_mesh.shape[mesh_axis] 428 | self.num_chunks[mesh_axis] = global_mesh.shape[mesh_axis] // num_devices_per_chunk 429 | self.chunk_ids[mesh_axis] = host_location[mesh_axis] // num_devices_per_chunk 430 | 431 | def get_local_chunk_info( 432 | self, global_shape: Tuple[int, ...], mesh_axes: Sequence[Optional[str]] 433 | ) -> LocalChunkInfo: 434 | """Get the local chunk info for a given array shape and sharded axes. 435 | 436 | Args: 437 | global_shape: the global, unsharded shape of the array to chunk. 438 | mesh_axes: a sequence of names (or None) of equal rank to `global_shape` 439 | that specifies which mesh dimensions the array is sharded along. 440 | 441 | Returns: 442 | LocalChunkInfo containing the logical slices of the array found on this 443 | host's local devices, as well as the replica index for this chunk among 444 | chunks with the same slice. The latter is used to determine which 445 | host should write this chunk during checkpointing. 446 | """ 447 | local_slice = [slice(None) for dim in global_shape] 448 | sharded_mesh_axes = set() 449 | for i, (mesh_axis, size) in enumerate(zip(mesh_axes, global_shape)): 450 | if not mesh_axis: 451 | continue 452 | sharded_mesh_axes.add(mesh_axis) 453 | if not isinstance(mesh_axis, str): 454 | raise NotImplementedError("TODO(jekbradbury)") 455 | chunk_id = self.chunk_ids[mesh_axis] 456 | chunk_size = size // self.num_chunks[mesh_axis] 457 | local_slice[i] = slice(chunk_id * chunk_size, (chunk_id + 1) * chunk_size) 458 | 459 | replicated_mesh_axes = [mesh_axis for mesh_axis in self.mesh_axes if mesh_axis not in sharded_mesh_axes] 460 | replica_id = 0 461 | for mesh_axis in replicated_mesh_axes: 462 | chunk_id = self.chunk_ids[mesh_axis] 463 | replica_id = replica_id * self.num_chunks[mesh_axis] + chunk_id 464 | 465 | return LocalChunkInfo(tuple(local_slice), replica_id) 466 | 467 | 468 | def standard_logical_axis_rules( 469 | activation_partitioning_dims: int = 1, 470 | parameter_partitioning_dims: int = 1, 471 | additional_rules: Optional[LogicalAxisRules] = None, 472 | ) -> LogicalAxisRules: 473 | """Default sharding rules for T5X model in terms of logical axis names. 474 | 475 | Args: 476 | activation_partitioning_dims: enables 2-D activation sharding when set to 2. 477 | parameter_partitioning_dims: enables 2-D parameter sharding when set to 2. 478 | additional_rules: additional rules (a sequence of tuples) that will be 479 | appended to the standard rules. 480 | 481 | Returns: 482 | Sequence of logical axis rules 483 | """ 484 | logging.info( 485 | "`activation_partitioning_dims` = %d, `parameter_partitioning_dims` = %d", 486 | activation_partitioning_dims, 487 | parameter_partitioning_dims, 488 | ) 489 | 490 | if activation_partitioning_dims == 1 and parameter_partitioning_dims == 1: 491 | rules = [ 492 | ("batch", "data"), 493 | ("vocab", "model"), 494 | ("embed", None), 495 | ("mlp", "model"), 496 | ("heads", "model"), 497 | ("kv", None), 498 | ("joined_kv", "model"), # joined heads+kv dim in 2D attn param layouts 499 | ] 500 | elif activation_partitioning_dims == 2 and parameter_partitioning_dims == 1: 501 | rules = [ 502 | ("batch", "data"), 503 | ("vocab", "model"), 504 | ("mlp", "model"), 505 | ("heads", "model"), 506 | ("kv", None), 507 | ("joined_kv", "model"), 508 | ("embed", "model"), 509 | ] 510 | elif activation_partitioning_dims == 1 and parameter_partitioning_dims == 2: 511 | rules = [ 512 | ("batch", "data"), 513 | ("vocab", "model"), 514 | ("mlp", "model"), 515 | ("heads", "model"), 516 | ("kv", None), 517 | ("joined_kv", "model"), 518 | ("embed", "data"), 519 | ] 520 | elif activation_partitioning_dims == 2 and parameter_partitioning_dims == 2: 521 | rules = [ 522 | ("batch", "data"), 523 | ("vocab", "model"), 524 | ("mlp", "model"), 525 | ("heads", "model"), 526 | ("kv", None), 527 | ("joined_kv", "model"), 528 | ("embed", "model"), 529 | ("embed", "data"), 530 | ] 531 | else: 532 | raise ValueError( 533 | f"`activation_partitioning_dims` = {activation_partitioning_dims} " 534 | f"`parameter_partitioning_dims` = {parameter_partitioning_dims} " 535 | "is not supported." 536 | ) 537 | 538 | # Add the common rules for the replicated logical axes names. 539 | replicated_rules = [ 540 | ("relpos_buckets", None), 541 | ("abspos_buckets", None), 542 | ("length", None), 543 | ("layers", None), 544 | ("stack", None), 545 | ("mlp_activations", None), 546 | ] 547 | rules.extend(replicated_rules) 548 | 549 | if additional_rules: 550 | rules.extend(additional_rules) 551 | 552 | return rules 553 | 554 | 555 | # NB: This needs to be top-level for the jax compilation cache. 556 | def _id_fn(x, ix): 557 | """Identity function for copying parameters to the devices, sharded.""" 558 | # A pure identity such as `lambda x, *: x` can get optimized away, so we 559 | # include a random.split as a cheap function that cannot be optimized away. 560 | y = random.split(random.PRNGKey(jnp.array(ix, dtype=jnp.uint32))) 561 | return x, y 562 | 563 | 564 | @dataclasses.dataclass 565 | class DataLayout: 566 | """Represents data layout for the partitioned model.""" 567 | 568 | batch_size: int 569 | shard_id: int 570 | num_shards: int 571 | is_first_host_in_replica_set: bool 572 | 573 | 574 | PartitionedCallable = Callable[..., Any] 575 | CompiledPartitionedCallable = Callable[..., Any] 576 | 577 | 578 | class BasePartitioner(metaclass=abc.ABCMeta): 579 | """Interface for partitioning computations across hardware devices.""" 580 | 581 | def __init__( 582 | self, 583 | num_partitions: Optional[int] = None, 584 | model_parallel_submesh: Optional[HardwareMesh] = None, 585 | params_on_devices: bool = True, 586 | backend: Optional[str] = None, 587 | ): 588 | """Configures the partitioner. 589 | 590 | Args: 591 | num_partitions: the number of partitions to use. Ignored if 592 | `model_parallel_submesh` is provided. 593 | model_parallel_submesh: 4-tuple that specifies the x,y,z,c submesh to use 594 | as the model-parallel device tile. This submesh is used for the larger 595 | of the two parameter dimensions, and, if 2-D activation sharding is 596 | enabled, for the model dimension of activations. The rest of the mesh is 597 | used for data parallelism and, if 2-D parameter sharding is enabled, the 598 | other parameter dimension. 599 | params_on_devices: whether to keep the params on devices, if False - 600 | params stay in the host memory. Note that some partitioners might ignore 601 | this setting, for example if they don't support storing all params on 602 | device memory. 603 | backend: get devices from the pinned backend, if specified. This is useful 604 | for explicitly specifying the devices other than relying on 605 | jax_platform_name. 606 | """ 607 | 608 | if not num_partitions and not model_parallel_submesh: 609 | raise ValueError("At least one of `num_partitions` or " "`model_parallel_submesh` must be set.") 610 | 611 | if model_parallel_submesh is not None and len(model_parallel_submesh) != 4: 612 | logging.error( 613 | ( 614 | "`model_parallel_submesh` must be either None or a 4-tuple. Got" 615 | " `model_parallel_submesh`=%s. A ValueError will be raised" 616 | " beginning March 1, 2022." 617 | ), 618 | model_parallel_submesh, 619 | ) 620 | 621 | if bool(num_partitions) and bool(model_parallel_submesh): 622 | logging.error( 623 | "At most one of `num_partitions` or `model_parallel_submesh` can be " 624 | "set. Got `num_partitions=%s` and `model_parallel_submesh`=%s. A " 625 | "ValueError will be raised beginning March 21, 2022.", 626 | num_partitions, 627 | model_parallel_submesh, 628 | ) 629 | 630 | self._num_partitions = num_partitions 631 | self._model_parallel_submesh = model_parallel_submesh 632 | self._params_on_devices = params_on_devices 633 | self._data_axis = "data" 634 | self._backend = backend 635 | 636 | @property 637 | def mesh(self) -> Mesh: 638 | raise NotImplementedError 639 | 640 | @property 641 | def data_partition_spec(self) -> PartitionSpec: 642 | return PartitionSpec(self._data_axis) 643 | 644 | def get_data_layout(self, batch_size: Optional[int] = None, host_index: Optional[int] = None) -> DataLayout: 645 | """Returns filled `DataLayout` based on the partitioned model layout. 646 | 647 | Args: 648 | batch_size: if set, indicates the requested batch size. The exception will 649 | be raised if this batch size is not compatible with the layout. If not 650 | set, the batch size is inferred from the layout. 651 | host_index: indicates the host index to use for the calculations, if not 652 | set - use JAX-provided one. Should be in [0, num_hosts) interval and the 653 | order should match the order of corresponding CPU devices in 654 | `jax.devices()`. 655 | 656 | Returns: 657 | Filled `DataLayout` structure. 658 | """ 659 | if host_index is not None: 660 | raise NotImplementedError("Explicit host_index is not yet implemented.") 661 | if self._data_axis is None: 662 | return DataLayout( 663 | batch_size=batch_size, 664 | shard_id=0, 665 | num_shards=1, 666 | is_first_host_in_replica_set=(jax.process_index() == 0), 667 | ) 668 | mesh_size = self._local_chunker.global_mesh.shape[self._data_axis] 669 | batch_size = batch_size or mesh_size 670 | if batch_size % mesh_size: 671 | raise ValueError( 672 | f"Batch size ({batch_size}) must be divisible by corresponding " f"mesh size ({mesh_size})." 673 | ) 674 | num_shards = self._local_chunker.num_chunks[self._data_axis] 675 | if batch_size % num_shards: 676 | raise ValueError(f"Batch size ({batch_size}) must be divisible by number of " f"replicas ({num_shards}).") 677 | replica_id = self._local_chunker.get_local_chunk_info((batch_size,), [self._data_axis]).replica_id 678 | return DataLayout( 679 | batch_size=int(batch_size), 680 | shard_id=int(self._local_chunker.chunk_ids[self._data_axis]), 681 | num_shards=int(num_shards), 682 | is_first_host_in_replica_set=(replica_id == 0), 683 | ) 684 | 685 | def get_local_chunk_info( 686 | self, global_shape: Tuple[int, ...], mesh_axes: Sequence[Optional[str]] 687 | ) -> LocalChunkInfo: 688 | """Returns the local chunk info for a given array shape and sharded axes.""" 689 | return self._local_chunker.get_local_chunk_info(global_shape, mesh_axes) 690 | 691 | @property 692 | def params_on_devices(self): 693 | return self._params_on_devices 694 | 695 | def move_params_to_devices(self, train_state: TrainState, train_state_axes: TrainState) -> TrainState: 696 | """Moves the optimizer parameters to devices.""" 697 | p_id_fn = self.partition( 698 | _id_fn, 699 | in_axis_resources=(train_state_axes, None), 700 | out_axis_resources=(train_state_axes, None), 701 | donate_argnums=(0,), 702 | ) 703 | if jax.config.jax_array and jax.process_count() > 1: 704 | train_state = multihost_utils.host_local_array_to_global_array(train_state, self.mesh, train_state_axes) 705 | train_state, _ = p_id_fn(train_state, jnp.ones((), dtype=jnp.uint32)) 706 | return train_state 707 | 708 | @property 709 | @abc.abstractmethod 710 | def _local_chunker(self): 711 | """Returns the chunker that matches the parameters of this partitioner.""" 712 | raise NotImplementedError 713 | 714 | def get_logical_axes(self, train_state: TrainState) -> TrainState: 715 | """Returns a copy of TrainState with Optional[AxisNames] as leaves.""" 716 | # By default, return None for the logical axes. 717 | return train_state.restore_state(jax.tree_map(lambda x: None, train_state.state_dict())) 718 | 719 | def get_mesh_axes(self, train_state: TrainState) -> TrainState: 720 | """Returns a copy of TrainState with Optional[PartitionSpecs] as leaves.""" 721 | raise NotImplementedError 722 | 723 | @abc.abstractmethod 724 | def partition( 725 | self, 726 | fn: Callable, # pylint: disable=g-bare-generic 727 | in_axis_resources, 728 | out_axis_resources, 729 | static_argnums: Union[int, Sequence[int]] = (), 730 | donate_argnums: Union[int, Sequence[int]] = (), 731 | ) -> PartitionedCallable: 732 | """Partitions the computation using partitioner-specific implementation. 733 | 734 | Args: 735 | fn: the function to partition. 736 | in_axis_resources: Pytree of structure matching that of arguments to `fn`, 737 | with all actual arguments replaced by resource assignment 738 | specifications. It is also valid to specify a pytree prefix (e.g. one 739 | value in place of a whole subtree), in which case the leaves get 740 | broadcast to all values in that subtree. 741 | The valid resource assignment specifications are: 742 | `None`: in which case the value will be replicated on all devices 743 | `PartitionSpec`: a tuple of length at most equal to the rank of the 744 | partitioned value. Each element can be a `None`, a mesh axis or a 745 | tuple of mesh axes, and specifies the set of resources assigned to 746 | partition the value's dimension matching its position in the spec. 747 | out_axis_resources: Like `in_axis_resources`, but specifies resource 748 | assignment for function outputs. 749 | static_argnums: an optional int or collection of ints that specify which 750 | positional arguments to treat as static (compile-time constant) in the 751 | partitioned function. 752 | donate_argnums: an optional int or collection of ints that specify which 753 | argument buffers are "donated" to the computation. It is safe to donate 754 | argument buffers if you no longer need them once the computation has 755 | finished. 756 | 757 | Returns: 758 | A partitioned version of the input function. 759 | """ 760 | raise NotImplementedError 761 | 762 | @abc.abstractmethod 763 | def compile(self, partitioned_fn: PartitionedCallable, *args) -> CompiledPartitionedCallable: 764 | """Compiles and returns the partitioned function, or the original. 765 | 766 | Args: 767 | partitioned_fn: The partitioned function. 768 | *args: Sample arguments to the partitioned function matching the input 769 | shapes that will be passed to the compiled function. 770 | 771 | Returns: 772 | The compiled function, or the original if this partitioner does not 773 | support compilation. 774 | """ 775 | raise NotImplementedError 776 | 777 | 778 | class PjittedFnWithContext(PartitionedCallable): 779 | """Wraps pjitted function to apply the appropriate contexts.""" 780 | 781 | def __init__(self, pjitted_fn, partition_mesh: Mesh, logical_axis_rules: flax_partitioning.LogicalRules = ()): 782 | self._pjitted_fn = pjitted_fn 783 | self._mesh = partition_mesh 784 | self._logical_axis_rules = logical_axis_rules 785 | 786 | def __call__(self, *args): 787 | with Mesh(self._mesh.devices, self._mesh.axis_names), flax_partitioning.axis_rules(self._logical_axis_rules): 788 | return self._pjitted_fn(*args) 789 | 790 | def lower(self, *args): 791 | with Mesh(self._mesh.devices, self._mesh.axis_names), flax_partitioning.axis_rules(self._logical_axis_rules): 792 | return self._pjitted_fn.lower(*args) 793 | 794 | 795 | class BasePjitPartitioner(BasePartitioner): 796 | """Partitioner that uses T5X version of jax.pjit.""" 797 | 798 | @cached_property 799 | def _local_chunker(self) -> LocalChunker: 800 | return LocalChunker(self.mesh) 801 | 802 | @cached_property 803 | def mesh(self) -> Mesh: 804 | return default_mesh(self._num_partitions, self._model_parallel_submesh, self._backend) 805 | 806 | def partition( 807 | self, 808 | fn: Callable, # pylint: disable=g-bare-generic 809 | in_axis_resources, 810 | out_axis_resources, 811 | static_argnums: Union[int, Sequence[int]] = (), 812 | donate_argnums: Union[int, Sequence[int]] = (), 813 | ) -> PjittedFnWithContext: 814 | pjitted = pjit( 815 | fn, 816 | in_axis_resources=in_axis_resources, 817 | out_axis_resources=out_axis_resources, 818 | static_argnums=static_argnums, 819 | donate_argnums=donate_argnums, 820 | backend=self._backend, 821 | ) 822 | 823 | return PjittedFnWithContext(pjitted, self.mesh) 824 | 825 | def compile(self, partitioned_fn: PjittedFnWithContext, *args) -> CompiledPartitionedCallable: 826 | return partitioned_fn.lower(*args).compile() 827 | 828 | 829 | class PjitPartitioner(BasePjitPartitioner): 830 | """Partitioner that uses named axes and jax.pjit.""" 831 | 832 | def __init__( 833 | self, 834 | num_partitions: Optional[int] = None, 835 | model_parallel_submesh: Optional[HardwareMesh] = None, 836 | params_on_devices: bool = True, 837 | backend: Optional[str] = None, 838 | logical_axis_rules: Optional[LogicalAxisRules] = None, 839 | use_cpu_pjit: Optional[bool] = False, 840 | ): 841 | """PjitPartitioner constructor. 842 | 843 | See https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.mdx/usage/partitioning for details. 844 | 845 | Args: 846 | num_partitions: an integer that specifies the size of the model parallel 847 | submesh to be automatically selected for the current topology. See 848 | `model_parallel_submesh` for details on how this submesh is used. 849 | Mutually exlusive with `model_parallel_submesh`. 850 | model_parallel_submesh: is a 4-tuple that specifies the `(x, y, z, c)` 851 | submesh model-parallel device tile, an axis of accelerator parallelism 852 | orthogonal to data parallelism. Array axes in a model's parameters or 853 | activations can be sharded over this submesh using axis rules (see 854 | `logical_axis_rules`) that map them to 'model'. The effective number of 855 | model sub-partitions is equal to `np.prod(model_parallel_submesh)` and 856 | must evenly divide the total number of devices (i.e., 857 | `jax.device_count() % np.prod(model_parallel_submesh) == 0`). The rest 858 | of the TPU mesh is the data parallel submesh, providing 859 | `jax.device_count() // np.prod(model_parallel_submesh)` partitions. It 860 | is used for data (batch) parallelism and to shard other array axes that 861 | are mapped to 'data'. This argument is mutually exclusive with 862 | `num_partitions`. 863 | params_on_devices: whether to keep the params on devices, if False - 864 | params stay in the host memory. Note that some partitioners might ignore 865 | this setting, for example if they don't support storing all params on 866 | device memory. 867 | backend: get devices from the pinned backend, if specified. This is 868 | useful for explicitly specifying the devices other than relying on 869 | jax_platform_name. 870 | logical_axis_rules: a priority-ordered sequence of KV tuples that maps 871 | logical axis names to either `None` (not sharded), 'model' (to shard 872 | across the model-parallel submesh), or 'data' (to shard across the 873 | data-parallel submesh). 874 | use_cpu_pjit: enables wrapper function for pjit which just jits the 875 | function if using CPU backend. 876 | """ 877 | super().__init__( 878 | num_partitions=num_partitions, 879 | model_parallel_submesh=model_parallel_submesh, 880 | params_on_devices=params_on_devices, 881 | backend=backend, 882 | ) 883 | if logical_axis_rules is None: 884 | logical_axis_rules = standard_logical_axis_rules() 885 | self._logical_axis_rules = tuple(logical_axis_rules) 886 | (self._data_axis,) = flax_partitioning.logical_to_mesh_axes(["batch"], logical_axis_rules) 887 | self._use_cpu_pjit = use_cpu_pjit 888 | 889 | def partition( 890 | self, 891 | fn: Callable, # pylint: disable=g-bare-generic 892 | in_axis_resources, 893 | out_axis_resources, 894 | static_argnums: Union[int, Sequence[int]] = (), 895 | donate_argnums: Union[int, Sequence[int]] = (), 896 | ) -> PjittedFnWithContext: 897 | """Partitions the function using jax.pjit.""" 898 | if self._use_cpu_pjit: 899 | pjit_fn = pjit_with_cpu_fallback 900 | else: 901 | pjit_fn = pjit 902 | pjitted = pjit_fn( 903 | fn, 904 | in_axis_resources=in_axis_resources, 905 | out_axis_resources=out_axis_resources, 906 | static_argnums=static_argnums, 907 | donate_argnums=donate_argnums, 908 | backend=self._backend, 909 | ) 910 | 911 | return PjittedFnWithContext(pjitted, self.mesh, self._logical_axis_rules) 912 | 913 | @property 914 | def logical_axis_rules(self): 915 | """Returns the logical axis rules.""" 916 | return self._logical_axis_rules 917 | 918 | def get_logical_axes(self, train_state: TrainState) -> TrainState: 919 | """Returns a copy of TrainState with Optional[AxisNames] as leaves.""" 920 | return train_state.as_logical_axes() 921 | 922 | def get_mesh_axes(self, train_state: TrainState) -> TrainState: 923 | """Returns a copy of TrainState with Optional[PartitionSpecs] as leaves.""" 924 | logical_axes = self.get_logical_axes(train_state) 925 | 926 | def _logical_to_mesh_axes(param_name, logical_axes): 927 | if logical_axes is None: 928 | return None 929 | elif logical_axes is traverse_util.empty_node: 930 | return traverse_util.empty_node 931 | try: 932 | return flax_partitioning.logical_to_mesh_axes(logical_axes, self._logical_axis_rules) 933 | except ValueError as e: 934 | raise ValueError(f"Failed to map logical axes for {param_name}") from e 935 | 936 | flat_logical_axes = traverse_util.flatten_dict(logical_axes.state_dict(), keep_empty_nodes=True, sep="/") 937 | flat_mesh_axes = {k: _logical_to_mesh_axes(k, v) for k, v in flat_logical_axes.items()} 938 | 939 | return logical_axes.restore_state(traverse_util.unflatten_dict(flat_mesh_axes, sep="/")) 940 | -------------------------------------------------------------------------------- /whisper_jax/pipeline.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import math 18 | 19 | import jax 20 | import jax.numpy as jnp 21 | import numpy as np 22 | import requests 23 | from flax import jax_utils 24 | from flax.core.frozen_dict import freeze 25 | from flax.training.common_utils import shard 26 | from jax.sharding import PartitionSpec as P 27 | from transformers import WhisperProcessor, is_tokenizers_available, WhisperFeatureExtractor, WhisperTokenizerFast 28 | from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE, WhisperTokenizer 29 | from transformers.pipelines.audio_utils import ffmpeg_read 30 | from transformers.utils import logging 31 | 32 | from .modeling_flax_whisper import FlaxWhisperForConditionalGeneration 33 | from .partitioner import PjitPartitioner 34 | from .train_state import InferenceState 35 | 36 | 37 | logger = logging.get_logger(__name__) 38 | 39 | # 2D parameter and activation partitioning for DP 40 | logical_axis_rules_dp = ( 41 | ("batch", "data"), 42 | ("mlp", None), 43 | ("heads", None), 44 | ("vocab", None), 45 | ("embed", None), 46 | ("embed", None), 47 | ("joined_kv", None), 48 | ("kv", None), 49 | ("length", None), 50 | ("num_mel", None), 51 | ("channels", None), 52 | ) 53 | 54 | 55 | class FlaxWhisperPipline: 56 | def __init__( 57 | self, 58 | checkpoint="openai/whisper-large-v2", 59 | dtype=jnp.float32, 60 | batch_size=None, 61 | max_length=None, 62 | ): 63 | """ 64 | Args 65 | checkpoint (`str`, *optional*, defaults to `"openai/whisper-large-v2"): 66 | The Whisper checkpoint to use with the pipeline. Must be an available checkpoint on the Hugging Face Hub 67 | with Flax weights. 68 | dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): 69 | The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and 70 | `jax.numpy.bfloat16` (on TPUs). This can be used to enable half-precision inference on GPUs or TPUs. 71 | If specified all the computation will be performed with the given `dtype`. **Note that this only 72 | specifies the dtype of the computation and does not influence the dtype of model parameters.** 73 | batch_size (`int`, *optional*, defaults to the minimum per-device batch size, i.e. `jax.local_device_count()`): 74 | The batch size to be used in chunking transcription. Beneficial for transcribing long audio files. Passing 75 | a batch size in the `__init__` method will be superseded by any batch size passed to the `__call__` method. 76 | max_length (`int`, *optional*): 77 | The maximum numbers of tokens to generate. Defaults to `model.config.max_length`. 78 | """ 79 | self.checkpoint = checkpoint 80 | self.dtype = dtype 81 | 82 | self.processor = WhisperProcessor.from_pretrained(self.checkpoint) 83 | self.feature_extractor = self.processor.feature_extractor 84 | # potentially load fast tokenizer if available 85 | tokenizer_cls = WhisperTokenizerFast if is_tokenizers_available() else WhisperTokenizer 86 | self.tokenizer = tokenizer_cls.from_pretrained(checkpoint) 87 | 88 | self.model, self.params = FlaxWhisperForConditionalGeneration.from_pretrained( 89 | self.checkpoint, 90 | _do_init=False, 91 | dtype=self.dtype, 92 | ) 93 | 94 | self.max_length = max_length if max_length is not None else self.model.generation_config.max_length 95 | self.min_batch_size = jax.local_device_count() 96 | self.batch_size = ( 97 | batch_size if batch_size is not None else self.min_batch_size 98 | ) # we need a minimum of 1 batch per-device 99 | 100 | def generate(params, input_features, forced_decoder_ids, return_timestamps): 101 | output_ids = self.model.pipeline_generate( 102 | input_features, 103 | params=params, 104 | forced_decoder_ids=forced_decoder_ids, 105 | return_timestamps=return_timestamps, 106 | max_length=self.max_length, 107 | ) 108 | return output_ids 109 | 110 | # use pmap for DP by default - this is compatible on a Colab TPU v2 111 | self.params = jax_utils.replicate(self.params) 112 | self.p_generate = jax.pmap( 113 | generate, "input_features", in_axes=(0, 0, None), out_axes=0, static_broadcasted_argnums=(3,) 114 | ) 115 | self.is_sharded = False 116 | 117 | def shard_params(self, num_mp_partitions=1, logical_axis_rules=logical_axis_rules_dp): 118 | def init_fn(): 119 | input_shape = (1, self.model.config.num_mel_bins, 2 * self.model.config.max_source_positions) 120 | 121 | input_features = jnp.zeros(input_shape, dtype="f4") 122 | input_features = input_features.at[(..., -1)].set(self.model.config.eos_token_id) 123 | 124 | decoder_input_ids = jnp.zeros((input_shape[0], 1), dtype="i4") 125 | decoder_attention_mask = jnp.ones_like(decoder_input_ids) 126 | 127 | batch_size, sequence_length = decoder_input_ids.shape 128 | decoder_position_ids = jnp.broadcast_to( 129 | jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) 130 | ) 131 | 132 | rng = jax.random.PRNGKey(0) 133 | init_params = self.model.module.init( 134 | rng, 135 | input_features=input_features, 136 | decoder_input_ids=decoder_input_ids, 137 | decoder_attention_mask=decoder_attention_mask, 138 | decoder_position_ids=decoder_position_ids, 139 | return_dict=False, 140 | ) 141 | return init_params 142 | 143 | # Axis names metadata 144 | param_axes = jax.eval_shape(init_fn)["params_axes"] 145 | 146 | # Create InferenceState, since the partitioner expects it 147 | state = InferenceState( 148 | step=jnp.array(0), 149 | params=freeze(self.model.params_shape_tree), 150 | params_axes=freeze(param_axes), 151 | flax_mutables=None, 152 | flax_mutables_axes=param_axes, 153 | ) 154 | 155 | partitioner = PjitPartitioner(num_partitions=num_mp_partitions, logical_axis_rules=logical_axis_rules) 156 | 157 | mesh_axes = partitioner.get_mesh_axes(state) 158 | params_spec = mesh_axes.params 159 | 160 | p_shard_params = partitioner.partition(self.model.to_bf16, (params_spec,), params_spec) 161 | 162 | # This will auto-magically run in mesh context 163 | self.params = p_shard_params(freeze(jax_utils.unreplicate(self.params))) 164 | self.is_sharded = True 165 | 166 | def generate(params, input_features, forced_decoder_ids, return_timestamps): 167 | output_ids = self.model.pipeline_generate( 168 | input_features, 169 | params=params, 170 | forced_decoder_ids=forced_decoder_ids, 171 | return_timestamps=return_timestamps, 172 | max_length=self.max_length, 173 | ) 174 | return output_ids 175 | 176 | # Use pjit for generate only once we've sharded the params 177 | self.p_generate = partitioner.partition( 178 | generate, 179 | in_axis_resources=(params_spec, P("data"), None), 180 | out_axis_resources=P("data"), 181 | static_argnums=(3,), 182 | ) 183 | 184 | def generate(self, input_features, language=None, task=None, return_timestamps=False): 185 | forced_decoder_ids = self.get_forced_decoder_ids( 186 | language=language, task=task, return_timestamps=return_timestamps 187 | ) 188 | if not self.is_sharded: 189 | # if we're using pmap we need to manually replicate the input data across devices and gather the output tokens 190 | output_ids = self.p_generate( 191 | freeze(self.params), shard(input_features), forced_decoder_ids, return_timestamps 192 | ).sequences 193 | output_ids = jax.device_get(output_ids.reshape(-1, self.max_length)) 194 | else: 195 | # pjit handles replication / gathering for us auto-magically 196 | output_ids = self.p_generate( 197 | freeze(self.params), input_features, forced_decoder_ids, return_timestamps 198 | ).sequences 199 | return output_ids 200 | 201 | def get_forced_decoder_ids(self, generation_config=None, task=None, language=None, return_timestamps=False): 202 | if generation_config is None: 203 | generation_config = self.model.generation_config 204 | 205 | if hasattr(generation_config, "is_multilingual"): 206 | is_multilingual = generation_config.is_multilingual 207 | else: 208 | is_multilingual = None 209 | 210 | forced_decoder_ids = [] 211 | 212 | if is_multilingual: 213 | if language is not None: 214 | language = language.lower() 215 | if language in generation_config.lang_to_id.keys(): 216 | language_token = language 217 | elif language in TO_LANGUAGE_CODE.values(): 218 | language_token = f"<|{language}|>" 219 | elif language in TO_LANGUAGE_CODE.keys(): 220 | language_token = f"<|{TO_LANGUAGE_CODE[language]}|>" 221 | else: 222 | if len(language) == 2: 223 | # ISO 639-1 language code 224 | acceptable_languages = list(TO_LANGUAGE_CODE.values()) 225 | elif "<" in language or "|" in language or ">" in language: 226 | # generation config language code 227 | acceptable_languages = list(generation_config.lang_to_id.keys()) 228 | else: 229 | # language passed as a string 230 | acceptable_languages = list(TO_LANGUAGE_CODE.keys()) 231 | raise ValueError( 232 | f"Unsupported language: {language}. Language should be one of:" f" {acceptable_languages}." 233 | ) 234 | forced_decoder_ids.append((1, generation_config.lang_to_id[language_token])) 235 | 236 | if task is not None: 237 | forced_decoder_ids.append((2, generation_config.task_to_id[task])) 238 | else: 239 | forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"])) 240 | 241 | if not return_timestamps: 242 | if forced_decoder_ids and forced_decoder_ids[-1][0] != generation_config.no_timestamps_token_id: 243 | idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1 244 | forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id)) 245 | 246 | return forced_decoder_ids 247 | 248 | def chunk_iter_with_batch(self, inputs, chunk_len, stride_left, stride_right, batch_size): 249 | inputs_len = inputs.shape[0] 250 | step = chunk_len - stride_left - stride_right 251 | 252 | all_chunk_start_idx = np.arange(0, inputs_len, step) 253 | num_samples = len(all_chunk_start_idx) 254 | 255 | num_batches = math.ceil(num_samples / batch_size) 256 | batch_idx = np.array_split(np.arange(num_samples), num_batches) 257 | 258 | for idx in batch_idx: 259 | chunk_start_idx = all_chunk_start_idx[idx] 260 | chunk_end_idx = chunk_start_idx + chunk_len 261 | 262 | chunks = [inputs[chunk_start:chunk_end] for chunk_start, chunk_end in zip(chunk_start_idx, chunk_end_idx)] 263 | processed = self.feature_extractor( 264 | chunks, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="np" 265 | ) 266 | 267 | _stride_left = np.where(chunk_start_idx == 0, 0, stride_left) 268 | is_last = np.where(stride_right > 0, chunk_end_idx > inputs_len, chunk_end_idx >= inputs_len) 269 | _stride_right = np.where(is_last, 0, stride_right) 270 | 271 | chunk_lens = [chunk.shape[0] for chunk in chunks] 272 | strides = [ 273 | (chunk_l, _stride_l, _stride_r) 274 | for chunk_l, _stride_l, _stride_r in zip(chunk_lens, _stride_left, _stride_right) 275 | ] 276 | 277 | yield {"stride": strides, **processed} 278 | 279 | def preprocess_batch(self, inputs, chunk_length_s=30.0, stride_length_s=None, batch_size=None): 280 | if isinstance(inputs, np.ndarray): 281 | logger.warning( 282 | "Numpy array passed as input - no sampling rate checks will be performed." 283 | "It is strongly recommended to pass the input as a dictionary with an 'array' key " 284 | "containing the numpy array representing the audio, and a 'sampling_rate' key " 285 | "containing the sampling rate associated with the audio array." 286 | "Failing to do so can result in silent errors that might be hard to debug." 287 | ) 288 | 289 | if isinstance(inputs, str): 290 | if inputs.startswith("http://") or inputs.startswith("https://"): 291 | # We need to actually check for a real protocol, otherwise it's impossible to use a local file 292 | # like http_huggingface_co.png 293 | inputs = requests.get(inputs).content 294 | else: 295 | with open(inputs, "rb") as f: 296 | inputs = f.read() 297 | 298 | if isinstance(inputs, bytes): 299 | inputs = ffmpeg_read(inputs, self.feature_extractor.sampling_rate) 300 | 301 | stride = None 302 | if isinstance(inputs, dict): 303 | stride = inputs.get("stride", None) 304 | # Accepting `"array"` which is the key defined in `datasets` for 305 | # better integration 306 | if not ("sampling_rate" in inputs and "array" in inputs): 307 | raise ValueError( 308 | "When passing a dictionary to FlaxWhisperPipline, the dict needs to contain an 'array' key " 309 | "containing the numpy array representing the audio, and a 'sampling_rate' key " 310 | "containing the sampling rate associated with the audio array." 311 | ) 312 | 313 | in_sampling_rate = inputs.get("sampling_rate") 314 | inputs = inputs.get("array", None) 315 | 316 | if in_sampling_rate != self.feature_extractor.sampling_rate: 317 | try: 318 | import librosa 319 | except ImportError as err: 320 | raise ImportError( 321 | "To support resampling audio files, please install 'librosa' and 'soundfile'." 322 | ) from err 323 | 324 | inputs = librosa.resample( 325 | inputs, orig_sr=in_sampling_rate, target_sr=self.feature_extractor.sampling_rate 326 | ) 327 | ratio = self.feature_extractor.sampling_rate / in_sampling_rate 328 | else: 329 | ratio = 1 330 | 331 | if not isinstance(inputs, np.ndarray): 332 | raise ValueError(f"We expect a numpy ndarray as input, got `{type(inputs)}`") 333 | if len(inputs.shape) != 1: 334 | raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline") 335 | 336 | if stride is not None: 337 | if stride[0] + stride[1] > inputs.shape[0]: 338 | raise ValueError("Stride is too large for input") 339 | 340 | # Stride needs to get the chunk length here, it's going to get 341 | # swallowed by the `feature_extractor` later, and then batching 342 | # can add extra data in the inputs, so we need to keep track 343 | # of the original length in the stride so we can cut properly. 344 | stride = (inputs.shape[0], int(round(stride[0] * ratio)), int(round(stride[1] * ratio))) 345 | 346 | if chunk_length_s: 347 | if stride_length_s is None: 348 | stride_length_s = chunk_length_s / 6 349 | 350 | if isinstance(stride_length_s, (int, float)): 351 | stride_length_s = [stride_length_s, stride_length_s] 352 | 353 | chunk_len = round(chunk_length_s * self.feature_extractor.sampling_rate) 354 | stride_left = round(stride_length_s[0] * self.feature_extractor.sampling_rate) 355 | stride_right = round(stride_length_s[1] * self.feature_extractor.sampling_rate) 356 | 357 | if chunk_len < stride_left + stride_right: 358 | raise ValueError("Chunk length must be superior to stride length") 359 | 360 | for item in self.chunk_iter_with_batch( 361 | inputs, 362 | chunk_len, 363 | stride_left, 364 | stride_right, 365 | batch_size, 366 | ): 367 | yield item 368 | else: 369 | processed = self.feature_extractor( 370 | inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="np" 371 | ) 372 | if stride is not None: 373 | processed["stride"] = stride 374 | yield processed 375 | 376 | def postprocess(self, model_outputs, return_timestamps=None, return_language=None): 377 | # unpack the outputs from list(dict(list)) to list(dict) 378 | model_outputs = [dict(zip(output, t)) for output in model_outputs for t in zip(*output.values())] 379 | 380 | time_precision = self.feature_extractor.chunk_length / self.model.config.max_source_positions 381 | # Send the chunking back to seconds, it's easier to handle in whisper 382 | sampling_rate = self.feature_extractor.sampling_rate 383 | for output in model_outputs: 384 | if "stride" in output: 385 | chunk_len, stride_left, stride_right = output["stride"] 386 | # Go back in seconds 387 | chunk_len /= sampling_rate 388 | stride_left /= sampling_rate 389 | stride_right /= sampling_rate 390 | output["stride"] = chunk_len, stride_left, stride_right 391 | 392 | text, optional = self.tokenizer._decode_asr( 393 | model_outputs, 394 | return_timestamps=return_timestamps, 395 | return_language=return_language, 396 | time_precision=time_precision, 397 | ) 398 | return {"text": text, **optional} 399 | 400 | def forward(self, model_inputs, batch_size=None, language=None, task=None, return_timestamps=False): 401 | # We need to keep track of some additional input arguments for post-processing so need to forward these on after running generation 402 | input_features = model_inputs.pop("input_features") 403 | input_batch_size = input_features.shape[0] 404 | 405 | if input_batch_size != batch_size: 406 | padding = np.zeros([batch_size - input_batch_size, *input_features.shape[1:]], input_features.dtype) 407 | input_features = np.concatenate([input_features, padding]) 408 | 409 | pred_ids = self.generate(input_features, language=language, task=task, return_timestamps=return_timestamps)[ 410 | :input_batch_size 411 | ] 412 | 413 | # tokenizer's decode method expects an extra dim - we insert it here for convenience 414 | out = {"tokens": pred_ids[:, None, :]} 415 | 416 | stride = model_inputs.pop("stride", None) 417 | if stride is not None: 418 | out["stride"] = stride 419 | 420 | return out 421 | 422 | def __call__( 423 | self, 424 | inputs, 425 | chunk_length_s=30.0, 426 | stride_length_s=None, 427 | batch_size=None, 428 | language=None, 429 | task=None, 430 | return_timestamps=None, 431 | generate_kwargs=None, 432 | ): 433 | """ 434 | Transcribe an audio input sequence to a text transcription, optionally with timestamps. 435 | 436 | Args: 437 | inputs (`np.ndarray` or `bytes` or `str` or `dict`): 438 | The inputs is either: 439 | - `str` that is the filename of the audio file, the file will be read at the correct sampling rate 440 | to get the waveform using *ffmpeg*. This requires *ffmpeg* to be installed on the system. 441 | - `bytes` is the byte content of an audio file and is interpreted by *ffmpeg* in the 442 | same way. 443 | - (`np.ndarray` of shape (n, ) of type `np.float32` or `np.float64`) 444 | Raw audio assumed to be at the correct sampling rate (16kHz). Note that no further sampling 445 | rate check will be done. 446 | - `dict` form can be used to pass raw audio sampled at arbitrary `sampling_rate` and let this 447 | pipeline do the resampling. The dict must be in the format `{"sampling_rate": int, "array": 448 | np.array}`. Optionally an additional argument `"stride": (left: int, right: int)` can be used to 449 | ask the pipeline to treat the first `left` samples and last `right` samples to be ignored in 450 | decoding (but used at inference to provide more context to the model). In general, this additional 451 | stride argument is not required. 452 | chunk_length_s (`float`, *optional*, defaults to 30.0): 453 | The input length for each chunk. If `chunk_length_s = 0` then chunking is disabled. By default, the chunk 454 | length is set 30.0s, equal to Whisper's context window. 455 | stride_length_s (`float`, *optional*, defaults to `chunk_length_s / 6`): 456 | The length of stride on the left and right of each chunk. Used only with `chunk_length_s > 0`. This enables 457 | the model to *see* more context and infer letters better than without this context but the pipeline 458 | discards the stride bits at the end to make the final reconstitution as perfect as possible. 459 | 460 | 461 | 462 | For more information on how to effectively use `stride_length_s`, refer to the [ASR chunking 463 | blog post](https://huggingface.co/blog/asr-chunking). 464 | 465 | 466 | batch_size (`int`, *optional*, defaults to the minimum per-device batch size, i.e. `jax.local_device_count()`): 467 | The batch size to be used in chunking transcription. Beneficial for transcribing long audio files. Passing 468 | a batch size in the `__call__` method will supersede any batch size passed to the `__init__`. 469 | task (`str`, *optional*): 470 | Task to use for generation, either `"transcribe"` or `"translate"`. Defaults to `"transcribe"`. 471 | language (`str`, *optional*): 472 | Language token to use for generation, can be either in the form of `"<|en|>"`, `"en"` or `"english"`. 473 | Defaults to `None`, meaning the language is automatically inferred from the audio input. 474 | return_timestamps (*optional*, `bool`): 475 | Whether to return timestamps in the prediction. Defaults to False. If set to true, the pipeline 476 | will return two keys in the output dictionary: `"text"` containing the text transcription, and `"chunks"` 477 | containing the transcription segments chunked by their utterance-level timestamps. 478 | 479 | Return: 480 | `Dict`: A dictionary with the following keys: 481 | - **text** (`str` ) -- The recognised text. 482 | - **chunks** (*optional(, `List[Dict]`) 483 | When using `return_timestamps`, the `chunks` will become a list containing all the various text 484 | chunks identified by the model, *e.g.* `[{"text": "hi ", "timestamps": (0.5,0.9), {"text": 485 | "there", "timestamps": (1.0, 1.5)}]`. The original full text can roughly be recovered by doing 486 | `"".join(chunk["text"] for chunk in output["chunks"])`. 487 | """ 488 | batch_size = batch_size if batch_size is not None else self.batch_size 489 | if batch_size % self.min_batch_size != 0: 490 | raise ValueError( 491 | f"Batch size must be a multiple of the number of JAX devices, but got batch size {batch_size} and num devices {self.min_batch_size}." 492 | ) 493 | 494 | dataloader = self.preprocess_batch( 495 | inputs, chunk_length_s=chunk_length_s, stride_length_s=stride_length_s, batch_size=batch_size 496 | ) 497 | model_outputs = [] 498 | # iterate over our chunked audio samples 499 | for batch in dataloader: 500 | model_outputs.append( 501 | self.forward( 502 | batch, batch_size=batch_size, language=language, task=task, return_timestamps=return_timestamps 503 | ) 504 | ) 505 | post_processed = self.postprocess(model_outputs, return_timestamps=return_timestamps) 506 | return post_processed 507 | -------------------------------------------------------------------------------- /whisper_jax/train_state.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The T5X Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utilities for partitioning.""" 17 | 18 | from typing import Any, Mapping, MutableMapping, Optional, Tuple 19 | 20 | import flax.core 21 | import flax.serialization 22 | import flax.struct 23 | import jax.numpy as jnp 24 | from flax import traverse_util 25 | from flax.core import scope as flax_scope 26 | from flax.linen import partitioning as flax_partitioning 27 | 28 | 29 | EMPTY_DICT = flax.core.freeze({}) 30 | FrozenDict = flax_scope.FrozenDict 31 | FrozenVariableDict = flax_scope.FrozenVariableDict 32 | MutableVariableDict = flax_scope.MutableVariableDict 33 | VariableDict = flax_scope.VariableDict 34 | 35 | 36 | def _validate_params_axes(params_axes, params): 37 | axis_names = flax_partitioning.get_axis_names(params_axes) 38 | missing_params_axes = set(traverse_util.flatten_dict(params, sep="/")) - set( 39 | traverse_util.flatten_dict(axis_names, sep="/") 40 | ) 41 | if missing_params_axes: 42 | raise ValueError(f"Missing axis names for parameters: {missing_params_axes}") 43 | 44 | 45 | def _split_variables_and_axes(variables_and_axes: FrozenVariableDict) -> Tuple[FrozenVariableDict, FrozenVariableDict]: 46 | """Splits `variables_and_axes` into two separate dicts with the same keys.""" 47 | # For each `key`, `key_axes` (if any) are its axes in `variables_and_axes`. 48 | variables = {} 49 | axes = {} 50 | for k, v in variables_and_axes.items(): 51 | if k.endswith("_axes"): 52 | axes[k[:-5]] = v # k without "_axes". 53 | _validate_params_axes(v, variables_and_axes[k[:-5]]) # k without "_axes". 54 | else: 55 | variables[k] = v 56 | return flax.core.freeze(variables), flax.core.freeze(axes) 57 | 58 | 59 | class InferenceState(flax.struct.PyTreeNode): 60 | """State compatible with FlaxOptimTrainState without optimizer state.""" 61 | 62 | step: jnp.ndarray 63 | params: flax_scope.FrozenVariableDict 64 | params_axes: Optional[flax_scope.FrozenVariableDict] = None 65 | flax_mutables: flax_scope.FrozenDict = EMPTY_DICT 66 | flax_mutables_axes: Optional[flax_scope.FrozenVariableDict] = None 67 | 68 | @classmethod 69 | def create(cls, model_variables: FrozenVariableDict) -> "InferenceState": 70 | other_variables, params = model_variables.pop("params") 71 | if "params_axes" in other_variables: 72 | other_variables, params_axes = other_variables.pop("params_axes") 73 | _validate_params_axes(params_axes, params) 74 | else: 75 | params_axes = None 76 | 77 | # Split other_variables into mutables and their corresponding axes. 78 | flax_mutables, flax_mutables_axes = _split_variables_and_axes(other_variables) 79 | flax_mutables_axes = flax_mutables_axes or None 80 | return InferenceState( 81 | step=jnp.array(0), 82 | params=params, 83 | params_axes=params_axes, 84 | flax_mutables=flax_mutables, 85 | flax_mutables_axes=flax_mutables_axes, 86 | ) 87 | 88 | @property 89 | def param_states(self) -> FrozenVariableDict: 90 | """The optimizer states of the parameters as a PyTree.""" 91 | raise NotImplementedError("InferenceState has no optimizer states.") 92 | 93 | def apply_gradient(self, *args, **kwargs) -> "InferenceState": 94 | raise NotImplementedError("InferenceState does not support `apply_gradient`.") 95 | 96 | def state_dict(self) -> MutableMapping[str, Any]: 97 | state_dict = {"target": flax.core.unfreeze(self.params), "state": {"step": self.step}} 98 | if self.flax_mutables: 99 | state_dict["flax_mutables"] = flax.core.unfreeze(self.flax_mutables) 100 | return state_dict 101 | 102 | def replace_step(self, step: jnp.ndarray) -> "InferenceState": 103 | return self.replace(step=step) 104 | 105 | def replace_params(self, params: FrozenVariableDict) -> "InferenceState": 106 | return self.replace(params=params) 107 | 108 | def replace_flax_mutables(self, flax_mutables: FrozenDict) -> "InferenceState": 109 | return self.replace(flax_mutables=flax_mutables) 110 | 111 | def restore_state(self, state_dict: Mapping[str, Any]) -> "InferenceState": 112 | return self.replace( 113 | params=flax.core.freeze(state_dict["target"]), 114 | step=state_dict["state"]["step"], 115 | flax_mutables=flax.core.freeze(state_dict["flax_mutables"]) 116 | if "flax_mutables" in state_dict 117 | else EMPTY_DICT, 118 | ) 119 | 120 | def as_logical_axes(self) -> "InferenceState": 121 | # Set step to None so that when the logical axes are processed by the 122 | # flax.partitioning.logical_to_mesh_axes function, it will be skipped 123 | # because jax.tree_map will short circut and never call the function on the 124 | # step. 125 | flax_mutables_axes = self.flax_mutables_axes or EMPTY_DICT 126 | return InferenceState( 127 | step=None, 128 | params=flax_partitioning.get_axis_names(self.params_axes), 129 | flax_mutables=flax_partitioning.get_axis_names(flax_mutables_axes), 130 | ) 131 | --------------------------------------------------------------------------------