├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── imgs └── imagenet64.gif ├── md4 ├── __init__.py ├── binary_search.py ├── configs │ ├── genmd4 │ │ ├── openwebtext.py │ │ └── text8.py │ └── md4 │ │ ├── cifar10.py │ │ ├── fineweb_edu.py │ │ ├── imagenet64.py │ │ ├── openwebtext.py │ │ └── text8.py ├── input_pipeline.py ├── input_pipeline_v2.py ├── main.py ├── models │ ├── backward.py │ ├── diffusion │ │ ├── genmd4.py │ │ └── md4.py │ └── utils.py ├── multihost_dataloading.py ├── networks │ ├── dit.py │ ├── sharded_transformer.py │ ├── transformer.py │ ├── unet.py │ └── uvit.py ├── sampling.py ├── sharded_train.py ├── train.py └── utils.py ├── prepare_openwebtext_data.py ├── requirements_gpu.txt ├── requirements_tpu.txt └── run_gcp.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Distribution / packaging 7 | .Python 8 | build/ 9 | develop-eggs/ 10 | dist/ 11 | downloads/ 12 | eggs/ 13 | .eggs/ 14 | lib/ 15 | lib64/ 16 | parts/ 17 | sdist/ 18 | var/ 19 | wheels/ 20 | share/python-wheels/ 21 | *.egg-info/ 22 | .installed.cfg 23 | *.egg 24 | MANIFEST 25 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | ## Contributor License Agreement 4 | 5 | Contributions to this project must be accompanied by a Contributor License 6 | Agreement. You (or your employer) retain the copyright to your contribution, 7 | this simply gives us permission to use and redistribute your contributions as 8 | part of the project. Head over to to see 9 | your current agreements on file or to sign a new one. 10 | 11 | You generally only need to submit a CLA once, so if you've already submitted one 12 | (even if it was for a different project), you probably don't need to do it 13 | again. 14 | 15 | ## Code reviews 16 | 17 | All submissions, including submissions by project members, require review. We 18 | use GitHub pull requests for this purpose. Consult 19 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 20 | information on using pull requests. 21 | 22 | ## Community Guidelines 23 | 24 | This project follows [Google's Open Source Community 25 | Guidelines](https://opensource.google/conduct/). 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MD4: Simplified and Generalized Masked Diffusion for Discrete Data 2 | 3 | 4 | 5 | ## Installation 6 | 7 | ### Create Python environment 8 | 9 | please use `requirements_gpu.txt` if your accelerator is GPUs, use 10 | `requirements_tpu`.txt when using Google Cloud TPUs. 11 | 12 | ``` 13 | python -m venv md4_venv 14 | source md4_venv/bin/activate 15 | pip install -r requirements_[gpu/tpu].txt 16 | export PYTHONPATH="$PYTHONPATH:~/path/to/md4" 17 | ``` 18 | 19 | ## Usage 20 | 21 | prepare openwebtext for training (i.e., tokenize and pack examples) 22 | 23 | ``` 24 | mkdir data_dir 25 | python prepare_openwebtext_data.py 26 | ``` 27 | 28 | train a MD4-S model over text data (OpenWebText, Fineweb-EDU). 29 | 30 | ``` 31 | python md4/main.py --config=md4/configs/md4/openwebtext.py --sharded=false --workdir=./expt 32 | ``` 33 | 34 | alternatively, you can train a MD4-S model over image data (CIFAR-10, 35 | ImageNet-64). 36 | 37 | ``` 38 | python md4/main.py --config=md4/configs/md4/cifar10.py --sharded=false --workdir=./expt 39 | ``` 40 | 41 | ### choose batch size 42 | 43 | Batch size depends on your compute resource. For training a MD4-S model with 44 | sequence length 1024, eight `A100` GPUs can support a maximum batch size of 45 | `128`. If running on TPUs, eight `v5litepod` chips can support a maximum batch 46 | size of `32`. 47 | 48 | ## Citing this work 49 | 50 | Add citation details here, usually a pastable BibTeX snippet: 51 | 52 | ``` 53 | @inproceedings{shi2024simplified, 54 | title={Simplified and Generalized Masked Diffusion for Discrete Data}, 55 | author={Shi, Jiaxin and Han, Kehang and Wang, Zhe and Doucet, Arnaud and Titsias, Michalis K.}, 56 | booktitle={Advances in Neural Information Processing Systems}, 57 | year={2024} 58 | } 59 | ``` 60 | 61 | ## License and disclaimer 62 | 63 | Copyright 2024 DeepMind Technologies Limited 64 | 65 | All software is licensed under the Apache License, Version 2.0 (Apache 2.0); 66 | you may not use this file except in compliance with the Apache 2.0 license. 67 | You may obtain a copy of the Apache 2.0 license at: 68 | https://www.apache.org/licenses/LICENSE-2.0 69 | 70 | Unless required by applicable law or agreed to in writing, all software and 71 | materials distributed here under the Apache 2.0 or CC-BY licenses are 72 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, 73 | either express or implied. See the licenses for the specific language governing 74 | permissions and limitations under those licenses. 75 | 76 | This is not an official Google product. 77 | -------------------------------------------------------------------------------- /imgs/imagenet64.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/md4/b9fbf29216f818cccff46cb5727199f86f35593e/imgs/imagenet64.gif -------------------------------------------------------------------------------- /md4/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """MD4.""" 17 | -------------------------------------------------------------------------------- /md4/binary_search.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Binary search over float32 bits. 17 | 18 | Includes fast algorithms top-k masking and top-p masking on probability 19 | distributions. 20 | 21 | Adapted from: 22 | https://github.com/google-research/t5x/blob/main/t5x/binary_search.py 23 | """ 24 | 25 | from typing import Callable, Sequence 26 | 27 | import jax 28 | from jax import lax 29 | from jax import numpy as jnp 30 | 31 | 32 | def int32_bsearch( 33 | batch_shape: Sequence[int], predicate: Callable[[jnp.ndarray], jnp.ndarray], 34 | ) -> int: 35 | """Batched binary search over int32 values. 36 | 37 | For each element of the batch, search for the largest int32 (closest to 38 | positive infinity) for which the predicate is False. If the predicate is 39 | always True, returns the minimum int32 value. 40 | 41 | Args: 42 | batch_shape: Shape of the search that we're batching over. 43 | predicate: the query we're searching for. For every batch element, this is 44 | required to be a monotonic function from int32 to bool. In other words, 45 | the predicate must return False for all numbers <= some threshold and then 46 | return True for all numbers > that threshold. The threshold may be 47 | different for different elements of the batch. 48 | 49 | Returns: 50 | For each element of the batch, the largest int32 for which the predicate 51 | returns False. Shape: batch_shape. 52 | """ 53 | current_bits = jnp.zeros(batch_shape, dtype=jnp.int32) 54 | 55 | # bit 31 is special, because it compares in the opposite order of all other 56 | # bits. we use uint32 due to numpy promotion/casting rules. 57 | midpoint = current_bits 58 | predicate_satisfied = predicate(midpoint) 59 | current_bits = current_bits | jnp.where( 60 | predicate_satisfied, jnp.uint32(1 << 31), jnp.uint32(0) 61 | ) 62 | del midpoint, predicate_satisfied 63 | 64 | def loop_body(i, current_bits): 65 | bit_index = 30 - i 66 | bit = jnp.int32(1 << bit_index) 67 | midpoint = current_bits | bit 68 | predicate_satisfied = predicate(midpoint) 69 | current_bits = current_bits | jnp.where( 70 | predicate_satisfied, jnp.int32(0), bit 71 | ) 72 | return current_bits 73 | 74 | return lax.fori_loop(0, 31, loop_body, current_bits) 75 | 76 | 77 | def _monotonic_int32_to_float32_bit_pattern(x: int) -> int: 78 | """Converts an int32 to a float32 bit pattern with consistent ordering. 79 | 80 | This function is the unique function that is monotonic with respect to the 81 | floating point total order, see 82 | https://en.wikipedia.org/wiki/IEEE_754#Total-ordering_predicate. Note that 83 | this function returns an int32, not a float32. For the function that returns 84 | float32, see `monotonic_int32_to_float32`. 85 | 86 | Args: 87 | x: int bit pattern. 88 | 89 | Returns: 90 | Bit pattern of a float32 number. 91 | """ 92 | non_sign_bits = jnp.int32((1 << 31) - 1) 93 | # See 94 | # https://stackoverflow.com/questions/20097380/iee-754-total-order-in-standard-c11 95 | # for the relationship between int32 order and f32 total order, including 96 | # the "xor trick". 97 | 98 | # Flip the sort order for numbers where the sign bit is set. On int32, 99 | # the bit pattern with sign bit set and all other bits clear is the most 100 | # negative bit pattern (it's int32::MIN), whereas on float32 it's the least 101 | # negative bit pattern (it's -0.0). Flipping all the non-sign bits makes the 102 | # int32 sort order consistent with the float32 sort order. 103 | x = x ^ jnp.where(x < 0, non_sign_bits, jnp.int32(0)) 104 | return x 105 | 106 | 107 | def _monotonic_int32_to_float32(x: int) -> jax.Array: 108 | """Converts an int32 to a float32 with consistent ordering. 109 | 110 | This function is the unique function that is monotonic with respect to the 111 | floating point total order, see 112 | https://en.wikipedia.org/wiki/IEEE_754#Total-ordering_predicate. 113 | 114 | Args: 115 | x: int bit pattern. 116 | 117 | Returns: 118 | float32 number with consistent ordering. 119 | """ 120 | x = _monotonic_int32_to_float32_bit_pattern(x) 121 | return lax.bitcast_convert_type(x, jnp.float32) 122 | 123 | 124 | def float32_bsearch(batch_shape, predicate): 125 | """Binary search on finite float32 numbers. 126 | 127 | For each element of the batch, this function searches for the largest finite 128 | non-NaN float32 for which the predicate is False. 129 | 130 | Args: 131 | batch_shape: Shape of the search that we're batching over. 132 | predicate: the query we're searching for. This is required to be monotonic 133 | with respect to the floating point order, i.e. it must be False for all 134 | numbers <= a threshold, and then True for all numbers > the threshold. The 135 | threshold may be different for different elements of the batch. 136 | 137 | Returns: 138 | For each element of the batch, the largest float32 for which the predicate 139 | returns False. Shape: f32[batch_shape]. 140 | """ 141 | exponent_bits = jnp.int32((1 << 31) - (1 << (31 - 8))) 142 | 143 | def int32_predicate(x): 144 | x = _monotonic_int32_to_float32_bit_pattern(x) 145 | is_finite = (x & exponent_bits) != exponent_bits 146 | 147 | # Non-finite numbers (infinity and NaN) are at the very extremes of the 148 | # int32 range, i.e. they include int32::MAX and int32::MIN, plus the numbers 149 | # adjacent to them. For the nonfinite numbers touching int32::MIN, we 150 | # arrange for them to return False from the predicate, and for the nonfinite 151 | # numbers touching int32::MAX, we arrange for them to return True from the 152 | # predicate. x>=0 is an easy way to achieve that. 153 | predicate_on_nonfinite = x >= 0 154 | x_float32 = lax.bitcast_convert_type(x, jnp.float32) 155 | return jnp.where(is_finite, predicate(x_float32), predicate_on_nonfinite) 156 | 157 | # We search over bit patterns, which requires bit shifting and ordering of bit 158 | # patterns. This is natively supported on int32 but not on float32. 159 | # Additionally, it's more common to reason about int32 bit arithmetic and 160 | # ordering than float32 bit arithmetic and ordering, so we do the core of our 161 | # search in int32. Additionally, this allows us to test the underlying binary 162 | # search on int32 values. 163 | # 164 | # The function _monotonic_int32_to_float32 encapsulates all of the knowledge 165 | # we need about float32 bit patterns. 166 | result = int32_bsearch(batch_shape, int32_predicate) 167 | return _monotonic_int32_to_float32(result) 168 | 169 | 170 | def topk_mask(x: jnp.ndarray, k: int, replace_val: jnp.ndarray) -> jnp.ndarray: 171 | """Sets everything to replace_val, except the top k values per batch element. 172 | 173 | Sharding considerations: this function does 32 reductions over the vocab_size 174 | axis of the input array. To avoid excessive latency from these reductions, you 175 | should ensure that the vocab_size axis is unsharded on input to this function. 176 | Prefer to shard the batch axes instead. 177 | 178 | Scratchpad memory considerations: this function is most efficient if the 179 | entire input array can fit in a fast memory tier. To help ensure this, you may 180 | wish to split the batch axes into microbatches and the microbatches in a 181 | sequential loop. 182 | 183 | Args: 184 | x: Values before masking. [batch..., vocab_size] 185 | k: Number of masked values to return. In presence of ties, more than k 186 | values might be returned. 187 | replace_val: For the masked values of x, what to overwrite them with. 188 | 189 | Returns: 190 | masked version of x. [batch..., vocab_size] 191 | """ 192 | batch_shape = tuple(list(x.shape)[:-1]) # [batch...] 193 | 194 | x_for_loop = x 195 | reduce_axis = x.ndim - 1 196 | if x.ndim > 1: 197 | # We're going to be doing 32 reductions over 'reduce_axis'. Generally, 198 | # reductions over the last dimension are the most expensive, because they 199 | # involve reducing across vector lanes, which is often not efficient. So 200 | # we transpose the reduce_axis to be the second-last dimension, to avoid 201 | # this inefficiency. 202 | # 203 | # Normaly the XLA compiler would automatically perform this optimization, 204 | # but it doesn't yet see through loops to do so. So we do it ourselves. 205 | x_for_loop = jnp.swapaxes(x_for_loop, -1, -2) 206 | reduce_axis = x.ndim - 2 207 | 208 | # x: [batch..., vocab_size, batch] 209 | def predicate(threshold): 210 | # threshold: [batch...] 211 | 212 | # Since we've negated, we now want a predicate that is True for small 213 | # numbers and False for large numbers. The result of the bsearch is the 214 | # smallest float32 for which the predicate is False. 215 | threshold = -threshold 216 | 217 | threshold = lax.expand_dims(threshold, (reduce_axis,)) 218 | # threshold: [batch..., 1, last_batch] 219 | 220 | # count_ge: [batch...] 221 | count_gt = jnp.sum(x_for_loop > threshold, axis=reduce_axis) 222 | 223 | return count_gt >= k 224 | 225 | # cutoff: [batch...] 226 | cutoff = float32_bsearch(batch_shape, predicate) 227 | cutoff = -cutoff 228 | # cutoff: [batch..., 1] 229 | cutoff = lax.expand_dims(cutoff, (cutoff.ndim,)) 230 | return jnp.where(x >= cutoff, x, jnp.full_like(x, replace_val)) 231 | 232 | 233 | def topp_mask( 234 | logits: jnp.ndarray, p: float, replace_val: jnp.ndarray 235 | ) -> jnp.ndarray: 236 | """Applies top-p masking to logits. 237 | 238 | Masks logits down to the smallest set of choices, such that the total 239 | probability mass is >= p. Values in this set are left as they are. All other 240 | values are set with `replace_val`. 241 | 242 | Sharding considerations: this function does 33 reductions over the vocab_size 243 | axis of the input array. To avoid excessive latency from these reductions, you 244 | should ensure that the vocab_size axis is unsharded on input to this function. 245 | Prefer to shard the batch axes instead. 246 | 247 | Scratchpad memory considerations: this function is most efficient if the 248 | entire input array can fit in a fast memory tier. To help ensure this, you may 249 | wish to split the batch axes into microbatches and the microbatches in a 250 | sequential loop. 251 | 252 | Args: 253 | logits: Logits before masking. [batch..., vocab_size] 254 | p: Minimum probability mass requested. 255 | replace_val: For the masked values of logits, what to overwrite them with. 256 | 257 | Returns: 258 | masked version of x. [batch..., vocab_size] 259 | """ 260 | batch_shape = tuple(list(logits.shape)[:-1]) # [batch...] 261 | 262 | probs = jax.nn.softmax(logits, axis=-1) 263 | 264 | probs_for_reduction = probs 265 | reduce_axis = probs_for_reduction.ndim - 1 266 | if probs_for_reduction.ndim > 1: 267 | # We're going to be doing 33 reductions over 'reduce_axis'. Generally, 268 | # reductions over the last dimension are the most expensive, because they 269 | # involve reducing across vector lanes, which is often not efficient. So 270 | # we transpose the reduce_axis to be the second-last dimension, to avoid 271 | # this inefficiency. 272 | probs_for_reduction = jnp.swapaxes(probs_for_reduction, -1, -2) 273 | reduce_axis = probs_for_reduction.ndim - 2 274 | 275 | # As we increase the threshold, the probability mass decreases, and the number 276 | # selected decreases. 277 | # 278 | # We want the largest threshold with the probability mass >= p. Binary search 279 | # searches for when the predicate is False, so we negate the output of the 280 | # predicate, i.e. probability mass < p. 281 | 282 | # probs_for_reduction: [batch..., vocab_size, batch] 283 | def predicate(threshold): 284 | # threshold: [batch...] 285 | threshold = lax.expand_dims(threshold, (reduce_axis,)) 286 | # threshold: [batch..., 1, last_batch] 287 | 288 | # count_ge: [batch...] 289 | probability_mass = jnp.sum( 290 | jnp.where(probs_for_reduction >= threshold, probs_for_reduction, 0.0), 291 | axis=reduce_axis, 292 | ) 293 | 294 | return probability_mass < p 295 | 296 | # threshold: [batch...] 297 | threshold = float32_bsearch(batch_shape, predicate) 298 | # threshold: [batch..., 1] 299 | threshold = lax.expand_dims(threshold, (threshold.ndim,)) 300 | return jnp.where( 301 | probs >= threshold, logits, jnp.full_like(logits, replace_val) 302 | ) 303 | -------------------------------------------------------------------------------- /md4/configs/genmd4/openwebtext.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | r"""A config for training GenMD4-S on OpenWebText.""" 17 | 18 | from collections import abc 19 | 20 | from ml_collections import config_dict 21 | 22 | 23 | def get_config() -> config_dict.ConfigDict: 24 | """Default config.""" 25 | 26 | config = config_dict.ConfigDict() 27 | 28 | # dataset configs 29 | config.vocab_size = 50259 30 | config.dataset = "openwebtext" 31 | config.classes = -1 32 | 33 | config.task_type = "text" # text or image 34 | config.model_type = "genmd4" 35 | config.data_shape = (1024,) 36 | 37 | # timesteps: int or None 38 | config.timesteps = 1000 39 | config.noise_schedule = "poly" 40 | config.power_init = 1.0 41 | config.outside_embed = False 42 | config.time_features = "t" 43 | config.cont_time = True 44 | 45 | config.feature_dim = 64 46 | config.n_layers = 12 47 | config.ch_mult = (1,) # not used 48 | config.n_dit_layers = 0 # not used 49 | config.dit_num_heads = 12 # not used 50 | config.dit_hidden_size = 768 # not used 51 | config.dropout_rate = 0.0 52 | 53 | config.num_heads = 12 54 | config.mlp_type = "glu" 55 | config.depth_scaled_init = True 56 | config.cond_type = "adaln_zero" 57 | 58 | config.learning_rate = 3e-4 59 | config.learning_rate_schedule = "cosine" 60 | config.warmup_steps = 2000 61 | config.weight_decay = 0.0 62 | config.clip = 0.0 63 | config.b2 = 0.999 64 | config.num_epochs = -1 65 | config.ema_rate = 0.9999 66 | # If num_train_steps==-1 then the number of training steps is calculated from 67 | # num_epochs. 68 | config.num_train_steps = 1_000_000 69 | # Evaluates for a full epoch if num_eval_steps==-1. Set to a smaller value for 70 | # fast iteration when running train.train_and_eval() from a Colab. 71 | config.num_eval_steps = -1 72 | config.batch_size = 512 73 | config.num_microbatches = 1 74 | config.per_device_batch_size = -1 75 | # If batches should be added to evaluate the entire dataset. 76 | config.eval_pad_last_batch = False 77 | config.check_nans = False 78 | 79 | config.log_loss_every_steps = 500 80 | config.eval_every_steps = 10000 81 | config.checkpoint_every_steps = 5000 82 | config.checkpoint_keep_period = -1 83 | 84 | # Single integer or tuple. If None will use (XManager ID, work unit). 85 | config.seed = 42 86 | 87 | # Number of workers for Grain loaders. 88 | config.grain_num_workers = 15 89 | 90 | config.trial = 0 # Dummy for repeated runs. 91 | config.test_in_colab = False 92 | return config 93 | 94 | 95 | # By default, the launcher calls `sweep()`. 96 | # To disable the sweep, the `sweep()` function can be commented (or renamed), 97 | # or the flag `--nosweep` can be specified to the launcher. 98 | def sweep(add: abc.Callable[..., None]): 99 | """Starts multiple work units with varying config args.""" 100 | add( 101 | learning_rate=3e-4, 102 | dropout_rate=0.02, 103 | weight_decay=0.03, 104 | ) 105 | -------------------------------------------------------------------------------- /md4/configs/genmd4/text8.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | r"""A config for training GenMD4 on text8.""" 17 | 18 | from collections import abc 19 | 20 | from ml_collections import config_dict 21 | 22 | 23 | def get_config() -> config_dict.ConfigDict: 24 | """Default config.""" 25 | 26 | config = config_dict.ConfigDict() 27 | 28 | # dataset configs 29 | config.vocab_size = 27 30 | config.dataset = "text8" 31 | config.classes = -1 32 | 33 | config.task_type = "text" # text or image 34 | config.model_type = "genmd4" 35 | config.data_shape = (256,) 36 | 37 | # timesteps: int or None 38 | config.timesteps = 1000 39 | config.noise_schedule = "poly" 40 | config.power_init = 1.0 41 | config.outside_embed = True 42 | config.time_features = "t" 43 | config.cont_time = True 44 | 45 | config.feature_dim = 64 46 | config.n_layers = 12 47 | config.ch_mult = (1,) # not used 48 | config.n_dit_layers = 0 # not used 49 | config.dit_num_heads = 12 # not used 50 | config.dit_hidden_size = 768 # not used 51 | config.dropout_rate = 0.0 52 | 53 | config.num_heads = 12 54 | config.mlp_type = "glu" 55 | config.depth_scaled_init = True 56 | config.cond_type = "adaln_zero" 57 | 58 | config.learning_rate = 3e-4 59 | config.learning_rate_schedule = "cosine" 60 | config.warmup_steps = 2000 61 | config.weight_decay = 0.0 62 | config.clip = 0.0 63 | config.b2 = 0.999 64 | config.num_epochs = -1 65 | config.ema_rate = 0.9999 66 | # If num_train_steps==-1 then the number of training steps is calculated from 67 | # num_epochs. 68 | config.num_train_steps = 1_000_000 69 | # Evaluates for a full epoch if num_eval_steps==-1. Set to a smaller value for 70 | # fast iteration when running train.train_and_eval() from a Colab. 71 | config.num_eval_steps = -1 72 | config.batch_size = 512 73 | config.num_microbatches = 1 74 | config.per_device_batch_size = -1 75 | # If batches should be added to evaluate the entire dataset. 76 | config.eval_pad_last_batch = False 77 | config.check_nans = False 78 | 79 | config.log_loss_every_steps = 500 80 | config.eval_every_steps = 5000 81 | config.checkpoint_every_steps = 5000 82 | config.checkpoint_keep_period = 10000 83 | 84 | # Single integer or tuple. If None will use (XManager ID, work unit). 85 | config.seed = 42 86 | 87 | # Number of workers for Grain loaders. 88 | config.grain_num_workers = 15 89 | 90 | config.trial = 0 # Dummy for repeated runs. 91 | config.test_in_colab = False 92 | return config 93 | 94 | 95 | # By default, the launcher calls `sweep()`. 96 | # To disable the sweep, the `sweep()` function can be commented (or renamed), 97 | # or the flag `--nosweep` can be specified to the launcher. 98 | def sweep(add: abc.Callable[..., None]): 99 | """Starts multiple work units with varying config args.""" 100 | add( 101 | learning_rate=3e-4, 102 | dropout_rate=0.05, 103 | weight_decay=0.03, 104 | ) 105 | -------------------------------------------------------------------------------- /md4/configs/md4/cifar10.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | r"""A config for training MD4 on CIFAR10.""" 17 | 18 | from collections import abc 19 | 20 | from ml_collections import config_dict 21 | 22 | 23 | def get_config() -> config_dict.ConfigDict: 24 | """Default config.""" 25 | 26 | config = config_dict.ConfigDict() 27 | 28 | # dataset configs 29 | config.vocab_size = 256 30 | config.dataset = "cifar10" 31 | config.classes = -1 32 | 33 | config.task_type = "image" # text or image 34 | config.model_type = "md4" 35 | config.data_shape = (32, 32, 3) 36 | 37 | # timesteps: int or None 38 | config.timesteps = 256 39 | # linear, cosine, poly[exponent], e.g., poly3 40 | config.noise_schedule = "linear" 41 | config.outside_embed = True # not used 42 | # t or none (removes time dependence) 43 | config.time_features = "t" 44 | config.cont_time = True 45 | 46 | config.feature_dim = 128 47 | config.n_layers = 32 48 | config.ch_mult = (1,) # not used 49 | config.n_dit_layers = 0 # not used 50 | config.dit_num_heads = 12 # not used 51 | config.dit_hidden_size = 768 # not used 52 | config.dropout_rate = 0.1 53 | 54 | config.num_heads = 12 # not used 55 | config.mlp_type = "glu" # not used 56 | config.depth_scaled_init = True # not used 57 | config.cond_type = "adaln_zero" # not used 58 | 59 | config.learning_rate = 2e-4 60 | config.learning_rate_schedule = "cosine" 61 | config.warmup_steps = 100 62 | config.weight_decay = 0.01 63 | config.clip = 0.0 64 | config.b2 = 0.99 65 | config.num_epochs = -1 66 | config.ema_rate = 0.9999 67 | # If num_train_steps==-1 then the number of training steps is calculated from 68 | # num_epochs. 69 | config.num_train_steps = 2_000_000 70 | # Evaluates for a full epoch if num_eval_steps==-1 71 | config.num_eval_steps = -1 72 | config.batch_size = 256 73 | config.per_device_batch_size = -1 74 | # If batches should be added to evaluate the entire dataset. 75 | config.eval_pad_last_batch = False 76 | config.check_nans = False 77 | 78 | # Sampling 79 | # ancestral, mean, or topp 80 | config.sampler = "ancestral" 81 | # uniform, cosine 82 | config.sampling_grid = "cosine" 83 | # for topp sampler 84 | config.topp = 0.98 85 | 86 | config.log_loss_every_steps = 500 87 | config.eval_every_steps = 10000 88 | config.checkpoint_every_steps = 5000 89 | config.checkpoint_keep_period = 10000 90 | 91 | # Single integer or tuple. If None will use (XManager ID, work unit). 92 | config.seed = 0 93 | 94 | # Number of workers for Grain loaders. 95 | config.grain_num_workers = 15 96 | 97 | config.trial = 0 # Dummy for repeated runs. 98 | config.test_in_colab = False 99 | return config 100 | 101 | 102 | # By default, the launcher calls `sweep()`. 103 | # To disable the sweep, the `sweep()` function can be commented (or renamed), 104 | # or the flag `--nosweep` can be specified to the launcher. 105 | def sweep(add: abc.Callable[..., None]): 106 | """Starts multiple work units with varying config args.""" 107 | # For best likelihood results 108 | add( 109 | noise_schedule="linear", 110 | sampling_grid="cosine", 111 | ) 112 | # For best sample quality 113 | # add( 114 | # noise_schedule="cosine", 115 | # sampling_grid="uniform", 116 | # ) 117 | -------------------------------------------------------------------------------- /md4/configs/md4/fineweb_edu.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | r"""A config for training MD4 on FineWeb EDU.""" 17 | 18 | from collections import abc 19 | 20 | from ml_collections import config_dict 21 | 22 | 23 | def get_config() -> config_dict.ConfigDict: 24 | """Default config.""" 25 | 26 | config = config_dict.ConfigDict() 27 | 28 | # dataset configs 29 | config.vocab_size = 50257 30 | config.use_v2_input_pipeline = True 31 | config.dataset = "fineweb_edu" 32 | config.classes = -1 33 | 34 | config.task_type = "text" # text or image 35 | config.model_type = "md4" 36 | config.data_shape = (1024,) 37 | 38 | # timesteps: int or None 39 | config.timesteps = 1000 40 | # linear, cosine, poly[exponent], e.g., poly3 41 | config.noise_schedule = "linear" 42 | config.outside_embed = False 43 | # t or none (removes time dependence) 44 | config.time_features = "t" 45 | config.cont_time = True 46 | 47 | config.feature_dim = 64 48 | config.n_layers = 12 49 | config.ch_mult = (1,) # not used 50 | config.n_dit_layers = 0 # not used 51 | config.dit_num_heads = 12 # not used 52 | config.dit_hidden_size = 768 # not used 53 | config.dropout_rate = 0.0 54 | 55 | config.num_heads = 12 56 | config.mlp_type = "glu" 57 | config.depth_scaled_init = True 58 | config.cond_type = "adaln_zero" 59 | 60 | config.learning_rate = 3e-4 61 | config.learning_rate_schedule = "cosine" 62 | config.warmup_steps = 2000 63 | config.weight_decay = 0.0 64 | config.clip = 0.0 65 | config.b2 = 0.999 66 | config.num_epochs = -1 67 | config.ema_rate = 0.9999 68 | # If num_train_steps==-1 then the number of training steps is calculated from 69 | # num_epochs. 70 | config.num_train_steps = 1_000_000 71 | # Evaluates for a full epoch if num_eval_steps==-1. 72 | config.num_eval_steps = -1 73 | config.batch_size = 512 74 | config.num_microbatches = 1 75 | config.per_device_batch_size = -1 76 | # If batches should be added to evaluate the entire dataset. 77 | config.eval_pad_last_batch = False 78 | config.check_nans = False 79 | 80 | # Sampling 81 | # ancestral, mean, or topp 82 | config.sampler = "ancestral" 83 | # uniform, cosine 84 | config.sampling_grid = "cosine" 85 | # for topp sampler 86 | config.topp = 0.98 87 | 88 | config.log_loss_every_steps = 500 89 | config.eval_every_steps = 10000 90 | config.checkpoint_every_steps = 5000 91 | config.checkpoint_keep_period = 200000 92 | 93 | # Single integer or tuple. If None will use (XManager ID, work unit). 94 | config.seed = 42 95 | 96 | # Number of workers for Grain loaders. 97 | # HF data source (OSS version) only supports one worker for now, more workers 98 | # result in duplicated data. 99 | config.grain_num_workers = 1 100 | 101 | config.trial = 0 # Dummy for repeated runs. 102 | config.test_in_colab = False 103 | return config 104 | 105 | 106 | # By default, the launcher calls `sweep()`. 107 | # To disable the sweep, the `sweep()` function can be commented (or renamed), 108 | # or the flag `--nosweep` can be specified to the launcher. 109 | def sweep(add: abc.Callable[..., None]): 110 | """Starts multiple work units with varying config args.""" 111 | # Small size 112 | add( 113 | noise_schedule="linear", 114 | dropout_rate=0.02, 115 | weight_decay=0.03, 116 | ) 117 | # Medium size 118 | # add( 119 | # noise_schedule="cosine", 120 | # n_layers=24, 121 | # num_heads=16, 122 | # batch_size=512, 123 | # sampling_grid="uniform", 124 | # dropout_rate=0.0, 125 | # weight_decay=0.03, 126 | # ) 127 | -------------------------------------------------------------------------------- /md4/configs/md4/imagenet64.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | r"""A config for training MD4 on downsampled ImageNet 64x64.""" 17 | 18 | from collections import abc 19 | 20 | from ml_collections import config_dict 21 | 22 | 23 | def get_config() -> config_dict.ConfigDict: 24 | """Default config.""" 25 | 26 | config = config_dict.ConfigDict() 27 | 28 | # dataset configs 29 | config.vocab_size = 256 30 | config.dataset = "downsampled_imagenet/64x64" 31 | config.classes = -1 32 | 33 | config.task_type = "image" # text or image 34 | config.model_type = "md4" 35 | config.data_shape = (64, 64, 3) 36 | 37 | # timesteps: int or None 38 | config.timesteps = 256 39 | # linear, cosine, poly[exponent], e.g., poly3 40 | config.noise_schedule = "linear" 41 | config.outside_embed = True # not used 42 | # t or none (removes time dependence) 43 | config.time_features = "t" 44 | config.cont_time = True 45 | 46 | config.feature_dim = 256 47 | config.n_layers = 8 48 | config.ch_mult = (1,) 49 | config.n_dit_layers = 20 50 | config.dit_num_heads = 12 51 | config.dit_hidden_size = 768 52 | config.dropout_rate = 0.1 53 | 54 | config.num_heads = 12 # not used 55 | config.mlp_type = "glu" # not used 56 | config.depth_scaled_init = True # not used 57 | config.cond_type = "adaln_zero" # not used 58 | 59 | config.learning_rate = 2e-4 60 | config.learning_rate_schedule = "cosine" 61 | config.warmup_steps = 100 62 | config.weight_decay = 0.01 63 | config.clip = 0.0 64 | config.b2 = 0.99 65 | config.num_epochs = -1 66 | config.ema_rate = 0.9999 67 | # If num_train_steps==-1 then the number of training steps is calculated from 68 | # num_epochs. 69 | config.num_train_steps = 2_000_000 70 | # Evaluates for a full epoch if num_eval_steps==-1 71 | config.num_eval_steps = -1 72 | config.batch_size = 512 73 | config.per_device_batch_size = -1 74 | # If batches should be added to evaluate the entire dataset. 75 | config.eval_pad_last_batch = False 76 | config.check_nans = False 77 | 78 | # Sampling 79 | # ancestral, mean, or topp 80 | config.sampler = "ancestral" 81 | # uniform, cosine 82 | config.sampling_grid = "cosine" 83 | # for topp sampler 84 | config.topp = 0.98 85 | 86 | config.log_loss_every_steps = 500 87 | config.eval_every_steps = 10000 88 | config.checkpoint_every_steps = 5000 89 | config.checkpoint_keep_period = 10000 90 | 91 | # Single integer or tuple. If None will use (XManager ID, work unit). 92 | config.seed = 0 93 | 94 | # Number of workers for Grain loaders. 95 | config.grain_num_workers = 15 96 | 97 | config.trial = 0 # Dummy for repeated runs. 98 | config.test_in_colab = False 99 | return config 100 | 101 | 102 | # By default, the launcher calls `sweep()`. 103 | # To disable the sweep, the `sweep()` function can be commented (or renamed), 104 | # or the flag `--nosweep` can be specified to the launcher. 105 | def sweep(add: abc.Callable[..., None]): 106 | """Starts multiple work units with varying config args.""" 107 | # Train unconditional model 108 | # Default: linear schedule, cosine sampling grid 109 | add( 110 | dataset="downsampled_imagenet/64x64", 111 | sampling_grid="cosine", 112 | noise_schedule="linear", 113 | n_layers=8, # use 24 for best likelihood in the paper 114 | n_dit_layers=20, # use 12 for best likelihood in the paper 115 | ) 116 | # # Cosine schedule: worse likelihood, for best sample quality 117 | # add( 118 | # dataset="downsampled_imagenet/64x64", 119 | # sampling_grid="uniform", 120 | # noise_schedule="cosine", 121 | # n_layers=8, 122 | # n_dit_layers=20, 123 | # ) 124 | # # Train class conditioned 125 | # add( 126 | # dataset="class_cond_imagenet64", 127 | # classes=1000, 128 | # sampling_grid="uniform", 129 | # noise_schedule="cosine", 130 | # ) 131 | -------------------------------------------------------------------------------- /md4/configs/md4/openwebtext.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | r"""A config for training small model on OpenWebText.""" 17 | 18 | from collections import abc 19 | 20 | from ml_collections import config_dict 21 | 22 | 23 | def get_config() -> config_dict.ConfigDict: 24 | """Default config.""" 25 | 26 | config = config_dict.ConfigDict() 27 | 28 | # dataset configs 29 | config.vocab_size = 50257 30 | config.dataset = "openwebtext" 31 | config.classes = -1 32 | 33 | config.task_type = "text" # text or image 34 | config.model_type = "md4" 35 | config.data_shape = (1024,) 36 | 37 | # timesteps: int or None 38 | config.timesteps = 1000 39 | # linear, cosine, poly[exponent], e.g., poly3 40 | config.noise_schedule = "linear" 41 | config.outside_embed = False 42 | # t or none (removes time dependence) 43 | config.time_features = "t" 44 | config.cont_time = True 45 | 46 | config.feature_dim = 64 47 | config.n_layers = 12 48 | config.ch_mult = (1,) # not used 49 | config.n_dit_layers = 0 # not used 50 | config.dit_num_heads = 12 # not used 51 | config.dit_hidden_size = 768 # not used 52 | config.dropout_rate = 0.0 53 | 54 | config.num_heads = 12 55 | config.mlp_type = "glu" 56 | config.depth_scaled_init = True 57 | config.cond_type = "adaln_zero" 58 | 59 | config.learning_rate = 3e-4 60 | config.learning_rate_schedule = "cosine" 61 | config.warmup_steps = 2000 62 | config.weight_decay = 0.0 63 | config.clip = 0.0 64 | config.b2 = 0.999 65 | config.num_epochs = -1 66 | config.ema_rate = 0.9999 67 | # If num_train_steps==-1 then the number of training steps is calculated from 68 | # num_epochs. 69 | config.num_train_steps = 1_000_000 70 | # Evaluates for a full epoch if num_eval_steps==-1. 71 | config.num_eval_steps = -1 72 | config.batch_size = 512 73 | config.num_microbatches = 1 74 | config.per_device_batch_size = -1 75 | # If batches should be added to evaluate the entire dataset. 76 | config.eval_pad_last_batch = False 77 | config.check_nans = False 78 | 79 | # Sampling 80 | # ancestral, mean, or topp 81 | config.sampler = "ancestral" 82 | # uniform, cosine 83 | config.sampling_grid = "cosine" 84 | # for topp sampler 85 | config.topp = 0.98 86 | 87 | config.log_loss_every_steps = 500 88 | config.eval_every_steps = 10000 89 | config.checkpoint_every_steps = 5000 90 | config.checkpoint_keep_period = 200000 91 | 92 | # Single integer or tuple. If None will use (XManager ID, work unit). 93 | config.seed = 42 94 | 95 | # Number of workers for Grain loaders. 96 | # HF data source (OSS version) only supports one worker for now, more workers 97 | # result in duplicated data. 98 | config.grain_num_workers = 1 99 | 100 | config.trial = 0 # Dummy for repeated runs. 101 | config.test_in_colab = False 102 | return config 103 | 104 | 105 | # By default, the launcher calls `sweep()`. 106 | # To disable the sweep, the `sweep()` function can be commented (or renamed), 107 | # or the flag `--nosweep` can be specified to the launcher. 108 | def sweep(add: abc.Callable[..., None]): 109 | """Starts multiple work units with varying config args.""" 110 | # Small size 111 | add( 112 | noise_schedule="linear", 113 | dropout_rate=0.02, 114 | weight_decay=0.03, 115 | ) 116 | # Medium size 117 | # add( 118 | # noise_schedule="cosine", 119 | # n_layers=24, 120 | # num_heads=16, 121 | # batch_size=512, 122 | # sampling_grid="uniform", 123 | # dropout_rate=0.0, 124 | # weight_decay=0.03, 125 | # ) 126 | -------------------------------------------------------------------------------- /md4/configs/md4/text8.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | r"""A config for training MD4 on text8.""" 17 | 18 | from collections import abc 19 | 20 | from ml_collections import config_dict 21 | 22 | 23 | def get_config() -> config_dict.ConfigDict: 24 | """Default config.""" 25 | 26 | config = config_dict.ConfigDict() 27 | 28 | # dataset configs 29 | config.vocab_size = 27 30 | config.dataset = "text8" 31 | config.classes = -1 32 | 33 | config.task_type = "text" # text or image 34 | config.model_type = "md4" 35 | config.data_shape = (256,) 36 | 37 | # timesteps: int or None 38 | config.timesteps = 1000 39 | # linear, cosine, poly[exponent], e.g., poly3 40 | config.noise_schedule = "linear" 41 | config.outside_embed = True 42 | # t or none (removes time dependence) 43 | config.time_features = "t" 44 | config.cont_time = True 45 | 46 | config.feature_dim = 64 47 | config.n_layers = 12 48 | config.ch_mult = (1,) # not used 49 | config.n_dit_layers = 0 # not used 50 | config.dit_num_heads = 12 # not used 51 | config.dit_hidden_size = 768 # not used 52 | config.dropout_rate = 0.0 53 | 54 | config.num_heads = 12 55 | config.mlp_type = "glu" 56 | config.depth_scaled_init = True 57 | config.cond_type = "adaln_zero" 58 | 59 | config.learning_rate = 3e-4 60 | config.learning_rate_schedule = "cosine" 61 | config.warmup_steps = 2000 62 | config.weight_decay = 0.0 63 | config.clip = 0.0 64 | config.b2 = 0.999 65 | config.num_epochs = -1 66 | config.ema_rate = 0.9999 67 | # If num_train_steps==-1 then the number of training steps is calculated from 68 | # num_epochs. 69 | config.num_train_steps = 1_000_000 70 | # Evaluates for a full epoch if num_eval_steps==-1. 71 | config.num_eval_steps = -1 72 | config.batch_size = 512 73 | config.num_microbatches = 1 74 | config.per_device_batch_size = -1 75 | # If batches should be added to evaluate the entire dataset. 76 | config.eval_pad_last_batch = False 77 | config.check_nans = False 78 | 79 | # Sampling 80 | # ancestral, mean, or topp 81 | config.sampler = "ancestral" 82 | # uniform, cosine 83 | config.sampling_grid = "cosine" 84 | # for topp sampler 85 | config.topp = 0.98 86 | 87 | config.log_loss_every_steps = 500 88 | config.eval_every_steps = 5000 89 | config.checkpoint_every_steps = 5000 90 | config.checkpoint_keep_period = 10000 91 | 92 | # Single integer or tuple. If None will use (XManager ID, work unit). 93 | config.seed = 42 94 | 95 | # Number of workers for Grain loaders. 96 | config.grain_num_workers = 15 97 | 98 | config.trial = 0 # Dummy for repeated runs. 99 | config.test_in_colab = False 100 | return config 101 | 102 | 103 | # By default, the launcher calls `sweep()`. 104 | # To disable the sweep, the `sweep()` function can be commented (or renamed), 105 | # or the flag `--nosweep` can be specified to the launcher. 106 | def sweep(add: abc.Callable[..., None]): 107 | """Starts multiple work units with varying config args.""" 108 | add( 109 | learning_rate=5e-4, 110 | dropout_rate=0.05, 111 | weight_decay=0.03, 112 | ) 113 | -------------------------------------------------------------------------------- /md4/input_pipeline.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Deterministic input pipeline.""" 17 | 18 | from collections.abc import Sequence 19 | import dataclasses 20 | import os 21 | from typing import Any, Union 22 | import urllib.request 23 | import zipfile 24 | 25 | import grain.python as grain 26 | import jax 27 | from ml_collections import config_dict 28 | import numpy as np 29 | import tensorflow as tf 30 | import tensorflow_datasets as tfds 31 | import transformers 32 | 33 | 34 | # pylint: disable=g-import-not-at-top 35 | try: 36 | import cv2 37 | except ImportError: 38 | print("cv2 not found") 39 | FlatFeatures = dict[str, Any] 40 | 41 | _DataSet = Union[grain.MapDataset, grain.DataLoader, grain.IterDataset] 42 | 43 | _GPT2_TOKENIZER = "gpt2" 44 | _OWT_DATASETS = dict( 45 | # OSS version. Please prepare the OWT datasets using the following command: 46 | # python ./prepare_openwebtext_data.py 47 | dataset_train_path=("./data_dir/openwebtext_splits_1024_train"), 48 | dataset_eval_path=("./data_dir/openwebtext_splits_1024_eval"), 49 | ) 50 | 51 | 52 | class Text8Tokenizer: 53 | """Simple text8 tokenizer.""" 54 | 55 | def __init__(self, num_extra_tokens=0): 56 | self.num_extra_tokens = num_extra_tokens 57 | 58 | @property 59 | def vocab_size(self): 60 | return 27 + self.num_extra_tokens 61 | 62 | @property 63 | def pad_token(self): 64 | return 26 65 | 66 | def encode(self, text): 67 | tokens = np.array([i - 97 for i in text], dtype=np.int32) 68 | tokens = np.where(tokens < 0, self.pad_token, tokens) 69 | return tokens 70 | 71 | def decode(self, tokens): 72 | tokens = np.where(np.equal(tokens, self.pad_token), 32 - 97, tokens) + 97 73 | text = tokens.astype(np.uint8).tobytes() 74 | return text.decode("utf-8") 75 | 76 | 77 | def preprocess_text8( 78 | data_dir, 79 | doc_length: int = 512, 80 | ): 81 | """Load the 27-char text8 dataset.""" 82 | if not os.path.exists(os.path.join(data_dir, "text8.train.txt")): 83 | if not os.path.exists(data_dir): 84 | os.makedirs(data_dir) 85 | if not os.path.exists(os.path.join(data_dir, "text8.zip")): 86 | url = "http://mattmahoney.net/dc/text8.zip" 87 | print("Downloading text8 from URL {}.".format(url)) 88 | urllib.request.urlretrieve(url, os.path.join(data_dir, "text8.zip")) 89 | with open(os.path.join(data_dir, "text8.zip"), "rb") as f: 90 | rawdata = zipfile.ZipFile(f).read("text8").decode("utf-8") 91 | splits = { 92 | "train": rawdata[:90000000], 93 | "valid": rawdata[90000000:95000000], 94 | "test": rawdata[95000000:], 95 | } 96 | for split, data in splits.items(): 97 | with open(os.path.join(data_dir, "text8." + split + ".txt"), "w") as f: 98 | f.write(data) 99 | 100 | def load_text8_split(split: str): 101 | def _split_chars(arr): 102 | return tf.compat.v1.string_split( 103 | [arr], sep="", result_type="RaggedTensor" 104 | ).flat_values 105 | 106 | def _join_and_rename(x): 107 | text = tf.strings.reduce_join(x, axis=0) 108 | return {"text": text} 109 | 110 | path = os.path.join(data_dir, "text8." + split + ".txt") 111 | ds = tf.data.TextLineDataset(path).map(_split_chars).unbatch() 112 | ds = ds.batch(doc_length, drop_remainder=True) 113 | ds = ds.map(_join_and_rename) 114 | return ds 115 | 116 | # Define the builder. 117 | text8_builder = tfds.dataset_builders.store_as_tfds_dataset( 118 | name="text8", 119 | version="1.0.0", 120 | features=tfds.features.FeaturesDict({ 121 | "text": tfds.features.Text(), 122 | }), 123 | split_datasets={ 124 | "train": load_text8_split("train"), 125 | "valid": load_text8_split("valid"), 126 | "test": load_text8_split("test"), 127 | }, 128 | config="text8", 129 | data_dir=data_dir, 130 | description="text8 dataset, document length 512.", 131 | file_format="array_record", 132 | disable_shuffling=True, 133 | ) 134 | 135 | return text8_builder 136 | 137 | 138 | class ChunkDataSource(grain.RandomAccessDataSource): 139 | """Chunk text data source.""" 140 | 141 | def __init__(self, tensor, chunk_size=256, overlapping=False): 142 | self.chunk_size = chunk_size 143 | self.overlapping = overlapping 144 | tensor = tensor.encode("utf-8") 145 | if not overlapping: 146 | extra_len = len(tensor) % chunk_size 147 | if extra_len > 0: 148 | tensor = tensor[:-extra_len] 149 | self.tensor = np.array(list(tensor)).reshape(-1, chunk_size) 150 | else: 151 | self.tensor = tensor 152 | 153 | def __len__(self): 154 | if not self.overlapping: 155 | return self.tensor.shape[0] 156 | else: 157 | return len(self.tensor) - self.chunk_size + 1 158 | 159 | def __getitem__(self, record_key): 160 | if not self.overlapping: 161 | return {"text": self.tensor[record_key]} 162 | else: 163 | start_idx = record_key 164 | end_idx = record_key + self.chunk_size 165 | chunk = self.tensor[start_idx:end_idx] 166 | return {"text": chunk} 167 | 168 | def __repr__(self) -> str: 169 | return f"ChunkDataSource(len={len(self)},overlapping={self.overlapping})" 170 | 171 | 172 | @dataclasses.dataclass 173 | class Tokenize(grain.MapTransform): 174 | tokenizer: Text8Tokenizer 175 | 176 | def map(self, features): 177 | text = features["text"] 178 | features["text"] = self.tokenizer.encode(text) 179 | return features 180 | 181 | 182 | @dataclasses.dataclass(frozen=True) 183 | class DiscreteWithoutLabel(grain.MapTransform): 184 | """Discrete image data with zero labels.""" 185 | 186 | def map(self, features): 187 | features["image"] = features["image"].astype(np.int32) 188 | if "label" in features: 189 | del features["label"] 190 | if "id" in features: 191 | del features["id"] 192 | return features 193 | 194 | 195 | @dataclasses.dataclass(frozen=True) 196 | class ResizeSmall(grain.MapTransform): 197 | """Resizes the smaller side to `size` keeping aspect ratio. 198 | 199 | Attr: 200 | size: Smaller side of an input image (might be adjusted if max_size given). 201 | """ 202 | 203 | size: int 204 | 205 | def map(self, features: FlatFeatures) -> FlatFeatures: 206 | image = features["image"] 207 | size = self.size 208 | image = cv2.resize(image, dsize=(size, size), interpolation=cv2.INTER_AREA) 209 | features["image"] = image.astype(np.int32) 210 | return features 211 | 212 | 213 | @dataclasses.dataclass(frozen=True) 214 | class CentralSquareCrop(grain.MapTransform): 215 | """Makes a square central square crop of a given size.""" 216 | 217 | def map(self, features: FlatFeatures) -> FlatFeatures: 218 | image = features["image"] 219 | h, w = image.shape[:2] 220 | size = min(h, w) 221 | top = (h - size) // 2 222 | left = (w - size) // 2 223 | image = image[top : top + size, left : left + size, :] 224 | features["image"] = image 225 | return features 226 | 227 | 228 | def get_data_shape(config): 229 | return config.data_shape 230 | 231 | 232 | @dataclasses.dataclass(frozen=True) 233 | class DropFeatures(grain.MapTransform): 234 | feature_names: Sequence[str] 235 | 236 | def map(self, features: FlatFeatures) -> FlatFeatures: 237 | for feature_name in self.feature_names: 238 | del features[feature_name] 239 | return features 240 | 241 | 242 | @dataclasses.dataclass 243 | class ParseFeatures(grain.MapTransform): 244 | """Parse serialized example.""" 245 | 246 | def __init__(self, data_column): 247 | self.data_column = data_column 248 | 249 | def map(self, features): 250 | def _parse(example): 251 | parsed = tf.io.parse_example( 252 | example, 253 | { 254 | self.data_column: tf.io.FixedLenSequenceFeature( 255 | [], dtype=tf.int64, allow_missing=True 256 | ) 257 | }, 258 | ) 259 | return parsed 260 | 261 | return _parse(features) 262 | 263 | 264 | def get_num_train_steps(config: config_dict.ConfigDict) -> int: 265 | """Calculates the total number of training steps.""" 266 | if config.num_train_steps > 0: 267 | return config.num_train_steps 268 | # From the beginning. We first shard the data (shard by process_count), then 269 | # combine all epochs, batch for all local devices. 270 | # In all steps we would drop the remainder (hence the use of integer 271 | # division). 272 | # When start_index is 0 the train_ds.cardinality() and num_train_steps should 273 | # be equivalent. 274 | if config.task_type == "image": 275 | tfds_info = tfds.builder(config.dataset).info 276 | num_train_records = tfds_info.splits["train"].num_examples 277 | return int( 278 | num_train_records // jax.process_count() * config.num_epochs 279 | ) // (config.per_device_batch_size * jax.local_device_count()) 280 | else: 281 | raise NotImplementedError() 282 | 283 | 284 | def create_datasets( 285 | config: config_dict.ConfigDict, seed: int 286 | ) -> tuple[_DataSet, dict[str, _DataSet], dict[str, Any]]: 287 | """Create Grain data loaders for training and evaluation. 288 | 289 | Args: 290 | config: Configuration to use. 291 | seed: Seed for shuffle and random operations in the training dataset. 292 | 293 | Returns: 294 | A tuple with the training dataset loader, the evaluation dataset 295 | loader, and a dictionary of other infos. 296 | """ 297 | info = {} 298 | assert config.batch_size % jax.process_count() == 0 299 | process_batch_size = config.batch_size // jax.process_count() 300 | eval_batch_size = config.get("eval_batch_size", config.batch_size) 301 | process_eval_batch_size = eval_batch_size // jax.process_count() 302 | 303 | if config.dataset == "text8": 304 | seq_len = config.data_shape[0] 305 | # Current train/valid format only support length of 256 306 | assert seq_len == 256 307 | 308 | with tf.io.gfile.GFile( 309 | os.path.join(DATA_DIR, "text8", "text8.zip"), "rb" 310 | ) as f: 311 | rawdata = zipfile.ZipFile(f).read("text8").decode("utf-8") 312 | splits = { 313 | "train": rawdata[:90000000], 314 | "valid": rawdata[90000000:95000000], 315 | "test": rawdata[95000000:], 316 | } 317 | 318 | tokenizer = Text8Tokenizer(num_extra_tokens=0) 319 | train_transformations = [Tokenize(tokenizer)] 320 | train_source = ChunkDataSource( 321 | splits["train"], chunk_size=seq_len, overlapping=True 322 | ) 323 | 324 | eval_transformations = [Tokenize(tokenizer)] 325 | eval_source = { 326 | k: ChunkDataSource(splits[k], chunk_size=seq_len) 327 | for k in ["valid", "test"] 328 | } 329 | info["tokenizer"] = tokenizer 330 | info["rawdata"] = rawdata 331 | elif config.dataset == "openwebtext": 332 | # we need to pretrain a GPT2 size model with context length of 1024. 333 | seq_len = config.data_shape[0] 334 | assert seq_len == 1024 335 | train_transformations = [ParseFeatures(data_column="text")] 336 | eval_transformations = [ParseFeatures(data_column="text")] 337 | 338 | train_table_path = _OWT_DATASETS["dataset_train_path"] 339 | train_source = grain.ArrayRecordDataSource(paths=train_table_path) 340 | 341 | eval_source = { 342 | "owt_eval": grain.ArrayRecordDataSource( 343 | paths=_OWT_DATASETS["dataset_eval_path"] 344 | ), 345 | } 346 | 347 | tokenizer = transformers.GPT2Tokenizer.from_pretrained(_GPT2_TOKENIZER) 348 | info["tokenizer"] = tokenizer 349 | elif ( 350 | config.dataset.startswith("mnist") 351 | or config.dataset.startswith("cifar") 352 | or config.dataset.startswith("downsampled_imagenet") 353 | ): 354 | data_source = tfds.data_source(config.dataset) 355 | train_transformations = [DiscreteWithoutLabel()] 356 | train_source = data_source["train"] 357 | eval_transformations = [DiscreteWithoutLabel()] 358 | eval_source = {k: v for k, v in data_source.items() if k != "train"} 359 | elif config.dataset == "class_cond_imagenet64": 360 | data_source = tfds.data_source("imagenet2012") 361 | train_transformations = [ 362 | CentralSquareCrop(), 363 | ResizeSmall(64), 364 | DropFeatures(("file_name",)), 365 | ] 366 | train_source = data_source["train"] 367 | eval_transformations = [ 368 | CentralSquareCrop(), 369 | ResizeSmall(64), 370 | DropFeatures(("file_name",)), 371 | ] 372 | eval_source = {"validation": data_source["validation"]} 373 | else: 374 | raise NotImplementedError("Unsupported datasets.") 375 | 376 | train_loader = grain.load( 377 | source=train_source, 378 | shuffle=True, 379 | seed=seed, 380 | shard_options=grain.ShardByJaxProcess(drop_remainder=True), 381 | transformations=train_transformations, 382 | batch_size=process_batch_size, 383 | worker_count=config.grain_num_workers, 384 | ) 385 | 386 | if config.eval_pad_last_batch: 387 | raise NotImplementedError( 388 | "BatchWithPadElements is not implemented in PyGrain yet." 389 | ) 390 | else: 391 | drop_remainder = True 392 | shard_options = grain.ShardByJaxProcess(drop_remainder=drop_remainder) 393 | 394 | eval_loaders = {} 395 | for split in eval_source: 396 | eval_loader = grain.load( 397 | source=eval_source[split], 398 | num_epochs=1, 399 | shard_options=shard_options, 400 | transformations=eval_transformations, 401 | batch_size=process_eval_batch_size, 402 | worker_count=0, 403 | drop_remainder=drop_remainder, 404 | ) 405 | eval_loaders[split] = eval_loader 406 | 407 | return train_loader, eval_loaders, info 408 | -------------------------------------------------------------------------------- /md4/input_pipeline_v2.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Input pipeline for grain based datasets.""" 17 | 18 | from collections.abc import Sequence 19 | import dataclasses 20 | import threading 21 | from typing import Any 22 | 23 | from absl import logging 24 | import datasets 25 | import datasets.distributed 26 | from etils import epath 27 | import grain.python as grain 28 | import jax 29 | from ml_collections import config_dict 30 | import numpy as np 31 | import tensorflow as tf 32 | import tensorflow_datasets as tfds 33 | import transformers 34 | 35 | 36 | _GPT2_TOKENIZER = "gpt2" 37 | _OWT_DATASETS = dict( 38 | # OSS version. Please prepare the OWT datasets using the following command: 39 | # python ./prepare_openwebtext_data.py 40 | dataset_train_path=("./data_dir/openwebtext_splits_1024_train"), 41 | dataset_eval_path=("./data_dir/openwebtext_splits_1024_eval"), 42 | ) 43 | 44 | 45 | class HFDataSource(grain.RandomAccessDataSource): 46 | """A class that makes HuggingFace IterableDataset a grain datasource without random access support.""" 47 | 48 | def __init__( 49 | self, 50 | dataset: datasets.IterableDataset, 51 | dataloading_host_index: int, 52 | dataloading_host_count: int, 53 | num_threads: int, 54 | generate_padding_example: bool, 55 | max_target_length: int, 56 | data_column_name: str, 57 | ): 58 | self.dataset = dataset 59 | self.num_threads = num_threads 60 | self.dataloading_host_count = dataloading_host_count 61 | self.dataloading_host_index = dataloading_host_index 62 | self.generate_padding_example = generate_padding_example 63 | self.max_target_lenth = max_target_length 64 | self.data_column_name = data_column_name 65 | self.n_shards = dataset.n_shards 66 | self._check_shard_count() 67 | self.dataset_shards = [ 68 | dataloading_host_index * self.num_threads + i 69 | for i in range(self.num_threads) 70 | ] 71 | self.datasets = [ 72 | datasets.distributed.split_dataset_by_node( 73 | dataset, world_size=self.n_shards, rank=x 74 | ) 75 | for x in self.dataset_shards 76 | ] 77 | self.data_iters = [] 78 | self.out_of_data = False 79 | 80 | def _check_shard_count(self): 81 | if self.n_shards < (self.dataloading_host_count * self.num_threads): 82 | print( 83 | f"WARNING: Inefficient dataloading. Your train or eval dataset" 84 | f" contains {self.n_shards} shards, smaller than number of host" 85 | " loading data. This is known to lead to inefficient dataloading." 86 | " see" 87 | " https://github.com/google/maxtext/blob/main/getting_started/Data_Input_Pipeline.md#multihost-dataloading-best-practice" 88 | ) 89 | self.n_shards = self.dataloading_host_count * self.num_threads 90 | 91 | def _update_shard(self, idx): 92 | new_shard = ( 93 | self.dataset_shards[idx] 94 | + self.dataloading_host_count * self.num_threads 95 | ) 96 | if new_shard < self.n_shards: 97 | print( 98 | f"Updating host {self.dataloading_host_index} dataset {idx}, was on" 99 | f" shard {self.dataset_shards[idx]}" 100 | ) 101 | print(f"New shard is {new_shard}") 102 | self.dataset_shards[idx] = new_shard 103 | self.datasets[idx] = datasets.distributed.split_dataset_by_node( 104 | self.dataset, world_size=self.n_shards, rank=self.dataset_shards[idx] 105 | ) 106 | self.data_iters[idx] = iter(self.datasets[idx]) 107 | else: 108 | print( 109 | f"Run out of shards on host {self.dataloading_host_index}, shard" 110 | f" {self.dataset_shards[idx]} is not available" 111 | ) 112 | self.out_of_data = True 113 | if self.generate_padding_example: 114 | print( 115 | f"Host {self.dataloading_host_index} will start generating all-0" 116 | " padding examples until step number is met." 117 | ) 118 | 119 | def __len__(self): 120 | """Return length of the HF dataset. 121 | 122 | Since HuggingFace IterableDataset does not have length, 123 | a fake length bigger than the dataset is returned 124 | """ 125 | return 10_000_000_000 126 | 127 | def __getitem__(self, index): 128 | """Since HuggingFace IterableDataset doesn't support random access by index. 129 | 130 | The next item in the iterator is returned. 131 | 132 | Args: 133 | index: The index of the item to return. 134 | 135 | Returns: 136 | The next item in the iterator. 137 | """ 138 | if not self.data_iters: 139 | self.data_iters = [iter(x) for x in self.datasets] 140 | idx = int(threading.current_thread().name.split("_")[1]) 141 | 142 | while True: 143 | try: 144 | if self.out_of_data: 145 | if self.generate_padding_example: 146 | return { 147 | self.data_column_name: np.zeros( 148 | self.max_target_lenth, dtype=np.int32 149 | ) 150 | } 151 | else: 152 | return None 153 | data = next(self.data_iters[idx]) 154 | return data 155 | except StopIteration: 156 | self._update_shard(idx) 157 | 158 | 159 | @dataclasses.dataclass 160 | class NormalizeFeatures(grain.MapTransform): 161 | """Normalize text feature keys.""" 162 | 163 | def __init__(self, column_name): 164 | self.column_name = column_name 165 | 166 | def map(self, features): 167 | return {"text": features[self.column_name].decode()} 168 | 169 | 170 | @dataclasses.dataclass 171 | class HFNormalizeFeatures(grain.MapTransform): 172 | """Normalize feature keys for HuggingFace input.""" 173 | 174 | def __init__(self, column_name): 175 | self.column_name = column_name 176 | 177 | def map(self, features): 178 | return { 179 | "text": np.asarray(features[self.column_name], dtype=np.int32), 180 | } 181 | 182 | 183 | @dataclasses.dataclass 184 | class ReformatPacking(grain.MapTransform): 185 | """Reformat packing outputs.""" 186 | 187 | def map(self, data): 188 | return { 189 | "text": data[0]["text"], 190 | "text_segmentation": data[1]["text"], 191 | "text_position": data[2]["text"], 192 | } 193 | 194 | 195 | @dataclasses.dataclass 196 | class PadToMaxLength(grain.MapTransform): 197 | """Pads each input to the specified length.""" 198 | 199 | def __init__(self, max_length): 200 | self.max_length = max_length 201 | 202 | def map(self, data): 203 | """map to each element.""" 204 | 205 | def _pad(x, max_length): 206 | pad_amount = max(max_length - x.shape[0], 0) 207 | pad_amount = [(0, pad_amount)] + [(0, 0)] * (len(x.shape) - 1) 208 | return np.pad(x, pad_amount) 209 | 210 | data["text_segmentation"] = np.ones(data["text"].shape, dtype=np.int32) 211 | data["text_position"] = np.arange(data["text"].shape[0], dtype=np.int32) 212 | for key, _ in data.items(): 213 | data[key] = _pad(data[key], self.max_length) 214 | return data 215 | 216 | 217 | @dataclasses.dataclass 218 | class TokenizeAndTrim(grain.MapTransform): 219 | """Tokenize and trim features to sequence length.""" 220 | 221 | # pylint: disable=attribute-defined-outside-init 222 | feature_names: Sequence[str] 223 | sequence_length: Sequence[int] 224 | tokenizer: Any 225 | add_bos: bool 226 | add_eos: bool 227 | 228 | def map(self, features: dict[str, Any]) -> dict[str, Any]: 229 | """Maps to each element.""" 230 | for feature_name, sequence_length in zip( 231 | self.feature_names, self.sequence_length, strict=True 232 | ): 233 | text = features[feature_name] 234 | token_ids = self.tokenizer(text)["input_ids"] 235 | if self.add_bos: 236 | token_ids = [self.tokenizer.bos_token_id] + token_ids 237 | 238 | if self.add_eos: 239 | token_ids = token_ids[: sequence_length - 1] 240 | token_ids = token_ids + [self.tokenizer.eos_token_id] 241 | else: 242 | token_ids = token_ids[:sequence_length] 243 | 244 | features[feature_name] = np.asarray(token_ids, dtype=np.int32) 245 | return features 246 | 247 | 248 | @dataclasses.dataclass 249 | class ParseFeatures(grain.MapTransform): 250 | """Parse serialized example.""" 251 | 252 | def __init__(self, data_column): 253 | self.data_column = data_column 254 | 255 | def map(self, features): 256 | def _parse(example): 257 | parsed = tf.io.parse_example( 258 | example, 259 | { 260 | self.data_column: tf.io.FixedLenSequenceFeature( 261 | [], dtype=tf.int64, allow_missing=True 262 | ) 263 | }, 264 | ) 265 | return parsed 266 | 267 | return _parse(features) 268 | 269 | 270 | 271 | 272 | def tokenization(example, hf_tokenizer, max_length, column_name): 273 | """Tokenize a HuggingFace dataset.""" 274 | return hf_tokenizer( 275 | example[column_name], truncation=True, max_length=max_length 276 | ) 277 | 278 | 279 | def load_fineweb_edu_hf_source(): 280 | """Loads fineweb_edu data source from HuggingFace. 281 | 282 | Returns: 283 | A grain data source for fineweb_edu. 284 | """ 285 | tokenizer = transformers.GPT2Tokenizer.from_pretrained("gpt2") 286 | fw = datasets.load_dataset( 287 | "HuggingFaceFW/fineweb-edu", name="default", split="train", streaming=True 288 | ) 289 | fw = fw.map( 290 | tokenization, 291 | batched=True, 292 | fn_kwargs={ 293 | "hf_tokenizer": tokenizer, 294 | "max_length": 1023, 295 | "column_name": "text", 296 | }, 297 | ) 298 | fw = fw.select_columns(["input_ids"]).rename_column("input_ids", "text") 299 | return HFDataSource(fw, 0, 1, 1, False, 1024, "text") 300 | 301 | 302 | def compile_transformations( 303 | seq_len, 304 | tokenizer, 305 | data_column="text", 306 | add_bos=False, 307 | add_eos=True, 308 | packing=True, 309 | drop_remainder=True, 310 | process_batch_size=32, 311 | ): 312 | """Collects transformations for the grain input pipeline.""" 313 | # Normalize: convert bytes to string ready for tokenization 314 | operations = [] 315 | operations.append(NormalizeFeatures(data_column)) 316 | operations.append( 317 | TokenizeAndTrim([data_column], [seq_len], tokenizer, add_bos, add_eos) 318 | ) 319 | 320 | # Pack and Batch examples. 321 | if packing: 322 | operations.append( 323 | grain.experimental.PackAndBatchOperation( 324 | batch_size=process_batch_size, 325 | length_struct={data_column: seq_len}, 326 | ) 327 | ) 328 | operations.append(ReformatPacking()) 329 | else: 330 | operations.append(PadToMaxLength(seq_len)) 331 | operations.append( 332 | grain.Batch( 333 | batch_size=process_batch_size, 334 | drop_remainder=drop_remainder, 335 | ) 336 | ) 337 | return operations 338 | 339 | 340 | def compile_hf_transformations( 341 | seq_len, 342 | data_column="text", 343 | process_batch_size=32, 344 | ): 345 | """Collects transformations for the grain input pipeline.""" 346 | operations = [] 347 | operations.append(HFNormalizeFeatures(data_column)) 348 | operations.append( 349 | grain.experimental.PackAndBatchOperation( 350 | batch_size=process_batch_size, 351 | length_struct={data_column: seq_len}, 352 | ) 353 | ) 354 | operations.append(ReformatPacking()) 355 | return operations 356 | 357 | 358 | def create_datasets( 359 | config: config_dict.ConfigDict, seed: int 360 | ) -> tuple[grain.DataLoader, dict[str, grain.DataLoader], dict[str, Any]]: 361 | """Create Grain data loaders for training and evaluation. 362 | 363 | For the same seed and config this will return the same datasets. 364 | The user is responsible to save()/load() the dataset iterators (for training) 365 | or calling reset() to restart the iterator (for eval). 366 | 367 | Args: 368 | config: Configuration to use. 369 | seed: Seed for shuffle and random operations in the training dataset. 370 | 371 | Returns: 372 | A tuple with the training dataset loader, the evaluation dataset 373 | loader, and a dictionary of other infos. 374 | """ 375 | info = {} 376 | assert config.batch_size % jax.process_count() == 0 377 | process_batch_size = config.batch_size // jax.process_count() 378 | eval_batch_size = config.get("eval_batch_size", config.batch_size) 379 | process_eval_batch_size = eval_batch_size // jax.process_count() 380 | 381 | if config.dataset == "fineweb_edu": 382 | # we need to pretrain a GPT2 size model with context length of 1024. 383 | seq_len = config.data_shape[0] 384 | assert seq_len == 1024 385 | 386 | tokenizer = transformers.GPT2Tokenizer.from_pretrained(_GPT2_TOKENIZER) 387 | 388 | load_fineweb_edu_source = load_fineweb_edu_hf_source 389 | train_source = load_fineweb_edu_source() 390 | transformations = compile_hf_transformations( 391 | seq_len, 392 | data_column="text", 393 | process_batch_size=process_batch_size, 394 | ) 395 | eval_sources = { 396 | "owt_eval": grain.ArrayRecordDataSource( 397 | paths=_OWT_DATASETS["dataset_eval_path"] 398 | ), 399 | # "fwe_eval": eval_source, 400 | } 401 | eval_transformations = { 402 | "owt_eval": [ 403 | ParseFeatures(data_column="text"), 404 | grain.Batch( 405 | batch_size=process_eval_batch_size, 406 | drop_remainder=True, 407 | ), 408 | ], 409 | # "fwe_eval": transformations, 410 | } 411 | info["tokenizer"] = tokenizer 412 | 413 | else: 414 | raise NotImplementedError("Unsupported datasets.") 415 | index_sampler = grain.IndexSampler( 416 | num_records=len(train_source), 417 | shard_options=grain.ShardByJaxProcess(drop_remainder=True), 418 | shuffle=True, 419 | seed=seed, 420 | ) 421 | train_loader = grain.DataLoader( 422 | data_source=train_source, 423 | operations=transformations, 424 | sampler=index_sampler, 425 | worker_count=config.grain_num_workers, 426 | worker_buffer_size=1, 427 | read_options=grain.ReadOptions(num_threads=1, prefetch_buffer_size=1024), 428 | ) 429 | 430 | if config.eval_pad_last_batch: 431 | raise NotImplementedError( 432 | "BatchWithPadElements is not implemented in PyGrain yet." 433 | ) 434 | else: 435 | drop_remainder = True 436 | shard_options = grain.ShardByJaxProcess(drop_remainder=drop_remainder) 437 | 438 | eval_loaders = {} 439 | for split in eval_sources: 440 | eval_loader = grain.load( 441 | source=eval_sources[split], 442 | num_epochs=1, 443 | shard_options=shard_options, 444 | transformations=eval_transformations[split], 445 | # For now, we do not parallelize the evaluation, because there is a 446 | # bug on DataLoader.__iter__ when used with Jax. 447 | worker_count=0, 448 | read_options=grain.ReadOptions(prefetch_buffer_size=1024), 449 | ) 450 | eval_loaders[split] = eval_loader 451 | 452 | return train_loader, eval_loaders, info 453 | -------------------------------------------------------------------------------- /md4/main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Main file for running the example. 17 | 18 | This file is intentionally kept short. The majority for logic is in libraries 19 | than can be easily tested and imported in Colab. 20 | """ 21 | 22 | from absl import app 23 | from absl import flags 24 | from absl import logging 25 | # Required import to setup work units when running through XManager. 26 | from clu import platform 27 | import jax 28 | from ml_collections import config_flags 29 | import tensorflow.compat.v2 as tf 30 | 31 | from md4 import sharded_train 32 | from md4 import train 33 | 34 | 35 | FLAGS = flags.FLAGS 36 | 37 | config_flags.DEFINE_config_file( 38 | "config", None, "Training configuration.", lock_config=True) 39 | flags.DEFINE_string("workdir", None, "Work unit directory.") 40 | flags.DEFINE_boolean("sharded", False, "Whether to use sharded training.") 41 | flags.mark_flags_as_required(["config", "workdir"]) 42 | # Flags --jax_backend_target and --jax_xla_backend are available through JAX. 43 | 44 | 45 | def main(argv): 46 | del argv 47 | 48 | tf.enable_v2_behavior() 49 | # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make 50 | # it unavailable to JAX. 51 | tf.config.experimental.set_visible_devices([], "GPU") 52 | 53 | if FLAGS.jax_backend_target: 54 | logging.info("Using JAX backend target %s", FLAGS.jax_backend_target) 55 | jax_xla_backend = ("None" if FLAGS.jax_xla_backend is None else 56 | FLAGS.jax_xla_backend) 57 | logging.info("Using JAX XLA backend %s", jax_xla_backend) 58 | 59 | logging.info("JAX process: %d / %d", jax.process_index(), jax.process_count()) 60 | logging.info("JAX devices: %r", jax.devices()) 61 | 62 | platform.work_unit().set_task_status(f"process_index: {jax.process_index()}, " 63 | f"process_count: {jax.process_count()}") 64 | platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, 65 | FLAGS.workdir, "workdir") 66 | 67 | if FLAGS.sharded: 68 | sharded_train.train_and_evaluate(FLAGS.config, FLAGS.workdir) 69 | else: 70 | train.train_and_evaluate(FLAGS.config, FLAGS.workdir) 71 | 72 | 73 | if __name__ == "__main__": 74 | # Provide access to --jax_backend_target and --jax_xla_backend flags. 75 | jax.config.config_with_absl() 76 | run_main = app.run 77 | run_main(main) 78 | -------------------------------------------------------------------------------- /md4/models/backward.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Classifier implementation.""" 17 | 18 | from collections.abc import Sequence 19 | 20 | import flax.linen as nn 21 | import jax 22 | import jax.numpy as jnp 23 | 24 | from md4.networks import sharded_transformer 25 | from md4.networks import transformer 26 | from md4.networks import unet 27 | from md4.networks import uvit 28 | 29 | 30 | def get_timestep_embedding(timesteps, embedding_dim, dtype='float'): 31 | """Build sinusoidal embeddings.""" 32 | 33 | assert embedding_dim > 2 34 | # timesteps: [bs] 35 | half_dim = embedding_dim // 2 36 | emb = jnp.log(10_000) / (half_dim - 1) 37 | emb = jnp.exp(jnp.arange(half_dim, dtype='float32') * -emb) 38 | emb = timesteps.astype('float32')[:, None] * emb[None, :] 39 | emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=1) 40 | if embedding_dim % 2 == 1: # zero pad 41 | emb = jax.lax.pad(emb, jnp.array(0, dtype), ((0, 0, 0), (0, 1, 0))) 42 | # ret: [bs, embedding_dim] 43 | return emb 44 | 45 | 46 | class CondEmbedding(nn.Module): 47 | """Time and cond embeddings.""" 48 | 49 | embedding_dim: int = 256 50 | 51 | @nn.compact 52 | def __call__(self, t, cond=None): 53 | # t: [bs] 54 | n_embd = self.embedding_dim 55 | temb = get_timestep_embedding(t, n_embd) 56 | if cond is None: 57 | cond = temb 58 | else: 59 | cond = jnp.concatenate([temb, cond], axis=-1) 60 | cond = nn.swish(nn.Dense(features=n_embd * 4, name='dense0')(cond)) 61 | cond = nn.Dense(n_embd)(cond) 62 | return cond 63 | 64 | 65 | class UNet5DWrapper(nn.Module): 66 | """5D to 5D UNet wrapper.""" 67 | 68 | feature_dim: int = 128 69 | n_layers: int = 32 70 | n_dit_layers: int = 0 71 | dit_num_heads: int = 12 72 | dit_hidden_size: int = 768 73 | ch_mult: Sequence[int] = (1,) 74 | output_channels: int = 256 75 | dropout_rate: float = 0.1 76 | 77 | @nn.compact 78 | def __call__(self, z, cond=None, train=False): 79 | # [bs, h, w, c, d or |V|] -> [bs, h, w, c, d or |V|] 80 | # Flatten the last two dimensions to pass to UNet 81 | h = z.reshape(list(z.shape)[:-2] + [-1]) 82 | 83 | if self.n_dit_layers > 0: 84 | h = uvit.UNet( 85 | d_channels=self.feature_dim, 86 | n_layers=self.n_layers, 87 | n_dit_layers=self.n_dit_layers, 88 | dit_num_heads=self.dit_num_heads, 89 | dit_hidden_size=self.dit_hidden_size, 90 | ch_mult=self.ch_mult, 91 | output_channels=self.output_channels * z.shape[-2], 92 | dropout_rate=self.dropout_rate, 93 | )(h, cond=cond, train=train) 94 | else: 95 | h = unet.UNet( 96 | d_channels=self.feature_dim, 97 | n_layers=self.n_layers, 98 | output_channels=self.output_channels * z.shape[-2], 99 | dropout_rate=self.dropout_rate, 100 | )(h, cond=cond, train=train) 101 | 102 | # ret: [bs, h, w, c, output_channels] 103 | return h.reshape(list(z.shape)[:-1] + [self.output_channels]) 104 | 105 | 106 | class DiscreteClassifier(nn.Module): 107 | """Discrete input classifier implementation.""" 108 | 109 | n_layers: int = 12 110 | n_dit_layers: int = 0 111 | dit_num_heads: int = 12 112 | dit_hidden_size: int = 768 113 | ch_mult: Sequence[int] = (1,) 114 | feature_dim: int = 64 115 | num_heads: int = 12 116 | vocab_size: int = 1000 117 | dropout_rate: float = 0.0 118 | use_attn_dropout: bool = True 119 | mlp_type: str = 'swiglu' 120 | depth_scaled_init: bool = False 121 | cond_type: str = 'adaln' 122 | outside_embed: bool = False 123 | model_sharding: bool = False 124 | 125 | @nn.compact 126 | def __call__(self, z, t=None, cond=None, train=False): 127 | if t is not None: 128 | # z: [bs, seq_len] or [bs, h, w, c] 129 | assert jnp.isscalar(t) or t.ndim == 0 or t.ndim == 1 130 | t = t * jnp.ones(z.shape[0]) # ensure t is a vector 131 | cond = CondEmbedding(self.feature_dim)(t * 1000, cond=cond) 132 | 133 | if z.ndim == 2: 134 | if self.outside_embed: 135 | z = nn.Embed(self.vocab_size + 1, self.feature_dim)(z) 136 | if self.model_sharding: 137 | args = sharded_transformer.ModelArgs( 138 | dim=self.feature_dim * self.num_heads, 139 | n_layers=self.n_layers, 140 | n_heads=self.num_heads, 141 | n_kv_heads=self.num_heads, 142 | output_channels=self.vocab_size, 143 | multiple_of=32, 144 | dropout_rate=self.dropout_rate, 145 | depth_scaled_init=self.depth_scaled_init, 146 | mlp_type=self.mlp_type, 147 | cond_type=self.cond_type, 148 | embed_input=not self.outside_embed, 149 | n_embed_classes=self.vocab_size + 1, 150 | use_attn_dropout=self.use_attn_dropout, 151 | ) 152 | # [bs, seq_len] -> [bs, seq_len, |V|] 153 | net = sharded_transformer.Transformer(args) 154 | else: 155 | args = transformer.ModelArgs( 156 | dim=self.feature_dim * self.num_heads, 157 | n_layers=self.n_layers, 158 | n_heads=self.num_heads, 159 | n_kv_heads=self.num_heads, 160 | output_channels=self.vocab_size, 161 | multiple_of=32, 162 | dropout_rate=self.dropout_rate, 163 | depth_scaled_init=self.depth_scaled_init, 164 | mlp_type=self.mlp_type, 165 | cond_type=self.cond_type, 166 | embed_input=not self.outside_embed, 167 | n_embed_classes=self.vocab_size + 1, 168 | ) 169 | # [bs, seq_len] -> [bs, seq_len, |V|] 170 | net = transformer.Transformer(args) 171 | logits = net(z, cond=cond, train=train) 172 | elif z.ndim == 4: 173 | z = nn.Embed(self.vocab_size + 1, self.feature_dim)(z) 174 | 175 | # [bs, h, w, c, d] -> [bs, h, w, c, |V|] 176 | net = UNet5DWrapper( 177 | feature_dim=self.feature_dim, 178 | n_layers=self.n_layers, 179 | n_dit_layers=self.n_dit_layers, 180 | dit_num_heads=self.dit_num_heads, 181 | dit_hidden_size=self.dit_hidden_size, 182 | ch_mult=self.ch_mult, 183 | output_channels=self.vocab_size, 184 | dropout_rate=self.dropout_rate, 185 | ) 186 | logits = net(z, cond=cond, train=train) 187 | else: 188 | raise NotImplementedError() 189 | 190 | return logits, {} 191 | -------------------------------------------------------------------------------- /md4/models/diffusion/genmd4.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Generalized state-dependent masked diffusion (GenMD4).""" 17 | 18 | from typing import Sequence 19 | 20 | import flax.linen as nn 21 | import jax 22 | import jax.numpy as jnp 23 | import tensorflow_probability.substrates.jax as tfp 24 | 25 | from md4 import utils 26 | from md4.models import backward 27 | 28 | 29 | tfd = tfp.distributions 30 | 31 | 32 | class LearnableVecMaskingSchedule(nn.Module): 33 | """Learnable vector-valued masking schedule for GenMD4.""" 34 | 35 | data_shape: tuple[int, ...] 36 | schedule_fn_type: str = 'poly' 37 | vocab_size: int = 256 38 | eps: float = 1e-4 39 | power_init: float = 1.0 40 | 41 | def setup(self): 42 | if self.schedule_fn_type == 'poly': 43 | w_init = jnp.log(jnp.exp(self.power_init) - 1.0) 44 | self.w = self.param('w', utils.constant_init(w_init), [self.vocab_size]) 45 | # Reduce to MD4 with a shared scalar schedule: 46 | # self.w = self.param('w', utils.constant_init(w_init), []) 47 | # self.power = jnp.tile(nn.softplus(self.w)[..., None], [self.vocab_size]) 48 | self.power = nn.softplus(self.w) 49 | else: 50 | raise NotImplementedError() 51 | 52 | def __call__(self, t): 53 | # return logSNR 54 | return jnp.log(self.alpha(t) / (1.0 - self.alpha(t))) 55 | 56 | def dalpha(self, t): 57 | if self.schedule_fn_type == 'poly': 58 | # ret: [..., |V|] 59 | return ( 60 | -(1.0 - self.eps) 61 | * self.power 62 | * jnp.array(t)[..., None] ** (self.power - 1.0) 63 | ) 64 | else: 65 | raise NotImplementedError() 66 | 67 | def alpha(self, t): 68 | if self.schedule_fn_type == 'poly': 69 | # instead of offsetting alpha_0 by eps as in MD4 class, we set alpha_0=1 70 | # and use a small non-zero t1 to avoid numerical issues, this gives a 71 | # nicer form of dgamma_times_alpha which is -w/t for a polynomial 72 | # schedule 1 - t**w 73 | # ret: [..., |V|] 74 | return 1.0 - (1.0 - self.eps) * jnp.array(t)[..., None] ** self.power 75 | else: 76 | raise NotImplementedError() 77 | 78 | def dgamma_times_alpha(self, t): 79 | # ret: [..., |V|] 80 | return -self.power / jnp.array(t)[..., None] 81 | 82 | 83 | class GenMD4(nn.Module): 84 | """Generalized state-Dependent masked discrete diffusion model.""" 85 | 86 | data_shape: tuple[int, ...] 87 | cont_time: bool = False 88 | timesteps: int = 1000 89 | feature_dim: int = 128 90 | num_heads: int = 12 91 | antithetic_time_sampling: bool = True 92 | n_layers: int = 32 93 | n_dit_layers: int = 0 94 | dit_num_heads: int = 12 95 | dit_hidden_size: int = 768 96 | ch_mult: Sequence[int] = (1,) 97 | vocab_size: int = 256 98 | noise_schedule_type: str = 'poly' 99 | power_init: float = 1.0 100 | t1: float = 1e-3 101 | dropout_rate: float = 0.0 102 | use_attn_dropout: bool = True 103 | mlp_type: str = 'swiglu' 104 | depth_scaled_init: bool = False 105 | cond_type: str = 'adaln' 106 | outside_embed: bool = False 107 | # time_features: t or none 108 | time_features: str = 't' 109 | classes: int = 10 + 1 # image classes 110 | 111 | def setup(self): 112 | self.noise_schedule = LearnableVecMaskingSchedule( 113 | self.data_shape, 114 | schedule_fn_type=self.noise_schedule_type, 115 | vocab_size=self.vocab_size, 116 | power_init=self.power_init, 117 | ) 118 | 119 | if self.classes > 0: 120 | self.cond_embeddings = nn.Embed(self.classes, self.feature_dim) 121 | self.classifier = backward.DiscreteClassifier( 122 | n_layers=self.n_layers, 123 | n_dit_layers=self.n_dit_layers, 124 | dit_num_heads=self.dit_num_heads, 125 | dit_hidden_size=self.dit_hidden_size, 126 | ch_mult=self.ch_mult, 127 | feature_dim=self.feature_dim, 128 | num_heads=self.num_heads, 129 | vocab_size=self.vocab_size, 130 | dropout_rate=self.dropout_rate, 131 | use_attn_dropout=self.use_attn_dropout, 132 | mlp_type=self.mlp_type, 133 | depth_scaled_init=self.depth_scaled_init, 134 | cond_type=self.cond_type, 135 | outside_embed=self.outside_embed, 136 | ) 137 | 138 | def forward_sample(self, x, t): 139 | t = utils.reverse_broadcast(t, x.ndim) 140 | # alpha_t: [bs, 1, |V|] or [bs, 1, 1, 1, |V|] 141 | a = self.noise_schedule.alpha(t) 142 | # un_mask_p: [bs, seq_len] or [bs, h, w, c] 143 | un_mask_p = jnp.sum(a * nn.one_hot(x, self.vocab_size), axis=-1) 144 | un_mask = jax.random.bernoulli(self.make_rng('sample'), un_mask_p, x.shape) 145 | # MASK = vocab_size 146 | return jnp.where(un_mask, x, self.vocab_size) 147 | 148 | def prior_sample(self, batch_size): 149 | return self.vocab_size * jnp.ones( 150 | [batch_size] + list(self.data_shape), dtype='int32' 151 | ) 152 | 153 | def get_cond_embedding(self, conditioning): 154 | if conditioning is not None: 155 | return self.cond_embeddings(conditioning) 156 | return None 157 | 158 | def predict_x(self, zt, t, cond=None, train=False): 159 | t = None if self.time_features == 'none' else t 160 | return self.classifier(zt, t=t, cond=cond, train=train) 161 | 162 | def visualize_classifier(self, x, t, conditioning=None): 163 | # if it's image, x: [bs, h, w, c] 164 | # if it's text, x: [bs, seq_len] 165 | cond = self.get_cond_embedding(conditioning) 166 | # t: [] 167 | # if it's image, zt: [bs, h, w, c] 168 | # if it's text, zt: [bs, seq_len] 169 | zt = self.forward_sample(x, t) 170 | # logits: [bs, h, w, c, vocab_size] for images 171 | # [bs, seq_len, vocab_size] for text 172 | logits, _ = self.predict_x(zt, t, cond=cond) 173 | n_indep_axes = logits.ndim - 2 174 | dist = tfd.Independent(tfd.Categorical(logits=logits), n_indep_axes) 175 | return dist 176 | 177 | def encode(self, x, conditioning=None): 178 | del conditioning 179 | return x 180 | 181 | def recon_loss(self, x): 182 | """The reconstruction loss measures the gap in the first step.""" 183 | assert self.noise_schedule_type == 'poly' 184 | eps = self.noise_schedule.eps 185 | # w: [|V|] 186 | w = self.noise_schedule.power 187 | # w_x: [bs, seq_len] or [bs, h, w, c] 188 | w_x = jnp.sum(w * nn.one_hot(x, self.vocab_size), axis=-1) 189 | t = jnp.array(self.t1) 190 | # wlogt_x: [bs, seq_len] or [bs, h, w, c] 191 | wlogt_x = w_x * jnp.log(t) 192 | # wlogt: [|V|] 193 | wlogt = w * jnp.log(t) 194 | remaining_axis = list(range(x.ndim)[1:]) 195 | # loss_recon: [bs, seq_len] or [bs, h, w, c] 196 | loss_recon = ( 197 | -(1 - eps) * jnp.exp(wlogt_x) * (wlogt_x - nn.logsumexp(wlogt, -1)) 198 | ).sum(remaining_axis) 199 | return loss_recon 200 | 201 | def latent_loss(self): 202 | # negligible 203 | return jnp.array(0.0) 204 | 205 | def diffusion_loss(self, t, x, cond=None, train=False): 206 | assert self.cont_time 207 | 208 | # sample z_t 209 | zt = self.forward_sample(x, t) 210 | logits, _ = self.predict_x(zt, t, cond=cond, train=train) 211 | log_p = jax.nn.log_softmax(logits, axis=-1) 212 | one_hot_x = jax.nn.one_hot(x, self.vocab_size) 213 | neg_cross_ent = one_hot_x * log_p 214 | neg_cross_ent = jnp.where(one_hot_x, neg_cross_ent, 0.0) 215 | neg_cross_ent = jnp.sum(neg_cross_ent, axis=-1, keepdims=True) 216 | integrand = (neg_cross_ent + 1.0) * one_hot_x - jnp.exp(log_p) 217 | mask = (zt == self.vocab_size).astype('float') 218 | 219 | remaining_axis = list(range(x.ndim)[1:]) 220 | # masked_neg_cross_ent: [bs, |V|] 221 | masked_neg_cross_ent = jnp.sum(mask[..., None] * integrand, remaining_axis) 222 | 223 | # cont-time loss 224 | loss_diff = ( 225 | self.noise_schedule.dgamma_times_alpha(t) * masked_neg_cross_ent 226 | ).sum(axis=-1) 227 | 228 | # loss_diff: [bs] 229 | return loss_diff, zt 230 | 231 | def reinforce_loss(self, t, x, zt_1, zt_2, loss_diff_1, loss_diff_2): 232 | assert self.noise_schedule_type == 'poly' 233 | eps = self.noise_schedule.eps 234 | # w: [|V|] 235 | w = self.noise_schedule.power 236 | # w_x: [bs, seq_len] or [bs, h, w, c] 237 | w_x = jnp.sum(w * nn.one_hot(x, self.vocab_size), axis=-1) 238 | # t: [bs, 1] or [bs, 1, 1, 1] 239 | t = utils.reverse_broadcast(t, x.ndim) 240 | # alpha_t_x: [bs, seq_len] or [bs, h, w, c] 241 | alpha_t_x = 1.0 - (1.0 - eps) * t**w_x 242 | # log_q_mask = jnp.log(1.0 - alpha_t_x) 243 | log_q_mask = jnp.log(1.0 - eps) + w_x * jnp.log(t) 244 | log_q_unmask = jnp.log(alpha_t_x) 245 | log_q1 = jnp.where(zt_1 == self.vocab_size, log_q_mask, log_q_unmask) 246 | log_q2 = jnp.where(zt_2 == self.vocab_size, log_q_mask, log_q_unmask) 247 | remaining_axis = list(range(x.ndim)[1:]) 248 | rloo_1 = ( 249 | 0.5 250 | * jax.lax.stop_gradient(loss_diff_1 - loss_diff_2) 251 | * (log_q1.sum(remaining_axis) - log_q2.sum(remaining_axis)) 252 | ) 253 | return rloo_1 254 | 255 | @nn.compact 256 | def __call__(self, x, cond=None, train=False): 257 | bs = x.shape[0] 258 | cond = self.get_cond_embedding(cond) 259 | 260 | # 1. RECONSTRUCTION LOSS: [] 261 | # add noise and reconstruct 262 | loss_recon = self.recon_loss(x).mean() 263 | 264 | # 2. LATENT LOSS: [] 265 | loss_prior = self.latent_loss() 266 | 267 | # 3. DIFFUSION LOSS: [bs] 268 | # sample time steps 269 | rng1 = self.make_rng('sample') 270 | if self.antithetic_time_sampling: 271 | t0 = jax.random.uniform(rng1) 272 | t = jnp.mod(t0 + jnp.arange(0.0, 1.0, step=1.0 / bs), 1.0) 273 | else: 274 | t = jax.random.uniform(rng1, shape=[bs]) 275 | # rescale t to be in [t1, 1.0] 276 | t = (1 - self.t1) * t + self.t1 277 | 278 | loss_diff_1, zt_1 = self.diffusion_loss(t, x, cond=cond, train=train) 279 | loss_diff_2, zt_2 = self.diffusion_loss(t, x, cond=cond, train=train) 280 | rloo_1 = self.reinforce_loss(t, x, zt_1, zt_2, loss_diff_1, loss_diff_2) 281 | loss_diff = 0.5 * (loss_diff_1 + loss_diff_2) 282 | loss_diff_sg = loss_diff + rloo_1 283 | 284 | # surrogate loss that includes the reinforce term 285 | loss = loss_diff_sg.mean() + loss_prior + loss_recon 286 | loss_diff = loss_diff.mean() 287 | # negative elbo 288 | loss_nelbo = loss_diff + loss_prior + loss_recon 289 | 290 | model_stats = { 291 | 'loss': loss, 292 | 'loss_nelbo': loss_nelbo, 293 | 'loss_diff': loss_diff, 294 | 'loss_prior': loss_prior, 295 | 'loss_recon': loss_recon, 296 | 'power_max': self.noise_schedule.power.max(), 297 | 'power_min': self.noise_schedule.power.min(), 298 | 'power_avg': self.noise_schedule.power.mean(), 299 | } 300 | model_stats = utils.loss2bpt(model_stats, self.data_shape) 301 | return model_stats 302 | -------------------------------------------------------------------------------- /md4/models/diffusion/md4.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Simplified masked diffusion (MD4).""" 17 | 18 | import math 19 | from typing import Sequence 20 | 21 | import flax.linen as nn 22 | import jax 23 | import jax.numpy as jnp 24 | import tensorflow_probability.substrates.jax as tfp 25 | 26 | from md4 import binary_search 27 | from md4 import utils 28 | from md4.models import backward 29 | 30 | 31 | tfd = tfp.distributions 32 | 33 | 34 | class MaskingSchedule(nn.Module): 35 | """Masking noise schedule.""" 36 | 37 | data_shape: tuple[int, ...] 38 | schedule_fn_type: str = 'cosine' 39 | eps: float = 1e-4 40 | 41 | def __call__(self, t): 42 | # return logSNR 43 | return jnp.log(self.alpha(t) / (1.0 - self.alpha(t))) 44 | 45 | def _dalpha(self, t): 46 | if self.schedule_fn_type == 'cosine': 47 | return -math.pi / 2.0 * jax.lax.sin(math.pi / 2.0 * (1.0 - t)) 48 | elif self.schedule_fn_type == 'linear': 49 | return -jnp.ones_like(t) 50 | elif 'poly' in self.schedule_fn_type: 51 | exponent = float(self.schedule_fn_type.replace('poly', '')) 52 | return -exponent * t ** (exponent - 1.0) 53 | else: 54 | raise NotImplementedError() 55 | 56 | def dalpha(self, t): 57 | return (1.0 - 2 * self.eps) * self._dalpha(t) 58 | 59 | def _alpha(self, t): 60 | if self.schedule_fn_type == 'linear': 61 | return 1.0 - t 62 | elif 'poly' in self.schedule_fn_type: 63 | exponent = float(self.schedule_fn_type.replace('poly', '')) 64 | return 1.0 - t**exponent 65 | elif self.schedule_fn_type == 'cosine': 66 | return 1.0 - jax.lax.cos(math.pi / 2.0 * (1.0 - t)) 67 | else: 68 | raise NotImplementedError() 69 | 70 | def alpha(self, t): 71 | return (1.0 - 2 * self.eps) * self._alpha(t) + self.eps 72 | 73 | def dgamma_times_alpha(self, t): 74 | return self.dalpha(t) / (1.0 - self.alpha(t)) 75 | 76 | 77 | class MD4(nn.Module): 78 | """Simplified masked discrete diffusion model.""" 79 | 80 | data_shape: tuple[int, ...] 81 | cont_time: bool = False 82 | timesteps: int = 1000 83 | feature_dim: int = 128 84 | num_heads: int = 12 85 | antithetic_time_sampling: bool = True 86 | n_layers: int = 32 87 | n_dit_layers: int = 0 88 | dit_num_heads: int = 12 89 | dit_hidden_size: int = 768 90 | ch_mult: Sequence[int] = (1,) 91 | vocab_size: int = 256 92 | noise_schedule_type: str = 'linear' 93 | dropout_rate: float = 0.0 94 | use_attn_dropout: bool = True 95 | mlp_type: str = 'swiglu' 96 | depth_scaled_init: bool = False 97 | cond_type: str = 'adaln' 98 | outside_embed: bool = False 99 | # time_features: t or none 100 | time_features: str = 't' 101 | classes: int = 10 + 1 # image classes 102 | sampler: str = 'analytic' 103 | # uniform, cosine 104 | sampling_grid: str = 'cosine' 105 | topp: float = 0.98 106 | model_sharding: bool = False 107 | 108 | def setup(self): 109 | self.noise_schedule = MaskingSchedule( 110 | self.data_shape, self.noise_schedule_type 111 | ) 112 | 113 | if self.classes > 0: 114 | self.cond_embeddings = nn.Embed(self.classes, self.feature_dim) 115 | self.classifier = backward.DiscreteClassifier( 116 | n_layers=self.n_layers, 117 | n_dit_layers=self.n_dit_layers, 118 | dit_num_heads=self.dit_num_heads, 119 | dit_hidden_size=self.dit_hidden_size, 120 | ch_mult=self.ch_mult, 121 | feature_dim=self.feature_dim, 122 | num_heads=self.num_heads, 123 | vocab_size=self.vocab_size, 124 | dropout_rate=self.dropout_rate, 125 | use_attn_dropout=self.use_attn_dropout, 126 | mlp_type=self.mlp_type, 127 | depth_scaled_init=self.depth_scaled_init, 128 | cond_type=self.cond_type, 129 | outside_embed=self.outside_embed, 130 | model_sharding=self.model_sharding, 131 | ) 132 | 133 | def forward_sample(self, x, t): 134 | t = utils.reverse_broadcast(t, x.ndim) 135 | a = self.noise_schedule.alpha(t) 136 | un_mask = jax.random.bernoulli(self.make_rng('sample'), a, x.shape) 137 | # MASK = vocab_size 138 | return jnp.where(un_mask, x, self.vocab_size) 139 | 140 | def prior_sample(self, batch_size): 141 | return self.vocab_size * jnp.ones( 142 | [batch_size] + list(self.data_shape), dtype='int32' 143 | ) 144 | 145 | def get_cond_embedding(self, conditioning): 146 | if conditioning is not None: 147 | return self.cond_embeddings(conditioning) 148 | return None 149 | 150 | def predict_x(self, zt, t, cond=None, train=False): 151 | t = None if self.time_features == 'none' else t 152 | return self.classifier(zt, t=t, cond=cond, train=train) 153 | 154 | def visualize_classifier(self, x, t, conditioning=None): 155 | # if it's image, x: [bs, h, w, c] 156 | # if it's text, x: [bs, seq_len] 157 | cond = self.get_cond_embedding(conditioning) 158 | # t: [] 159 | # if it's image, zt: [bs, h, w, c] 160 | # if it's text, zt: [bs, seq_len] 161 | zt = self.forward_sample(x, t) 162 | # logits: [bs, h, w, c, vocab_size] for images 163 | # [bs, seq_len, vocab_size] for text 164 | logits, _ = self.predict_x(zt, t, cond=cond) 165 | n_indep_axes = logits.ndim - 2 166 | dist = tfd.Independent(tfd.Categorical(logits=logits), n_indep_axes) 167 | return dist 168 | 169 | def encode(self, x, conditioning=None): 170 | del conditioning 171 | return x 172 | 173 | def decode(self, z0, conditioning=None): 174 | # Remove any mask tokens left in the last step of sampling. 175 | masked = z0 == self.vocab_size 176 | z0_cliped = jnp.where(masked, jnp.zeros_like(z0), z0) 177 | masked = masked[..., None] 178 | cond = self.get_cond_embedding(conditioning) 179 | logits, _ = self.predict_x(z0, jnp.array(0.0), cond=cond) 180 | probs = jnp.where( 181 | masked, 182 | nn.softmax(logits, axis=-1), 183 | jax.nn.one_hot(z0_cliped, self.vocab_size), 184 | ) 185 | n_indep_axes = probs.ndim - 2 186 | dist = tfd.Independent(tfd.Categorical(probs=probs), n_indep_axes) 187 | return dist.mode().astype('int32') 188 | 189 | def recon_loss(self): 190 | """The reconstruction loss measures the gap in the first step.""" 191 | alpha_t1 = self.noise_schedule.alpha(0.0) 192 | loss_recon = ( 193 | jnp.prod(jnp.array(self.data_shape)) 194 | * (1.0 - alpha_t1) 195 | * jnp.log(self.vocab_size) 196 | ) 197 | return loss_recon 198 | 199 | def latent_loss(self): 200 | # negligible 201 | return jnp.array(0.0) 202 | 203 | def diffusion_loss(self, t, x, cond=None, train=False): 204 | if not self.cont_time: 205 | # discretize time steps 206 | t = (jnp.floor(t * self.timesteps) + 1) / self.timesteps 207 | 208 | # sample z_t 209 | zt = self.forward_sample(x, t) 210 | logits, _ = self.predict_x(zt, t, cond=cond, train=train) 211 | log_p = jax.nn.log_softmax(logits, axis=-1) 212 | one_hot_x = jax.nn.one_hot(x, self.vocab_size) 213 | neg_cross_ent = one_hot_x * log_p 214 | neg_cross_ent = jnp.where(one_hot_x, neg_cross_ent, 0.0) 215 | neg_cross_ent = jnp.sum(neg_cross_ent, axis=-1) 216 | mask = (zt == self.vocab_size).astype('float32') 217 | 218 | remaining_axis = list(range(x.ndim)[1:]) 219 | # masked_neg_cross_ent: [bs] 220 | masked_neg_cross_ent = jnp.sum(mask * neg_cross_ent, remaining_axis) 221 | 222 | if not self.cont_time: 223 | # loss for finite depth T, i.e. discrete time 224 | s = t - (1.0 / self.timesteps) 225 | gt = self.noise_schedule(t) 226 | gs = self.noise_schedule(s) 227 | loss_diff = ( 228 | self.timesteps 229 | * jnp.expm1(gt - gs) 230 | * self.noise_schedule.alpha(s) 231 | * masked_neg_cross_ent 232 | ) 233 | else: 234 | # cont-time loss 235 | loss_diff = ( 236 | self.noise_schedule.dgamma_times_alpha(t) * masked_neg_cross_ent 237 | ) 238 | 239 | # loss_diff: [bs] 240 | return loss_diff 241 | 242 | @nn.compact 243 | def __call__(self, x, cond=None, train=False): 244 | bs = x.shape[0] 245 | cond = self.get_cond_embedding(cond) 246 | 247 | # 1. RECONSTRUCTION LOSS: [] 248 | # add noise and reconstruct 249 | loss_recon = self.recon_loss() 250 | 251 | # 2. LATENT LOSS: [] 252 | loss_prior = self.latent_loss() 253 | 254 | # 3. DIFFUSION LOSS: [bs] 255 | # sample time steps 256 | rng1 = self.make_rng('sample') 257 | if self.antithetic_time_sampling: 258 | t0 = jax.random.uniform(rng1) 259 | t = jnp.mod(t0 + jnp.arange(0.0, 1.0, step=1.0 / bs), 1.0) 260 | else: 261 | t = jax.random.uniform(rng1, shape=[bs]) 262 | 263 | loss_diff = self.diffusion_loss(t, x, cond=cond, train=train).mean() 264 | loss = loss_diff + loss_prior + loss_recon 265 | 266 | model_stats = { 267 | 'loss': loss, 268 | 'loss_diff': loss_diff, 269 | 'loss_prior': loss_prior, 270 | 'loss_recon': loss_recon, 271 | } 272 | model_stats = utils.loss2bpt(model_stats, self.data_shape) 273 | return model_stats 274 | 275 | def get_sampling_grid(self, i, timesteps): 276 | t = (timesteps - i) / timesteps 277 | s = t - 1 / timesteps 278 | if self.sampling_grid == 'cosine': 279 | t = jnp.cos(math.pi / 2.0 * (1.0 - t)) 280 | s = jnp.cos(math.pi / 2.0 * (1.0 - s)) 281 | return s, t 282 | 283 | def ancestral_sample_step(self, rng, i, timesteps, zt, conditioning=None): 284 | rng_body = jax.random.fold_in(rng, i) 285 | s, t = self.get_sampling_grid(i, timesteps) 286 | cond = self.get_cond_embedding(conditioning) 287 | 288 | alpha_t = self.noise_schedule.alpha(t) 289 | alpha_s = self.noise_schedule.alpha(s) 290 | 291 | logits, _ = self.predict_x(zt, t, cond=cond) 292 | mean_preds = jax.nn.softmax(logits, axis=-1) 293 | 294 | unmask_prob = (alpha_s - alpha_t) / (1 - alpha_t) 295 | probs_vocab = unmask_prob * mean_preds 296 | 297 | probs_mask = jnp.ones(list(zt.shape) + [1]) * (1 - unmask_prob) 298 | probs = jnp.concatenate([probs_vocab, probs_mask], axis=-1) 299 | 300 | to_unmask = tfd.Categorical(probs=probs).sample(seed=rng_body) 301 | is_mask = zt == self.vocab_size 302 | zs = jnp.where(is_mask, to_unmask, zt) 303 | return zs 304 | 305 | def topp_sample_step( 306 | self, rng, i, timesteps, zt, conditioning=None, topp=0.98 307 | ): 308 | rng_body = jax.random.fold_in(rng, i) 309 | s, t = self.get_sampling_grid(i, timesteps) 310 | cond = self.get_cond_embedding(conditioning) 311 | 312 | alpha_t = self.noise_schedule.alpha(t) 313 | alpha_s = self.noise_schedule.alpha(s) 314 | 315 | logits, _ = self.predict_x(zt, t, cond=cond) 316 | logits = binary_search.topp_mask(logits, topp, replace_val=jnp.array(-1e7)) 317 | # mean_preds: [bs, ..., vocab] 318 | mean_preds = jax.nn.softmax(logits, axis=-1) 319 | 320 | unmask_prob = (alpha_s - alpha_t) / (1 - alpha_t) 321 | probs_vocab = unmask_prob * mean_preds 322 | 323 | probs_mask = jnp.ones(list(zt.shape) + [1]) * (1 - unmask_prob) 324 | probs = jnp.concatenate([probs_vocab, probs_mask], axis=-1) 325 | 326 | to_unmask = tfd.Categorical(probs=probs).sample(seed=rng_body) 327 | is_mask = zt == self.vocab_size 328 | zs = jnp.where(is_mask, to_unmask, zt) 329 | return zs 330 | 331 | def mean_sample_step(self, rng, i, timesteps, zt, conditioning=None): 332 | # Ancestral sampling done in two steps -- tends to be worse than one-step 333 | # implementation in ancestral_sample_step. See App. G of 334 | # https://arxiv.org/abs/2406.04329. 335 | rng_body = jax.random.fold_in(rng, i) 336 | s, t = self.get_sampling_grid(i, timesteps) 337 | cond = self.get_cond_embedding(conditioning) 338 | 339 | alpha_t = self.noise_schedule.alpha(t) 340 | alpha_s = self.noise_schedule.alpha(s) 341 | 342 | logits, _ = self.predict_x(zt, t, cond=cond) 343 | unmask_prob = (alpha_s - alpha_t) / (1 - alpha_t) 344 | 345 | rng_body, rng = jax.random.split(rng_body) 346 | z0 = tfd.Categorical(logits=logits).sample(seed=rng_body) 347 | 348 | rng_body, _ = jax.random.split(rng) 349 | unmask = jax.random.bernoulli(rng_body, unmask_prob, zt.shape) 350 | 351 | to_unmask = jnp.where(unmask, z0, zt) 352 | is_mask = zt == self.vocab_size 353 | zs = jnp.where(is_mask, to_unmask, zt) 354 | return zs 355 | 356 | def sample_step(self, rng, i, timesteps, zt, conditioning=None, topp=None): 357 | if self.sampler == 'ancestral': 358 | return self.ancestral_sample_step( 359 | rng, i, timesteps, zt, conditioning=conditioning 360 | ) 361 | elif self.sampler == 'topp': 362 | topp = self.topp if topp is None else topp 363 | return self.topp_sample_step( 364 | rng, i, timesteps, zt, conditioning=conditioning, topp=topp 365 | ) 366 | elif self.sampler == 'mean': 367 | return self.mean_sample_step( 368 | rng, i, timesteps, zt, conditioning=conditioning 369 | ) 370 | else: 371 | raise NotImplementedError() 372 | -------------------------------------------------------------------------------- /md4/models/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Model utils.""" 17 | 18 | import ml_collections 19 | 20 | from md4.models.diffusion import genmd4 21 | from md4.models.diffusion import md4 22 | 23 | 24 | def get_model(config: ml_collections.ConfigDict): 25 | """Get model instances.""" 26 | if config.model_type == "md4": 27 | return md4.MD4( 28 | config.data_shape, 29 | cont_time=config.cont_time, 30 | timesteps=config.timesteps, 31 | feature_dim=config.feature_dim, 32 | num_heads=config.num_heads, 33 | n_layers=config.n_layers, 34 | n_dit_layers=config.n_dit_layers, 35 | dit_num_heads=config.dit_num_heads, 36 | dit_hidden_size=config.dit_hidden_size, 37 | ch_mult=config.ch_mult, 38 | vocab_size=config.vocab_size, 39 | noise_schedule_type=config.noise_schedule, 40 | dropout_rate=config.dropout_rate, 41 | use_attn_dropout=config.get("use_attn_dropout", True), 42 | mlp_type=config.mlp_type, 43 | depth_scaled_init=config.depth_scaled_init, 44 | cond_type=config.cond_type, 45 | outside_embed=config.outside_embed, 46 | time_features=config.time_features, 47 | classes=config.classes, 48 | sampler=config.sampler, 49 | sampling_grid=config.sampling_grid, 50 | topp=config.topp, 51 | model_sharding=config.get("model_sharding", False), 52 | ) 53 | elif config.model_type == "genmd4": 54 | return genmd4.GenMD4( 55 | config.data_shape, 56 | cont_time=config.cont_time, 57 | timesteps=config.timesteps, 58 | feature_dim=config.feature_dim, 59 | num_heads=config.num_heads, 60 | n_layers=config.n_layers, 61 | n_dit_layers=config.n_dit_layers, 62 | dit_num_heads=config.dit_num_heads, 63 | dit_hidden_size=config.dit_hidden_size, 64 | ch_mult=config.ch_mult, 65 | vocab_size=config.vocab_size, 66 | noise_schedule_type=config.noise_schedule, 67 | power_init=config.power_init, 68 | dropout_rate=config.dropout_rate, 69 | use_attn_dropout=config.get("use_attn_dropout", True), 70 | mlp_type=config.mlp_type, 71 | depth_scaled_init=config.depth_scaled_init, 72 | cond_type=config.cond_type, 73 | outside_embed=config.outside_embed, 74 | time_features=config.time_features, 75 | ) 76 | else: 77 | raise NotImplementedError() 78 | -------------------------------------------------------------------------------- /md4/multihost_dataloading.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Multihost dataloading utilities. 17 | 18 | Adapted from: 19 | https://github.com/AI-Hypercomputer/maxtext/blob/main/MaxText/multihost_dataloading.py 20 | """ 21 | 22 | from collections.abc import Iterable, Iterator 23 | import functools 24 | import time 25 | from typing import Union 26 | 27 | import grain.python as grain 28 | import jax 29 | from jax.sharding import Mesh 30 | from jax.sharding import NamedSharding 31 | from jax.sharding import PartitionSpec 32 | import jax.tree_util as jtu 33 | import numpy as np 34 | import tensorflow as tf # pylint: disable=g-import-not-at-top 35 | 36 | 37 | def _build_global_shape_and_sharding( 38 | local_shape: tuple[int, ...], global_mesh: Mesh 39 | ) -> tuple[tuple[int, ...], NamedSharding]: 40 | sharding = NamedSharding(global_mesh, PartitionSpec(global_mesh.axis_names)) 41 | 42 | global_shape = (jax.process_count() * local_shape[0],) + local_shape[1:] 43 | 44 | return global_shape, sharding 45 | 46 | 47 | def _form_global_array(path, array: np.ndarray, global_mesh: Mesh) -> jax.Array: 48 | """Put local sharded array into local devices.""" 49 | global_shape, sharding = _build_global_shape_and_sharding( 50 | np.shape(array), global_mesh 51 | ) 52 | 53 | try: 54 | local_device_arrays = np.split( 55 | array, len(global_mesh.local_devices), axis=0 56 | ) 57 | except ValueError as array_split_error: 58 | raise ValueError( 59 | f"Unable to put to devices shape {array.shape} with " 60 | f"local device count {len(global_mesh.local_devices)} " 61 | f"at {jtu.keystr(path)}" 62 | ) from array_split_error 63 | 64 | local_device_buffers = jax.device_put( 65 | local_device_arrays, global_mesh.local_devices 66 | ) 67 | return jax.make_array_from_single_device_arrays( 68 | global_shape, sharding, local_device_buffers 69 | ) 70 | 71 | 72 | def get_next_batch_sharded( 73 | local_iterator: Iterator[jax.Array], global_mesh: Mesh 74 | ) -> jax.Array: 75 | """Splits the host loaded data equally over all devices.""" 76 | 77 | sleep_time = 10 78 | max_data_load_attempts = 30 79 | 80 | data_load_attempts = 0 81 | loaded_data_success = False 82 | while not loaded_data_success and data_load_attempts < max_data_load_attempts: 83 | data_load_attempts += 1 84 | try: 85 | local_data = next(local_iterator) 86 | loaded_data_success = True 87 | except tf.errors.FailedPreconditionError: 88 | print("Failed to get next data batch, retrying") 89 | time.sleep(sleep_time) 90 | 91 | # Try one last time, if this fails we will see the full stack trace. 92 | if not loaded_data_success: 93 | local_data = next(local_iterator) 94 | 95 | input_gdas = jtu.tree_map_with_path( 96 | functools.partial(_form_global_array, global_mesh=global_mesh), local_data 97 | ) 98 | 99 | return input_gdas 100 | 101 | 102 | class MultiHostDataLoadIterator: 103 | """fold get_next_batch_sharded into a iterator class.""" 104 | 105 | def __init__( 106 | self, 107 | dataloader: Union[tf.data.Dataset, grain.DataLoader], 108 | global_mesh: Mesh, 109 | ): 110 | self.global_mesh = global_mesh 111 | self.dataloader = dataloader 112 | if isinstance(self.dataloader, tf.data.Dataset): 113 | self.local_iterator = self.dataloader.as_numpy_iterator() 114 | elif isinstance(self.dataloader, Iterable): 115 | self.local_iterator = iter(self.dataloader) 116 | else: 117 | raise ValueError( 118 | "Type error: dataloader should be either tf.data.Dataset or Iterable." 119 | ) 120 | 121 | def reset(self): 122 | if isinstance(self.dataloader, tf.data.Dataset): 123 | self.local_iterator = self.dataloader.as_numpy_iterator() 124 | elif isinstance(self.dataloader, Iterable): 125 | self.local_iterator = iter(self.dataloader) 126 | else: 127 | raise ValueError( 128 | "Type error: dataloader should be either tf.data.Dataset or" 129 | " grain.DataLoader." 130 | ) 131 | 132 | def __iter__(self): 133 | self.reset() 134 | return self 135 | 136 | def __next__(self): 137 | return get_next_batch_sharded(self.local_iterator, self.global_mesh) 138 | -------------------------------------------------------------------------------- /md4/networks/dit.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """DiT architecture. 17 | 18 | Jax implementation of https://arxiv.org/abs/2212.09748, based on PyTorch 19 | implementation https://github.com/facebookresearch/DiT/blob/main/models.py. 20 | """ 21 | 22 | import math 23 | from typing import Any 24 | 25 | import flax.linen as nn 26 | import jax.numpy as jnp 27 | import numpy as np 28 | 29 | from md4.networks import transformer 30 | # pylint: disable=missing-class-docstring 31 | 32 | 33 | def modulate(x, shift, scale): 34 | return x * (1 + scale) + shift 35 | 36 | 37 | class PatchEmbed(nn.Module): 38 | """2D Image to Patch Embedding.""" 39 | 40 | img_size: int = 224 41 | patch_size: int = 16 42 | embed_dim: int = 768 43 | flatten: bool = True 44 | use_bias: bool = True 45 | 46 | def setup(self): 47 | self.proj = nn.Conv( 48 | self.embed_dim, 49 | kernel_size=(self.patch_size, self.patch_size), 50 | strides=self.patch_size, 51 | padding='VALID', 52 | use_bias=self.use_bias, 53 | ) 54 | 55 | def __call__(self, x): 56 | x = self.proj(x) 57 | if self.flatten: 58 | x = x.reshape(x.shape[0], -1, x.shape[-1]) 59 | return x 60 | 61 | 62 | class Mlp(nn.Module): 63 | """MLP as used in Vision Transformer, MLP-Mixer and related networks.""" 64 | 65 | out_features: int 66 | hidden_features: int 67 | act: Any = nn.gelu 68 | use_bias: bool = True 69 | dropout_rate: float = 0.0 70 | 71 | def setup(self): 72 | self.fc1 = nn.Dense(self.hidden_features, use_bias=self.use_bias) 73 | self.drop1 = nn.Dropout(self.dropout_rate) 74 | self.fc2 = nn.Dense(self.out_features, use_bias=self.use_bias) 75 | self.drop2 = nn.Dropout(self.dropout_rate) 76 | 77 | def __call__(self, x, train=False): 78 | x = self.fc1(x) 79 | x = self.act(x) 80 | x = self.drop1(x, deterministic=not train) 81 | x = self.fc2(x) 82 | x = self.drop2(x, deterministic=not train) 83 | return x 84 | 85 | 86 | class Attention(nn.Module): 87 | 88 | dim: int 89 | n_heads: int 90 | n_kv_heads: int | None = None 91 | dropout_rate: float = 0.0 92 | qkv_bias: bool = False 93 | 94 | def setup(self): 95 | self._n_kv_heads = ( 96 | self.n_heads if self.n_kv_heads is None else self.n_kv_heads 97 | ) 98 | assert self.n_heads % self._n_kv_heads == 0 99 | self.n_rep = self.n_heads // self._n_kv_heads 100 | self.head_dim = self.dim // self.n_heads 101 | self.wq = nn.Dense(self.n_heads * self.head_dim, use_bias=self.qkv_bias) 102 | self.wk = nn.Dense(self._n_kv_heads * self.head_dim, use_bias=self.qkv_bias) 103 | self.wv = nn.Dense(self._n_kv_heads * self.head_dim, use_bias=self.qkv_bias) 104 | self.wo = nn.Dense(self.dim, use_bias=False) 105 | self.attn_dropout = nn.Dropout(self.dropout_rate) 106 | # self.resid_dropout = nn.Dropout(self.dropout_rate) 107 | self.resid_dropout = transformer.Dropout1d(self.dropout_rate) 108 | 109 | def __call__(self, x, train=False): 110 | bsz, seqlen, _ = x.shape 111 | 112 | # QKV 113 | xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) 114 | xq = xq.reshape(bsz, seqlen, self.n_heads, self.head_dim) 115 | xk = xk.reshape(bsz, seqlen, self._n_kv_heads, self.head_dim) 116 | xv = xv.reshape(bsz, seqlen, self._n_kv_heads, self.head_dim) 117 | 118 | # grouped multiquery attention: expand out keys and values 119 | xk = transformer.repeat_kv(xk, self.n_rep) 120 | xv = transformer.repeat_kv(xv, self.n_rep) 121 | 122 | # make heads into a batch dimension 123 | xq = xq.swapaxes(1, 2) # (bs, n_heads, seqlen, head_dim) 124 | xk = xk.swapaxes(1, 2) 125 | xv = xv.swapaxes(1, 2) 126 | 127 | scores = jnp.matmul(xq, xk.swapaxes(2, 3)) / math.sqrt(self.head_dim) 128 | scores = nn.softmax(scores, axis=-1) 129 | scores = self.attn_dropout(scores, deterministic=not train) 130 | output = jnp.matmul(scores, xv) # (bs, n_heads, seqlen, head_dim) 131 | 132 | # restore time as batch dimension and concat heads 133 | output = output.swapaxes(1, 2).reshape(bsz, seqlen, -1) 134 | 135 | # final projection into the residual stream 136 | output = self.wo(output) 137 | output = self.resid_dropout(output, deterministic=not train) 138 | return output 139 | 140 | 141 | class DiTBlock(nn.Module): 142 | """A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.""" 143 | 144 | hidden_size: int 145 | num_heads: int 146 | mlp_ratio: float = 4.0 147 | dropout_rate: float = 0.0 148 | 149 | def setup(self): 150 | self.attn = Attention( 151 | self.hidden_size, 152 | self.num_heads, 153 | dropout_rate=self.dropout_rate, 154 | qkv_bias=True, 155 | ) 156 | mlp_hidden_dim = int(self.hidden_size * self.mlp_ratio) 157 | self.mlp = Mlp( 158 | out_features=self.hidden_size, 159 | hidden_features=mlp_hidden_dim, 160 | act=nn.gelu, 161 | dropout_rate=self.dropout_rate, 162 | ) 163 | 164 | @nn.compact 165 | def __call__(self, x, cond=None, train=False): 166 | if cond is not None: 167 | adaln_modulation = nn.Sequential([ 168 | nn.swish, 169 | nn.Dense( 170 | 6 * self.hidden_size, 171 | kernel_init=nn.zeros_init(), 172 | bias_init=nn.zeros_init(), 173 | ), 174 | ]) 175 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( 176 | jnp.split(adaln_modulation(cond)[:, None, :], 6, axis=-1) 177 | ) 178 | norm1 = nn.LayerNorm(use_bias=False, use_scale=False) 179 | norm2 = nn.LayerNorm(use_bias=False, use_scale=False) 180 | x = x + gate_msa * self.attn( 181 | modulate(norm1(x), shift_msa, scale_msa), train=train 182 | ) 183 | x = x + gate_mlp * self.mlp( 184 | modulate(norm2(x), shift_mlp, scale_mlp), train=train 185 | ) 186 | else: 187 | x = x + self.attn(nn.RMSNorm()(x), train=train) 188 | x = x + self.mlp(nn.RMSNorm()(x), train=train) 189 | return x 190 | 191 | 192 | class FinalLayer(nn.Module): 193 | """The final layer of DiT.""" 194 | 195 | hidden_size: int 196 | patch_size: int 197 | out_channels: int 198 | 199 | def setup(self): 200 | self.linear = nn.Dense( 201 | self.patch_size * self.patch_size * self.out_channels, 202 | kernel_init=nn.zeros_init(), 203 | bias_init=nn.zeros_init(), 204 | ) 205 | 206 | @nn.compact 207 | def __call__(self, x, cond=None): 208 | if cond is not None: 209 | adaln_modulation = nn.Sequential([ 210 | nn.swish, 211 | nn.Dense( 212 | 2 * self.hidden_size, 213 | kernel_init=nn.zeros_init(), 214 | bias_init=nn.zeros_init(), 215 | ), 216 | ]) 217 | shift, scale = jnp.split(adaln_modulation(cond)[:, None, :], 2, axis=-1) 218 | norm_final = nn.LayerNorm(use_bias=False, use_scale=False) 219 | x = modulate(norm_final(x), shift, scale) 220 | else: 221 | x = nn.RMSNorm()(x) 222 | x = self.linear(x) 223 | return x 224 | 225 | 226 | class DiT(nn.Module): 227 | """Diffusion model with a Transformer backbone.""" 228 | 229 | img_size: int 230 | patch_size: int 231 | in_channels: int 232 | out_channels: int 233 | hidden_size: int 234 | depth: int 235 | num_heads: int 236 | mlp_ratio: float = 4.0 237 | dropout_rate: float = 0.0 238 | 239 | def setup(self): 240 | self.x_embedder = PatchEmbed( 241 | img_size=self.img_size, 242 | patch_size=self.patch_size, 243 | embed_dim=self.hidden_size, 244 | use_bias=True, 245 | ) 246 | self.grid_size = self.img_size // self.patch_size 247 | num_patches = self.grid_size * self.grid_size 248 | self.pos_embed = self.param( 249 | 'pos_embed', 250 | lambda k, s: get_2d_sincos_pos_embed(s[-1], int(num_patches**0.5)), 251 | [num_patches, self.hidden_size], 252 | ) 253 | self.blocks = [ 254 | DiTBlock( 255 | self.hidden_size, 256 | self.num_heads, 257 | mlp_ratio=self.mlp_ratio, 258 | dropout_rate=self.dropout_rate, 259 | ) 260 | for _ in range(self.depth) 261 | ] 262 | self.final_layer = FinalLayer( 263 | self.hidden_size, self.patch_size, self.out_channels 264 | ) 265 | 266 | def __call__(self, x, cond=None, train=False): 267 | c = x.shape[-1] 268 | p = self.patch_size 269 | grid_size = self.grid_size 270 | x = ( 271 | self.x_embedder(x) + self.pos_embed 272 | ) # (N, T, D), where T = H * W / p ** 2 273 | for block in self.blocks: 274 | x = block(x, cond=cond, train=train) # (N, T, D) 275 | x = self.final_layer(x, cond=cond) # (N, T, p ** 2 * c) 276 | x = x.reshape(-1, grid_size, grid_size, p, p, c) 277 | x = jnp.einsum('nhwpqc->nhpwqc', x) 278 | x = x.reshape(-1, grid_size * p, grid_size * p, c) # (N, H, W, c) 279 | return x 280 | 281 | 282 | def get_2d_sincos_pos_embed(embed_dim, grid_size): 283 | """2D sin-cos position embedding.""" 284 | grid_h = np.arange(grid_size, dtype=np.float32) 285 | grid_w = np.arange(grid_size, dtype=np.float32) 286 | grid = np.meshgrid(grid_w, grid_h) 287 | grid = np.stack(grid, axis=0) 288 | 289 | grid = grid.reshape([2, 1, grid_size, grid_size]) 290 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 291 | return pos_embed 292 | 293 | 294 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 295 | """Gets 2D sin-cos position embedding from grid.""" 296 | assert embed_dim % 2 == 0 297 | 298 | # use half of dimensions to encode grid_h 299 | emb_h = get_1d_sincos_pos_embed_from_grid( 300 | embed_dim // 2, grid[0] 301 | ) # (H*W, D/2) 302 | emb_w = get_1d_sincos_pos_embed_from_grid( 303 | embed_dim // 2, grid[1] 304 | ) # (H*W, D/2) 305 | 306 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 307 | return emb 308 | 309 | 310 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 311 | """Gets 1D sin-cos position embedding from grid.""" 312 | assert embed_dim % 2 == 0 313 | omega = np.arange(embed_dim // 2, dtype=np.float64) 314 | omega /= embed_dim / 2.0 315 | omega = 1.0 / 10000**omega # (D/2,) 316 | 317 | pos = pos.reshape(-1) # (M,) 318 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 319 | 320 | emb_sin = np.sin(out) # (M, D/2) 321 | emb_cos = np.cos(out) # (M, D/2) 322 | 323 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 324 | return emb 325 | -------------------------------------------------------------------------------- /md4/networks/sharded_transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Jax implementation of LLAMA2-like Transformer with model sharding. 17 | 18 | Based on PyTorch implementation 19 | https://github.com/karpathy/llama2.c/blob/master/model.py 20 | """ 21 | 22 | import dataclasses 23 | import math 24 | from typing import Optional 25 | 26 | import flax.linen as nn 27 | import jax 28 | import jax.numpy as jnp 29 | # pylint: disable=missing-class-docstring 30 | # pylint: disable=missing-function-docstring 31 | 32 | 33 | activation_map = dict( 34 | swiglu=nn.swish, 35 | geglu=nn.gelu, 36 | glu=nn.sigmoid, 37 | ) 38 | 39 | 40 | @dataclasses.dataclass(unsafe_hash=True) 41 | class ModelArgs: 42 | dim: int = 288 43 | n_layers: int = 6 44 | n_heads: int = 6 45 | n_kv_heads: Optional[int] = None 46 | output_channels: int = 1024 47 | hidden_dim: Optional[int] = None 48 | multiple_of: int = 32 # MLP hidden layer size will be multiple of 49 | norm_eps: float = 1e-5 50 | dropout_rate: float = 0.0 51 | use_attn_dropout: bool = True 52 | weight_tying: bool = False 53 | w_init_scale: float = 1.0 54 | depth_scaled_init: bool = False 55 | # glu, geglu, swiglu 56 | mlp_type: str = 'swiglu' 57 | # adaln, adaln_zero 58 | cond_type: str = 'adaln' 59 | embed_input: bool = False 60 | n_embed_classes: int = 1024 61 | causal: bool = False 62 | 63 | 64 | class RMSNorm(nn.Module): 65 | 66 | dim: int 67 | eps: float 68 | 69 | def setup(self): 70 | self.scale = self.param( 71 | 'scale', lambda key, shape: jnp.ones(shape), (self.dim,) 72 | ) 73 | 74 | def _norm(self, x): 75 | return x * jax.lax.rsqrt(jnp.square(x).mean(-1, keepdims=True) + self.eps) 76 | 77 | def __call__(self, x): 78 | output = self._norm(x) 79 | return output * self.scale 80 | 81 | 82 | def precompute_freqs_cis(dim, end, theta: float = 10000.0): 83 | freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2)[: (dim // 2)] / dim)) 84 | t = jnp.arange(end) 85 | freqs = jnp.outer(t, freqs) 86 | freqs_cos = jnp.cos(freqs) 87 | freqs_sin = jnp.sin(freqs) 88 | return freqs_cos, freqs_sin 89 | 90 | 91 | def reshape_for_broadcast(freqs_cis, x): 92 | ndim = x.ndim 93 | assert 1 < ndim 94 | assert freqs_cis.shape == (x.shape[1], x.shape[-1]) 95 | shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] 96 | return freqs_cis.reshape(shape) 97 | 98 | 99 | def jax_unstack(x, axis=0): 100 | return [ 101 | jax.lax.index_in_dim(x, i, axis, keepdims=False) 102 | for i in range(x.shape[axis]) 103 | ] 104 | 105 | 106 | def apply_rotary_emb(xq, xk, freqs_cos, freqs_sin): 107 | # reshape xq and xk to match the complex representation 108 | # [bs, seq_len, n_head, head_dim // 2] 109 | xq_r, xq_i = jax_unstack(xq.reshape(xq.shape[:-1] + (-1, 2)), -1) 110 | xk_r, xk_i = jax_unstack(xk.reshape(xk.shape[:-1] + (-1, 2)), -1) 111 | 112 | # reshape freqs_cos and freqs_sin for broadcasting 113 | # [1, seq_len, 1, head_dim // 2] 114 | freqs_cos = reshape_for_broadcast(freqs_cos, xq_r) 115 | freqs_sin = reshape_for_broadcast(freqs_sin, xq_r) 116 | 117 | # apply rotation using real numbers 118 | xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin 119 | xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos 120 | xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin 121 | xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos 122 | 123 | # flatten last two dimensions 124 | # [bs, seq_len, n_head, head_dim // 2, 2] -> [bs, seq_len, n_head, head_dim] 125 | xq_out = jnp.stack([xq_out_r, xq_out_i], axis=-1).reshape( 126 | xq_out_r.shape[:3] + (-1,) 127 | ) 128 | xk_out = jnp.stack([xk_out_r, xk_out_i], axis=-1).reshape( 129 | xk_out_r.shape[:3] + (-1,) 130 | ) 131 | 132 | return xq_out, xk_out 133 | 134 | 135 | def repeat_kv(x, n_rep): 136 | bs, slen, n_kv_heads, head_dim = x.shape 137 | if n_rep == 1: 138 | return x 139 | return jnp.tile(x[:, :, :, None, :], [1, 1, 1, n_rep, 1]).reshape( 140 | bs, slen, n_kv_heads * n_rep, head_dim 141 | ) 142 | 143 | 144 | class Dropout1d(nn.Module): 145 | 146 | dropout_rate: float = 0.0 147 | 148 | def __call__(self, x, deterministic=True): 149 | if (self.dropout_rate > 0.0) and not deterministic: 150 | drop = jax.random.bernoulli( 151 | self.make_rng('dropout'), 152 | 1 - self.dropout_rate, 153 | (x.shape[0], 1, x.shape[-1]), 154 | ) 155 | x = x * drop / (1 - self.dropout_rate) 156 | return x 157 | 158 | 159 | class Attention(nn.Module): 160 | 161 | dim: int 162 | n_heads: int 163 | n_kv_heads: int | None = None 164 | dropout_rate: float = 0.0 165 | causal: bool = False 166 | qkv_bias: bool = False 167 | use_attn_dropout: bool = True 168 | 169 | def setup(self): 170 | self._n_kv_heads = ( 171 | self.n_heads if self.n_kv_heads is None else self.n_kv_heads 172 | ) 173 | assert self.n_heads % self._n_kv_heads == 0 174 | self.n_rep = self.n_heads // self._n_kv_heads 175 | self.head_dim = self.dim // self.n_heads 176 | self.wq = nn.Dense( 177 | self.n_heads * self.head_dim, 178 | use_bias=self.qkv_bias, 179 | kernel_init=nn.with_logical_partitioning( 180 | nn.linear.default_kernel_init, ('embed', 'qkv') 181 | ), 182 | ) # fsdp, tensor 183 | self.wk = nn.Dense( 184 | self._n_kv_heads * self.head_dim, 185 | use_bias=self.qkv_bias, 186 | kernel_init=nn.with_logical_partitioning( 187 | nn.linear.default_kernel_init, ('embed', 'qkv') 188 | ), 189 | ) 190 | self.wv = nn.Dense( 191 | self._n_kv_heads * self.head_dim, 192 | use_bias=self.qkv_bias, 193 | kernel_init=nn.with_logical_partitioning( 194 | nn.linear.default_kernel_init, ('embed', 'qkv') 195 | ), 196 | ) 197 | self.wo = nn.Dense( 198 | self.dim, 199 | use_bias=False, 200 | kernel_init=nn.with_logical_partitioning( 201 | nn.linear.default_kernel_init, ('qkv', 'embed') 202 | ) 203 | ) 204 | if self.use_attn_dropout and self.dropout_rate > 0.0: 205 | self.attn_dropout = nn.Dropout(self.dropout_rate) 206 | if self.dropout_rate > 0.0: 207 | self.resid_dropout = Dropout1d(self.dropout_rate) 208 | 209 | def __call__(self, x, freqs_cos, freqs_sin, train=False): 210 | bsz, seqlen, _ = x.shape 211 | 212 | # QKV 213 | xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) 214 | xq = xq.reshape(bsz, seqlen, self.n_heads, self.head_dim) 215 | xk = xk.reshape(bsz, seqlen, self._n_kv_heads, self.head_dim) 216 | xv = xv.reshape(bsz, seqlen, self._n_kv_heads, self.head_dim) 217 | 218 | # RoPE relative positional embeddings 219 | xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin) 220 | 221 | # grouped multiquery attention: expand out keys and values 222 | xk = repeat_kv(xk, self.n_rep) 223 | xv = repeat_kv(xv, self.n_rep) 224 | 225 | # make heads into a batch dimension 226 | xq = xq.swapaxes(1, 2) # (bs, n_heads, seqlen, head_dim) 227 | xk = xk.swapaxes(1, 2) 228 | xv = xv.swapaxes(1, 2) 229 | 230 | scores = jnp.matmul(xq, xk.swapaxes(2, 3)) / math.sqrt(self.head_dim) 231 | if self.causal: 232 | mask = jnp.full((1, 1, seqlen, seqlen), -jnp.inf) 233 | mask = jnp.triu(mask, k=1) 234 | scores = ( 235 | scores + mask[:, :, :seqlen, :seqlen] 236 | ) # (bs, n_heads, seqlen, seqlen) 237 | scores = nn.softmax(scores, axis=-1) 238 | if self.use_attn_dropout and self.dropout_rate > 0.0: 239 | scores = self.attn_dropout(scores, deterministic=not train) 240 | output = jnp.matmul(scores, xv) # (bs, n_heads, seqlen, head_dim) 241 | 242 | # restore time as batch dimension and concat heads 243 | output = output.swapaxes(1, 2).reshape(bsz, seqlen, -1) 244 | 245 | # final projection into the residual stream 246 | output = self.wo(output) 247 | if self.dropout_rate > 0.0: 248 | output = self.resid_dropout(output, deterministic=not train) 249 | return output 250 | 251 | 252 | class FeedForward(nn.Module): 253 | 254 | dim: int 255 | multiple_of: int 256 | dropout_rate: float 257 | hidden_dim: int | None = None 258 | w_init_scale: float = 1.0 259 | mlp_type: str = 'swiglu' 260 | 261 | def setup(self): 262 | multiple_of = self.multiple_of 263 | hidden_dim = self.hidden_dim 264 | if hidden_dim is None: 265 | hidden_dim = 4 * self.dim 266 | hidden_dim = int(2 * hidden_dim / 3) 267 | hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) 268 | w_init = nn.initializers.variance_scaling( 269 | self.w_init_scale, 'fan_in', 'truncated_normal' 270 | ) 271 | self.w1 = nn.Dense( 272 | hidden_dim, 273 | use_bias=False, 274 | kernel_init=nn.with_logical_partitioning(w_init, ('embed', 'mlp')), 275 | ) 276 | self.w2 = nn.Dense( 277 | self.dim, 278 | use_bias=False, 279 | kernel_init=nn.with_logical_partitioning(w_init, ('mlp', 'embed')), 280 | ) 281 | self.w3 = nn.Dense( 282 | hidden_dim, 283 | use_bias=False, 284 | kernel_init=nn.with_logical_partitioning(w_init, ('embed', 'mlp')), 285 | ) 286 | # self.dropout = nn.Dropout(self.dropout_rate) 287 | if self.dropout_rate > 0.0: 288 | self.dropout = Dropout1d(self.dropout_rate) 289 | 290 | def __call__(self, x, train=False): 291 | activation = activation_map[self.mlp_type] 292 | y = self.w2(activation(self.w1(x)) * self.w3(x)) 293 | if self.dropout_rate > 0.0: 294 | return self.dropout(y, deterministic=not train) 295 | else: 296 | return y 297 | 298 | 299 | class TransformerBlock(nn.Module): 300 | 301 | layer_id: int 302 | args: ModelArgs 303 | 304 | def setup(self): 305 | args = self.args 306 | self.attention = Attention( 307 | args.dim, 308 | args.n_heads, 309 | n_kv_heads=args.n_kv_heads, 310 | dropout_rate=args.dropout_rate, 311 | causal=args.causal, 312 | use_attn_dropout=args.use_attn_dropout, 313 | ) 314 | 315 | if args.depth_scaled_init: 316 | w_init_scale = 2.0 / args.n_layers 317 | else: 318 | w_init_scale = args.w_init_scale 319 | 320 | self.feed_forward = FeedForward( 321 | dim=args.dim, 322 | multiple_of=args.multiple_of, 323 | dropout_rate=args.dropout_rate, 324 | hidden_dim=args.hidden_dim, 325 | w_init_scale=w_init_scale, 326 | mlp_type=args.mlp_type, 327 | ) 328 | 329 | @nn.compact 330 | def __call__(self, x, freqs_cos, freqs_sin, cond=None, train=False): 331 | if cond is not None: 332 | activation = activation_map[self.args.mlp_type] 333 | if self.args.cond_type == 'adaln': 334 | ln = nn.Sequential([ 335 | # nn.swish, 336 | activation, 337 | nn.Dense( 338 | 6 * self.args.dim, 339 | use_bias=True, 340 | ), 341 | ]) 342 | elif self.args.cond_type == 'adaln_zero': 343 | ln = nn.Sequential([ 344 | # nn.swish, 345 | activation, 346 | nn.Dense( 347 | 6 * self.args.dim, 348 | use_bias=True, 349 | kernel_init=nn.initializers.zeros, 350 | bias_init=nn.initializers.zeros, 351 | ), 352 | ]) 353 | else: 354 | raise NotImplementedError() 355 | (shift_att, scale_att, gate_att, shift_mlp, scale_mlp, gate_mlp) = ( 356 | jnp.split(ln(cond)[:, None, :], 6, axis=-1) 357 | ) 358 | attention_norm = nn.LayerNorm( 359 | epsilon=self.args.norm_eps, use_bias=False, use_scale=False 360 | ) 361 | ffn_norm = nn.LayerNorm( 362 | epsilon=self.args.norm_eps, use_bias=False, use_scale=False 363 | ) 364 | h = x + gate_att * self.attention( 365 | attention_norm(x) * (scale_att + 1.0) + shift_att, 366 | freqs_cos, 367 | freqs_sin, 368 | train=train, 369 | ) 370 | out = h + gate_mlp * self.feed_forward( 371 | ffn_norm(h) * (scale_mlp + 1.0) + shift_mlp, train=train 372 | ) 373 | else: 374 | attention_norm = RMSNorm(self.args.dim, eps=self.args.norm_eps) 375 | ffn_norm = RMSNorm(self.args.dim, eps=self.args.norm_eps) 376 | h = x + self.attention( 377 | attention_norm(x), freqs_cos, freqs_sin, train=train 378 | ) 379 | out = h + self.feed_forward(ffn_norm(h), train=train) 380 | 381 | return out 382 | 383 | 384 | class Transformer(nn.Module): 385 | 386 | args: ModelArgs 387 | 388 | @nn.compact 389 | def __call__(self, x, cond=None, train=False, output_channels=None): 390 | args = self.args 391 | if output_channels is None: 392 | output_channels = args.output_channels 393 | 394 | if args.embed_input: 395 | h = nn.Embed( 396 | args.n_embed_classes, 397 | args.dim, 398 | embedding_init=nn.with_logical_partitioning( 399 | nn.linear.default_embed_init, ('embed_vocab', 'embed') 400 | ), 401 | )(x) 402 | if args.dropout_rate > 0.0: 403 | h = nn.Dropout(args.dropout_rate, deterministic=not train)(h) 404 | else: 405 | h = nn.Dense( 406 | args.dim, 407 | kernel_init=nn.with_logical_partitioning( 408 | nn.linear.default_kernel_init, ('input_embed', 'embed') 409 | ), 410 | )(x) 411 | 412 | seqlen = x.shape[1] 413 | freqs_cos, freqs_sin = precompute_freqs_cis( 414 | args.dim // args.n_heads, seqlen 415 | ) 416 | 417 | freqs_cos = freqs_cos[:seqlen] 418 | freqs_sin = freqs_sin[:seqlen] 419 | 420 | for layer_id in range(args.n_layers): 421 | h = TransformerBlock(layer_id, args)( 422 | h, freqs_cos, freqs_sin, cond=cond, train=train 423 | ) 424 | 425 | if cond is not None: 426 | output_norm = nn.LayerNorm( 427 | epsilon=args.norm_eps, use_bias=False, use_scale=False 428 | ) 429 | activation = activation_map[args.mlp_type] 430 | if args.cond_type == 'adaln': 431 | ln = nn.Sequential([ 432 | # nn.swish, 433 | activation, 434 | nn.Dense( 435 | 2 * args.dim, 436 | use_bias=True, 437 | ), 438 | ]) 439 | elif args.cond_type == 'adaln_zero': 440 | ln = nn.Sequential([ 441 | # nn.swish, 442 | activation, 443 | nn.Dense( 444 | 2 * args.dim, 445 | use_bias=True, 446 | kernel_init=nn.initializers.zeros, 447 | bias_init=nn.initializers.zeros, 448 | ), 449 | ]) 450 | else: 451 | raise NotImplementedError() 452 | shift_out, scale_out = jnp.split(ln(cond)[:, None, :], 2, axis=-1) 453 | logits = nn.Dense( 454 | output_channels, 455 | use_bias=False, 456 | kernel_init=nn.with_logical_partitioning( 457 | nn.initializers.zeros, ('embed', 'output_vocab') 458 | ), 459 | )(output_norm(h) * (scale_out + 1) + shift_out) 460 | else: 461 | h = RMSNorm(args.dim, args.norm_eps)(h) 462 | logits = nn.Dense( 463 | features=output_channels, 464 | use_bias=False, 465 | kernel_init=nn.with_logical_partitioning( 466 | nn.initializers.zeros, ('embed', 'output_vocab') 467 | ), 468 | )(h) 469 | 470 | return logits 471 | -------------------------------------------------------------------------------- /md4/networks/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Jax implementation of LLAMA2-like Transformer. 17 | 18 | Based on PyTorch implementation 19 | https://github.com/karpathy/llama2.c/blob/master/model.py 20 | """ 21 | 22 | import dataclasses 23 | import math 24 | from typing import Optional 25 | 26 | import flax.linen as nn 27 | import jax 28 | import jax.numpy as jnp 29 | # pylint: disable=missing-class-docstring 30 | # pylint: disable=missing-function-docstring 31 | 32 | 33 | activation_map = dict( 34 | swiglu=nn.swish, 35 | geglu=nn.gelu, 36 | glu=nn.sigmoid, 37 | ) 38 | 39 | 40 | @dataclasses.dataclass(unsafe_hash=True) 41 | class ModelArgs: 42 | dim: int = 288 43 | n_layers: int = 6 44 | n_heads: int = 6 45 | n_kv_heads: Optional[int] = None 46 | output_channels: int = 1024 47 | hidden_dim: Optional[int] = None 48 | multiple_of: int = 32 # MLP hidden layer size will be multiple of 49 | norm_eps: float = 1e-5 50 | dropout_rate: float = 0.0 51 | weight_tying: bool = False 52 | w_init_scale: float = 1.0 53 | depth_scaled_init: bool = False 54 | # glu, geglu, swiglu 55 | mlp_type: str = 'swiglu' 56 | # adaln, adaln_zero 57 | cond_type: str = 'adaln' 58 | embed_input: bool = False 59 | n_embed_classes: int = 1024 60 | causal: bool = False 61 | 62 | 63 | class RMSNorm(nn.Module): 64 | 65 | dim: int 66 | eps: float 67 | 68 | def setup(self): 69 | self.scale = self.param( 70 | 'scale', lambda key, shape: jnp.ones(shape), (self.dim,) 71 | ) 72 | 73 | def _norm(self, x): 74 | return x * jax.lax.rsqrt(jnp.square(x).mean(-1, keepdims=True) + self.eps) 75 | 76 | def __call__(self, x): 77 | output = self._norm(x) 78 | return output * self.scale 79 | 80 | 81 | def precompute_freqs_cis(dim, end, theta: float = 10000.0): 82 | freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2)[: (dim // 2)] / dim)) 83 | t = jnp.arange(end) 84 | freqs = jnp.outer(t, freqs) 85 | freqs_cos = jnp.cos(freqs) 86 | freqs_sin = jnp.sin(freqs) 87 | return freqs_cos, freqs_sin 88 | 89 | 90 | def reshape_for_broadcast(freqs_cis, x): 91 | ndim = x.ndim 92 | assert 1 < ndim 93 | assert freqs_cis.shape == (x.shape[1], x.shape[-1]) 94 | shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] 95 | return freqs_cis.reshape(shape) 96 | 97 | 98 | def jax_unstack(x, axis=0): 99 | return [ 100 | jax.lax.index_in_dim(x, i, axis, keepdims=False) 101 | for i in range(x.shape[axis]) 102 | ] 103 | 104 | 105 | def apply_rotary_emb(xq, xk, freqs_cos, freqs_sin): 106 | # reshape xq and xk to match the complex representation 107 | # [bs, seq_len, n_head, head_dim // 2] 108 | xq_r, xq_i = jax_unstack(xq.reshape(xq.shape[:-1] + (-1, 2)), -1) 109 | xk_r, xk_i = jax_unstack(xk.reshape(xk.shape[:-1] + (-1, 2)), -1) 110 | 111 | # reshape freqs_cos and freqs_sin for broadcasting 112 | # [1, seq_len, 1, head_dim // 2] 113 | freqs_cos = reshape_for_broadcast(freqs_cos, xq_r) 114 | freqs_sin = reshape_for_broadcast(freqs_sin, xq_r) 115 | 116 | # apply rotation using real numbers 117 | xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin 118 | xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos 119 | xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin 120 | xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos 121 | 122 | # flatten last two dimensions 123 | # [bs, seq_len, n_head, head_dim // 2, 2] -> [bs, seq_len, n_head, head_dim] 124 | xq_out = jnp.stack([xq_out_r, xq_out_i], axis=-1).reshape( 125 | xq_out_r.shape[:3] + (-1,) 126 | ) 127 | xk_out = jnp.stack([xk_out_r, xk_out_i], axis=-1).reshape( 128 | xk_out_r.shape[:3] + (-1,) 129 | ) 130 | 131 | return xq_out, xk_out 132 | 133 | 134 | def repeat_kv(x, n_rep): 135 | bs, slen, n_kv_heads, head_dim = x.shape 136 | if n_rep == 1: 137 | return x 138 | return jnp.tile(x[:, :, :, None, :], [1, 1, 1, n_rep, 1]).reshape( 139 | bs, slen, n_kv_heads * n_rep, head_dim 140 | ) 141 | 142 | 143 | class Dropout1d(nn.Module): 144 | 145 | dropout_rate: float = 0.0 146 | 147 | def __call__(self, x, deterministic=True): 148 | if (self.dropout_rate > 0.0) and not deterministic: 149 | drop = jax.random.bernoulli( 150 | self.make_rng('dropout'), 151 | 1 - self.dropout_rate, 152 | (x.shape[0], 1, x.shape[-1]), 153 | ) 154 | x = x * drop / (1 - self.dropout_rate) 155 | return x 156 | 157 | 158 | class Attention(nn.Module): 159 | 160 | dim: int 161 | n_heads: int 162 | n_kv_heads: int | None = None 163 | dropout_rate: float = 0.0 164 | causal: bool = False 165 | qkv_bias: bool = False 166 | 167 | def setup(self): 168 | self._n_kv_heads = ( 169 | self.n_heads if self.n_kv_heads is None else self.n_kv_heads 170 | ) 171 | assert self.n_heads % self._n_kv_heads == 0 172 | self.n_rep = self.n_heads // self._n_kv_heads 173 | self.head_dim = self.dim // self.n_heads 174 | self.wq = nn.Dense(self.n_heads * self.head_dim, use_bias=self.qkv_bias) 175 | self.wk = nn.Dense(self._n_kv_heads * self.head_dim, use_bias=self.qkv_bias) 176 | self.wv = nn.Dense(self._n_kv_heads * self.head_dim, use_bias=self.qkv_bias) 177 | self.wo = nn.Dense(self.dim, use_bias=False) 178 | if self.dropout_rate > 0.0: 179 | self.attn_dropout = nn.Dropout(self.dropout_rate) 180 | self.resid_dropout = Dropout1d(self.dropout_rate) 181 | 182 | def __call__(self, x, freqs_cos, freqs_sin, train=False): 183 | bsz, seqlen, _ = x.shape 184 | 185 | # QKV 186 | xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) 187 | xq = xq.reshape(bsz, seqlen, self.n_heads, self.head_dim) 188 | xk = xk.reshape(bsz, seqlen, self._n_kv_heads, self.head_dim) 189 | xv = xv.reshape(bsz, seqlen, self._n_kv_heads, self.head_dim) 190 | 191 | # RoPE relative positional embeddings 192 | xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin) 193 | 194 | # grouped multiquery attention: expand out keys and values 195 | xk = repeat_kv(xk, self.n_rep) 196 | xv = repeat_kv(xv, self.n_rep) 197 | 198 | # make heads into a batch dimension 199 | xq = xq.swapaxes(1, 2) # (bs, n_heads, seqlen, head_dim) 200 | xk = xk.swapaxes(1, 2) 201 | xv = xv.swapaxes(1, 2) 202 | 203 | scores = jnp.matmul(xq, xk.swapaxes(2, 3)) / math.sqrt(self.head_dim) 204 | if self.causal: 205 | mask = jnp.full((1, 1, seqlen, seqlen), -jnp.inf) 206 | mask = jnp.triu(mask, k=1) 207 | scores = ( 208 | scores + mask[:, :, :seqlen, :seqlen] 209 | ) # (bs, n_heads, seqlen, seqlen) 210 | scores = nn.softmax(scores, axis=-1) 211 | if self.dropout_rate > 0.0: 212 | scores = self.attn_dropout(scores, deterministic=not train) 213 | output = jnp.matmul(scores, xv) # (bs, n_heads, seqlen, head_dim) 214 | 215 | # restore time as batch dimension and concat heads 216 | output = output.swapaxes(1, 2).reshape(bsz, seqlen, -1) 217 | 218 | # final projection into the residual stream 219 | output = self.wo(output) 220 | if self.dropout_rate > 0.0: 221 | output = self.resid_dropout(output, deterministic=not train) 222 | return output 223 | 224 | 225 | class FeedForward(nn.Module): 226 | 227 | dim: int 228 | multiple_of: int 229 | dropout_rate: float 230 | hidden_dim: int | None = None 231 | w_init_scale: float = 1.0 232 | mlp_type: str = 'swiglu' 233 | 234 | def setup(self): 235 | multiple_of = self.multiple_of 236 | hidden_dim = self.hidden_dim 237 | if hidden_dim is None: 238 | hidden_dim = 4 * self.dim 239 | hidden_dim = int(2 * hidden_dim / 3) 240 | hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) 241 | w_init = nn.initializers.variance_scaling( 242 | self.w_init_scale, 'fan_in', 'truncated_normal' 243 | ) 244 | self.w1 = nn.Dense(hidden_dim, use_bias=False, kernel_init=w_init) 245 | self.w2 = nn.Dense(self.dim, use_bias=False, kernel_init=w_init) 246 | self.w3 = nn.Dense(hidden_dim, use_bias=False, kernel_init=w_init) 247 | # self.dropout = nn.Dropout(self.dropout_rate) 248 | if self.dropout_rate > 0.0: 249 | self.dropout = Dropout1d(self.dropout_rate) 250 | 251 | def __call__(self, x, train=False): 252 | activation = activation_map[self.mlp_type] 253 | y = self.w2(activation(self.w1(x)) * self.w3(x)) 254 | if self.dropout_rate > 0.0: 255 | return self.dropout(y, deterministic=not train) 256 | else: 257 | return y 258 | 259 | 260 | class TransformerBlock(nn.Module): 261 | 262 | layer_id: int 263 | args: ModelArgs 264 | 265 | def setup(self): 266 | args = self.args 267 | self.attention = Attention( 268 | args.dim, 269 | args.n_heads, 270 | n_kv_heads=args.n_kv_heads, 271 | dropout_rate=args.dropout_rate, 272 | causal=args.causal, 273 | ) 274 | 275 | if args.depth_scaled_init: 276 | w_init_scale = 2.0 / args.n_layers 277 | else: 278 | w_init_scale = args.w_init_scale 279 | 280 | self.feed_forward = FeedForward( 281 | dim=args.dim, 282 | multiple_of=args.multiple_of, 283 | dropout_rate=args.dropout_rate, 284 | hidden_dim=args.hidden_dim, 285 | w_init_scale=w_init_scale, 286 | mlp_type=args.mlp_type, 287 | ) 288 | 289 | @nn.compact 290 | def __call__(self, x, freqs_cos, freqs_sin, cond=None, train=False): 291 | if cond is not None: 292 | activation = activation_map[self.args.mlp_type] 293 | if self.args.cond_type == 'adaln': 294 | ln = nn.Sequential([ 295 | # nn.swish, 296 | activation, 297 | nn.Dense(6 * self.args.dim, use_bias=True), 298 | ]) 299 | elif self.args.cond_type == 'adaln_zero': 300 | ln = nn.Sequential([ 301 | # nn.swish, 302 | activation, 303 | nn.Dense( 304 | 6 * self.args.dim, 305 | use_bias=True, 306 | kernel_init=nn.initializers.zeros, 307 | bias_init=nn.initializers.zeros, 308 | ), 309 | ]) 310 | else: 311 | raise NotImplementedError() 312 | (shift_att, scale_att, gate_att, shift_mlp, scale_mlp, gate_mlp) = ( 313 | jnp.split(ln(cond)[:, None, :], 6, axis=-1) 314 | ) 315 | attention_norm = nn.LayerNorm( 316 | epsilon=self.args.norm_eps, use_bias=False, use_scale=False 317 | ) 318 | ffn_norm = nn.LayerNorm( 319 | epsilon=self.args.norm_eps, use_bias=False, use_scale=False 320 | ) 321 | h = x + gate_att * self.attention( 322 | attention_norm(x) * (scale_att + 1.0) + shift_att, 323 | freqs_cos, 324 | freqs_sin, 325 | train=train, 326 | ) 327 | out = h + gate_mlp * self.feed_forward( 328 | ffn_norm(h) * (scale_mlp + 1.0) + shift_mlp, train=train 329 | ) 330 | else: 331 | attention_norm = RMSNorm(self.args.dim, eps=self.args.norm_eps) 332 | ffn_norm = RMSNorm(self.args.dim, eps=self.args.norm_eps) 333 | h = x + self.attention( 334 | attention_norm(x), freqs_cos, freqs_sin, train=train 335 | ) 336 | out = h + self.feed_forward(ffn_norm(h), train=train) 337 | 338 | return out 339 | 340 | 341 | class Transformer(nn.Module): 342 | 343 | args: ModelArgs 344 | 345 | @nn.compact 346 | def __call__(self, x, cond=None, train=False, output_channels=None): 347 | args = self.args 348 | if output_channels is None: 349 | output_channels = args.output_channels 350 | 351 | if args.embed_input: 352 | h = nn.Embed(args.n_embed_classes, args.dim)(x) 353 | if args.dropout_rate > 0.0: 354 | h = nn.Dropout(args.dropout_rate, deterministic=not train)(h) 355 | else: 356 | h = nn.Dense(args.dim)(x) 357 | 358 | seqlen = x.shape[1] 359 | freqs_cos, freqs_sin = precompute_freqs_cis( 360 | args.dim // args.n_heads, seqlen 361 | ) 362 | 363 | freqs_cos = freqs_cos[:seqlen] 364 | freqs_sin = freqs_sin[:seqlen] 365 | 366 | for layer_id in range(args.n_layers): 367 | h = TransformerBlock(layer_id, args)( 368 | h, freqs_cos, freqs_sin, cond=cond, train=train 369 | ) 370 | 371 | if cond is not None: 372 | output_norm = nn.LayerNorm( 373 | epsilon=args.norm_eps, use_bias=False, use_scale=False 374 | ) 375 | activation = activation_map[args.mlp_type] 376 | if args.cond_type == 'adaln': 377 | ln = nn.Sequential([ 378 | # nn.swish, 379 | activation, 380 | nn.Dense(2 * args.dim, use_bias=True), 381 | ]) 382 | elif args.cond_type == 'adaln_zero': 383 | ln = nn.Sequential([ 384 | # nn.swish, 385 | activation, 386 | nn.Dense( 387 | 2 * args.dim, 388 | use_bias=True, 389 | kernel_init=nn.initializers.zeros, 390 | bias_init=nn.initializers.zeros, 391 | ), 392 | ]) 393 | else: 394 | raise NotImplementedError() 395 | shift_out, scale_out = jnp.split(ln(cond)[:, None, :], 2, axis=-1) 396 | logits = nn.Dense( 397 | output_channels, use_bias=False, kernel_init=nn.initializers.zeros 398 | )(output_norm(h) * (scale_out + 1) + shift_out) 399 | else: 400 | h = RMSNorm(args.dim, args.norm_eps)(h) 401 | logits = nn.Dense( 402 | features=output_channels, 403 | use_bias=False, 404 | kernel_init=nn.initializers.zeros, 405 | )(h) 406 | 407 | return logits 408 | -------------------------------------------------------------------------------- /md4/networks/unet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """UNet implementation. 17 | 18 | Adapted from https://github.com/google-research/vdm/blob/main/model_vdm.py 19 | """ 20 | 21 | import flax.linen as nn 22 | import jax.numpy as jnp 23 | 24 | 25 | class SelfAttention(nn.Module): 26 | """Self attention layer in UNets.""" 27 | 28 | num_heads: int = 1 29 | 30 | @nn.compact 31 | def __call__(self, x): 32 | _, h, w, c = x.shape 33 | z = nn.GroupNorm()(x) 34 | z = z.reshape(z.shape[0], -1, c) 35 | mha = nn.MultiHeadDotProductAttention( 36 | num_heads=self.num_heads, qkv_features=c 37 | ) 38 | z = mha(z, z) 39 | z = z.reshape(-1, h, w, c) + x 40 | return z 41 | 42 | 43 | class ResBlock(nn.Module): 44 | """Residual block in UNets.""" 45 | 46 | out_channels: int 47 | dropout_rate: float = 0.1 48 | 49 | @nn.compact 50 | def __call__(self, x, cond=None, train=False): 51 | in_channels = x.shape[-1] 52 | h = nn.GroupNorm()(x) 53 | h = nn.swish(h) 54 | h = nn.Conv(self.out_channels, kernel_size=(3, 3))(h) 55 | if cond is not None: 56 | cond_act = nn.Dense( 57 | self.out_channels, 58 | use_bias=False, 59 | kernel_init=nn.initializers.zeros, 60 | )(cond) 61 | h = h + cond_act[:, None, None, :] 62 | h = nn.GroupNorm()(h) 63 | h = nn.swish(h) 64 | h = nn.Dropout(rate=self.dropout_rate)(h, deterministic=not train) 65 | h = nn.Conv( 66 | self.out_channels, kernel_size=(3, 3), kernel_init=nn.initializers.zeros 67 | )(h) 68 | if in_channels != self.out_channels: 69 | h = nn.Dense(self.out_channels)(x) + h 70 | else: 71 | h = x + h 72 | return h 73 | 74 | 75 | class UNet(nn.Module): 76 | """UNet for Diffusion.""" 77 | 78 | d_channels: int = 128 79 | n_layers: int = 32 80 | add_input: bool = False 81 | output_channels: int | None = None 82 | dropout_rate: float = 0.1 83 | 84 | @nn.compact 85 | def __call__(self, x, cond=None, train=False, output_channels=None): 86 | if output_channels is None: 87 | if self.output_channels is None: 88 | output_channels = x.shape[-1] 89 | else: 90 | output_channels = self.output_channels 91 | 92 | # Linear projection of input 93 | h = nn.Conv(self.d_channels, kernel_size=(3, 3))(x) 94 | hs = [h] 95 | 96 | # Downsampling 97 | for _ in range(self.n_layers): 98 | h = ResBlock( 99 | out_channels=self.d_channels, dropout_rate=self.dropout_rate 100 | )(h, cond, train) 101 | hs.append(h) 102 | 103 | # Middle 104 | h = ResBlock(out_channels=self.d_channels, dropout_rate=self.dropout_rate)( 105 | h, cond, train 106 | ) 107 | h = SelfAttention(num_heads=1)(h) 108 | h = ResBlock(out_channels=self.d_channels, dropout_rate=self.dropout_rate)( 109 | h, cond, train 110 | ) 111 | 112 | # Upsampling 113 | for _ in range(self.n_layers + 1): 114 | h = jnp.concatenate([h, hs.pop()], axis=-1) 115 | h = ResBlock( 116 | out_channels=self.d_channels, dropout_rate=self.dropout_rate 117 | )(h, cond, train) 118 | 119 | assert not hs 120 | 121 | # Predict noise 122 | h = nn.swish(nn.GroupNorm()(h)) 123 | h = nn.Conv(output_channels, (3, 3), kernel_init=nn.initializers.zeros)(h) 124 | 125 | if self.add_input: 126 | h += x 127 | return h 128 | -------------------------------------------------------------------------------- /md4/networks/uvit.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """UViT implementation based on DiT.""" 17 | 18 | from typing import Sequence 19 | 20 | import flax.linen as nn 21 | import jax.numpy as jnp 22 | 23 | from md4.networks.dit import DiT 24 | 25 | 26 | def nearest_neighbor_upsample(x): 27 | bs, h, w, c = x.shape 28 | x = x.reshape(bs, h, 1, w, 1, c) 29 | x = jnp.broadcast_to(x, (bs, h, 2, w, 2, c)) 30 | return x.reshape(bs, h * 2, w * 2, c) 31 | 32 | 33 | class CondGroupNorm(nn.Module): 34 | """Conditional normalization.""" 35 | 36 | @nn.compact 37 | def __call__(self, x, cond=None): 38 | c = x.shape[-1] 39 | if cond is not None: 40 | cond_act = nn.Dense(c * 2, kernel_init=nn.initializers.zeros)(cond) 41 | scale, shift = jnp.split(cond_act[:, None, None, :], 2, axis=-1) 42 | x = nn.GroupNorm(use_bias=False, use_scale=False)(x) * (1 + scale) + shift 43 | else: 44 | x = nn.GroupNorm()(x) 45 | return x 46 | 47 | 48 | class SelfAttention(nn.Module): 49 | """Self attention layer in UNets.""" 50 | 51 | num_heads: int = 1 52 | 53 | @nn.compact 54 | def __call__(self, x, cond=None): 55 | bs, h, w, c = x.shape 56 | z = CondGroupNorm()(x, cond=cond) 57 | z = z.reshape(bs, -1, c) 58 | 59 | mha = nn.MultiHeadDotProductAttention( 60 | num_heads=self.num_heads, 61 | qkv_features=c, 62 | out_kernel_init=nn.zeros_init(), 63 | ) 64 | z = mha(z, z) 65 | z = z.reshape(-1, h, w, c) + x 66 | return z 67 | 68 | 69 | class ResBlock(nn.Module): 70 | """Residual block in UNets.""" 71 | 72 | out_channels: int | None = None 73 | dropout_rate: float = 0.1 74 | resample: str | None = None 75 | 76 | @nn.compact 77 | def __call__(self, x, cond=None, train=False): 78 | in_channels = x.shape[-1] 79 | out_channels = ( 80 | in_channels if self.out_channels is None else self.out_channels 81 | ) 82 | h = CondGroupNorm()(x, cond=cond) 83 | h = nn.swish(h) 84 | 85 | if self.resample is not None: 86 | 87 | def updown(z): 88 | return { 89 | 'up': nearest_neighbor_upsample(z), 90 | 'down': nn.avg_pool(z, (2, 2), (2, 2)), 91 | }[self.resample] 92 | 93 | h = updown(h) 94 | x = updown(x) 95 | h = nn.Conv(out_channels, kernel_size=(3, 3))(h) 96 | h = CondGroupNorm()(h, cond=cond) 97 | h = nn.swish(h) 98 | h = nn.Dropout(rate=self.dropout_rate)(h, deterministic=not train) 99 | h = nn.Conv( 100 | out_channels, kernel_size=(3, 3), kernel_init=nn.initializers.zeros 101 | )(h) 102 | if in_channels != out_channels: 103 | x = nn.Dense(out_channels)(x) 104 | return x + h 105 | 106 | 107 | class UNet(nn.Module): 108 | """UNet for Diffusion.""" 109 | 110 | d_channels: int = 128 111 | n_layers: int = 32 112 | n_dit_layers: int = 0 113 | dit_num_heads: int = 12 114 | dit_hidden_size: int = 768 115 | ch_mult: Sequence[int] = (1,) 116 | add_input: bool = False 117 | output_channels: int | None = None 118 | dropout_rate: float = 0.1 119 | 120 | @nn.compact 121 | def __call__(self, x, cond=None, train=False, output_channels=None): 122 | if output_channels is None: 123 | if self.output_channels is None: 124 | output_channels = x.shape[-1] 125 | else: 126 | output_channels = self.output_channels 127 | num_resolutions = len(self.ch_mult) 128 | 129 | # Linear projection of input 130 | h = nn.Conv(self.d_channels, kernel_size=(3, 3))(x) 131 | hs = [h] 132 | 133 | # Downsampling 134 | for i_level in range(num_resolutions): 135 | for _ in range(self.n_layers): 136 | h = ResBlock( 137 | out_channels=self.d_channels * self.ch_mult[i_level], 138 | dropout_rate=self.dropout_rate, 139 | )(h, cond=cond, train=train) 140 | hs.append(h) 141 | # Downsample 142 | if i_level != num_resolutions - 1: 143 | h = ResBlock( 144 | dropout_rate=self.dropout_rate, 145 | resample='down', 146 | )(h, cond=cond, train=train) 147 | hs.append(h) 148 | 149 | # Middle 150 | _, img_size, _, c = h.shape 151 | h = DiT( 152 | img_size=img_size, 153 | patch_size=2, 154 | in_channels=c, 155 | out_channels=c, 156 | hidden_size=self.dit_hidden_size, # c * 2, c * 3.. 157 | depth=self.n_dit_layers, # 8, 12, 16, 20 158 | num_heads=self.dit_num_heads, # 8, 12.. 159 | dropout_rate=self.dropout_rate, 160 | )(h, cond=cond, train=train) 161 | 162 | # Upsampling 163 | for i_level in reversed(range(num_resolutions)): 164 | for _ in range(self.n_layers + 1): 165 | h = jnp.concatenate([h, hs.pop()], axis=-1) 166 | h = ResBlock( 167 | out_channels=self.d_channels * self.ch_mult[i_level], 168 | dropout_rate=self.dropout_rate, 169 | )(h, cond=cond, train=train) 170 | # Upsample 171 | if i_level != 0: 172 | h = ResBlock( 173 | dropout_rate=self.dropout_rate, 174 | resample='up', 175 | )(h, cond=cond, train=train) 176 | 177 | assert not hs 178 | 179 | # Predict noise 180 | h = nn.swish(CondGroupNorm()(h, cond=cond)) 181 | h = nn.Conv(output_channels, (3, 3), kernel_init=nn.initializers.zeros)(h) 182 | 183 | if self.add_input: 184 | h += x 185 | return h 186 | -------------------------------------------------------------------------------- /md4/sampling.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Sampling functions.""" 17 | 18 | import functools 19 | 20 | import jax 21 | import jax.numpy as jnp 22 | 23 | 24 | def get_attr(train_state, key): 25 | if hasattr(train_state, key): 26 | return getattr(train_state, key) 27 | else: 28 | return train_state[key] 29 | 30 | 31 | @functools.partial(jax.pmap, axis_name='batch', static_broadcasted_argnums=0) 32 | def generate(model, train_state, rng, dummy_inputs, conditioning=None): 33 | """Generate samples from the diffusion model.""" 34 | rng = jax.random.fold_in(rng, jax.lax.axis_index('batch')) 35 | variables = { 36 | 'params': get_attr(train_state, 'ema_params'), 37 | **get_attr(train_state, 'state'), 38 | } 39 | rng, sub_rng = jax.random.split(rng) 40 | zt = model.apply( 41 | variables, 42 | dummy_inputs.shape[0], 43 | method=model.prior_sample, 44 | rngs={'sample': sub_rng}, 45 | ) 46 | rng, sub_rng = jax.random.split(rng) 47 | 48 | def body_fn(i, zt): 49 | return model.apply( 50 | variables, 51 | sub_rng, 52 | i, 53 | model.timesteps, 54 | zt, 55 | conditioning=conditioning, 56 | method=model.sample_step, 57 | ) 58 | 59 | z0 = jax.lax.fori_loop( 60 | lower=0, upper=model.timesteps, body_fun=body_fn, init_val=zt 61 | ) 62 | return model.apply( 63 | variables, 64 | z0, 65 | conditioning=conditioning, 66 | method=model.decode, 67 | rngs={'sample': rng}, 68 | ) 69 | 70 | 71 | def simple_generate(rng, train_state, batch_size, model, conditioning=None): 72 | """Generate samples from the diffusion model.""" 73 | variables = {'params': train_state.params, **train_state.state} 74 | rng, sub_rng = jax.random.split(rng) 75 | zt = model.apply( 76 | variables, 77 | batch_size, 78 | method=model.prior_sample, 79 | rngs={'sample': sub_rng}, 80 | ) 81 | rng, sub_rng = jax.random.split(rng) 82 | 83 | def body_fn(i, zt): 84 | return model.apply( 85 | variables, 86 | sub_rng, 87 | i, 88 | model.timesteps, 89 | zt, 90 | conditioning=conditioning, 91 | method=model.sample_step, 92 | ) 93 | 94 | z0 = jax.lax.fori_loop( 95 | lower=0, upper=model.timesteps, body_fun=body_fn, init_val=zt 96 | ) 97 | return model.apply( 98 | variables, 99 | z0, 100 | conditioning=conditioning, 101 | method=model.decode, 102 | rngs={'sample': rng}, 103 | ) 104 | 105 | 106 | @functools.partial(jax.pmap, axis_name='batch', static_broadcasted_argnums=0) 107 | def reconstruct(model, train_state, rng, t, inputs, conditioning=None): 108 | """Reconstruct from the latent at t.""" 109 | rng = jax.random.fold_in(rng, jax.lax.axis_index('batch')) 110 | variables = { 111 | 'params': get_attr(train_state, 'ema_params'), 112 | **get_attr(train_state, 'state'), 113 | } 114 | f = model.apply(variables, inputs, conditioning, method=model.encode) 115 | 116 | timesteps = model.timesteps 117 | tn = jnp.ceil(t * timesteps).astype('int32') 118 | t = tn / timesteps 119 | rng, sub_rng = jax.random.split(rng) 120 | zt = model.apply( 121 | variables, f, t, method=model.forward_sample, rngs={'sample': sub_rng} 122 | ) 123 | rng, sub_rng = jax.random.split(rng) 124 | 125 | def body_fn(i, zt): 126 | return model.apply( 127 | variables, 128 | sub_rng, 129 | i, 130 | timesteps, 131 | zt, 132 | conditioning=conditioning, 133 | method=model.sample_step, 134 | ) 135 | 136 | z0 = jax.lax.fori_loop( 137 | lower=timesteps - tn, upper=timesteps, body_fun=body_fn, init_val=zt 138 | ) 139 | return model.apply( 140 | variables, 141 | z0, 142 | conditioning=conditioning, 143 | method=model.decode, 144 | rngs={'sample': rng}, 145 | ) 146 | -------------------------------------------------------------------------------- /md4/train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Methods for training MD4/GenMD4 on text/image datasets.""" 17 | 18 | from collections.abc import Callable, Mapping, Sequence 19 | import copy 20 | import functools 21 | from typing import Any 22 | 23 | from absl import logging 24 | from clu import metric_writers 25 | from clu import metrics 26 | from clu import parameter_overview 27 | from clu import periodic_actions 28 | from etils import epath 29 | import flax 30 | import flax.jax_utils as flax_utils 31 | import flax.linen as nn 32 | import grain.python as grain 33 | import jax 34 | from jax.experimental import checkify 35 | import jax.numpy as jnp 36 | import ml_collections 37 | import numpy as np 38 | import optax 39 | from orbax import checkpoint as orbax_checkpoint 40 | 41 | from md4 import input_pipeline 42 | from md4 import input_pipeline_v2 43 | from md4 import sampling 44 | from md4 import utils 45 | from md4.models import utils as model_utils 46 | 47 | 48 | @flax.struct.dataclass 49 | class TrainState: 50 | """State of the model and the training. 51 | 52 | This includes parameters, statistics and optimizer. 53 | """ 54 | 55 | rng: jnp.ndarray 56 | step: int 57 | params: Any 58 | ema_params: Any 59 | opt_state: optax.OptState 60 | state: Any 61 | 62 | 63 | def merge_batch_stats(replicated_state: TrainState) -> TrainState: 64 | """Merge model batch stats.""" 65 | if jax.tree.leaves(replicated_state.state): 66 | cross_replica_mean = jax.pmap(lambda x: jax.lax.pmean(x, "batch"), "batch") 67 | return replicated_state.replace( 68 | state=cross_replica_mean(replicated_state.state) 69 | ) 70 | else: 71 | return replicated_state 72 | 73 | 74 | def _get_checkpoint_manager( 75 | config: ml_collections.ConfigDict, workdir: epath.PathLike 76 | ) -> orbax_checkpoint.CheckpointManager: 77 | """Loads the orbax checkpoint manager for train state and data iterator.""" 78 | # The keys in this dict should match the keys in `checkpointed_state`. 79 | checkpointers = dict( 80 | train_state=orbax_checkpoint.PyTreeCheckpointer(), 81 | train_iter=orbax_checkpoint.Checkpointer( 82 | grain.PyGrainCheckpointHandler() 83 | ), # pytype:disable=wrong-arg-types 84 | ) 85 | checkpoint_dir = epath.Path(workdir) / "checkpoints" 86 | keep_period = ( 87 | config.checkpoint_keep_period 88 | if config.checkpoint_keep_period > 0 89 | else None 90 | ) 91 | return orbax_checkpoint.CheckpointManager( 92 | checkpoint_dir, 93 | checkpointers=checkpointers, 94 | options=orbax_checkpoint.CheckpointManagerOptions( 95 | max_to_keep=5, create=True, keep_period=keep_period 96 | ), 97 | ) 98 | 99 | 100 | def create_train_state( 101 | config: ml_collections.ConfigDict, 102 | rng: jnp.ndarray, 103 | input_shape: Sequence[int] | Mapping[str, Sequence[int]], 104 | schedule_fn: Callable[[Any], Any], 105 | ) -> tuple[nn.Module, optax.GradientTransformation, TrainState, Any]: 106 | """Create and initialize the model.""" 107 | model = model_utils.get_model(config) 108 | 109 | if config.classes > 0: 110 | conditioning = jnp.zeros(input_shape[0], dtype="int32") 111 | else: 112 | conditioning = None 113 | rng, sample_rng, init_rng = jax.random.split(rng, 3) 114 | dummy_input = jnp.ones(input_shape, dtype="int32") 115 | 116 | output, variables = model.init_with_output( 117 | {"sample": sample_rng, "params": init_rng}, 118 | dummy_input, 119 | cond=conditioning, 120 | train=False, 121 | ) 122 | metric_keys = sorted(list(output.keys()) + ["learning_rate"]) 123 | logging.info("metric_keys: %s", metric_keys) 124 | metrics_class = create_metrics_class_from_keys(metric_keys) 125 | state, params = flax.core.pop(variables, "params") 126 | del variables 127 | parameter_overview.log_parameter_overview( 128 | state, msg="############# state #############" 129 | ) 130 | parameter_overview.log_parameter_overview( 131 | params, msg="############# params #############" 132 | ) 133 | 134 | optimizer = optax.chain( 135 | optax.clip(config.clip) if config.clip > 0.0 else optax.identity(), 136 | optax.adamw( 137 | schedule_fn, 138 | b1=0.9, 139 | b2=config.b2, 140 | weight_decay=config.weight_decay, 141 | ), 142 | ) 143 | return ( 144 | model, 145 | optimizer, 146 | TrainState( 147 | step=0, 148 | rng=rng, 149 | params=params, 150 | ema_params=copy.deepcopy(params) if config.ema_rate > 0.0 else None, 151 | opt_state=optimizer.init(params), 152 | state=state, 153 | ), 154 | metrics_class, 155 | ) 156 | 157 | 158 | def create_metrics_class_from_keys(metric_keys): 159 | """Create train/eval metrics collection from dictionary.""" 160 | average_keys = [] 161 | stats = dict( 162 | (k, metrics.Average.from_output(k)) 163 | if (k in average_keys) or ("loss" in k) 164 | else (k, metrics.LastValue.from_output(k)) 165 | for k in metric_keys 166 | ) 167 | return metrics.Collection.create(**stats) 168 | 169 | 170 | def cosine_decay(lr: float, current_step: float, total_steps: float) -> float: 171 | ratio = jnp.maximum(0.0, current_step / total_steps) 172 | mult = 0.5 * (1.0 + jnp.cos(jnp.pi * ratio)) 173 | return mult * lr # pytype: disable=bad-return-type # jax-types 174 | 175 | 176 | def get_learning_rate( 177 | step: int, 178 | *, 179 | base_learning_rate: float, 180 | num_steps: int, 181 | warmup_steps: int | None = None, 182 | schedule_type: str = "cosine", 183 | ) -> float: 184 | """Cosine learning rate schedule.""" 185 | logging.info( 186 | "get_learning_rate(step=%s, base_learning_rate=%s, num_steps=%s", 187 | step, 188 | base_learning_rate, 189 | num_steps, 190 | ) 191 | warmup = jnp.minimum(1.0, step / warmup_steps) 192 | if schedule_type == "cosine": 193 | lr = cosine_decay( 194 | base_learning_rate, step - warmup_steps, num_steps - warmup_steps 195 | ) 196 | elif schedule_type == "constant": 197 | lr = base_learning_rate 198 | else: 199 | raise NotImplementedError() 200 | return lr * warmup # pytype: disable=bad-return-type # jax-types 201 | 202 | 203 | def loss_fn(params, state, rng, model, batch, train=False): 204 | """Loss function.""" 205 | rng, sample_rng = jax.random.split(rng) 206 | rngs = {"sample": sample_rng} 207 | if train: 208 | _, dropout_rng = jax.random.split(rng) 209 | rngs["dropout"] = dropout_rng 210 | 211 | variables = {"params": params, **state} 212 | if "image" in batch: 213 | x = batch["image"] 214 | elif "text" in batch: 215 | x = batch["text"] 216 | else: 217 | raise ValueError("Unsupported targets/tasks.") 218 | 219 | if "label" in batch: 220 | conditioning = batch["label"].astype("int32") 221 | else: 222 | conditioning = None 223 | 224 | new_state = {} 225 | if train: 226 | metrics_dict, new_state = model.apply( 227 | variables, 228 | x, 229 | cond=conditioning, 230 | train=train, 231 | rngs=rngs, 232 | mutable=list(state.keys()), 233 | ) 234 | else: 235 | metrics_dict = model.apply( 236 | variables, x, cond=conditioning, train=train, rngs=rngs 237 | ) 238 | 239 | loss = metrics_dict["loss"] 240 | if train: 241 | return loss, (new_state, metrics_dict) 242 | return loss, metrics_dict 243 | 244 | 245 | @jax.jit 246 | def merge_metrics(a_tree, b_tree): 247 | return jax.tree.map(lambda a, b: a + b, a_tree, b_tree) 248 | 249 | 250 | def train_step( 251 | model: nn.Module, 252 | optimizer: optax.GradientTransformation, 253 | train_state: TrainState, 254 | batch: Mapping[str, jnp.ndarray], 255 | learning_rate_fn: Callable[[int], float], 256 | train_metrics_class: Any, 257 | ema_rate: float = 0.0, 258 | num_microbatches: int | None = None, 259 | ) -> tuple[TrainState, metrics.Collection]: 260 | """Perform a single training step.""" 261 | logging.info("train_step(batch=%s)", batch) 262 | rng, new_rng = jax.random.split(train_state.rng) 263 | rng = jax.random.fold_in(rng, jax.lax.axis_index("batch")) 264 | 265 | grad_fn = jax.value_and_grad(loss_fn, has_aux=True) 266 | if num_microbatches is None or num_microbatches <= 1: 267 | (_, (new_state, metrics_dict)), grads = grad_fn( 268 | train_state.params, train_state.state, rng, model, batch, train=True 269 | ) 270 | else: 271 | batch_size = next(iter(batch.values())).shape[0] 272 | assert ( 273 | batch_size % num_microbatches == 0 274 | ), "Batch size isn't divided evenly by num_microbatches." 275 | microbatch_size = batch_size // num_microbatches 276 | logging.info( 277 | "using microbatches: %d microbatches, %d size", 278 | num_microbatches, 279 | microbatch_size, 280 | ) 281 | 282 | def get_microbatch( 283 | batch: Mapping[str, jnp.ndarray], idx: int 284 | ) -> Mapping[str, jnp.ndarray]: 285 | """Fetch microbatch slice from possibly-packed input data.""" 286 | offset = idx * microbatch_size 287 | length = microbatch_size 288 | starts = {k: [offset] + [0] * (b.ndim - 1) for k, b in batch.items()} 289 | limits = {k: [length] + list(b.shape[1:]) for k, b in batch.items()} 290 | return { 291 | k: jax.lax.dynamic_slice(b, starts[k], limits[k]) 292 | for k, b in batch.items() 293 | } 294 | 295 | def metrics_and_grad(loop_cnt, rng, train_state_state): 296 | _, mbrng = jax.random.split(rng) 297 | mb = get_microbatch(batch, loop_cnt) 298 | 299 | (_, (new_state, metrics_dict)), grads = grad_fn( 300 | train_state.params, train_state_state, mbrng, model, mb, train=True 301 | ) 302 | return metrics_dict, grads, new_state 303 | 304 | def per_microbatch_train_step(loop_cnt, carry): 305 | (rng, grad_accum, prev_metrics_dict, train_state_state) = carry 306 | metrics_dict, grads, train_state_state = metrics_and_grad( 307 | loop_cnt, rng, train_state_state 308 | ) 309 | 310 | grad_accum = jax.tree.map(jnp.add, grad_accum, grads) 311 | metrics_dict = jax.lax.cond( 312 | loop_cnt == 0, 313 | lambda _: metrics_dict, 314 | lambda _: merge_metrics(prev_metrics_dict, metrics_dict), 315 | None, 316 | ) 317 | return rng, grad_accum, metrics_dict, train_state_state 318 | 319 | # Initialize gradient accumulation loop state. 320 | accum_dtype = jnp.float32 321 | grad_accum_init = jax.tree.map( 322 | lambda x: jnp.zeros(x.shape, accum_dtype), train_state.params 323 | ) 324 | initial_metrics_shape, _, _ = jax.eval_shape( 325 | metrics_and_grad, 326 | loop_cnt=0, 327 | rng=rng, 328 | train_state_state=train_state.state, 329 | ) 330 | 331 | initial_metrics = { 332 | k: jnp.zeros(shape=v.shape, dtype=v.dtype) 333 | for k, v in initial_metrics_shape.items() 334 | } 335 | 336 | loop_init = ( 337 | rng, 338 | grad_accum_init, 339 | initial_metrics, 340 | train_state.state, 341 | ) 342 | _, grads, metrics_dict, train_state_state = jax.lax.fori_loop( 343 | 0, num_microbatches, per_microbatch_train_step, loop_init 344 | ) 345 | metrics_dict = jax.tree.map(lambda x: x / num_microbatches, metrics_dict) 346 | new_state = train_state_state 347 | 348 | # Compute average gradient across multiple workers. 349 | grads = jax.lax.pmean(grads, axis_name="batch") 350 | updates, new_opt_state = optimizer.update( 351 | grads, train_state.opt_state, train_state.params 352 | ) 353 | new_params = optax.apply_updates(train_state.params, updates) 354 | if ema_rate > 0.0: 355 | new_ema_params = jax.tree_util.tree_map( 356 | lambda x, y: x + (1.0 - ema_rate) * (y - x), 357 | train_state.ema_params, 358 | new_params, 359 | ) 360 | else: 361 | new_ema_params = None 362 | new_train_state = train_state.replace( 363 | step=train_state.step + 1, 364 | rng=new_rng, 365 | params=new_params, 366 | ema_params=new_ema_params, 367 | opt_state=new_opt_state, 368 | state=new_state, 369 | ) 370 | 371 | metrics_update = train_metrics_class.gather_from_model_output( 372 | learning_rate=learning_rate_fn(train_state.step), 373 | **metrics_dict, 374 | ) 375 | return new_train_state, metrics_update 376 | 377 | 378 | def eval_step( 379 | model: nn.Module, 380 | rng: jnp.ndarray, 381 | train_state: TrainState, 382 | batch: Mapping[str, jnp.ndarray], 383 | eval_metrics_class: Any, 384 | ema_rate: float = 0.0, 385 | ) -> metrics.Collection: 386 | """Compute the metrics for the given model in inference mode.""" 387 | logging.info("eval_step(batch=%s)", batch) 388 | rng = jax.random.fold_in(rng, jax.lax.axis_index("batch")) 389 | params = train_state.ema_params if ema_rate > 0.0 else train_state.params 390 | 391 | _, metrics_dict = loss_fn( 392 | params, train_state.state, rng, model, batch, train=False 393 | ) 394 | return eval_metrics_class.gather_from_model_output( 395 | learning_rate=0.0, **metrics_dict 396 | ) 397 | 398 | 399 | def evaluate( 400 | p_eval_step: Any, 401 | rng: jnp.ndarray, 402 | train_state: TrainState, 403 | eval_loader: grain.DataLoader, 404 | num_eval_steps: int = -1, 405 | ): 406 | """Evaluate the model on the given dataset.""" 407 | logging.info("Starting evaluation.") 408 | eval_metrics = None 409 | with utils.StepTraceContextHelper("eval", 0) as trace_context: 410 | # Use `iter` to reset the eval_loader before each evaluation. 411 | for step, batch in enumerate(iter(eval_loader)): 412 | rng, sub_rng = jax.random.split(rng) 413 | sub_rng = flax_utils.replicate(sub_rng) 414 | batch = utils.reshape_batch(batch) 415 | metrics_update = flax_utils.unreplicate( 416 | p_eval_step(rng=sub_rng, train_state=train_state, batch=batch) 417 | ) 418 | eval_metrics = ( 419 | metrics_update 420 | if eval_metrics is None 421 | else eval_metrics.merge(metrics_update) 422 | ) 423 | if num_eval_steps > 0 and step + 1 == num_eval_steps: 424 | break 425 | trace_context.next_step() 426 | if eval_metrics is None: 427 | raise ValueError(f"Eval dataset {eval_loader} was empty.") 428 | return eval_metrics 429 | 430 | 431 | def train_and_evaluate( 432 | config: ml_collections.ConfigDict, workdir: epath.PathLike 433 | ): 434 | """Runs a training and evaluation loop. 435 | 436 | Args: 437 | config: Configuration to use. 438 | workdir: Working directory for checkpoints and TF summaries. If this 439 | contains checkpoint training will be resumed from the latest checkpoint. 440 | """ 441 | workdir = epath.Path(workdir) 442 | workdir.mkdir(parents=True, exist_ok=True) 443 | 444 | rng = utils.get_rng(config.seed) 445 | logging.info("Using random seed %s.", rng) 446 | writer = metric_writers.create_default_writer( 447 | workdir, just_logging=jax.process_index() > 0 448 | ) 449 | # Learning rate schedule. 450 | assert config.batch_size % jax.device_count() == 0 451 | per_device_batch_size = config.batch_size // jax.device_count() 452 | num_train_steps = input_pipeline.get_num_train_steps(config) 453 | steps_per_epoch = num_train_steps // config.num_epochs 454 | logging.info( 455 | "num_train_steps=%d, steps_per_epoch=%d", num_train_steps, steps_per_epoch 456 | ) 457 | schedule_fn = functools.partial( 458 | get_learning_rate, 459 | base_learning_rate=config.learning_rate, 460 | num_steps=num_train_steps, 461 | warmup_steps=config.warmup_steps, 462 | schedule_type=config.learning_rate_schedule, 463 | ) 464 | 465 | # Build input pipeline. 466 | rng, data_seed = jax.random.split(rng) 467 | data_seed = int( 468 | jax.random.randint(data_seed, [], minval=0, maxval=np.iinfo(np.int32).max) 469 | ) 470 | # The input pipeline runs on each process and loads data for local TPUs. 471 | create_datasets = ( 472 | input_pipeline_v2.create_datasets 473 | if config.get("use_v2_input_pipeline", None) 474 | else input_pipeline.create_datasets 475 | ) 476 | train_loader, eval_loaders, dataset_info = create_datasets(config, data_seed) 477 | train_iter = iter(train_loader) 478 | 479 | # Initialize model. 480 | rng, model_rng = jax.random.split(rng) 481 | data_shape = input_pipeline.get_data_shape(config) 482 | model, optimizer, train_state, metrics_class = create_train_state( # pylint: disable=invalid-name 483 | config, 484 | model_rng, 485 | input_shape=(per_device_batch_size,) + data_shape, 486 | schedule_fn=schedule_fn, 487 | ) 488 | 489 | # Set up checkpointing of the model and the input pipeline. 490 | checkpoint_manager = _get_checkpoint_manager(config, workdir) 491 | 492 | # Retrieve data from previous checkpoints if possible. 493 | checkpointed_state = dict(train_state=train_state, train_iter=train_iter) 494 | if checkpoint_manager.latest_step() is not None: 495 | checkpointed_state = checkpoint_manager.restore( 496 | checkpoint_manager.latest_step(), items=checkpointed_state 497 | ) 498 | train_state = checkpointed_state["train_state"] 499 | train_iter = checkpointed_state["train_iter"] 500 | 501 | # Distribute training. 502 | train_state = flax_utils.replicate(train_state) 503 | train_step_func = functools.partial( 504 | train_step, 505 | model=model, 506 | optimizer=optimizer, 507 | train_metrics_class=metrics_class, 508 | learning_rate_fn=schedule_fn, 509 | ema_rate=config.ema_rate, 510 | num_microbatches=config.get("num_microbatches", None), 511 | ) 512 | if config.check_nans: 513 | train_step_func = checkify.checkify( 514 | train_step_func, errors=checkify.float_checks 515 | ) 516 | p_train_step = jax.pmap( 517 | train_step_func, axis_name="batch", donate_argnums=(0,) 518 | ) 519 | p_eval_step = jax.pmap( 520 | functools.partial( 521 | eval_step, 522 | model=model, 523 | eval_metrics_class=metrics_class, 524 | ema_rate=config.ema_rate, 525 | ), 526 | axis_name="batch", 527 | ) 528 | 529 | hooks = [] 530 | report_progress = periodic_actions.ReportProgress( 531 | num_train_steps=num_train_steps, writer=writer 532 | ) 533 | if jax.process_index() == 0: 534 | hooks += [ 535 | report_progress, 536 | periodic_actions.Profile(num_profile_steps=5, logdir=workdir), 537 | ] 538 | train_metrics = None 539 | 540 | # Unreplicating from TPU is costly, so we only do it once at the start. 541 | initial_step = int(flax.jax_utils.unreplicate(train_state.step)) 542 | 543 | with metric_writers.ensure_flushes(writer): 544 | # Steps are in interval [1, num_train_steps], not [0, num_train_steps - 1]. 545 | for step in range(initial_step + 1, num_train_steps + 1): 546 | is_last_step = step == num_train_steps 547 | 548 | with jax.profiler.StepTraceAnnotation("train", step_num=step): 549 | batch = utils.reshape_batch(next(train_iter)) 550 | if config.check_nans: 551 | errs, (train_state, metrics_update) = p_train_step( 552 | train_state=train_state, batch=batch 553 | ) 554 | errs.throw() 555 | else: 556 | train_state, metrics_update = p_train_step( 557 | train_state=train_state, batch=batch 558 | ) 559 | metric_update = flax_utils.unreplicate(metrics_update) 560 | train_metrics = ( 561 | metric_update 562 | if train_metrics is None 563 | else train_metrics.merge(metric_update) 564 | ) 565 | 566 | # Quick indication that training is happening. 567 | logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step) 568 | for h in hooks: 569 | h(step) 570 | 571 | if step % config.log_loss_every_steps == 0 or is_last_step: 572 | writer.write_scalars(step, train_metrics.compute()) 573 | train_metrics = None 574 | 575 | if step == 1 or step % config.eval_every_steps == 0 or is_last_step: 576 | for split, eval_loader in eval_loaders.items(): 577 | rng, eval_rng = jax.random.split(rng) 578 | with report_progress.timed("eval"): 579 | train_state = merge_batch_stats(train_state) 580 | eval_metrics = evaluate( 581 | p_eval_step, 582 | eval_rng, 583 | train_state, 584 | eval_loader, 585 | config.num_eval_steps, 586 | ) 587 | eval_metrics_cpu = jax.tree_util.tree_map( 588 | np.array, eval_metrics.compute() 589 | ) 590 | eval_metrics_cpu = { 591 | split + "_" + k: v for k, v in eval_metrics_cpu.items() 592 | } 593 | writer.write_scalars(step, eval_metrics_cpu) 594 | 595 | if hasattr(model, "sample_step"): 596 | with report_progress.timed("sample"): 597 | _, sample_rng = jax.random.split(rng) 598 | dummy_loader = train_loader 599 | dummy_batch = utils.reshape_batch(next(iter(dummy_loader))) 600 | dummy_inputs = dummy_batch[config.task_type] 601 | if "label" in dummy_batch: 602 | conditioning = dummy_batch["label"].astype("int32") 603 | else: 604 | conditioning = None 605 | 606 | samples = sampling.generate( 607 | model, 608 | train_state, 609 | flax_utils.replicate(sample_rng), 610 | dummy_inputs, 611 | conditioning=conditioning, 612 | ) 613 | 614 | all_samples = jax.pmap( 615 | lambda x: jax.lax.all_gather(x, "batch"), axis_name="batch" 616 | )(samples) 617 | all_samples = flax_utils.unreplicate(all_samples) 618 | all_samples = all_samples.reshape(-1, *data_shape) 619 | if config.task_type == "image": 620 | sample_grid = utils.generate_image_grids(all_samples) 621 | writer.write_images(step, {"samples": sample_grid}) 622 | del all_samples, sample_grid 623 | elif config.task_type == "text": 624 | tokenizer = dataset_info["tokenizer"] 625 | texts = utils.detokenize_texts(all_samples, tokenizer) 626 | writer.write_texts(step, {"samples": texts}) 627 | 628 | if step % config.checkpoint_every_steps == 0 or is_last_step: 629 | with report_progress.timed("checkpoint"): 630 | train_state = merge_batch_stats(train_state) 631 | checkpoint_manager.save( 632 | step, 633 | items=dict( 634 | train_state=jax.tree_util.tree_map( 635 | np.array, flax_utils.unreplicate(train_state) 636 | ), 637 | train_iter=train_iter, 638 | ), 639 | ) 640 | 641 | logging.info("Finishing training at step %d", num_train_steps) 642 | -------------------------------------------------------------------------------- /md4/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utils.""" 17 | 18 | from collections.abc import Mapping 19 | import math 20 | import time 21 | from typing import Any, Tuple 22 | from absl import logging 23 | import chex 24 | from clu import platform 25 | import distrax 26 | import jax 27 | import jax.numpy as jnp 28 | import matplotlib.pyplot as plt 29 | import numpy as np 30 | from orbax import checkpoint as orbax_checkpoint 31 | import seaborn as sns 32 | 33 | 34 | def loss2bpt(loss_dict, data_shape): 35 | """Normalize loss to bits per token.""" 36 | seq_len = jnp.prod(jnp.array(data_shape)) 37 | rescale_to_bpd = 1.0 / (seq_len * jnp.log(2.0)) 38 | bpt_loss_dict = {} 39 | for k, v in loss_dict.items(): 40 | if "loss" in k: 41 | bpt_loss_dict[k] = v * rescale_to_bpd 42 | else: 43 | bpt_loss_dict[k] = v 44 | return bpt_loss_dict 45 | 46 | 47 | def constant_init(value, dtype="float32"): 48 | def _init(key, shape, dtype=dtype): 49 | del key 50 | return value * jnp.ones(shape, dtype) 51 | 52 | return _init 53 | 54 | 55 | def _logistic_pdf_fn(z, log_scales): 56 | return -z - 2.0 * jax.nn.softplus(-z) - log_scales 57 | 58 | 59 | class DiscretizedLogisticMixture(distrax.Distribution): 60 | """Discretized mixture of Logistics defined in PixelCNN++.""" 61 | 62 | def __init__( 63 | self, 64 | w_logits, 65 | locs, 66 | log_scales, 67 | min_val=0.0, 68 | max_val=255.0, 69 | bin_size=1.0, 70 | ): 71 | self._w_logits = w_logits 72 | self._locs = locs 73 | self._log_scales = log_scales 74 | self._min_val = min_val 75 | self._max_val = max_val 76 | self._bin_size = bin_size 77 | 78 | self._batch_shape = jax.lax.broadcast_shapes( 79 | self._locs.shape[:-1], self._log_scales.shape[:-1] 80 | ) 81 | 82 | self._cdf = jax.nn.sigmoid 83 | self._log_prob_fn = _logistic_pdf_fn 84 | 85 | @property 86 | def event_shape(self) -> Tuple[int, ...]: 87 | """Shape of event of distribution samples.""" 88 | return () 89 | 90 | @property 91 | def batch_shape(self) -> Tuple[int, ...]: 92 | return self._batch_shape 93 | 94 | @property 95 | def _n_components(self) -> int: 96 | return self._locs.shape[-1] 97 | 98 | def _mean(self): 99 | # The expected value of the mixture is the sum of the linear probabilities 100 | # of each component multiplied by their means. 101 | # Apply softmax on the logits to get the linear probabilities. 102 | probs = jax.nn.softmax(self._w_logits) 103 | return jnp.sum(self._locs * probs, axis=-1) 104 | 105 | def log_prob(self, event: chex.Array) -> chex.Array: 106 | # expand the mixture dim 107 | event = jnp.expand_dims(event, -1) 108 | assert len(self._locs.shape) == len(event.shape) 109 | 110 | # Expand the dimensions of the params for tiling and broadcasting. 111 | locs = self._locs 112 | w_logits = self._w_logits 113 | inv_scales = jnp.exp(-self._log_scales) 114 | 115 | # pdf at the mid of the bin, used when bins are too small 116 | z = (event - locs) * inv_scales 117 | mid_log_prob = self._log_prob_fn(z, self._log_scales) 118 | 119 | # Calculate difference of sigmoid. 120 | half_bin = self._bin_size / 2 121 | b = (event - locs + half_bin) * inv_scales 122 | a = (event - locs - half_bin) * inv_scales 123 | a = self._cdf(a) 124 | b = self._cdf(b) 125 | diff = b - a 126 | 127 | # Handle edge case. 128 | edge_b = ( 129 | jax.nn.sigmoid((self._min_val - locs + half_bin) * inv_scales) - 0.0 130 | ) 131 | edge_a = 1.0 - jax.nn.sigmoid( 132 | (self._max_val - locs - half_bin) * inv_scales 133 | ) 134 | diff = jnp.where(event > self._max_val - half_bin, edge_a, diff) 135 | diff = jnp.where(event < self._min_val + half_bin, edge_b, diff) 136 | 137 | # Avoid small values for the subsequent log operation. 138 | diff = jnp.maximum(diff, 1e-12) 139 | log_prob = jnp.where(diff > 1e-8, jnp.log(diff), mid_log_prob) 140 | 141 | # Normalize logits. 142 | w_logits -= jax.nn.logsumexp(w_logits, axis=-1, keepdims=True) 143 | 144 | # Total loss, summed over the mixture dimension. 145 | total = w_logits + log_prob.sum(axis=-2) 146 | total = jax.nn.logsumexp(total, -1) 147 | return total 148 | 149 | def _sample_n(self, key: chex.PRNGKey, n: int): 150 | # First sample from the mixing weights. 151 | w_dist = distrax.Categorical(logits=self._w_logits) 152 | key, sub_key = jax.random.split(key) 153 | index = w_dist.sample(seed=sub_key, sample_shape=n) 154 | index = jax.nn.one_hot(index, num_classes=self._n_components) 155 | 156 | # Pick the mixture (per pixel) and board cast to n samples. 157 | log_scales = jnp.sum(jnp.expand_dims(self._log_scales, 0) * index, -1) 158 | loc = jnp.sum(jnp.expand_dims(self._locs, 0) * index, -1) 159 | scales = jnp.exp(log_scales) 160 | 161 | # Compute logistic 162 | _, sub_key = jax.random.split(key) 163 | logistic_sample = jax.random.logistic(sub_key, shape=loc.shape) 164 | sample_values = loc + scales * logistic_sample 165 | 166 | return jnp.clip(jnp.round(sample_values), 0, 255).astype("int32") 167 | 168 | 169 | def shifted_softplus(x, b=1.0): 170 | """log(exp(x) + b).""" 171 | return x + jax.nn.softplus(jnp.log(b) - x) 172 | 173 | 174 | def reverse_broadcast(value, ndim): 175 | """Broadcast by adding singleton axes to the right, instead of to the left.""" 176 | if value.ndim > ndim: 177 | raise ValueError( 178 | f"Cannot reverse broadcast a value with {value.ndim} dimensions " 179 | f"to {ndim} dimensions." 180 | ) 181 | 182 | if value.ndim < ndim: 183 | difference = ndim - value.ndim 184 | return value.reshape(value.shape + difference * (1,)) 185 | else: 186 | return value 187 | 188 | 189 | def get_rng(seed: None | int | tuple[int, int]) -> np.ndarray: 190 | """Returns a JAX RNGKey.""" 191 | if seed is None: 192 | # Case 1: No random seed given, use XManager ID. 193 | # All processes (and restarts) get exactly the same seed but every work unit 194 | # and experiment is different. 195 | work_unit = platform.work_unit() 196 | rng = (work_unit.experiment_id, work_unit.id) 197 | elif isinstance(seed, int): 198 | # Case 2: Single integer given. 199 | rng = (0, seed) 200 | else: 201 | # Case 3: tuple[int, int] given. 202 | if not isinstance(seed, (tuple, list)) or len(seed) != 2: 203 | raise ValueError( 204 | "Random seed must be an integer or tuple of 2 integers " 205 | f"but got {seed!r}" 206 | ) 207 | rng = seed 208 | # JAX RNGKeys are arrays of np.uint32 and shape [2]. 209 | return np.asarray(rng, dtype=np.uint32) 210 | 211 | 212 | class StepTraceContextHelper: 213 | """Helper class to use jax.profiler.StepTraceAnnotation.""" 214 | 215 | def __init__(self, name: str, init_step_num: int): 216 | self.name = name 217 | self.step_num = init_step_num 218 | self.context = None 219 | 220 | def __enter__(self): 221 | self.context = jax.profiler.StepTraceAnnotation( 222 | self.name, step_num=self.step_num 223 | ) 224 | self.step_num += 1 225 | self.context.__enter__() 226 | return self 227 | 228 | def __exit__(self, exc_type, exc_value, tb): 229 | assert self.context is not None, "Exited context without entering." 230 | self.context.__exit__(exc_type, exc_value, tb) 231 | self.context = None 232 | 233 | def next_step(self): 234 | if self.context is None: 235 | raise ValueError("Must call next_step() within a context.") 236 | self.__exit__(None, None, None) 237 | self.__enter__() 238 | 239 | 240 | def plot_embeddings(step, workdir, embeddings, annotations=None): 241 | """Helper function to plot embeddings.""" 242 | fig, ax = plt.subplots() 243 | ax.set_title("Embeddings") 244 | if embeddings.ndim == 1: 245 | ax.scatter(np.arange(256), embeddings) 246 | else: 247 | assert embeddings.ndim == 2 248 | colors = np.linspace(0, 1, embeddings.shape[0]) 249 | ax.scatter(embeddings[:, 0], embeddings[:, 1], c=colors, cmap="rainbow") 250 | if annotations: 251 | for i in range(embeddings.shape[0]): 252 | ax.annotate(annotations[i], (embeddings[i, 0], embeddings[i, 1])) 253 | embedding_plot = workdir / "embedding_{}.png".format(step) 254 | with embedding_plot.open("wb") as f: 255 | fig.savefig(f) 256 | 257 | 258 | def plot_heatmap(step, workdir, emb_gram_matrix, token_labels=None): 259 | """Helper function to plot embeddings.""" 260 | fig, ax = plt.subplots() 261 | assert emb_gram_matrix.ndim == 2 262 | if token_labels: 263 | _ = sns.heatmap(emb_gram_matrix, linewidth=0.5, ax=ax, 264 | xticklabels=token_labels, 265 | yticklabels=token_labels) 266 | else: 267 | _ = sns.heatmap(emb_gram_matrix, linewidth=0.5, ax=ax) 268 | 269 | plt.xticks(rotation=90) 270 | plt.yticks(rotation=0) 271 | 272 | heatmap_plot = workdir / "embedding_heatmap_{}.png".format(step) 273 | with heatmap_plot.open("wb") as f: 274 | fig.savefig(f) 275 | 276 | 277 | def generate_image_grids(images): 278 | """Simple helper to generate a single image from a mini batch.""" 279 | 280 | def image_grid(nrow, ncol, imagevecs, imshape): 281 | images = iter(imagevecs.reshape((-1,) + imshape)) 282 | return jnp.vstack([ 283 | jnp.hstack([next(images) for _ in range(ncol)][::-1]) 284 | for _ in range(nrow) 285 | ]) 286 | 287 | batch_size = images.shape[0] 288 | grid_size = int(np.floor(np.sqrt(batch_size))) 289 | 290 | image_shape = images.shape[1:] 291 | return image_grid( 292 | nrow=grid_size, 293 | ncol=grid_size, 294 | imagevecs=images[0 : grid_size**2], 295 | imshape=image_shape, 296 | ).astype("uint8") 297 | 298 | 299 | def detokenize_texts(tokens, tokenizer): 300 | """Detokenize the outputs.""" 301 | 302 | assert len(tokens.shape) == 2, "Invalid token shape." 303 | 304 | np_tokens = np.asarray(tokens) 305 | detokenized = np.apply_along_axis(tokenizer.decode, -1, np_tokens) 306 | 307 | return detokenized 308 | 309 | 310 | def get_topk_token_mask(tokenizer, k=100): 311 | """Get the indices of Top-K tokens.""" 312 | 313 | id_unigram_scores = [ 314 | (id_, math.exp(tokenizer._model.GetScore(id_))) # pylint: disable=protected-access 315 | for id_ in range(tokenizer.vocab_size)] 316 | 317 | id_unigram_scores_sorted = sorted(id_unigram_scores, 318 | key=lambda x: x[1]) 319 | 320 | # Exact k elements. 321 | topk = id_unigram_scores_sorted[-k:] 322 | 323 | topk_mask = [False for _ in range(tokenizer.vocab_size)] 324 | topk_ids = [] 325 | 326 | for id_, _ in topk: 327 | topk_mask[id_] = True 328 | topk_ids.append(id_) 329 | 330 | topk_tokens = tokenizer.to_string_list(np.array(topk_ids)) 331 | 332 | return np.array(topk_mask), topk_tokens 333 | 334 | 335 | def reshape_batch(batch: Mapping[str, Any]) -> Mapping[str, np.ndarray]: 336 | """Reshapes a batch to have the leading dimension for the local devices.""" 337 | leading_dims = [jax.local_device_count(), -1] 338 | return jax.tree_util.tree_map( 339 | lambda x: np.reshape(x, leading_dims + list(x.shape[1:])), batch 340 | ) 341 | 342 | 343 | def checkpoints_iterator( 344 | ckpt_manager, timeout=None, min_interval_secs=0, period=10000 345 | ): 346 | """Repeatedly yield new checkpoints as they appear. 347 | 348 | Args: 349 | ckpt_manager: CheckpointManager object. 350 | timeout: int: maximum number of seconds to wait. If left as `None`, then the 351 | process will wait indefinitely. 352 | min_interval_secs: int: minimum number of seconds between yielding 353 | checkpoints. 354 | period: The period of the checkpoint. 355 | 356 | Yields: 357 | new checkpoint step. 358 | """ 359 | last_step = None 360 | while True: 361 | cur_step = wait_for_new_checkpoint( 362 | ckpt_manager, last_step, timeout=timeout, period=period 363 | ) 364 | if cur_step is None: 365 | # timed out 366 | logging.info("Timed-out waiting for a checkpoint.") 367 | return 368 | start = time.time() 369 | last_step = cur_step 370 | 371 | yield cur_step 372 | 373 | time_to_next_eval = start + min_interval_secs - time.time() 374 | if time_to_next_eval > 0: 375 | time.sleep(time_to_next_eval) 376 | 377 | 378 | def wait_for_new_checkpoint( 379 | ckpt_manager: orbax_checkpoint.CheckpointManager, 380 | last_step=None, 381 | seconds_to_sleep=1, 382 | timeout=None, 383 | period=10000, 384 | ): 385 | """Waits until a new checkpoint file is found. 386 | 387 | Args: 388 | ckpt_manager: The directory in which checkpoints are saved. 389 | last_step: The last checkpoint path used or `None` if we're expecting a 390 | checkpoint for the first time. 391 | seconds_to_sleep: The number of seconds to sleep for before looking for a 392 | new checkpoint. 393 | timeout: The maximum number of seconds to wait. If left as `None`, then the 394 | process will wait indefinitely. 395 | period: The period of the checkpoint. 396 | 397 | Returns: 398 | a new checkpoint path, or None if the timeout was reached. 399 | """ 400 | logging.info("Waiting for new checkpoint at %s", ckpt_manager.directory) 401 | stop_time = time.time() + timeout if timeout is not None else None 402 | while True: 403 | ckpt_manager.reload() 404 | cur_step = ckpt_manager.latest_step() 405 | if cur_step is None or cur_step == last_step or cur_step % period != 0: 406 | if stop_time is not None and time.time() + seconds_to_sleep > stop_time: 407 | return None 408 | time.sleep(seconds_to_sleep) 409 | else: 410 | logging.info("Found new checkpoint at step %d", cur_step) 411 | return cur_step 412 | -------------------------------------------------------------------------------- /prepare_openwebtext_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Prepares the input pipeline for OpenWebText (OWT). 17 | 18 | This script tokenizes the OWT dataset and splits it into train and eval sets. 19 | The train and eval sets are saved as ArrayRecord files. 20 | """ 21 | 22 | from array_record.python import array_record_module 23 | import datasets 24 | import numpy as np 25 | import tensorflow as tf 26 | import tqdm 27 | import transformers 28 | 29 | 30 | source = datasets.load_dataset( 31 | "Skylion007/openwebtext", name="plain_text", split="train", streaming=True 32 | ) 33 | 34 | _GPT2_TOKENIZER = "gpt2" 35 | tokenizer = transformers.GPT2Tokenizer.from_pretrained(_GPT2_TOKENIZER) 36 | 37 | ArrayRecordWriter = array_record_module.ArrayRecordWriter 38 | ArrayRecordReader = array_record_module.ArrayRecordReader 39 | 40 | 41 | def _int64_feature(value): 42 | """Returns an int64_list from a bool / enum / int / uint.""" 43 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 44 | 45 | 46 | ds_output_file_train = "./data_dir/openwebtext_splits_1024_train" 47 | ds_output_file_eval = "./data_dir/openwebtext_splits_1024_eval" 48 | 49 | n_examples = 8013769 # tiny: 2; small: 10_000; full: 8013769 50 | save_every_examples = 10_000 51 | block_size = 1024 # size of the chunk 52 | 53 | data_iter = (iter(source)) 54 | 55 | all_tokens = [] 56 | count = 0 57 | count_per_save = 0 58 | eval_chunks = [] 59 | 60 | writer_train = ArrayRecordWriter(ds_output_file_train, "group_size:1") 61 | writer_eval = ArrayRecordWriter(ds_output_file_eval, "group_size:1") 62 | 63 | for example in data_iter: 64 | tokens = tokenizer(example["text"])["input_ids"] 65 | all_tokens.extend(tokens + [tokenizer.eos_token_id]) 66 | count += 1 67 | count_per_save += 1 68 | 69 | # pause to save when having tokenized enough examples for saving. 70 | if count_per_save >= save_every_examples: 71 | # save to disk 72 | saved_length = (len(all_tokens) // block_size) * block_size 73 | chunks = [ 74 | all_tokens[i : i + block_size] 75 | for i in range(0, saved_length, block_size) 76 | ] 77 | 78 | print(f"\nsaving chunks @ {count}th example mark...") 79 | np.random.shuffle(chunks) 80 | num_eval = int(len(chunks) * 0.02) # put 2% of chunks into eval split. 81 | for eval_i in tqdm.tqdm(range(num_eval)): 82 | feature = { 83 | "text": _int64_feature(chunks[eval_i]), 84 | } 85 | example_proto = tf.train.Example( 86 | features=tf.train.Features(feature=feature) 87 | ) 88 | writer_eval.write(example_proto.SerializeToString()) 89 | 90 | for train_i in tqdm.tqdm(range(num_eval, len(chunks))): 91 | feature = { 92 | "text": _int64_feature(chunks[train_i]), 93 | } 94 | example_proto = tf.train.Example( 95 | features=tf.train.Features(feature=feature) 96 | ) 97 | writer_train.write(example_proto.SerializeToString()) 98 | 99 | # prepare for the next round of tokenize-n-save. 100 | all_tokens = all_tokens[saved_length:] 101 | count_per_save = 0 102 | 103 | # stop when having tokenized enough examples for total #. 104 | if count >= n_examples: 105 | # save to disk 106 | saved_length = (len(all_tokens) // block_size) * block_size 107 | chunks = [ 108 | all_tokens[i : i + block_size] 109 | for i in range(0, saved_length, block_size) 110 | ] 111 | 112 | print(f"\nsaving chunks @ {count}th example mark...") 113 | np.random.shuffle(chunks) 114 | num_eval = int(len(chunks) * 0.02) # put 2% of chunks into eval split. 115 | for eval_i in tqdm.tqdm(range(num_eval)): 116 | feature = { 117 | "text": _int64_feature(chunks[eval_i]), 118 | } 119 | example_proto = tf.train.Example( 120 | features=tf.train.Features(feature=feature) 121 | ) 122 | writer_eval.write(example_proto.SerializeToString()) 123 | 124 | for train_i in tqdm.tqdm(range(num_eval, len(chunks))): 125 | feature = { 126 | "text": _int64_feature(chunks[train_i]), 127 | } 128 | example_proto = tf.train.Example( 129 | features=tf.train.Features(feature=feature) 130 | ) 131 | writer_train.write(example_proto.SerializeToString()) 132 | break 133 | 134 | writer_train.close() 135 | writer_eval.close() 136 | -------------------------------------------------------------------------------- /requirements_gpu.txt: -------------------------------------------------------------------------------- 1 | clu 2 | datasets 3 | distrax 4 | grain 5 | matplotlib 6 | seaborn 7 | tensorflow 8 | tensorflow-datasets 9 | tf-keras 10 | transformers 11 | # for jax on GPU 12 | --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 13 | jax[cuda] 14 | -------------------------------------------------------------------------------- /requirements_tpu.txt: -------------------------------------------------------------------------------- 1 | clu 2 | datasets 3 | distrax 4 | grain 5 | matplotlib 6 | seaborn 7 | tensorflow 8 | tensorflow-datasets 9 | tf-keras 10 | transformers 11 | # for jax on TPU 12 | --find-links https://storage.googleapis.com/jax-releases/jax_releases.html 13 | --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html 14 | jax[tpu] 15 | -------------------------------------------------------------------------------- /run_gcp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2025 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | export PYTHONPATH="$PYTHONPATH:$(pwd)" 18 | source md4_venv/bin/activate 19 | 20 | EXPT_DIR="$(pwd)"/expt 21 | python md4/main.py \ 22 | --config=md4/configs/md4/fineweb_edu.py \ 23 | --sharded=false \ 24 | --workdir=${EXPT_DIR} 25 | --------------------------------------------------------------------------------