├── .gitignore
├── LICENSE
├── Makefile
├── README.md
├── app
├── app.py
├── monitor.sh
└── run_app.sh
├── benchmarks
├── run_pipeline_dataloader.py
├── run_pjit.py
├── run_pjit_dataloader.py
├── run_pmap.py
└── run_pytorch.py
├── pyproject.toml
├── setup.py
├── whisper-jax-tpu.ipynb
└── whisper_jax
├── __init__.py
├── layers.py
├── modeling_flax_whisper.py
├── partitioner.py
├── pipeline.py
└── train_state.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/.gitignore
2 |
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 |
8 | # C extensions
9 | *.so
10 |
11 | # logs
12 | logs/
13 |
14 | # Distribution / packaging
15 | .Python
16 | build/
17 | develop-eggs/
18 | dist/
19 | downloads/
20 | eggs/
21 | .eggs/
22 | lib/
23 | lib64/
24 | parts/
25 | sdist/
26 | var/
27 | wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | .hypothesis/
54 | .pytest_cache/
55 |
56 | # Translations
57 | *.mo
58 | *.pot
59 |
60 | # Django stuff:
61 | *.log
62 | local_settings.py
63 | db.sqlite3
64 |
65 | # Flask stuff:
66 | instance/
67 | .webassets-cache
68 |
69 | # Scrapy stuff:
70 | .scrapy
71 |
72 | # Sphinx documentation
73 | docs/_build/
74 |
75 | # PyBuilder
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | .python-version
87 |
88 | # celery beat schedule file
89 | celerybeat-schedule
90 |
91 | # SageMath parsed files
92 | *.sage.py
93 |
94 | # Environments
95 | .env
96 | .venv
97 | env/
98 | venv/
99 | ENV/
100 | env.bak/
101 | venv.bak/
102 |
103 | # Spyder project settings
104 | .spyderproject
105 | .spyproject
106 |
107 | # Rope project settings
108 | .ropeproject
109 |
110 | # mkdocs documentation
111 | /site
112 |
113 | # mypy
114 | .mypy_cache/
115 | .dmypy.json
116 | dmypy.json
117 |
118 | # Pyre type checker
119 | .pyre/
120 |
121 | # vscode
122 | .vs
123 | .vscode
124 |
125 | # Pycharm
126 | .idea
127 |
128 | # TF code
129 | tensorflow_code
130 |
131 | # Models
132 | proc_data
133 |
134 | # examples
135 | runs
136 | /runs_old
137 | /wandb
138 | /examples/runs
139 | /examples/**/*.args
140 | /examples/rag/sweep
141 |
142 | # data
143 | /data
144 | serialization_dir
145 |
146 | # emacs
147 | *.*~
148 | debug.env
149 |
150 | # vim
151 | .*.swp
152 |
153 | #ctags
154 | tags
155 |
156 | # pre-commit
157 | .pre-commit*
158 |
159 | # .lock
160 | *.lock
161 |
162 | # DS_Store (MacOS)
163 | .DS_Store
164 | # RL pipelines may produce mp4 outputs
165 | *.mp4
166 |
167 | # dependencies
168 | /transformers
169 |
170 | # ruff
171 | .ruff_cache
172 |
173 | wandb
174 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright 2023- The HuggingFace Inc. team and The OpenAI Authors. All rights reserved.
2 |
3 | Apache License
4 | Version 2.0, January 2004
5 | http://www.apache.org/licenses/
6 |
7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
8 |
9 | 1. Definitions.
10 |
11 | "License" shall mean the terms and conditions for use, reproduction,
12 | and distribution as defined by Sections 1 through 9 of this document.
13 |
14 | "Licensor" shall mean the copyright owner or entity authorized by
15 | the copyright owner that is granting the License.
16 |
17 | "Legal Entity" shall mean the union of the acting entity and all
18 | other entities that control, are controlled by, or are under common
19 | control with that entity. For the purposes of this definition,
20 | "control" means (i) the power, direct or indirect, to cause the
21 | direction or management of such entity, whether by contract or
22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
23 | outstanding shares, or (iii) beneficial ownership of such entity.
24 |
25 | "You" (or "Your") shall mean an individual or Legal Entity
26 | exercising permissions granted by this License.
27 |
28 | "Source" form shall mean the preferred form for making modifications,
29 | including but not limited to software source code, documentation
30 | source, and configuration files.
31 |
32 | "Object" form shall mean any form resulting from mechanical
33 | transformation or translation of a Source form, including but
34 | not limited to compiled object code, generated documentation,
35 | and conversions to other media types.
36 |
37 | "Work" shall mean the work of authorship, whether in Source or
38 | Object form, made available under the License, as indicated by a
39 | copyright notice that is included in or attached to the work
40 | (an example is provided in the Appendix below).
41 |
42 | "Derivative Works" shall mean any work, whether in Source or Object
43 | form, that is based on (or derived from) the Work and for which the
44 | editorial revisions, annotations, elaborations, or other modifications
45 | represent, as a whole, an original work of authorship. For the purposes
46 | of this License, Derivative Works shall not include works that remain
47 | separable from, or merely link (or bind by name) to the interfaces of,
48 | the Work and Derivative Works thereof.
49 |
50 | "Contribution" shall mean any work of authorship, including
51 | the original version of the Work and any modifications or additions
52 | to that Work or Derivative Works thereof, that is intentionally
53 | submitted to Licensor for inclusion in the Work by the copyright owner
54 | or by an individual or Legal Entity authorized to submit on behalf of
55 | the copyright owner. For the purposes of this definition, "submitted"
56 | means any form of electronic, verbal, or written communication sent
57 | to the Licensor or its representatives, including but not limited to
58 | communication on electronic mailing lists, source code control systems,
59 | and issue tracking systems that are managed by, or on behalf of, the
60 | Licensor for the purpose of discussing and improving the Work, but
61 | excluding communication that is conspicuously marked or otherwise
62 | designated in writing by the copyright owner as "Not a Contribution."
63 |
64 | "Contributor" shall mean Licensor and any individual or Legal Entity
65 | on behalf of whom a Contribution has been received by Licensor and
66 | subsequently incorporated within the Work.
67 |
68 | 2. Grant of Copyright License. Subject to the terms and conditions of
69 | this License, each Contributor hereby grants to You a perpetual,
70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
71 | copyright license to reproduce, prepare Derivative Works of,
72 | publicly display, publicly perform, sublicense, and distribute the
73 | Work and such Derivative Works in Source or Object form.
74 |
75 | 3. Grant of Patent License. Subject to the terms and conditions of
76 | this License, each Contributor hereby grants to You a perpetual,
77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
78 | (except as stated in this section) patent license to make, have made,
79 | use, offer to sell, sell, import, and otherwise transfer the Work,
80 | where such license applies only to those patent claims licensable
81 | by such Contributor that are necessarily infringed by their
82 | Contribution(s) alone or by combination of their Contribution(s)
83 | with the Work to which such Contribution(s) was submitted. If You
84 | institute patent litigation against any entity (including a
85 | cross-claim or counterclaim in a lawsuit) alleging that the Work
86 | or a Contribution incorporated within the Work constitutes direct
87 | or contributory patent infringement, then any patent licenses
88 | granted to You under this License for that Work shall terminate
89 | as of the date such litigation is filed.
90 |
91 | 4. Redistribution. You may reproduce and distribute copies of the
92 | Work or Derivative Works thereof in any medium, with or without
93 | modifications, and in Source or Object form, provided that You
94 | meet the following conditions:
95 |
96 | (a) You must give any other recipients of the Work or
97 | Derivative Works a copy of this License; and
98 |
99 | (b) You must cause any modified files to carry prominent notices
100 | stating that You changed the files; and
101 |
102 | (c) You must retain, in the Source form of any Derivative Works
103 | that You distribute, all copyright, patent, trademark, and
104 | attribution notices from the Source form of the Work,
105 | excluding those notices that do not pertain to any part of
106 | the Derivative Works; and
107 |
108 | (d) If the Work includes a "NOTICE" text file as part of its
109 | distribution, then any Derivative Works that You distribute must
110 | include a readable copy of the attribution notices contained
111 | within such NOTICE file, excluding those notices that do not
112 | pertain to any part of the Derivative Works, in at least one
113 | of the following places: within a NOTICE text file distributed
114 | as part of the Derivative Works; within the Source form or
115 | documentation, if provided along with the Derivative Works; or,
116 | within a display generated by the Derivative Works, if and
117 | wherever such third-party notices normally appear. The contents
118 | of the NOTICE file are for informational purposes only and
119 | do not modify the License. You may add Your own attribution
120 | notices within Derivative Works that You distribute, alongside
121 | or as an addendum to the NOTICE text from the Work, provided
122 | that such additional attribution notices cannot be construed
123 | as modifying the License.
124 |
125 | You may add Your own copyright statement to Your modifications and
126 | may provide additional or different license terms and conditions
127 | for use, reproduction, or distribution of Your modifications, or
128 | for any such Derivative Works as a whole, provided Your use,
129 | reproduction, and distribution of the Work otherwise complies with
130 | the conditions stated in this License.
131 |
132 | 5. Submission of Contributions. Unless You explicitly state otherwise,
133 | any Contribution intentionally submitted for inclusion in the Work
134 | by You to the Licensor shall be under the terms and conditions of
135 | this License, without any additional terms or conditions.
136 | Notwithstanding the above, nothing herein shall supersede or modify
137 | the terms of any separate license agreement you may have executed
138 | with Licensor regarding such Contributions.
139 |
140 | 6. Trademarks. This License does not grant permission to use the trade
141 | names, trademarks, service marks, or product names of the Licensor,
142 | except as required for reasonable and customary use in describing the
143 | origin of the Work and reproducing the content of the NOTICE file.
144 |
145 | 7. Disclaimer of Warranty. Unless required by applicable law or
146 | agreed to in writing, Licensor provides the Work (and each
147 | Contributor provides its Contributions) on an "AS IS" BASIS,
148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
149 | implied, including, without limitation, any warranties or conditions
150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
151 | PARTICULAR PURPOSE. You are solely responsible for determining the
152 | appropriateness of using or redistributing the Work and assume any
153 | risks associated with Your exercise of permissions under this License.
154 |
155 | 8. Limitation of Liability. In no event and under no legal theory,
156 | whether in tort (including negligence), contract, or otherwise,
157 | unless required by applicable law (such as deliberate and grossly
158 | negligent acts) or agreed to in writing, shall any Contributor be
159 | liable to You for damages, including any direct, indirect, special,
160 | incidental, or consequential damages of any character arising as a
161 | result of this License or out of the use or inability to use the
162 | Work (including but not limited to damages for loss of goodwill,
163 | work stoppage, computer failure or malfunction, or any and all
164 | other commercial damages or losses), even if such Contributor
165 | has been advised of the possibility of such damages.
166 |
167 | 9. Accepting Warranty or Additional Liability. While redistributing
168 | the Work or Derivative Works thereof, You may choose to offer,
169 | and charge a fee for, acceptance of support, warranty, indemnity,
170 | or other liability obligations and/or rights consistent with this
171 | License. However, in accepting such obligations, You may act only
172 | on Your own behalf and on Your sole responsibility, not on behalf
173 | of any other Contributor, and only if You agree to indemnify,
174 | defend, and hold each Contributor harmless for any liability
175 | incurred by, or claims asserted against, such Contributor by reason
176 | of your accepting any such warranty or additional liability.
177 |
178 | END OF TERMS AND CONDITIONS
179 |
180 | APPENDIX: How to apply the Apache License to your work.
181 |
182 | To apply the Apache License to your work, attach the following
183 | boilerplate notice, with the fields enclosed by brackets "[]"
184 | replaced with your own identifying information. (Don't include
185 | the brackets!) The text should be enclosed in the appropriate
186 | comment syntax for the file format. We also recommend that a
187 | file or class name and description of purpose be included on the
188 | same "printed page" as the copyright notice for easier
189 | identification within third-party archives.
190 |
191 | Copyright [yyyy] [name of copyright owner]
192 |
193 | Licensed under the Apache License, Version 2.0 (the "License");
194 | you may not use this file except in compliance with the License.
195 | You may obtain a copy of the License at
196 |
197 | http://www.apache.org/licenses/LICENSE-2.0
198 |
199 | Unless required by applicable law or agreed to in writing, software
200 | distributed under the License is distributed on an "AS IS" BASIS,
201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
202 | See the License for the specific language governing permissions and
203 | limitations under the License.
204 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | check_dirs := .
2 |
3 | quality:
4 | black --check $(check_dirs)
5 | ruff $(check_dirs)
6 |
7 | style:
8 | black $(check_dirs)
9 | ruff $(check_dirs) --fix
10 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Whisper JAX
2 |
3 | This repository contains optimised JAX code for OpenAI's [Whisper Model](https://arxiv.org/abs/2212.04356), largely built
4 | on the 🤗 Hugging Face Transformers Whisper implementation. Compared to OpenAI's PyTorch code, Whisper JAX runs over **70x**
5 | faster, making it the fastest Whisper implementation available.
6 |
7 | The JAX code is compatible on CPU, GPU and TPU, and can be run standalone (see [Pipeline Usage](#pipeline-usage)) or
8 | as an inference endpoint (see [Creating an Endpoint](#creating-an-endpoint)).
9 |
10 | For a quick-start guide to running Whisper JAX on a Cloud TPU, refer to the following Kaggle notebook, where we transcribe 30 mins of audio in approx 30 sec:
11 |
12 | [](https://www.kaggle.com/code/sgandhi99/whisper-jax-tpu)
13 |
14 | The Whisper JAX model is also running as a demo on the Hugging Face Hub:
15 |
16 | [](https://huggingface.co/spaces/sanchit-gandhi/whisper-jax)
17 |
18 | ## Installation
19 |
20 | Whisper JAX was tested using Python 3.9 and JAX version 0.4.5. Installation assumes that you already have the latest
21 | version of the JAX package installed on your device. You can do so using the official JAX installation guide: https://github.com/google/jax#installation
22 |
23 | Once the appropriate version of JAX has been installed, Whisper JAX can be installed through pip:
24 | ```
25 | pip install git+https://github.com/sanchit-gandhi/whisper-jax.git
26 | ```
27 |
28 | To update the Whisper JAX package to the latest version, simply run:
29 | ```
30 | pip install --upgrade --no-deps --force-reinstall git+https://github.com/sanchit-gandhi/whisper-jax.git
31 | ```
32 |
33 | ## Pipeline Usage
34 |
35 | The recommended way of running Whisper JAX is through the [`FlaxWhisperPipline`](https://github.com/sanchit-gandhi/whisper-jax/blob/main/whisper_jax/pipeline.py#L57) abstraction class. This class handles all
36 | the necessary pre- and post-processing, as well as wrapping the generate method for data parallelism across accelerator devices.
37 |
38 | Whisper JAX makes use of JAX's [`pmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html) function for data parallelism across GPU/TPU devices. This function is _Just In Time (JIT)_
39 | compiled the first time it is called. Thereafter, the function will be _cached_, enabling it to be run in super-fast time:
40 |
41 | ```python
42 | from whisper_jax import FlaxWhisperPipline
43 |
44 | # instantiate pipeline
45 | pipeline = FlaxWhisperPipline("openai/whisper-large-v2")
46 |
47 | # JIT compile the forward call - slow, but we only do once
48 | text = pipeline("audio.mp3")
49 |
50 | # used cached function thereafter - super fast!!
51 | text = pipeline("audio.mp3")
52 | ```
53 |
54 | ### Half-Precision
55 |
56 | The model computation can be run in half-precision by passing the dtype argument when instantiating the pipeline. This will
57 | speed-up the computation quite considerably by storing intermediate tensors in half-precision. There is no change to the precision
58 | of the model weights.
59 |
60 | For most GPUs, the dtype should be set to `jnp.float16`. For A100 GPUs or TPUs, the dtype should be set to `jnp.bfloat16`:
61 | ```python
62 | from whisper_jax import FlaxWhisperPipline
63 | import jax.numpy as jnp
64 |
65 | # instantiate pipeline in bfloat16
66 | pipeline = FlaxWhisperPipline("openai/whisper-large-v2", dtype=jnp.bfloat16)
67 | ```
68 |
69 | ### Batching
70 | Whisper JAX also provides the option of _batching_ a single audio input across accelerator devices. The audio is first
71 | chunked into 30 second segments, and then chunks dispatched to the model to be transcribed in parallel. The resulting
72 | transcriptions are stitched back together at the boundaries to give a single, uniform transcription. In practice, batching
73 | provides a 10x speed-up compared to transcribing the audio samples sequentially, with a less than 1% penalty to the WER[^1], provided the batch size is selected large enough.
74 |
75 | To enable batching, pass the `batch_size` parameter when you instantiate the pipeline:
76 |
77 | ```python
78 | from whisper_jax import FlaxWhisperPipline
79 |
80 | # instantiate pipeline with batching
81 | pipeline = FlaxWhisperPipline("openai/whisper-large-v2", batch_size=16)
82 | ```
83 |
84 | ### Task
85 |
86 | By default, the pipeline transcribes the audio file in the language it was spoken in. For speech translation, set the
87 | `task` argument to `"translate"`:
88 |
89 | ```python
90 | # translate
91 | text = pipeline("audio.mp3", task="translate")
92 | ```
93 |
94 | ### Timestamps
95 |
96 | The [`FlaxWhisperPipline`](https://github.com/sanchit-gandhi/whisper-jax/blob/main/whisper_jax/pipeline.py#L57) also supports timestamp prediction. Note that enabling timestamps will require a second JIT compilation of the
97 | forward call, this time including the timestamp outputs:
98 |
99 | ```python
100 | # transcribe and return timestamps
101 | outputs = pipeline("audio.mp3", task="transcribe", return_timestamps=True)
102 | text = outputs["text"] # transcription
103 | chunks = outputs["chunks"] # transcription + timestamps
104 | ```
105 |
106 | ### Putting it all together
107 | In the following code snippet, we instantiate the model in bfloat16 precision with batching enabled, and transcribe the audio file
108 | returning timestamps tokens:
109 |
110 | ```python
111 | from whisper_jax import FlaxWhisperPipline
112 | import jax.numpy as jnp
113 |
114 | # instantiate pipeline with bfloat16 and enable batching
115 | pipeline = FlaxWhisperPipline("openai/whisper-large-v2", dtype=jnp.bfloat16, batch_size=16)
116 |
117 | # transcribe and return timestamps
118 | outputs = pipeline("audio.mp3", task="transcribe", return_timestamps=True)
119 | ```
120 |
121 | ## Model Usage
122 |
123 | The Whisper JAX model can use on a more granular level in much the same way as the original Hugging Face
124 | Transformers implementation. This requires the Whisper processor to be loaded separately to the model to handle the
125 | pre- and post-processing, and the generate function to be wrapped using `pmap` by hand:
126 |
127 | ```python
128 | import jax.numpy as jnp
129 | from datasets import load_dataset
130 | from flax.jax_utils import replicate
131 | from flax.training.common_utils import shard
132 | from jax import device_get, pmap
133 | from transformers import WhisperProcessor
134 |
135 | from whisper_jax import FlaxWhisperForConditionalGeneration
136 |
137 | # load the processor and model
138 | processor = WhisperProcessor.from_pretrained("openai/whisper-large-v2")
139 | model, params = FlaxWhisperForConditionalGeneration.from_pretrained(
140 | "openai/whisper-large-v2", dtype=jnp.bfloat16, _do_init=False,
141 | )
142 |
143 | def generate_fn(input_features):
144 | pred_ids = model.generate(
145 | input_features, task="transcribe", return_timestamps=False, max_length=model.config.max_length, params=params,
146 | )
147 | return pred_ids.sequences
148 |
149 | # pmap the generate function for data parallelism
150 | p_generate = pmap(generate_fn, "input_features")
151 | # replicate the parameters across devices
152 | params = replicate(params)
153 |
154 | # load a dummy sample from the LibriSpeech dataset
155 | ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
156 | sample = ds[0]["audio"]
157 |
158 | # pre-process: convert the audio array to log-mel input features
159 | input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="np").input_features
160 | # replicate the input features across devices for DP
161 | input_features = shard(input_features)
162 |
163 | # run the forward pass (JIT compiled the first time it is called)
164 | pred_ids = p_generate(input_features)
165 | output_ids = device_get(pred_ids.reshape(-1, model.config.max_length))
166 |
167 | # post-process: convert tokens ids to text string
168 | transcription = processor.batch_decode(pred_ids, skip_special_tokens=True)
169 | ```
170 |
171 | ## Available Models and Languages
172 | All Whisper models on the Hugging Face Hub with Flax weights are compatible with Whisper JAX. This includes, but is not limited to,
173 | the official OpenAI Whisper checkpoints:
174 |
175 | | Size | Parameters | English-only | Multilingual |
176 | |----------|------------|------------------------------------------------------|-----------------------------------------------------|
177 | | tiny | 39 M | [✓](https://huggingface.co/openai/whisper-tiny.en) | [✓](https://huggingface.co/openai/whisper-tiny) |
178 | | base | 74 M | [✓](https://huggingface.co/openai/whisper-base.en) | [✓](https://huggingface.co/openai/whisper-base) |
179 | | small | 244 M | [✓](https://huggingface.co/openai/whisper-small.en) | [✓](https://huggingface.co/openai/whisper-small) |
180 | | medium | 769 M | [✓](https://huggingface.co/openai/whisper-medium.en) | [✓](https://huggingface.co/openai/whisper-medium) |
181 | | large | 1550 M | x | [✓](https://huggingface.co/openai/whisper-large) |
182 | | large-v2 | 1550 M | x | [✓](https://huggingface.co/openai/whisper-large-v2) |
183 |
184 | Should you wish to use a fine-tuned Whisper checkpoint in Whisper JAX, you should first convert the PyTorch weights to Flax.
185 | This is straightforward through use of the `from_pt` argument, which will convert the PyTorch state dict to a frozen Flax
186 | parameter dictionary on the fly. You can then push the converted Flax weights to the Hub to be used directly in Flax
187 | the next time they are required. Note that converting weights from PyTorch to Flax requires both PyTorch and Flax to be installed.
188 |
189 | For example, to convert the fine-tuned checkpoint [`sanchit-gandhi/whisper-small-hi`](https://huggingface.co/sanchit-gandhi/whisper-small-hi) from the blog post [Fine-Tuning Whisper](https://huggingface.co/blog/fine-tune-whisper):
190 | ```python
191 | from whisper_jax import FlaxWhisperForConditionalGeneration, FlaxWhisperPipline
192 | import jax.numpy as jnp
193 |
194 | checkpoint_id = "sanchit-gandhi/whisper-small-hi"
195 | # convert PyTorch weights to Flax
196 | model = FlaxWhisperForConditionalGeneration.from_pretrained(checkpoint_id, from_pt=True)
197 | # push converted weights to the Hub
198 | model.push_to_hub(checkpoint_id)
199 |
200 | # now we can load the Flax weights directly as required
201 | pipeline = FlaxWhisperPipline(checkpoint_id, dtype=jnp.bfloat16, batch_size=16)
202 | ```
203 |
204 | ## Advanced Usage
205 | More advanced users may wish to explore different parallelisation techniques. The Whisper JAX code is
206 | built on-top of the [T5x codebase](https://github.com/google-research/t5x), meaning it can be run using model, activation, and data parallelism using the T5x
207 | partitioning convention. To use T5x partitioning, the logical axis rules and number of model partitions must be defined.
208 | For more details, the user is referred to the official T5x partitioning guide: https://github.com/google-research/t5x/blob/main/docs/usage/partitioning.md
209 |
210 | ### Pipeline
211 | The following code snippet demonstrates how data parallelism can be achieved using the pipeline `shard_params` method in
212 | an entirely equivalent way to `pmap`:
213 |
214 | ```python
215 | from whisper_jax import FlaxWhisperPipline
216 | import jax.numpy as jnp
217 |
218 | # 2D parameter and activation partitioning for DP
219 | logical_axis_rules_dp = (
220 | ("batch", "data"),
221 | ("mlp", None),
222 | ("heads", None),
223 | ("vocab", None),
224 | ("embed", None),
225 | ("embed", None),
226 | ("joined_kv", None),
227 | ("kv", None),
228 | ("length", None),
229 | ("num_mel", None),
230 | ("channels", None),
231 | )
232 |
233 | pipeline = FlaxWhisperPipline("openai/whisper-large-v2", dtype=jnp.bfloat16, batch_size=16)
234 | pipeline.shard_params(num_mp_partitions=1, logical_axis_rules=logical_axis_rules_dp)
235 | ```
236 |
237 | ### Model
238 | It is also possible to use the Whisper JAX model with T5x partitioning by defining a T5x inference state and T5x partitioner:
239 |
240 | ```python
241 | import jax
242 | import jax.numpy as jnp
243 | from flax.core.frozen_dict import freeze
244 | from jax.sharding import PartitionSpec as P
245 |
246 | from whisper_jax import FlaxWhisperForConditionalGeneration, InferenceState, PjitPartitioner
247 |
248 |
249 | # 2D parameter and activation partitioning for DP
250 | logical_axis_rules_dp = [
251 | ("batch", "data"),
252 | ("mlp", None),
253 | ("heads", None),
254 | ("vocab", None),
255 | ("embed", None),
256 | ("embed", None),
257 | ("joined_kv", None),
258 | ("kv", None),
259 | ("length", None),
260 | ("num_mel", None),
261 | ("channels", None),
262 | ]
263 |
264 | model, params = FlaxWhisperForConditionalGeneration.from_pretrained(
265 | "openai/whisper-large-v2",
266 | _do_init=False,
267 | dtype=jnp.bfloat16,
268 | )
269 |
270 |
271 | def init_fn():
272 | input_shape = (1, 80, 3000)
273 |
274 | input_features = jnp.zeros(input_shape, dtype="f4")
275 | input_features = input_features.at[(..., -1)].set(model.config.eos_token_id)
276 |
277 | decoder_input_ids = jnp.zeros((input_shape[0], 1), dtype="i4")
278 | decoder_attention_mask = jnp.ones_like(decoder_input_ids)
279 |
280 | batch_size, sequence_length = decoder_input_ids.shape
281 | decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
282 |
283 | rng = jax.random.PRNGKey(0)
284 | init_params = model.module.init(
285 | rng,
286 | input_features=input_features,
287 | decoder_input_ids=decoder_input_ids,
288 | decoder_attention_mask=decoder_attention_mask,
289 | decoder_position_ids=decoder_position_ids,
290 | return_dict=False,
291 | )
292 | return init_params
293 |
294 |
295 | # Axis names metadata
296 | param_axes = jax.eval_shape(init_fn)["params_axes"]
297 |
298 | # Create InferenceState, since the partitioner expects it
299 | state = InferenceState(
300 | step=jnp.array(0),
301 | params=freeze(model.params_shape_tree),
302 | params_axes=freeze(param_axes),
303 | flax_mutables=None,
304 | flax_mutables_axes=param_axes,
305 | )
306 |
307 | # Define the pjit partitioner with 1 model partition
308 | partitioner = PjitPartitioner(
309 | num_partitions=1,
310 | logical_axis_rules=logical_axis_rules_dp,
311 | )
312 |
313 | mesh_axes = partitioner.get_mesh_axes(state)
314 | params_spec = mesh_axes.params
315 |
316 | p_shard_params = partitioner.partition(model.to_bf16, (params_spec,), params_spec)
317 |
318 |
319 | def generate(params, input_features):
320 | output_ids = model.generate(input_features, params=params, max_length=model.config.max_length).sequences
321 | return output_ids
322 |
323 |
324 | p_generate = partitioner.partition(
325 | generate,
326 | in_axis_resources=(params_spec, P("data")),
327 | out_axis_resources=P("data"),
328 | )
329 |
330 | # This will auto-magically run in mesh context
331 | params = p_shard_params(freeze(params))
332 |
333 | # you can now run the forward pass with:
334 | # pred_ids = p_generate(input_features)
335 | ```
336 |
337 | ## Benchmarks
338 |
339 | We compare Whisper JAX to the official [OpenAI implementation](https://github.com/openai/whisper) and the [🤗 Transformers
340 | implementation](https://huggingface.co/docs/transformers/model_doc/whisper). We benchmark the models on audio samples of
341 | increasing length and report the average inference time in seconds over 10 repeat runs. For all three systems, we pass a
342 | pre-loaded audio file to the model and measure the time for the forward pass. Leaving the task of loading the audio file
343 | to the systems adds an equal offset to all the benchmark times, so the actual time for loading **and** transcribing an
344 | audio file will be higher than the reported numbers.
345 |
346 | OpenAI and Transformers both run in PyTorch on GPU. Whisper JAX runs in JAX on GPU and TPU. OpenAI transcribes the audio
347 | sequentially in the order it is spoken. Both Transformers and Whisper JAX use a batching algorithm, where chunks of audio
348 | are batched together and transcribed in parallel (see section [Batching](#batching)).
349 |
350 | **Table 1:** Average inference time in seconds for audio files of increasing length. GPU device is a single A100 40GB GPU.
351 | TPU device is a single TPU v4-8.
352 |
353 |
367 |
368 | ## Creating an Endpoint
369 |
370 | The Whisper JAX model is running as a demo on the Hugging Face Hub:
371 |
372 | [](https://huggingface.co/spaces/sanchit-gandhi/whisper-jax)
373 |
374 | However, at peak times there may be a queue of users that limit how quickly your audio input is transcribed. In this case,
375 | you may benefit from running the model yourself, such that you have unrestricted access to the Whisper JAX model.
376 |
377 | If you are just interested in running the model in a standalone Python script, refer to the Kaggle notebook Whisper JAX TPU:
378 |
379 | [](https://www.kaggle.com/code/sgandhi99/whisper-jax-tpu)
380 |
381 | Otherwise, we provide all the necessary code for creating an inference endpoint. To obtain this code, first clone the
382 | repository on the GPU/TPU on which you want to host the endpoint:
383 | ```
384 | git clone https://github.com/sanchit-gandhi/whisper-jax
385 | ```
386 |
387 | And then install Whisper JAX from source, with the required additional endpoint dependencies:
388 | ```
389 | cd whisper-jax
390 | pip install -e .["endpoint"]
391 | ```
392 |
393 | We recommend that you set-up an endpoint in the same zone/region as the one you are based in. This reduces the communication
394 | time between your local machine and the remote one, which can significantly reduce the overall request time.
395 |
396 | ### Gradio App
397 |
398 | The Python script [`app.py`](app/app.py) contains the code to launch a Gradio app with the Whisper large-v2 model.
399 | By default, it uses a batch size of 16 and bfloat16 half-precision. You should update these parameters depending on your
400 | GPU/TPU device (as explained in the sections on [Half-precision](#half-precision) and [Batching](#batching)).
401 |
402 | We can launch the Gradio app on port 7860 (default) on our GPU/TPU device through the following command:
403 | ```
404 | python app/app.py
405 | ```
406 |
407 | This will launch a Gradio demo with the same interface as the official Whisper JAX demo. To view the Gradio app remotely,
408 | we have two options:
409 |
410 | 1. Open the port 7860 on the GPU/TPU device to listen to all requests
411 | 2. Start an ngrok server on the GPU/TPU that redirects requests to port 7860
412 |
413 | To open the port 7860 on your GPU/TPU, refer to your hardware provider's firewall instructions (for GCP, these can be
414 | found [here](https://cloud.google.com/firewall/docs/using-firewalls)). Once you have opened port 7860, you should be able
415 | to access the gradio demo through the http address:
416 | ```
417 | http://DEVICE-IP:7860
418 | ```
419 | where `DEVICE-IP` is the public IP address of your GPU/TPU. We can verify this address is accessible by opening this
420 | http address in a browser window on our local machine.
421 |
422 | Alternatively, we can direct network requests to the Gradio app using ngrok. By using ngrok, we don't need to open the
423 | port 7860 on our GPU/TPU - ngrok will provide us with a public http address that will automatically redirect requests to
424 | port 7860 on our accelerator. However, in our experience, using ngrok was less reliable than a direct tunnel to port 7860,
425 | thus we recommend option 1 here where possible.
426 |
427 | To set-up ngrok on your GPU/TPU, first install ngrok according to the official [installation guide](https://ngrok.com/download).
428 | You should authenticate your ngrok account if you have one, otherwise your ngrok server will be time-limited to 2 hours.
429 | Once installed and authenticated, you can launch an ngrok server on port 7860:
430 | ```
431 | ngrok http 7860
432 | ```
433 | The ngrok http address will be of the form:
434 | ```
435 | https://NGROK-ADDRESS.ngrok.io
436 | ```
437 | which can be used to access the Gradio demo through a web browser.
438 |
439 | ### Sending Requests
440 |
441 | Independent of whether you've chosen to open the port 7860 or use ngrok, we're now ready to send audio file requests to our
442 | endpoint. To do this, we'll make use of the `gradio_client` library. If you already have a recent version of Gradio,
443 | then the `gradio_client` library is included as a dependency.
444 |
445 | Otherwise, the lightweight `gradio_client` package can be installed from pip and is tested to work with Python versions
446 | 3.9 or higher:
447 | ```
448 | pip install --upgrade gradio_client
449 | ```
450 |
451 | We can now send json requests to our endpoint using ngrok. The function `transcribe_audio` sends an audio file to our endpoint
452 | and returns the transcription:
453 |
454 | ```python
455 | from gradio_client import Client
456 |
457 | # make sure this URL matches your http web address
458 | API_URL = "http://DEVICE-IP:7860/" # if using port 7860
459 | API_URL = "https://NGROK-ADDRESS.ngrok.io/" # if using ngrok
460 |
461 | # set up the Gradio client
462 | client = Client(API_URL)
463 |
464 | def transcribe_audio(audio_path, task="transcribe", return_timestamps=False):
465 | """Function to transcribe an audio file using our endpoint"""
466 | text, runtime = client.predict(
467 | audio_path,
468 | task,
469 | return_timestamps,
470 | api_name="/predict_1",
471 | )
472 | return text
473 |
474 | # transcribe an audio file using our endpoint
475 | output = transcribe_audio("audio.mp3")
476 |
477 | # transcribe with timestamps
478 | output_with_timestamps = transcribe_audio("audio.mp3", return_timestamps=True)
479 | ```
480 |
481 | ## Acknowledgements
482 |
483 | * 🤗 Hugging Face Transformers for the base Whisper implementation, particularly to [andyehrenberg](https://github.com/andyehrenberg) for the [Flax Whisper PR](https://github.com/huggingface/transformers/pull/20479) and [ArthurZucker](https://github.com/ArthurZucker) for the batching algorithm
484 | * Gradio for their easy-to-use package for building ML demos, and [pcuenca](https://github.com/pcuenca) for the help in hooking the demo up to the TPU
485 | * Google's [TPU Research Cloud (TRC)](https://sites.research.google/trc/about/) programme for Cloud TPUs
486 | * Google's [t5x Repository](https://github.com/google-research/t5x) for the model partitioning framework
487 |
488 | [^1]: See WER results from Colab: https://colab.research.google.com/drive/1rS1L4YSJqKUH_3YxIQHBI982zso23wor?usp=sharing
489 |
--------------------------------------------------------------------------------
/app/app.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import math
3 | import os
4 | import tempfile
5 | import time
6 | from multiprocessing import Pool
7 |
8 | import gradio as gr
9 | import jax.numpy as jnp
10 | import numpy as np
11 | import yt_dlp as youtube_dl
12 | from jax.experimental.compilation_cache import compilation_cache as cc
13 | from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE
14 | from transformers.pipelines.audio_utils import ffmpeg_read
15 |
16 | from whisper_jax import FlaxWhisperPipline
17 |
18 |
19 | cc.initialize_cache("./jax_cache")
20 | checkpoint = "openai/whisper-large-v3"
21 |
22 | BATCH_SIZE = 32
23 | CHUNK_LENGTH_S = 30
24 | NUM_PROC = 32
25 | FILE_LIMIT_MB = 1000
26 | YT_LENGTH_LIMIT_S = 7200 # limit to 2 hour YouTube files
27 |
28 | title = "Whisper JAX: The Fastest Whisper API ⚡️"
29 |
30 | description = """Whisper JAX is an optimised implementation of the [Whisper model](https://huggingface.co/openai/whisper-large-v3) by OpenAI. It runs on JAX with a TPU v4-8 in the backend. Compared to PyTorch on an A100 GPU, it is over [**70x faster**](https://github.com/sanchit-gandhi/whisper-jax#benchmarks), making it the fastest Whisper API available.
31 |
32 | Note that at peak times, you may find yourself in the queue for this demo. When you submit a request, your queue position will be shown in the top right-hand side of the demo pane. Once you reach the front of the queue, your audio file will be transcribed, with the progress displayed through a progress bar.
33 |
34 | To skip the queue, you may wish to create your own inference endpoint, details for which can be found in the [Whisper JAX repository](https://github.com/sanchit-gandhi/whisper-jax#creating-an-endpoint).
35 | """
36 |
37 | article = "Whisper large-v3 model by OpenAI. Backend running JAX on a TPU v4-8 through the generous support of the [TRC](https://sites.research.google/trc/about/) programme. Whisper JAX [code](https://github.com/sanchit-gandhi/whisper-jax) and Gradio demo by 🤗 Hugging Face."
38 |
39 | language_names = sorted(TO_LANGUAGE_CODE.keys())
40 |
41 | logger = logging.getLogger("whisper-jax-app")
42 | logger.setLevel(logging.INFO)
43 | ch = logging.StreamHandler()
44 | ch.setLevel(logging.INFO)
45 | formatter = logging.Formatter("%(asctime)s;%(levelname)s;%(message)s", "%Y-%m-%d %H:%M:%S")
46 | ch.setFormatter(formatter)
47 | logger.addHandler(ch)
48 |
49 |
50 | def identity(batch):
51 | return batch
52 |
53 |
54 | # Copied from https://github.com/openai/whisper/blob/c09a7ae299c4c34c5839a76380ae407e7d785914/whisper/utils.py#L50
55 | def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = "."):
56 | if seconds is not None:
57 | milliseconds = round(seconds * 1000.0)
58 |
59 | hours = milliseconds // 3_600_000
60 | milliseconds -= hours * 3_600_000
61 |
62 | minutes = milliseconds // 60_000
63 | milliseconds -= minutes * 60_000
64 |
65 | seconds = milliseconds // 1_000
66 | milliseconds -= seconds * 1_000
67 |
68 | hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
69 | return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
70 | else:
71 | # we have a malformed timestamp so just return it as is
72 | return seconds
73 |
74 |
75 | if __name__ == "__main__":
76 | pipeline = FlaxWhisperPipline(checkpoint, dtype=jnp.bfloat16, batch_size=BATCH_SIZE)
77 | stride_length_s = CHUNK_LENGTH_S / 6
78 | chunk_len = round(CHUNK_LENGTH_S * pipeline.feature_extractor.sampling_rate)
79 | stride_left = stride_right = round(stride_length_s * pipeline.feature_extractor.sampling_rate)
80 | step = chunk_len - stride_left - stride_right
81 | pool = Pool(NUM_PROC)
82 |
83 | # do a pre-compile step so that the first user to use the demo isn't hit with a long transcription time
84 | logger.info("compiling forward call...")
85 | start = time.time()
86 | random_inputs = {
87 | "input_features": np.ones(
88 | (BATCH_SIZE, pipeline.model.config.num_mel_bins, 2 * pipeline.model.config.max_source_positions)
89 | )
90 | }
91 | random_timestamps = pipeline.forward(random_inputs, batch_size=BATCH_SIZE, return_timestamps=True)
92 | compile_time = time.time() - start
93 | logger.info(f"compiled in {compile_time}s")
94 |
95 | def tqdm_generate(inputs: dict, task: str, return_timestamps: bool, progress: gr.Progress):
96 | inputs_len = inputs["array"].shape[0]
97 | all_chunk_start_idx = np.arange(0, inputs_len, step)
98 | num_samples = len(all_chunk_start_idx)
99 | num_batches = math.ceil(num_samples / BATCH_SIZE)
100 | dummy_batches = list(
101 | range(num_batches)
102 | ) # Gradio progress bar not compatible with generator, see https://github.com/gradio-app/gradio/issues/3841
103 |
104 | dataloader = pipeline.preprocess_batch(inputs, chunk_length_s=CHUNK_LENGTH_S, batch_size=BATCH_SIZE)
105 | progress(0, desc="Pre-processing audio file...")
106 | logger.info("pre-processing audio file...")
107 | dataloader = pool.map(identity, dataloader)
108 | logger.info("done post-processing")
109 |
110 | model_outputs = []
111 | start_time = time.time()
112 | logger.info("transcribing...")
113 | # iterate over our chunked audio samples - always predict timestamps to reduce hallucinations
114 | for batch, _ in zip(dataloader, progress.tqdm(dummy_batches, desc="Transcribing...")):
115 | model_outputs.append(pipeline.forward(batch, batch_size=BATCH_SIZE, task=task, return_timestamps=True))
116 | runtime = time.time() - start_time
117 | logger.info("done transcription")
118 |
119 | logger.info("post-processing...")
120 | post_processed = pipeline.postprocess(model_outputs, return_timestamps=True)
121 | text = post_processed["text"]
122 | if return_timestamps:
123 | timestamps = post_processed.get("chunks")
124 | timestamps = [
125 | f"[{format_timestamp(chunk['timestamp'][0])} -> {format_timestamp(chunk['timestamp'][1])}] {chunk['text']}"
126 | for chunk in timestamps
127 | ]
128 | text = "\n".join(str(feature) for feature in timestamps)
129 | logger.info("done post-processing")
130 | return text, runtime
131 |
132 | def transcribe_chunked_audio(inputs, task, return_timestamps, progress=gr.Progress()):
133 | progress(0, desc="Loading audio file...")
134 | logger.info("loading audio file...")
135 | if inputs is None:
136 | logger.warning("No audio file")
137 | raise gr.Error("No audio file submitted! Please upload an audio file before submitting your request.")
138 | file_size_mb = os.stat(inputs).st_size / (1024 * 1024)
139 | if file_size_mb > FILE_LIMIT_MB:
140 | logger.warning("Max file size exceeded")
141 | raise gr.Error(
142 | f"File size exceeds file size limit. Got file of size {file_size_mb:.2f}MB for a limit of {FILE_LIMIT_MB}MB."
143 | )
144 |
145 | with open(inputs, "rb") as f:
146 | inputs = f.read()
147 |
148 | inputs = ffmpeg_read(inputs, pipeline.feature_extractor.sampling_rate)
149 | inputs = {"array": inputs, "sampling_rate": pipeline.feature_extractor.sampling_rate}
150 | logger.info("done loading")
151 | text, runtime = tqdm_generate(inputs, task=task, return_timestamps=return_timestamps, progress=progress)
152 | return text, runtime
153 |
154 | def _return_yt_html_embed(yt_url):
155 | video_id = yt_url.split("?v=")[-1]
156 | HTML_str = (
157 | f'
'
158 | "
"
159 | )
160 | return HTML_str
161 |
162 | def download_yt_audio(yt_url, filename):
163 | info_loader = youtube_dl.YoutubeDL()
164 | try:
165 | info = info_loader.extract_info(yt_url, download=False)
166 | except youtube_dl.utils.DownloadError as err:
167 | raise gr.Error(str(err))
168 |
169 | file_length = info["duration_string"]
170 | file_h_m_s = file_length.split(":")
171 | file_h_m_s = [int(sub_length) for sub_length in file_h_m_s]
172 | if len(file_h_m_s) == 1:
173 | file_h_m_s.insert(0, 0)
174 | if len(file_h_m_s) == 2:
175 | file_h_m_s.insert(0, 0)
176 |
177 | file_length_s = file_h_m_s[0] * 3600 + file_h_m_s[1] * 60 + file_h_m_s[2]
178 | if file_length_s > YT_LENGTH_LIMIT_S:
179 | yt_length_limit_hms = time.strftime("%HH:%MM:%SS", time.gmtime(YT_LENGTH_LIMIT_S))
180 | file_length_hms = time.strftime("%HH:%MM:%SS", time.gmtime(file_length_s))
181 | raise gr.Error(f"Maximum YouTube length is {yt_length_limit_hms}, got {file_length_hms} YouTube video.")
182 |
183 | ydl_opts = {"outtmpl": filename, "format": "worstvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best"}
184 | with youtube_dl.YoutubeDL(ydl_opts) as ydl:
185 | try:
186 | ydl.download([yt_url])
187 | except youtube_dl.utils.ExtractorError as err:
188 | raise gr.Error(str(err))
189 |
190 | def transcribe_youtube(yt_url, task, return_timestamps, progress=gr.Progress()):
191 | progress(0, desc="Loading audio file...")
192 | logger.info("loading youtube file...")
193 | html_embed_str = _return_yt_html_embed(yt_url)
194 | with tempfile.TemporaryDirectory() as tmpdirname:
195 | filepath = os.path.join(tmpdirname, "video.mp4")
196 | download_yt_audio(yt_url, filepath)
197 |
198 | with open(filepath, "rb") as f:
199 | inputs = f.read()
200 |
201 | inputs = ffmpeg_read(inputs, pipeline.feature_extractor.sampling_rate)
202 | inputs = {"array": inputs, "sampling_rate": pipeline.feature_extractor.sampling_rate}
203 | logger.info("done loading...")
204 | text, runtime = tqdm_generate(inputs, task=task, return_timestamps=return_timestamps, progress=progress)
205 | return html_embed_str, text, runtime
206 |
207 | microphone_chunked = gr.Interface(
208 | fn=transcribe_chunked_audio,
209 | inputs=[
210 | gr.Audio(source="microphone", type="filepath"),
211 | gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"),
212 | gr.Checkbox(value=False, label="Return timestamps"),
213 | ],
214 | outputs=[
215 | gr.Textbox(label="Transcription", show_copy_button=True),
216 | gr.Textbox(label="Transcription Time (s)"),
217 | ],
218 | allow_flagging="never",
219 | title=title,
220 | description=description,
221 | article=article,
222 | )
223 |
224 | audio_chunked = gr.Interface(
225 | fn=transcribe_chunked_audio,
226 | inputs=[
227 | gr.Audio(source="upload", label="Audio file", type="filepath"),
228 | gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"),
229 | gr.Checkbox(value=False, label="Return timestamps"),
230 | ],
231 | outputs=[
232 | gr.Textbox(label="Transcription", show_copy_button=True),
233 | gr.Textbox(label="Transcription Time (s)"),
234 | ],
235 | allow_flagging="never",
236 | title=title,
237 | description=description,
238 | article=article,
239 | )
240 |
241 | youtube = gr.Interface(
242 | fn=transcribe_youtube,
243 | inputs=[
244 | gr.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL"),
245 | gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"),
246 | gr.Checkbox(value=False, label="Return timestamps"),
247 | ],
248 | outputs=[
249 | gr.HTML(label="Video"),
250 | gr.Textbox(label="Transcription", show_copy_button=True),
251 | gr.Textbox(label="Transcription Time (s)"),
252 | ],
253 | allow_flagging="never",
254 | title=title,
255 | examples=[["https://www.youtube.com/watch?v=m8u-18Q0s7I", "transcribe", False]],
256 | cache_examples=False,
257 | description=description,
258 | article=article,
259 | )
260 |
261 | demo = gr.Blocks()
262 |
263 | with demo:
264 | gr.TabbedInterface([microphone_chunked, audio_chunked, youtube], ["Microphone", "Audio File", "YouTube"])
265 |
266 | demo.queue(max_size=5)
267 | demo.launch(server_name="0.0.0.0", show_api=False)
268 |
--------------------------------------------------------------------------------
/app/monitor.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | waiting=0
4 | check_server() {
5 | url=http://localhost:7860
6 | response_code=$(curl -o /dev/null -s -w "%{http_code}" --connect-timeout 2 $url)
7 | [[ $response_code -ne 200 ]] && {
8 | return 0
9 | }
10 | return 1
11 | }
12 |
13 | while [ 1 ]
14 | do
15 | # periodically clear the /tmp directory for files created > 30 mins ago so it doesn't fill up
16 | sudo find /tmp -type f -amin +30 -delete
17 | check_server
18 | if [[ $? -ne 1 ]]
19 | then
20 | if [[ $waiting -eq 0 ]]
21 | then
22 | waiting=1
23 | echo "Restarting"
24 | pkill -9 python
25 | #sudo lsof -t /dev/accel0 | xargs kill -9
26 | sleep 5
27 | mv log.txt log_`date +%Y%m%d%H%M%S`
28 | TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD=10000000000 python ./app.py &> log.txt &
29 | else
30 | echo "Waiting for restart"
31 | fi
32 | else
33 | if [[ $waiting -eq 1 ]]
34 | then
35 | waiting=0
36 | echo "Restarted"
37 | fi
38 | fi
39 | sleep 10
40 | done
41 |
--------------------------------------------------------------------------------
/app/run_app.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Sets up the environment, suitable to be run from a cron job, i.e.:
4 | # @reboot /home/sanchitgandhi/whisper-jax/run_app.sh
5 | cd ~/whisper-jax/app
6 | source ~/hf/bin/activate
7 | chmod +x monitor.sh
8 | ./monitor.sh &
9 |
--------------------------------------------------------------------------------
/benchmarks/run_pipeline_dataloader.py:
--------------------------------------------------------------------------------
1 | import math
2 | import time
3 |
4 | import numpy as np
5 | from torch.utils.data import DataLoader
6 | from transformers import WhisperProcessor
7 | from transformers.pipelines.base import no_collate_fn
8 | from transformers.pipelines.pt_utils import PipelineChunkIterator, PipelinePackIterator
9 |
10 |
11 | class Pipeline:
12 | """Relies on the data loaders defined in transformers Pipeline"""
13 |
14 | def __init__(self, checkpoint="openai/whisper-tiny.en"):
15 | self.checkpoint = checkpoint
16 | self.processor = WhisperProcessor.from_pretrained(self.checkpoint)
17 | self.feature_extractor = self.processor.feature_extractor
18 | self.tokenizer = self.processor.tokenizer
19 |
20 | @staticmethod
21 | def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right):
22 | inputs_len = inputs.shape[0]
23 | step = chunk_len - stride_left - stride_right
24 | for chunk_start_idx in range(0, inputs_len, step):
25 | chunk_end_idx = chunk_start_idx + chunk_len
26 | chunk = inputs[chunk_start_idx:chunk_end_idx]
27 | processed = feature_extractor(chunk, sampling_rate=feature_extractor.sampling_rate, return_tensors="np")
28 | _stride_left = 0 if chunk_start_idx == 0 else stride_left
29 | # all right strides must be full, otherwise it is the last item
30 | is_last = chunk_end_idx > inputs_len if stride_right > 0 else chunk_end_idx >= inputs_len
31 | _stride_right = 0 if is_last else stride_right
32 |
33 | chunk_len = chunk.shape[0]
34 | stride = (chunk_len, _stride_left, _stride_right)
35 | if chunk.shape[0] > _stride_left:
36 | yield {"is_last": is_last, "stride": stride, **processed}
37 | if is_last:
38 | break
39 |
40 | def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
41 | array = inputs.get("array")
42 | in_sampling_rate = inputs.get("sampling_rate")
43 | stride = inputs.get("stride", None)
44 |
45 | if in_sampling_rate != self.feature_extractor.sampling_rate:
46 | ratio = self.feature_extractor.sampling_rate / in_sampling_rate
47 | else:
48 | ratio = 1
49 |
50 | if stride is not None:
51 | if stride[0] + stride[1] > inputs.shape[0]:
52 | raise ValueError("Stride is too large for input")
53 |
54 | # Stride needs to get the chunk length here, it's going to get
55 | # swallowed by the `feature_extractor` later, and then batching
56 | # can add extra data in the inputs, so we need to keep track
57 | # of the original length in the stride so we can cut properly.
58 | stride = (inputs.shape[0], int(round(stride[0] * ratio)), int(round(stride[1] * ratio)))
59 |
60 | if chunk_length_s:
61 | if stride_length_s is None:
62 | stride_length_s = chunk_length_s / 6
63 |
64 | if isinstance(stride_length_s, (int, float)):
65 | stride_length_s = [stride_length_s, stride_length_s]
66 |
67 | chunk_len = round(chunk_length_s * self.feature_extractor.sampling_rate)
68 | stride_left = round(stride_length_s[0] * self.feature_extractor.sampling_rate)
69 | stride_right = round(stride_length_s[1] * self.feature_extractor.sampling_rate)
70 |
71 | if chunk_len < stride_left + stride_right:
72 | raise ValueError("Chunk length must be superior to stride length")
73 |
74 | for item in self.chunk_iter(
75 | array,
76 | self.feature_extractor,
77 | chunk_len,
78 | stride_left,
79 | stride_right,
80 | ):
81 | yield item
82 | else:
83 | processed = self.feature_extractor(
84 | array, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="np"
85 | )
86 | if stride is not None:
87 | processed["stride"] = stride
88 | yield {"is_last": True, **processed}
89 |
90 | def forward(self, model_inputs, return_timestamps=False, generate_kwargs=None):
91 | print(model_inputs["stride"])
92 | if generate_kwargs is None:
93 | generate_kwargs = {}
94 |
95 | if return_timestamps:
96 | generate_kwargs["return_timestamps"] = return_timestamps
97 | is_last = model_inputs.pop("is_last")
98 |
99 | out = {}
100 | stride = model_inputs.pop("stride", None)
101 | if stride is not None:
102 | out["stride"] = stride
103 |
104 | return {"is_last": is_last, **out}
105 |
106 | def __call__(
107 | self,
108 | inputs,
109 | chunk_length_s=0,
110 | stride_length_s=None,
111 | return_timestamps=None,
112 | return_language=None,
113 | generate_kwargs=None,
114 | batch_size=4,
115 | num_workers=1,
116 | ):
117 | dataset = PipelineChunkIterator(
118 | [inputs], self.preprocess, {"chunk_length_s": chunk_length_s, "stride_length_s": stride_length_s}
119 | )
120 | collate_fn = no_collate_fn if batch_size == 1 else pad_collate_fn()
121 | dataloader = DataLoader(dataset, num_workers=num_workers, batch_size=batch_size, collate_fn=collate_fn)
122 | model_iterator = PipelinePackIterator(dataloader, self.forward, {}, loader_batch_size=batch_size)
123 |
124 | for batch in enumerate(model_iterator):
125 | continue
126 |
127 |
128 | def _pad(items, key):
129 | if isinstance(items[0][key], np.ndarray):
130 | if key == "input_features":
131 | # this is probably a mel spectrogram batched
132 | return np.concatenate([item[key] for item in items], axis=0)
133 | else:
134 | return [item[key] for item in items]
135 |
136 |
137 | def pad_collate_fn():
138 | def inner(items):
139 | keys = set(items[0].keys())
140 | for item in items:
141 | if set(item.keys()) != keys:
142 | raise ValueError(
143 | f"The elements of the batch contain different keys. Cannot batch them ({set(item.keys())} !="
144 | f" {keys})"
145 | )
146 | # input_values, input_pixels, input_ids, ...
147 | padded = {}
148 | for key in keys:
149 | padded[key] = _pad(items, key)
150 | return padded
151 |
152 | return inner
153 |
154 |
155 | class ManualIterator:
156 | """Manual implementation"""
157 |
158 | def __init__(self, checkpoint="openai/whisper-tiny.en"):
159 | self.checkpoint = checkpoint
160 | self.processor = WhisperProcessor.from_pretrained(self.checkpoint)
161 | self.feature_extractor = self.processor.feature_extractor
162 | self.tokenizer = self.processor.tokenizer
163 |
164 | @staticmethod
165 | def chunk_iter_with_batch(inputs, feature_extractor, chunk_len, stride_left, stride_right, batch_size):
166 | inputs_len = inputs.shape[0]
167 | step = chunk_len - stride_left - stride_right
168 |
169 | all_chunk_start_idx = np.arange(0, inputs_len, step)
170 | num_samples = len(all_chunk_start_idx)
171 |
172 | num_batches = math.ceil(num_samples / batch_size)
173 | batch_idx = np.array_split(np.arange(num_samples), num_batches)
174 |
175 | for i, idx in enumerate(batch_idx):
176 | chunk_start_idx = all_chunk_start_idx[idx]
177 |
178 | chunk_end_idx = chunk_start_idx + chunk_len
179 |
180 | chunks = [inputs[chunk_start:chunk_end] for chunk_start, chunk_end in zip(chunk_start_idx, chunk_end_idx)]
181 | processed = feature_extractor(chunks, sampling_rate=feature_extractor.sampling_rate, return_tensors="np")
182 |
183 | _stride_left = np.where(chunk_start_idx == 0, 0, stride_left)
184 | is_last = np.where(stride_right > 0, chunk_end_idx > inputs_len, chunk_end_idx >= inputs_len)
185 | _stride_right = np.where(is_last, 0, stride_right)
186 |
187 | chunk_lens = [chunk.shape[0] for chunk in chunks]
188 | strides = [
189 | (chunk_l, _stride_l, _stride_r)
190 | for chunk_l, _stride_l, _stride_r in zip(chunk_lens, _stride_left, _stride_right)
191 | ]
192 |
193 | yield {"stride": strides, **processed}
194 |
195 | def preprocess_batch(self, inputs, chunk_length_s=0, stride_length_s=None, batch_size=None):
196 | array = inputs.get("array")
197 | in_sampling_rate = inputs.get("sampling_rate")
198 | stride = inputs.get("stride", None)
199 |
200 | if in_sampling_rate != self.feature_extractor.sampling_rate:
201 | ratio = self.feature_extractor.sampling_rate / in_sampling_rate
202 | else:
203 | ratio = 1
204 |
205 | if stride is not None:
206 | if stride[0] + stride[1] > inputs.shape[0]:
207 | raise ValueError("Stride is too large for input")
208 |
209 | # Stride needs to get the chunk length here, it's going to get
210 | # swallowed by the `feature_extractor` later, and then batching
211 | # can add extra data in the inputs, so we need to keep track
212 | # of the original length in the stride so we can cut properly.
213 | stride = (inputs.shape[0], int(round(stride[0] * ratio)), int(round(stride[1] * ratio)))
214 |
215 | if chunk_length_s:
216 | if stride_length_s is None:
217 | stride_length_s = chunk_length_s / 6
218 |
219 | if isinstance(stride_length_s, (int, float)):
220 | stride_length_s = [stride_length_s, stride_length_s]
221 |
222 | chunk_len = round(chunk_length_s * self.feature_extractor.sampling_rate)
223 | stride_left = round(stride_length_s[0] * self.feature_extractor.sampling_rate)
224 | stride_right = round(stride_length_s[1] * self.feature_extractor.sampling_rate)
225 |
226 | if chunk_len < stride_left + stride_right:
227 | raise ValueError("Chunk length must be superior to stride length")
228 |
229 | for item in self.chunk_iter_with_batch(
230 | array,
231 | self.feature_extractor,
232 | chunk_len,
233 | stride_left,
234 | stride_right,
235 | batch_size,
236 | ):
237 | yield item
238 | else:
239 | processed = self.feature_extractor(
240 | array, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="np"
241 | )
242 | if stride is not None:
243 | processed["stride"] = stride
244 | yield processed
245 |
246 | def forward(self, model_inputs, return_timestamps=False, generate_kwargs=None):
247 | if generate_kwargs is None:
248 | generate_kwargs = {}
249 |
250 | if return_timestamps:
251 | generate_kwargs["return_timestamps"] = return_timestamps
252 |
253 | out = {}
254 | stride = model_inputs.pop("stride", None)
255 | if stride is not None:
256 | out["stride"] = stride
257 |
258 | def __call__(
259 | self,
260 | inputs,
261 | chunk_length_s=0,
262 | stride_length_s=None,
263 | return_timestamps=None,
264 | return_language=None,
265 | generate_kwargs=None,
266 | batch_size=4,
267 | num_workers=1,
268 | ):
269 | dataloader = self.preprocess_batch(inputs, chunk_length_s, stride_length_s, batch_size)
270 |
271 | for batch in dataloader:
272 | print(batch["stride"])
273 | self.forward(batch)
274 |
275 |
276 | pipeline = Pipeline()
277 | custom_pipeline = ManualIterator()
278 |
279 | powers = np.logspace(0, 6, 7, base=2)
280 | input_lengths = [30, 60, 120, 240]
281 |
282 | for input_len in input_lengths:
283 | print(f"=== Input len {input_len} ===")
284 | inputs = {"array": np.ones(16000 * (int(input_len))), "sampling_rate": 16000}
285 |
286 | # first run our custom one since it doesn't consume the audio input
287 | start = time.time()
288 | custom_pipeline(inputs, chunk_length_s=30)
289 | runtime = time.time() - start
290 | print(f"Manual: {runtime}")
291 |
292 | # now run the transformers based one
293 | start = time.time()
294 | pipeline(inputs, chunk_length_s=30)
295 | runtime = time.time() - start
296 | print(f"Pipeline: {runtime}")
297 |
298 |
299 | """
300 | Results:
301 | === Input len 30.0 ===
302 | Manual: 0.09786629676818848
303 | Pipeline: 0.4056670665740967
304 |
305 | === Input len 60.0 ===
306 | Manual: 0.14911437034606934
307 | Pipeline: 0.6139914989471436
308 |
309 | === Input len 120.0 ===
310 | Manual: 0.29366159439086914
311 | Pipeline: 0.9421713352203369
312 |
313 | === Input len 240.0 ===
314 | Manual: 0.5912315845489502
315 | Pipeline: 1.2646639347076416
316 |
317 | === Input len 480.0 ===
318 | Manual: 1.1709322929382324
319 | Pipeline: 1.6115500926971436
320 |
321 | === Input len 960.0 ===
322 | Manual: 2.373847246170044
323 | Pipeline: 3.2664968967437744
324 |
325 | === Input len 1920.0 ===
326 | Manual: 4.619845151901245
327 | Pipeline: 5.50755500793457
328 | """
329 |
--------------------------------------------------------------------------------
/benchmarks/run_pjit.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 |
4 | import datasets
5 | import jax
6 | import jax.numpy as jnp
7 | from datasets import concatenate_datasets, load_dataset
8 | from flax.core.frozen_dict import freeze
9 | from jax.experimental.compilation_cache import compilation_cache as cc
10 | from jax.sharding import PartitionSpec as P
11 | from transformers import WhisperConfig, WhisperProcessor
12 |
13 | from whisper_jax import FlaxWhisperForConditionalGeneration, InferenceState, PjitPartitioner
14 |
15 |
16 | datasets.logging.set_verbosity(datasets.logging.CRITICAL)
17 |
18 | cc.initialize_cache("./jax_cache")
19 | jax.config.update("jax_array", True)
20 |
21 |
22 | def parse_args():
23 | parser = argparse.ArgumentParser(description="Benchmark Whisper large-v2")
24 | parser.add_argument(
25 | "--model_parallel_submesh",
26 | type=int,
27 | nargs="+",
28 | default=(2, 2, 1, 1),
29 | help="Model parallel submesh.",
30 | )
31 | args = parser.parse_args()
32 | return args
33 |
34 |
35 | BATCH_SIZES = [4, 8, 16, 32]
36 | NUM_BATCHES = 100
37 | NUM_TOKENS = 25
38 | CHECKPOINT = "large-v2"
39 |
40 | # 2D parameter and activation partitioning for DP
41 | logical_axis_rules_dp = [
42 | ("batch", "data"),
43 | ("mlp", None),
44 | ("heads", None),
45 | ("vocab", None),
46 | ("embed", None),
47 | ("embed", None),
48 | ("joined_kv", None),
49 | ("kv", None),
50 | ("length", None),
51 | ("num_mel", None),
52 | ("channels", None),
53 | ]
54 |
55 |
56 | def main():
57 | args = parse_args()
58 | print(args.model_parallel_submesh)
59 | # processors/tokenizers are the same for all models, so just load from tiny and preprocess once
60 | processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
61 |
62 | def preprocess(batch):
63 | batch["input_features"] = processor(
64 | batch["audio"]["array"], sampling_rate=16000, return_tensors="np"
65 | ).input_features[0]
66 | return batch
67 |
68 | librispeech = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
69 | dataset_processed = librispeech.map(preprocess, remove_columns=librispeech.column_names)
70 |
71 | config = WhisperConfig.from_pretrained(f"openai/whisper-{CHECKPOINT}")
72 | model = FlaxWhisperForConditionalGeneration(config, _do_init=False, dtype=jnp.bfloat16)
73 | # to init the params
74 | params = model.init_weights(model.key, model.input_shape)
75 |
76 | def init_fn():
77 | input_shape = (1, 80, 3000)
78 |
79 | input_features = jnp.zeros(input_shape, dtype="f4")
80 | input_features = input_features.at[(..., -1)].set(model.config.eos_token_id)
81 |
82 | decoder_input_ids = jnp.zeros((input_shape[0], 1), dtype="i4")
83 | decoder_attention_mask = jnp.ones_like(decoder_input_ids)
84 |
85 | batch_size, sequence_length = decoder_input_ids.shape
86 | decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
87 |
88 | rng = jax.random.PRNGKey(0)
89 | init_params = model.module.init(
90 | rng,
91 | input_features=input_features,
92 | decoder_input_ids=decoder_input_ids,
93 | decoder_attention_mask=decoder_attention_mask,
94 | decoder_position_ids=decoder_position_ids,
95 | return_dict=False,
96 | )
97 | return init_params
98 |
99 | # Axis names metadata
100 | param_axes = jax.eval_shape(init_fn)["params_axes"]
101 |
102 | # Create InferenceState, since the partitioner expects it
103 | state = InferenceState(
104 | step=jnp.array(0),
105 | params=freeze(model.params_shape_tree),
106 | params_axes=freeze(param_axes),
107 | flax_mutables=None,
108 | flax_mutables_axes=param_axes,
109 | )
110 |
111 | partitioner = PjitPartitioner(
112 | num_partitions=1,
113 | logical_axis_rules=logical_axis_rules_dp,
114 | )
115 |
116 | mesh_axes = partitioner.get_mesh_axes(state)
117 | params_spec = mesh_axes.params
118 |
119 | p_shard_params = partitioner.partition(model.to_bf16, (params_spec,), params_spec)
120 |
121 | def generate(params, input_features):
122 | output_ids = model.generate(input_features, params=params, max_new_tokens=25).sequences
123 | return output_ids
124 |
125 | p_generate = partitioner.partition(
126 | generate,
127 | in_axis_resources=(params_spec, P("data")),
128 | out_axis_resources=P("data"),
129 | )
130 |
131 | # This will auto-magically run in mesh context
132 | params = p_shard_params(freeze(params))
133 |
134 | for batch_size in BATCH_SIZES:
135 | eval_dataset = dataset_processed.select(range(batch_size // 2))
136 | eval_dataset = concatenate_datasets([eval_dataset for _ in range(2 * NUM_BATCHES)])
137 |
138 | eval_dataloader = eval_dataset.with_format("numpy").iter(batch_size=batch_size)
139 |
140 | # warm-up step
141 | batch = next(iter(eval_dataloader))
142 | p_generate(freeze(params), batch["input_features"])
143 |
144 | start = time.time()
145 | for batch in eval_dataloader:
146 | p_generate(freeze(params), batch["input_features"])
147 | runtime = time.time() - start
148 |
149 | print(f"{batch_size}: {runtime:.06}")
150 |
151 |
152 | if __name__ == "__main__":
153 | main()
154 |
--------------------------------------------------------------------------------
/benchmarks/run_pjit_dataloader.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import jax
4 | import jax.numpy as jnp
5 | import numpy as np
6 | from datasets import load_dataset
7 | from flax.core.frozen_dict import freeze
8 | from jax.experimental.compilation_cache import compilation_cache as cc
9 | from jax.sharding import PartitionSpec as P
10 | from torch.utils.data import DataLoader
11 | from tqdm import tqdm
12 | from transformers import WhisperProcessor
13 |
14 | from whisper_jax import FlaxWhisperForConditionalGeneration, InferenceState, PjitPartitioner
15 |
16 |
17 | cc.initialize_cache("./jax_cache")
18 | jax.config.update("jax_array", True)
19 |
20 | BATCH_SIZE = 128
21 | NUM_TOKENS = 256
22 | NUM_BATCHES = 5
23 | NUM_WORKERS = 4
24 | CHECKPOINT = "large-v2"
25 | STREAMING = True
26 |
27 | # 2D parameter and activation partitioning for DP
28 | logical_axis_rules_dp = [
29 | ("batch", "data"),
30 | ("mlp", None),
31 | ("heads", None),
32 | ("vocab", None),
33 | ("embed", None),
34 | ("embed", None),
35 | ("joined_kv", None),
36 | ("kv", None),
37 | ("length", None),
38 | ("num_mel", None),
39 | ("channels", None),
40 | ]
41 |
42 | model, params = FlaxWhisperForConditionalGeneration.from_pretrained(
43 | f"openai/whisper-{CHECKPOINT}",
44 | _do_init=False,
45 | dtype=jnp.bfloat16,
46 | )
47 |
48 |
49 | def init_fn():
50 | input_shape = (1, 80, 3000)
51 |
52 | input_features = jnp.zeros(input_shape, dtype="f4")
53 | input_features = input_features.at[(..., -1)].set(model.config.eos_token_id)
54 |
55 | decoder_input_ids = jnp.zeros((input_shape[0], 1), dtype="i4")
56 | decoder_attention_mask = jnp.ones_like(decoder_input_ids)
57 |
58 | batch_size, sequence_length = decoder_input_ids.shape
59 | decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
60 |
61 | rng = jax.random.PRNGKey(0)
62 | init_params = model.module.init(
63 | rng,
64 | input_features=input_features,
65 | decoder_input_ids=decoder_input_ids,
66 | decoder_attention_mask=decoder_attention_mask,
67 | decoder_position_ids=decoder_position_ids,
68 | return_dict=False,
69 | )
70 | return init_params
71 |
72 |
73 | # Axis names metadata
74 | param_axes = jax.eval_shape(init_fn)["params_axes"]
75 |
76 | # Create InferenceState, since the partitioner expects it
77 | state = InferenceState(
78 | step=jnp.array(0),
79 | params=freeze(model.params_shape_tree),
80 | params_axes=freeze(param_axes),
81 | flax_mutables=None,
82 | flax_mutables_axes=param_axes,
83 | )
84 |
85 | partitioner = PjitPartitioner(
86 | num_partitions=1,
87 | logical_axis_rules=logical_axis_rules_dp,
88 | )
89 |
90 | mesh_axes = partitioner.get_mesh_axes(state)
91 | params_spec = mesh_axes.params
92 |
93 | p_shard_params = partitioner.partition(model.to_bf16, (params_spec,), params_spec)
94 |
95 |
96 | def generate(params, input_features):
97 | output_ids = model.generate(input_features, params=params, max_length=NUM_TOKENS).sequences
98 | return output_ids
99 |
100 |
101 | p_generate = partitioner.partition(
102 | generate,
103 | in_axis_resources=(params_spec, P("data")),
104 | out_axis_resources=P("data"),
105 | )
106 |
107 | # This will auto-magically run in mesh context
108 | params = p_shard_params(freeze(params))
109 |
110 | # warm-up
111 | pred_ids = p_generate(freeze(params), np.ones((BATCH_SIZE, 80, 3000)))
112 |
113 | # processors/tokenizers are the same for all models, so just load from tiny and preprocess once
114 | processor = WhisperProcessor.from_pretrained("openai/whisper-large-v2")
115 |
116 |
117 | def preprocess(batch):
118 | batch["input_features"] = processor(
119 | batch["audio"]["array"], sampling_rate=16000, return_tensors="np"
120 | ).input_features[0]
121 | return batch
122 |
123 |
124 | librispeech = load_dataset("speechcolab/gigaspeech", "l", split="train", streaming=STREAMING, use_auth_token=True)
125 | librispeech_features = librispeech.features.keys()
126 |
127 | librispeech_processed = librispeech.map(preprocess, remove_columns=librispeech_features)
128 |
129 | eval_dataloader = DataLoader(
130 | librispeech_processed, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, persistent_workers=True
131 | )
132 |
133 | all_load_times = 0
134 | all_runtimes = 0
135 |
136 | load_start = time.time()
137 | for i, batch in tqdm(enumerate(eval_dataloader)):
138 | input_features = np.asarray(batch.pop("input_features"))
139 | load_time = time.time() - load_start
140 | all_load_times += load_time
141 |
142 | input_batch_size = input_features.shape[0]
143 |
144 | if input_batch_size != BATCH_SIZE:
145 | padding = np.zeros([BATCH_SIZE - input_batch_size, *input_features.shape[1:]], input_features.dtype)
146 | input_features = np.concatenate([input_features, padding])
147 |
148 | generate_start = time.time()
149 | pred_ids = p_generate(freeze(params), input_features)[:input_batch_size]
150 | generate_runtime = time.time() - generate_start
151 | all_runtimes += generate_runtime
152 |
153 | pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
154 |
155 | if i + 1 == NUM_BATCHES:
156 | break
157 |
158 | print(all_load_times)
159 | print(all_runtimes)
160 |
--------------------------------------------------------------------------------
/benchmarks/run_pmap.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import jax
4 | import jax.numpy as jnp
5 | from datasets import concatenate_datasets, load_dataset
6 | from flax import jax_utils
7 | from flax.training.common_utils import shard
8 | from transformers import FlaxWhisperForConditionalGeneration, WhisperProcessor
9 |
10 |
11 | BATCH_SIZES = [4, 8, 16, 32, 64, 128]
12 | NUM_BATCHES = 100
13 | NUM_TOKENS = 25
14 |
15 | model, params = FlaxWhisperForConditionalGeneration.from_pretrained(
16 | "openai/whisper-large-v2",
17 | _do_init=False,
18 | dtype=jnp.bfloat16,
19 | )
20 |
21 | params = jax_utils.replicate(params)
22 |
23 |
24 | def generate_fn(batch):
25 | pred_ids = model.generate(batch, params=params, max_new_tokens=NUM_TOKENS, min_new_tokens=NUM_TOKENS)
26 | return pred_ids.sequences
27 |
28 |
29 | p_generate_fn = jax.pmap(generate_fn, "batch")
30 |
31 | # processors/tokenizers are the same for all models, so just load from tiny and preprocess once
32 | processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
33 |
34 |
35 | def preprocess(batch):
36 | batch["input_features"] = processor(
37 | batch["audio"]["array"], sampling_rate=16000, return_tensors="np"
38 | ).input_features[0]
39 | return batch
40 |
41 |
42 | # load a dataset of 73 audio samples
43 | librispeech = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
44 | dataset_processed = librispeech.map(preprocess, remove_columns=librispeech.column_names)
45 |
46 | for batch_size in BATCH_SIZES:
47 | eval_dataset = dataset_processed.select(range(batch_size // 2))
48 | eval_dataset = concatenate_datasets([eval_dataset for _ in range(2 * NUM_BATCHES)])
49 |
50 | eval_dataloader = eval_dataset.with_format("numpy").iter(batch_size=batch_size)
51 |
52 | # warm-up step
53 | batch = next(iter(eval_dataloader))
54 | input_features = shard(batch["input_features"])
55 | pred_ids = p_generate_fn(input_features)
56 |
57 | start = time.time()
58 | for batch in eval_dataloader:
59 | input_features = shard(batch["input_features"])
60 | pred_ids = p_generate_fn(input_features)
61 | runtime = time.time() - start
62 |
63 | print(f"{batch_size}: {runtime:.06}")
64 |
--------------------------------------------------------------------------------
/benchmarks/run_pytorch.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | from datasets import concatenate_datasets, load_dataset
4 | from torch.utils.data import DataLoader
5 | from transformers import WhisperForConditionalGeneration, WhisperProcessor
6 |
7 |
8 | BATCH_SIZES = [4, 8, 16, 32, 64, 128]
9 | NUM_BATCHES = 100
10 | NUM_TOKENS = 25
11 |
12 | model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v2")
13 | model.to("cuda").half()
14 |
15 | # processors/tokenizers are the same for all models, so just load from tiny and preprocess once
16 | processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
17 |
18 |
19 | def preprocess(batch):
20 | batch["input_features"] = processor(
21 | batch["audio"]["array"], sampling_rate=16000, return_tensors="pt"
22 | ).input_features[0]
23 | return batch
24 |
25 |
26 | # load a dataset of 73 audio samples
27 | librispeech = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
28 | dataset_processed = librispeech.map(preprocess, remove_columns=librispeech.column_names)
29 |
30 | for batch_size in BATCH_SIZES:
31 | eval_dataset = dataset_processed.select(range(batch_size // 2))
32 | eval_dataset = concatenate_datasets([eval_dataset for _ in range(2 * NUM_BATCHES)])
33 |
34 | dataloader = DataLoader(
35 | dataset=eval_dataset.with_format("torch"), batch_size=batch_size, num_workers=4, pin_memory=True
36 | )
37 |
38 | # generate
39 | start = time.time()
40 | for batch in dataloader:
41 | input_features = batch["input_features"].to("cuda").half()
42 | out = model.generate(input_features, max_new_tokens=NUM_TOKENS, min_new_tokens=NUM_TOKENS)
43 | runtime = time.time() - start
44 | print(f"{batch_size}: {runtime:.06}")
45 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.black]
2 | line-length = 119
3 | target-version = ['py37']
4 |
5 | [tool.ruff]
6 | # Never enforce `E501` (line length violations).
7 | ignore = ["C901", "E501", "E741", "W605"]
8 | select = ["C", "E", "F", "I", "W"]
9 | line-length = 119
10 |
11 | # Ignore import violations in all `__init__.py` files.
12 | [tool.ruff.per-file-ignores]
13 | "__init__.py" = ["E402", "F401", "F403", "F811"]
14 |
15 | [tool.ruff.isort]
16 | lines-after-imports = 2
17 | known-first-party = ["distil_whisper"]
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import os
17 |
18 | # To use a consistent encoding
19 | from codecs import open
20 |
21 | from setuptools import find_packages, setup
22 |
23 |
24 | _deps = [
25 | "transformers>=4.27.4,<4.35.0",
26 | "flax",
27 | "cached-property",
28 | ]
29 |
30 | _extras_dev_deps = [
31 | "black~=23.1",
32 | "isort>=5.5.4",
33 | "ruff>=0.0.241,<=0.0.259",
34 | ]
35 |
36 | _extras_endpoint_deps = [
37 | "gradio>=4.9.1",
38 | "requests>=2.28.2",
39 | "yt-dlp>=2023.3.4",
40 | ]
41 |
42 |
43 | here = os.path.abspath(os.path.dirname(__file__))
44 |
45 | with open(os.path.join(here, "README.md"), encoding="utf-8") as f:
46 | long_description = f.read()
47 |
48 | # read version
49 | with open(os.path.join(here, "whisper_jax", "__init__.py"), encoding="utf-8") as f:
50 | for line in f:
51 | if line.startswith("__version__"):
52 | version = line.split("=")[1].strip().strip('"')
53 | break
54 | else:
55 | raise RuntimeError("Unable to find version string.")
56 |
57 | setup(
58 | name="whisper_jax",
59 | version=version,
60 | description="JAX implementation of OpenAI's Whisper model.",
61 | long_description=long_description,
62 | long_description_content_type="text/markdown",
63 | packages=find_packages(),
64 | install_requires=_deps,
65 | extras_require={
66 | "dev": [_extras_dev_deps],
67 | "endpoint": [_extras_endpoint_deps],
68 | },
69 | )
70 |
--------------------------------------------------------------------------------
/whisper_jax/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | __version__ = "0.0.1"
17 |
18 | from .modeling_flax_whisper import FlaxWhisperForConditionalGeneration
19 | from .partitioner import PjitPartitioner
20 | from .pipeline import FlaxWhisperPipline
21 | from .train_state import InferenceState
22 |
--------------------------------------------------------------------------------
/whisper_jax/layers.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The T5X Authors and The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Dense attention classes and mask/weighting functions."""
17 |
18 | # pylint: disable=attribute-defined-outside-init,g-bare-generic
19 |
20 | import dataclasses
21 | import functools
22 | import operator
23 | from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Union
24 |
25 | import jax
26 | import jax.numpy as jnp
27 | import numpy as np
28 | from flax import linen as nn
29 | from flax.linen import partitioning as nn_partitioning
30 | from flax.linen.dtypes import promote_dtype
31 | from jax import lax, random
32 |
33 |
34 | # from flax.linen.partitioning import param_with_axes, with_sharding_constraint
35 | param_with_axes = nn_partitioning.param_with_axes
36 | with_sharding_constraint = nn_partitioning.with_sharding_constraint
37 |
38 |
39 | # Type annotations
40 | Array = jnp.ndarray
41 | DType = jnp.dtype
42 | PRNGKey = jnp.ndarray
43 | Shape = Iterable[int]
44 | Activation = Callable[..., Array]
45 | PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]]
46 | DotGeneralT = Callable[..., Array]
47 | ConvGeneralDilatedT = Callable[..., Array]
48 | PaddingLike = Union[str, int, Sequence[Union[int, Tuple[int, int]]]]
49 | LaxPadding = Union[str, Sequence[Tuple[int, int]]]
50 |
51 | # Parameter initializers.
52 | Initializer = Callable[[PRNGKey, Shape, DType], Array]
53 | InitializerAxis = Union[int, Tuple[int, ...]]
54 | NdInitializer = Callable[[PRNGKey, Shape, DType, InitializerAxis, InitializerAxis], Array]
55 |
56 | default_embed_init = nn.initializers.variance_scaling(1.0, "fan_in", "normal", out_axis=0)
57 |
58 |
59 | # ------------------------------------------------------------------------------
60 | # Temporary inlined JAX N-d initializer code
61 | # TODO(levskaya): remove once new JAX release is out.
62 | # ------------------------------------------------------------------------------
63 | def _compute_fans(shape: jax.core.NamedShape, in_axis=-2, out_axis=-1):
64 | """Inlined JAX `nn.initializer._compute_fans`."""
65 | if isinstance(in_axis, int):
66 | in_size = shape[in_axis]
67 | else:
68 | in_size = int(np.prod([shape[i] for i in in_axis]))
69 | if isinstance(out_axis, int):
70 | out_size = shape[out_axis]
71 | else:
72 | out_size = int(np.prod([shape[i] for i in out_axis]))
73 | receptive_field_size = shape.total / in_size / out_size
74 | fan_in = in_size * receptive_field_size
75 | fan_out = out_size * receptive_field_size
76 | return fan_in, fan_out
77 |
78 |
79 | def variance_scaling(scale, mode, distribution, in_axis=-2, out_axis=-1, dtype=jnp.float_):
80 | """Inlined JAX `nn.initializer.variance_scaling`."""
81 |
82 | def init(key, shape, dtype=dtype):
83 | return jnp.zeros(shape, dtype=dtype)
84 | dtype = jax.dtypes.canonicalize_dtype(dtype)
85 | shape = jax.core.as_named_shape(shape)
86 | fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
87 | if mode == "fan_in":
88 | denominator = fan_in
89 | elif mode == "fan_out":
90 | denominator = fan_out
91 | elif mode == "fan_avg":
92 | denominator = (fan_in + fan_out) / 2
93 | else:
94 | raise ValueError("invalid mode for variance scaling initializer: {}".format(mode))
95 | variance = jnp.array(scale / denominator, dtype=dtype)
96 |
97 | if distribution == "truncated_normal":
98 | # constant is stddev of standard normal truncated to (-2, 2)
99 | stddev = jnp.sqrt(variance) / jnp.array(0.87962566103423978, dtype)
100 | return random.truncated_normal(key, -2, 2, shape, dtype) * stddev
101 | elif distribution == "normal":
102 | return random.normal(key, shape, dtype) * jnp.sqrt(variance)
103 | elif distribution == "uniform":
104 | return random.uniform(key, shape, dtype, -1) * jnp.sqrt(3 * variance)
105 | else:
106 | raise ValueError("invalid distribution for variance scaling " "initializer: {}".format(distribution))
107 |
108 | return init
109 |
110 |
111 | # ------------------------------------------------------------------------------
112 |
113 |
114 | def nd_dense_init(scale, mode, distribution):
115 | """Initializer with in_axis, out_axis set at call time."""
116 |
117 | def init_fn(key, shape, dtype, in_axis, out_axis):
118 | fn = variance_scaling(scale, mode, distribution, in_axis, out_axis)
119 | return fn(key, shape, dtype)
120 |
121 | return init_fn
122 |
123 |
124 | def dot_product_attention(
125 | query: Array,
126 | key: Array,
127 | value: Array,
128 | bias: Optional[Array] = None,
129 | dropout_rng: Optional[PRNGKey] = None,
130 | dropout_rate: float = 0.0,
131 | deterministic: bool = False,
132 | dtype: DType = jnp.float32,
133 | float32_logits: bool = False,
134 | ):
135 | """Computes dot-product attention given query, key, and value.
136 |
137 | This is the core function for applying attention based on
138 | https://arxiv.org/abs/1706.03762. It calculates the attention weights given
139 | query and key and combines the values using the attention weights.
140 |
141 | Args:
142 | query: queries for calculating attention with shape of `[batch, q_length,
143 | num_heads, qk_depth_per_head]`.
144 | key: keys for calculating attention with shape of `[batch, kv_length,
145 | num_heads, qk_depth_per_head]`.
146 | value: values to be used in attention with shape of `[batch, kv_length,
147 | num_heads, v_depth_per_head]`.
148 | bias: bias for the attention weights. This should be broadcastable to the
149 | shape `[batch, num_heads, q_length, kv_length]` This can be used for
150 | incorporating causal masks, padding masks, proximity bias, etc.
151 | dropout_rng: JAX PRNGKey: to be used for dropout
152 | dropout_rate: dropout rate
153 | deterministic: bool, deterministic or not (to apply dropout)
154 | dtype: the dtype of the computation (default: float32)
155 | float32_logits: bool, if True then compute logits in float32 to avoid
156 | numerical issues with bfloat16.
157 |
158 | Returns:
159 | Output of shape `[batch, length, num_heads, v_depth_per_head]`.
160 | """
161 | assert key.ndim == query.ndim == value.ndim, "q, k, v must have same rank."
162 | assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], "q, k, v batch dims must match."
163 | assert query.shape[-2] == key.shape[-2] == value.shape[-2], "q, k, v num_heads must match."
164 | assert key.shape[-3] == value.shape[-3], "k, v lengths must match."
165 | assert query.shape[-1] == key.shape[-1], "q, k depths must match."
166 |
167 | # Casting logits and softmax computation for float32 for model stability.
168 | if float32_logits:
169 | query = query.astype(jnp.float32)
170 | key = key.astype(jnp.float32)
171 |
172 | # `attn_weights`: [batch, num_heads, q_length, kv_length]
173 | attn_weights = jnp.einsum("bqhd,bkhd->bhqk", query, key)
174 |
175 | # Apply attention bias: masking, dropout, proximity bias, etc.
176 | if bias is not None:
177 | attn_weights = attn_weights + bias.astype(attn_weights.dtype)
178 |
179 | # Normalize the attention weights across `kv_length` dimension.
180 | attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
181 |
182 | # Apply attention dropout.
183 | if not deterministic and dropout_rate > 0.0:
184 | keep_prob = 1.0 - dropout_rate
185 | # T5 broadcasts along the "length" dim, but unclear which one that
186 | # corresponds to in positional dimensions here, assuming query dim.
187 | dropout_shape = list(attn_weights.shape)
188 | dropout_shape[-2] = 1
189 | keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape)
190 | keep = jnp.broadcast_to(keep, attn_weights.shape)
191 | multiplier = keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype)
192 | attn_weights = attn_weights * multiplier
193 |
194 | # Take the linear combination of `value`.
195 | return jnp.einsum("bhqk,bkhd->bqhd", attn_weights, value)
196 |
197 |
198 | dynamic_vector_slice_in_dim = jax.vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None))
199 |
200 |
201 | class MultiHeadDotProductAttention(nn.Module):
202 | """Multi-head dot-product attention.
203 |
204 | Attributes:
205 | num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
206 | should be divisible by the number of heads.
207 | head_dim: dimension of each head.
208 | dtype: the dtype of the computation.
209 | dropout_rate: dropout rate
210 | kernel_init: initializer for the kernel of the Dense layers.
211 | float32_logits: bool, if True then compute logits in float32 to avoid
212 | numerical issues with bfloat16.
213 | """
214 |
215 | num_heads: int
216 | head_dim: int
217 | dtype: DType = jnp.float32
218 | dropout_rate: float = 0.0
219 | kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "normal")
220 | float32_logits: bool = False # computes logits in float32 for stability.
221 |
222 | @nn.compact
223 | def __call__(
224 | self,
225 | inputs_q: Array,
226 | inputs_kv: Array,
227 | mask: Optional[Array] = None,
228 | bias: Optional[Array] = None,
229 | *,
230 | decode: bool = False,
231 | deterministic: bool = False,
232 | ) -> Array:
233 | """Applies multi-head dot product attention on the input data.
234 |
235 | Projects the inputs into multi-headed query, key, and value vectors,
236 | applies dot-product attention and project the results to an output vector.
237 |
238 | There are two modes: decoding and non-decoding (e.g., training). The mode is
239 | determined by `decode` argument. For decoding, this method is called twice,
240 | first to initialize the cache and then for an actual decoding process. The
241 | two calls are differentiated by the presence of 'cached_key' in the variable
242 | dict. In the cache initialization stage, the cache variables are initialized
243 | as zeros and will be filled in the subsequent decoding process.
244 |
245 | In the cache initialization call, `inputs_q` has a shape [batch, length,
246 | q_features] and `inputs_kv`: [batch, length, kv_features]. During the
247 | incremental decoding stage, query, key and value all have the shape [batch,
248 | 1, qkv_features] corresponding to a single step.
249 |
250 | Args:
251 | inputs_q: input queries of shape `[batch, q_length, q_features]`.
252 | inputs_kv: key/values of shape `[batch, kv_length, kv_features]`.
253 | mask: attention mask of shape `[batch, num_heads, q_length, kv_length]`.
254 | bias: attention bias of shape `[batch, num_heads, q_length, kv_length]`.
255 | decode: Whether to prepare and use an autoregressive cache.
256 | deterministic: Disables dropout if set to True.
257 |
258 | Returns:
259 | output of shape `[batch, length, q_features]`.
260 | """
261 | projection = functools.partial(
262 | DenseGeneral,
263 | axis=-1,
264 | features=(self.num_heads, self.head_dim),
265 | kernel_axes=("embed", "heads", "kv"),
266 | dtype=self.dtype,
267 | )
268 |
269 | # NOTE: T5 does not explicitly rescale the attention logits by
270 | # 1/sqrt(depth_kq)! This is folded into the initializers of the
271 | # linear transformations, which is equivalent under Adafactor.
272 | depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
273 |
274 | def query_init(*args):
275 | return self.kernel_init(*args) / depth_scaling
276 |
277 | # Project inputs_q to multi-headed q/k/v
278 | # dimensions are then [batch, length, num_heads, head_dim]
279 | query = projection(kernel_init=query_init, name="query")(inputs_q)
280 | key = projection(kernel_init=self.kernel_init, name="key")(inputs_kv)
281 | value = projection(kernel_init=self.kernel_init, name="value")(inputs_kv)
282 |
283 | query = with_sharding_constraint(query, ("batch", "length", "heads", "kv"))
284 | key = with_sharding_constraint(key, ("batch", "length", "heads", "kv"))
285 | value = with_sharding_constraint(value, ("batch", "length", "heads", "kv"))
286 |
287 | if decode:
288 | # Detect if we're initializing by absence of existing cache data.
289 | is_initialized = self.has_variable("cache", "cached_key")
290 |
291 | # The key and value have dimension [batch, length, num_heads, head_dim],
292 | # but we cache them as [batch, num_heads, head_dim, length] as a TPU
293 | # fusion optimization. This also enables the "scatter via one-hot
294 | # broadcast" trick, which means we do a one-hot broadcast instead of a
295 | # scatter/gather operations, resulting in a 3-4x speedup in practice.
296 | def swap_dims(x):
297 | return x[:-3] + tuple(x[i] for i in [-2, -1, -3])
298 |
299 | cached_key = self.variable("cache", "cached_key", jnp.zeros, swap_dims(key.shape), key.dtype)
300 | cached_value = self.variable("cache", "cached_value", jnp.zeros, swap_dims(value.shape), value.dtype)
301 | cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
302 | if is_initialized:
303 | batch, num_heads, head_dim, length = cached_key.value.shape
304 | # During fast autoregressive decoding, we feed one position at a time,
305 | # and cache the keys and values step by step.
306 | # Sanity shape check of cached key against input query.
307 | expected_shape = (batch, 1, num_heads, head_dim)
308 | if expected_shape != query.shape:
309 | raise ValueError(
310 | "Autoregressive cache shape error, "
311 | "expected query shape %s instead got %s." % (expected_shape, query.shape)
312 | )
313 |
314 | # Create a OHE of the current index. NOTE: the index is increased below.
315 | cur_index = cache_index.value
316 | one_hot_indices = jax.nn.one_hot(cur_index, length, dtype=key.dtype)
317 | # In order to update the key, value caches with the current key and
318 | # value, we move the length axis to the back, similar to what we did for
319 | # the cached ones above.
320 | # Note these are currently the key and value of a single position, since
321 | # we feed one position at a time.
322 | one_token_key = jnp.moveaxis(key, -3, -1)
323 | one_token_value = jnp.moveaxis(value, -3, -1)
324 | # Update key, value caches with our new 1d spatial slices.
325 | # We implement an efficient scatter into the cache via one-hot
326 | # broadcast and addition.
327 | key = cached_key.value + one_token_key * one_hot_indices
328 | value = cached_value.value + one_token_value * one_hot_indices
329 | cached_key.value = key
330 | cached_value.value = value
331 | cache_index.value = cache_index.value + 1
332 | # Move the keys and values back to their original shapes.
333 | key = jnp.moveaxis(key, -1, -3)
334 | value = jnp.moveaxis(value, -1, -3)
335 |
336 | # Causal mask for cached decoder self-attention: our single query
337 | # position should only attend to those key positions that have already
338 | # been generated and cached, not the remaining zero elements.
339 | mask = combine_masks(
340 | mask,
341 | jnp.broadcast_to(
342 | jnp.arange(length) <= cur_index,
343 | # (1, 1, length) represent (head dim, query length, key length)
344 | # query length is 1 because during decoding we deal with one
345 | # index.
346 | # The same mask is applied to all batch elements and heads.
347 | (batch, 1, 1, length),
348 | ),
349 | )
350 |
351 | # Grab the correct relative attention bias during decoding. This is
352 | # only required during single step decoding.
353 | if bias is not None:
354 | # The bias is a full attention matrix, but during decoding we only
355 | # have to take a slice of it.
356 | # This is equivalent to bias[..., cur_index:cur_index+1, :].
357 | bias = dynamic_vector_slice_in_dim(jnp.squeeze(bias, axis=0), jnp.reshape(cur_index, (-1)), 1, -2)
358 |
359 | # Convert the boolean attention mask to an attention bias.
360 | if mask is not None:
361 | # attention mask in the form of attention bias
362 | attention_bias = lax.select(
363 | mask > 0, jnp.full(mask.shape, 0.0).astype(self.dtype), jnp.full(mask.shape, -1e10).astype(self.dtype)
364 | )
365 | else:
366 | attention_bias = None
367 |
368 | # Add provided bias term (e.g. relative position embedding).
369 | if bias is not None:
370 | attention_bias = combine_biases(attention_bias, bias)
371 |
372 | dropout_rng = None
373 | if not deterministic and self.dropout_rate > 0.0:
374 | dropout_rng = self.make_rng("dropout")
375 |
376 | # Apply attention.
377 | x = dot_product_attention(
378 | query,
379 | key,
380 | value,
381 | bias=attention_bias,
382 | dropout_rng=dropout_rng,
383 | dropout_rate=self.dropout_rate,
384 | deterministic=deterministic,
385 | dtype=self.dtype,
386 | float32_logits=self.float32_logits,
387 | )
388 |
389 | # Back to the original inputs dimensions.
390 | out = DenseGeneral(
391 | features=inputs_q.shape[-1], # output dim is set to the input dim.
392 | axis=(-2, -1),
393 | kernel_init=self.kernel_init,
394 | kernel_axes=("heads", "kv", "embed"),
395 | dtype=self.dtype,
396 | name="out",
397 | )(x)
398 | return out
399 |
400 |
401 | def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]:
402 | # A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
403 | return tuple([ax if ax >= 0 else ndim + ax for ax in axes])
404 |
405 |
406 | def _canonicalize_tuple(x):
407 | if isinstance(x, Iterable):
408 | return tuple(x)
409 | else:
410 | return (x,)
411 |
412 |
413 | # ------------------------------------------------------------------------------
414 | # DenseGeneral for attention layers.
415 | # ------------------------------------------------------------------------------
416 | class DenseGeneral(nn.Module):
417 | """A linear transformation (without bias) with flexible axes.
418 |
419 | Attributes:
420 | features: tuple with numbers of output features.
421 | axis: tuple with axes to apply the transformation on.
422 | dtype: the dtype of the computation (default: float32).
423 | kernel_init: initializer function for the weight matrix.
424 | """
425 |
426 | features: Union[Iterable[int], int]
427 | axis: Union[Iterable[int], int] = -1
428 | dtype: DType = jnp.float32
429 | params_dtype: DType = jnp.float32
430 | kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "normal")
431 | kernel_axes: Tuple[str, ...] = ()
432 | use_bias: bool = True
433 | bias_init: Any = nn.initializers.zeros
434 |
435 | @nn.compact
436 | def __call__(self, inputs: Array) -> Array:
437 | """Applies a linear transformation to the inputs along multiple dimensions.
438 |
439 | Args:
440 | inputs: The nd-array to be transformed.
441 |
442 | Returns:
443 | The transformed input.
444 | """
445 | features = _canonicalize_tuple(self.features)
446 | axis = _canonicalize_tuple(self.axis)
447 |
448 | inputs = jnp.asarray(inputs, self.dtype)
449 | axis = _normalize_axes(axis, inputs.ndim)
450 |
451 | kernel_shape = tuple([inputs.shape[ax] for ax in axis]) + features
452 | kernel_in_axis = np.arange(len(axis))
453 | kernel_out_axis = np.arange(len(axis), len(axis) + len(features))
454 | kernel = param_with_axes(
455 | "kernel",
456 | self.kernel_init,
457 | kernel_shape,
458 | self.params_dtype,
459 | kernel_in_axis,
460 | kernel_out_axis,
461 | axes=self.kernel_axes,
462 | )
463 | if self.use_bias:
464 | bias = param_with_axes("bias", self.bias_init, features, self.params_dtype, axes=(self.kernel_axes[-1],))
465 | kernel = jnp.asarray(kernel, self.dtype)
466 |
467 | contract_ind = tuple(range(0, len(axis)))
468 | y = lax.dot_general(inputs, kernel, ((axis, contract_ind), ((), ())))
469 | if self.use_bias:
470 | bias = jnp.asarray(bias, self.dtype)
471 | # y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))
472 | y += jnp.reshape(bias, (1,) * (len(features) - y.ndim) + bias.shape[:])
473 | return y
474 |
475 |
476 | def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Callable:
477 | """Convert a string to an activation function."""
478 | if fn_or_string == "linear":
479 | return lambda x: x
480 | elif isinstance(fn_or_string, str):
481 | return getattr(nn, fn_or_string)
482 | elif callable(fn_or_string):
483 | return fn_or_string
484 | else:
485 | raise ValueError("don't know how to convert %s to an activation function" % (fn_or_string,))
486 |
487 |
488 | class MlpBlock(nn.Module):
489 | """Transformer MLP / feed-forward block.
490 |
491 | Attributes:
492 | intermediate_dim: Shared dimension of hidden layers.
493 | activations: Type of activations for each layer. Each element is either
494 | 'linear', a string function name in flax.linen, or a function.
495 | kernel_init: Kernel function, passed to the dense layers.
496 | deterministic: Whether the dropout layers should be deterministic.
497 | intermediate_dropout_rate: Dropout rate used after the intermediate layers.
498 | dtype: Type for the dense layer.
499 | """
500 |
501 | intermediate_dim: int = 2048
502 | activations: Sequence[Union[str, Callable]] = ("relu",)
503 | kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "truncated_normal")
504 | intermediate_dropout_rate: float = 0.1
505 | dtype: Any = jnp.float32
506 |
507 | @nn.compact
508 | def __call__(self, inputs, decode: bool = False, deterministic: bool = False):
509 | """Applies Transformer MlpBlock module."""
510 | # Iterate over specified MLP input activation functions.
511 | # e.g. ('relu',) or ('gelu', 'linear') for gated-gelu.
512 | activations = []
513 | for idx, act_fn in enumerate(self.activations):
514 | dense_name = "wi" if len(self.activations) == 1 else f"wi_{idx}"
515 | x = DenseGeneral(
516 | self.intermediate_dim,
517 | dtype=self.dtype,
518 | kernel_init=self.kernel_init,
519 | kernel_axes=("embed", "mlp"),
520 | name=dense_name,
521 | )(inputs)
522 | x = _convert_to_activation_function(act_fn)(x)
523 | activations.append(x)
524 |
525 | # Take elementwise product of above intermediate activations.
526 | x = functools.reduce(operator.mul, activations)
527 | # Apply dropout and final dense output projection.
528 | x = nn.Dropout(rate=self.intermediate_dropout_rate, broadcast_dims=(-2,))(
529 | x, deterministic=deterministic
530 | ) # Broadcast along length.
531 | x = with_sharding_constraint(x, ("batch", "length", "mlp"))
532 | output = DenseGeneral(
533 | inputs.shape[-1], dtype=self.dtype, kernel_init=self.kernel_init, kernel_axes=("mlp", "embed"), name="wo"
534 | )(x)
535 | return output
536 |
537 |
538 | class Embed(nn.Module):
539 | """A parameterized function from integers [0, n) to d-dimensional vectors.
540 |
541 | Attributes:
542 | num_embeddings: number of embeddings.
543 | features: number of feature dimensions for each embedding.
544 | dtype: the dtype of the embedding vectors (default: float32).
545 | embedding_init: embedding initializer.
546 | one_hot: performs the gather with a one-hot contraction rather than a true
547 | gather. This is currently needed for SPMD partitioning.
548 | """
549 |
550 | num_embeddings: int
551 | features: int
552 | cast_input_dtype: Optional[DType] = None
553 | dtype: DType = jnp.float32
554 | params_dtype: DType = jnp.float32
555 | attend_dtype: Optional[DType] = None
556 | embedding_init: Initializer = default_embed_init
557 | one_hot: bool = True
558 | embedding: Array = dataclasses.field(init=False)
559 |
560 | def setup(self):
561 | self.embedding = param_with_axes(
562 | "embedding",
563 | self.embedding_init,
564 | (self.num_embeddings, self.features),
565 | self.params_dtype,
566 | axes=("vocab", "embed"),
567 | )
568 |
569 | def __call__(self, inputs: Array) -> Array:
570 | """Embeds the inputs along the last dimension.
571 |
572 | Args:
573 | inputs: input data, all dimensions are considered batch dimensions.
574 |
575 | Returns:
576 | Output which is embedded input data. The output shape follows the input,
577 | with an additional `features` dimension appended.
578 | """
579 | if self.cast_input_dtype:
580 | inputs = inputs.astype(self.cast_input_dtype)
581 | if not jnp.issubdtype(inputs.dtype, jnp.integer):
582 | raise ValueError("Input type must be an integer or unsigned integer.")
583 | if self.one_hot:
584 | iota = lax.iota(jnp.int32, self.num_embeddings)
585 | one_hot = jnp.array(inputs[..., jnp.newaxis] == iota, dtype=self.dtype)
586 | output = jnp.dot(one_hot, jnp.asarray(self.embedding, self.dtype))
587 | else:
588 | output = jnp.asarray(self.embedding, self.dtype)[inputs]
589 | output = with_sharding_constraint(output, ("batch", "length", "embed"))
590 | return output
591 |
592 | def attend(self, query: Array) -> Array:
593 | """Attend over the embedding using a query array.
594 |
595 | Args:
596 | query: array with last dimension equal the feature depth `features` of the
597 | embedding.
598 |
599 | Returns:
600 | An array with final dim `num_embeddings` corresponding to the batched
601 | inner-product of the array of query vectors against each embedding.
602 | Commonly used for weight-sharing between embeddings and logit transform
603 | in NLP models.
604 | """
605 | dtype = self.attend_dtype if self.attend_dtype is not None else self.dtype
606 | return jnp.dot(query, jnp.asarray(self.embedding, dtype).T)
607 |
608 |
609 | class RelativePositionBiases(nn.Module):
610 | """Adds T5-style relative positional embeddings to the attention logits.
611 |
612 | Attributes:
613 | num_buckets: Number of buckets to bucket distances between key and query
614 | positions into.
615 | max_distance: Maximum distance before everything is lumped into the last
616 | distance bucket.
617 | num_heads: Number of heads in the attention layer. Each head will get a
618 | different relative position weighting.
619 | dtype: Type of arrays through this module.
620 | embedding_init: initializer for relative embedding table.
621 | """
622 |
623 | num_buckets: int
624 | max_distance: int
625 | num_heads: int
626 | dtype: Any
627 | embedding_init: Callable[..., Array] = nn.linear.default_embed_init
628 |
629 | @staticmethod
630 | def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
631 | """Translate relative position to a bucket number for relative attention.
632 |
633 | The relative position is defined as memory_position - query_position, i.e.
634 | the distance in tokens from the attending position to the attended-to
635 | position. If bidirectional=False, then positive relative positions are
636 | invalid.
637 | We use smaller buckets for small absolute relative_position and larger
638 | buckets for larger absolute relative_positions. All relative
639 | positions >=max_distance map to the same bucket. All relative
640 | positions <=-max_distance map to the same bucket. This should allow for
641 | more graceful generalization to longer sequences than the model has been
642 | trained on.
643 |
644 | Args:
645 | relative_position: an int32 array
646 | bidirectional: a boolean - whether the attention is bidirectional
647 | num_buckets: an integer
648 | max_distance: an integer
649 |
650 | Returns:
651 | a Tensor with the same shape as relative_position, containing int32
652 | values in the range [0, num_buckets)
653 | """
654 | ret = 0
655 | n = -relative_position
656 | if bidirectional:
657 | num_buckets //= 2
658 | ret += (n < 0).astype(np.int32) * num_buckets
659 | n = np.abs(n)
660 | else:
661 | n = np.maximum(n, 0)
662 | # now n is in the range [0, inf)
663 | max_exact = num_buckets // 2
664 | is_small = n < max_exact
665 | val_if_large = max_exact + (
666 | np.log(n.astype(np.float32) / max_exact + np.finfo(np.float32).eps)
667 | / np.log(max_distance / max_exact)
668 | * (num_buckets - max_exact)
669 | ).astype(np.int32)
670 | val_if_large = np.minimum(val_if_large, num_buckets - 1)
671 | ret += np.where(is_small, n, val_if_large)
672 | return ret
673 |
674 | @nn.compact
675 | def __call__(self, qlen, klen, bidirectional=True):
676 | """Produce relative position embedding attention biases.
677 |
678 | Args:
679 | qlen: attention query length.
680 | klen: attention key length.
681 | bidirectional: whether to allow positive memory-query relative position
682 | embeddings.
683 |
684 | Returns:
685 | output: `(1, len, q_len, k_len)` attention bias
686 | """
687 | # TODO(levskaya): should we be computing this w. numpy as a program
688 | # constant?
689 | context_position = np.arange(qlen, dtype=jnp.int32)[:, None]
690 | memory_position = np.arange(klen, dtype=jnp.int32)[None, :]
691 | relative_position = memory_position - context_position # shape (qlen, klen)
692 | rp_bucket = self._relative_position_bucket(
693 | relative_position,
694 | bidirectional=bidirectional,
695 | num_buckets=self.num_buckets,
696 | max_distance=self.max_distance,
697 | )
698 | relative_attention_bias = param_with_axes(
699 | "rel_embedding",
700 | self.embedding_init,
701 | (self.num_heads, self.num_buckets),
702 | jnp.float32,
703 | axes=("heads", "relpos_buckets"),
704 | )
705 |
706 | relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype)
707 | # Instead of using a slow gather, we create a leading-dimension one-hot
708 | # array from rp_bucket and use it to perform the gather-equivalent via a
709 | # contraction, i.e.:
710 | # (num_head, num_buckets) x (num_buckets one-hot, qlen, klen).
711 | # This is equivalent to relative_attention_bias[:, rp_bucket]
712 | bcast_iota = lax.broadcasted_iota(jnp.int32, (self.num_buckets, 1, 1), 0)
713 | rp_bucket_one_hot = jnp.array(rp_bucket[jnp.newaxis, ...] == bcast_iota, dtype=self.dtype)
714 | # --> shape (qlen, klen, num_heads)
715 | values = lax.dot_general(
716 | relative_attention_bias, rp_bucket_one_hot, (((1,), (0,)), ((), ())) # rhs, lhs contracting dims
717 | ) # no batched dims
718 | # Add a singleton batch dimension.
719 | # --> shape (1, num_heads, qlen, klen)
720 | return values[jnp.newaxis, ...]
721 |
722 |
723 | # ------------------------------------------------------------------------------
724 | # T5 Layernorm - no subtraction of mean or bias.
725 | # ------------------------------------------------------------------------------
726 | # class LayerNorm(nn.Module):
727 | # """T5 Layer normalization operating on the last axis of the input data."""
728 | # epsilon: float = 1e-6
729 | # dtype: Any = jnp.float32
730 | # scale_init: Initializer = nn.initializers.ones
731 |
732 | # @nn.compact
733 | # def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
734 | # """Applies layer normalization on the input."""
735 | # x = jnp.asarray(x, jnp.float32)
736 | # features = x.shape[-1]
737 | # mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True)
738 | # y = jnp.asarray(x * lax.rsqrt(mean2 + self.epsilon), self.dtype)
739 | # scale = param_with_axes(
740 | # 'scale', self.scale_init, (features,), jnp.float32, axes=('embed',))
741 |
742 | # scale = jnp.asarray(scale, self.dtype)
743 | # return y * scale
744 |
745 |
746 | class LayerNorm(nn.Module):
747 | """Layer normalization (https://arxiv.org/abs/1607.06450).
748 | Operates on the last axis of the input data.
749 | It normalizes the activations of the layer for each given example in a
750 | batch independently, rather than across a batch like Batch Normalization.
751 | i.e. applies a transformation that maintains the mean activation within
752 | each example close to 0 and the activation standard deviation close to 1.
753 | Attributes:
754 | epsilon: A small float added to variance to avoid dividing by zero.
755 | dtype: the dtype of the computation (default: float32).
756 | use_bias: If True, bias (beta) is added.
757 | use_scale: If True, multiply by scale (gamma). When the next layer is linear
758 | (also e.g. nn.relu), this can be disabled since the scaling will be done
759 | by the next layer.
760 | bias_init: Initializer for bias, by default, zero.
761 | scale_init: Initializer for scale, by default, one.
762 | """
763 |
764 | epsilon: float = 1e-6
765 | dtype: Any = jnp.float32
766 | params_dtype: DType = jnp.float32
767 | use_bias: bool = True
768 | use_scale: bool = True
769 | bias_init: Callable[[PRNGKey, Shape, Any], Array] = nn.initializers.zeros
770 | scale_init: Callable[[PRNGKey, Shape, Any], Array] = nn.initializers.ones
771 |
772 | @nn.compact
773 | def __call__(self, x):
774 | """Applies layer normalization on the input.
775 | Args:
776 | x: the inputs
777 | Returns:
778 | Normalized inputs (the same shape as inputs).
779 | """
780 | x = jnp.asarray(x, jnp.float32)
781 | features = x.shape[-1]
782 | mean = jnp.mean(x, axis=-1, keepdims=True)
783 | mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True)
784 | var = mean2 - lax.square(mean)
785 | mul = lax.rsqrt(var + self.epsilon)
786 | if self.use_scale:
787 | scale = param_with_axes("scale", self.scale_init, (features,), self.params_dtype, axes=("embed",))
788 | mul = mul * jnp.asarray(scale, self.dtype)
789 | y = (x - mean) * mul
790 | if self.use_bias:
791 | bias = param_with_axes("bias", self.bias_init, (features,), self.params_dtype, axes=("embed",))
792 | y = y + jnp.asarray(bias, self.dtype)
793 | return jnp.asarray(y, self.dtype)
794 |
795 |
796 | # ------------------------------------------------------------------------------
797 | # Mask-making utility functions.
798 | # ------------------------------------------------------------------------------
799 | def make_attention_mask(
800 | query_input: Array,
801 | key_input: Array,
802 | pairwise_fn: Callable = jnp.multiply,
803 | extra_batch_dims: int = 0,
804 | dtype: DType = jnp.float32,
805 | ) -> Array:
806 | """Mask-making helper for attention weights.
807 |
808 | In case of 1d inputs (i.e., `[batch, len_q]`, `[batch, len_kv]`, the
809 | attention weights will be `[batch, heads, len_q, len_kv]` and this
810 | function will produce `[batch, 1, len_q, len_kv]`.
811 |
812 | Args:
813 | query_input: a batched, flat input of query_length size
814 | key_input: a batched, flat input of key_length size
815 | pairwise_fn: broadcasting elementwise comparison function
816 | extra_batch_dims: number of extra batch dims to add singleton axes for, none
817 | by default
818 | dtype: mask return dtype
819 |
820 | Returns:
821 | A `[batch, 1, len_q, len_kv]` shaped mask for 1d attention.
822 | """
823 | # [batch, len_q, len_kv]
824 | mask = pairwise_fn(
825 | # [batch, len_q] -> [batch, len_q, 1]
826 | jnp.expand_dims(query_input, axis=-1),
827 | # [batch, len_q] -> [batch, 1, len_kv]
828 | jnp.expand_dims(key_input, axis=-2),
829 | )
830 |
831 | # [batch, 1, len_q, len_kv]. This creates the head dim.
832 | mask = jnp.expand_dims(mask, axis=-3)
833 | mask = jnp.expand_dims(mask, axis=tuple(range(extra_batch_dims)))
834 | return mask.astype(dtype)
835 |
836 |
837 | def make_causal_mask(x: Array, extra_batch_dims: int = 0, dtype: DType = jnp.float32) -> Array:
838 | """Make a causal mask for self-attention.
839 |
840 | In case of 1d inputs (i.e., `[batch, len]`, the self-attention weights
841 | will be `[batch, heads, len, len]` and this function will produce a
842 | causal mask of shape `[batch, 1, len, len]`.
843 |
844 | Note that a causal mask does not depend on the values of x; it only depends on
845 | the shape. If x has padding elements, they will not be treated in a special
846 | manner.
847 |
848 | Args:
849 | x: input array of shape `[batch, len]`
850 | extra_batch_dims: number of batch dims to add singleton axes for, none by
851 | default
852 | dtype: mask return dtype
853 |
854 | Returns:
855 | A `[batch, 1, len, len]` shaped causal mask for 1d attention.
856 | """
857 | idxs = jnp.broadcast_to(jnp.arange(x.shape[-1], dtype=jnp.int32), x.shape)
858 | return make_attention_mask(idxs, idxs, jnp.greater_equal, extra_batch_dims=extra_batch_dims, dtype=dtype)
859 |
860 |
861 | def combine_masks(*masks: Optional[Array], dtype: DType = jnp.float32):
862 | """Combine attention masks.
863 |
864 | Args:
865 | *masks: set of attention mask arguments to combine, some can be None.
866 | dtype: final mask dtype
867 |
868 | Returns:
869 | Combined mask, reduced by logical and, returns None if no masks given.
870 | """
871 | masks = [m for m in masks if m is not None]
872 | if not masks:
873 | return None
874 | assert all(
875 | (x.ndim == masks[0].ndim for x in masks)
876 | ), f"masks must have same rank: {tuple((x.ndim for x in masks))}"
877 | mask, *other_masks = masks
878 | for other_mask in other_masks:
879 | mask = jnp.logical_and(mask, other_mask)
880 | return mask.astype(dtype)
881 |
882 |
883 | def combine_biases(*masks: Optional[Array]):
884 | """Combine attention biases.
885 |
886 | Args:
887 | *masks: set of attention bias arguments to combine, some can be None.
888 |
889 | Returns:
890 | Combined mask, reduced by summation, returns None if no masks given.
891 | """
892 | masks = [m for m in masks if m is not None]
893 | if not masks:
894 | return None
895 | assert all(
896 | (x.ndim == masks[0].ndim for x in masks)
897 | ), f"masks must have same rank: {tuple((x.ndim for x in masks))}"
898 | mask, *other_masks = masks
899 | for other_mask in other_masks:
900 | mask = mask + other_mask
901 | return mask
902 |
903 |
904 | def make_decoder_mask(
905 | decoder_target_tokens: Array,
906 | dtype: DType,
907 | decoder_causal_attention: Optional[Array] = None,
908 | decoder_segment_ids: Optional[Array] = None,
909 | ) -> Array:
910 | """Compute the self-attention mask for a decoder.
911 |
912 | Decoder mask is formed by combining a causal mask, a padding mask and an
913 | optional packing mask. If decoder_causal_attention is passed, it makes the
914 | masking non-causal for positions that have value of 1.
915 |
916 | A prefix LM is applied to a dataset which has a notion of "inputs" and
917 | "targets", e.g., a machine translation task. The inputs and targets are
918 | concatenated to form a new target. `decoder_target_tokens` is the concatenated
919 | decoder output tokens.
920 |
921 | The "inputs" portion of the concatenated sequence can attend to other "inputs"
922 | tokens even for those at a later time steps. In order to control this
923 | behavior, `decoder_causal_attention` is necessary. This is a binary mask with
924 | a value of 1 indicating that the position belonged to "inputs" portion of the
925 | original dataset.
926 |
927 | Example:
928 |
929 | Suppose we have a dataset with two examples.
930 |
931 | ds = [{"inputs": [6, 7], "targets": [8]},
932 | {"inputs": [3, 4], "targets": [5]}]
933 |
934 | After the data preprocessing with packing, the two examples are packed into
935 | one example with the following three fields (some fields are skipped for
936 | simplicity).
937 |
938 | decoder_target_tokens = [[6, 7, 8, 3, 4, 5, 0]]
939 | decoder_segment_ids = [[1, 1, 1, 2, 2, 2, 0]]
940 | decoder_causal_attention = [[1, 1, 0, 1, 1, 0, 0]]
941 |
942 | where each array has [batch, length] shape with batch size being 1. Then,
943 | this function computes the following mask.
944 |
945 | mask = [[[[1, 1, 0, 0, 0, 0, 0],
946 | [1, 1, 0, 0, 0, 0, 0],
947 | [1, 1, 1, 0, 0, 0, 0],
948 | [0, 0, 0, 1, 1, 0, 0],
949 | [0, 0, 0, 1, 1, 0, 0],
950 | [0, 0, 0, 1, 1, 1, 0],
951 | [0, 0, 0, 0, 0, 0, 0]]]]
952 |
953 | mask[b, 1, :, :] represents the mask for the example `b` in the batch.
954 | Because mask is for a self-attention layer, the mask's shape is a square of
955 | shape [query length, key length].
956 |
957 | mask[b, 1, i, j] = 1 means that the query token at position i can attend to
958 | the key token at position j.
959 |
960 | Args:
961 | decoder_target_tokens: decoder output tokens. [batch, length]
962 | dtype: dtype of the output mask.
963 | decoder_causal_attention: a binary mask indicating which position should
964 | only attend to earlier positions in the sequence. Others will attend
965 | bidirectionally. [batch, length]
966 | decoder_segment_ids: decoder segmentation info for packed examples. [batch,
967 | length]
968 |
969 | Returns:
970 | the combined decoder mask.
971 | """
972 | masks = []
973 | # The same mask is applied to all attention heads. So the head dimension is 1,
974 | # i.e., the mask will be broadcast along the heads dim.
975 | # [batch, 1, length, length]
976 | causal_mask = make_causal_mask(decoder_target_tokens, dtype=dtype)
977 |
978 | # Positions with value 1 in `decoder_causal_attneition` can attend
979 | # bidirectionally.
980 | if decoder_causal_attention is not None:
981 | # [batch, 1, length, length]
982 | inputs_mask = make_attention_mask(
983 | decoder_causal_attention, decoder_causal_attention, jnp.logical_and, dtype=dtype
984 | )
985 | masks.append(jnp.logical_or(causal_mask, inputs_mask).astype(dtype))
986 | else:
987 | masks.append(causal_mask)
988 |
989 | # Padding mask.
990 | masks.append(make_attention_mask(decoder_target_tokens > 0, decoder_target_tokens > 0, dtype=dtype))
991 |
992 | # Packing mask
993 | if decoder_segment_ids is not None:
994 | masks.append(make_attention_mask(decoder_segment_ids, decoder_segment_ids, jnp.equal, dtype=dtype))
995 |
996 | return combine_masks(*masks, dtype=dtype)
997 |
998 |
999 | def canonicalize_padding(padding: PaddingLike, rank: int) -> LaxPadding:
1000 | """ "Canonicalizes conv padding to a jax.lax supported format."""
1001 | if isinstance(padding, str):
1002 | return padding
1003 | if isinstance(padding, int):
1004 | return [(padding, padding)] * rank
1005 | if isinstance(padding, Sequence) and len(padding) == rank:
1006 | new_pad = []
1007 | for p in padding:
1008 | if isinstance(p, int):
1009 | new_pad.append((p, p))
1010 | elif isinstance(p, tuple) and len(p) == 2:
1011 | new_pad.append(p)
1012 | else:
1013 | break
1014 | if len(new_pad) == rank:
1015 | return new_pad
1016 | raise ValueError(
1017 | f"Invalid padding format: {padding}, should be str, int,"
1018 | f" or a sequence of len {rank} where each element is an"
1019 | f" int or pair of ints."
1020 | )
1021 |
1022 |
1023 | def _conv_dimension_numbers(input_shape):
1024 | """Computes the dimension numbers based on the input shape."""
1025 | ndim = len(input_shape)
1026 | lhs_spec = (0, ndim - 1) + tuple(range(1, ndim - 1))
1027 | rhs_spec = (ndim - 1, ndim - 2) + tuple(range(0, ndim - 2))
1028 | out_spec = lhs_spec
1029 | return lax.ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec)
1030 |
1031 |
1032 | class _Conv(nn.Module):
1033 | """Convolution Module wrapping `lax.conv_general_dilated[_local]`.
1034 |
1035 | Attributes:
1036 | features: number of convolution filters.
1037 | kernel_size: shape of the convolutional kernel. For 1D convolution,
1038 | the kernel size can be passed as an integer. For all other cases, it must
1039 | be a sequence of integers.
1040 | strides: an integer or a sequence of `n` integers, representing the
1041 | inter-window strides (default: 1).
1042 | padding: either the string `'SAME'`, the string `'VALID'`, the string
1043 | `'CIRCULAR'` (periodic boundary conditions), or a sequence of `n` `(low,
1044 | high)` integer pairs that give the padding to apply before and after each
1045 | spatial dimension. A single int is interpeted as applying the same padding
1046 | in all dims and passign a single int in a sequence causes the same padding
1047 | to be used on both sides. `'CAUSAL'` padding for a 1D convolution will
1048 | left-pad the convolution axis, resulting in same-sized output.
1049 | input_dilation: an integer or a sequence of `n` integers, giving the
1050 | dilation factor to apply in each spatial dimension of `inputs`
1051 | (default: 1). Convolution with input dilation `d` is equivalent to
1052 | transposed convolution with stride `d`.
1053 | kernel_dilation: an integer or a sequence of `n` integers, giving the
1054 | dilation factor to apply in each spatial dimension of the convolution
1055 | kernel (default: 1). Convolution with kernel dilation
1056 | is also known as 'atrous convolution'.
1057 | feature_group_count: integer, default 1. If specified divides the input
1058 | features into groups.
1059 | use_bias: whether to add a bias to the output (default: True).
1060 | mask: Optional mask for the weights during masked convolution. The mask must
1061 | be the same shape as the convolution weight matrix.
1062 | dtype: the dtype of the computation (default: infer from input and params).
1063 | params_dtype: the dtype passed to parameter initializers (default: float32).
1064 | precision: numerical precision of the computation see `jax.lax.Precision`
1065 | for details.
1066 | kernel_init: initializer for the convolutional kernel.
1067 | bias_init: initializer for the bias.
1068 | """
1069 |
1070 | features: int
1071 | kernel_size: Sequence[int]
1072 | strides: Union[None, int, Sequence[int]] = 1
1073 | padding: PaddingLike = "SAME"
1074 | input_dilation: Union[None, int, Sequence[int]] = 1
1075 | kernel_dilation: Union[None, int, Sequence[int]] = 1
1076 | feature_group_count: int = 1
1077 | use_bias: bool = True
1078 | mask: Optional[Array] = None
1079 | dtype: Optional[DType] = None
1080 | params_dtype: DType = jnp.float32
1081 | precision: PrecisionLike = None
1082 | kernel_init: Callable[[PRNGKey, Shape, DType], Array] = nn.initializers.lecun_normal()
1083 | bias_init: Callable[[PRNGKey, Shape, DType], Array] = nn.initializers.zeros
1084 | conv_general_dilated: ConvGeneralDilatedT = lax.conv_general_dilated
1085 | kernel_axes: Tuple[str, ...] = ()
1086 |
1087 | @property
1088 | def shared_weights(self) -> bool: # type: ignore
1089 | """Defines whether weights are shared or not between different pixels.
1090 |
1091 | Returns:
1092 | `True` to use shared weights in convolution (regular convolution).
1093 | `False` to use different weights at different pixels, a.k.a.
1094 | "locally connected layer", "unshared convolution", or "local convolution".
1095 |
1096 | """
1097 | ...
1098 |
1099 | @nn.compact
1100 | def __call__(self, inputs: Array) -> Array:
1101 | """Applies a (potentially unshared) convolution to the inputs.
1102 |
1103 | Args:
1104 | inputs: input data with dimensions (*batch_dims, spatial_dims...,
1105 | features). This is the channels-last convention, i.e. NHWC for a 2d
1106 | convolution and NDHWC for a 3D convolution. Note: this is different from
1107 | the input convention used by `lax.conv_general_dilated`, which puts the
1108 | spatial dimensions last.
1109 | Note: If the input has more than 1 batch dimension, all batch dimensions
1110 | are flattened into a single dimension for the convolution and restored
1111 | before returning. In some cases directly vmap'ing the layer may yield
1112 | better performance than this default flattening approach. If the input
1113 | lacks a batch dimension it will be added for the convolution and removed
1114 | n return, an allowance made to enable writing single-example code.
1115 |
1116 | Returns:
1117 | The convolved data.
1118 | """
1119 |
1120 | if isinstance(self.kernel_size, int):
1121 | raise TypeError(
1122 | "Expected Conv kernel_size to be a"
1123 | " tuple/list of integers (eg.: [3, 3]) but got"
1124 | f" {self.kernel_size}."
1125 | )
1126 | else:
1127 | kernel_size = tuple(self.kernel_size)
1128 |
1129 | def maybe_broadcast(x: Optional[Union[int, Sequence[int]]]) -> Tuple[int, ...]:
1130 | if x is None:
1131 | # backward compatibility with using None as sentinel for
1132 | # broadcast 1
1133 | x = 1
1134 | if isinstance(x, int):
1135 | return (x,) * len(kernel_size)
1136 | return tuple(x)
1137 |
1138 | # Combine all input batch dimensions into a single leading batch axis.
1139 | num_batch_dimensions = inputs.ndim - (len(kernel_size) + 1)
1140 | if num_batch_dimensions != 1:
1141 | input_batch_shape = inputs.shape[:num_batch_dimensions]
1142 | total_batch_size = int(np.prod(input_batch_shape))
1143 | flat_input_shape = (total_batch_size,) + inputs.shape[num_batch_dimensions:]
1144 | inputs = jnp.reshape(inputs, flat_input_shape)
1145 |
1146 | # self.strides or (1,) * (inputs.ndim - 2)
1147 | strides = maybe_broadcast(self.strides)
1148 | input_dilation = maybe_broadcast(self.input_dilation)
1149 | kernel_dilation = maybe_broadcast(self.kernel_dilation)
1150 |
1151 | padding_lax = canonicalize_padding(self.padding, len(kernel_size))
1152 | if padding_lax == "CIRCULAR":
1153 | kernel_size_dilated = [(k - 1) * d + 1 for k, d in zip(kernel_size, kernel_dilation)]
1154 | zero_pad: List[Tuple[int, int]] = [(0, 0)]
1155 | pads = zero_pad + [((k - 1) // 2, k // 2) for k in kernel_size_dilated] + [(0, 0)]
1156 | inputs = jnp.pad(inputs, pads, mode="wrap")
1157 | padding_lax = "VALID"
1158 | elif padding_lax == "CAUSAL":
1159 | if len(kernel_size) != 1:
1160 | raise ValueError("Causal padding is only implemented for 1D convolutions.")
1161 | left_pad = kernel_dilation[0] * (kernel_size[0] - 1)
1162 | pads = [(0, 0), (left_pad, 0), (0, 0)]
1163 | inputs = jnp.pad(inputs, pads)
1164 | padding_lax = "VALID"
1165 |
1166 | dimension_numbers = _conv_dimension_numbers(inputs.shape)
1167 | in_features = jnp.shape(inputs)[-1]
1168 |
1169 | if self.shared_weights:
1170 | # One shared convolutional kernel for all pixels in the output.
1171 | assert in_features % self.feature_group_count == 0
1172 | kernel_shape = kernel_size + (in_features // self.feature_group_count, self.features)
1173 |
1174 | else:
1175 | if self.feature_group_count != 1:
1176 | raise NotImplementedError(
1177 | f"`lax.conv_general_dilated_local` does not support "
1178 | f"`feature_group_count != 1`, got `{self.feature_group_count}`."
1179 | )
1180 |
1181 | # Need to know the spatial output shape of a standard convolution to
1182 | # create the unshared convolution kernel.
1183 | conv_output_shape = jax.eval_shape(
1184 | lambda lhs, rhs: self.conv_general_dilated( # pylint: disable=g-long-lambda
1185 | lhs=lhs,
1186 | rhs=rhs,
1187 | window_strides=strides,
1188 | padding=padding_lax,
1189 | dimension_numbers=dimension_numbers,
1190 | lhs_dilation=input_dilation,
1191 | rhs_dilation=kernel_dilation,
1192 | ),
1193 | inputs,
1194 | jax.ShapedArray(kernel_size + (in_features, self.features), inputs.dtype),
1195 | ).shape
1196 |
1197 | # One (unshared) convolutional kernel per each pixel in the output.
1198 | kernel_shape = conv_output_shape[1:-1] + (np.prod(kernel_size) * in_features, self.features)
1199 |
1200 | if self.mask is not None and self.mask.shape != kernel_shape:
1201 | raise ValueError(
1202 | "Mask needs to have the same shape as weights. " f"Shapes are: {self.mask.shape}, {kernel_shape}"
1203 | )
1204 |
1205 | kernel = param_with_axes(
1206 | "kernel",
1207 | self.kernel_init,
1208 | kernel_shape,
1209 | self.params_dtype,
1210 | axes=self.kernel_axes,
1211 | )
1212 |
1213 | if self.mask is not None:
1214 | kernel *= self.mask
1215 |
1216 | if self.use_bias:
1217 | if self.shared_weights:
1218 | # One bias weight per output channel, shared between pixels.
1219 | bias_shape = (self.features,)
1220 | else:
1221 | # One bias weight per output entry, unshared betwen pixels.
1222 | bias_shape = conv_output_shape[1:]
1223 |
1224 | bias = param_with_axes(
1225 | "bias",
1226 | self.bias_init,
1227 | bias_shape,
1228 | self.params_dtype,
1229 | axes=(self.kernel_axes[-1],),
1230 | )
1231 | else:
1232 | bias = None
1233 |
1234 | inputs, kernel, bias = promote_dtype(inputs, kernel, bias, dtype=self.dtype)
1235 | if self.shared_weights:
1236 | y = self.conv_general_dilated(
1237 | inputs,
1238 | kernel,
1239 | strides,
1240 | padding_lax,
1241 | lhs_dilation=input_dilation,
1242 | rhs_dilation=kernel_dilation,
1243 | dimension_numbers=dimension_numbers,
1244 | feature_group_count=self.feature_group_count,
1245 | precision=self.precision,
1246 | )
1247 | else:
1248 | y = lax.conv_general_dilated_local(
1249 | lhs=inputs,
1250 | rhs=kernel,
1251 | window_strides=strides,
1252 | padding=padding_lax,
1253 | filter_shape=kernel_size,
1254 | lhs_dilation=input_dilation,
1255 | rhs_dilation=kernel_dilation,
1256 | dimension_numbers=dimension_numbers,
1257 | precision=self.precision,
1258 | )
1259 |
1260 | if self.use_bias:
1261 | bias = bias.reshape((1,) * (y.ndim - bias.ndim) + bias.shape)
1262 | y += bias
1263 |
1264 | if num_batch_dimensions != 1:
1265 | output_shape = input_batch_shape + y.shape[1:]
1266 | y = jnp.reshape(y, output_shape)
1267 | return y
1268 |
1269 |
1270 | class Conv(_Conv):
1271 | """Convolution Module wrapping `lax.conv_general_dilated`.
1272 |
1273 | Attributes:
1274 | features: number of convolution filters.
1275 | kernel_size: shape of the convolutional kernel. For 1D convolution,
1276 | the kernel size can be passed as an integer. For all other cases, it must
1277 | be a sequence of integers.
1278 | strides: an integer or a sequence of `n` integers, representing the
1279 | inter-window strides (default: 1).
1280 | padding: either the string `'SAME'`, the string `'VALID'`, the string
1281 | `'CIRCULAR'` (periodic boundary conditions), or a sequence of `n` `(low,
1282 | high)` integer pairs that give the padding to apply before and after each
1283 | spatial dimension. A single int is interpeted as applying the same padding
1284 | in all dims and passign a single int in a sequence causes the same padding
1285 | to be used on both sides. `'CAUSAL'` padding for a 1D convolution will
1286 | left-pad the convolution axis, resulting in same-sized output.
1287 | input_dilation: an integer or a sequence of `n` integers, giving the
1288 | dilation factor to apply in each spatial dimension of `inputs`
1289 | (default: 1). Convolution with input dilation `d` is equivalent to
1290 | transposed convolution with stride `d`.
1291 | kernel_dilation: an integer or a sequence of `n` integers, giving the
1292 | dilation factor to apply in each spatial dimension of the convolution
1293 | kernel (default: 1). Convolution with kernel dilation
1294 | is also known as 'atrous convolution'.
1295 | feature_group_count: integer, default 1. If specified divides the input
1296 | features into groups.
1297 | use_bias: whether to add a bias to the output (default: True).
1298 | mask: Optional mask for the weights during masked convolution. The mask must
1299 | be the same shape as the convolution weight matrix.
1300 | dtype: the dtype of the computation (default: infer from input and params).
1301 | params_dtype: the dtype passed to parameter initializers (default: float32).
1302 | precision: numerical precision of the computation see `jax.lax.Precision`
1303 | for details.
1304 | kernel_init: initializer for the convolutional kernel.
1305 | bias_init: initializer for the bias.
1306 | """
1307 |
1308 | @property
1309 | def shared_weights(self) -> bool:
1310 | return True
1311 |
--------------------------------------------------------------------------------
/whisper_jax/partitioner.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The T5X Authors and The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Utilities for partitioning."""
17 |
18 | import abc
19 | import collections
20 | import dataclasses
21 | import typing
22 | from typing import Any, Callable, Optional, Sequence, Tuple, Union
23 |
24 | import cached_property
25 | import jax
26 | import numpy as np
27 | from absl import logging
28 | from flax import traverse_util
29 | from flax.linen import partitioning as flax_partitioning
30 | from jax import numpy as jnp
31 | from jax import random
32 | from jax.experimental import multihost_utils
33 | from jax.experimental.mesh_utils import create_hybrid_device_mesh
34 | from jax.experimental.pjit import pjit as jax_pjit
35 | from jax.sharding import Mesh, PartitionSpec
36 |
37 |
38 | JaxDevice = Any
39 | TpuMesh = Tuple[int, int, int, int] # (x, y, z, num_cores).
40 | OtherMesh = Tuple[int, int]
41 | HardwareMesh = Union[TpuMesh, OtherMesh]
42 | PyTreeDef = type(jax.tree_util.tree_structure(None))
43 | TrainState = Any
44 | LogicalAxisRules = Sequence[Tuple[str, Optional[str]]]
45 |
46 | if typing.TYPE_CHECKING: # See b/163639353
47 | cached_property = property # pylint: disable=invalid-name
48 | else:
49 | cached_property = cached_property.cached_property
50 |
51 |
52 | class AxisNames(tuple):
53 | """Tuple of strings specifying name for each axis.
54 |
55 | We create a separate class for this so JAX's pytree utilities can distinguish
56 | it from a tuple that should be treated as a pytree, instead treating it as a
57 | leaf.
58 | """
59 |
60 | def __new__(cls, *names):
61 | return tuple.__new__(AxisNames, names)
62 |
63 | def __repr__(self):
64 | return "AxisNames%s" % tuple.__repr__(self)
65 |
66 |
67 | # pjit wrappers for cpu fallback.
68 | # ----------------------------------------------------------------------------
69 | # TODO(levskaya): This function is now no different than jax_pjit, but callers
70 | # currently depend on `backend` argument
71 | def pjit(
72 | fun: Callable, # pylint: disable=g-bare-generic
73 | in_axis_resources,
74 | out_axis_resources,
75 | static_argnums: Union[int, Sequence[int]] = (),
76 | donate_argnums: Union[int, Sequence[int]] = (),
77 | backend: Optional[str] = None,
78 | ):
79 | """Wrapper for pjit."""
80 | del backend
81 | return jax_pjit(
82 | fun, in_axis_resources, out_axis_resources, static_argnums=static_argnums, donate_argnums=donate_argnums
83 | )
84 |
85 |
86 | # pjit wrappers for cpu fallback.
87 | # -----------------------------------------------------------------------------
88 | # TODO(levskaya): upstream this fallback behavior to jax pjit.
89 | def pjit_with_cpu_fallback(
90 | fun: Callable, # pylint: disable=g-bare-generic
91 | in_axis_resources,
92 | out_axis_resources,
93 | static_argnums: Union[int, Sequence[int]] = (),
94 | donate_argnums: Union[int, Sequence[int]] = (),
95 | backend: Optional[str] = None,
96 | ):
97 | """Wrapper for pjit that calls normal jit on cpu."""
98 | if jax.devices(backend)[0].platform == "cpu":
99 | return jax.jit(fun, static_argnums=static_argnums, donate_argnums=donate_argnums)
100 | else:
101 | return jax_pjit(
102 | fun, in_axis_resources, out_axis_resources, static_argnums=static_argnums, donate_argnums=donate_argnums
103 | )
104 |
105 |
106 | def with_sharding_constraint(x, axis_resources):
107 | """Wrapper for pjit with_sharding_constraint, no-op on cpu or outside pjit."""
108 | if jax.devices()[0].platform == "cpu" or not global_mesh_defined():
109 | return x
110 | else:
111 | return jax.experimental.pjit.with_sharding_constraint(x, axis_resources)
112 |
113 |
114 | # pjit Mesh creation functions.
115 | # -----------------------------------------------------------------------------
116 | def bounds_from_last_device(last_device: JaxDevice) -> HardwareMesh:
117 | """Get the bound from the given last device."""
118 | # Must be passed the device at the highest-coordinate corner of the
119 | # relevant mesh, which is a requirement we know is satisfied by the last
120 | # device in jax.devices().
121 | if hasattr(last_device, "coords"):
122 | x, y, z = last_device.coords
123 | return x + 1, y + 1, z + 1, last_device.core_on_chip + 1
124 | else:
125 | # On non-TPU platforms, the "mesh" is hosts x devices per host in order
126 | # to take advantage of faster within-host interconnect.
127 | return jax.host_count(), jax.local_device_count()
128 |
129 |
130 | def get_coords(device: JaxDevice) -> HardwareMesh:
131 | """Returns the coordinates of the given device."""
132 | if hasattr(device, "coords"):
133 | return (*device.coords, device.core_on_chip)
134 | return (device.process_index, device.id % jax.local_device_count())
135 |
136 |
137 | def global_mesh_defined():
138 | """Checks if global xmap/pjit mesh resource environment is defined."""
139 | maps_env = jax.experimental.maps.thread_resources.env
140 | return maps_env.physical_mesh.devices.shape != () # pylint: disable=g-explicit-bool-comparison
141 |
142 |
143 | def get_mesh(
144 | model_parallel_submesh: HardwareMesh,
145 | input_devices: Sequence[JaxDevice] = (),
146 | input_local_devices: Sequence[JaxDevice] = (),
147 | tile_by_host_if_needed: bool = True,
148 | backend: Optional[str] = None,
149 | ) -> Mesh:
150 | """Construct an xmap/pjit Mesh for the given model-parallel submesh.
151 |
152 | The resulting mesh has two resource axes: 'model', with the provided submesh
153 | shape, and 'data', which covers the rest of the mesh.
154 |
155 | Args:
156 | model_parallel_submesh: a HardwareMesh spec, namely (x,y,z,core) on TPU for
157 | a single model-parallel replica's "tile" in the physical device mesh. The
158 | first three elements (`x`, `y`, and `z`) should be factors of the pod
159 | slice; e.g., if you are using df_4x8, then `x` should be a factor of 4
160 | (one of 1, 2, 4), `y` should be a factor of 8 (one of 1, 2, 4, 8), and `z`
161 | must be 1, because TPU v3 slices are only 2D. `z` can be >1 for TPU v4
162 | (and maybe later TPUs) that allow 3D slices. `core` is the number of cores
163 | to use from each TPU node. As communication is usually fastest inside the
164 | same node, if you need a tile of more than 1 core, then
165 | you should first increase `core`: e.g., for TPU v3, (1,1,1,2) is better
166 | than (2,1,1,1). To pick a good spec, try a few possible values until you
167 | get high TPU utilization.
168 | input_devices: the devices to use, will use jax.devices() if this is not
169 | set.
170 | input_local_devices: the local devices to use, will use jax.local_devices()
171 | if this is not set.
172 | tile_by_host_if_needed: JAX currently requires that the parts of any sharded
173 | array that are located on one host's local devices form a single
174 | contiguous slice. A best effort will be made to achieve this without
175 | "tiling" the device assignment over hosts (which can reduce XLA collective
176 | performance). If this flag is True, then the device assignment will be
177 | tiled over hosts if necessary to satisfy this constraint and create a
178 | buildable mesh; if false, mesh construction will fail instead.
179 | backend: get devices from the pinned backend, if specified. This is
180 | useful for explicitly specifying the devices other than relying on
181 | jax_platform_name.
182 |
183 | Returns:
184 | A xmap / pjit Mesh containing the virtual device mesh with data, model axes.
185 | """
186 | input_devices = input_devices or jax.devices(backend)
187 | input_local_devices = input_local_devices or jax.local_devices(0, backend)
188 | # Sort input_devices based on coords, as backends might not return devices
189 | # in order.
190 | last_device = sorted(input_devices, key=get_coords)[-1]
191 | last_input_local_devices = sorted(input_local_devices, key=get_coords)[-1]
192 | logging.info(
193 | "last device coords : %r\nlast local device coords: %r",
194 | get_coords(last_device),
195 | get_coords(last_input_local_devices),
196 | )
197 | global_hardware_mesh = bounds_from_last_device(last_device)
198 | mesh_ndim = len(global_hardware_mesh)
199 | local_hardware_mesh = bounds_from_last_device(last_input_local_devices)
200 | mesh_err = (
201 | f"each dimension of the model parallel submesh {model_parallel_submesh} "
202 | "must be a factor of the corresponding dimension of the global device "
203 | f"mesh {global_hardware_mesh}"
204 | )
205 | assert not any(g % m for g, m in zip(global_hardware_mesh, model_parallel_submesh)), mesh_err
206 | assert not any(g % l for g, l in zip(global_hardware_mesh, local_hardware_mesh))
207 | devices = np.empty(global_hardware_mesh, dtype=object)
208 | for device in input_devices:
209 | device_coords = get_coords(device)
210 | devices[device_coords] = device
211 | tile_by_host = tile_by_host_if_needed
212 | if len(global_hardware_mesh) == 4:
213 | # enable contiguous local chunks without host tiling by making Z major
214 | global_hardware_mesh = typing.cast(Tuple[int, int, int, int], global_hardware_mesh)
215 | model_parallel_submesh = typing.cast(Tuple[int, int, int, int], model_parallel_submesh)
216 | gx, gy, gz, gc = global_hardware_mesh
217 | mx, my, mz, mc = model_parallel_submesh
218 | if (mx == gx > 1 and my == mz == 1) or (mx == 1 and my == gy > 1 and mz == gz > 1):
219 | logging.info("ensuring YZ plane has a Z-major device order")
220 | # YZ should be ZY
221 | assert mc == gc, (mc, gc)
222 | global_hardware_mesh = gx, gz, gy, gc
223 | model_parallel_submesh = mx, mz, my, mc
224 | devices = devices.swapaxes(1, 2)
225 | tile_by_host = False
226 | if (my == gy > 1 and mx == mz == 1) or (my == 1 and mx == gx > 1 and mz == gz > 1):
227 | logging.info("ensuring XZ plane has a Z-major device order")
228 | # XZ should be ZX
229 | assert mc == gc, (mc, gc)
230 | global_hardware_mesh = gz, gy, gx, gc
231 | model_parallel_submesh = mz, my, mx, mc
232 | devices = devices.swapaxes(0, 2)
233 | tile_by_host = False
234 | if tile_by_host:
235 | logging.warning(
236 | "Tiling device assignment mesh by hosts, which may lead to "
237 | "reduced XLA collective performance. To avoid this, modify "
238 | "the model parallel submesh or run with more tasks per host."
239 | )
240 | tile_err = (
241 | "to tile the mesh by hosts, each dimension of the model parallel "
242 | "submesh must be either a factor or a multiple of the corresponding "
243 | "dimension of the per-host submesh"
244 | )
245 |
246 | def dh_dd_mh_md(g: int, m: int, l: int) -> Tuple[int, int, int, int]:
247 | """Split a global mesh dimension into four tiling components.
248 |
249 | Args:
250 | g: global mesh bounds dimension size
251 | m: model-parallel submesh bounds dimension size
252 | l: local submesh bounds dimension size
253 |
254 | Returns:
255 | The resulting tuple divides the dimension into the hosts component of
256 | the data-parallel submesh, the devices component of the data-parallel
257 | submesh, the hosts component of the model-parallel submesh, and the
258 | devices component of the model-parallel submesh.
259 | """
260 | d = g // m
261 | if m >= l:
262 | assert not m % l, tile_err
263 | return (d, 1, m // l, l)
264 | else:
265 | assert not l % m, tile_err
266 | return (d // (l // m), l // m, 1, m)
267 |
268 | # e.g. [(x_data_hosts, x_data_devs, x_model_hosts, x_model_devs), ...]
269 | dh_dd_mh_md_tups = map(dh_dd_mh_md, global_hardware_mesh, model_parallel_submesh, local_hardware_mesh)
270 | # reshape to e.g. (x_dh, x_dd, x_mh, x_md, y_dh, ...)
271 | devices = devices.reshape(*(s for t in dh_dd_mh_md_tups for s in t)) # pylint: disable=g-complex-comprehension
272 | # TODO(jekbradbury): reorder local subgroups for ring locality
273 | # Transpose to [data_host], [data_device], [model_host], [model_device]
274 | # block ordering e.g. (x_dh, y_dh, ..., x_dd, y_dd, ...)
275 | devices = devices.transpose(
276 | *(4 * i for i in range(mesh_ndim)),
277 | *(4 * i + 1 for i in range(mesh_ndim)),
278 | *(4 * i + 2 for i in range(mesh_ndim)),
279 | *(4 * i + 3 for i in range(mesh_ndim)),
280 | )
281 | else:
282 | # e.g. [(x_data, x_model), (y_data, y_model), ...]
283 | model_data_tups = [(g // m, m) for g, m in zip(global_hardware_mesh, model_parallel_submesh)]
284 | # reshape to e.g. (x_data, x_model, y_data, y_model...)
285 | devices = devices.reshape(*(s for t in model_data_tups for s in t)) # pylint: disable=g-complex-comprehension
286 | # TODO(jekbradbury): reorder small subgroups for ring locality
287 | # transpose to e.g. (x_data, y_data, ..., x_model, ...)
288 | devices = devices.transpose(*(2 * i for i in range(mesh_ndim)), *(2 * i + 1 for i in range(mesh_ndim)))
289 | # reshape to (data, model)
290 | devices = devices.reshape(-1, np.prod(model_parallel_submesh))
291 | global_mesh = Mesh(devices, ["data", "model"])
292 | logging.info("global_mesh axis_names: %s", global_mesh.axis_names)
293 | logging.info("global_mesh devices: %s", global_mesh.devices)
294 | logging.info("global_mesh devices shape: %s", global_mesh.devices.shape)
295 | return global_mesh
296 |
297 |
298 | def get_cpu_mesh() -> Mesh:
299 | """Trivial mesh for CPU Testing."""
300 | devices = np.empty((jax.host_count(), jax.local_device_count()), dtype=object)
301 | for device in jax.devices():
302 | devices[device.process_index, device.id % jax.local_device_count()] = device
303 | return Mesh(devices, ["data", "model"])
304 |
305 |
306 | def get_gpu_mesh(num_partitions: int) -> Mesh:
307 | """Mesh for GPUs that preferentially places 'model' on NVLink."""
308 | nvlink_size = jax.local_device_count()
309 | dcn_size = jax.process_count()
310 | nvlink_mp = min(num_partitions, nvlink_size)
311 | nvlink_dp, extra1 = divmod(nvlink_size, nvlink_mp)
312 | dcn_mp, extra2 = divmod(num_partitions, nvlink_mp)
313 | assert not (extra1 or extra2), (
314 | "number of partitions on GPU must be a factor" " or multiple of the number of local devices"
315 | )
316 | dcn_dp = dcn_size // dcn_mp
317 |
318 | devices = create_hybrid_device_mesh(
319 | mesh_shape=[nvlink_dp, nvlink_mp], dcn_mesh_shape=[dcn_dp, dcn_mp], process_is_granule=True
320 | )
321 |
322 | global_mesh = Mesh(devices, ["data", "model"])
323 | logging.info("global_mesh axis_names: %s", global_mesh.axis_names)
324 | logging.info("global_mesh devices: %s", global_mesh.devices)
325 | return global_mesh
326 |
327 |
328 | def default_mesh(
329 | num_partitions: int, model_parallel_submesh: Optional[HardwareMesh] = None, backend: Optional[str] = None
330 | ) -> Mesh:
331 | """Attempt to return a default mesh for simple cases.
332 |
333 | Args:
334 | num_partitions: number of partitions to use, will be ignored if
335 | model_parallel_submesh is provided.
336 | model_parallel_submesh: 4-tuple that specifies the x,y,z,c submesh to use as
337 | the model-parallel device tile.
338 | backend: get devices from the pinned backend, if specified. This is useful
339 | for explicitly specifying the devices other than relying on
340 | jax_platform_name.
341 |
342 | Returns:
343 | xmap/pjit 2D Mesh with 'data', 'model' mesh axes.
344 | """
345 | last_device = jax.devices(backend)[-1]
346 | platform = last_device.platform
347 | device_kind = last_device.device_kind
348 | bounds = bounds_from_last_device(last_device)
349 |
350 | if model_parallel_submesh:
351 | return get_mesh(model_parallel_submesh, backend=backend)
352 |
353 | if platform == "cpu":
354 | return get_cpu_mesh()
355 | elif platform == "gpu":
356 | return get_gpu_mesh(num_partitions)
357 |
358 | mps = None
359 | if device_kind in ("TPU v2", "TPU v3"):
360 | if num_partitions == 1:
361 | mps = (1, 1, 1, 1)
362 | elif num_partitions == 2:
363 | mps = (1, 1, 1, 2)
364 | elif num_partitions == 4:
365 | mps = (2, 1, 1, 2)
366 | elif num_partitions == 8:
367 | mps = (2, 2, 1, 2)
368 | elif num_partitions == 16:
369 | mps = (4, 2, 1, 2)
370 | # assume the use of megacore on TPU v4
371 | elif (device_kind == "TPU v4" or device_kind == "TPU v4 lite") and bounds[3] == 1:
372 | if num_partitions == 1:
373 | mps = (1, 1, 1, 1)
374 | elif num_partitions == 2:
375 | mps = (1, 2, 1, 1)
376 | elif num_partitions == 4:
377 | if bounds[0] >= 4:
378 | mps = (4, 1, 1, 1)
379 | else:
380 | mps = (2, 2, 1, 1)
381 | elif num_partitions == 8:
382 | if bounds[2] >= 8:
383 | mps = (1, 1, 8, 1)
384 | else:
385 | mps = (4, 2, 1, 1)
386 | elif num_partitions == 16:
387 | if bounds[2] >= 16:
388 | mps = (1, 1, 16, 1)
389 | elif bounds[0] >= 8:
390 | mps = (8, 2, 1, 1)
391 | elif bounds[0] >= 4:
392 | mps = (4, 4, 1, 1)
393 | else:
394 | mps = (2, 2, 4, 1)
395 |
396 | if mps is None:
397 | raise ValueError(
398 | "No default mesh for this configuration: specify " "config.model_parallel_submesh explicitly."
399 | )
400 | return get_mesh(mps, backend=backend)
401 |
402 |
403 | # Data chunking helper.
404 | # -----------------------------------------------------------------------------
405 | @dataclasses.dataclass
406 | class LocalChunkInfo:
407 | # The logical slice of an array located on this host's local devices.
408 | slice: Tuple[slice, ...]
409 | # A unique index for this host/local chunk among chunks with the same slice.
410 | replica_id: int
411 |
412 |
413 | class LocalChunker:
414 | """Utility class to aid chunking of sharded arrays in multihost settings."""
415 |
416 | def __init__(self, global_mesh: Mesh):
417 | self.global_mesh = global_mesh
418 | local_mesh = global_mesh.local_mesh
419 | first_local_device = local_mesh.devices.reshape(-1)[0]
420 | host_location = collections.OrderedDict(
421 | zip(global_mesh.shape.keys(), list(zip(*np.nonzero(global_mesh.devices == first_local_device)))[0])
422 | )
423 | self.num_chunks = collections.OrderedDict()
424 | self.chunk_ids = collections.OrderedDict()
425 | self.mesh_axes = list(global_mesh.shape.keys())
426 | for mesh_axis in self.mesh_axes:
427 | num_devices_per_chunk = local_mesh.shape[mesh_axis]
428 | self.num_chunks[mesh_axis] = global_mesh.shape[mesh_axis] // num_devices_per_chunk
429 | self.chunk_ids[mesh_axis] = host_location[mesh_axis] // num_devices_per_chunk
430 |
431 | def get_local_chunk_info(
432 | self, global_shape: Tuple[int, ...], mesh_axes: Sequence[Optional[str]]
433 | ) -> LocalChunkInfo:
434 | """Get the local chunk info for a given array shape and sharded axes.
435 |
436 | Args:
437 | global_shape: the global, unsharded shape of the array to chunk.
438 | mesh_axes: a sequence of names (or None) of equal rank to `global_shape`
439 | that specifies which mesh dimensions the array is sharded along.
440 |
441 | Returns:
442 | LocalChunkInfo containing the logical slices of the array found on this
443 | host's local devices, as well as the replica index for this chunk among
444 | chunks with the same slice. The latter is used to determine which
445 | host should write this chunk during checkpointing.
446 | """
447 | local_slice = [slice(None) for dim in global_shape]
448 | sharded_mesh_axes = set()
449 | for i, (mesh_axis, size) in enumerate(zip(mesh_axes, global_shape)):
450 | if not mesh_axis:
451 | continue
452 | sharded_mesh_axes.add(mesh_axis)
453 | if not isinstance(mesh_axis, str):
454 | raise NotImplementedError("TODO(jekbradbury)")
455 | chunk_id = self.chunk_ids[mesh_axis]
456 | chunk_size = size // self.num_chunks[mesh_axis]
457 | local_slice[i] = slice(chunk_id * chunk_size, (chunk_id + 1) * chunk_size)
458 |
459 | replicated_mesh_axes = [mesh_axis for mesh_axis in self.mesh_axes if mesh_axis not in sharded_mesh_axes]
460 | replica_id = 0
461 | for mesh_axis in replicated_mesh_axes:
462 | chunk_id = self.chunk_ids[mesh_axis]
463 | replica_id = replica_id * self.num_chunks[mesh_axis] + chunk_id
464 |
465 | return LocalChunkInfo(tuple(local_slice), replica_id)
466 |
467 |
468 | def standard_logical_axis_rules(
469 | activation_partitioning_dims: int = 1,
470 | parameter_partitioning_dims: int = 1,
471 | additional_rules: Optional[LogicalAxisRules] = None,
472 | ) -> LogicalAxisRules:
473 | """Default sharding rules for T5X model in terms of logical axis names.
474 |
475 | Args:
476 | activation_partitioning_dims: enables 2-D activation sharding when set to 2.
477 | parameter_partitioning_dims: enables 2-D parameter sharding when set to 2.
478 | additional_rules: additional rules (a sequence of tuples) that will be
479 | appended to the standard rules.
480 |
481 | Returns:
482 | Sequence of logical axis rules
483 | """
484 | logging.info(
485 | "`activation_partitioning_dims` = %d, `parameter_partitioning_dims` = %d",
486 | activation_partitioning_dims,
487 | parameter_partitioning_dims,
488 | )
489 |
490 | if activation_partitioning_dims == 1 and parameter_partitioning_dims == 1:
491 | rules = [
492 | ("batch", "data"),
493 | ("vocab", "model"),
494 | ("embed", None),
495 | ("mlp", "model"),
496 | ("heads", "model"),
497 | ("kv", None),
498 | ("joined_kv", "model"), # joined heads+kv dim in 2D attn param layouts
499 | ]
500 | elif activation_partitioning_dims == 2 and parameter_partitioning_dims == 1:
501 | rules = [
502 | ("batch", "data"),
503 | ("vocab", "model"),
504 | ("mlp", "model"),
505 | ("heads", "model"),
506 | ("kv", None),
507 | ("joined_kv", "model"),
508 | ("embed", "model"),
509 | ]
510 | elif activation_partitioning_dims == 1 and parameter_partitioning_dims == 2:
511 | rules = [
512 | ("batch", "data"),
513 | ("vocab", "model"),
514 | ("mlp", "model"),
515 | ("heads", "model"),
516 | ("kv", None),
517 | ("joined_kv", "model"),
518 | ("embed", "data"),
519 | ]
520 | elif activation_partitioning_dims == 2 and parameter_partitioning_dims == 2:
521 | rules = [
522 | ("batch", "data"),
523 | ("vocab", "model"),
524 | ("mlp", "model"),
525 | ("heads", "model"),
526 | ("kv", None),
527 | ("joined_kv", "model"),
528 | ("embed", "model"),
529 | ("embed", "data"),
530 | ]
531 | else:
532 | raise ValueError(
533 | f"`activation_partitioning_dims` = {activation_partitioning_dims} "
534 | f"`parameter_partitioning_dims` = {parameter_partitioning_dims} "
535 | "is not supported."
536 | )
537 |
538 | # Add the common rules for the replicated logical axes names.
539 | replicated_rules = [
540 | ("relpos_buckets", None),
541 | ("abspos_buckets", None),
542 | ("length", None),
543 | ("layers", None),
544 | ("stack", None),
545 | ("mlp_activations", None),
546 | ]
547 | rules.extend(replicated_rules)
548 |
549 | if additional_rules:
550 | rules.extend(additional_rules)
551 |
552 | return rules
553 |
554 |
555 | # NB: This needs to be top-level for the jax compilation cache.
556 | def _id_fn(x, ix):
557 | """Identity function for copying parameters to the devices, sharded."""
558 | # A pure identity such as `lambda x, *: x` can get optimized away, so we
559 | # include a random.split as a cheap function that cannot be optimized away.
560 | y = random.split(random.PRNGKey(jnp.array(ix, dtype=jnp.uint32)))
561 | return x, y
562 |
563 |
564 | @dataclasses.dataclass
565 | class DataLayout:
566 | """Represents data layout for the partitioned model."""
567 |
568 | batch_size: int
569 | shard_id: int
570 | num_shards: int
571 | is_first_host_in_replica_set: bool
572 |
573 |
574 | PartitionedCallable = Callable[..., Any]
575 | CompiledPartitionedCallable = Callable[..., Any]
576 |
577 |
578 | class BasePartitioner(metaclass=abc.ABCMeta):
579 | """Interface for partitioning computations across hardware devices."""
580 |
581 | def __init__(
582 | self,
583 | num_partitions: Optional[int] = None,
584 | model_parallel_submesh: Optional[HardwareMesh] = None,
585 | params_on_devices: bool = True,
586 | backend: Optional[str] = None,
587 | ):
588 | """Configures the partitioner.
589 |
590 | Args:
591 | num_partitions: the number of partitions to use. Ignored if
592 | `model_parallel_submesh` is provided.
593 | model_parallel_submesh: 4-tuple that specifies the x,y,z,c submesh to use
594 | as the model-parallel device tile. This submesh is used for the larger
595 | of the two parameter dimensions, and, if 2-D activation sharding is
596 | enabled, for the model dimension of activations. The rest of the mesh is
597 | used for data parallelism and, if 2-D parameter sharding is enabled, the
598 | other parameter dimension.
599 | params_on_devices: whether to keep the params on devices, if False -
600 | params stay in the host memory. Note that some partitioners might ignore
601 | this setting, for example if they don't support storing all params on
602 | device memory.
603 | backend: get devices from the pinned backend, if specified. This is useful
604 | for explicitly specifying the devices other than relying on
605 | jax_platform_name.
606 | """
607 |
608 | if not num_partitions and not model_parallel_submesh:
609 | raise ValueError("At least one of `num_partitions` or " "`model_parallel_submesh` must be set.")
610 |
611 | if model_parallel_submesh is not None and len(model_parallel_submesh) != 4:
612 | logging.error(
613 | (
614 | "`model_parallel_submesh` must be either None or a 4-tuple. Got"
615 | " `model_parallel_submesh`=%s. A ValueError will be raised"
616 | " beginning March 1, 2022."
617 | ),
618 | model_parallel_submesh,
619 | )
620 |
621 | if bool(num_partitions) and bool(model_parallel_submesh):
622 | logging.error(
623 | "At most one of `num_partitions` or `model_parallel_submesh` can be "
624 | "set. Got `num_partitions=%s` and `model_parallel_submesh`=%s. A "
625 | "ValueError will be raised beginning March 21, 2022.",
626 | num_partitions,
627 | model_parallel_submesh,
628 | )
629 |
630 | self._num_partitions = num_partitions
631 | self._model_parallel_submesh = model_parallel_submesh
632 | self._params_on_devices = params_on_devices
633 | self._data_axis = "data"
634 | self._backend = backend
635 |
636 | @property
637 | def mesh(self) -> Mesh:
638 | raise NotImplementedError
639 |
640 | @property
641 | def data_partition_spec(self) -> PartitionSpec:
642 | return PartitionSpec(self._data_axis)
643 |
644 | def get_data_layout(self, batch_size: Optional[int] = None, host_index: Optional[int] = None) -> DataLayout:
645 | """Returns filled `DataLayout` based on the partitioned model layout.
646 |
647 | Args:
648 | batch_size: if set, indicates the requested batch size. The exception will
649 | be raised if this batch size is not compatible with the layout. If not
650 | set, the batch size is inferred from the layout.
651 | host_index: indicates the host index to use for the calculations, if not
652 | set - use JAX-provided one. Should be in [0, num_hosts) interval and the
653 | order should match the order of corresponding CPU devices in
654 | `jax.devices()`.
655 |
656 | Returns:
657 | Filled `DataLayout` structure.
658 | """
659 | if host_index is not None:
660 | raise NotImplementedError("Explicit host_index is not yet implemented.")
661 | if self._data_axis is None:
662 | return DataLayout(
663 | batch_size=batch_size,
664 | shard_id=0,
665 | num_shards=1,
666 | is_first_host_in_replica_set=(jax.process_index() == 0),
667 | )
668 | mesh_size = self._local_chunker.global_mesh.shape[self._data_axis]
669 | batch_size = batch_size or mesh_size
670 | if batch_size % mesh_size:
671 | raise ValueError(
672 | f"Batch size ({batch_size}) must be divisible by corresponding " f"mesh size ({mesh_size})."
673 | )
674 | num_shards = self._local_chunker.num_chunks[self._data_axis]
675 | if batch_size % num_shards:
676 | raise ValueError(f"Batch size ({batch_size}) must be divisible by number of " f"replicas ({num_shards}).")
677 | replica_id = self._local_chunker.get_local_chunk_info((batch_size,), [self._data_axis]).replica_id
678 | return DataLayout(
679 | batch_size=int(batch_size),
680 | shard_id=int(self._local_chunker.chunk_ids[self._data_axis]),
681 | num_shards=int(num_shards),
682 | is_first_host_in_replica_set=(replica_id == 0),
683 | )
684 |
685 | def get_local_chunk_info(
686 | self, global_shape: Tuple[int, ...], mesh_axes: Sequence[Optional[str]]
687 | ) -> LocalChunkInfo:
688 | """Returns the local chunk info for a given array shape and sharded axes."""
689 | return self._local_chunker.get_local_chunk_info(global_shape, mesh_axes)
690 |
691 | @property
692 | def params_on_devices(self):
693 | return self._params_on_devices
694 |
695 | def move_params_to_devices(self, train_state: TrainState, train_state_axes: TrainState) -> TrainState:
696 | """Moves the optimizer parameters to devices."""
697 | p_id_fn = self.partition(
698 | _id_fn,
699 | in_axis_resources=(train_state_axes, None),
700 | out_axis_resources=(train_state_axes, None),
701 | donate_argnums=(0,),
702 | )
703 | if jax.config.jax_array and jax.process_count() > 1:
704 | train_state = multihost_utils.host_local_array_to_global_array(train_state, self.mesh, train_state_axes)
705 | train_state, _ = p_id_fn(train_state, jnp.ones((), dtype=jnp.uint32))
706 | return train_state
707 |
708 | @property
709 | @abc.abstractmethod
710 | def _local_chunker(self):
711 | """Returns the chunker that matches the parameters of this partitioner."""
712 | raise NotImplementedError
713 |
714 | def get_logical_axes(self, train_state: TrainState) -> TrainState:
715 | """Returns a copy of TrainState with Optional[AxisNames] as leaves."""
716 | # By default, return None for the logical axes.
717 | return train_state.restore_state(jax.tree_map(lambda x: None, train_state.state_dict()))
718 |
719 | def get_mesh_axes(self, train_state: TrainState) -> TrainState:
720 | """Returns a copy of TrainState with Optional[PartitionSpecs] as leaves."""
721 | raise NotImplementedError
722 |
723 | @abc.abstractmethod
724 | def partition(
725 | self,
726 | fn: Callable, # pylint: disable=g-bare-generic
727 | in_axis_resources,
728 | out_axis_resources,
729 | static_argnums: Union[int, Sequence[int]] = (),
730 | donate_argnums: Union[int, Sequence[int]] = (),
731 | ) -> PartitionedCallable:
732 | """Partitions the computation using partitioner-specific implementation.
733 |
734 | Args:
735 | fn: the function to partition.
736 | in_axis_resources: Pytree of structure matching that of arguments to `fn`,
737 | with all actual arguments replaced by resource assignment
738 | specifications. It is also valid to specify a pytree prefix (e.g. one
739 | value in place of a whole subtree), in which case the leaves get
740 | broadcast to all values in that subtree.
741 | The valid resource assignment specifications are:
742 | `None`: in which case the value will be replicated on all devices
743 | `PartitionSpec`: a tuple of length at most equal to the rank of the
744 | partitioned value. Each element can be a `None`, a mesh axis or a
745 | tuple of mesh axes, and specifies the set of resources assigned to
746 | partition the value's dimension matching its position in the spec.
747 | out_axis_resources: Like `in_axis_resources`, but specifies resource
748 | assignment for function outputs.
749 | static_argnums: an optional int or collection of ints that specify which
750 | positional arguments to treat as static (compile-time constant) in the
751 | partitioned function.
752 | donate_argnums: an optional int or collection of ints that specify which
753 | argument buffers are "donated" to the computation. It is safe to donate
754 | argument buffers if you no longer need them once the computation has
755 | finished.
756 |
757 | Returns:
758 | A partitioned version of the input function.
759 | """
760 | raise NotImplementedError
761 |
762 | @abc.abstractmethod
763 | def compile(self, partitioned_fn: PartitionedCallable, *args) -> CompiledPartitionedCallable:
764 | """Compiles and returns the partitioned function, or the original.
765 |
766 | Args:
767 | partitioned_fn: The partitioned function.
768 | *args: Sample arguments to the partitioned function matching the input
769 | shapes that will be passed to the compiled function.
770 |
771 | Returns:
772 | The compiled function, or the original if this partitioner does not
773 | support compilation.
774 | """
775 | raise NotImplementedError
776 |
777 |
778 | class PjittedFnWithContext(PartitionedCallable):
779 | """Wraps pjitted function to apply the appropriate contexts."""
780 |
781 | def __init__(self, pjitted_fn, partition_mesh: Mesh, logical_axis_rules: flax_partitioning.LogicalRules = ()):
782 | self._pjitted_fn = pjitted_fn
783 | self._mesh = partition_mesh
784 | self._logical_axis_rules = logical_axis_rules
785 |
786 | def __call__(self, *args):
787 | with Mesh(self._mesh.devices, self._mesh.axis_names), flax_partitioning.axis_rules(self._logical_axis_rules):
788 | return self._pjitted_fn(*args)
789 |
790 | def lower(self, *args):
791 | with Mesh(self._mesh.devices, self._mesh.axis_names), flax_partitioning.axis_rules(self._logical_axis_rules):
792 | return self._pjitted_fn.lower(*args)
793 |
794 |
795 | class BasePjitPartitioner(BasePartitioner):
796 | """Partitioner that uses T5X version of jax.pjit."""
797 |
798 | @cached_property
799 | def _local_chunker(self) -> LocalChunker:
800 | return LocalChunker(self.mesh)
801 |
802 | @cached_property
803 | def mesh(self) -> Mesh:
804 | return default_mesh(self._num_partitions, self._model_parallel_submesh, self._backend)
805 |
806 | def partition(
807 | self,
808 | fn: Callable, # pylint: disable=g-bare-generic
809 | in_axis_resources,
810 | out_axis_resources,
811 | static_argnums: Union[int, Sequence[int]] = (),
812 | donate_argnums: Union[int, Sequence[int]] = (),
813 | ) -> PjittedFnWithContext:
814 | pjitted = pjit(
815 | fn,
816 | in_axis_resources=in_axis_resources,
817 | out_axis_resources=out_axis_resources,
818 | static_argnums=static_argnums,
819 | donate_argnums=donate_argnums,
820 | backend=self._backend,
821 | )
822 |
823 | return PjittedFnWithContext(pjitted, self.mesh)
824 |
825 | def compile(self, partitioned_fn: PjittedFnWithContext, *args) -> CompiledPartitionedCallable:
826 | return partitioned_fn.lower(*args).compile()
827 |
828 |
829 | class PjitPartitioner(BasePjitPartitioner):
830 | """Partitioner that uses named axes and jax.pjit."""
831 |
832 | def __init__(
833 | self,
834 | num_partitions: Optional[int] = None,
835 | model_parallel_submesh: Optional[HardwareMesh] = None,
836 | params_on_devices: bool = True,
837 | backend: Optional[str] = None,
838 | logical_axis_rules: Optional[LogicalAxisRules] = None,
839 | use_cpu_pjit: Optional[bool] = False,
840 | ):
841 | """PjitPartitioner constructor.
842 |
843 | See https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.mdx/usage/partitioning for details.
844 |
845 | Args:
846 | num_partitions: an integer that specifies the size of the model parallel
847 | submesh to be automatically selected for the current topology. See
848 | `model_parallel_submesh` for details on how this submesh is used.
849 | Mutually exlusive with `model_parallel_submesh`.
850 | model_parallel_submesh: is a 4-tuple that specifies the `(x, y, z, c)`
851 | submesh model-parallel device tile, an axis of accelerator parallelism
852 | orthogonal to data parallelism. Array axes in a model's parameters or
853 | activations can be sharded over this submesh using axis rules (see
854 | `logical_axis_rules`) that map them to 'model'. The effective number of
855 | model sub-partitions is equal to `np.prod(model_parallel_submesh)` and
856 | must evenly divide the total number of devices (i.e.,
857 | `jax.device_count() % np.prod(model_parallel_submesh) == 0`). The rest
858 | of the TPU mesh is the data parallel submesh, providing
859 | `jax.device_count() // np.prod(model_parallel_submesh)` partitions. It
860 | is used for data (batch) parallelism and to shard other array axes that
861 | are mapped to 'data'. This argument is mutually exclusive with
862 | `num_partitions`.
863 | params_on_devices: whether to keep the params on devices, if False -
864 | params stay in the host memory. Note that some partitioners might ignore
865 | this setting, for example if they don't support storing all params on
866 | device memory.
867 | backend: get devices from the pinned backend, if specified. This is
868 | useful for explicitly specifying the devices other than relying on
869 | jax_platform_name.
870 | logical_axis_rules: a priority-ordered sequence of KV tuples that maps
871 | logical axis names to either `None` (not sharded), 'model' (to shard
872 | across the model-parallel submesh), or 'data' (to shard across the
873 | data-parallel submesh).
874 | use_cpu_pjit: enables wrapper function for pjit which just jits the
875 | function if using CPU backend.
876 | """
877 | super().__init__(
878 | num_partitions=num_partitions,
879 | model_parallel_submesh=model_parallel_submesh,
880 | params_on_devices=params_on_devices,
881 | backend=backend,
882 | )
883 | if logical_axis_rules is None:
884 | logical_axis_rules = standard_logical_axis_rules()
885 | self._logical_axis_rules = tuple(logical_axis_rules)
886 | (self._data_axis,) = flax_partitioning.logical_to_mesh_axes(["batch"], logical_axis_rules)
887 | self._use_cpu_pjit = use_cpu_pjit
888 |
889 | def partition(
890 | self,
891 | fn: Callable, # pylint: disable=g-bare-generic
892 | in_axis_resources,
893 | out_axis_resources,
894 | static_argnums: Union[int, Sequence[int]] = (),
895 | donate_argnums: Union[int, Sequence[int]] = (),
896 | ) -> PjittedFnWithContext:
897 | """Partitions the function using jax.pjit."""
898 | if self._use_cpu_pjit:
899 | pjit_fn = pjit_with_cpu_fallback
900 | else:
901 | pjit_fn = pjit
902 | pjitted = pjit_fn(
903 | fn,
904 | in_axis_resources=in_axis_resources,
905 | out_axis_resources=out_axis_resources,
906 | static_argnums=static_argnums,
907 | donate_argnums=donate_argnums,
908 | backend=self._backend,
909 | )
910 |
911 | return PjittedFnWithContext(pjitted, self.mesh, self._logical_axis_rules)
912 |
913 | @property
914 | def logical_axis_rules(self):
915 | """Returns the logical axis rules."""
916 | return self._logical_axis_rules
917 |
918 | def get_logical_axes(self, train_state: TrainState) -> TrainState:
919 | """Returns a copy of TrainState with Optional[AxisNames] as leaves."""
920 | return train_state.as_logical_axes()
921 |
922 | def get_mesh_axes(self, train_state: TrainState) -> TrainState:
923 | """Returns a copy of TrainState with Optional[PartitionSpecs] as leaves."""
924 | logical_axes = self.get_logical_axes(train_state)
925 |
926 | def _logical_to_mesh_axes(param_name, logical_axes):
927 | if logical_axes is None:
928 | return None
929 | elif logical_axes is traverse_util.empty_node:
930 | return traverse_util.empty_node
931 | try:
932 | return flax_partitioning.logical_to_mesh_axes(logical_axes, self._logical_axis_rules)
933 | except ValueError as e:
934 | raise ValueError(f"Failed to map logical axes for {param_name}") from e
935 |
936 | flat_logical_axes = traverse_util.flatten_dict(logical_axes.state_dict(), keep_empty_nodes=True, sep="/")
937 | flat_mesh_axes = {k: _logical_to_mesh_axes(k, v) for k, v in flat_logical_axes.items()}
938 |
939 | return logical_axes.restore_state(traverse_util.unflatten_dict(flat_mesh_axes, sep="/"))
940 |
--------------------------------------------------------------------------------
/whisper_jax/pipeline.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | import math
18 |
19 | import jax
20 | import jax.numpy as jnp
21 | import numpy as np
22 | import requests
23 | from flax import jax_utils
24 | from flax.core.frozen_dict import freeze
25 | from flax.training.common_utils import shard
26 | from jax.sharding import PartitionSpec as P
27 | from transformers import WhisperProcessor, is_tokenizers_available, WhisperFeatureExtractor, WhisperTokenizerFast
28 | from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE, WhisperTokenizer
29 | from transformers.pipelines.audio_utils import ffmpeg_read
30 | from transformers.utils import logging
31 |
32 | from .modeling_flax_whisper import FlaxWhisperForConditionalGeneration
33 | from .partitioner import PjitPartitioner
34 | from .train_state import InferenceState
35 |
36 |
37 | logger = logging.get_logger(__name__)
38 |
39 | # 2D parameter and activation partitioning for DP
40 | logical_axis_rules_dp = (
41 | ("batch", "data"),
42 | ("mlp", None),
43 | ("heads", None),
44 | ("vocab", None),
45 | ("embed", None),
46 | ("embed", None),
47 | ("joined_kv", None),
48 | ("kv", None),
49 | ("length", None),
50 | ("num_mel", None),
51 | ("channels", None),
52 | )
53 |
54 |
55 | class FlaxWhisperPipline:
56 | def __init__(
57 | self,
58 | checkpoint="openai/whisper-large-v2",
59 | dtype=jnp.float32,
60 | batch_size=None,
61 | max_length=None,
62 | ):
63 | """
64 | Args
65 | checkpoint (`str`, *optional*, defaults to `"openai/whisper-large-v2"):
66 | The Whisper checkpoint to use with the pipeline. Must be an available checkpoint on the Hugging Face Hub
67 | with Flax weights.
68 | dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
69 | The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
70 | `jax.numpy.bfloat16` (on TPUs). This can be used to enable half-precision inference on GPUs or TPUs.
71 | If specified all the computation will be performed with the given `dtype`. **Note that this only
72 | specifies the dtype of the computation and does not influence the dtype of model parameters.**
73 | batch_size (`int`, *optional*, defaults to the minimum per-device batch size, i.e. `jax.local_device_count()`):
74 | The batch size to be used in chunking transcription. Beneficial for transcribing long audio files. Passing
75 | a batch size in the `__init__` method will be superseded by any batch size passed to the `__call__` method.
76 | max_length (`int`, *optional*):
77 | The maximum numbers of tokens to generate. Defaults to `model.config.max_length`.
78 | """
79 | self.checkpoint = checkpoint
80 | self.dtype = dtype
81 |
82 | self.processor = WhisperProcessor.from_pretrained(self.checkpoint)
83 | self.feature_extractor = self.processor.feature_extractor
84 | # potentially load fast tokenizer if available
85 | tokenizer_cls = WhisperTokenizerFast if is_tokenizers_available() else WhisperTokenizer
86 | self.tokenizer = tokenizer_cls.from_pretrained(checkpoint)
87 |
88 | self.model, self.params = FlaxWhisperForConditionalGeneration.from_pretrained(
89 | self.checkpoint,
90 | _do_init=False,
91 | dtype=self.dtype,
92 | )
93 |
94 | self.max_length = max_length if max_length is not None else self.model.generation_config.max_length
95 | self.min_batch_size = jax.local_device_count()
96 | self.batch_size = (
97 | batch_size if batch_size is not None else self.min_batch_size
98 | ) # we need a minimum of 1 batch per-device
99 |
100 | def generate(params, input_features, forced_decoder_ids, return_timestamps):
101 | output_ids = self.model.pipeline_generate(
102 | input_features,
103 | params=params,
104 | forced_decoder_ids=forced_decoder_ids,
105 | return_timestamps=return_timestamps,
106 | max_length=self.max_length,
107 | )
108 | return output_ids
109 |
110 | # use pmap for DP by default - this is compatible on a Colab TPU v2
111 | self.params = jax_utils.replicate(self.params)
112 | self.p_generate = jax.pmap(
113 | generate, "input_features", in_axes=(0, 0, None), out_axes=0, static_broadcasted_argnums=(3,)
114 | )
115 | self.is_sharded = False
116 |
117 | def shard_params(self, num_mp_partitions=1, logical_axis_rules=logical_axis_rules_dp):
118 | def init_fn():
119 | input_shape = (1, self.model.config.num_mel_bins, 2 * self.model.config.max_source_positions)
120 |
121 | input_features = jnp.zeros(input_shape, dtype="f4")
122 | input_features = input_features.at[(..., -1)].set(self.model.config.eos_token_id)
123 |
124 | decoder_input_ids = jnp.zeros((input_shape[0], 1), dtype="i4")
125 | decoder_attention_mask = jnp.ones_like(decoder_input_ids)
126 |
127 | batch_size, sequence_length = decoder_input_ids.shape
128 | decoder_position_ids = jnp.broadcast_to(
129 | jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
130 | )
131 |
132 | rng = jax.random.PRNGKey(0)
133 | init_params = self.model.module.init(
134 | rng,
135 | input_features=input_features,
136 | decoder_input_ids=decoder_input_ids,
137 | decoder_attention_mask=decoder_attention_mask,
138 | decoder_position_ids=decoder_position_ids,
139 | return_dict=False,
140 | )
141 | return init_params
142 |
143 | # Axis names metadata
144 | param_axes = jax.eval_shape(init_fn)["params_axes"]
145 |
146 | # Create InferenceState, since the partitioner expects it
147 | state = InferenceState(
148 | step=jnp.array(0),
149 | params=freeze(self.model.params_shape_tree),
150 | params_axes=freeze(param_axes),
151 | flax_mutables=None,
152 | flax_mutables_axes=param_axes,
153 | )
154 |
155 | partitioner = PjitPartitioner(num_partitions=num_mp_partitions, logical_axis_rules=logical_axis_rules)
156 |
157 | mesh_axes = partitioner.get_mesh_axes(state)
158 | params_spec = mesh_axes.params
159 |
160 | p_shard_params = partitioner.partition(self.model.to_bf16, (params_spec,), params_spec)
161 |
162 | # This will auto-magically run in mesh context
163 | self.params = p_shard_params(freeze(jax_utils.unreplicate(self.params)))
164 | self.is_sharded = True
165 |
166 | def generate(params, input_features, forced_decoder_ids, return_timestamps):
167 | output_ids = self.model.pipeline_generate(
168 | input_features,
169 | params=params,
170 | forced_decoder_ids=forced_decoder_ids,
171 | return_timestamps=return_timestamps,
172 | max_length=self.max_length,
173 | )
174 | return output_ids
175 |
176 | # Use pjit for generate only once we've sharded the params
177 | self.p_generate = partitioner.partition(
178 | generate,
179 | in_axis_resources=(params_spec, P("data"), None),
180 | out_axis_resources=P("data"),
181 | static_argnums=(3,),
182 | )
183 |
184 | def generate(self, input_features, language=None, task=None, return_timestamps=False):
185 | forced_decoder_ids = self.get_forced_decoder_ids(
186 | language=language, task=task, return_timestamps=return_timestamps
187 | )
188 | if not self.is_sharded:
189 | # if we're using pmap we need to manually replicate the input data across devices and gather the output tokens
190 | output_ids = self.p_generate(
191 | freeze(self.params), shard(input_features), forced_decoder_ids, return_timestamps
192 | ).sequences
193 | output_ids = jax.device_get(output_ids.reshape(-1, self.max_length))
194 | else:
195 | # pjit handles replication / gathering for us auto-magically
196 | output_ids = self.p_generate(
197 | freeze(self.params), input_features, forced_decoder_ids, return_timestamps
198 | ).sequences
199 | return output_ids
200 |
201 | def get_forced_decoder_ids(self, generation_config=None, task=None, language=None, return_timestamps=False):
202 | if generation_config is None:
203 | generation_config = self.model.generation_config
204 |
205 | if hasattr(generation_config, "is_multilingual"):
206 | is_multilingual = generation_config.is_multilingual
207 | else:
208 | is_multilingual = None
209 |
210 | forced_decoder_ids = []
211 |
212 | if is_multilingual:
213 | if language is not None:
214 | language = language.lower()
215 | if language in generation_config.lang_to_id.keys():
216 | language_token = language
217 | elif language in TO_LANGUAGE_CODE.values():
218 | language_token = f"<|{language}|>"
219 | elif language in TO_LANGUAGE_CODE.keys():
220 | language_token = f"<|{TO_LANGUAGE_CODE[language]}|>"
221 | else:
222 | if len(language) == 2:
223 | # ISO 639-1 language code
224 | acceptable_languages = list(TO_LANGUAGE_CODE.values())
225 | elif "<" in language or "|" in language or ">" in language:
226 | # generation config language code
227 | acceptable_languages = list(generation_config.lang_to_id.keys())
228 | else:
229 | # language passed as a string
230 | acceptable_languages = list(TO_LANGUAGE_CODE.keys())
231 | raise ValueError(
232 | f"Unsupported language: {language}. Language should be one of:" f" {acceptable_languages}."
233 | )
234 | forced_decoder_ids.append((1, generation_config.lang_to_id[language_token]))
235 |
236 | if task is not None:
237 | forced_decoder_ids.append((2, generation_config.task_to_id[task]))
238 | else:
239 | forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"]))
240 |
241 | if not return_timestamps:
242 | if forced_decoder_ids and forced_decoder_ids[-1][0] != generation_config.no_timestamps_token_id:
243 | idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1
244 | forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id))
245 |
246 | return forced_decoder_ids
247 |
248 | def chunk_iter_with_batch(self, inputs, chunk_len, stride_left, stride_right, batch_size):
249 | inputs_len = inputs.shape[0]
250 | step = chunk_len - stride_left - stride_right
251 |
252 | all_chunk_start_idx = np.arange(0, inputs_len, step)
253 | num_samples = len(all_chunk_start_idx)
254 |
255 | num_batches = math.ceil(num_samples / batch_size)
256 | batch_idx = np.array_split(np.arange(num_samples), num_batches)
257 |
258 | for idx in batch_idx:
259 | chunk_start_idx = all_chunk_start_idx[idx]
260 | chunk_end_idx = chunk_start_idx + chunk_len
261 |
262 | chunks = [inputs[chunk_start:chunk_end] for chunk_start, chunk_end in zip(chunk_start_idx, chunk_end_idx)]
263 | processed = self.feature_extractor(
264 | chunks, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="np"
265 | )
266 |
267 | _stride_left = np.where(chunk_start_idx == 0, 0, stride_left)
268 | is_last = np.where(stride_right > 0, chunk_end_idx > inputs_len, chunk_end_idx >= inputs_len)
269 | _stride_right = np.where(is_last, 0, stride_right)
270 |
271 | chunk_lens = [chunk.shape[0] for chunk in chunks]
272 | strides = [
273 | (chunk_l, _stride_l, _stride_r)
274 | for chunk_l, _stride_l, _stride_r in zip(chunk_lens, _stride_left, _stride_right)
275 | ]
276 |
277 | yield {"stride": strides, **processed}
278 |
279 | def preprocess_batch(self, inputs, chunk_length_s=30.0, stride_length_s=None, batch_size=None):
280 | if isinstance(inputs, np.ndarray):
281 | logger.warning(
282 | "Numpy array passed as input - no sampling rate checks will be performed."
283 | "It is strongly recommended to pass the input as a dictionary with an 'array' key "
284 | "containing the numpy array representing the audio, and a 'sampling_rate' key "
285 | "containing the sampling rate associated with the audio array."
286 | "Failing to do so can result in silent errors that might be hard to debug."
287 | )
288 |
289 | if isinstance(inputs, str):
290 | if inputs.startswith("http://") or inputs.startswith("https://"):
291 | # We need to actually check for a real protocol, otherwise it's impossible to use a local file
292 | # like http_huggingface_co.png
293 | inputs = requests.get(inputs).content
294 | else:
295 | with open(inputs, "rb") as f:
296 | inputs = f.read()
297 |
298 | if isinstance(inputs, bytes):
299 | inputs = ffmpeg_read(inputs, self.feature_extractor.sampling_rate)
300 |
301 | stride = None
302 | if isinstance(inputs, dict):
303 | stride = inputs.get("stride", None)
304 | # Accepting `"array"` which is the key defined in `datasets` for
305 | # better integration
306 | if not ("sampling_rate" in inputs and "array" in inputs):
307 | raise ValueError(
308 | "When passing a dictionary to FlaxWhisperPipline, the dict needs to contain an 'array' key "
309 | "containing the numpy array representing the audio, and a 'sampling_rate' key "
310 | "containing the sampling rate associated with the audio array."
311 | )
312 |
313 | in_sampling_rate = inputs.get("sampling_rate")
314 | inputs = inputs.get("array", None)
315 |
316 | if in_sampling_rate != self.feature_extractor.sampling_rate:
317 | try:
318 | import librosa
319 | except ImportError as err:
320 | raise ImportError(
321 | "To support resampling audio files, please install 'librosa' and 'soundfile'."
322 | ) from err
323 |
324 | inputs = librosa.resample(
325 | inputs, orig_sr=in_sampling_rate, target_sr=self.feature_extractor.sampling_rate
326 | )
327 | ratio = self.feature_extractor.sampling_rate / in_sampling_rate
328 | else:
329 | ratio = 1
330 |
331 | if not isinstance(inputs, np.ndarray):
332 | raise ValueError(f"We expect a numpy ndarray as input, got `{type(inputs)}`")
333 | if len(inputs.shape) != 1:
334 | raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline")
335 |
336 | if stride is not None:
337 | if stride[0] + stride[1] > inputs.shape[0]:
338 | raise ValueError("Stride is too large for input")
339 |
340 | # Stride needs to get the chunk length here, it's going to get
341 | # swallowed by the `feature_extractor` later, and then batching
342 | # can add extra data in the inputs, so we need to keep track
343 | # of the original length in the stride so we can cut properly.
344 | stride = (inputs.shape[0], int(round(stride[0] * ratio)), int(round(stride[1] * ratio)))
345 |
346 | if chunk_length_s:
347 | if stride_length_s is None:
348 | stride_length_s = chunk_length_s / 6
349 |
350 | if isinstance(stride_length_s, (int, float)):
351 | stride_length_s = [stride_length_s, stride_length_s]
352 |
353 | chunk_len = round(chunk_length_s * self.feature_extractor.sampling_rate)
354 | stride_left = round(stride_length_s[0] * self.feature_extractor.sampling_rate)
355 | stride_right = round(stride_length_s[1] * self.feature_extractor.sampling_rate)
356 |
357 | if chunk_len < stride_left + stride_right:
358 | raise ValueError("Chunk length must be superior to stride length")
359 |
360 | for item in self.chunk_iter_with_batch(
361 | inputs,
362 | chunk_len,
363 | stride_left,
364 | stride_right,
365 | batch_size,
366 | ):
367 | yield item
368 | else:
369 | processed = self.feature_extractor(
370 | inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="np"
371 | )
372 | if stride is not None:
373 | processed["stride"] = stride
374 | yield processed
375 |
376 | def postprocess(self, model_outputs, return_timestamps=None, return_language=None):
377 | # unpack the outputs from list(dict(list)) to list(dict)
378 | model_outputs = [dict(zip(output, t)) for output in model_outputs for t in zip(*output.values())]
379 |
380 | time_precision = self.feature_extractor.chunk_length / self.model.config.max_source_positions
381 | # Send the chunking back to seconds, it's easier to handle in whisper
382 | sampling_rate = self.feature_extractor.sampling_rate
383 | for output in model_outputs:
384 | if "stride" in output:
385 | chunk_len, stride_left, stride_right = output["stride"]
386 | # Go back in seconds
387 | chunk_len /= sampling_rate
388 | stride_left /= sampling_rate
389 | stride_right /= sampling_rate
390 | output["stride"] = chunk_len, stride_left, stride_right
391 |
392 | text, optional = self.tokenizer._decode_asr(
393 | model_outputs,
394 | return_timestamps=return_timestamps,
395 | return_language=return_language,
396 | time_precision=time_precision,
397 | )
398 | return {"text": text, **optional}
399 |
400 | def forward(self, model_inputs, batch_size=None, language=None, task=None, return_timestamps=False):
401 | # We need to keep track of some additional input arguments for post-processing so need to forward these on after running generation
402 | input_features = model_inputs.pop("input_features")
403 | input_batch_size = input_features.shape[0]
404 |
405 | if input_batch_size != batch_size:
406 | padding = np.zeros([batch_size - input_batch_size, *input_features.shape[1:]], input_features.dtype)
407 | input_features = np.concatenate([input_features, padding])
408 |
409 | pred_ids = self.generate(input_features, language=language, task=task, return_timestamps=return_timestamps)[
410 | :input_batch_size
411 | ]
412 |
413 | # tokenizer's decode method expects an extra dim - we insert it here for convenience
414 | out = {"tokens": pred_ids[:, None, :]}
415 |
416 | stride = model_inputs.pop("stride", None)
417 | if stride is not None:
418 | out["stride"] = stride
419 |
420 | return out
421 |
422 | def __call__(
423 | self,
424 | inputs,
425 | chunk_length_s=30.0,
426 | stride_length_s=None,
427 | batch_size=None,
428 | language=None,
429 | task=None,
430 | return_timestamps=None,
431 | generate_kwargs=None,
432 | ):
433 | """
434 | Transcribe an audio input sequence to a text transcription, optionally with timestamps.
435 |
436 | Args:
437 | inputs (`np.ndarray` or `bytes` or `str` or `dict`):
438 | The inputs is either:
439 | - `str` that is the filename of the audio file, the file will be read at the correct sampling rate
440 | to get the waveform using *ffmpeg*. This requires *ffmpeg* to be installed on the system.
441 | - `bytes` is the byte content of an audio file and is interpreted by *ffmpeg* in the
442 | same way.
443 | - (`np.ndarray` of shape (n, ) of type `np.float32` or `np.float64`)
444 | Raw audio assumed to be at the correct sampling rate (16kHz). Note that no further sampling
445 | rate check will be done.
446 | - `dict` form can be used to pass raw audio sampled at arbitrary `sampling_rate` and let this
447 | pipeline do the resampling. The dict must be in the format `{"sampling_rate": int, "array":
448 | np.array}`. Optionally an additional argument `"stride": (left: int, right: int)` can be used to
449 | ask the pipeline to treat the first `left` samples and last `right` samples to be ignored in
450 | decoding (but used at inference to provide more context to the model). In general, this additional
451 | stride argument is not required.
452 | chunk_length_s (`float`, *optional*, defaults to 30.0):
453 | The input length for each chunk. If `chunk_length_s = 0` then chunking is disabled. By default, the chunk
454 | length is set 30.0s, equal to Whisper's context window.
455 | stride_length_s (`float`, *optional*, defaults to `chunk_length_s / 6`):
456 | The length of stride on the left and right of each chunk. Used only with `chunk_length_s > 0`. This enables
457 | the model to *see* more context and infer letters better than without this context but the pipeline
458 | discards the stride bits at the end to make the final reconstitution as perfect as possible.
459 |
460 |
461 |
462 | For more information on how to effectively use `stride_length_s`, refer to the [ASR chunking
463 | blog post](https://huggingface.co/blog/asr-chunking).
464 |
465 |
466 | batch_size (`int`, *optional*, defaults to the minimum per-device batch size, i.e. `jax.local_device_count()`):
467 | The batch size to be used in chunking transcription. Beneficial for transcribing long audio files. Passing
468 | a batch size in the `__call__` method will supersede any batch size passed to the `__init__`.
469 | task (`str`, *optional*):
470 | Task to use for generation, either `"transcribe"` or `"translate"`. Defaults to `"transcribe"`.
471 | language (`str`, *optional*):
472 | Language token to use for generation, can be either in the form of `"<|en|>"`, `"en"` or `"english"`.
473 | Defaults to `None`, meaning the language is automatically inferred from the audio input.
474 | return_timestamps (*optional*, `bool`):
475 | Whether to return timestamps in the prediction. Defaults to False. If set to true, the pipeline
476 | will return two keys in the output dictionary: `"text"` containing the text transcription, and `"chunks"`
477 | containing the transcription segments chunked by their utterance-level timestamps.
478 |
479 | Return:
480 | `Dict`: A dictionary with the following keys:
481 | - **text** (`str` ) -- The recognised text.
482 | - **chunks** (*optional(, `List[Dict]`)
483 | When using `return_timestamps`, the `chunks` will become a list containing all the various text
484 | chunks identified by the model, *e.g.* `[{"text": "hi ", "timestamps": (0.5,0.9), {"text":
485 | "there", "timestamps": (1.0, 1.5)}]`. The original full text can roughly be recovered by doing
486 | `"".join(chunk["text"] for chunk in output["chunks"])`.
487 | """
488 | batch_size = batch_size if batch_size is not None else self.batch_size
489 | if batch_size % self.min_batch_size != 0:
490 | raise ValueError(
491 | f"Batch size must be a multiple of the number of JAX devices, but got batch size {batch_size} and num devices {self.min_batch_size}."
492 | )
493 |
494 | dataloader = self.preprocess_batch(
495 | inputs, chunk_length_s=chunk_length_s, stride_length_s=stride_length_s, batch_size=batch_size
496 | )
497 | model_outputs = []
498 | # iterate over our chunked audio samples
499 | for batch in dataloader:
500 | model_outputs.append(
501 | self.forward(
502 | batch, batch_size=batch_size, language=language, task=task, return_timestamps=return_timestamps
503 | )
504 | )
505 | post_processed = self.postprocess(model_outputs, return_timestamps=return_timestamps)
506 | return post_processed
507 |
--------------------------------------------------------------------------------
/whisper_jax/train_state.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The T5X Authors and The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Utilities for partitioning."""
17 |
18 | from typing import Any, Mapping, MutableMapping, Optional, Tuple
19 |
20 | import flax.core
21 | import flax.serialization
22 | import flax.struct
23 | import jax.numpy as jnp
24 | from flax import traverse_util
25 | from flax.core import scope as flax_scope
26 | from flax.linen import partitioning as flax_partitioning
27 |
28 |
29 | EMPTY_DICT = flax.core.freeze({})
30 | FrozenDict = flax_scope.FrozenDict
31 | FrozenVariableDict = flax_scope.FrozenVariableDict
32 | MutableVariableDict = flax_scope.MutableVariableDict
33 | VariableDict = flax_scope.VariableDict
34 |
35 |
36 | def _validate_params_axes(params_axes, params):
37 | axis_names = flax_partitioning.get_axis_names(params_axes)
38 | missing_params_axes = set(traverse_util.flatten_dict(params, sep="/")) - set(
39 | traverse_util.flatten_dict(axis_names, sep="/")
40 | )
41 | if missing_params_axes:
42 | raise ValueError(f"Missing axis names for parameters: {missing_params_axes}")
43 |
44 |
45 | def _split_variables_and_axes(variables_and_axes: FrozenVariableDict) -> Tuple[FrozenVariableDict, FrozenVariableDict]:
46 | """Splits `variables_and_axes` into two separate dicts with the same keys."""
47 | # For each `key`, `key_axes` (if any) are its axes in `variables_and_axes`.
48 | variables = {}
49 | axes = {}
50 | for k, v in variables_and_axes.items():
51 | if k.endswith("_axes"):
52 | axes[k[:-5]] = v # k without "_axes".
53 | _validate_params_axes(v, variables_and_axes[k[:-5]]) # k without "_axes".
54 | else:
55 | variables[k] = v
56 | return flax.core.freeze(variables), flax.core.freeze(axes)
57 |
58 |
59 | class InferenceState(flax.struct.PyTreeNode):
60 | """State compatible with FlaxOptimTrainState without optimizer state."""
61 |
62 | step: jnp.ndarray
63 | params: flax_scope.FrozenVariableDict
64 | params_axes: Optional[flax_scope.FrozenVariableDict] = None
65 | flax_mutables: flax_scope.FrozenDict = EMPTY_DICT
66 | flax_mutables_axes: Optional[flax_scope.FrozenVariableDict] = None
67 |
68 | @classmethod
69 | def create(cls, model_variables: FrozenVariableDict) -> "InferenceState":
70 | other_variables, params = model_variables.pop("params")
71 | if "params_axes" in other_variables:
72 | other_variables, params_axes = other_variables.pop("params_axes")
73 | _validate_params_axes(params_axes, params)
74 | else:
75 | params_axes = None
76 |
77 | # Split other_variables into mutables and their corresponding axes.
78 | flax_mutables, flax_mutables_axes = _split_variables_and_axes(other_variables)
79 | flax_mutables_axes = flax_mutables_axes or None
80 | return InferenceState(
81 | step=jnp.array(0),
82 | params=params,
83 | params_axes=params_axes,
84 | flax_mutables=flax_mutables,
85 | flax_mutables_axes=flax_mutables_axes,
86 | )
87 |
88 | @property
89 | def param_states(self) -> FrozenVariableDict:
90 | """The optimizer states of the parameters as a PyTree."""
91 | raise NotImplementedError("InferenceState has no optimizer states.")
92 |
93 | def apply_gradient(self, *args, **kwargs) -> "InferenceState":
94 | raise NotImplementedError("InferenceState does not support `apply_gradient`.")
95 |
96 | def state_dict(self) -> MutableMapping[str, Any]:
97 | state_dict = {"target": flax.core.unfreeze(self.params), "state": {"step": self.step}}
98 | if self.flax_mutables:
99 | state_dict["flax_mutables"] = flax.core.unfreeze(self.flax_mutables)
100 | return state_dict
101 |
102 | def replace_step(self, step: jnp.ndarray) -> "InferenceState":
103 | return self.replace(step=step)
104 |
105 | def replace_params(self, params: FrozenVariableDict) -> "InferenceState":
106 | return self.replace(params=params)
107 |
108 | def replace_flax_mutables(self, flax_mutables: FrozenDict) -> "InferenceState":
109 | return self.replace(flax_mutables=flax_mutables)
110 |
111 | def restore_state(self, state_dict: Mapping[str, Any]) -> "InferenceState":
112 | return self.replace(
113 | params=flax.core.freeze(state_dict["target"]),
114 | step=state_dict["state"]["step"],
115 | flax_mutables=flax.core.freeze(state_dict["flax_mutables"])
116 | if "flax_mutables" in state_dict
117 | else EMPTY_DICT,
118 | )
119 |
120 | def as_logical_axes(self) -> "InferenceState":
121 | # Set step to None so that when the logical axes are processed by the
122 | # flax.partitioning.logical_to_mesh_axes function, it will be skipped
123 | # because jax.tree_map will short circut and never call the function on the
124 | # step.
125 | flax_mutables_axes = self.flax_mutables_axes or EMPTY_DICT
126 | return InferenceState(
127 | step=None,
128 | params=flax_partitioning.get_axis_names(self.params_axes),
129 | flax_mutables=flax_partitioning.get_axis_names(flax_mutables_axes),
130 | )
131 |
--------------------------------------------------------------------------------