├── .gitignore
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── imgs
└── imagenet64.gif
├── md4
├── __init__.py
├── binary_search.py
├── configs
│ ├── genmd4
│ │ ├── openwebtext.py
│ │ └── text8.py
│ └── md4
│ │ ├── cifar10.py
│ │ ├── fineweb_edu.py
│ │ ├── imagenet64.py
│ │ ├── openwebtext.py
│ │ └── text8.py
├── input_pipeline.py
├── input_pipeline_v2.py
├── main.py
├── models
│ ├── backward.py
│ ├── diffusion
│ │ ├── genmd4.py
│ │ └── md4.py
│ └── utils.py
├── multihost_dataloading.py
├── networks
│ ├── dit.py
│ ├── sharded_transformer.py
│ ├── transformer.py
│ ├── unet.py
│ └── uvit.py
├── sampling.py
├── sharded_train.py
├── train.py
└── utils.py
├── prepare_openwebtext_data.py
├── requirements_gpu.txt
├── requirements_tpu.txt
└── run_gcp.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # Distribution / packaging
7 | .Python
8 | build/
9 | develop-eggs/
10 | dist/
11 | downloads/
12 | eggs/
13 | .eggs/
14 | lib/
15 | lib64/
16 | parts/
17 | sdist/
18 | var/
19 | wheels/
20 | share/python-wheels/
21 | *.egg-info/
22 | .installed.cfg
23 | *.egg
24 | MANIFEST
25 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | ## Contributor License Agreement
4 |
5 | Contributions to this project must be accompanied by a Contributor License
6 | Agreement. You (or your employer) retain the copyright to your contribution,
7 | this simply gives us permission to use and redistribute your contributions as
8 | part of the project. Head over to to see
9 | your current agreements on file or to sign a new one.
10 |
11 | You generally only need to submit a CLA once, so if you've already submitted one
12 | (even if it was for a different project), you probably don't need to do it
13 | again.
14 |
15 | ## Code reviews
16 |
17 | All submissions, including submissions by project members, require review. We
18 | use GitHub pull requests for this purpose. Consult
19 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
20 | information on using pull requests.
21 |
22 | ## Community Guidelines
23 |
24 | This project follows [Google's Open Source Community
25 | Guidelines](https://opensource.google/conduct/).
26 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MD4: Simplified and Generalized Masked Diffusion for Discrete Data
2 |
3 |
4 |
5 | ## Installation
6 |
7 | ### Create Python environment
8 |
9 | please use `requirements_gpu.txt` if your accelerator is GPUs, use
10 | `requirements_tpu`.txt when using Google Cloud TPUs.
11 |
12 | ```
13 | python -m venv md4_venv
14 | source md4_venv/bin/activate
15 | pip install -r requirements_[gpu/tpu].txt
16 | export PYTHONPATH="$PYTHONPATH:~/path/to/md4"
17 | ```
18 |
19 | ## Usage
20 |
21 | prepare openwebtext for training (i.e., tokenize and pack examples)
22 |
23 | ```
24 | mkdir data_dir
25 | python prepare_openwebtext_data.py
26 | ```
27 |
28 | train a MD4-S model over text data (OpenWebText, Fineweb-EDU).
29 |
30 | ```
31 | python md4/main.py --config=md4/configs/md4/openwebtext.py --sharded=false --workdir=./expt
32 | ```
33 |
34 | alternatively, you can train a MD4-S model over image data (CIFAR-10,
35 | ImageNet-64).
36 |
37 | ```
38 | python md4/main.py --config=md4/configs/md4/cifar10.py --sharded=false --workdir=./expt
39 | ```
40 |
41 | ### choose batch size
42 |
43 | Batch size depends on your compute resource. For training a MD4-S model with
44 | sequence length 1024, eight `A100` GPUs can support a maximum batch size of
45 | `128`. If running on TPUs, eight `v5litepod` chips can support a maximum batch
46 | size of `32`.
47 |
48 | ## Citing this work
49 |
50 | Add citation details here, usually a pastable BibTeX snippet:
51 |
52 | ```
53 | @inproceedings{shi2024simplified,
54 | title={Simplified and Generalized Masked Diffusion for Discrete Data},
55 | author={Shi, Jiaxin and Han, Kehang and Wang, Zhe and Doucet, Arnaud and Titsias, Michalis K.},
56 | booktitle={Advances in Neural Information Processing Systems},
57 | year={2024}
58 | }
59 | ```
60 |
61 | ## License and disclaimer
62 |
63 | Copyright 2024 DeepMind Technologies Limited
64 |
65 | All software is licensed under the Apache License, Version 2.0 (Apache 2.0);
66 | you may not use this file except in compliance with the Apache 2.0 license.
67 | You may obtain a copy of the Apache 2.0 license at:
68 | https://www.apache.org/licenses/LICENSE-2.0
69 |
70 | Unless required by applicable law or agreed to in writing, all software and
71 | materials distributed here under the Apache 2.0 or CC-BY licenses are
72 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
73 | either express or implied. See the licenses for the specific language governing
74 | permissions and limitations under those licenses.
75 |
76 | This is not an official Google product.
77 |
--------------------------------------------------------------------------------
/imgs/imagenet64.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/md4/b9fbf29216f818cccff46cb5727199f86f35593e/imgs/imagenet64.gif
--------------------------------------------------------------------------------
/md4/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """MD4."""
17 |
--------------------------------------------------------------------------------
/md4/binary_search.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Binary search over float32 bits.
17 |
18 | Includes fast algorithms top-k masking and top-p masking on probability
19 | distributions.
20 |
21 | Adapted from:
22 | https://github.com/google-research/t5x/blob/main/t5x/binary_search.py
23 | """
24 |
25 | from typing import Callable, Sequence
26 |
27 | import jax
28 | from jax import lax
29 | from jax import numpy as jnp
30 |
31 |
32 | def int32_bsearch(
33 | batch_shape: Sequence[int], predicate: Callable[[jnp.ndarray], jnp.ndarray],
34 | ) -> int:
35 | """Batched binary search over int32 values.
36 |
37 | For each element of the batch, search for the largest int32 (closest to
38 | positive infinity) for which the predicate is False. If the predicate is
39 | always True, returns the minimum int32 value.
40 |
41 | Args:
42 | batch_shape: Shape of the search that we're batching over.
43 | predicate: the query we're searching for. For every batch element, this is
44 | required to be a monotonic function from int32 to bool. In other words,
45 | the predicate must return False for all numbers <= some threshold and then
46 | return True for all numbers > that threshold. The threshold may be
47 | different for different elements of the batch.
48 |
49 | Returns:
50 | For each element of the batch, the largest int32 for which the predicate
51 | returns False. Shape: batch_shape.
52 | """
53 | current_bits = jnp.zeros(batch_shape, dtype=jnp.int32)
54 |
55 | # bit 31 is special, because it compares in the opposite order of all other
56 | # bits. we use uint32 due to numpy promotion/casting rules.
57 | midpoint = current_bits
58 | predicate_satisfied = predicate(midpoint)
59 | current_bits = current_bits | jnp.where(
60 | predicate_satisfied, jnp.uint32(1 << 31), jnp.uint32(0)
61 | )
62 | del midpoint, predicate_satisfied
63 |
64 | def loop_body(i, current_bits):
65 | bit_index = 30 - i
66 | bit = jnp.int32(1 << bit_index)
67 | midpoint = current_bits | bit
68 | predicate_satisfied = predicate(midpoint)
69 | current_bits = current_bits | jnp.where(
70 | predicate_satisfied, jnp.int32(0), bit
71 | )
72 | return current_bits
73 |
74 | return lax.fori_loop(0, 31, loop_body, current_bits)
75 |
76 |
77 | def _monotonic_int32_to_float32_bit_pattern(x: int) -> int:
78 | """Converts an int32 to a float32 bit pattern with consistent ordering.
79 |
80 | This function is the unique function that is monotonic with respect to the
81 | floating point total order, see
82 | https://en.wikipedia.org/wiki/IEEE_754#Total-ordering_predicate. Note that
83 | this function returns an int32, not a float32. For the function that returns
84 | float32, see `monotonic_int32_to_float32`.
85 |
86 | Args:
87 | x: int bit pattern.
88 |
89 | Returns:
90 | Bit pattern of a float32 number.
91 | """
92 | non_sign_bits = jnp.int32((1 << 31) - 1)
93 | # See
94 | # https://stackoverflow.com/questions/20097380/iee-754-total-order-in-standard-c11
95 | # for the relationship between int32 order and f32 total order, including
96 | # the "xor trick".
97 |
98 | # Flip the sort order for numbers where the sign bit is set. On int32,
99 | # the bit pattern with sign bit set and all other bits clear is the most
100 | # negative bit pattern (it's int32::MIN), whereas on float32 it's the least
101 | # negative bit pattern (it's -0.0). Flipping all the non-sign bits makes the
102 | # int32 sort order consistent with the float32 sort order.
103 | x = x ^ jnp.where(x < 0, non_sign_bits, jnp.int32(0))
104 | return x
105 |
106 |
107 | def _monotonic_int32_to_float32(x: int) -> jax.Array:
108 | """Converts an int32 to a float32 with consistent ordering.
109 |
110 | This function is the unique function that is monotonic with respect to the
111 | floating point total order, see
112 | https://en.wikipedia.org/wiki/IEEE_754#Total-ordering_predicate.
113 |
114 | Args:
115 | x: int bit pattern.
116 |
117 | Returns:
118 | float32 number with consistent ordering.
119 | """
120 | x = _monotonic_int32_to_float32_bit_pattern(x)
121 | return lax.bitcast_convert_type(x, jnp.float32)
122 |
123 |
124 | def float32_bsearch(batch_shape, predicate):
125 | """Binary search on finite float32 numbers.
126 |
127 | For each element of the batch, this function searches for the largest finite
128 | non-NaN float32 for which the predicate is False.
129 |
130 | Args:
131 | batch_shape: Shape of the search that we're batching over.
132 | predicate: the query we're searching for. This is required to be monotonic
133 | with respect to the floating point order, i.e. it must be False for all
134 | numbers <= a threshold, and then True for all numbers > the threshold. The
135 | threshold may be different for different elements of the batch.
136 |
137 | Returns:
138 | For each element of the batch, the largest float32 for which the predicate
139 | returns False. Shape: f32[batch_shape].
140 | """
141 | exponent_bits = jnp.int32((1 << 31) - (1 << (31 - 8)))
142 |
143 | def int32_predicate(x):
144 | x = _monotonic_int32_to_float32_bit_pattern(x)
145 | is_finite = (x & exponent_bits) != exponent_bits
146 |
147 | # Non-finite numbers (infinity and NaN) are at the very extremes of the
148 | # int32 range, i.e. they include int32::MAX and int32::MIN, plus the numbers
149 | # adjacent to them. For the nonfinite numbers touching int32::MIN, we
150 | # arrange for them to return False from the predicate, and for the nonfinite
151 | # numbers touching int32::MAX, we arrange for them to return True from the
152 | # predicate. x>=0 is an easy way to achieve that.
153 | predicate_on_nonfinite = x >= 0
154 | x_float32 = lax.bitcast_convert_type(x, jnp.float32)
155 | return jnp.where(is_finite, predicate(x_float32), predicate_on_nonfinite)
156 |
157 | # We search over bit patterns, which requires bit shifting and ordering of bit
158 | # patterns. This is natively supported on int32 but not on float32.
159 | # Additionally, it's more common to reason about int32 bit arithmetic and
160 | # ordering than float32 bit arithmetic and ordering, so we do the core of our
161 | # search in int32. Additionally, this allows us to test the underlying binary
162 | # search on int32 values.
163 | #
164 | # The function _monotonic_int32_to_float32 encapsulates all of the knowledge
165 | # we need about float32 bit patterns.
166 | result = int32_bsearch(batch_shape, int32_predicate)
167 | return _monotonic_int32_to_float32(result)
168 |
169 |
170 | def topk_mask(x: jnp.ndarray, k: int, replace_val: jnp.ndarray) -> jnp.ndarray:
171 | """Sets everything to replace_val, except the top k values per batch element.
172 |
173 | Sharding considerations: this function does 32 reductions over the vocab_size
174 | axis of the input array. To avoid excessive latency from these reductions, you
175 | should ensure that the vocab_size axis is unsharded on input to this function.
176 | Prefer to shard the batch axes instead.
177 |
178 | Scratchpad memory considerations: this function is most efficient if the
179 | entire input array can fit in a fast memory tier. To help ensure this, you may
180 | wish to split the batch axes into microbatches and the microbatches in a
181 | sequential loop.
182 |
183 | Args:
184 | x: Values before masking. [batch..., vocab_size]
185 | k: Number of masked values to return. In presence of ties, more than k
186 | values might be returned.
187 | replace_val: For the masked values of x, what to overwrite them with.
188 |
189 | Returns:
190 | masked version of x. [batch..., vocab_size]
191 | """
192 | batch_shape = tuple(list(x.shape)[:-1]) # [batch...]
193 |
194 | x_for_loop = x
195 | reduce_axis = x.ndim - 1
196 | if x.ndim > 1:
197 | # We're going to be doing 32 reductions over 'reduce_axis'. Generally,
198 | # reductions over the last dimension are the most expensive, because they
199 | # involve reducing across vector lanes, which is often not efficient. So
200 | # we transpose the reduce_axis to be the second-last dimension, to avoid
201 | # this inefficiency.
202 | #
203 | # Normaly the XLA compiler would automatically perform this optimization,
204 | # but it doesn't yet see through loops to do so. So we do it ourselves.
205 | x_for_loop = jnp.swapaxes(x_for_loop, -1, -2)
206 | reduce_axis = x.ndim - 2
207 |
208 | # x: [batch..., vocab_size, batch]
209 | def predicate(threshold):
210 | # threshold: [batch...]
211 |
212 | # Since we've negated, we now want a predicate that is True for small
213 | # numbers and False for large numbers. The result of the bsearch is the
214 | # smallest float32 for which the predicate is False.
215 | threshold = -threshold
216 |
217 | threshold = lax.expand_dims(threshold, (reduce_axis,))
218 | # threshold: [batch..., 1, last_batch]
219 |
220 | # count_ge: [batch...]
221 | count_gt = jnp.sum(x_for_loop > threshold, axis=reduce_axis)
222 |
223 | return count_gt >= k
224 |
225 | # cutoff: [batch...]
226 | cutoff = float32_bsearch(batch_shape, predicate)
227 | cutoff = -cutoff
228 | # cutoff: [batch..., 1]
229 | cutoff = lax.expand_dims(cutoff, (cutoff.ndim,))
230 | return jnp.where(x >= cutoff, x, jnp.full_like(x, replace_val))
231 |
232 |
233 | def topp_mask(
234 | logits: jnp.ndarray, p: float, replace_val: jnp.ndarray
235 | ) -> jnp.ndarray:
236 | """Applies top-p masking to logits.
237 |
238 | Masks logits down to the smallest set of choices, such that the total
239 | probability mass is >= p. Values in this set are left as they are. All other
240 | values are set with `replace_val`.
241 |
242 | Sharding considerations: this function does 33 reductions over the vocab_size
243 | axis of the input array. To avoid excessive latency from these reductions, you
244 | should ensure that the vocab_size axis is unsharded on input to this function.
245 | Prefer to shard the batch axes instead.
246 |
247 | Scratchpad memory considerations: this function is most efficient if the
248 | entire input array can fit in a fast memory tier. To help ensure this, you may
249 | wish to split the batch axes into microbatches and the microbatches in a
250 | sequential loop.
251 |
252 | Args:
253 | logits: Logits before masking. [batch..., vocab_size]
254 | p: Minimum probability mass requested.
255 | replace_val: For the masked values of logits, what to overwrite them with.
256 |
257 | Returns:
258 | masked version of x. [batch..., vocab_size]
259 | """
260 | batch_shape = tuple(list(logits.shape)[:-1]) # [batch...]
261 |
262 | probs = jax.nn.softmax(logits, axis=-1)
263 |
264 | probs_for_reduction = probs
265 | reduce_axis = probs_for_reduction.ndim - 1
266 | if probs_for_reduction.ndim > 1:
267 | # We're going to be doing 33 reductions over 'reduce_axis'. Generally,
268 | # reductions over the last dimension are the most expensive, because they
269 | # involve reducing across vector lanes, which is often not efficient. So
270 | # we transpose the reduce_axis to be the second-last dimension, to avoid
271 | # this inefficiency.
272 | probs_for_reduction = jnp.swapaxes(probs_for_reduction, -1, -2)
273 | reduce_axis = probs_for_reduction.ndim - 2
274 |
275 | # As we increase the threshold, the probability mass decreases, and the number
276 | # selected decreases.
277 | #
278 | # We want the largest threshold with the probability mass >= p. Binary search
279 | # searches for when the predicate is False, so we negate the output of the
280 | # predicate, i.e. probability mass < p.
281 |
282 | # probs_for_reduction: [batch..., vocab_size, batch]
283 | def predicate(threshold):
284 | # threshold: [batch...]
285 | threshold = lax.expand_dims(threshold, (reduce_axis,))
286 | # threshold: [batch..., 1, last_batch]
287 |
288 | # count_ge: [batch...]
289 | probability_mass = jnp.sum(
290 | jnp.where(probs_for_reduction >= threshold, probs_for_reduction, 0.0),
291 | axis=reduce_axis,
292 | )
293 |
294 | return probability_mass < p
295 |
296 | # threshold: [batch...]
297 | threshold = float32_bsearch(batch_shape, predicate)
298 | # threshold: [batch..., 1]
299 | threshold = lax.expand_dims(threshold, (threshold.ndim,))
300 | return jnp.where(
301 | probs >= threshold, logits, jnp.full_like(logits, replace_val)
302 | )
303 |
--------------------------------------------------------------------------------
/md4/configs/genmd4/openwebtext.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | r"""A config for training GenMD4-S on OpenWebText."""
17 |
18 | from collections import abc
19 |
20 | from ml_collections import config_dict
21 |
22 |
23 | def get_config() -> config_dict.ConfigDict:
24 | """Default config."""
25 |
26 | config = config_dict.ConfigDict()
27 |
28 | # dataset configs
29 | config.vocab_size = 50259
30 | config.dataset = "openwebtext"
31 | config.classes = -1
32 |
33 | config.task_type = "text" # text or image
34 | config.model_type = "genmd4"
35 | config.data_shape = (1024,)
36 |
37 | # timesteps: int or None
38 | config.timesteps = 1000
39 | config.noise_schedule = "poly"
40 | config.power_init = 1.0
41 | config.outside_embed = False
42 | config.time_features = "t"
43 | config.cont_time = True
44 |
45 | config.feature_dim = 64
46 | config.n_layers = 12
47 | config.ch_mult = (1,) # not used
48 | config.n_dit_layers = 0 # not used
49 | config.dit_num_heads = 12 # not used
50 | config.dit_hidden_size = 768 # not used
51 | config.dropout_rate = 0.0
52 |
53 | config.num_heads = 12
54 | config.mlp_type = "glu"
55 | config.depth_scaled_init = True
56 | config.cond_type = "adaln_zero"
57 |
58 | config.learning_rate = 3e-4
59 | config.learning_rate_schedule = "cosine"
60 | config.warmup_steps = 2000
61 | config.weight_decay = 0.0
62 | config.clip = 0.0
63 | config.b2 = 0.999
64 | config.num_epochs = -1
65 | config.ema_rate = 0.9999
66 | # If num_train_steps==-1 then the number of training steps is calculated from
67 | # num_epochs.
68 | config.num_train_steps = 1_000_000
69 | # Evaluates for a full epoch if num_eval_steps==-1. Set to a smaller value for
70 | # fast iteration when running train.train_and_eval() from a Colab.
71 | config.num_eval_steps = -1
72 | config.batch_size = 512
73 | config.num_microbatches = 1
74 | config.per_device_batch_size = -1
75 | # If batches should be added to evaluate the entire dataset.
76 | config.eval_pad_last_batch = False
77 | config.check_nans = False
78 |
79 | config.log_loss_every_steps = 500
80 | config.eval_every_steps = 10000
81 | config.checkpoint_every_steps = 5000
82 | config.checkpoint_keep_period = -1
83 |
84 | # Single integer or tuple. If None will use (XManager ID, work unit).
85 | config.seed = 42
86 |
87 | # Number of workers for Grain loaders.
88 | config.grain_num_workers = 15
89 |
90 | config.trial = 0 # Dummy for repeated runs.
91 | config.test_in_colab = False
92 | return config
93 |
94 |
95 | # By default, the launcher calls `sweep()`.
96 | # To disable the sweep, the `sweep()` function can be commented (or renamed),
97 | # or the flag `--nosweep` can be specified to the launcher.
98 | def sweep(add: abc.Callable[..., None]):
99 | """Starts multiple work units with varying config args."""
100 | add(
101 | learning_rate=3e-4,
102 | dropout_rate=0.02,
103 | weight_decay=0.03,
104 | )
105 |
--------------------------------------------------------------------------------
/md4/configs/genmd4/text8.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | r"""A config for training GenMD4 on text8."""
17 |
18 | from collections import abc
19 |
20 | from ml_collections import config_dict
21 |
22 |
23 | def get_config() -> config_dict.ConfigDict:
24 | """Default config."""
25 |
26 | config = config_dict.ConfigDict()
27 |
28 | # dataset configs
29 | config.vocab_size = 27
30 | config.dataset = "text8"
31 | config.classes = -1
32 |
33 | config.task_type = "text" # text or image
34 | config.model_type = "genmd4"
35 | config.data_shape = (256,)
36 |
37 | # timesteps: int or None
38 | config.timesteps = 1000
39 | config.noise_schedule = "poly"
40 | config.power_init = 1.0
41 | config.outside_embed = True
42 | config.time_features = "t"
43 | config.cont_time = True
44 |
45 | config.feature_dim = 64
46 | config.n_layers = 12
47 | config.ch_mult = (1,) # not used
48 | config.n_dit_layers = 0 # not used
49 | config.dit_num_heads = 12 # not used
50 | config.dit_hidden_size = 768 # not used
51 | config.dropout_rate = 0.0
52 |
53 | config.num_heads = 12
54 | config.mlp_type = "glu"
55 | config.depth_scaled_init = True
56 | config.cond_type = "adaln_zero"
57 |
58 | config.learning_rate = 3e-4
59 | config.learning_rate_schedule = "cosine"
60 | config.warmup_steps = 2000
61 | config.weight_decay = 0.0
62 | config.clip = 0.0
63 | config.b2 = 0.999
64 | config.num_epochs = -1
65 | config.ema_rate = 0.9999
66 | # If num_train_steps==-1 then the number of training steps is calculated from
67 | # num_epochs.
68 | config.num_train_steps = 1_000_000
69 | # Evaluates for a full epoch if num_eval_steps==-1. Set to a smaller value for
70 | # fast iteration when running train.train_and_eval() from a Colab.
71 | config.num_eval_steps = -1
72 | config.batch_size = 512
73 | config.num_microbatches = 1
74 | config.per_device_batch_size = -1
75 | # If batches should be added to evaluate the entire dataset.
76 | config.eval_pad_last_batch = False
77 | config.check_nans = False
78 |
79 | config.log_loss_every_steps = 500
80 | config.eval_every_steps = 5000
81 | config.checkpoint_every_steps = 5000
82 | config.checkpoint_keep_period = 10000
83 |
84 | # Single integer or tuple. If None will use (XManager ID, work unit).
85 | config.seed = 42
86 |
87 | # Number of workers for Grain loaders.
88 | config.grain_num_workers = 15
89 |
90 | config.trial = 0 # Dummy for repeated runs.
91 | config.test_in_colab = False
92 | return config
93 |
94 |
95 | # By default, the launcher calls `sweep()`.
96 | # To disable the sweep, the `sweep()` function can be commented (or renamed),
97 | # or the flag `--nosweep` can be specified to the launcher.
98 | def sweep(add: abc.Callable[..., None]):
99 | """Starts multiple work units with varying config args."""
100 | add(
101 | learning_rate=3e-4,
102 | dropout_rate=0.05,
103 | weight_decay=0.03,
104 | )
105 |
--------------------------------------------------------------------------------
/md4/configs/md4/cifar10.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | r"""A config for training MD4 on CIFAR10."""
17 |
18 | from collections import abc
19 |
20 | from ml_collections import config_dict
21 |
22 |
23 | def get_config() -> config_dict.ConfigDict:
24 | """Default config."""
25 |
26 | config = config_dict.ConfigDict()
27 |
28 | # dataset configs
29 | config.vocab_size = 256
30 | config.dataset = "cifar10"
31 | config.classes = -1
32 |
33 | config.task_type = "image" # text or image
34 | config.model_type = "md4"
35 | config.data_shape = (32, 32, 3)
36 |
37 | # timesteps: int or None
38 | config.timesteps = 256
39 | # linear, cosine, poly[exponent], e.g., poly3
40 | config.noise_schedule = "linear"
41 | config.outside_embed = True # not used
42 | # t or none (removes time dependence)
43 | config.time_features = "t"
44 | config.cont_time = True
45 |
46 | config.feature_dim = 128
47 | config.n_layers = 32
48 | config.ch_mult = (1,) # not used
49 | config.n_dit_layers = 0 # not used
50 | config.dit_num_heads = 12 # not used
51 | config.dit_hidden_size = 768 # not used
52 | config.dropout_rate = 0.1
53 |
54 | config.num_heads = 12 # not used
55 | config.mlp_type = "glu" # not used
56 | config.depth_scaled_init = True # not used
57 | config.cond_type = "adaln_zero" # not used
58 |
59 | config.learning_rate = 2e-4
60 | config.learning_rate_schedule = "cosine"
61 | config.warmup_steps = 100
62 | config.weight_decay = 0.01
63 | config.clip = 0.0
64 | config.b2 = 0.99
65 | config.num_epochs = -1
66 | config.ema_rate = 0.9999
67 | # If num_train_steps==-1 then the number of training steps is calculated from
68 | # num_epochs.
69 | config.num_train_steps = 2_000_000
70 | # Evaluates for a full epoch if num_eval_steps==-1
71 | config.num_eval_steps = -1
72 | config.batch_size = 256
73 | config.per_device_batch_size = -1
74 | # If batches should be added to evaluate the entire dataset.
75 | config.eval_pad_last_batch = False
76 | config.check_nans = False
77 |
78 | # Sampling
79 | # ancestral, mean, or topp
80 | config.sampler = "ancestral"
81 | # uniform, cosine
82 | config.sampling_grid = "cosine"
83 | # for topp sampler
84 | config.topp = 0.98
85 |
86 | config.log_loss_every_steps = 500
87 | config.eval_every_steps = 10000
88 | config.checkpoint_every_steps = 5000
89 | config.checkpoint_keep_period = 10000
90 |
91 | # Single integer or tuple. If None will use (XManager ID, work unit).
92 | config.seed = 0
93 |
94 | # Number of workers for Grain loaders.
95 | config.grain_num_workers = 15
96 |
97 | config.trial = 0 # Dummy for repeated runs.
98 | config.test_in_colab = False
99 | return config
100 |
101 |
102 | # By default, the launcher calls `sweep()`.
103 | # To disable the sweep, the `sweep()` function can be commented (or renamed),
104 | # or the flag `--nosweep` can be specified to the launcher.
105 | def sweep(add: abc.Callable[..., None]):
106 | """Starts multiple work units with varying config args."""
107 | # For best likelihood results
108 | add(
109 | noise_schedule="linear",
110 | sampling_grid="cosine",
111 | )
112 | # For best sample quality
113 | # add(
114 | # noise_schedule="cosine",
115 | # sampling_grid="uniform",
116 | # )
117 |
--------------------------------------------------------------------------------
/md4/configs/md4/fineweb_edu.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | r"""A config for training MD4 on FineWeb EDU."""
17 |
18 | from collections import abc
19 |
20 | from ml_collections import config_dict
21 |
22 |
23 | def get_config() -> config_dict.ConfigDict:
24 | """Default config."""
25 |
26 | config = config_dict.ConfigDict()
27 |
28 | # dataset configs
29 | config.vocab_size = 50257
30 | config.use_v2_input_pipeline = True
31 | config.dataset = "fineweb_edu"
32 | config.classes = -1
33 |
34 | config.task_type = "text" # text or image
35 | config.model_type = "md4"
36 | config.data_shape = (1024,)
37 |
38 | # timesteps: int or None
39 | config.timesteps = 1000
40 | # linear, cosine, poly[exponent], e.g., poly3
41 | config.noise_schedule = "linear"
42 | config.outside_embed = False
43 | # t or none (removes time dependence)
44 | config.time_features = "t"
45 | config.cont_time = True
46 |
47 | config.feature_dim = 64
48 | config.n_layers = 12
49 | config.ch_mult = (1,) # not used
50 | config.n_dit_layers = 0 # not used
51 | config.dit_num_heads = 12 # not used
52 | config.dit_hidden_size = 768 # not used
53 | config.dropout_rate = 0.0
54 |
55 | config.num_heads = 12
56 | config.mlp_type = "glu"
57 | config.depth_scaled_init = True
58 | config.cond_type = "adaln_zero"
59 |
60 | config.learning_rate = 3e-4
61 | config.learning_rate_schedule = "cosine"
62 | config.warmup_steps = 2000
63 | config.weight_decay = 0.0
64 | config.clip = 0.0
65 | config.b2 = 0.999
66 | config.num_epochs = -1
67 | config.ema_rate = 0.9999
68 | # If num_train_steps==-1 then the number of training steps is calculated from
69 | # num_epochs.
70 | config.num_train_steps = 1_000_000
71 | # Evaluates for a full epoch if num_eval_steps==-1.
72 | config.num_eval_steps = -1
73 | config.batch_size = 512
74 | config.num_microbatches = 1
75 | config.per_device_batch_size = -1
76 | # If batches should be added to evaluate the entire dataset.
77 | config.eval_pad_last_batch = False
78 | config.check_nans = False
79 |
80 | # Sampling
81 | # ancestral, mean, or topp
82 | config.sampler = "ancestral"
83 | # uniform, cosine
84 | config.sampling_grid = "cosine"
85 | # for topp sampler
86 | config.topp = 0.98
87 |
88 | config.log_loss_every_steps = 500
89 | config.eval_every_steps = 10000
90 | config.checkpoint_every_steps = 5000
91 | config.checkpoint_keep_period = 200000
92 |
93 | # Single integer or tuple. If None will use (XManager ID, work unit).
94 | config.seed = 42
95 |
96 | # Number of workers for Grain loaders.
97 | # HF data source (OSS version) only supports one worker for now, more workers
98 | # result in duplicated data.
99 | config.grain_num_workers = 1
100 |
101 | config.trial = 0 # Dummy for repeated runs.
102 | config.test_in_colab = False
103 | return config
104 |
105 |
106 | # By default, the launcher calls `sweep()`.
107 | # To disable the sweep, the `sweep()` function can be commented (or renamed),
108 | # or the flag `--nosweep` can be specified to the launcher.
109 | def sweep(add: abc.Callable[..., None]):
110 | """Starts multiple work units with varying config args."""
111 | # Small size
112 | add(
113 | noise_schedule="linear",
114 | dropout_rate=0.02,
115 | weight_decay=0.03,
116 | )
117 | # Medium size
118 | # add(
119 | # noise_schedule="cosine",
120 | # n_layers=24,
121 | # num_heads=16,
122 | # batch_size=512,
123 | # sampling_grid="uniform",
124 | # dropout_rate=0.0,
125 | # weight_decay=0.03,
126 | # )
127 |
--------------------------------------------------------------------------------
/md4/configs/md4/imagenet64.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | r"""A config for training MD4 on downsampled ImageNet 64x64."""
17 |
18 | from collections import abc
19 |
20 | from ml_collections import config_dict
21 |
22 |
23 | def get_config() -> config_dict.ConfigDict:
24 | """Default config."""
25 |
26 | config = config_dict.ConfigDict()
27 |
28 | # dataset configs
29 | config.vocab_size = 256
30 | config.dataset = "downsampled_imagenet/64x64"
31 | config.classes = -1
32 |
33 | config.task_type = "image" # text or image
34 | config.model_type = "md4"
35 | config.data_shape = (64, 64, 3)
36 |
37 | # timesteps: int or None
38 | config.timesteps = 256
39 | # linear, cosine, poly[exponent], e.g., poly3
40 | config.noise_schedule = "linear"
41 | config.outside_embed = True # not used
42 | # t or none (removes time dependence)
43 | config.time_features = "t"
44 | config.cont_time = True
45 |
46 | config.feature_dim = 256
47 | config.n_layers = 8
48 | config.ch_mult = (1,)
49 | config.n_dit_layers = 20
50 | config.dit_num_heads = 12
51 | config.dit_hidden_size = 768
52 | config.dropout_rate = 0.1
53 |
54 | config.num_heads = 12 # not used
55 | config.mlp_type = "glu" # not used
56 | config.depth_scaled_init = True # not used
57 | config.cond_type = "adaln_zero" # not used
58 |
59 | config.learning_rate = 2e-4
60 | config.learning_rate_schedule = "cosine"
61 | config.warmup_steps = 100
62 | config.weight_decay = 0.01
63 | config.clip = 0.0
64 | config.b2 = 0.99
65 | config.num_epochs = -1
66 | config.ema_rate = 0.9999
67 | # If num_train_steps==-1 then the number of training steps is calculated from
68 | # num_epochs.
69 | config.num_train_steps = 2_000_000
70 | # Evaluates for a full epoch if num_eval_steps==-1
71 | config.num_eval_steps = -1
72 | config.batch_size = 512
73 | config.per_device_batch_size = -1
74 | # If batches should be added to evaluate the entire dataset.
75 | config.eval_pad_last_batch = False
76 | config.check_nans = False
77 |
78 | # Sampling
79 | # ancestral, mean, or topp
80 | config.sampler = "ancestral"
81 | # uniform, cosine
82 | config.sampling_grid = "cosine"
83 | # for topp sampler
84 | config.topp = 0.98
85 |
86 | config.log_loss_every_steps = 500
87 | config.eval_every_steps = 10000
88 | config.checkpoint_every_steps = 5000
89 | config.checkpoint_keep_period = 10000
90 |
91 | # Single integer or tuple. If None will use (XManager ID, work unit).
92 | config.seed = 0
93 |
94 | # Number of workers for Grain loaders.
95 | config.grain_num_workers = 15
96 |
97 | config.trial = 0 # Dummy for repeated runs.
98 | config.test_in_colab = False
99 | return config
100 |
101 |
102 | # By default, the launcher calls `sweep()`.
103 | # To disable the sweep, the `sweep()` function can be commented (or renamed),
104 | # or the flag `--nosweep` can be specified to the launcher.
105 | def sweep(add: abc.Callable[..., None]):
106 | """Starts multiple work units with varying config args."""
107 | # Train unconditional model
108 | # Default: linear schedule, cosine sampling grid
109 | add(
110 | dataset="downsampled_imagenet/64x64",
111 | sampling_grid="cosine",
112 | noise_schedule="linear",
113 | n_layers=8, # use 24 for best likelihood in the paper
114 | n_dit_layers=20, # use 12 for best likelihood in the paper
115 | )
116 | # # Cosine schedule: worse likelihood, for best sample quality
117 | # add(
118 | # dataset="downsampled_imagenet/64x64",
119 | # sampling_grid="uniform",
120 | # noise_schedule="cosine",
121 | # n_layers=8,
122 | # n_dit_layers=20,
123 | # )
124 | # # Train class conditioned
125 | # add(
126 | # dataset="class_cond_imagenet64",
127 | # classes=1000,
128 | # sampling_grid="uniform",
129 | # noise_schedule="cosine",
130 | # )
131 |
--------------------------------------------------------------------------------
/md4/configs/md4/openwebtext.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | r"""A config for training small model on OpenWebText."""
17 |
18 | from collections import abc
19 |
20 | from ml_collections import config_dict
21 |
22 |
23 | def get_config() -> config_dict.ConfigDict:
24 | """Default config."""
25 |
26 | config = config_dict.ConfigDict()
27 |
28 | # dataset configs
29 | config.vocab_size = 50257
30 | config.dataset = "openwebtext"
31 | config.classes = -1
32 |
33 | config.task_type = "text" # text or image
34 | config.model_type = "md4"
35 | config.data_shape = (1024,)
36 |
37 | # timesteps: int or None
38 | config.timesteps = 1000
39 | # linear, cosine, poly[exponent], e.g., poly3
40 | config.noise_schedule = "linear"
41 | config.outside_embed = False
42 | # t or none (removes time dependence)
43 | config.time_features = "t"
44 | config.cont_time = True
45 |
46 | config.feature_dim = 64
47 | config.n_layers = 12
48 | config.ch_mult = (1,) # not used
49 | config.n_dit_layers = 0 # not used
50 | config.dit_num_heads = 12 # not used
51 | config.dit_hidden_size = 768 # not used
52 | config.dropout_rate = 0.0
53 |
54 | config.num_heads = 12
55 | config.mlp_type = "glu"
56 | config.depth_scaled_init = True
57 | config.cond_type = "adaln_zero"
58 |
59 | config.learning_rate = 3e-4
60 | config.learning_rate_schedule = "cosine"
61 | config.warmup_steps = 2000
62 | config.weight_decay = 0.0
63 | config.clip = 0.0
64 | config.b2 = 0.999
65 | config.num_epochs = -1
66 | config.ema_rate = 0.9999
67 | # If num_train_steps==-1 then the number of training steps is calculated from
68 | # num_epochs.
69 | config.num_train_steps = 1_000_000
70 | # Evaluates for a full epoch if num_eval_steps==-1.
71 | config.num_eval_steps = -1
72 | config.batch_size = 512
73 | config.num_microbatches = 1
74 | config.per_device_batch_size = -1
75 | # If batches should be added to evaluate the entire dataset.
76 | config.eval_pad_last_batch = False
77 | config.check_nans = False
78 |
79 | # Sampling
80 | # ancestral, mean, or topp
81 | config.sampler = "ancestral"
82 | # uniform, cosine
83 | config.sampling_grid = "cosine"
84 | # for topp sampler
85 | config.topp = 0.98
86 |
87 | config.log_loss_every_steps = 500
88 | config.eval_every_steps = 10000
89 | config.checkpoint_every_steps = 5000
90 | config.checkpoint_keep_period = 200000
91 |
92 | # Single integer or tuple. If None will use (XManager ID, work unit).
93 | config.seed = 42
94 |
95 | # Number of workers for Grain loaders.
96 | # HF data source (OSS version) only supports one worker for now, more workers
97 | # result in duplicated data.
98 | config.grain_num_workers = 1
99 |
100 | config.trial = 0 # Dummy for repeated runs.
101 | config.test_in_colab = False
102 | return config
103 |
104 |
105 | # By default, the launcher calls `sweep()`.
106 | # To disable the sweep, the `sweep()` function can be commented (or renamed),
107 | # or the flag `--nosweep` can be specified to the launcher.
108 | def sweep(add: abc.Callable[..., None]):
109 | """Starts multiple work units with varying config args."""
110 | # Small size
111 | add(
112 | noise_schedule="linear",
113 | dropout_rate=0.02,
114 | weight_decay=0.03,
115 | )
116 | # Medium size
117 | # add(
118 | # noise_schedule="cosine",
119 | # n_layers=24,
120 | # num_heads=16,
121 | # batch_size=512,
122 | # sampling_grid="uniform",
123 | # dropout_rate=0.0,
124 | # weight_decay=0.03,
125 | # )
126 |
--------------------------------------------------------------------------------
/md4/configs/md4/text8.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | r"""A config for training MD4 on text8."""
17 |
18 | from collections import abc
19 |
20 | from ml_collections import config_dict
21 |
22 |
23 | def get_config() -> config_dict.ConfigDict:
24 | """Default config."""
25 |
26 | config = config_dict.ConfigDict()
27 |
28 | # dataset configs
29 | config.vocab_size = 27
30 | config.dataset = "text8"
31 | config.classes = -1
32 |
33 | config.task_type = "text" # text or image
34 | config.model_type = "md4"
35 | config.data_shape = (256,)
36 |
37 | # timesteps: int or None
38 | config.timesteps = 1000
39 | # linear, cosine, poly[exponent], e.g., poly3
40 | config.noise_schedule = "linear"
41 | config.outside_embed = True
42 | # t or none (removes time dependence)
43 | config.time_features = "t"
44 | config.cont_time = True
45 |
46 | config.feature_dim = 64
47 | config.n_layers = 12
48 | config.ch_mult = (1,) # not used
49 | config.n_dit_layers = 0 # not used
50 | config.dit_num_heads = 12 # not used
51 | config.dit_hidden_size = 768 # not used
52 | config.dropout_rate = 0.0
53 |
54 | config.num_heads = 12
55 | config.mlp_type = "glu"
56 | config.depth_scaled_init = True
57 | config.cond_type = "adaln_zero"
58 |
59 | config.learning_rate = 3e-4
60 | config.learning_rate_schedule = "cosine"
61 | config.warmup_steps = 2000
62 | config.weight_decay = 0.0
63 | config.clip = 0.0
64 | config.b2 = 0.999
65 | config.num_epochs = -1
66 | config.ema_rate = 0.9999
67 | # If num_train_steps==-1 then the number of training steps is calculated from
68 | # num_epochs.
69 | config.num_train_steps = 1_000_000
70 | # Evaluates for a full epoch if num_eval_steps==-1.
71 | config.num_eval_steps = -1
72 | config.batch_size = 512
73 | config.num_microbatches = 1
74 | config.per_device_batch_size = -1
75 | # If batches should be added to evaluate the entire dataset.
76 | config.eval_pad_last_batch = False
77 | config.check_nans = False
78 |
79 | # Sampling
80 | # ancestral, mean, or topp
81 | config.sampler = "ancestral"
82 | # uniform, cosine
83 | config.sampling_grid = "cosine"
84 | # for topp sampler
85 | config.topp = 0.98
86 |
87 | config.log_loss_every_steps = 500
88 | config.eval_every_steps = 5000
89 | config.checkpoint_every_steps = 5000
90 | config.checkpoint_keep_period = 10000
91 |
92 | # Single integer or tuple. If None will use (XManager ID, work unit).
93 | config.seed = 42
94 |
95 | # Number of workers for Grain loaders.
96 | config.grain_num_workers = 15
97 |
98 | config.trial = 0 # Dummy for repeated runs.
99 | config.test_in_colab = False
100 | return config
101 |
102 |
103 | # By default, the launcher calls `sweep()`.
104 | # To disable the sweep, the `sweep()` function can be commented (or renamed),
105 | # or the flag `--nosweep` can be specified to the launcher.
106 | def sweep(add: abc.Callable[..., None]):
107 | """Starts multiple work units with varying config args."""
108 | add(
109 | learning_rate=5e-4,
110 | dropout_rate=0.05,
111 | weight_decay=0.03,
112 | )
113 |
--------------------------------------------------------------------------------
/md4/input_pipeline.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Deterministic input pipeline."""
17 |
18 | from collections.abc import Sequence
19 | import dataclasses
20 | import os
21 | from typing import Any, Union
22 | import urllib.request
23 | import zipfile
24 |
25 | import grain.python as grain
26 | import jax
27 | from ml_collections import config_dict
28 | import numpy as np
29 | import tensorflow as tf
30 | import tensorflow_datasets as tfds
31 | import transformers
32 |
33 |
34 | # pylint: disable=g-import-not-at-top
35 | try:
36 | import cv2
37 | except ImportError:
38 | print("cv2 not found")
39 | FlatFeatures = dict[str, Any]
40 |
41 | _DataSet = Union[grain.MapDataset, grain.DataLoader, grain.IterDataset]
42 |
43 | _GPT2_TOKENIZER = "gpt2"
44 | _OWT_DATASETS = dict(
45 | # OSS version. Please prepare the OWT datasets using the following command:
46 | # python ./prepare_openwebtext_data.py
47 | dataset_train_path=("./data_dir/openwebtext_splits_1024_train"),
48 | dataset_eval_path=("./data_dir/openwebtext_splits_1024_eval"),
49 | )
50 |
51 |
52 | class Text8Tokenizer:
53 | """Simple text8 tokenizer."""
54 |
55 | def __init__(self, num_extra_tokens=0):
56 | self.num_extra_tokens = num_extra_tokens
57 |
58 | @property
59 | def vocab_size(self):
60 | return 27 + self.num_extra_tokens
61 |
62 | @property
63 | def pad_token(self):
64 | return 26
65 |
66 | def encode(self, text):
67 | tokens = np.array([i - 97 for i in text], dtype=np.int32)
68 | tokens = np.where(tokens < 0, self.pad_token, tokens)
69 | return tokens
70 |
71 | def decode(self, tokens):
72 | tokens = np.where(np.equal(tokens, self.pad_token), 32 - 97, tokens) + 97
73 | text = tokens.astype(np.uint8).tobytes()
74 | return text.decode("utf-8")
75 |
76 |
77 | def preprocess_text8(
78 | data_dir,
79 | doc_length: int = 512,
80 | ):
81 | """Load the 27-char text8 dataset."""
82 | if not os.path.exists(os.path.join(data_dir, "text8.train.txt")):
83 | if not os.path.exists(data_dir):
84 | os.makedirs(data_dir)
85 | if not os.path.exists(os.path.join(data_dir, "text8.zip")):
86 | url = "http://mattmahoney.net/dc/text8.zip"
87 | print("Downloading text8 from URL {}.".format(url))
88 | urllib.request.urlretrieve(url, os.path.join(data_dir, "text8.zip"))
89 | with open(os.path.join(data_dir, "text8.zip"), "rb") as f:
90 | rawdata = zipfile.ZipFile(f).read("text8").decode("utf-8")
91 | splits = {
92 | "train": rawdata[:90000000],
93 | "valid": rawdata[90000000:95000000],
94 | "test": rawdata[95000000:],
95 | }
96 | for split, data in splits.items():
97 | with open(os.path.join(data_dir, "text8." + split + ".txt"), "w") as f:
98 | f.write(data)
99 |
100 | def load_text8_split(split: str):
101 | def _split_chars(arr):
102 | return tf.compat.v1.string_split(
103 | [arr], sep="", result_type="RaggedTensor"
104 | ).flat_values
105 |
106 | def _join_and_rename(x):
107 | text = tf.strings.reduce_join(x, axis=0)
108 | return {"text": text}
109 |
110 | path = os.path.join(data_dir, "text8." + split + ".txt")
111 | ds = tf.data.TextLineDataset(path).map(_split_chars).unbatch()
112 | ds = ds.batch(doc_length, drop_remainder=True)
113 | ds = ds.map(_join_and_rename)
114 | return ds
115 |
116 | # Define the builder.
117 | text8_builder = tfds.dataset_builders.store_as_tfds_dataset(
118 | name="text8",
119 | version="1.0.0",
120 | features=tfds.features.FeaturesDict({
121 | "text": tfds.features.Text(),
122 | }),
123 | split_datasets={
124 | "train": load_text8_split("train"),
125 | "valid": load_text8_split("valid"),
126 | "test": load_text8_split("test"),
127 | },
128 | config="text8",
129 | data_dir=data_dir,
130 | description="text8 dataset, document length 512.",
131 | file_format="array_record",
132 | disable_shuffling=True,
133 | )
134 |
135 | return text8_builder
136 |
137 |
138 | class ChunkDataSource(grain.RandomAccessDataSource):
139 | """Chunk text data source."""
140 |
141 | def __init__(self, tensor, chunk_size=256, overlapping=False):
142 | self.chunk_size = chunk_size
143 | self.overlapping = overlapping
144 | tensor = tensor.encode("utf-8")
145 | if not overlapping:
146 | extra_len = len(tensor) % chunk_size
147 | if extra_len > 0:
148 | tensor = tensor[:-extra_len]
149 | self.tensor = np.array(list(tensor)).reshape(-1, chunk_size)
150 | else:
151 | self.tensor = tensor
152 |
153 | def __len__(self):
154 | if not self.overlapping:
155 | return self.tensor.shape[0]
156 | else:
157 | return len(self.tensor) - self.chunk_size + 1
158 |
159 | def __getitem__(self, record_key):
160 | if not self.overlapping:
161 | return {"text": self.tensor[record_key]}
162 | else:
163 | start_idx = record_key
164 | end_idx = record_key + self.chunk_size
165 | chunk = self.tensor[start_idx:end_idx]
166 | return {"text": chunk}
167 |
168 | def __repr__(self) -> str:
169 | return f"ChunkDataSource(len={len(self)},overlapping={self.overlapping})"
170 |
171 |
172 | @dataclasses.dataclass
173 | class Tokenize(grain.MapTransform):
174 | tokenizer: Text8Tokenizer
175 |
176 | def map(self, features):
177 | text = features["text"]
178 | features["text"] = self.tokenizer.encode(text)
179 | return features
180 |
181 |
182 | @dataclasses.dataclass(frozen=True)
183 | class DiscreteWithoutLabel(grain.MapTransform):
184 | """Discrete image data with zero labels."""
185 |
186 | def map(self, features):
187 | features["image"] = features["image"].astype(np.int32)
188 | if "label" in features:
189 | del features["label"]
190 | if "id" in features:
191 | del features["id"]
192 | return features
193 |
194 |
195 | @dataclasses.dataclass(frozen=True)
196 | class ResizeSmall(grain.MapTransform):
197 | """Resizes the smaller side to `size` keeping aspect ratio.
198 |
199 | Attr:
200 | size: Smaller side of an input image (might be adjusted if max_size given).
201 | """
202 |
203 | size: int
204 |
205 | def map(self, features: FlatFeatures) -> FlatFeatures:
206 | image = features["image"]
207 | size = self.size
208 | image = cv2.resize(image, dsize=(size, size), interpolation=cv2.INTER_AREA)
209 | features["image"] = image.astype(np.int32)
210 | return features
211 |
212 |
213 | @dataclasses.dataclass(frozen=True)
214 | class CentralSquareCrop(grain.MapTransform):
215 | """Makes a square central square crop of a given size."""
216 |
217 | def map(self, features: FlatFeatures) -> FlatFeatures:
218 | image = features["image"]
219 | h, w = image.shape[:2]
220 | size = min(h, w)
221 | top = (h - size) // 2
222 | left = (w - size) // 2
223 | image = image[top : top + size, left : left + size, :]
224 | features["image"] = image
225 | return features
226 |
227 |
228 | def get_data_shape(config):
229 | return config.data_shape
230 |
231 |
232 | @dataclasses.dataclass(frozen=True)
233 | class DropFeatures(grain.MapTransform):
234 | feature_names: Sequence[str]
235 |
236 | def map(self, features: FlatFeatures) -> FlatFeatures:
237 | for feature_name in self.feature_names:
238 | del features[feature_name]
239 | return features
240 |
241 |
242 | @dataclasses.dataclass
243 | class ParseFeatures(grain.MapTransform):
244 | """Parse serialized example."""
245 |
246 | def __init__(self, data_column):
247 | self.data_column = data_column
248 |
249 | def map(self, features):
250 | def _parse(example):
251 | parsed = tf.io.parse_example(
252 | example,
253 | {
254 | self.data_column: tf.io.FixedLenSequenceFeature(
255 | [], dtype=tf.int64, allow_missing=True
256 | )
257 | },
258 | )
259 | return parsed
260 |
261 | return _parse(features)
262 |
263 |
264 | def get_num_train_steps(config: config_dict.ConfigDict) -> int:
265 | """Calculates the total number of training steps."""
266 | if config.num_train_steps > 0:
267 | return config.num_train_steps
268 | # From the beginning. We first shard the data (shard by process_count), then
269 | # combine all epochs, batch for all local devices.
270 | # In all steps we would drop the remainder (hence the use of integer
271 | # division).
272 | # When start_index is 0 the train_ds.cardinality() and num_train_steps should
273 | # be equivalent.
274 | if config.task_type == "image":
275 | tfds_info = tfds.builder(config.dataset).info
276 | num_train_records = tfds_info.splits["train"].num_examples
277 | return int(
278 | num_train_records // jax.process_count() * config.num_epochs
279 | ) // (config.per_device_batch_size * jax.local_device_count())
280 | else:
281 | raise NotImplementedError()
282 |
283 |
284 | def create_datasets(
285 | config: config_dict.ConfigDict, seed: int
286 | ) -> tuple[_DataSet, dict[str, _DataSet], dict[str, Any]]:
287 | """Create Grain data loaders for training and evaluation.
288 |
289 | Args:
290 | config: Configuration to use.
291 | seed: Seed for shuffle and random operations in the training dataset.
292 |
293 | Returns:
294 | A tuple with the training dataset loader, the evaluation dataset
295 | loader, and a dictionary of other infos.
296 | """
297 | info = {}
298 | assert config.batch_size % jax.process_count() == 0
299 | process_batch_size = config.batch_size // jax.process_count()
300 | eval_batch_size = config.get("eval_batch_size", config.batch_size)
301 | process_eval_batch_size = eval_batch_size // jax.process_count()
302 |
303 | if config.dataset == "text8":
304 | seq_len = config.data_shape[0]
305 | # Current train/valid format only support length of 256
306 | assert seq_len == 256
307 |
308 | with tf.io.gfile.GFile(
309 | os.path.join(DATA_DIR, "text8", "text8.zip"), "rb"
310 | ) as f:
311 | rawdata = zipfile.ZipFile(f).read("text8").decode("utf-8")
312 | splits = {
313 | "train": rawdata[:90000000],
314 | "valid": rawdata[90000000:95000000],
315 | "test": rawdata[95000000:],
316 | }
317 |
318 | tokenizer = Text8Tokenizer(num_extra_tokens=0)
319 | train_transformations = [Tokenize(tokenizer)]
320 | train_source = ChunkDataSource(
321 | splits["train"], chunk_size=seq_len, overlapping=True
322 | )
323 |
324 | eval_transformations = [Tokenize(tokenizer)]
325 | eval_source = {
326 | k: ChunkDataSource(splits[k], chunk_size=seq_len)
327 | for k in ["valid", "test"]
328 | }
329 | info["tokenizer"] = tokenizer
330 | info["rawdata"] = rawdata
331 | elif config.dataset == "openwebtext":
332 | # we need to pretrain a GPT2 size model with context length of 1024.
333 | seq_len = config.data_shape[0]
334 | assert seq_len == 1024
335 | train_transformations = [ParseFeatures(data_column="text")]
336 | eval_transformations = [ParseFeatures(data_column="text")]
337 |
338 | train_table_path = _OWT_DATASETS["dataset_train_path"]
339 | train_source = grain.ArrayRecordDataSource(paths=train_table_path)
340 |
341 | eval_source = {
342 | "owt_eval": grain.ArrayRecordDataSource(
343 | paths=_OWT_DATASETS["dataset_eval_path"]
344 | ),
345 | }
346 |
347 | tokenizer = transformers.GPT2Tokenizer.from_pretrained(_GPT2_TOKENIZER)
348 | info["tokenizer"] = tokenizer
349 | elif (
350 | config.dataset.startswith("mnist")
351 | or config.dataset.startswith("cifar")
352 | or config.dataset.startswith("downsampled_imagenet")
353 | ):
354 | data_source = tfds.data_source(config.dataset)
355 | train_transformations = [DiscreteWithoutLabel()]
356 | train_source = data_source["train"]
357 | eval_transformations = [DiscreteWithoutLabel()]
358 | eval_source = {k: v for k, v in data_source.items() if k != "train"}
359 | elif config.dataset == "class_cond_imagenet64":
360 | data_source = tfds.data_source("imagenet2012")
361 | train_transformations = [
362 | CentralSquareCrop(),
363 | ResizeSmall(64),
364 | DropFeatures(("file_name",)),
365 | ]
366 | train_source = data_source["train"]
367 | eval_transformations = [
368 | CentralSquareCrop(),
369 | ResizeSmall(64),
370 | DropFeatures(("file_name",)),
371 | ]
372 | eval_source = {"validation": data_source["validation"]}
373 | else:
374 | raise NotImplementedError("Unsupported datasets.")
375 |
376 | train_loader = grain.load(
377 | source=train_source,
378 | shuffle=True,
379 | seed=seed,
380 | shard_options=grain.ShardByJaxProcess(drop_remainder=True),
381 | transformations=train_transformations,
382 | batch_size=process_batch_size,
383 | worker_count=config.grain_num_workers,
384 | )
385 |
386 | if config.eval_pad_last_batch:
387 | raise NotImplementedError(
388 | "BatchWithPadElements is not implemented in PyGrain yet."
389 | )
390 | else:
391 | drop_remainder = True
392 | shard_options = grain.ShardByJaxProcess(drop_remainder=drop_remainder)
393 |
394 | eval_loaders = {}
395 | for split in eval_source:
396 | eval_loader = grain.load(
397 | source=eval_source[split],
398 | num_epochs=1,
399 | shard_options=shard_options,
400 | transformations=eval_transformations,
401 | batch_size=process_eval_batch_size,
402 | worker_count=0,
403 | drop_remainder=drop_remainder,
404 | )
405 | eval_loaders[split] = eval_loader
406 |
407 | return train_loader, eval_loaders, info
408 |
--------------------------------------------------------------------------------
/md4/input_pipeline_v2.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Input pipeline for grain based datasets."""
17 |
18 | from collections.abc import Sequence
19 | import dataclasses
20 | import threading
21 | from typing import Any
22 |
23 | from absl import logging
24 | import datasets
25 | import datasets.distributed
26 | from etils import epath
27 | import grain.python as grain
28 | import jax
29 | from ml_collections import config_dict
30 | import numpy as np
31 | import tensorflow as tf
32 | import tensorflow_datasets as tfds
33 | import transformers
34 |
35 |
36 | _GPT2_TOKENIZER = "gpt2"
37 | _OWT_DATASETS = dict(
38 | # OSS version. Please prepare the OWT datasets using the following command:
39 | # python ./prepare_openwebtext_data.py
40 | dataset_train_path=("./data_dir/openwebtext_splits_1024_train"),
41 | dataset_eval_path=("./data_dir/openwebtext_splits_1024_eval"),
42 | )
43 |
44 |
45 | class HFDataSource(grain.RandomAccessDataSource):
46 | """A class that makes HuggingFace IterableDataset a grain datasource without random access support."""
47 |
48 | def __init__(
49 | self,
50 | dataset: datasets.IterableDataset,
51 | dataloading_host_index: int,
52 | dataloading_host_count: int,
53 | num_threads: int,
54 | generate_padding_example: bool,
55 | max_target_length: int,
56 | data_column_name: str,
57 | ):
58 | self.dataset = dataset
59 | self.num_threads = num_threads
60 | self.dataloading_host_count = dataloading_host_count
61 | self.dataloading_host_index = dataloading_host_index
62 | self.generate_padding_example = generate_padding_example
63 | self.max_target_lenth = max_target_length
64 | self.data_column_name = data_column_name
65 | self.n_shards = dataset.n_shards
66 | self._check_shard_count()
67 | self.dataset_shards = [
68 | dataloading_host_index * self.num_threads + i
69 | for i in range(self.num_threads)
70 | ]
71 | self.datasets = [
72 | datasets.distributed.split_dataset_by_node(
73 | dataset, world_size=self.n_shards, rank=x
74 | )
75 | for x in self.dataset_shards
76 | ]
77 | self.data_iters = []
78 | self.out_of_data = False
79 |
80 | def _check_shard_count(self):
81 | if self.n_shards < (self.dataloading_host_count * self.num_threads):
82 | print(
83 | f"WARNING: Inefficient dataloading. Your train or eval dataset"
84 | f" contains {self.n_shards} shards, smaller than number of host"
85 | " loading data. This is known to lead to inefficient dataloading."
86 | " see"
87 | " https://github.com/google/maxtext/blob/main/getting_started/Data_Input_Pipeline.md#multihost-dataloading-best-practice"
88 | )
89 | self.n_shards = self.dataloading_host_count * self.num_threads
90 |
91 | def _update_shard(self, idx):
92 | new_shard = (
93 | self.dataset_shards[idx]
94 | + self.dataloading_host_count * self.num_threads
95 | )
96 | if new_shard < self.n_shards:
97 | print(
98 | f"Updating host {self.dataloading_host_index} dataset {idx}, was on"
99 | f" shard {self.dataset_shards[idx]}"
100 | )
101 | print(f"New shard is {new_shard}")
102 | self.dataset_shards[idx] = new_shard
103 | self.datasets[idx] = datasets.distributed.split_dataset_by_node(
104 | self.dataset, world_size=self.n_shards, rank=self.dataset_shards[idx]
105 | )
106 | self.data_iters[idx] = iter(self.datasets[idx])
107 | else:
108 | print(
109 | f"Run out of shards on host {self.dataloading_host_index}, shard"
110 | f" {self.dataset_shards[idx]} is not available"
111 | )
112 | self.out_of_data = True
113 | if self.generate_padding_example:
114 | print(
115 | f"Host {self.dataloading_host_index} will start generating all-0"
116 | " padding examples until step number is met."
117 | )
118 |
119 | def __len__(self):
120 | """Return length of the HF dataset.
121 |
122 | Since HuggingFace IterableDataset does not have length,
123 | a fake length bigger than the dataset is returned
124 | """
125 | return 10_000_000_000
126 |
127 | def __getitem__(self, index):
128 | """Since HuggingFace IterableDataset doesn't support random access by index.
129 |
130 | The next item in the iterator is returned.
131 |
132 | Args:
133 | index: The index of the item to return.
134 |
135 | Returns:
136 | The next item in the iterator.
137 | """
138 | if not self.data_iters:
139 | self.data_iters = [iter(x) for x in self.datasets]
140 | idx = int(threading.current_thread().name.split("_")[1])
141 |
142 | while True:
143 | try:
144 | if self.out_of_data:
145 | if self.generate_padding_example:
146 | return {
147 | self.data_column_name: np.zeros(
148 | self.max_target_lenth, dtype=np.int32
149 | )
150 | }
151 | else:
152 | return None
153 | data = next(self.data_iters[idx])
154 | return data
155 | except StopIteration:
156 | self._update_shard(idx)
157 |
158 |
159 | @dataclasses.dataclass
160 | class NormalizeFeatures(grain.MapTransform):
161 | """Normalize text feature keys."""
162 |
163 | def __init__(self, column_name):
164 | self.column_name = column_name
165 |
166 | def map(self, features):
167 | return {"text": features[self.column_name].decode()}
168 |
169 |
170 | @dataclasses.dataclass
171 | class HFNormalizeFeatures(grain.MapTransform):
172 | """Normalize feature keys for HuggingFace input."""
173 |
174 | def __init__(self, column_name):
175 | self.column_name = column_name
176 |
177 | def map(self, features):
178 | return {
179 | "text": np.asarray(features[self.column_name], dtype=np.int32),
180 | }
181 |
182 |
183 | @dataclasses.dataclass
184 | class ReformatPacking(grain.MapTransform):
185 | """Reformat packing outputs."""
186 |
187 | def map(self, data):
188 | return {
189 | "text": data[0]["text"],
190 | "text_segmentation": data[1]["text"],
191 | "text_position": data[2]["text"],
192 | }
193 |
194 |
195 | @dataclasses.dataclass
196 | class PadToMaxLength(grain.MapTransform):
197 | """Pads each input to the specified length."""
198 |
199 | def __init__(self, max_length):
200 | self.max_length = max_length
201 |
202 | def map(self, data):
203 | """map to each element."""
204 |
205 | def _pad(x, max_length):
206 | pad_amount = max(max_length - x.shape[0], 0)
207 | pad_amount = [(0, pad_amount)] + [(0, 0)] * (len(x.shape) - 1)
208 | return np.pad(x, pad_amount)
209 |
210 | data["text_segmentation"] = np.ones(data["text"].shape, dtype=np.int32)
211 | data["text_position"] = np.arange(data["text"].shape[0], dtype=np.int32)
212 | for key, _ in data.items():
213 | data[key] = _pad(data[key], self.max_length)
214 | return data
215 |
216 |
217 | @dataclasses.dataclass
218 | class TokenizeAndTrim(grain.MapTransform):
219 | """Tokenize and trim features to sequence length."""
220 |
221 | # pylint: disable=attribute-defined-outside-init
222 | feature_names: Sequence[str]
223 | sequence_length: Sequence[int]
224 | tokenizer: Any
225 | add_bos: bool
226 | add_eos: bool
227 |
228 | def map(self, features: dict[str, Any]) -> dict[str, Any]:
229 | """Maps to each element."""
230 | for feature_name, sequence_length in zip(
231 | self.feature_names, self.sequence_length, strict=True
232 | ):
233 | text = features[feature_name]
234 | token_ids = self.tokenizer(text)["input_ids"]
235 | if self.add_bos:
236 | token_ids = [self.tokenizer.bos_token_id] + token_ids
237 |
238 | if self.add_eos:
239 | token_ids = token_ids[: sequence_length - 1]
240 | token_ids = token_ids + [self.tokenizer.eos_token_id]
241 | else:
242 | token_ids = token_ids[:sequence_length]
243 |
244 | features[feature_name] = np.asarray(token_ids, dtype=np.int32)
245 | return features
246 |
247 |
248 | @dataclasses.dataclass
249 | class ParseFeatures(grain.MapTransform):
250 | """Parse serialized example."""
251 |
252 | def __init__(self, data_column):
253 | self.data_column = data_column
254 |
255 | def map(self, features):
256 | def _parse(example):
257 | parsed = tf.io.parse_example(
258 | example,
259 | {
260 | self.data_column: tf.io.FixedLenSequenceFeature(
261 | [], dtype=tf.int64, allow_missing=True
262 | )
263 | },
264 | )
265 | return parsed
266 |
267 | return _parse(features)
268 |
269 |
270 |
271 |
272 | def tokenization(example, hf_tokenizer, max_length, column_name):
273 | """Tokenize a HuggingFace dataset."""
274 | return hf_tokenizer(
275 | example[column_name], truncation=True, max_length=max_length
276 | )
277 |
278 |
279 | def load_fineweb_edu_hf_source():
280 | """Loads fineweb_edu data source from HuggingFace.
281 |
282 | Returns:
283 | A grain data source for fineweb_edu.
284 | """
285 | tokenizer = transformers.GPT2Tokenizer.from_pretrained("gpt2")
286 | fw = datasets.load_dataset(
287 | "HuggingFaceFW/fineweb-edu", name="default", split="train", streaming=True
288 | )
289 | fw = fw.map(
290 | tokenization,
291 | batched=True,
292 | fn_kwargs={
293 | "hf_tokenizer": tokenizer,
294 | "max_length": 1023,
295 | "column_name": "text",
296 | },
297 | )
298 | fw = fw.select_columns(["input_ids"]).rename_column("input_ids", "text")
299 | return HFDataSource(fw, 0, 1, 1, False, 1024, "text")
300 |
301 |
302 | def compile_transformations(
303 | seq_len,
304 | tokenizer,
305 | data_column="text",
306 | add_bos=False,
307 | add_eos=True,
308 | packing=True,
309 | drop_remainder=True,
310 | process_batch_size=32,
311 | ):
312 | """Collects transformations for the grain input pipeline."""
313 | # Normalize: convert bytes to string ready for tokenization
314 | operations = []
315 | operations.append(NormalizeFeatures(data_column))
316 | operations.append(
317 | TokenizeAndTrim([data_column], [seq_len], tokenizer, add_bos, add_eos)
318 | )
319 |
320 | # Pack and Batch examples.
321 | if packing:
322 | operations.append(
323 | grain.experimental.PackAndBatchOperation(
324 | batch_size=process_batch_size,
325 | length_struct={data_column: seq_len},
326 | )
327 | )
328 | operations.append(ReformatPacking())
329 | else:
330 | operations.append(PadToMaxLength(seq_len))
331 | operations.append(
332 | grain.Batch(
333 | batch_size=process_batch_size,
334 | drop_remainder=drop_remainder,
335 | )
336 | )
337 | return operations
338 |
339 |
340 | def compile_hf_transformations(
341 | seq_len,
342 | data_column="text",
343 | process_batch_size=32,
344 | ):
345 | """Collects transformations for the grain input pipeline."""
346 | operations = []
347 | operations.append(HFNormalizeFeatures(data_column))
348 | operations.append(
349 | grain.experimental.PackAndBatchOperation(
350 | batch_size=process_batch_size,
351 | length_struct={data_column: seq_len},
352 | )
353 | )
354 | operations.append(ReformatPacking())
355 | return operations
356 |
357 |
358 | def create_datasets(
359 | config: config_dict.ConfigDict, seed: int
360 | ) -> tuple[grain.DataLoader, dict[str, grain.DataLoader], dict[str, Any]]:
361 | """Create Grain data loaders for training and evaluation.
362 |
363 | For the same seed and config this will return the same datasets.
364 | The user is responsible to save()/load() the dataset iterators (for training)
365 | or calling reset() to restart the iterator (for eval).
366 |
367 | Args:
368 | config: Configuration to use.
369 | seed: Seed for shuffle and random operations in the training dataset.
370 |
371 | Returns:
372 | A tuple with the training dataset loader, the evaluation dataset
373 | loader, and a dictionary of other infos.
374 | """
375 | info = {}
376 | assert config.batch_size % jax.process_count() == 0
377 | process_batch_size = config.batch_size // jax.process_count()
378 | eval_batch_size = config.get("eval_batch_size", config.batch_size)
379 | process_eval_batch_size = eval_batch_size // jax.process_count()
380 |
381 | if config.dataset == "fineweb_edu":
382 | # we need to pretrain a GPT2 size model with context length of 1024.
383 | seq_len = config.data_shape[0]
384 | assert seq_len == 1024
385 |
386 | tokenizer = transformers.GPT2Tokenizer.from_pretrained(_GPT2_TOKENIZER)
387 |
388 | load_fineweb_edu_source = load_fineweb_edu_hf_source
389 | train_source = load_fineweb_edu_source()
390 | transformations = compile_hf_transformations(
391 | seq_len,
392 | data_column="text",
393 | process_batch_size=process_batch_size,
394 | )
395 | eval_sources = {
396 | "owt_eval": grain.ArrayRecordDataSource(
397 | paths=_OWT_DATASETS["dataset_eval_path"]
398 | ),
399 | # "fwe_eval": eval_source,
400 | }
401 | eval_transformations = {
402 | "owt_eval": [
403 | ParseFeatures(data_column="text"),
404 | grain.Batch(
405 | batch_size=process_eval_batch_size,
406 | drop_remainder=True,
407 | ),
408 | ],
409 | # "fwe_eval": transformations,
410 | }
411 | info["tokenizer"] = tokenizer
412 |
413 | else:
414 | raise NotImplementedError("Unsupported datasets.")
415 | index_sampler = grain.IndexSampler(
416 | num_records=len(train_source),
417 | shard_options=grain.ShardByJaxProcess(drop_remainder=True),
418 | shuffle=True,
419 | seed=seed,
420 | )
421 | train_loader = grain.DataLoader(
422 | data_source=train_source,
423 | operations=transformations,
424 | sampler=index_sampler,
425 | worker_count=config.grain_num_workers,
426 | worker_buffer_size=1,
427 | read_options=grain.ReadOptions(num_threads=1, prefetch_buffer_size=1024),
428 | )
429 |
430 | if config.eval_pad_last_batch:
431 | raise NotImplementedError(
432 | "BatchWithPadElements is not implemented in PyGrain yet."
433 | )
434 | else:
435 | drop_remainder = True
436 | shard_options = grain.ShardByJaxProcess(drop_remainder=drop_remainder)
437 |
438 | eval_loaders = {}
439 | for split in eval_sources:
440 | eval_loader = grain.load(
441 | source=eval_sources[split],
442 | num_epochs=1,
443 | shard_options=shard_options,
444 | transformations=eval_transformations[split],
445 | # For now, we do not parallelize the evaluation, because there is a
446 | # bug on DataLoader.__iter__ when used with Jax.
447 | worker_count=0,
448 | read_options=grain.ReadOptions(prefetch_buffer_size=1024),
449 | )
450 | eval_loaders[split] = eval_loader
451 |
452 | return train_loader, eval_loaders, info
453 |
--------------------------------------------------------------------------------
/md4/main.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Main file for running the example.
17 |
18 | This file is intentionally kept short. The majority for logic is in libraries
19 | than can be easily tested and imported in Colab.
20 | """
21 |
22 | from absl import app
23 | from absl import flags
24 | from absl import logging
25 | # Required import to setup work units when running through XManager.
26 | from clu import platform
27 | import jax
28 | from ml_collections import config_flags
29 | import tensorflow.compat.v2 as tf
30 |
31 | from md4 import sharded_train
32 | from md4 import train
33 |
34 |
35 | FLAGS = flags.FLAGS
36 |
37 | config_flags.DEFINE_config_file(
38 | "config", None, "Training configuration.", lock_config=True)
39 | flags.DEFINE_string("workdir", None, "Work unit directory.")
40 | flags.DEFINE_boolean("sharded", False, "Whether to use sharded training.")
41 | flags.mark_flags_as_required(["config", "workdir"])
42 | # Flags --jax_backend_target and --jax_xla_backend are available through JAX.
43 |
44 |
45 | def main(argv):
46 | del argv
47 |
48 | tf.enable_v2_behavior()
49 | # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
50 | # it unavailable to JAX.
51 | tf.config.experimental.set_visible_devices([], "GPU")
52 |
53 | if FLAGS.jax_backend_target:
54 | logging.info("Using JAX backend target %s", FLAGS.jax_backend_target)
55 | jax_xla_backend = ("None" if FLAGS.jax_xla_backend is None else
56 | FLAGS.jax_xla_backend)
57 | logging.info("Using JAX XLA backend %s", jax_xla_backend)
58 |
59 | logging.info("JAX process: %d / %d", jax.process_index(), jax.process_count())
60 | logging.info("JAX devices: %r", jax.devices())
61 |
62 | platform.work_unit().set_task_status(f"process_index: {jax.process_index()}, "
63 | f"process_count: {jax.process_count()}")
64 | platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
65 | FLAGS.workdir, "workdir")
66 |
67 | if FLAGS.sharded:
68 | sharded_train.train_and_evaluate(FLAGS.config, FLAGS.workdir)
69 | else:
70 | train.train_and_evaluate(FLAGS.config, FLAGS.workdir)
71 |
72 |
73 | if __name__ == "__main__":
74 | # Provide access to --jax_backend_target and --jax_xla_backend flags.
75 | jax.config.config_with_absl()
76 | run_main = app.run
77 | run_main(main)
78 |
--------------------------------------------------------------------------------
/md4/models/backward.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Classifier implementation."""
17 |
18 | from collections.abc import Sequence
19 |
20 | import flax.linen as nn
21 | import jax
22 | import jax.numpy as jnp
23 |
24 | from md4.networks import sharded_transformer
25 | from md4.networks import transformer
26 | from md4.networks import unet
27 | from md4.networks import uvit
28 |
29 |
30 | def get_timestep_embedding(timesteps, embedding_dim, dtype='float'):
31 | """Build sinusoidal embeddings."""
32 |
33 | assert embedding_dim > 2
34 | # timesteps: [bs]
35 | half_dim = embedding_dim // 2
36 | emb = jnp.log(10_000) / (half_dim - 1)
37 | emb = jnp.exp(jnp.arange(half_dim, dtype='float32') * -emb)
38 | emb = timesteps.astype('float32')[:, None] * emb[None, :]
39 | emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=1)
40 | if embedding_dim % 2 == 1: # zero pad
41 | emb = jax.lax.pad(emb, jnp.array(0, dtype), ((0, 0, 0), (0, 1, 0)))
42 | # ret: [bs, embedding_dim]
43 | return emb
44 |
45 |
46 | class CondEmbedding(nn.Module):
47 | """Time and cond embeddings."""
48 |
49 | embedding_dim: int = 256
50 |
51 | @nn.compact
52 | def __call__(self, t, cond=None):
53 | # t: [bs]
54 | n_embd = self.embedding_dim
55 | temb = get_timestep_embedding(t, n_embd)
56 | if cond is None:
57 | cond = temb
58 | else:
59 | cond = jnp.concatenate([temb, cond], axis=-1)
60 | cond = nn.swish(nn.Dense(features=n_embd * 4, name='dense0')(cond))
61 | cond = nn.Dense(n_embd)(cond)
62 | return cond
63 |
64 |
65 | class UNet5DWrapper(nn.Module):
66 | """5D to 5D UNet wrapper."""
67 |
68 | feature_dim: int = 128
69 | n_layers: int = 32
70 | n_dit_layers: int = 0
71 | dit_num_heads: int = 12
72 | dit_hidden_size: int = 768
73 | ch_mult: Sequence[int] = (1,)
74 | output_channels: int = 256
75 | dropout_rate: float = 0.1
76 |
77 | @nn.compact
78 | def __call__(self, z, cond=None, train=False):
79 | # [bs, h, w, c, d or |V|] -> [bs, h, w, c, d or |V|]
80 | # Flatten the last two dimensions to pass to UNet
81 | h = z.reshape(list(z.shape)[:-2] + [-1])
82 |
83 | if self.n_dit_layers > 0:
84 | h = uvit.UNet(
85 | d_channels=self.feature_dim,
86 | n_layers=self.n_layers,
87 | n_dit_layers=self.n_dit_layers,
88 | dit_num_heads=self.dit_num_heads,
89 | dit_hidden_size=self.dit_hidden_size,
90 | ch_mult=self.ch_mult,
91 | output_channels=self.output_channels * z.shape[-2],
92 | dropout_rate=self.dropout_rate,
93 | )(h, cond=cond, train=train)
94 | else:
95 | h = unet.UNet(
96 | d_channels=self.feature_dim,
97 | n_layers=self.n_layers,
98 | output_channels=self.output_channels * z.shape[-2],
99 | dropout_rate=self.dropout_rate,
100 | )(h, cond=cond, train=train)
101 |
102 | # ret: [bs, h, w, c, output_channels]
103 | return h.reshape(list(z.shape)[:-1] + [self.output_channels])
104 |
105 |
106 | class DiscreteClassifier(nn.Module):
107 | """Discrete input classifier implementation."""
108 |
109 | n_layers: int = 12
110 | n_dit_layers: int = 0
111 | dit_num_heads: int = 12
112 | dit_hidden_size: int = 768
113 | ch_mult: Sequence[int] = (1,)
114 | feature_dim: int = 64
115 | num_heads: int = 12
116 | vocab_size: int = 1000
117 | dropout_rate: float = 0.0
118 | use_attn_dropout: bool = True
119 | mlp_type: str = 'swiglu'
120 | depth_scaled_init: bool = False
121 | cond_type: str = 'adaln'
122 | outside_embed: bool = False
123 | model_sharding: bool = False
124 |
125 | @nn.compact
126 | def __call__(self, z, t=None, cond=None, train=False):
127 | if t is not None:
128 | # z: [bs, seq_len] or [bs, h, w, c]
129 | assert jnp.isscalar(t) or t.ndim == 0 or t.ndim == 1
130 | t = t * jnp.ones(z.shape[0]) # ensure t is a vector
131 | cond = CondEmbedding(self.feature_dim)(t * 1000, cond=cond)
132 |
133 | if z.ndim == 2:
134 | if self.outside_embed:
135 | z = nn.Embed(self.vocab_size + 1, self.feature_dim)(z)
136 | if self.model_sharding:
137 | args = sharded_transformer.ModelArgs(
138 | dim=self.feature_dim * self.num_heads,
139 | n_layers=self.n_layers,
140 | n_heads=self.num_heads,
141 | n_kv_heads=self.num_heads,
142 | output_channels=self.vocab_size,
143 | multiple_of=32,
144 | dropout_rate=self.dropout_rate,
145 | depth_scaled_init=self.depth_scaled_init,
146 | mlp_type=self.mlp_type,
147 | cond_type=self.cond_type,
148 | embed_input=not self.outside_embed,
149 | n_embed_classes=self.vocab_size + 1,
150 | use_attn_dropout=self.use_attn_dropout,
151 | )
152 | # [bs, seq_len] -> [bs, seq_len, |V|]
153 | net = sharded_transformer.Transformer(args)
154 | else:
155 | args = transformer.ModelArgs(
156 | dim=self.feature_dim * self.num_heads,
157 | n_layers=self.n_layers,
158 | n_heads=self.num_heads,
159 | n_kv_heads=self.num_heads,
160 | output_channels=self.vocab_size,
161 | multiple_of=32,
162 | dropout_rate=self.dropout_rate,
163 | depth_scaled_init=self.depth_scaled_init,
164 | mlp_type=self.mlp_type,
165 | cond_type=self.cond_type,
166 | embed_input=not self.outside_embed,
167 | n_embed_classes=self.vocab_size + 1,
168 | )
169 | # [bs, seq_len] -> [bs, seq_len, |V|]
170 | net = transformer.Transformer(args)
171 | logits = net(z, cond=cond, train=train)
172 | elif z.ndim == 4:
173 | z = nn.Embed(self.vocab_size + 1, self.feature_dim)(z)
174 |
175 | # [bs, h, w, c, d] -> [bs, h, w, c, |V|]
176 | net = UNet5DWrapper(
177 | feature_dim=self.feature_dim,
178 | n_layers=self.n_layers,
179 | n_dit_layers=self.n_dit_layers,
180 | dit_num_heads=self.dit_num_heads,
181 | dit_hidden_size=self.dit_hidden_size,
182 | ch_mult=self.ch_mult,
183 | output_channels=self.vocab_size,
184 | dropout_rate=self.dropout_rate,
185 | )
186 | logits = net(z, cond=cond, train=train)
187 | else:
188 | raise NotImplementedError()
189 |
190 | return logits, {}
191 |
--------------------------------------------------------------------------------
/md4/models/diffusion/genmd4.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Generalized state-dependent masked diffusion (GenMD4)."""
17 |
18 | from typing import Sequence
19 |
20 | import flax.linen as nn
21 | import jax
22 | import jax.numpy as jnp
23 | import tensorflow_probability.substrates.jax as tfp
24 |
25 | from md4 import utils
26 | from md4.models import backward
27 |
28 |
29 | tfd = tfp.distributions
30 |
31 |
32 | class LearnableVecMaskingSchedule(nn.Module):
33 | """Learnable vector-valued masking schedule for GenMD4."""
34 |
35 | data_shape: tuple[int, ...]
36 | schedule_fn_type: str = 'poly'
37 | vocab_size: int = 256
38 | eps: float = 1e-4
39 | power_init: float = 1.0
40 |
41 | def setup(self):
42 | if self.schedule_fn_type == 'poly':
43 | w_init = jnp.log(jnp.exp(self.power_init) - 1.0)
44 | self.w = self.param('w', utils.constant_init(w_init), [self.vocab_size])
45 | # Reduce to MD4 with a shared scalar schedule:
46 | # self.w = self.param('w', utils.constant_init(w_init), [])
47 | # self.power = jnp.tile(nn.softplus(self.w)[..., None], [self.vocab_size])
48 | self.power = nn.softplus(self.w)
49 | else:
50 | raise NotImplementedError()
51 |
52 | def __call__(self, t):
53 | # return logSNR
54 | return jnp.log(self.alpha(t) / (1.0 - self.alpha(t)))
55 |
56 | def dalpha(self, t):
57 | if self.schedule_fn_type == 'poly':
58 | # ret: [..., |V|]
59 | return (
60 | -(1.0 - self.eps)
61 | * self.power
62 | * jnp.array(t)[..., None] ** (self.power - 1.0)
63 | )
64 | else:
65 | raise NotImplementedError()
66 |
67 | def alpha(self, t):
68 | if self.schedule_fn_type == 'poly':
69 | # instead of offsetting alpha_0 by eps as in MD4 class, we set alpha_0=1
70 | # and use a small non-zero t1 to avoid numerical issues, this gives a
71 | # nicer form of dgamma_times_alpha which is -w/t for a polynomial
72 | # schedule 1 - t**w
73 | # ret: [..., |V|]
74 | return 1.0 - (1.0 - self.eps) * jnp.array(t)[..., None] ** self.power
75 | else:
76 | raise NotImplementedError()
77 |
78 | def dgamma_times_alpha(self, t):
79 | # ret: [..., |V|]
80 | return -self.power / jnp.array(t)[..., None]
81 |
82 |
83 | class GenMD4(nn.Module):
84 | """Generalized state-Dependent masked discrete diffusion model."""
85 |
86 | data_shape: tuple[int, ...]
87 | cont_time: bool = False
88 | timesteps: int = 1000
89 | feature_dim: int = 128
90 | num_heads: int = 12
91 | antithetic_time_sampling: bool = True
92 | n_layers: int = 32
93 | n_dit_layers: int = 0
94 | dit_num_heads: int = 12
95 | dit_hidden_size: int = 768
96 | ch_mult: Sequence[int] = (1,)
97 | vocab_size: int = 256
98 | noise_schedule_type: str = 'poly'
99 | power_init: float = 1.0
100 | t1: float = 1e-3
101 | dropout_rate: float = 0.0
102 | use_attn_dropout: bool = True
103 | mlp_type: str = 'swiglu'
104 | depth_scaled_init: bool = False
105 | cond_type: str = 'adaln'
106 | outside_embed: bool = False
107 | # time_features: t or none
108 | time_features: str = 't'
109 | classes: int = 10 + 1 # image classes
110 |
111 | def setup(self):
112 | self.noise_schedule = LearnableVecMaskingSchedule(
113 | self.data_shape,
114 | schedule_fn_type=self.noise_schedule_type,
115 | vocab_size=self.vocab_size,
116 | power_init=self.power_init,
117 | )
118 |
119 | if self.classes > 0:
120 | self.cond_embeddings = nn.Embed(self.classes, self.feature_dim)
121 | self.classifier = backward.DiscreteClassifier(
122 | n_layers=self.n_layers,
123 | n_dit_layers=self.n_dit_layers,
124 | dit_num_heads=self.dit_num_heads,
125 | dit_hidden_size=self.dit_hidden_size,
126 | ch_mult=self.ch_mult,
127 | feature_dim=self.feature_dim,
128 | num_heads=self.num_heads,
129 | vocab_size=self.vocab_size,
130 | dropout_rate=self.dropout_rate,
131 | use_attn_dropout=self.use_attn_dropout,
132 | mlp_type=self.mlp_type,
133 | depth_scaled_init=self.depth_scaled_init,
134 | cond_type=self.cond_type,
135 | outside_embed=self.outside_embed,
136 | )
137 |
138 | def forward_sample(self, x, t):
139 | t = utils.reverse_broadcast(t, x.ndim)
140 | # alpha_t: [bs, 1, |V|] or [bs, 1, 1, 1, |V|]
141 | a = self.noise_schedule.alpha(t)
142 | # un_mask_p: [bs, seq_len] or [bs, h, w, c]
143 | un_mask_p = jnp.sum(a * nn.one_hot(x, self.vocab_size), axis=-1)
144 | un_mask = jax.random.bernoulli(self.make_rng('sample'), un_mask_p, x.shape)
145 | # MASK = vocab_size
146 | return jnp.where(un_mask, x, self.vocab_size)
147 |
148 | def prior_sample(self, batch_size):
149 | return self.vocab_size * jnp.ones(
150 | [batch_size] + list(self.data_shape), dtype='int32'
151 | )
152 |
153 | def get_cond_embedding(self, conditioning):
154 | if conditioning is not None:
155 | return self.cond_embeddings(conditioning)
156 | return None
157 |
158 | def predict_x(self, zt, t, cond=None, train=False):
159 | t = None if self.time_features == 'none' else t
160 | return self.classifier(zt, t=t, cond=cond, train=train)
161 |
162 | def visualize_classifier(self, x, t, conditioning=None):
163 | # if it's image, x: [bs, h, w, c]
164 | # if it's text, x: [bs, seq_len]
165 | cond = self.get_cond_embedding(conditioning)
166 | # t: []
167 | # if it's image, zt: [bs, h, w, c]
168 | # if it's text, zt: [bs, seq_len]
169 | zt = self.forward_sample(x, t)
170 | # logits: [bs, h, w, c, vocab_size] for images
171 | # [bs, seq_len, vocab_size] for text
172 | logits, _ = self.predict_x(zt, t, cond=cond)
173 | n_indep_axes = logits.ndim - 2
174 | dist = tfd.Independent(tfd.Categorical(logits=logits), n_indep_axes)
175 | return dist
176 |
177 | def encode(self, x, conditioning=None):
178 | del conditioning
179 | return x
180 |
181 | def recon_loss(self, x):
182 | """The reconstruction loss measures the gap in the first step."""
183 | assert self.noise_schedule_type == 'poly'
184 | eps = self.noise_schedule.eps
185 | # w: [|V|]
186 | w = self.noise_schedule.power
187 | # w_x: [bs, seq_len] or [bs, h, w, c]
188 | w_x = jnp.sum(w * nn.one_hot(x, self.vocab_size), axis=-1)
189 | t = jnp.array(self.t1)
190 | # wlogt_x: [bs, seq_len] or [bs, h, w, c]
191 | wlogt_x = w_x * jnp.log(t)
192 | # wlogt: [|V|]
193 | wlogt = w * jnp.log(t)
194 | remaining_axis = list(range(x.ndim)[1:])
195 | # loss_recon: [bs, seq_len] or [bs, h, w, c]
196 | loss_recon = (
197 | -(1 - eps) * jnp.exp(wlogt_x) * (wlogt_x - nn.logsumexp(wlogt, -1))
198 | ).sum(remaining_axis)
199 | return loss_recon
200 |
201 | def latent_loss(self):
202 | # negligible
203 | return jnp.array(0.0)
204 |
205 | def diffusion_loss(self, t, x, cond=None, train=False):
206 | assert self.cont_time
207 |
208 | # sample z_t
209 | zt = self.forward_sample(x, t)
210 | logits, _ = self.predict_x(zt, t, cond=cond, train=train)
211 | log_p = jax.nn.log_softmax(logits, axis=-1)
212 | one_hot_x = jax.nn.one_hot(x, self.vocab_size)
213 | neg_cross_ent = one_hot_x * log_p
214 | neg_cross_ent = jnp.where(one_hot_x, neg_cross_ent, 0.0)
215 | neg_cross_ent = jnp.sum(neg_cross_ent, axis=-1, keepdims=True)
216 | integrand = (neg_cross_ent + 1.0) * one_hot_x - jnp.exp(log_p)
217 | mask = (zt == self.vocab_size).astype('float')
218 |
219 | remaining_axis = list(range(x.ndim)[1:])
220 | # masked_neg_cross_ent: [bs, |V|]
221 | masked_neg_cross_ent = jnp.sum(mask[..., None] * integrand, remaining_axis)
222 |
223 | # cont-time loss
224 | loss_diff = (
225 | self.noise_schedule.dgamma_times_alpha(t) * masked_neg_cross_ent
226 | ).sum(axis=-1)
227 |
228 | # loss_diff: [bs]
229 | return loss_diff, zt
230 |
231 | def reinforce_loss(self, t, x, zt_1, zt_2, loss_diff_1, loss_diff_2):
232 | assert self.noise_schedule_type == 'poly'
233 | eps = self.noise_schedule.eps
234 | # w: [|V|]
235 | w = self.noise_schedule.power
236 | # w_x: [bs, seq_len] or [bs, h, w, c]
237 | w_x = jnp.sum(w * nn.one_hot(x, self.vocab_size), axis=-1)
238 | # t: [bs, 1] or [bs, 1, 1, 1]
239 | t = utils.reverse_broadcast(t, x.ndim)
240 | # alpha_t_x: [bs, seq_len] or [bs, h, w, c]
241 | alpha_t_x = 1.0 - (1.0 - eps) * t**w_x
242 | # log_q_mask = jnp.log(1.0 - alpha_t_x)
243 | log_q_mask = jnp.log(1.0 - eps) + w_x * jnp.log(t)
244 | log_q_unmask = jnp.log(alpha_t_x)
245 | log_q1 = jnp.where(zt_1 == self.vocab_size, log_q_mask, log_q_unmask)
246 | log_q2 = jnp.where(zt_2 == self.vocab_size, log_q_mask, log_q_unmask)
247 | remaining_axis = list(range(x.ndim)[1:])
248 | rloo_1 = (
249 | 0.5
250 | * jax.lax.stop_gradient(loss_diff_1 - loss_diff_2)
251 | * (log_q1.sum(remaining_axis) - log_q2.sum(remaining_axis))
252 | )
253 | return rloo_1
254 |
255 | @nn.compact
256 | def __call__(self, x, cond=None, train=False):
257 | bs = x.shape[0]
258 | cond = self.get_cond_embedding(cond)
259 |
260 | # 1. RECONSTRUCTION LOSS: []
261 | # add noise and reconstruct
262 | loss_recon = self.recon_loss(x).mean()
263 |
264 | # 2. LATENT LOSS: []
265 | loss_prior = self.latent_loss()
266 |
267 | # 3. DIFFUSION LOSS: [bs]
268 | # sample time steps
269 | rng1 = self.make_rng('sample')
270 | if self.antithetic_time_sampling:
271 | t0 = jax.random.uniform(rng1)
272 | t = jnp.mod(t0 + jnp.arange(0.0, 1.0, step=1.0 / bs), 1.0)
273 | else:
274 | t = jax.random.uniform(rng1, shape=[bs])
275 | # rescale t to be in [t1, 1.0]
276 | t = (1 - self.t1) * t + self.t1
277 |
278 | loss_diff_1, zt_1 = self.diffusion_loss(t, x, cond=cond, train=train)
279 | loss_diff_2, zt_2 = self.diffusion_loss(t, x, cond=cond, train=train)
280 | rloo_1 = self.reinforce_loss(t, x, zt_1, zt_2, loss_diff_1, loss_diff_2)
281 | loss_diff = 0.5 * (loss_diff_1 + loss_diff_2)
282 | loss_diff_sg = loss_diff + rloo_1
283 |
284 | # surrogate loss that includes the reinforce term
285 | loss = loss_diff_sg.mean() + loss_prior + loss_recon
286 | loss_diff = loss_diff.mean()
287 | # negative elbo
288 | loss_nelbo = loss_diff + loss_prior + loss_recon
289 |
290 | model_stats = {
291 | 'loss': loss,
292 | 'loss_nelbo': loss_nelbo,
293 | 'loss_diff': loss_diff,
294 | 'loss_prior': loss_prior,
295 | 'loss_recon': loss_recon,
296 | 'power_max': self.noise_schedule.power.max(),
297 | 'power_min': self.noise_schedule.power.min(),
298 | 'power_avg': self.noise_schedule.power.mean(),
299 | }
300 | model_stats = utils.loss2bpt(model_stats, self.data_shape)
301 | return model_stats
302 |
--------------------------------------------------------------------------------
/md4/models/diffusion/md4.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Simplified masked diffusion (MD4)."""
17 |
18 | import math
19 | from typing import Sequence
20 |
21 | import flax.linen as nn
22 | import jax
23 | import jax.numpy as jnp
24 | import tensorflow_probability.substrates.jax as tfp
25 |
26 | from md4 import binary_search
27 | from md4 import utils
28 | from md4.models import backward
29 |
30 |
31 | tfd = tfp.distributions
32 |
33 |
34 | class MaskingSchedule(nn.Module):
35 | """Masking noise schedule."""
36 |
37 | data_shape: tuple[int, ...]
38 | schedule_fn_type: str = 'cosine'
39 | eps: float = 1e-4
40 |
41 | def __call__(self, t):
42 | # return logSNR
43 | return jnp.log(self.alpha(t) / (1.0 - self.alpha(t)))
44 |
45 | def _dalpha(self, t):
46 | if self.schedule_fn_type == 'cosine':
47 | return -math.pi / 2.0 * jax.lax.sin(math.pi / 2.0 * (1.0 - t))
48 | elif self.schedule_fn_type == 'linear':
49 | return -jnp.ones_like(t)
50 | elif 'poly' in self.schedule_fn_type:
51 | exponent = float(self.schedule_fn_type.replace('poly', ''))
52 | return -exponent * t ** (exponent - 1.0)
53 | else:
54 | raise NotImplementedError()
55 |
56 | def dalpha(self, t):
57 | return (1.0 - 2 * self.eps) * self._dalpha(t)
58 |
59 | def _alpha(self, t):
60 | if self.schedule_fn_type == 'linear':
61 | return 1.0 - t
62 | elif 'poly' in self.schedule_fn_type:
63 | exponent = float(self.schedule_fn_type.replace('poly', ''))
64 | return 1.0 - t**exponent
65 | elif self.schedule_fn_type == 'cosine':
66 | return 1.0 - jax.lax.cos(math.pi / 2.0 * (1.0 - t))
67 | else:
68 | raise NotImplementedError()
69 |
70 | def alpha(self, t):
71 | return (1.0 - 2 * self.eps) * self._alpha(t) + self.eps
72 |
73 | def dgamma_times_alpha(self, t):
74 | return self.dalpha(t) / (1.0 - self.alpha(t))
75 |
76 |
77 | class MD4(nn.Module):
78 | """Simplified masked discrete diffusion model."""
79 |
80 | data_shape: tuple[int, ...]
81 | cont_time: bool = False
82 | timesteps: int = 1000
83 | feature_dim: int = 128
84 | num_heads: int = 12
85 | antithetic_time_sampling: bool = True
86 | n_layers: int = 32
87 | n_dit_layers: int = 0
88 | dit_num_heads: int = 12
89 | dit_hidden_size: int = 768
90 | ch_mult: Sequence[int] = (1,)
91 | vocab_size: int = 256
92 | noise_schedule_type: str = 'linear'
93 | dropout_rate: float = 0.0
94 | use_attn_dropout: bool = True
95 | mlp_type: str = 'swiglu'
96 | depth_scaled_init: bool = False
97 | cond_type: str = 'adaln'
98 | outside_embed: bool = False
99 | # time_features: t or none
100 | time_features: str = 't'
101 | classes: int = 10 + 1 # image classes
102 | sampler: str = 'analytic'
103 | # uniform, cosine
104 | sampling_grid: str = 'cosine'
105 | topp: float = 0.98
106 | model_sharding: bool = False
107 |
108 | def setup(self):
109 | self.noise_schedule = MaskingSchedule(
110 | self.data_shape, self.noise_schedule_type
111 | )
112 |
113 | if self.classes > 0:
114 | self.cond_embeddings = nn.Embed(self.classes, self.feature_dim)
115 | self.classifier = backward.DiscreteClassifier(
116 | n_layers=self.n_layers,
117 | n_dit_layers=self.n_dit_layers,
118 | dit_num_heads=self.dit_num_heads,
119 | dit_hidden_size=self.dit_hidden_size,
120 | ch_mult=self.ch_mult,
121 | feature_dim=self.feature_dim,
122 | num_heads=self.num_heads,
123 | vocab_size=self.vocab_size,
124 | dropout_rate=self.dropout_rate,
125 | use_attn_dropout=self.use_attn_dropout,
126 | mlp_type=self.mlp_type,
127 | depth_scaled_init=self.depth_scaled_init,
128 | cond_type=self.cond_type,
129 | outside_embed=self.outside_embed,
130 | model_sharding=self.model_sharding,
131 | )
132 |
133 | def forward_sample(self, x, t):
134 | t = utils.reverse_broadcast(t, x.ndim)
135 | a = self.noise_schedule.alpha(t)
136 | un_mask = jax.random.bernoulli(self.make_rng('sample'), a, x.shape)
137 | # MASK = vocab_size
138 | return jnp.where(un_mask, x, self.vocab_size)
139 |
140 | def prior_sample(self, batch_size):
141 | return self.vocab_size * jnp.ones(
142 | [batch_size] + list(self.data_shape), dtype='int32'
143 | )
144 |
145 | def get_cond_embedding(self, conditioning):
146 | if conditioning is not None:
147 | return self.cond_embeddings(conditioning)
148 | return None
149 |
150 | def predict_x(self, zt, t, cond=None, train=False):
151 | t = None if self.time_features == 'none' else t
152 | return self.classifier(zt, t=t, cond=cond, train=train)
153 |
154 | def visualize_classifier(self, x, t, conditioning=None):
155 | # if it's image, x: [bs, h, w, c]
156 | # if it's text, x: [bs, seq_len]
157 | cond = self.get_cond_embedding(conditioning)
158 | # t: []
159 | # if it's image, zt: [bs, h, w, c]
160 | # if it's text, zt: [bs, seq_len]
161 | zt = self.forward_sample(x, t)
162 | # logits: [bs, h, w, c, vocab_size] for images
163 | # [bs, seq_len, vocab_size] for text
164 | logits, _ = self.predict_x(zt, t, cond=cond)
165 | n_indep_axes = logits.ndim - 2
166 | dist = tfd.Independent(tfd.Categorical(logits=logits), n_indep_axes)
167 | return dist
168 |
169 | def encode(self, x, conditioning=None):
170 | del conditioning
171 | return x
172 |
173 | def decode(self, z0, conditioning=None):
174 | # Remove any mask tokens left in the last step of sampling.
175 | masked = z0 == self.vocab_size
176 | z0_cliped = jnp.where(masked, jnp.zeros_like(z0), z0)
177 | masked = masked[..., None]
178 | cond = self.get_cond_embedding(conditioning)
179 | logits, _ = self.predict_x(z0, jnp.array(0.0), cond=cond)
180 | probs = jnp.where(
181 | masked,
182 | nn.softmax(logits, axis=-1),
183 | jax.nn.one_hot(z0_cliped, self.vocab_size),
184 | )
185 | n_indep_axes = probs.ndim - 2
186 | dist = tfd.Independent(tfd.Categorical(probs=probs), n_indep_axes)
187 | return dist.mode().astype('int32')
188 |
189 | def recon_loss(self):
190 | """The reconstruction loss measures the gap in the first step."""
191 | alpha_t1 = self.noise_schedule.alpha(0.0)
192 | loss_recon = (
193 | jnp.prod(jnp.array(self.data_shape))
194 | * (1.0 - alpha_t1)
195 | * jnp.log(self.vocab_size)
196 | )
197 | return loss_recon
198 |
199 | def latent_loss(self):
200 | # negligible
201 | return jnp.array(0.0)
202 |
203 | def diffusion_loss(self, t, x, cond=None, train=False):
204 | if not self.cont_time:
205 | # discretize time steps
206 | t = (jnp.floor(t * self.timesteps) + 1) / self.timesteps
207 |
208 | # sample z_t
209 | zt = self.forward_sample(x, t)
210 | logits, _ = self.predict_x(zt, t, cond=cond, train=train)
211 | log_p = jax.nn.log_softmax(logits, axis=-1)
212 | one_hot_x = jax.nn.one_hot(x, self.vocab_size)
213 | neg_cross_ent = one_hot_x * log_p
214 | neg_cross_ent = jnp.where(one_hot_x, neg_cross_ent, 0.0)
215 | neg_cross_ent = jnp.sum(neg_cross_ent, axis=-1)
216 | mask = (zt == self.vocab_size).astype('float32')
217 |
218 | remaining_axis = list(range(x.ndim)[1:])
219 | # masked_neg_cross_ent: [bs]
220 | masked_neg_cross_ent = jnp.sum(mask * neg_cross_ent, remaining_axis)
221 |
222 | if not self.cont_time:
223 | # loss for finite depth T, i.e. discrete time
224 | s = t - (1.0 / self.timesteps)
225 | gt = self.noise_schedule(t)
226 | gs = self.noise_schedule(s)
227 | loss_diff = (
228 | self.timesteps
229 | * jnp.expm1(gt - gs)
230 | * self.noise_schedule.alpha(s)
231 | * masked_neg_cross_ent
232 | )
233 | else:
234 | # cont-time loss
235 | loss_diff = (
236 | self.noise_schedule.dgamma_times_alpha(t) * masked_neg_cross_ent
237 | )
238 |
239 | # loss_diff: [bs]
240 | return loss_diff
241 |
242 | @nn.compact
243 | def __call__(self, x, cond=None, train=False):
244 | bs = x.shape[0]
245 | cond = self.get_cond_embedding(cond)
246 |
247 | # 1. RECONSTRUCTION LOSS: []
248 | # add noise and reconstruct
249 | loss_recon = self.recon_loss()
250 |
251 | # 2. LATENT LOSS: []
252 | loss_prior = self.latent_loss()
253 |
254 | # 3. DIFFUSION LOSS: [bs]
255 | # sample time steps
256 | rng1 = self.make_rng('sample')
257 | if self.antithetic_time_sampling:
258 | t0 = jax.random.uniform(rng1)
259 | t = jnp.mod(t0 + jnp.arange(0.0, 1.0, step=1.0 / bs), 1.0)
260 | else:
261 | t = jax.random.uniform(rng1, shape=[bs])
262 |
263 | loss_diff = self.diffusion_loss(t, x, cond=cond, train=train).mean()
264 | loss = loss_diff + loss_prior + loss_recon
265 |
266 | model_stats = {
267 | 'loss': loss,
268 | 'loss_diff': loss_diff,
269 | 'loss_prior': loss_prior,
270 | 'loss_recon': loss_recon,
271 | }
272 | model_stats = utils.loss2bpt(model_stats, self.data_shape)
273 | return model_stats
274 |
275 | def get_sampling_grid(self, i, timesteps):
276 | t = (timesteps - i) / timesteps
277 | s = t - 1 / timesteps
278 | if self.sampling_grid == 'cosine':
279 | t = jnp.cos(math.pi / 2.0 * (1.0 - t))
280 | s = jnp.cos(math.pi / 2.0 * (1.0 - s))
281 | return s, t
282 |
283 | def ancestral_sample_step(self, rng, i, timesteps, zt, conditioning=None):
284 | rng_body = jax.random.fold_in(rng, i)
285 | s, t = self.get_sampling_grid(i, timesteps)
286 | cond = self.get_cond_embedding(conditioning)
287 |
288 | alpha_t = self.noise_schedule.alpha(t)
289 | alpha_s = self.noise_schedule.alpha(s)
290 |
291 | logits, _ = self.predict_x(zt, t, cond=cond)
292 | mean_preds = jax.nn.softmax(logits, axis=-1)
293 |
294 | unmask_prob = (alpha_s - alpha_t) / (1 - alpha_t)
295 | probs_vocab = unmask_prob * mean_preds
296 |
297 | probs_mask = jnp.ones(list(zt.shape) + [1]) * (1 - unmask_prob)
298 | probs = jnp.concatenate([probs_vocab, probs_mask], axis=-1)
299 |
300 | to_unmask = tfd.Categorical(probs=probs).sample(seed=rng_body)
301 | is_mask = zt == self.vocab_size
302 | zs = jnp.where(is_mask, to_unmask, zt)
303 | return zs
304 |
305 | def topp_sample_step(
306 | self, rng, i, timesteps, zt, conditioning=None, topp=0.98
307 | ):
308 | rng_body = jax.random.fold_in(rng, i)
309 | s, t = self.get_sampling_grid(i, timesteps)
310 | cond = self.get_cond_embedding(conditioning)
311 |
312 | alpha_t = self.noise_schedule.alpha(t)
313 | alpha_s = self.noise_schedule.alpha(s)
314 |
315 | logits, _ = self.predict_x(zt, t, cond=cond)
316 | logits = binary_search.topp_mask(logits, topp, replace_val=jnp.array(-1e7))
317 | # mean_preds: [bs, ..., vocab]
318 | mean_preds = jax.nn.softmax(logits, axis=-1)
319 |
320 | unmask_prob = (alpha_s - alpha_t) / (1 - alpha_t)
321 | probs_vocab = unmask_prob * mean_preds
322 |
323 | probs_mask = jnp.ones(list(zt.shape) + [1]) * (1 - unmask_prob)
324 | probs = jnp.concatenate([probs_vocab, probs_mask], axis=-1)
325 |
326 | to_unmask = tfd.Categorical(probs=probs).sample(seed=rng_body)
327 | is_mask = zt == self.vocab_size
328 | zs = jnp.where(is_mask, to_unmask, zt)
329 | return zs
330 |
331 | def mean_sample_step(self, rng, i, timesteps, zt, conditioning=None):
332 | # Ancestral sampling done in two steps -- tends to be worse than one-step
333 | # implementation in ancestral_sample_step. See App. G of
334 | # https://arxiv.org/abs/2406.04329.
335 | rng_body = jax.random.fold_in(rng, i)
336 | s, t = self.get_sampling_grid(i, timesteps)
337 | cond = self.get_cond_embedding(conditioning)
338 |
339 | alpha_t = self.noise_schedule.alpha(t)
340 | alpha_s = self.noise_schedule.alpha(s)
341 |
342 | logits, _ = self.predict_x(zt, t, cond=cond)
343 | unmask_prob = (alpha_s - alpha_t) / (1 - alpha_t)
344 |
345 | rng_body, rng = jax.random.split(rng_body)
346 | z0 = tfd.Categorical(logits=logits).sample(seed=rng_body)
347 |
348 | rng_body, _ = jax.random.split(rng)
349 | unmask = jax.random.bernoulli(rng_body, unmask_prob, zt.shape)
350 |
351 | to_unmask = jnp.where(unmask, z0, zt)
352 | is_mask = zt == self.vocab_size
353 | zs = jnp.where(is_mask, to_unmask, zt)
354 | return zs
355 |
356 | def sample_step(self, rng, i, timesteps, zt, conditioning=None, topp=None):
357 | if self.sampler == 'ancestral':
358 | return self.ancestral_sample_step(
359 | rng, i, timesteps, zt, conditioning=conditioning
360 | )
361 | elif self.sampler == 'topp':
362 | topp = self.topp if topp is None else topp
363 | return self.topp_sample_step(
364 | rng, i, timesteps, zt, conditioning=conditioning, topp=topp
365 | )
366 | elif self.sampler == 'mean':
367 | return self.mean_sample_step(
368 | rng, i, timesteps, zt, conditioning=conditioning
369 | )
370 | else:
371 | raise NotImplementedError()
372 |
--------------------------------------------------------------------------------
/md4/models/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Model utils."""
17 |
18 | import ml_collections
19 |
20 | from md4.models.diffusion import genmd4
21 | from md4.models.diffusion import md4
22 |
23 |
24 | def get_model(config: ml_collections.ConfigDict):
25 | """Get model instances."""
26 | if config.model_type == "md4":
27 | return md4.MD4(
28 | config.data_shape,
29 | cont_time=config.cont_time,
30 | timesteps=config.timesteps,
31 | feature_dim=config.feature_dim,
32 | num_heads=config.num_heads,
33 | n_layers=config.n_layers,
34 | n_dit_layers=config.n_dit_layers,
35 | dit_num_heads=config.dit_num_heads,
36 | dit_hidden_size=config.dit_hidden_size,
37 | ch_mult=config.ch_mult,
38 | vocab_size=config.vocab_size,
39 | noise_schedule_type=config.noise_schedule,
40 | dropout_rate=config.dropout_rate,
41 | use_attn_dropout=config.get("use_attn_dropout", True),
42 | mlp_type=config.mlp_type,
43 | depth_scaled_init=config.depth_scaled_init,
44 | cond_type=config.cond_type,
45 | outside_embed=config.outside_embed,
46 | time_features=config.time_features,
47 | classes=config.classes,
48 | sampler=config.sampler,
49 | sampling_grid=config.sampling_grid,
50 | topp=config.topp,
51 | model_sharding=config.get("model_sharding", False),
52 | )
53 | elif config.model_type == "genmd4":
54 | return genmd4.GenMD4(
55 | config.data_shape,
56 | cont_time=config.cont_time,
57 | timesteps=config.timesteps,
58 | feature_dim=config.feature_dim,
59 | num_heads=config.num_heads,
60 | n_layers=config.n_layers,
61 | n_dit_layers=config.n_dit_layers,
62 | dit_num_heads=config.dit_num_heads,
63 | dit_hidden_size=config.dit_hidden_size,
64 | ch_mult=config.ch_mult,
65 | vocab_size=config.vocab_size,
66 | noise_schedule_type=config.noise_schedule,
67 | power_init=config.power_init,
68 | dropout_rate=config.dropout_rate,
69 | use_attn_dropout=config.get("use_attn_dropout", True),
70 | mlp_type=config.mlp_type,
71 | depth_scaled_init=config.depth_scaled_init,
72 | cond_type=config.cond_type,
73 | outside_embed=config.outside_embed,
74 | time_features=config.time_features,
75 | )
76 | else:
77 | raise NotImplementedError()
78 |
--------------------------------------------------------------------------------
/md4/multihost_dataloading.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Multihost dataloading utilities.
17 |
18 | Adapted from:
19 | https://github.com/AI-Hypercomputer/maxtext/blob/main/MaxText/multihost_dataloading.py
20 | """
21 |
22 | from collections.abc import Iterable, Iterator
23 | import functools
24 | import time
25 | from typing import Union
26 |
27 | import grain.python as grain
28 | import jax
29 | from jax.sharding import Mesh
30 | from jax.sharding import NamedSharding
31 | from jax.sharding import PartitionSpec
32 | import jax.tree_util as jtu
33 | import numpy as np
34 | import tensorflow as tf # pylint: disable=g-import-not-at-top
35 |
36 |
37 | def _build_global_shape_and_sharding(
38 | local_shape: tuple[int, ...], global_mesh: Mesh
39 | ) -> tuple[tuple[int, ...], NamedSharding]:
40 | sharding = NamedSharding(global_mesh, PartitionSpec(global_mesh.axis_names))
41 |
42 | global_shape = (jax.process_count() * local_shape[0],) + local_shape[1:]
43 |
44 | return global_shape, sharding
45 |
46 |
47 | def _form_global_array(path, array: np.ndarray, global_mesh: Mesh) -> jax.Array:
48 | """Put local sharded array into local devices."""
49 | global_shape, sharding = _build_global_shape_and_sharding(
50 | np.shape(array), global_mesh
51 | )
52 |
53 | try:
54 | local_device_arrays = np.split(
55 | array, len(global_mesh.local_devices), axis=0
56 | )
57 | except ValueError as array_split_error:
58 | raise ValueError(
59 | f"Unable to put to devices shape {array.shape} with "
60 | f"local device count {len(global_mesh.local_devices)} "
61 | f"at {jtu.keystr(path)}"
62 | ) from array_split_error
63 |
64 | local_device_buffers = jax.device_put(
65 | local_device_arrays, global_mesh.local_devices
66 | )
67 | return jax.make_array_from_single_device_arrays(
68 | global_shape, sharding, local_device_buffers
69 | )
70 |
71 |
72 | def get_next_batch_sharded(
73 | local_iterator: Iterator[jax.Array], global_mesh: Mesh
74 | ) -> jax.Array:
75 | """Splits the host loaded data equally over all devices."""
76 |
77 | sleep_time = 10
78 | max_data_load_attempts = 30
79 |
80 | data_load_attempts = 0
81 | loaded_data_success = False
82 | while not loaded_data_success and data_load_attempts < max_data_load_attempts:
83 | data_load_attempts += 1
84 | try:
85 | local_data = next(local_iterator)
86 | loaded_data_success = True
87 | except tf.errors.FailedPreconditionError:
88 | print("Failed to get next data batch, retrying")
89 | time.sleep(sleep_time)
90 |
91 | # Try one last time, if this fails we will see the full stack trace.
92 | if not loaded_data_success:
93 | local_data = next(local_iterator)
94 |
95 | input_gdas = jtu.tree_map_with_path(
96 | functools.partial(_form_global_array, global_mesh=global_mesh), local_data
97 | )
98 |
99 | return input_gdas
100 |
101 |
102 | class MultiHostDataLoadIterator:
103 | """fold get_next_batch_sharded into a iterator class."""
104 |
105 | def __init__(
106 | self,
107 | dataloader: Union[tf.data.Dataset, grain.DataLoader],
108 | global_mesh: Mesh,
109 | ):
110 | self.global_mesh = global_mesh
111 | self.dataloader = dataloader
112 | if isinstance(self.dataloader, tf.data.Dataset):
113 | self.local_iterator = self.dataloader.as_numpy_iterator()
114 | elif isinstance(self.dataloader, Iterable):
115 | self.local_iterator = iter(self.dataloader)
116 | else:
117 | raise ValueError(
118 | "Type error: dataloader should be either tf.data.Dataset or Iterable."
119 | )
120 |
121 | def reset(self):
122 | if isinstance(self.dataloader, tf.data.Dataset):
123 | self.local_iterator = self.dataloader.as_numpy_iterator()
124 | elif isinstance(self.dataloader, Iterable):
125 | self.local_iterator = iter(self.dataloader)
126 | else:
127 | raise ValueError(
128 | "Type error: dataloader should be either tf.data.Dataset or"
129 | " grain.DataLoader."
130 | )
131 |
132 | def __iter__(self):
133 | self.reset()
134 | return self
135 |
136 | def __next__(self):
137 | return get_next_batch_sharded(self.local_iterator, self.global_mesh)
138 |
--------------------------------------------------------------------------------
/md4/networks/dit.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """DiT architecture.
17 |
18 | Jax implementation of https://arxiv.org/abs/2212.09748, based on PyTorch
19 | implementation https://github.com/facebookresearch/DiT/blob/main/models.py.
20 | """
21 |
22 | import math
23 | from typing import Any
24 |
25 | import flax.linen as nn
26 | import jax.numpy as jnp
27 | import numpy as np
28 |
29 | from md4.networks import transformer
30 | # pylint: disable=missing-class-docstring
31 |
32 |
33 | def modulate(x, shift, scale):
34 | return x * (1 + scale) + shift
35 |
36 |
37 | class PatchEmbed(nn.Module):
38 | """2D Image to Patch Embedding."""
39 |
40 | img_size: int = 224
41 | patch_size: int = 16
42 | embed_dim: int = 768
43 | flatten: bool = True
44 | use_bias: bool = True
45 |
46 | def setup(self):
47 | self.proj = nn.Conv(
48 | self.embed_dim,
49 | kernel_size=(self.patch_size, self.patch_size),
50 | strides=self.patch_size,
51 | padding='VALID',
52 | use_bias=self.use_bias,
53 | )
54 |
55 | def __call__(self, x):
56 | x = self.proj(x)
57 | if self.flatten:
58 | x = x.reshape(x.shape[0], -1, x.shape[-1])
59 | return x
60 |
61 |
62 | class Mlp(nn.Module):
63 | """MLP as used in Vision Transformer, MLP-Mixer and related networks."""
64 |
65 | out_features: int
66 | hidden_features: int
67 | act: Any = nn.gelu
68 | use_bias: bool = True
69 | dropout_rate: float = 0.0
70 |
71 | def setup(self):
72 | self.fc1 = nn.Dense(self.hidden_features, use_bias=self.use_bias)
73 | self.drop1 = nn.Dropout(self.dropout_rate)
74 | self.fc2 = nn.Dense(self.out_features, use_bias=self.use_bias)
75 | self.drop2 = nn.Dropout(self.dropout_rate)
76 |
77 | def __call__(self, x, train=False):
78 | x = self.fc1(x)
79 | x = self.act(x)
80 | x = self.drop1(x, deterministic=not train)
81 | x = self.fc2(x)
82 | x = self.drop2(x, deterministic=not train)
83 | return x
84 |
85 |
86 | class Attention(nn.Module):
87 |
88 | dim: int
89 | n_heads: int
90 | n_kv_heads: int | None = None
91 | dropout_rate: float = 0.0
92 | qkv_bias: bool = False
93 |
94 | def setup(self):
95 | self._n_kv_heads = (
96 | self.n_heads if self.n_kv_heads is None else self.n_kv_heads
97 | )
98 | assert self.n_heads % self._n_kv_heads == 0
99 | self.n_rep = self.n_heads // self._n_kv_heads
100 | self.head_dim = self.dim // self.n_heads
101 | self.wq = nn.Dense(self.n_heads * self.head_dim, use_bias=self.qkv_bias)
102 | self.wk = nn.Dense(self._n_kv_heads * self.head_dim, use_bias=self.qkv_bias)
103 | self.wv = nn.Dense(self._n_kv_heads * self.head_dim, use_bias=self.qkv_bias)
104 | self.wo = nn.Dense(self.dim, use_bias=False)
105 | self.attn_dropout = nn.Dropout(self.dropout_rate)
106 | # self.resid_dropout = nn.Dropout(self.dropout_rate)
107 | self.resid_dropout = transformer.Dropout1d(self.dropout_rate)
108 |
109 | def __call__(self, x, train=False):
110 | bsz, seqlen, _ = x.shape
111 |
112 | # QKV
113 | xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
114 | xq = xq.reshape(bsz, seqlen, self.n_heads, self.head_dim)
115 | xk = xk.reshape(bsz, seqlen, self._n_kv_heads, self.head_dim)
116 | xv = xv.reshape(bsz, seqlen, self._n_kv_heads, self.head_dim)
117 |
118 | # grouped multiquery attention: expand out keys and values
119 | xk = transformer.repeat_kv(xk, self.n_rep)
120 | xv = transformer.repeat_kv(xv, self.n_rep)
121 |
122 | # make heads into a batch dimension
123 | xq = xq.swapaxes(1, 2) # (bs, n_heads, seqlen, head_dim)
124 | xk = xk.swapaxes(1, 2)
125 | xv = xv.swapaxes(1, 2)
126 |
127 | scores = jnp.matmul(xq, xk.swapaxes(2, 3)) / math.sqrt(self.head_dim)
128 | scores = nn.softmax(scores, axis=-1)
129 | scores = self.attn_dropout(scores, deterministic=not train)
130 | output = jnp.matmul(scores, xv) # (bs, n_heads, seqlen, head_dim)
131 |
132 | # restore time as batch dimension and concat heads
133 | output = output.swapaxes(1, 2).reshape(bsz, seqlen, -1)
134 |
135 | # final projection into the residual stream
136 | output = self.wo(output)
137 | output = self.resid_dropout(output, deterministic=not train)
138 | return output
139 |
140 |
141 | class DiTBlock(nn.Module):
142 | """A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning."""
143 |
144 | hidden_size: int
145 | num_heads: int
146 | mlp_ratio: float = 4.0
147 | dropout_rate: float = 0.0
148 |
149 | def setup(self):
150 | self.attn = Attention(
151 | self.hidden_size,
152 | self.num_heads,
153 | dropout_rate=self.dropout_rate,
154 | qkv_bias=True,
155 | )
156 | mlp_hidden_dim = int(self.hidden_size * self.mlp_ratio)
157 | self.mlp = Mlp(
158 | out_features=self.hidden_size,
159 | hidden_features=mlp_hidden_dim,
160 | act=nn.gelu,
161 | dropout_rate=self.dropout_rate,
162 | )
163 |
164 | @nn.compact
165 | def __call__(self, x, cond=None, train=False):
166 | if cond is not None:
167 | adaln_modulation = nn.Sequential([
168 | nn.swish,
169 | nn.Dense(
170 | 6 * self.hidden_size,
171 | kernel_init=nn.zeros_init(),
172 | bias_init=nn.zeros_init(),
173 | ),
174 | ])
175 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
176 | jnp.split(adaln_modulation(cond)[:, None, :], 6, axis=-1)
177 | )
178 | norm1 = nn.LayerNorm(use_bias=False, use_scale=False)
179 | norm2 = nn.LayerNorm(use_bias=False, use_scale=False)
180 | x = x + gate_msa * self.attn(
181 | modulate(norm1(x), shift_msa, scale_msa), train=train
182 | )
183 | x = x + gate_mlp * self.mlp(
184 | modulate(norm2(x), shift_mlp, scale_mlp), train=train
185 | )
186 | else:
187 | x = x + self.attn(nn.RMSNorm()(x), train=train)
188 | x = x + self.mlp(nn.RMSNorm()(x), train=train)
189 | return x
190 |
191 |
192 | class FinalLayer(nn.Module):
193 | """The final layer of DiT."""
194 |
195 | hidden_size: int
196 | patch_size: int
197 | out_channels: int
198 |
199 | def setup(self):
200 | self.linear = nn.Dense(
201 | self.patch_size * self.patch_size * self.out_channels,
202 | kernel_init=nn.zeros_init(),
203 | bias_init=nn.zeros_init(),
204 | )
205 |
206 | @nn.compact
207 | def __call__(self, x, cond=None):
208 | if cond is not None:
209 | adaln_modulation = nn.Sequential([
210 | nn.swish,
211 | nn.Dense(
212 | 2 * self.hidden_size,
213 | kernel_init=nn.zeros_init(),
214 | bias_init=nn.zeros_init(),
215 | ),
216 | ])
217 | shift, scale = jnp.split(adaln_modulation(cond)[:, None, :], 2, axis=-1)
218 | norm_final = nn.LayerNorm(use_bias=False, use_scale=False)
219 | x = modulate(norm_final(x), shift, scale)
220 | else:
221 | x = nn.RMSNorm()(x)
222 | x = self.linear(x)
223 | return x
224 |
225 |
226 | class DiT(nn.Module):
227 | """Diffusion model with a Transformer backbone."""
228 |
229 | img_size: int
230 | patch_size: int
231 | in_channels: int
232 | out_channels: int
233 | hidden_size: int
234 | depth: int
235 | num_heads: int
236 | mlp_ratio: float = 4.0
237 | dropout_rate: float = 0.0
238 |
239 | def setup(self):
240 | self.x_embedder = PatchEmbed(
241 | img_size=self.img_size,
242 | patch_size=self.patch_size,
243 | embed_dim=self.hidden_size,
244 | use_bias=True,
245 | )
246 | self.grid_size = self.img_size // self.patch_size
247 | num_patches = self.grid_size * self.grid_size
248 | self.pos_embed = self.param(
249 | 'pos_embed',
250 | lambda k, s: get_2d_sincos_pos_embed(s[-1], int(num_patches**0.5)),
251 | [num_patches, self.hidden_size],
252 | )
253 | self.blocks = [
254 | DiTBlock(
255 | self.hidden_size,
256 | self.num_heads,
257 | mlp_ratio=self.mlp_ratio,
258 | dropout_rate=self.dropout_rate,
259 | )
260 | for _ in range(self.depth)
261 | ]
262 | self.final_layer = FinalLayer(
263 | self.hidden_size, self.patch_size, self.out_channels
264 | )
265 |
266 | def __call__(self, x, cond=None, train=False):
267 | c = x.shape[-1]
268 | p = self.patch_size
269 | grid_size = self.grid_size
270 | x = (
271 | self.x_embedder(x) + self.pos_embed
272 | ) # (N, T, D), where T = H * W / p ** 2
273 | for block in self.blocks:
274 | x = block(x, cond=cond, train=train) # (N, T, D)
275 | x = self.final_layer(x, cond=cond) # (N, T, p ** 2 * c)
276 | x = x.reshape(-1, grid_size, grid_size, p, p, c)
277 | x = jnp.einsum('nhwpqc->nhpwqc', x)
278 | x = x.reshape(-1, grid_size * p, grid_size * p, c) # (N, H, W, c)
279 | return x
280 |
281 |
282 | def get_2d_sincos_pos_embed(embed_dim, grid_size):
283 | """2D sin-cos position embedding."""
284 | grid_h = np.arange(grid_size, dtype=np.float32)
285 | grid_w = np.arange(grid_size, dtype=np.float32)
286 | grid = np.meshgrid(grid_w, grid_h)
287 | grid = np.stack(grid, axis=0)
288 |
289 | grid = grid.reshape([2, 1, grid_size, grid_size])
290 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
291 | return pos_embed
292 |
293 |
294 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
295 | """Gets 2D sin-cos position embedding from grid."""
296 | assert embed_dim % 2 == 0
297 |
298 | # use half of dimensions to encode grid_h
299 | emb_h = get_1d_sincos_pos_embed_from_grid(
300 | embed_dim // 2, grid[0]
301 | ) # (H*W, D/2)
302 | emb_w = get_1d_sincos_pos_embed_from_grid(
303 | embed_dim // 2, grid[1]
304 | ) # (H*W, D/2)
305 |
306 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
307 | return emb
308 |
309 |
310 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
311 | """Gets 1D sin-cos position embedding from grid."""
312 | assert embed_dim % 2 == 0
313 | omega = np.arange(embed_dim // 2, dtype=np.float64)
314 | omega /= embed_dim / 2.0
315 | omega = 1.0 / 10000**omega # (D/2,)
316 |
317 | pos = pos.reshape(-1) # (M,)
318 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
319 |
320 | emb_sin = np.sin(out) # (M, D/2)
321 | emb_cos = np.cos(out) # (M, D/2)
322 |
323 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
324 | return emb
325 |
--------------------------------------------------------------------------------
/md4/networks/sharded_transformer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Jax implementation of LLAMA2-like Transformer with model sharding.
17 |
18 | Based on PyTorch implementation
19 | https://github.com/karpathy/llama2.c/blob/master/model.py
20 | """
21 |
22 | import dataclasses
23 | import math
24 | from typing import Optional
25 |
26 | import flax.linen as nn
27 | import jax
28 | import jax.numpy as jnp
29 | # pylint: disable=missing-class-docstring
30 | # pylint: disable=missing-function-docstring
31 |
32 |
33 | activation_map = dict(
34 | swiglu=nn.swish,
35 | geglu=nn.gelu,
36 | glu=nn.sigmoid,
37 | )
38 |
39 |
40 | @dataclasses.dataclass(unsafe_hash=True)
41 | class ModelArgs:
42 | dim: int = 288
43 | n_layers: int = 6
44 | n_heads: int = 6
45 | n_kv_heads: Optional[int] = None
46 | output_channels: int = 1024
47 | hidden_dim: Optional[int] = None
48 | multiple_of: int = 32 # MLP hidden layer size will be multiple of
49 | norm_eps: float = 1e-5
50 | dropout_rate: float = 0.0
51 | use_attn_dropout: bool = True
52 | weight_tying: bool = False
53 | w_init_scale: float = 1.0
54 | depth_scaled_init: bool = False
55 | # glu, geglu, swiglu
56 | mlp_type: str = 'swiglu'
57 | # adaln, adaln_zero
58 | cond_type: str = 'adaln'
59 | embed_input: bool = False
60 | n_embed_classes: int = 1024
61 | causal: bool = False
62 |
63 |
64 | class RMSNorm(nn.Module):
65 |
66 | dim: int
67 | eps: float
68 |
69 | def setup(self):
70 | self.scale = self.param(
71 | 'scale', lambda key, shape: jnp.ones(shape), (self.dim,)
72 | )
73 |
74 | def _norm(self, x):
75 | return x * jax.lax.rsqrt(jnp.square(x).mean(-1, keepdims=True) + self.eps)
76 |
77 | def __call__(self, x):
78 | output = self._norm(x)
79 | return output * self.scale
80 |
81 |
82 | def precompute_freqs_cis(dim, end, theta: float = 10000.0):
83 | freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2)[: (dim // 2)] / dim))
84 | t = jnp.arange(end)
85 | freqs = jnp.outer(t, freqs)
86 | freqs_cos = jnp.cos(freqs)
87 | freqs_sin = jnp.sin(freqs)
88 | return freqs_cos, freqs_sin
89 |
90 |
91 | def reshape_for_broadcast(freqs_cis, x):
92 | ndim = x.ndim
93 | assert 1 < ndim
94 | assert freqs_cis.shape == (x.shape[1], x.shape[-1])
95 | shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
96 | return freqs_cis.reshape(shape)
97 |
98 |
99 | def jax_unstack(x, axis=0):
100 | return [
101 | jax.lax.index_in_dim(x, i, axis, keepdims=False)
102 | for i in range(x.shape[axis])
103 | ]
104 |
105 |
106 | def apply_rotary_emb(xq, xk, freqs_cos, freqs_sin):
107 | # reshape xq and xk to match the complex representation
108 | # [bs, seq_len, n_head, head_dim // 2]
109 | xq_r, xq_i = jax_unstack(xq.reshape(xq.shape[:-1] + (-1, 2)), -1)
110 | xk_r, xk_i = jax_unstack(xk.reshape(xk.shape[:-1] + (-1, 2)), -1)
111 |
112 | # reshape freqs_cos and freqs_sin for broadcasting
113 | # [1, seq_len, 1, head_dim // 2]
114 | freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
115 | freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)
116 |
117 | # apply rotation using real numbers
118 | xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
119 | xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
120 | xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
121 | xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos
122 |
123 | # flatten last two dimensions
124 | # [bs, seq_len, n_head, head_dim // 2, 2] -> [bs, seq_len, n_head, head_dim]
125 | xq_out = jnp.stack([xq_out_r, xq_out_i], axis=-1).reshape(
126 | xq_out_r.shape[:3] + (-1,)
127 | )
128 | xk_out = jnp.stack([xk_out_r, xk_out_i], axis=-1).reshape(
129 | xk_out_r.shape[:3] + (-1,)
130 | )
131 |
132 | return xq_out, xk_out
133 |
134 |
135 | def repeat_kv(x, n_rep):
136 | bs, slen, n_kv_heads, head_dim = x.shape
137 | if n_rep == 1:
138 | return x
139 | return jnp.tile(x[:, :, :, None, :], [1, 1, 1, n_rep, 1]).reshape(
140 | bs, slen, n_kv_heads * n_rep, head_dim
141 | )
142 |
143 |
144 | class Dropout1d(nn.Module):
145 |
146 | dropout_rate: float = 0.0
147 |
148 | def __call__(self, x, deterministic=True):
149 | if (self.dropout_rate > 0.0) and not deterministic:
150 | drop = jax.random.bernoulli(
151 | self.make_rng('dropout'),
152 | 1 - self.dropout_rate,
153 | (x.shape[0], 1, x.shape[-1]),
154 | )
155 | x = x * drop / (1 - self.dropout_rate)
156 | return x
157 |
158 |
159 | class Attention(nn.Module):
160 |
161 | dim: int
162 | n_heads: int
163 | n_kv_heads: int | None = None
164 | dropout_rate: float = 0.0
165 | causal: bool = False
166 | qkv_bias: bool = False
167 | use_attn_dropout: bool = True
168 |
169 | def setup(self):
170 | self._n_kv_heads = (
171 | self.n_heads if self.n_kv_heads is None else self.n_kv_heads
172 | )
173 | assert self.n_heads % self._n_kv_heads == 0
174 | self.n_rep = self.n_heads // self._n_kv_heads
175 | self.head_dim = self.dim // self.n_heads
176 | self.wq = nn.Dense(
177 | self.n_heads * self.head_dim,
178 | use_bias=self.qkv_bias,
179 | kernel_init=nn.with_logical_partitioning(
180 | nn.linear.default_kernel_init, ('embed', 'qkv')
181 | ),
182 | ) # fsdp, tensor
183 | self.wk = nn.Dense(
184 | self._n_kv_heads * self.head_dim,
185 | use_bias=self.qkv_bias,
186 | kernel_init=nn.with_logical_partitioning(
187 | nn.linear.default_kernel_init, ('embed', 'qkv')
188 | ),
189 | )
190 | self.wv = nn.Dense(
191 | self._n_kv_heads * self.head_dim,
192 | use_bias=self.qkv_bias,
193 | kernel_init=nn.with_logical_partitioning(
194 | nn.linear.default_kernel_init, ('embed', 'qkv')
195 | ),
196 | )
197 | self.wo = nn.Dense(
198 | self.dim,
199 | use_bias=False,
200 | kernel_init=nn.with_logical_partitioning(
201 | nn.linear.default_kernel_init, ('qkv', 'embed')
202 | )
203 | )
204 | if self.use_attn_dropout and self.dropout_rate > 0.0:
205 | self.attn_dropout = nn.Dropout(self.dropout_rate)
206 | if self.dropout_rate > 0.0:
207 | self.resid_dropout = Dropout1d(self.dropout_rate)
208 |
209 | def __call__(self, x, freqs_cos, freqs_sin, train=False):
210 | bsz, seqlen, _ = x.shape
211 |
212 | # QKV
213 | xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
214 | xq = xq.reshape(bsz, seqlen, self.n_heads, self.head_dim)
215 | xk = xk.reshape(bsz, seqlen, self._n_kv_heads, self.head_dim)
216 | xv = xv.reshape(bsz, seqlen, self._n_kv_heads, self.head_dim)
217 |
218 | # RoPE relative positional embeddings
219 | xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)
220 |
221 | # grouped multiquery attention: expand out keys and values
222 | xk = repeat_kv(xk, self.n_rep)
223 | xv = repeat_kv(xv, self.n_rep)
224 |
225 | # make heads into a batch dimension
226 | xq = xq.swapaxes(1, 2) # (bs, n_heads, seqlen, head_dim)
227 | xk = xk.swapaxes(1, 2)
228 | xv = xv.swapaxes(1, 2)
229 |
230 | scores = jnp.matmul(xq, xk.swapaxes(2, 3)) / math.sqrt(self.head_dim)
231 | if self.causal:
232 | mask = jnp.full((1, 1, seqlen, seqlen), -jnp.inf)
233 | mask = jnp.triu(mask, k=1)
234 | scores = (
235 | scores + mask[:, :, :seqlen, :seqlen]
236 | ) # (bs, n_heads, seqlen, seqlen)
237 | scores = nn.softmax(scores, axis=-1)
238 | if self.use_attn_dropout and self.dropout_rate > 0.0:
239 | scores = self.attn_dropout(scores, deterministic=not train)
240 | output = jnp.matmul(scores, xv) # (bs, n_heads, seqlen, head_dim)
241 |
242 | # restore time as batch dimension and concat heads
243 | output = output.swapaxes(1, 2).reshape(bsz, seqlen, -1)
244 |
245 | # final projection into the residual stream
246 | output = self.wo(output)
247 | if self.dropout_rate > 0.0:
248 | output = self.resid_dropout(output, deterministic=not train)
249 | return output
250 |
251 |
252 | class FeedForward(nn.Module):
253 |
254 | dim: int
255 | multiple_of: int
256 | dropout_rate: float
257 | hidden_dim: int | None = None
258 | w_init_scale: float = 1.0
259 | mlp_type: str = 'swiglu'
260 |
261 | def setup(self):
262 | multiple_of = self.multiple_of
263 | hidden_dim = self.hidden_dim
264 | if hidden_dim is None:
265 | hidden_dim = 4 * self.dim
266 | hidden_dim = int(2 * hidden_dim / 3)
267 | hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
268 | w_init = nn.initializers.variance_scaling(
269 | self.w_init_scale, 'fan_in', 'truncated_normal'
270 | )
271 | self.w1 = nn.Dense(
272 | hidden_dim,
273 | use_bias=False,
274 | kernel_init=nn.with_logical_partitioning(w_init, ('embed', 'mlp')),
275 | )
276 | self.w2 = nn.Dense(
277 | self.dim,
278 | use_bias=False,
279 | kernel_init=nn.with_logical_partitioning(w_init, ('mlp', 'embed')),
280 | )
281 | self.w3 = nn.Dense(
282 | hidden_dim,
283 | use_bias=False,
284 | kernel_init=nn.with_logical_partitioning(w_init, ('embed', 'mlp')),
285 | )
286 | # self.dropout = nn.Dropout(self.dropout_rate)
287 | if self.dropout_rate > 0.0:
288 | self.dropout = Dropout1d(self.dropout_rate)
289 |
290 | def __call__(self, x, train=False):
291 | activation = activation_map[self.mlp_type]
292 | y = self.w2(activation(self.w1(x)) * self.w3(x))
293 | if self.dropout_rate > 0.0:
294 | return self.dropout(y, deterministic=not train)
295 | else:
296 | return y
297 |
298 |
299 | class TransformerBlock(nn.Module):
300 |
301 | layer_id: int
302 | args: ModelArgs
303 |
304 | def setup(self):
305 | args = self.args
306 | self.attention = Attention(
307 | args.dim,
308 | args.n_heads,
309 | n_kv_heads=args.n_kv_heads,
310 | dropout_rate=args.dropout_rate,
311 | causal=args.causal,
312 | use_attn_dropout=args.use_attn_dropout,
313 | )
314 |
315 | if args.depth_scaled_init:
316 | w_init_scale = 2.0 / args.n_layers
317 | else:
318 | w_init_scale = args.w_init_scale
319 |
320 | self.feed_forward = FeedForward(
321 | dim=args.dim,
322 | multiple_of=args.multiple_of,
323 | dropout_rate=args.dropout_rate,
324 | hidden_dim=args.hidden_dim,
325 | w_init_scale=w_init_scale,
326 | mlp_type=args.mlp_type,
327 | )
328 |
329 | @nn.compact
330 | def __call__(self, x, freqs_cos, freqs_sin, cond=None, train=False):
331 | if cond is not None:
332 | activation = activation_map[self.args.mlp_type]
333 | if self.args.cond_type == 'adaln':
334 | ln = nn.Sequential([
335 | # nn.swish,
336 | activation,
337 | nn.Dense(
338 | 6 * self.args.dim,
339 | use_bias=True,
340 | ),
341 | ])
342 | elif self.args.cond_type == 'adaln_zero':
343 | ln = nn.Sequential([
344 | # nn.swish,
345 | activation,
346 | nn.Dense(
347 | 6 * self.args.dim,
348 | use_bias=True,
349 | kernel_init=nn.initializers.zeros,
350 | bias_init=nn.initializers.zeros,
351 | ),
352 | ])
353 | else:
354 | raise NotImplementedError()
355 | (shift_att, scale_att, gate_att, shift_mlp, scale_mlp, gate_mlp) = (
356 | jnp.split(ln(cond)[:, None, :], 6, axis=-1)
357 | )
358 | attention_norm = nn.LayerNorm(
359 | epsilon=self.args.norm_eps, use_bias=False, use_scale=False
360 | )
361 | ffn_norm = nn.LayerNorm(
362 | epsilon=self.args.norm_eps, use_bias=False, use_scale=False
363 | )
364 | h = x + gate_att * self.attention(
365 | attention_norm(x) * (scale_att + 1.0) + shift_att,
366 | freqs_cos,
367 | freqs_sin,
368 | train=train,
369 | )
370 | out = h + gate_mlp * self.feed_forward(
371 | ffn_norm(h) * (scale_mlp + 1.0) + shift_mlp, train=train
372 | )
373 | else:
374 | attention_norm = RMSNorm(self.args.dim, eps=self.args.norm_eps)
375 | ffn_norm = RMSNorm(self.args.dim, eps=self.args.norm_eps)
376 | h = x + self.attention(
377 | attention_norm(x), freqs_cos, freqs_sin, train=train
378 | )
379 | out = h + self.feed_forward(ffn_norm(h), train=train)
380 |
381 | return out
382 |
383 |
384 | class Transformer(nn.Module):
385 |
386 | args: ModelArgs
387 |
388 | @nn.compact
389 | def __call__(self, x, cond=None, train=False, output_channels=None):
390 | args = self.args
391 | if output_channels is None:
392 | output_channels = args.output_channels
393 |
394 | if args.embed_input:
395 | h = nn.Embed(
396 | args.n_embed_classes,
397 | args.dim,
398 | embedding_init=nn.with_logical_partitioning(
399 | nn.linear.default_embed_init, ('embed_vocab', 'embed')
400 | ),
401 | )(x)
402 | if args.dropout_rate > 0.0:
403 | h = nn.Dropout(args.dropout_rate, deterministic=not train)(h)
404 | else:
405 | h = nn.Dense(
406 | args.dim,
407 | kernel_init=nn.with_logical_partitioning(
408 | nn.linear.default_kernel_init, ('input_embed', 'embed')
409 | ),
410 | )(x)
411 |
412 | seqlen = x.shape[1]
413 | freqs_cos, freqs_sin = precompute_freqs_cis(
414 | args.dim // args.n_heads, seqlen
415 | )
416 |
417 | freqs_cos = freqs_cos[:seqlen]
418 | freqs_sin = freqs_sin[:seqlen]
419 |
420 | for layer_id in range(args.n_layers):
421 | h = TransformerBlock(layer_id, args)(
422 | h, freqs_cos, freqs_sin, cond=cond, train=train
423 | )
424 |
425 | if cond is not None:
426 | output_norm = nn.LayerNorm(
427 | epsilon=args.norm_eps, use_bias=False, use_scale=False
428 | )
429 | activation = activation_map[args.mlp_type]
430 | if args.cond_type == 'adaln':
431 | ln = nn.Sequential([
432 | # nn.swish,
433 | activation,
434 | nn.Dense(
435 | 2 * args.dim,
436 | use_bias=True,
437 | ),
438 | ])
439 | elif args.cond_type == 'adaln_zero':
440 | ln = nn.Sequential([
441 | # nn.swish,
442 | activation,
443 | nn.Dense(
444 | 2 * args.dim,
445 | use_bias=True,
446 | kernel_init=nn.initializers.zeros,
447 | bias_init=nn.initializers.zeros,
448 | ),
449 | ])
450 | else:
451 | raise NotImplementedError()
452 | shift_out, scale_out = jnp.split(ln(cond)[:, None, :], 2, axis=-1)
453 | logits = nn.Dense(
454 | output_channels,
455 | use_bias=False,
456 | kernel_init=nn.with_logical_partitioning(
457 | nn.initializers.zeros, ('embed', 'output_vocab')
458 | ),
459 | )(output_norm(h) * (scale_out + 1) + shift_out)
460 | else:
461 | h = RMSNorm(args.dim, args.norm_eps)(h)
462 | logits = nn.Dense(
463 | features=output_channels,
464 | use_bias=False,
465 | kernel_init=nn.with_logical_partitioning(
466 | nn.initializers.zeros, ('embed', 'output_vocab')
467 | ),
468 | )(h)
469 |
470 | return logits
471 |
--------------------------------------------------------------------------------
/md4/networks/transformer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Jax implementation of LLAMA2-like Transformer.
17 |
18 | Based on PyTorch implementation
19 | https://github.com/karpathy/llama2.c/blob/master/model.py
20 | """
21 |
22 | import dataclasses
23 | import math
24 | from typing import Optional
25 |
26 | import flax.linen as nn
27 | import jax
28 | import jax.numpy as jnp
29 | # pylint: disable=missing-class-docstring
30 | # pylint: disable=missing-function-docstring
31 |
32 |
33 | activation_map = dict(
34 | swiglu=nn.swish,
35 | geglu=nn.gelu,
36 | glu=nn.sigmoid,
37 | )
38 |
39 |
40 | @dataclasses.dataclass(unsafe_hash=True)
41 | class ModelArgs:
42 | dim: int = 288
43 | n_layers: int = 6
44 | n_heads: int = 6
45 | n_kv_heads: Optional[int] = None
46 | output_channels: int = 1024
47 | hidden_dim: Optional[int] = None
48 | multiple_of: int = 32 # MLP hidden layer size will be multiple of
49 | norm_eps: float = 1e-5
50 | dropout_rate: float = 0.0
51 | weight_tying: bool = False
52 | w_init_scale: float = 1.0
53 | depth_scaled_init: bool = False
54 | # glu, geglu, swiglu
55 | mlp_type: str = 'swiglu'
56 | # adaln, adaln_zero
57 | cond_type: str = 'adaln'
58 | embed_input: bool = False
59 | n_embed_classes: int = 1024
60 | causal: bool = False
61 |
62 |
63 | class RMSNorm(nn.Module):
64 |
65 | dim: int
66 | eps: float
67 |
68 | def setup(self):
69 | self.scale = self.param(
70 | 'scale', lambda key, shape: jnp.ones(shape), (self.dim,)
71 | )
72 |
73 | def _norm(self, x):
74 | return x * jax.lax.rsqrt(jnp.square(x).mean(-1, keepdims=True) + self.eps)
75 |
76 | def __call__(self, x):
77 | output = self._norm(x)
78 | return output * self.scale
79 |
80 |
81 | def precompute_freqs_cis(dim, end, theta: float = 10000.0):
82 | freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2)[: (dim // 2)] / dim))
83 | t = jnp.arange(end)
84 | freqs = jnp.outer(t, freqs)
85 | freqs_cos = jnp.cos(freqs)
86 | freqs_sin = jnp.sin(freqs)
87 | return freqs_cos, freqs_sin
88 |
89 |
90 | def reshape_for_broadcast(freqs_cis, x):
91 | ndim = x.ndim
92 | assert 1 < ndim
93 | assert freqs_cis.shape == (x.shape[1], x.shape[-1])
94 | shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
95 | return freqs_cis.reshape(shape)
96 |
97 |
98 | def jax_unstack(x, axis=0):
99 | return [
100 | jax.lax.index_in_dim(x, i, axis, keepdims=False)
101 | for i in range(x.shape[axis])
102 | ]
103 |
104 |
105 | def apply_rotary_emb(xq, xk, freqs_cos, freqs_sin):
106 | # reshape xq and xk to match the complex representation
107 | # [bs, seq_len, n_head, head_dim // 2]
108 | xq_r, xq_i = jax_unstack(xq.reshape(xq.shape[:-1] + (-1, 2)), -1)
109 | xk_r, xk_i = jax_unstack(xk.reshape(xk.shape[:-1] + (-1, 2)), -1)
110 |
111 | # reshape freqs_cos and freqs_sin for broadcasting
112 | # [1, seq_len, 1, head_dim // 2]
113 | freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
114 | freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)
115 |
116 | # apply rotation using real numbers
117 | xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
118 | xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
119 | xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
120 | xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos
121 |
122 | # flatten last two dimensions
123 | # [bs, seq_len, n_head, head_dim // 2, 2] -> [bs, seq_len, n_head, head_dim]
124 | xq_out = jnp.stack([xq_out_r, xq_out_i], axis=-1).reshape(
125 | xq_out_r.shape[:3] + (-1,)
126 | )
127 | xk_out = jnp.stack([xk_out_r, xk_out_i], axis=-1).reshape(
128 | xk_out_r.shape[:3] + (-1,)
129 | )
130 |
131 | return xq_out, xk_out
132 |
133 |
134 | def repeat_kv(x, n_rep):
135 | bs, slen, n_kv_heads, head_dim = x.shape
136 | if n_rep == 1:
137 | return x
138 | return jnp.tile(x[:, :, :, None, :], [1, 1, 1, n_rep, 1]).reshape(
139 | bs, slen, n_kv_heads * n_rep, head_dim
140 | )
141 |
142 |
143 | class Dropout1d(nn.Module):
144 |
145 | dropout_rate: float = 0.0
146 |
147 | def __call__(self, x, deterministic=True):
148 | if (self.dropout_rate > 0.0) and not deterministic:
149 | drop = jax.random.bernoulli(
150 | self.make_rng('dropout'),
151 | 1 - self.dropout_rate,
152 | (x.shape[0], 1, x.shape[-1]),
153 | )
154 | x = x * drop / (1 - self.dropout_rate)
155 | return x
156 |
157 |
158 | class Attention(nn.Module):
159 |
160 | dim: int
161 | n_heads: int
162 | n_kv_heads: int | None = None
163 | dropout_rate: float = 0.0
164 | causal: bool = False
165 | qkv_bias: bool = False
166 |
167 | def setup(self):
168 | self._n_kv_heads = (
169 | self.n_heads if self.n_kv_heads is None else self.n_kv_heads
170 | )
171 | assert self.n_heads % self._n_kv_heads == 0
172 | self.n_rep = self.n_heads // self._n_kv_heads
173 | self.head_dim = self.dim // self.n_heads
174 | self.wq = nn.Dense(self.n_heads * self.head_dim, use_bias=self.qkv_bias)
175 | self.wk = nn.Dense(self._n_kv_heads * self.head_dim, use_bias=self.qkv_bias)
176 | self.wv = nn.Dense(self._n_kv_heads * self.head_dim, use_bias=self.qkv_bias)
177 | self.wo = nn.Dense(self.dim, use_bias=False)
178 | if self.dropout_rate > 0.0:
179 | self.attn_dropout = nn.Dropout(self.dropout_rate)
180 | self.resid_dropout = Dropout1d(self.dropout_rate)
181 |
182 | def __call__(self, x, freqs_cos, freqs_sin, train=False):
183 | bsz, seqlen, _ = x.shape
184 |
185 | # QKV
186 | xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
187 | xq = xq.reshape(bsz, seqlen, self.n_heads, self.head_dim)
188 | xk = xk.reshape(bsz, seqlen, self._n_kv_heads, self.head_dim)
189 | xv = xv.reshape(bsz, seqlen, self._n_kv_heads, self.head_dim)
190 |
191 | # RoPE relative positional embeddings
192 | xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)
193 |
194 | # grouped multiquery attention: expand out keys and values
195 | xk = repeat_kv(xk, self.n_rep)
196 | xv = repeat_kv(xv, self.n_rep)
197 |
198 | # make heads into a batch dimension
199 | xq = xq.swapaxes(1, 2) # (bs, n_heads, seqlen, head_dim)
200 | xk = xk.swapaxes(1, 2)
201 | xv = xv.swapaxes(1, 2)
202 |
203 | scores = jnp.matmul(xq, xk.swapaxes(2, 3)) / math.sqrt(self.head_dim)
204 | if self.causal:
205 | mask = jnp.full((1, 1, seqlen, seqlen), -jnp.inf)
206 | mask = jnp.triu(mask, k=1)
207 | scores = (
208 | scores + mask[:, :, :seqlen, :seqlen]
209 | ) # (bs, n_heads, seqlen, seqlen)
210 | scores = nn.softmax(scores, axis=-1)
211 | if self.dropout_rate > 0.0:
212 | scores = self.attn_dropout(scores, deterministic=not train)
213 | output = jnp.matmul(scores, xv) # (bs, n_heads, seqlen, head_dim)
214 |
215 | # restore time as batch dimension and concat heads
216 | output = output.swapaxes(1, 2).reshape(bsz, seqlen, -1)
217 |
218 | # final projection into the residual stream
219 | output = self.wo(output)
220 | if self.dropout_rate > 0.0:
221 | output = self.resid_dropout(output, deterministic=not train)
222 | return output
223 |
224 |
225 | class FeedForward(nn.Module):
226 |
227 | dim: int
228 | multiple_of: int
229 | dropout_rate: float
230 | hidden_dim: int | None = None
231 | w_init_scale: float = 1.0
232 | mlp_type: str = 'swiglu'
233 |
234 | def setup(self):
235 | multiple_of = self.multiple_of
236 | hidden_dim = self.hidden_dim
237 | if hidden_dim is None:
238 | hidden_dim = 4 * self.dim
239 | hidden_dim = int(2 * hidden_dim / 3)
240 | hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
241 | w_init = nn.initializers.variance_scaling(
242 | self.w_init_scale, 'fan_in', 'truncated_normal'
243 | )
244 | self.w1 = nn.Dense(hidden_dim, use_bias=False, kernel_init=w_init)
245 | self.w2 = nn.Dense(self.dim, use_bias=False, kernel_init=w_init)
246 | self.w3 = nn.Dense(hidden_dim, use_bias=False, kernel_init=w_init)
247 | # self.dropout = nn.Dropout(self.dropout_rate)
248 | if self.dropout_rate > 0.0:
249 | self.dropout = Dropout1d(self.dropout_rate)
250 |
251 | def __call__(self, x, train=False):
252 | activation = activation_map[self.mlp_type]
253 | y = self.w2(activation(self.w1(x)) * self.w3(x))
254 | if self.dropout_rate > 0.0:
255 | return self.dropout(y, deterministic=not train)
256 | else:
257 | return y
258 |
259 |
260 | class TransformerBlock(nn.Module):
261 |
262 | layer_id: int
263 | args: ModelArgs
264 |
265 | def setup(self):
266 | args = self.args
267 | self.attention = Attention(
268 | args.dim,
269 | args.n_heads,
270 | n_kv_heads=args.n_kv_heads,
271 | dropout_rate=args.dropout_rate,
272 | causal=args.causal,
273 | )
274 |
275 | if args.depth_scaled_init:
276 | w_init_scale = 2.0 / args.n_layers
277 | else:
278 | w_init_scale = args.w_init_scale
279 |
280 | self.feed_forward = FeedForward(
281 | dim=args.dim,
282 | multiple_of=args.multiple_of,
283 | dropout_rate=args.dropout_rate,
284 | hidden_dim=args.hidden_dim,
285 | w_init_scale=w_init_scale,
286 | mlp_type=args.mlp_type,
287 | )
288 |
289 | @nn.compact
290 | def __call__(self, x, freqs_cos, freqs_sin, cond=None, train=False):
291 | if cond is not None:
292 | activation = activation_map[self.args.mlp_type]
293 | if self.args.cond_type == 'adaln':
294 | ln = nn.Sequential([
295 | # nn.swish,
296 | activation,
297 | nn.Dense(6 * self.args.dim, use_bias=True),
298 | ])
299 | elif self.args.cond_type == 'adaln_zero':
300 | ln = nn.Sequential([
301 | # nn.swish,
302 | activation,
303 | nn.Dense(
304 | 6 * self.args.dim,
305 | use_bias=True,
306 | kernel_init=nn.initializers.zeros,
307 | bias_init=nn.initializers.zeros,
308 | ),
309 | ])
310 | else:
311 | raise NotImplementedError()
312 | (shift_att, scale_att, gate_att, shift_mlp, scale_mlp, gate_mlp) = (
313 | jnp.split(ln(cond)[:, None, :], 6, axis=-1)
314 | )
315 | attention_norm = nn.LayerNorm(
316 | epsilon=self.args.norm_eps, use_bias=False, use_scale=False
317 | )
318 | ffn_norm = nn.LayerNorm(
319 | epsilon=self.args.norm_eps, use_bias=False, use_scale=False
320 | )
321 | h = x + gate_att * self.attention(
322 | attention_norm(x) * (scale_att + 1.0) + shift_att,
323 | freqs_cos,
324 | freqs_sin,
325 | train=train,
326 | )
327 | out = h + gate_mlp * self.feed_forward(
328 | ffn_norm(h) * (scale_mlp + 1.0) + shift_mlp, train=train
329 | )
330 | else:
331 | attention_norm = RMSNorm(self.args.dim, eps=self.args.norm_eps)
332 | ffn_norm = RMSNorm(self.args.dim, eps=self.args.norm_eps)
333 | h = x + self.attention(
334 | attention_norm(x), freqs_cos, freqs_sin, train=train
335 | )
336 | out = h + self.feed_forward(ffn_norm(h), train=train)
337 |
338 | return out
339 |
340 |
341 | class Transformer(nn.Module):
342 |
343 | args: ModelArgs
344 |
345 | @nn.compact
346 | def __call__(self, x, cond=None, train=False, output_channels=None):
347 | args = self.args
348 | if output_channels is None:
349 | output_channels = args.output_channels
350 |
351 | if args.embed_input:
352 | h = nn.Embed(args.n_embed_classes, args.dim)(x)
353 | if args.dropout_rate > 0.0:
354 | h = nn.Dropout(args.dropout_rate, deterministic=not train)(h)
355 | else:
356 | h = nn.Dense(args.dim)(x)
357 |
358 | seqlen = x.shape[1]
359 | freqs_cos, freqs_sin = precompute_freqs_cis(
360 | args.dim // args.n_heads, seqlen
361 | )
362 |
363 | freqs_cos = freqs_cos[:seqlen]
364 | freqs_sin = freqs_sin[:seqlen]
365 |
366 | for layer_id in range(args.n_layers):
367 | h = TransformerBlock(layer_id, args)(
368 | h, freqs_cos, freqs_sin, cond=cond, train=train
369 | )
370 |
371 | if cond is not None:
372 | output_norm = nn.LayerNorm(
373 | epsilon=args.norm_eps, use_bias=False, use_scale=False
374 | )
375 | activation = activation_map[args.mlp_type]
376 | if args.cond_type == 'adaln':
377 | ln = nn.Sequential([
378 | # nn.swish,
379 | activation,
380 | nn.Dense(2 * args.dim, use_bias=True),
381 | ])
382 | elif args.cond_type == 'adaln_zero':
383 | ln = nn.Sequential([
384 | # nn.swish,
385 | activation,
386 | nn.Dense(
387 | 2 * args.dim,
388 | use_bias=True,
389 | kernel_init=nn.initializers.zeros,
390 | bias_init=nn.initializers.zeros,
391 | ),
392 | ])
393 | else:
394 | raise NotImplementedError()
395 | shift_out, scale_out = jnp.split(ln(cond)[:, None, :], 2, axis=-1)
396 | logits = nn.Dense(
397 | output_channels, use_bias=False, kernel_init=nn.initializers.zeros
398 | )(output_norm(h) * (scale_out + 1) + shift_out)
399 | else:
400 | h = RMSNorm(args.dim, args.norm_eps)(h)
401 | logits = nn.Dense(
402 | features=output_channels,
403 | use_bias=False,
404 | kernel_init=nn.initializers.zeros,
405 | )(h)
406 |
407 | return logits
408 |
--------------------------------------------------------------------------------
/md4/networks/unet.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """UNet implementation.
17 |
18 | Adapted from https://github.com/google-research/vdm/blob/main/model_vdm.py
19 | """
20 |
21 | import flax.linen as nn
22 | import jax.numpy as jnp
23 |
24 |
25 | class SelfAttention(nn.Module):
26 | """Self attention layer in UNets."""
27 |
28 | num_heads: int = 1
29 |
30 | @nn.compact
31 | def __call__(self, x):
32 | _, h, w, c = x.shape
33 | z = nn.GroupNorm()(x)
34 | z = z.reshape(z.shape[0], -1, c)
35 | mha = nn.MultiHeadDotProductAttention(
36 | num_heads=self.num_heads, qkv_features=c
37 | )
38 | z = mha(z, z)
39 | z = z.reshape(-1, h, w, c) + x
40 | return z
41 |
42 |
43 | class ResBlock(nn.Module):
44 | """Residual block in UNets."""
45 |
46 | out_channels: int
47 | dropout_rate: float = 0.1
48 |
49 | @nn.compact
50 | def __call__(self, x, cond=None, train=False):
51 | in_channels = x.shape[-1]
52 | h = nn.GroupNorm()(x)
53 | h = nn.swish(h)
54 | h = nn.Conv(self.out_channels, kernel_size=(3, 3))(h)
55 | if cond is not None:
56 | cond_act = nn.Dense(
57 | self.out_channels,
58 | use_bias=False,
59 | kernel_init=nn.initializers.zeros,
60 | )(cond)
61 | h = h + cond_act[:, None, None, :]
62 | h = nn.GroupNorm()(h)
63 | h = nn.swish(h)
64 | h = nn.Dropout(rate=self.dropout_rate)(h, deterministic=not train)
65 | h = nn.Conv(
66 | self.out_channels, kernel_size=(3, 3), kernel_init=nn.initializers.zeros
67 | )(h)
68 | if in_channels != self.out_channels:
69 | h = nn.Dense(self.out_channels)(x) + h
70 | else:
71 | h = x + h
72 | return h
73 |
74 |
75 | class UNet(nn.Module):
76 | """UNet for Diffusion."""
77 |
78 | d_channels: int = 128
79 | n_layers: int = 32
80 | add_input: bool = False
81 | output_channels: int | None = None
82 | dropout_rate: float = 0.1
83 |
84 | @nn.compact
85 | def __call__(self, x, cond=None, train=False, output_channels=None):
86 | if output_channels is None:
87 | if self.output_channels is None:
88 | output_channels = x.shape[-1]
89 | else:
90 | output_channels = self.output_channels
91 |
92 | # Linear projection of input
93 | h = nn.Conv(self.d_channels, kernel_size=(3, 3))(x)
94 | hs = [h]
95 |
96 | # Downsampling
97 | for _ in range(self.n_layers):
98 | h = ResBlock(
99 | out_channels=self.d_channels, dropout_rate=self.dropout_rate
100 | )(h, cond, train)
101 | hs.append(h)
102 |
103 | # Middle
104 | h = ResBlock(out_channels=self.d_channels, dropout_rate=self.dropout_rate)(
105 | h, cond, train
106 | )
107 | h = SelfAttention(num_heads=1)(h)
108 | h = ResBlock(out_channels=self.d_channels, dropout_rate=self.dropout_rate)(
109 | h, cond, train
110 | )
111 |
112 | # Upsampling
113 | for _ in range(self.n_layers + 1):
114 | h = jnp.concatenate([h, hs.pop()], axis=-1)
115 | h = ResBlock(
116 | out_channels=self.d_channels, dropout_rate=self.dropout_rate
117 | )(h, cond, train)
118 |
119 | assert not hs
120 |
121 | # Predict noise
122 | h = nn.swish(nn.GroupNorm()(h))
123 | h = nn.Conv(output_channels, (3, 3), kernel_init=nn.initializers.zeros)(h)
124 |
125 | if self.add_input:
126 | h += x
127 | return h
128 |
--------------------------------------------------------------------------------
/md4/networks/uvit.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """UViT implementation based on DiT."""
17 |
18 | from typing import Sequence
19 |
20 | import flax.linen as nn
21 | import jax.numpy as jnp
22 |
23 | from md4.networks.dit import DiT
24 |
25 |
26 | def nearest_neighbor_upsample(x):
27 | bs, h, w, c = x.shape
28 | x = x.reshape(bs, h, 1, w, 1, c)
29 | x = jnp.broadcast_to(x, (bs, h, 2, w, 2, c))
30 | return x.reshape(bs, h * 2, w * 2, c)
31 |
32 |
33 | class CondGroupNorm(nn.Module):
34 | """Conditional normalization."""
35 |
36 | @nn.compact
37 | def __call__(self, x, cond=None):
38 | c = x.shape[-1]
39 | if cond is not None:
40 | cond_act = nn.Dense(c * 2, kernel_init=nn.initializers.zeros)(cond)
41 | scale, shift = jnp.split(cond_act[:, None, None, :], 2, axis=-1)
42 | x = nn.GroupNorm(use_bias=False, use_scale=False)(x) * (1 + scale) + shift
43 | else:
44 | x = nn.GroupNorm()(x)
45 | return x
46 |
47 |
48 | class SelfAttention(nn.Module):
49 | """Self attention layer in UNets."""
50 |
51 | num_heads: int = 1
52 |
53 | @nn.compact
54 | def __call__(self, x, cond=None):
55 | bs, h, w, c = x.shape
56 | z = CondGroupNorm()(x, cond=cond)
57 | z = z.reshape(bs, -1, c)
58 |
59 | mha = nn.MultiHeadDotProductAttention(
60 | num_heads=self.num_heads,
61 | qkv_features=c,
62 | out_kernel_init=nn.zeros_init(),
63 | )
64 | z = mha(z, z)
65 | z = z.reshape(-1, h, w, c) + x
66 | return z
67 |
68 |
69 | class ResBlock(nn.Module):
70 | """Residual block in UNets."""
71 |
72 | out_channels: int | None = None
73 | dropout_rate: float = 0.1
74 | resample: str | None = None
75 |
76 | @nn.compact
77 | def __call__(self, x, cond=None, train=False):
78 | in_channels = x.shape[-1]
79 | out_channels = (
80 | in_channels if self.out_channels is None else self.out_channels
81 | )
82 | h = CondGroupNorm()(x, cond=cond)
83 | h = nn.swish(h)
84 |
85 | if self.resample is not None:
86 |
87 | def updown(z):
88 | return {
89 | 'up': nearest_neighbor_upsample(z),
90 | 'down': nn.avg_pool(z, (2, 2), (2, 2)),
91 | }[self.resample]
92 |
93 | h = updown(h)
94 | x = updown(x)
95 | h = nn.Conv(out_channels, kernel_size=(3, 3))(h)
96 | h = CondGroupNorm()(h, cond=cond)
97 | h = nn.swish(h)
98 | h = nn.Dropout(rate=self.dropout_rate)(h, deterministic=not train)
99 | h = nn.Conv(
100 | out_channels, kernel_size=(3, 3), kernel_init=nn.initializers.zeros
101 | )(h)
102 | if in_channels != out_channels:
103 | x = nn.Dense(out_channels)(x)
104 | return x + h
105 |
106 |
107 | class UNet(nn.Module):
108 | """UNet for Diffusion."""
109 |
110 | d_channels: int = 128
111 | n_layers: int = 32
112 | n_dit_layers: int = 0
113 | dit_num_heads: int = 12
114 | dit_hidden_size: int = 768
115 | ch_mult: Sequence[int] = (1,)
116 | add_input: bool = False
117 | output_channels: int | None = None
118 | dropout_rate: float = 0.1
119 |
120 | @nn.compact
121 | def __call__(self, x, cond=None, train=False, output_channels=None):
122 | if output_channels is None:
123 | if self.output_channels is None:
124 | output_channels = x.shape[-1]
125 | else:
126 | output_channels = self.output_channels
127 | num_resolutions = len(self.ch_mult)
128 |
129 | # Linear projection of input
130 | h = nn.Conv(self.d_channels, kernel_size=(3, 3))(x)
131 | hs = [h]
132 |
133 | # Downsampling
134 | for i_level in range(num_resolutions):
135 | for _ in range(self.n_layers):
136 | h = ResBlock(
137 | out_channels=self.d_channels * self.ch_mult[i_level],
138 | dropout_rate=self.dropout_rate,
139 | )(h, cond=cond, train=train)
140 | hs.append(h)
141 | # Downsample
142 | if i_level != num_resolutions - 1:
143 | h = ResBlock(
144 | dropout_rate=self.dropout_rate,
145 | resample='down',
146 | )(h, cond=cond, train=train)
147 | hs.append(h)
148 |
149 | # Middle
150 | _, img_size, _, c = h.shape
151 | h = DiT(
152 | img_size=img_size,
153 | patch_size=2,
154 | in_channels=c,
155 | out_channels=c,
156 | hidden_size=self.dit_hidden_size, # c * 2, c * 3..
157 | depth=self.n_dit_layers, # 8, 12, 16, 20
158 | num_heads=self.dit_num_heads, # 8, 12..
159 | dropout_rate=self.dropout_rate,
160 | )(h, cond=cond, train=train)
161 |
162 | # Upsampling
163 | for i_level in reversed(range(num_resolutions)):
164 | for _ in range(self.n_layers + 1):
165 | h = jnp.concatenate([h, hs.pop()], axis=-1)
166 | h = ResBlock(
167 | out_channels=self.d_channels * self.ch_mult[i_level],
168 | dropout_rate=self.dropout_rate,
169 | )(h, cond=cond, train=train)
170 | # Upsample
171 | if i_level != 0:
172 | h = ResBlock(
173 | dropout_rate=self.dropout_rate,
174 | resample='up',
175 | )(h, cond=cond, train=train)
176 |
177 | assert not hs
178 |
179 | # Predict noise
180 | h = nn.swish(CondGroupNorm()(h, cond=cond))
181 | h = nn.Conv(output_channels, (3, 3), kernel_init=nn.initializers.zeros)(h)
182 |
183 | if self.add_input:
184 | h += x
185 | return h
186 |
--------------------------------------------------------------------------------
/md4/sampling.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Sampling functions."""
17 |
18 | import functools
19 |
20 | import jax
21 | import jax.numpy as jnp
22 |
23 |
24 | def get_attr(train_state, key):
25 | if hasattr(train_state, key):
26 | return getattr(train_state, key)
27 | else:
28 | return train_state[key]
29 |
30 |
31 | @functools.partial(jax.pmap, axis_name='batch', static_broadcasted_argnums=0)
32 | def generate(model, train_state, rng, dummy_inputs, conditioning=None):
33 | """Generate samples from the diffusion model."""
34 | rng = jax.random.fold_in(rng, jax.lax.axis_index('batch'))
35 | variables = {
36 | 'params': get_attr(train_state, 'ema_params'),
37 | **get_attr(train_state, 'state'),
38 | }
39 | rng, sub_rng = jax.random.split(rng)
40 | zt = model.apply(
41 | variables,
42 | dummy_inputs.shape[0],
43 | method=model.prior_sample,
44 | rngs={'sample': sub_rng},
45 | )
46 | rng, sub_rng = jax.random.split(rng)
47 |
48 | def body_fn(i, zt):
49 | return model.apply(
50 | variables,
51 | sub_rng,
52 | i,
53 | model.timesteps,
54 | zt,
55 | conditioning=conditioning,
56 | method=model.sample_step,
57 | )
58 |
59 | z0 = jax.lax.fori_loop(
60 | lower=0, upper=model.timesteps, body_fun=body_fn, init_val=zt
61 | )
62 | return model.apply(
63 | variables,
64 | z0,
65 | conditioning=conditioning,
66 | method=model.decode,
67 | rngs={'sample': rng},
68 | )
69 |
70 |
71 | def simple_generate(rng, train_state, batch_size, model, conditioning=None):
72 | """Generate samples from the diffusion model."""
73 | variables = {'params': train_state.params, **train_state.state}
74 | rng, sub_rng = jax.random.split(rng)
75 | zt = model.apply(
76 | variables,
77 | batch_size,
78 | method=model.prior_sample,
79 | rngs={'sample': sub_rng},
80 | )
81 | rng, sub_rng = jax.random.split(rng)
82 |
83 | def body_fn(i, zt):
84 | return model.apply(
85 | variables,
86 | sub_rng,
87 | i,
88 | model.timesteps,
89 | zt,
90 | conditioning=conditioning,
91 | method=model.sample_step,
92 | )
93 |
94 | z0 = jax.lax.fori_loop(
95 | lower=0, upper=model.timesteps, body_fun=body_fn, init_val=zt
96 | )
97 | return model.apply(
98 | variables,
99 | z0,
100 | conditioning=conditioning,
101 | method=model.decode,
102 | rngs={'sample': rng},
103 | )
104 |
105 |
106 | @functools.partial(jax.pmap, axis_name='batch', static_broadcasted_argnums=0)
107 | def reconstruct(model, train_state, rng, t, inputs, conditioning=None):
108 | """Reconstruct from the latent at t."""
109 | rng = jax.random.fold_in(rng, jax.lax.axis_index('batch'))
110 | variables = {
111 | 'params': get_attr(train_state, 'ema_params'),
112 | **get_attr(train_state, 'state'),
113 | }
114 | f = model.apply(variables, inputs, conditioning, method=model.encode)
115 |
116 | timesteps = model.timesteps
117 | tn = jnp.ceil(t * timesteps).astype('int32')
118 | t = tn / timesteps
119 | rng, sub_rng = jax.random.split(rng)
120 | zt = model.apply(
121 | variables, f, t, method=model.forward_sample, rngs={'sample': sub_rng}
122 | )
123 | rng, sub_rng = jax.random.split(rng)
124 |
125 | def body_fn(i, zt):
126 | return model.apply(
127 | variables,
128 | sub_rng,
129 | i,
130 | timesteps,
131 | zt,
132 | conditioning=conditioning,
133 | method=model.sample_step,
134 | )
135 |
136 | z0 = jax.lax.fori_loop(
137 | lower=timesteps - tn, upper=timesteps, body_fun=body_fn, init_val=zt
138 | )
139 | return model.apply(
140 | variables,
141 | z0,
142 | conditioning=conditioning,
143 | method=model.decode,
144 | rngs={'sample': rng},
145 | )
146 |
--------------------------------------------------------------------------------
/md4/train.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Methods for training MD4/GenMD4 on text/image datasets."""
17 |
18 | from collections.abc import Callable, Mapping, Sequence
19 | import copy
20 | import functools
21 | from typing import Any
22 |
23 | from absl import logging
24 | from clu import metric_writers
25 | from clu import metrics
26 | from clu import parameter_overview
27 | from clu import periodic_actions
28 | from etils import epath
29 | import flax
30 | import flax.jax_utils as flax_utils
31 | import flax.linen as nn
32 | import grain.python as grain
33 | import jax
34 | from jax.experimental import checkify
35 | import jax.numpy as jnp
36 | import ml_collections
37 | import numpy as np
38 | import optax
39 | from orbax import checkpoint as orbax_checkpoint
40 |
41 | from md4 import input_pipeline
42 | from md4 import input_pipeline_v2
43 | from md4 import sampling
44 | from md4 import utils
45 | from md4.models import utils as model_utils
46 |
47 |
48 | @flax.struct.dataclass
49 | class TrainState:
50 | """State of the model and the training.
51 |
52 | This includes parameters, statistics and optimizer.
53 | """
54 |
55 | rng: jnp.ndarray
56 | step: int
57 | params: Any
58 | ema_params: Any
59 | opt_state: optax.OptState
60 | state: Any
61 |
62 |
63 | def merge_batch_stats(replicated_state: TrainState) -> TrainState:
64 | """Merge model batch stats."""
65 | if jax.tree.leaves(replicated_state.state):
66 | cross_replica_mean = jax.pmap(lambda x: jax.lax.pmean(x, "batch"), "batch")
67 | return replicated_state.replace(
68 | state=cross_replica_mean(replicated_state.state)
69 | )
70 | else:
71 | return replicated_state
72 |
73 |
74 | def _get_checkpoint_manager(
75 | config: ml_collections.ConfigDict, workdir: epath.PathLike
76 | ) -> orbax_checkpoint.CheckpointManager:
77 | """Loads the orbax checkpoint manager for train state and data iterator."""
78 | # The keys in this dict should match the keys in `checkpointed_state`.
79 | checkpointers = dict(
80 | train_state=orbax_checkpoint.PyTreeCheckpointer(),
81 | train_iter=orbax_checkpoint.Checkpointer(
82 | grain.PyGrainCheckpointHandler()
83 | ), # pytype:disable=wrong-arg-types
84 | )
85 | checkpoint_dir = epath.Path(workdir) / "checkpoints"
86 | keep_period = (
87 | config.checkpoint_keep_period
88 | if config.checkpoint_keep_period > 0
89 | else None
90 | )
91 | return orbax_checkpoint.CheckpointManager(
92 | checkpoint_dir,
93 | checkpointers=checkpointers,
94 | options=orbax_checkpoint.CheckpointManagerOptions(
95 | max_to_keep=5, create=True, keep_period=keep_period
96 | ),
97 | )
98 |
99 |
100 | def create_train_state(
101 | config: ml_collections.ConfigDict,
102 | rng: jnp.ndarray,
103 | input_shape: Sequence[int] | Mapping[str, Sequence[int]],
104 | schedule_fn: Callable[[Any], Any],
105 | ) -> tuple[nn.Module, optax.GradientTransformation, TrainState, Any]:
106 | """Create and initialize the model."""
107 | model = model_utils.get_model(config)
108 |
109 | if config.classes > 0:
110 | conditioning = jnp.zeros(input_shape[0], dtype="int32")
111 | else:
112 | conditioning = None
113 | rng, sample_rng, init_rng = jax.random.split(rng, 3)
114 | dummy_input = jnp.ones(input_shape, dtype="int32")
115 |
116 | output, variables = model.init_with_output(
117 | {"sample": sample_rng, "params": init_rng},
118 | dummy_input,
119 | cond=conditioning,
120 | train=False,
121 | )
122 | metric_keys = sorted(list(output.keys()) + ["learning_rate"])
123 | logging.info("metric_keys: %s", metric_keys)
124 | metrics_class = create_metrics_class_from_keys(metric_keys)
125 | state, params = flax.core.pop(variables, "params")
126 | del variables
127 | parameter_overview.log_parameter_overview(
128 | state, msg="############# state #############"
129 | )
130 | parameter_overview.log_parameter_overview(
131 | params, msg="############# params #############"
132 | )
133 |
134 | optimizer = optax.chain(
135 | optax.clip(config.clip) if config.clip > 0.0 else optax.identity(),
136 | optax.adamw(
137 | schedule_fn,
138 | b1=0.9,
139 | b2=config.b2,
140 | weight_decay=config.weight_decay,
141 | ),
142 | )
143 | return (
144 | model,
145 | optimizer,
146 | TrainState(
147 | step=0,
148 | rng=rng,
149 | params=params,
150 | ema_params=copy.deepcopy(params) if config.ema_rate > 0.0 else None,
151 | opt_state=optimizer.init(params),
152 | state=state,
153 | ),
154 | metrics_class,
155 | )
156 |
157 |
158 | def create_metrics_class_from_keys(metric_keys):
159 | """Create train/eval metrics collection from dictionary."""
160 | average_keys = []
161 | stats = dict(
162 | (k, metrics.Average.from_output(k))
163 | if (k in average_keys) or ("loss" in k)
164 | else (k, metrics.LastValue.from_output(k))
165 | for k in metric_keys
166 | )
167 | return metrics.Collection.create(**stats)
168 |
169 |
170 | def cosine_decay(lr: float, current_step: float, total_steps: float) -> float:
171 | ratio = jnp.maximum(0.0, current_step / total_steps)
172 | mult = 0.5 * (1.0 + jnp.cos(jnp.pi * ratio))
173 | return mult * lr # pytype: disable=bad-return-type # jax-types
174 |
175 |
176 | def get_learning_rate(
177 | step: int,
178 | *,
179 | base_learning_rate: float,
180 | num_steps: int,
181 | warmup_steps: int | None = None,
182 | schedule_type: str = "cosine",
183 | ) -> float:
184 | """Cosine learning rate schedule."""
185 | logging.info(
186 | "get_learning_rate(step=%s, base_learning_rate=%s, num_steps=%s",
187 | step,
188 | base_learning_rate,
189 | num_steps,
190 | )
191 | warmup = jnp.minimum(1.0, step / warmup_steps)
192 | if schedule_type == "cosine":
193 | lr = cosine_decay(
194 | base_learning_rate, step - warmup_steps, num_steps - warmup_steps
195 | )
196 | elif schedule_type == "constant":
197 | lr = base_learning_rate
198 | else:
199 | raise NotImplementedError()
200 | return lr * warmup # pytype: disable=bad-return-type # jax-types
201 |
202 |
203 | def loss_fn(params, state, rng, model, batch, train=False):
204 | """Loss function."""
205 | rng, sample_rng = jax.random.split(rng)
206 | rngs = {"sample": sample_rng}
207 | if train:
208 | _, dropout_rng = jax.random.split(rng)
209 | rngs["dropout"] = dropout_rng
210 |
211 | variables = {"params": params, **state}
212 | if "image" in batch:
213 | x = batch["image"]
214 | elif "text" in batch:
215 | x = batch["text"]
216 | else:
217 | raise ValueError("Unsupported targets/tasks.")
218 |
219 | if "label" in batch:
220 | conditioning = batch["label"].astype("int32")
221 | else:
222 | conditioning = None
223 |
224 | new_state = {}
225 | if train:
226 | metrics_dict, new_state = model.apply(
227 | variables,
228 | x,
229 | cond=conditioning,
230 | train=train,
231 | rngs=rngs,
232 | mutable=list(state.keys()),
233 | )
234 | else:
235 | metrics_dict = model.apply(
236 | variables, x, cond=conditioning, train=train, rngs=rngs
237 | )
238 |
239 | loss = metrics_dict["loss"]
240 | if train:
241 | return loss, (new_state, metrics_dict)
242 | return loss, metrics_dict
243 |
244 |
245 | @jax.jit
246 | def merge_metrics(a_tree, b_tree):
247 | return jax.tree.map(lambda a, b: a + b, a_tree, b_tree)
248 |
249 |
250 | def train_step(
251 | model: nn.Module,
252 | optimizer: optax.GradientTransformation,
253 | train_state: TrainState,
254 | batch: Mapping[str, jnp.ndarray],
255 | learning_rate_fn: Callable[[int], float],
256 | train_metrics_class: Any,
257 | ema_rate: float = 0.0,
258 | num_microbatches: int | None = None,
259 | ) -> tuple[TrainState, metrics.Collection]:
260 | """Perform a single training step."""
261 | logging.info("train_step(batch=%s)", batch)
262 | rng, new_rng = jax.random.split(train_state.rng)
263 | rng = jax.random.fold_in(rng, jax.lax.axis_index("batch"))
264 |
265 | grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
266 | if num_microbatches is None or num_microbatches <= 1:
267 | (_, (new_state, metrics_dict)), grads = grad_fn(
268 | train_state.params, train_state.state, rng, model, batch, train=True
269 | )
270 | else:
271 | batch_size = next(iter(batch.values())).shape[0]
272 | assert (
273 | batch_size % num_microbatches == 0
274 | ), "Batch size isn't divided evenly by num_microbatches."
275 | microbatch_size = batch_size // num_microbatches
276 | logging.info(
277 | "using microbatches: %d microbatches, %d size",
278 | num_microbatches,
279 | microbatch_size,
280 | )
281 |
282 | def get_microbatch(
283 | batch: Mapping[str, jnp.ndarray], idx: int
284 | ) -> Mapping[str, jnp.ndarray]:
285 | """Fetch microbatch slice from possibly-packed input data."""
286 | offset = idx * microbatch_size
287 | length = microbatch_size
288 | starts = {k: [offset] + [0] * (b.ndim - 1) for k, b in batch.items()}
289 | limits = {k: [length] + list(b.shape[1:]) for k, b in batch.items()}
290 | return {
291 | k: jax.lax.dynamic_slice(b, starts[k], limits[k])
292 | for k, b in batch.items()
293 | }
294 |
295 | def metrics_and_grad(loop_cnt, rng, train_state_state):
296 | _, mbrng = jax.random.split(rng)
297 | mb = get_microbatch(batch, loop_cnt)
298 |
299 | (_, (new_state, metrics_dict)), grads = grad_fn(
300 | train_state.params, train_state_state, mbrng, model, mb, train=True
301 | )
302 | return metrics_dict, grads, new_state
303 |
304 | def per_microbatch_train_step(loop_cnt, carry):
305 | (rng, grad_accum, prev_metrics_dict, train_state_state) = carry
306 | metrics_dict, grads, train_state_state = metrics_and_grad(
307 | loop_cnt, rng, train_state_state
308 | )
309 |
310 | grad_accum = jax.tree.map(jnp.add, grad_accum, grads)
311 | metrics_dict = jax.lax.cond(
312 | loop_cnt == 0,
313 | lambda _: metrics_dict,
314 | lambda _: merge_metrics(prev_metrics_dict, metrics_dict),
315 | None,
316 | )
317 | return rng, grad_accum, metrics_dict, train_state_state
318 |
319 | # Initialize gradient accumulation loop state.
320 | accum_dtype = jnp.float32
321 | grad_accum_init = jax.tree.map(
322 | lambda x: jnp.zeros(x.shape, accum_dtype), train_state.params
323 | )
324 | initial_metrics_shape, _, _ = jax.eval_shape(
325 | metrics_and_grad,
326 | loop_cnt=0,
327 | rng=rng,
328 | train_state_state=train_state.state,
329 | )
330 |
331 | initial_metrics = {
332 | k: jnp.zeros(shape=v.shape, dtype=v.dtype)
333 | for k, v in initial_metrics_shape.items()
334 | }
335 |
336 | loop_init = (
337 | rng,
338 | grad_accum_init,
339 | initial_metrics,
340 | train_state.state,
341 | )
342 | _, grads, metrics_dict, train_state_state = jax.lax.fori_loop(
343 | 0, num_microbatches, per_microbatch_train_step, loop_init
344 | )
345 | metrics_dict = jax.tree.map(lambda x: x / num_microbatches, metrics_dict)
346 | new_state = train_state_state
347 |
348 | # Compute average gradient across multiple workers.
349 | grads = jax.lax.pmean(grads, axis_name="batch")
350 | updates, new_opt_state = optimizer.update(
351 | grads, train_state.opt_state, train_state.params
352 | )
353 | new_params = optax.apply_updates(train_state.params, updates)
354 | if ema_rate > 0.0:
355 | new_ema_params = jax.tree_util.tree_map(
356 | lambda x, y: x + (1.0 - ema_rate) * (y - x),
357 | train_state.ema_params,
358 | new_params,
359 | )
360 | else:
361 | new_ema_params = None
362 | new_train_state = train_state.replace(
363 | step=train_state.step + 1,
364 | rng=new_rng,
365 | params=new_params,
366 | ema_params=new_ema_params,
367 | opt_state=new_opt_state,
368 | state=new_state,
369 | )
370 |
371 | metrics_update = train_metrics_class.gather_from_model_output(
372 | learning_rate=learning_rate_fn(train_state.step),
373 | **metrics_dict,
374 | )
375 | return new_train_state, metrics_update
376 |
377 |
378 | def eval_step(
379 | model: nn.Module,
380 | rng: jnp.ndarray,
381 | train_state: TrainState,
382 | batch: Mapping[str, jnp.ndarray],
383 | eval_metrics_class: Any,
384 | ema_rate: float = 0.0,
385 | ) -> metrics.Collection:
386 | """Compute the metrics for the given model in inference mode."""
387 | logging.info("eval_step(batch=%s)", batch)
388 | rng = jax.random.fold_in(rng, jax.lax.axis_index("batch"))
389 | params = train_state.ema_params if ema_rate > 0.0 else train_state.params
390 |
391 | _, metrics_dict = loss_fn(
392 | params, train_state.state, rng, model, batch, train=False
393 | )
394 | return eval_metrics_class.gather_from_model_output(
395 | learning_rate=0.0, **metrics_dict
396 | )
397 |
398 |
399 | def evaluate(
400 | p_eval_step: Any,
401 | rng: jnp.ndarray,
402 | train_state: TrainState,
403 | eval_loader: grain.DataLoader,
404 | num_eval_steps: int = -1,
405 | ):
406 | """Evaluate the model on the given dataset."""
407 | logging.info("Starting evaluation.")
408 | eval_metrics = None
409 | with utils.StepTraceContextHelper("eval", 0) as trace_context:
410 | # Use `iter` to reset the eval_loader before each evaluation.
411 | for step, batch in enumerate(iter(eval_loader)):
412 | rng, sub_rng = jax.random.split(rng)
413 | sub_rng = flax_utils.replicate(sub_rng)
414 | batch = utils.reshape_batch(batch)
415 | metrics_update = flax_utils.unreplicate(
416 | p_eval_step(rng=sub_rng, train_state=train_state, batch=batch)
417 | )
418 | eval_metrics = (
419 | metrics_update
420 | if eval_metrics is None
421 | else eval_metrics.merge(metrics_update)
422 | )
423 | if num_eval_steps > 0 and step + 1 == num_eval_steps:
424 | break
425 | trace_context.next_step()
426 | if eval_metrics is None:
427 | raise ValueError(f"Eval dataset {eval_loader} was empty.")
428 | return eval_metrics
429 |
430 |
431 | def train_and_evaluate(
432 | config: ml_collections.ConfigDict, workdir: epath.PathLike
433 | ):
434 | """Runs a training and evaluation loop.
435 |
436 | Args:
437 | config: Configuration to use.
438 | workdir: Working directory for checkpoints and TF summaries. If this
439 | contains checkpoint training will be resumed from the latest checkpoint.
440 | """
441 | workdir = epath.Path(workdir)
442 | workdir.mkdir(parents=True, exist_ok=True)
443 |
444 | rng = utils.get_rng(config.seed)
445 | logging.info("Using random seed %s.", rng)
446 | writer = metric_writers.create_default_writer(
447 | workdir, just_logging=jax.process_index() > 0
448 | )
449 | # Learning rate schedule.
450 | assert config.batch_size % jax.device_count() == 0
451 | per_device_batch_size = config.batch_size // jax.device_count()
452 | num_train_steps = input_pipeline.get_num_train_steps(config)
453 | steps_per_epoch = num_train_steps // config.num_epochs
454 | logging.info(
455 | "num_train_steps=%d, steps_per_epoch=%d", num_train_steps, steps_per_epoch
456 | )
457 | schedule_fn = functools.partial(
458 | get_learning_rate,
459 | base_learning_rate=config.learning_rate,
460 | num_steps=num_train_steps,
461 | warmup_steps=config.warmup_steps,
462 | schedule_type=config.learning_rate_schedule,
463 | )
464 |
465 | # Build input pipeline.
466 | rng, data_seed = jax.random.split(rng)
467 | data_seed = int(
468 | jax.random.randint(data_seed, [], minval=0, maxval=np.iinfo(np.int32).max)
469 | )
470 | # The input pipeline runs on each process and loads data for local TPUs.
471 | create_datasets = (
472 | input_pipeline_v2.create_datasets
473 | if config.get("use_v2_input_pipeline", None)
474 | else input_pipeline.create_datasets
475 | )
476 | train_loader, eval_loaders, dataset_info = create_datasets(config, data_seed)
477 | train_iter = iter(train_loader)
478 |
479 | # Initialize model.
480 | rng, model_rng = jax.random.split(rng)
481 | data_shape = input_pipeline.get_data_shape(config)
482 | model, optimizer, train_state, metrics_class = create_train_state( # pylint: disable=invalid-name
483 | config,
484 | model_rng,
485 | input_shape=(per_device_batch_size,) + data_shape,
486 | schedule_fn=schedule_fn,
487 | )
488 |
489 | # Set up checkpointing of the model and the input pipeline.
490 | checkpoint_manager = _get_checkpoint_manager(config, workdir)
491 |
492 | # Retrieve data from previous checkpoints if possible.
493 | checkpointed_state = dict(train_state=train_state, train_iter=train_iter)
494 | if checkpoint_manager.latest_step() is not None:
495 | checkpointed_state = checkpoint_manager.restore(
496 | checkpoint_manager.latest_step(), items=checkpointed_state
497 | )
498 | train_state = checkpointed_state["train_state"]
499 | train_iter = checkpointed_state["train_iter"]
500 |
501 | # Distribute training.
502 | train_state = flax_utils.replicate(train_state)
503 | train_step_func = functools.partial(
504 | train_step,
505 | model=model,
506 | optimizer=optimizer,
507 | train_metrics_class=metrics_class,
508 | learning_rate_fn=schedule_fn,
509 | ema_rate=config.ema_rate,
510 | num_microbatches=config.get("num_microbatches", None),
511 | )
512 | if config.check_nans:
513 | train_step_func = checkify.checkify(
514 | train_step_func, errors=checkify.float_checks
515 | )
516 | p_train_step = jax.pmap(
517 | train_step_func, axis_name="batch", donate_argnums=(0,)
518 | )
519 | p_eval_step = jax.pmap(
520 | functools.partial(
521 | eval_step,
522 | model=model,
523 | eval_metrics_class=metrics_class,
524 | ema_rate=config.ema_rate,
525 | ),
526 | axis_name="batch",
527 | )
528 |
529 | hooks = []
530 | report_progress = periodic_actions.ReportProgress(
531 | num_train_steps=num_train_steps, writer=writer
532 | )
533 | if jax.process_index() == 0:
534 | hooks += [
535 | report_progress,
536 | periodic_actions.Profile(num_profile_steps=5, logdir=workdir),
537 | ]
538 | train_metrics = None
539 |
540 | # Unreplicating from TPU is costly, so we only do it once at the start.
541 | initial_step = int(flax.jax_utils.unreplicate(train_state.step))
542 |
543 | with metric_writers.ensure_flushes(writer):
544 | # Steps are in interval [1, num_train_steps], not [0, num_train_steps - 1].
545 | for step in range(initial_step + 1, num_train_steps + 1):
546 | is_last_step = step == num_train_steps
547 |
548 | with jax.profiler.StepTraceAnnotation("train", step_num=step):
549 | batch = utils.reshape_batch(next(train_iter))
550 | if config.check_nans:
551 | errs, (train_state, metrics_update) = p_train_step(
552 | train_state=train_state, batch=batch
553 | )
554 | errs.throw()
555 | else:
556 | train_state, metrics_update = p_train_step(
557 | train_state=train_state, batch=batch
558 | )
559 | metric_update = flax_utils.unreplicate(metrics_update)
560 | train_metrics = (
561 | metric_update
562 | if train_metrics is None
563 | else train_metrics.merge(metric_update)
564 | )
565 |
566 | # Quick indication that training is happening.
567 | logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step)
568 | for h in hooks:
569 | h(step)
570 |
571 | if step % config.log_loss_every_steps == 0 or is_last_step:
572 | writer.write_scalars(step, train_metrics.compute())
573 | train_metrics = None
574 |
575 | if step == 1 or step % config.eval_every_steps == 0 or is_last_step:
576 | for split, eval_loader in eval_loaders.items():
577 | rng, eval_rng = jax.random.split(rng)
578 | with report_progress.timed("eval"):
579 | train_state = merge_batch_stats(train_state)
580 | eval_metrics = evaluate(
581 | p_eval_step,
582 | eval_rng,
583 | train_state,
584 | eval_loader,
585 | config.num_eval_steps,
586 | )
587 | eval_metrics_cpu = jax.tree_util.tree_map(
588 | np.array, eval_metrics.compute()
589 | )
590 | eval_metrics_cpu = {
591 | split + "_" + k: v for k, v in eval_metrics_cpu.items()
592 | }
593 | writer.write_scalars(step, eval_metrics_cpu)
594 |
595 | if hasattr(model, "sample_step"):
596 | with report_progress.timed("sample"):
597 | _, sample_rng = jax.random.split(rng)
598 | dummy_loader = train_loader
599 | dummy_batch = utils.reshape_batch(next(iter(dummy_loader)))
600 | dummy_inputs = dummy_batch[config.task_type]
601 | if "label" in dummy_batch:
602 | conditioning = dummy_batch["label"].astype("int32")
603 | else:
604 | conditioning = None
605 |
606 | samples = sampling.generate(
607 | model,
608 | train_state,
609 | flax_utils.replicate(sample_rng),
610 | dummy_inputs,
611 | conditioning=conditioning,
612 | )
613 |
614 | all_samples = jax.pmap(
615 | lambda x: jax.lax.all_gather(x, "batch"), axis_name="batch"
616 | )(samples)
617 | all_samples = flax_utils.unreplicate(all_samples)
618 | all_samples = all_samples.reshape(-1, *data_shape)
619 | if config.task_type == "image":
620 | sample_grid = utils.generate_image_grids(all_samples)
621 | writer.write_images(step, {"samples": sample_grid})
622 | del all_samples, sample_grid
623 | elif config.task_type == "text":
624 | tokenizer = dataset_info["tokenizer"]
625 | texts = utils.detokenize_texts(all_samples, tokenizer)
626 | writer.write_texts(step, {"samples": texts})
627 |
628 | if step % config.checkpoint_every_steps == 0 or is_last_step:
629 | with report_progress.timed("checkpoint"):
630 | train_state = merge_batch_stats(train_state)
631 | checkpoint_manager.save(
632 | step,
633 | items=dict(
634 | train_state=jax.tree_util.tree_map(
635 | np.array, flax_utils.unreplicate(train_state)
636 | ),
637 | train_iter=train_iter,
638 | ),
639 | )
640 |
641 | logging.info("Finishing training at step %d", num_train_steps)
642 |
--------------------------------------------------------------------------------
/md4/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Utils."""
17 |
18 | from collections.abc import Mapping
19 | import math
20 | import time
21 | from typing import Any, Tuple
22 | from absl import logging
23 | import chex
24 | from clu import platform
25 | import distrax
26 | import jax
27 | import jax.numpy as jnp
28 | import matplotlib.pyplot as plt
29 | import numpy as np
30 | from orbax import checkpoint as orbax_checkpoint
31 | import seaborn as sns
32 |
33 |
34 | def loss2bpt(loss_dict, data_shape):
35 | """Normalize loss to bits per token."""
36 | seq_len = jnp.prod(jnp.array(data_shape))
37 | rescale_to_bpd = 1.0 / (seq_len * jnp.log(2.0))
38 | bpt_loss_dict = {}
39 | for k, v in loss_dict.items():
40 | if "loss" in k:
41 | bpt_loss_dict[k] = v * rescale_to_bpd
42 | else:
43 | bpt_loss_dict[k] = v
44 | return bpt_loss_dict
45 |
46 |
47 | def constant_init(value, dtype="float32"):
48 | def _init(key, shape, dtype=dtype):
49 | del key
50 | return value * jnp.ones(shape, dtype)
51 |
52 | return _init
53 |
54 |
55 | def _logistic_pdf_fn(z, log_scales):
56 | return -z - 2.0 * jax.nn.softplus(-z) - log_scales
57 |
58 |
59 | class DiscretizedLogisticMixture(distrax.Distribution):
60 | """Discretized mixture of Logistics defined in PixelCNN++."""
61 |
62 | def __init__(
63 | self,
64 | w_logits,
65 | locs,
66 | log_scales,
67 | min_val=0.0,
68 | max_val=255.0,
69 | bin_size=1.0,
70 | ):
71 | self._w_logits = w_logits
72 | self._locs = locs
73 | self._log_scales = log_scales
74 | self._min_val = min_val
75 | self._max_val = max_val
76 | self._bin_size = bin_size
77 |
78 | self._batch_shape = jax.lax.broadcast_shapes(
79 | self._locs.shape[:-1], self._log_scales.shape[:-1]
80 | )
81 |
82 | self._cdf = jax.nn.sigmoid
83 | self._log_prob_fn = _logistic_pdf_fn
84 |
85 | @property
86 | def event_shape(self) -> Tuple[int, ...]:
87 | """Shape of event of distribution samples."""
88 | return ()
89 |
90 | @property
91 | def batch_shape(self) -> Tuple[int, ...]:
92 | return self._batch_shape
93 |
94 | @property
95 | def _n_components(self) -> int:
96 | return self._locs.shape[-1]
97 |
98 | def _mean(self):
99 | # The expected value of the mixture is the sum of the linear probabilities
100 | # of each component multiplied by their means.
101 | # Apply softmax on the logits to get the linear probabilities.
102 | probs = jax.nn.softmax(self._w_logits)
103 | return jnp.sum(self._locs * probs, axis=-1)
104 |
105 | def log_prob(self, event: chex.Array) -> chex.Array:
106 | # expand the mixture dim
107 | event = jnp.expand_dims(event, -1)
108 | assert len(self._locs.shape) == len(event.shape)
109 |
110 | # Expand the dimensions of the params for tiling and broadcasting.
111 | locs = self._locs
112 | w_logits = self._w_logits
113 | inv_scales = jnp.exp(-self._log_scales)
114 |
115 | # pdf at the mid of the bin, used when bins are too small
116 | z = (event - locs) * inv_scales
117 | mid_log_prob = self._log_prob_fn(z, self._log_scales)
118 |
119 | # Calculate difference of sigmoid.
120 | half_bin = self._bin_size / 2
121 | b = (event - locs + half_bin) * inv_scales
122 | a = (event - locs - half_bin) * inv_scales
123 | a = self._cdf(a)
124 | b = self._cdf(b)
125 | diff = b - a
126 |
127 | # Handle edge case.
128 | edge_b = (
129 | jax.nn.sigmoid((self._min_val - locs + half_bin) * inv_scales) - 0.0
130 | )
131 | edge_a = 1.0 - jax.nn.sigmoid(
132 | (self._max_val - locs - half_bin) * inv_scales
133 | )
134 | diff = jnp.where(event > self._max_val - half_bin, edge_a, diff)
135 | diff = jnp.where(event < self._min_val + half_bin, edge_b, diff)
136 |
137 | # Avoid small values for the subsequent log operation.
138 | diff = jnp.maximum(diff, 1e-12)
139 | log_prob = jnp.where(diff > 1e-8, jnp.log(diff), mid_log_prob)
140 |
141 | # Normalize logits.
142 | w_logits -= jax.nn.logsumexp(w_logits, axis=-1, keepdims=True)
143 |
144 | # Total loss, summed over the mixture dimension.
145 | total = w_logits + log_prob.sum(axis=-2)
146 | total = jax.nn.logsumexp(total, -1)
147 | return total
148 |
149 | def _sample_n(self, key: chex.PRNGKey, n: int):
150 | # First sample from the mixing weights.
151 | w_dist = distrax.Categorical(logits=self._w_logits)
152 | key, sub_key = jax.random.split(key)
153 | index = w_dist.sample(seed=sub_key, sample_shape=n)
154 | index = jax.nn.one_hot(index, num_classes=self._n_components)
155 |
156 | # Pick the mixture (per pixel) and board cast to n samples.
157 | log_scales = jnp.sum(jnp.expand_dims(self._log_scales, 0) * index, -1)
158 | loc = jnp.sum(jnp.expand_dims(self._locs, 0) * index, -1)
159 | scales = jnp.exp(log_scales)
160 |
161 | # Compute logistic
162 | _, sub_key = jax.random.split(key)
163 | logistic_sample = jax.random.logistic(sub_key, shape=loc.shape)
164 | sample_values = loc + scales * logistic_sample
165 |
166 | return jnp.clip(jnp.round(sample_values), 0, 255).astype("int32")
167 |
168 |
169 | def shifted_softplus(x, b=1.0):
170 | """log(exp(x) + b)."""
171 | return x + jax.nn.softplus(jnp.log(b) - x)
172 |
173 |
174 | def reverse_broadcast(value, ndim):
175 | """Broadcast by adding singleton axes to the right, instead of to the left."""
176 | if value.ndim > ndim:
177 | raise ValueError(
178 | f"Cannot reverse broadcast a value with {value.ndim} dimensions "
179 | f"to {ndim} dimensions."
180 | )
181 |
182 | if value.ndim < ndim:
183 | difference = ndim - value.ndim
184 | return value.reshape(value.shape + difference * (1,))
185 | else:
186 | return value
187 |
188 |
189 | def get_rng(seed: None | int | tuple[int, int]) -> np.ndarray:
190 | """Returns a JAX RNGKey."""
191 | if seed is None:
192 | # Case 1: No random seed given, use XManager ID.
193 | # All processes (and restarts) get exactly the same seed but every work unit
194 | # and experiment is different.
195 | work_unit = platform.work_unit()
196 | rng = (work_unit.experiment_id, work_unit.id)
197 | elif isinstance(seed, int):
198 | # Case 2: Single integer given.
199 | rng = (0, seed)
200 | else:
201 | # Case 3: tuple[int, int] given.
202 | if not isinstance(seed, (tuple, list)) or len(seed) != 2:
203 | raise ValueError(
204 | "Random seed must be an integer or tuple of 2 integers "
205 | f"but got {seed!r}"
206 | )
207 | rng = seed
208 | # JAX RNGKeys are arrays of np.uint32 and shape [2].
209 | return np.asarray(rng, dtype=np.uint32)
210 |
211 |
212 | class StepTraceContextHelper:
213 | """Helper class to use jax.profiler.StepTraceAnnotation."""
214 |
215 | def __init__(self, name: str, init_step_num: int):
216 | self.name = name
217 | self.step_num = init_step_num
218 | self.context = None
219 |
220 | def __enter__(self):
221 | self.context = jax.profiler.StepTraceAnnotation(
222 | self.name, step_num=self.step_num
223 | )
224 | self.step_num += 1
225 | self.context.__enter__()
226 | return self
227 |
228 | def __exit__(self, exc_type, exc_value, tb):
229 | assert self.context is not None, "Exited context without entering."
230 | self.context.__exit__(exc_type, exc_value, tb)
231 | self.context = None
232 |
233 | def next_step(self):
234 | if self.context is None:
235 | raise ValueError("Must call next_step() within a context.")
236 | self.__exit__(None, None, None)
237 | self.__enter__()
238 |
239 |
240 | def plot_embeddings(step, workdir, embeddings, annotations=None):
241 | """Helper function to plot embeddings."""
242 | fig, ax = plt.subplots()
243 | ax.set_title("Embeddings")
244 | if embeddings.ndim == 1:
245 | ax.scatter(np.arange(256), embeddings)
246 | else:
247 | assert embeddings.ndim == 2
248 | colors = np.linspace(0, 1, embeddings.shape[0])
249 | ax.scatter(embeddings[:, 0], embeddings[:, 1], c=colors, cmap="rainbow")
250 | if annotations:
251 | for i in range(embeddings.shape[0]):
252 | ax.annotate(annotations[i], (embeddings[i, 0], embeddings[i, 1]))
253 | embedding_plot = workdir / "embedding_{}.png".format(step)
254 | with embedding_plot.open("wb") as f:
255 | fig.savefig(f)
256 |
257 |
258 | def plot_heatmap(step, workdir, emb_gram_matrix, token_labels=None):
259 | """Helper function to plot embeddings."""
260 | fig, ax = plt.subplots()
261 | assert emb_gram_matrix.ndim == 2
262 | if token_labels:
263 | _ = sns.heatmap(emb_gram_matrix, linewidth=0.5, ax=ax,
264 | xticklabels=token_labels,
265 | yticklabels=token_labels)
266 | else:
267 | _ = sns.heatmap(emb_gram_matrix, linewidth=0.5, ax=ax)
268 |
269 | plt.xticks(rotation=90)
270 | plt.yticks(rotation=0)
271 |
272 | heatmap_plot = workdir / "embedding_heatmap_{}.png".format(step)
273 | with heatmap_plot.open("wb") as f:
274 | fig.savefig(f)
275 |
276 |
277 | def generate_image_grids(images):
278 | """Simple helper to generate a single image from a mini batch."""
279 |
280 | def image_grid(nrow, ncol, imagevecs, imshape):
281 | images = iter(imagevecs.reshape((-1,) + imshape))
282 | return jnp.vstack([
283 | jnp.hstack([next(images) for _ in range(ncol)][::-1])
284 | for _ in range(nrow)
285 | ])
286 |
287 | batch_size = images.shape[0]
288 | grid_size = int(np.floor(np.sqrt(batch_size)))
289 |
290 | image_shape = images.shape[1:]
291 | return image_grid(
292 | nrow=grid_size,
293 | ncol=grid_size,
294 | imagevecs=images[0 : grid_size**2],
295 | imshape=image_shape,
296 | ).astype("uint8")
297 |
298 |
299 | def detokenize_texts(tokens, tokenizer):
300 | """Detokenize the outputs."""
301 |
302 | assert len(tokens.shape) == 2, "Invalid token shape."
303 |
304 | np_tokens = np.asarray(tokens)
305 | detokenized = np.apply_along_axis(tokenizer.decode, -1, np_tokens)
306 |
307 | return detokenized
308 |
309 |
310 | def get_topk_token_mask(tokenizer, k=100):
311 | """Get the indices of Top-K tokens."""
312 |
313 | id_unigram_scores = [
314 | (id_, math.exp(tokenizer._model.GetScore(id_))) # pylint: disable=protected-access
315 | for id_ in range(tokenizer.vocab_size)]
316 |
317 | id_unigram_scores_sorted = sorted(id_unigram_scores,
318 | key=lambda x: x[1])
319 |
320 | # Exact k elements.
321 | topk = id_unigram_scores_sorted[-k:]
322 |
323 | topk_mask = [False for _ in range(tokenizer.vocab_size)]
324 | topk_ids = []
325 |
326 | for id_, _ in topk:
327 | topk_mask[id_] = True
328 | topk_ids.append(id_)
329 |
330 | topk_tokens = tokenizer.to_string_list(np.array(topk_ids))
331 |
332 | return np.array(topk_mask), topk_tokens
333 |
334 |
335 | def reshape_batch(batch: Mapping[str, Any]) -> Mapping[str, np.ndarray]:
336 | """Reshapes a batch to have the leading dimension for the local devices."""
337 | leading_dims = [jax.local_device_count(), -1]
338 | return jax.tree_util.tree_map(
339 | lambda x: np.reshape(x, leading_dims + list(x.shape[1:])), batch
340 | )
341 |
342 |
343 | def checkpoints_iterator(
344 | ckpt_manager, timeout=None, min_interval_secs=0, period=10000
345 | ):
346 | """Repeatedly yield new checkpoints as they appear.
347 |
348 | Args:
349 | ckpt_manager: CheckpointManager object.
350 | timeout: int: maximum number of seconds to wait. If left as `None`, then the
351 | process will wait indefinitely.
352 | min_interval_secs: int: minimum number of seconds between yielding
353 | checkpoints.
354 | period: The period of the checkpoint.
355 |
356 | Yields:
357 | new checkpoint step.
358 | """
359 | last_step = None
360 | while True:
361 | cur_step = wait_for_new_checkpoint(
362 | ckpt_manager, last_step, timeout=timeout, period=period
363 | )
364 | if cur_step is None:
365 | # timed out
366 | logging.info("Timed-out waiting for a checkpoint.")
367 | return
368 | start = time.time()
369 | last_step = cur_step
370 |
371 | yield cur_step
372 |
373 | time_to_next_eval = start + min_interval_secs - time.time()
374 | if time_to_next_eval > 0:
375 | time.sleep(time_to_next_eval)
376 |
377 |
378 | def wait_for_new_checkpoint(
379 | ckpt_manager: orbax_checkpoint.CheckpointManager,
380 | last_step=None,
381 | seconds_to_sleep=1,
382 | timeout=None,
383 | period=10000,
384 | ):
385 | """Waits until a new checkpoint file is found.
386 |
387 | Args:
388 | ckpt_manager: The directory in which checkpoints are saved.
389 | last_step: The last checkpoint path used or `None` if we're expecting a
390 | checkpoint for the first time.
391 | seconds_to_sleep: The number of seconds to sleep for before looking for a
392 | new checkpoint.
393 | timeout: The maximum number of seconds to wait. If left as `None`, then the
394 | process will wait indefinitely.
395 | period: The period of the checkpoint.
396 |
397 | Returns:
398 | a new checkpoint path, or None if the timeout was reached.
399 | """
400 | logging.info("Waiting for new checkpoint at %s", ckpt_manager.directory)
401 | stop_time = time.time() + timeout if timeout is not None else None
402 | while True:
403 | ckpt_manager.reload()
404 | cur_step = ckpt_manager.latest_step()
405 | if cur_step is None or cur_step == last_step or cur_step % period != 0:
406 | if stop_time is not None and time.time() + seconds_to_sleep > stop_time:
407 | return None
408 | time.sleep(seconds_to_sleep)
409 | else:
410 | logging.info("Found new checkpoint at step %d", cur_step)
411 | return cur_step
412 |
--------------------------------------------------------------------------------
/prepare_openwebtext_data.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Prepares the input pipeline for OpenWebText (OWT).
17 |
18 | This script tokenizes the OWT dataset and splits it into train and eval sets.
19 | The train and eval sets are saved as ArrayRecord files.
20 | """
21 |
22 | from array_record.python import array_record_module
23 | import datasets
24 | import numpy as np
25 | import tensorflow as tf
26 | import tqdm
27 | import transformers
28 |
29 |
30 | source = datasets.load_dataset(
31 | "Skylion007/openwebtext", name="plain_text", split="train", streaming=True
32 | )
33 |
34 | _GPT2_TOKENIZER = "gpt2"
35 | tokenizer = transformers.GPT2Tokenizer.from_pretrained(_GPT2_TOKENIZER)
36 |
37 | ArrayRecordWriter = array_record_module.ArrayRecordWriter
38 | ArrayRecordReader = array_record_module.ArrayRecordReader
39 |
40 |
41 | def _int64_feature(value):
42 | """Returns an int64_list from a bool / enum / int / uint."""
43 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
44 |
45 |
46 | ds_output_file_train = "./data_dir/openwebtext_splits_1024_train"
47 | ds_output_file_eval = "./data_dir/openwebtext_splits_1024_eval"
48 |
49 | n_examples = 8013769 # tiny: 2; small: 10_000; full: 8013769
50 | save_every_examples = 10_000
51 | block_size = 1024 # size of the chunk
52 |
53 | data_iter = (iter(source))
54 |
55 | all_tokens = []
56 | count = 0
57 | count_per_save = 0
58 | eval_chunks = []
59 |
60 | writer_train = ArrayRecordWriter(ds_output_file_train, "group_size:1")
61 | writer_eval = ArrayRecordWriter(ds_output_file_eval, "group_size:1")
62 |
63 | for example in data_iter:
64 | tokens = tokenizer(example["text"])["input_ids"]
65 | all_tokens.extend(tokens + [tokenizer.eos_token_id])
66 | count += 1
67 | count_per_save += 1
68 |
69 | # pause to save when having tokenized enough examples for saving.
70 | if count_per_save >= save_every_examples:
71 | # save to disk
72 | saved_length = (len(all_tokens) // block_size) * block_size
73 | chunks = [
74 | all_tokens[i : i + block_size]
75 | for i in range(0, saved_length, block_size)
76 | ]
77 |
78 | print(f"\nsaving chunks @ {count}th example mark...")
79 | np.random.shuffle(chunks)
80 | num_eval = int(len(chunks) * 0.02) # put 2% of chunks into eval split.
81 | for eval_i in tqdm.tqdm(range(num_eval)):
82 | feature = {
83 | "text": _int64_feature(chunks[eval_i]),
84 | }
85 | example_proto = tf.train.Example(
86 | features=tf.train.Features(feature=feature)
87 | )
88 | writer_eval.write(example_proto.SerializeToString())
89 |
90 | for train_i in tqdm.tqdm(range(num_eval, len(chunks))):
91 | feature = {
92 | "text": _int64_feature(chunks[train_i]),
93 | }
94 | example_proto = tf.train.Example(
95 | features=tf.train.Features(feature=feature)
96 | )
97 | writer_train.write(example_proto.SerializeToString())
98 |
99 | # prepare for the next round of tokenize-n-save.
100 | all_tokens = all_tokens[saved_length:]
101 | count_per_save = 0
102 |
103 | # stop when having tokenized enough examples for total #.
104 | if count >= n_examples:
105 | # save to disk
106 | saved_length = (len(all_tokens) // block_size) * block_size
107 | chunks = [
108 | all_tokens[i : i + block_size]
109 | for i in range(0, saved_length, block_size)
110 | ]
111 |
112 | print(f"\nsaving chunks @ {count}th example mark...")
113 | np.random.shuffle(chunks)
114 | num_eval = int(len(chunks) * 0.02) # put 2% of chunks into eval split.
115 | for eval_i in tqdm.tqdm(range(num_eval)):
116 | feature = {
117 | "text": _int64_feature(chunks[eval_i]),
118 | }
119 | example_proto = tf.train.Example(
120 | features=tf.train.Features(feature=feature)
121 | )
122 | writer_eval.write(example_proto.SerializeToString())
123 |
124 | for train_i in tqdm.tqdm(range(num_eval, len(chunks))):
125 | feature = {
126 | "text": _int64_feature(chunks[train_i]),
127 | }
128 | example_proto = tf.train.Example(
129 | features=tf.train.Features(feature=feature)
130 | )
131 | writer_train.write(example_proto.SerializeToString())
132 | break
133 |
134 | writer_train.close()
135 | writer_eval.close()
136 |
--------------------------------------------------------------------------------
/requirements_gpu.txt:
--------------------------------------------------------------------------------
1 | clu
2 | datasets
3 | distrax
4 | grain
5 | matplotlib
6 | seaborn
7 | tensorflow
8 | tensorflow-datasets
9 | tf-keras
10 | transformers
11 | # for jax on GPU
12 | --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
13 | jax[cuda]
14 |
--------------------------------------------------------------------------------
/requirements_tpu.txt:
--------------------------------------------------------------------------------
1 | clu
2 | datasets
3 | distrax
4 | grain
5 | matplotlib
6 | seaborn
7 | tensorflow
8 | tensorflow-datasets
9 | tf-keras
10 | transformers
11 | # for jax on TPU
12 | --find-links https://storage.googleapis.com/jax-releases/jax_releases.html
13 | --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html
14 | jax[tpu]
15 |
--------------------------------------------------------------------------------
/run_gcp.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Copyright 2025 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 |
17 | export PYTHONPATH="$PYTHONPATH:$(pwd)"
18 | source md4_venv/bin/activate
19 |
20 | EXPT_DIR="$(pwd)"/expt
21 | python md4/main.py \
22 | --config=md4/configs/md4/fineweb_edu.py \
23 | --sharded=false \
24 | --workdir=${EXPT_DIR}
25 |
--------------------------------------------------------------------------------