├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── hessian
├── model_debugger.py
├── model_debugger_callback.py
├── precondition.py
├── test_model_debugger.py
└── test_precondition.py
├── init2winit
├── __init__.py
├── base_callback.py
├── callbacks.py
├── checkpoint.py
├── dataset_lib
│ ├── __init__.py
│ ├── autoaugment.py
│ ├── criteo_terabyte_dataset.py
│ ├── data_selectors.py
│ ├── data_utils.py
│ ├── datasets.py
│ ├── fake_dataset.py
│ ├── fastmri_dataset.py
│ ├── image_preprocessing.py
│ ├── imagenet_dataset.py
│ ├── imagenet_preprocessing.py
│ ├── librispeech.py
│ ├── librispeech_input_pipeline.py
│ ├── lm1b_input_pipeline_v2.py
│ ├── lm1b_v2.py
│ ├── mlperf_imagenet_dataset.py
│ ├── mlperf_input_pipeline.py
│ ├── mt_pipeline.py
│ ├── mt_pipeline_test.py
│ ├── mt_tokenizer.py
│ ├── nanodo_c4.py
│ ├── nanodo_data_loader_shared.py
│ ├── nanodo_fineweb_edu.py
│ ├── nqm_noise.py
│ ├── ogbg_molpcba.py
│ ├── pg19.py
│ ├── protein_vocab.py
│ ├── proteins.py
│ ├── small_image_datasets.py
│ ├── spm_tokenizer.py
│ ├── test_data_utils.py
│ ├── test_datasets.py
│ ├── test_ogbg_molpcba.py
│ ├── test_small_image_datasets.py
│ ├── test_wikitext_tokenizer.py
│ ├── translate_wmt.py
│ ├── wikitext103.py
│ ├── wikitext103_input_pipeline.py
│ ├── wikitext103_spm.py
│ ├── wikitext2.py
│ ├── wikitext2_input_pipeline.py
│ └── wikitext_tokenizer.py
├── gradient_statistics_callback.py
├── hyperparameters.py
├── init_lib
│ ├── __init__.py
│ ├── initializers.py
│ ├── meta_init.py
│ ├── sparse_init.py
│ └── test_initializers.py
├── main.py
├── model_lib
│ ├── __init__.py
│ ├── adabelief_densenet.py
│ ├── adabelief_resnet.py
│ ├── adabelief_vgg.py
│ ├── attention.py
│ ├── autoencoder.py
│ ├── base_model.py
│ ├── binarize_layers.py
│ ├── conformer.py
│ ├── convolutional_autoencoder.py
│ ├── deepspeech.py
│ ├── dlrm.py
│ ├── fully_connected.py
│ ├── gnn.py
│ ├── librispeech_preprocessor.py
│ ├── local_attention_transformer.py
│ ├── losses.py
│ ├── lstm.py
│ ├── lstm_lm.py
│ ├── max_pooling_cnn.py
│ ├── metrics.py
│ ├── mlperf_resnet.py
│ ├── model_utils.py
│ ├── models.py
│ ├── nanodo.py
│ ├── normalization.py
│ ├── nqm.py
│ ├── partition_tree.py
│ ├── resnet.py
│ ├── simple_cnn.py
│ ├── spectrum_augmenter.py
│ ├── test_local_attention_transformer.py
│ ├── test_losses.py
│ ├── test_metrics.py
│ ├── test_models.py
│ ├── test_normalization.py
│ ├── transformer_lm.py
│ ├── transformer_stu_lm.py
│ ├── transformer_stu_tensordot_lm.py
│ ├── unet.py
│ ├── vit.py
│ ├── wide_resnet.py
│ ├── xformer_translate.py
│ ├── xformer_translate_binary.py
│ └── xformer_translate_mlc_variant.py
├── mt_eval
│ ├── decode.py
│ ├── eval_utils.py
│ ├── inference.py
│ ├── main.py
│ └── mt_callback.py
├── optimizer_lib
│ ├── __init__.py
│ ├── factor_sam.py
│ ├── gradient_accumulator.py
│ ├── kitchen_sink
│ │ ├── __init__.py
│ │ └── _src
│ │ │ ├── alias.py
│ │ │ ├── combine.py
│ │ │ ├── core.py
│ │ │ ├── mask.py
│ │ │ ├── preconditioner.py
│ │ │ ├── test_core.py
│ │ │ ├── test_mask.py
│ │ │ ├── test_preconditioner.py
│ │ │ ├── test_transform.py
│ │ │ ├── transform.py
│ │ │ └── utils.py
│ ├── linalg
│ │ ├── README.md
│ │ ├── low_rank_root_update.py
│ │ ├── low_rank_root_update_test.py
│ │ ├── paterson_stockmeyer.py
│ │ ├── pth_inv_root_rmn.py
│ │ ├── pth_inv_root_rmn_coefficients.py
│ │ ├── pth_inv_root_rmn_test.py
│ │ └── root_selector.py
│ ├── muon.py
│ ├── online_newton_step.py
│ ├── optimizers.py
│ ├── pax_adafactor.py
│ ├── samuel.py
│ ├── search_subspace.py
│ ├── sharpness_aware_minimization.py
│ ├── sla.py
│ ├── sla_test.py
│ ├── test_gradient_accumulator.py
│ ├── test_optimizers.py
│ ├── test_search_subspace.py
│ ├── test_utils.py
│ └── utils.py
├── projects
│ └── optlrschedule
│ │ ├── README.md
│ │ ├── notebook_utils
│ │ ├── pandas_util.py
│ │ ├── parquet_util.py
│ │ ├── plot_util.py
│ │ ├── schedule_util.py
│ │ └── test_pandas_util.py
│ │ ├── run_search_decoupled.py
│ │ ├── scheduler
│ │ ├── base_schedule_family.py
│ │ ├── constant_schedule_family.py
│ │ ├── cosine_schedule_family.py
│ │ ├── cosine_standard_schedule_family.py
│ │ ├── piecewise_schedule_family.py
│ │ ├── rex_schedule_family.py
│ │ ├── schedule_families.py
│ │ ├── smooth_nonmonotonic_schedule_family.py
│ │ ├── sqrt_schedule_family.py
│ │ ├── test_scheduler.py
│ │ ├── twopointslinear_schedule_family.py
│ │ └── twopointsspline_schedule_family.py
│ │ ├── search_algorithm
│ │ ├── coordinate_descent_search.py
│ │ ├── grid_search.py
│ │ ├── random_search.py
│ │ ├── search_algorithms.py
│ │ └── test_search_algorithm.py
│ │ └── workload
│ │ ├── base_workload.py
│ │ ├── cifar10_cnn.py
│ │ ├── datasets
│ │ └── wikitext_103.py
│ │ ├── linear_regression.py
│ │ ├── optimizers.py
│ │ ├── test_workload.py
│ │ ├── wikitext103_transformer.py
│ │ └── workloads.py
├── references.md
├── schedules.py
├── shared_test_utilities.py
├── test_checkpoint.py
├── test_hyperparameters.py
├── test_schedules.py
├── test_training_metrics_grabber.py
├── test_utils.py
├── testdata
│ └── wikitext_tokenizer_fake_data.txt
├── tools
│ └── inspect_dataset.py
├── trainer_lib
│ ├── base_trainer.py
│ ├── test_trainer.py
│ ├── trainer.py
│ ├── trainer_utils.py
│ └── trainers.py
├── training_metrics_grabber.py
└── utils.py
└── setup.py
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement (CLA). You (or your employer) retain the copyright to your
10 | contribution; this simply gives us permission to use and redistribute your
11 | contributions as part of the project. Head over to
12 | to see your current agreements on file or
13 | to sign a new one.
14 |
15 | You generally only need to submit a CLA once, so if you've already submitted one
16 | (even if it was for a different project), you probably don't need to do it
17 | again.
18 |
19 | ## Code Reviews
20 |
21 | All submissions, including submissions by project members, require review. We
22 | use GitHub pull requests for this purpose. Consult
23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
24 | information on using pull requests.
25 |
26 | ## Community Guidelines
27 |
28 | This project follows
29 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/).
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # init2winit
2 |
3 | A Jax/Flax codebase for running deterministic, scalable, and well-documented deep learning experiments, with a particular emphasis on neural network initialization, optimization, and tuning experiments.
4 |
5 | There is not yet a stable version (nor an official release of this library).
6 | All APIs are subject to change.
7 |
8 | This is a research project, not an official Google product.
9 |
10 |
11 | ## Installation
12 | The current development version requires Python 3.6-3.8.
13 |
14 | To install the latest development version inside a virtual environment, run
15 |
16 | ```
17 | python3 -m venv env-i2w
18 | source env-i2w/bin/activate
19 | pip install --upgrade pip
20 | pip install "git+https://github.com/google/init2winit.git#egg=init2winit"
21 | pip install --upgrade jax jaxlib==0.1.66+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
22 | ```
23 |
24 | where `cuda111` corresponds to the installed version of CUDA. For more Jax install information see the [Jax README](https://github.com/google/jax#installation).
25 |
26 | ## Usage
27 |
28 | An example MNIST experiment can be run with the following command:
29 |
30 | ```sh
31 | python3 main.py \
32 | --experiment_dir=/tmp/test_mnist \
33 | --model=fully_connected \
34 | --dataset=mnist \
35 | --num_train_steps=10
36 | ```
37 |
38 | For local debugging we recommend using the `fake` dataset:
39 |
40 | ```sh
41 | python3 main.py \
42 | --experiment_dir=/tmp/test_fake \
43 | --num_train_steps=10 \
44 | --dataset=fake \
45 | --hparam_overrides='{"input_shape": [28, 28, 1], "output_shape": [10]}'
46 | ```
47 |
48 | The `hparam_overrides` accepts a serialized JSON object with hyperparameter names/values to use. See the flags in `main.py` for more information on possible configurations.
49 |
50 | See the [`dataset_lib`](https://github.com/google/init2winit/tree/master/init2winit/dataset_lib) and [`model_lib`](https://github.com/google/init2winit/tree/master/init2winit/model_lib) directories for currently implemented datasets and models.
51 |
52 |
53 | ## Citing
54 | To cite this repository:
55 |
56 | ```bibtex
57 | @software{init2winit2021github,
58 | author = {Justin M. Gilmer and George E. Dahl and Zachary Nado and Priya Kasimbeg and Sourabh Medapati},
59 | title = {{init2winit}: a JAX codebase for initialization, optimization, and tuning research},
60 | url = {http://github.com/google/init2winit},
61 | version = {0.0.2},
62 | year = {2023},
63 | }
64 | ```
65 |
66 | For a list of references to the models, datasets, and techniques implemented in this codebase, see [`references.md`](https://github.com/google/init2winit/tree/master/init2winit/references.md).
67 |
68 |
69 | ## Contributors
70 | Contributors (past and present):
71 |
72 | - Ankush Garg
73 | - Behrooz Ghorbani
74 | - Cheolmin Kim
75 | - David Cardoze
76 | - George E. Dahl
77 | - Justin M. Gilmer
78 | - Michal Badura
79 | - Priya Kasimbeg
80 | - Rohan Anil
81 | - Sourabh Medapati
82 | - Sneha Kudugunta
83 | - Varun Godbole
84 | - Zachary Nado
85 | - Vlad Feinberg
86 | - Derrick Xin
87 | - Naman Agarwal
88 | - Daniel Suo
89 | - Bilal Khan
90 | - Jeremy Cohen
91 | - Kacper Krasowiak
92 |
93 |
--------------------------------------------------------------------------------
/hessian/model_debugger_callback.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Callback which runs the model debugger."""
17 |
18 | import functools
19 | import os
20 |
21 | import flax
22 | from init2winit import utils
23 | from init2winit.dataset_lib import data_utils
24 | from init2winit.hessian import model_debugger
25 | import jax
26 | import jax.numpy as jnp
27 |
28 | DEFAULT_CONFIG = {
29 | 'name': 'model_debugger',
30 | }
31 |
32 |
33 | def get_grad(params,
34 | batch,
35 | rng,
36 | batch_stats=None,
37 | module_flags=None,
38 | training_cost=None):
39 | """Single step of the training loop.
40 |
41 | Args:
42 | params: the Flax param pytree. batch norm statistics.
43 | batch: the per-device batch of data to process.
44 | rng: the RNG used for calling the model. Assumes the step and device index
45 | has already been folded in.
46 | batch_stats: Same as in trainer.py
47 | module_flags: Used in the skip analysis.
48 | training_cost: a function used to calculate the training objective that will
49 | be differentiated to generate updates. Takes (`params`, `batch_stats`,
50 | `batch`, `rng`) as inputs.
51 |
52 | Returns:
53 | Gradient of the given loss.
54 | """
55 |
56 | if module_flags is not None:
57 | kwargs = {'module_flags': module_flags}
58 | else:
59 | kwargs = {}
60 | def opt_cost(params):
61 | return training_cost(
62 | params,
63 | batch,
64 | batch_stats=batch_stats,
65 | dropout_rng=rng,
66 | **kwargs)
67 |
68 | grad_fn = jax.value_and_grad(opt_cost, has_aux=True)
69 | _, grad = grad_fn(params)
70 |
71 | grad = jax.lax.pmean(grad, axis_name='batch')
72 | return grad
73 |
74 |
75 | class ModelDebugCallback:
76 | """Used to run the hessian eval in the trainer binary."""
77 |
78 | def __init__(self, model, params, batch_stats, optimizer_state,
79 | optimizer_update_fn, dataset, hps, callback_config, train_dir,
80 | rng):
81 | del hps
82 | del params
83 | del optimizer_state
84 | del optimizer_update_fn
85 | checkpoint_dir = os.path.join(train_dir, 'checkpoints')
86 | # copy batch_stats as we close over it, and it gets modified.
87 | self.dataset = dataset
88 | checkpoint_dir = os.path.join(train_dir, 'checkpoints')
89 | pytree_path = os.path.join(checkpoint_dir, 'debugger')
90 | logger = utils.MetricLogger(pytree_path=pytree_path)
91 |
92 | get_act_stats_fn = model_debugger.create_forward_pass_stats_fn(
93 | model.apply_on_batch,
94 | capture_activation_norms=True,
95 | sown_collection_names=callback_config.get('sown_collection_names'))
96 | batch_stats = jax.tree.map(lambda x: x[:][0], batch_stats)
97 | grad_fn = functools.partial(
98 | get_grad,
99 | batch_stats=batch_stats,
100 | training_cost=model.training_cost,
101 | )
102 | debugger = model_debugger.ModelDebugger(
103 | use_pmap=True,
104 | forward_pass=get_act_stats_fn,
105 | metrics_logger=logger,
106 | grad_fn=grad_fn,
107 | skip_flags=callback_config.get('skip_flags'),
108 | skip_groups=callback_config.get('skip_groups'))
109 | # pmap functions for the training loop
110 | # in_axes = (params = 0, batch_stats = 0, batch = 0, step = None,
111 | # lr = None, rng = None, local_device_index = 0, training_metrics_grabber=0,
112 | # training_metrics_grabber, training_cost )
113 | # Also, we can donate buffers for 'optimizer', 'batch_stats',
114 | # 'batch' and 'training_metrics_grabber' for update's pmapped computation.
115 | self.debugger = debugger
116 | self.logger = logger
117 | self.dataset = dataset
118 |
119 | batch = next(dataset.train_iterator_fn())
120 | self.batch = data_utils.shard(batch)
121 |
122 | self.batch_rng = flax.jax_utils.replicate(rng)
123 |
124 | def run_eval(self, params, batch_stats, optimizer_state, global_step):
125 | """Runs ModelDebugger.full_eval on the given params.
126 |
127 | Note, the full lanczos tridiagonal matrix is saved via the logger to
128 | train_dir/checkpoints/config['name'].
129 |
130 | Args:
131 | params: Replicated model parameter tree.
132 | batch_stats: Replicated batch_stats from the trainer.
133 | optimizer_state: Replicated optimizer state from the trainer.
134 | global_step: Current training step.
135 |
136 | Returns:
137 | Max eigenvalue of the loss (full tridiag is saved to disk).
138 | """
139 | del optimizer_state
140 | del batch_stats
141 | p_norms = jax.tree.map(lambda x: jnp.linalg.norm(x[0].reshape(-1))**2,
142 | params)
143 |
144 | self.debugger.full_eval(
145 | step=global_step,
146 | params=params,
147 | param_norms_sql2=p_norms,
148 | batch=self.batch,
149 | rng=self.batch_rng)
150 |
151 | return {}
152 |
--------------------------------------------------------------------------------
/init2winit/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/init2winit/base_callback.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Infrastructure for arbitrary code to be run during training.
17 |
18 | Callbacks can be stateful, and in trainer are meant to be called as follows:
19 |
20 |
21 | callback_builder = callbacks.get_callback(config['callback_name'])
22 | callback = callback_builder(model, params, batch_stats, optimizer_state,
23 | dataset, hps, config, train_dir, rng, mesh)
24 |
25 | callback_metrics = callback.run_eval(params, batch_stats,
26 | optimizer_state, global_step).
27 |
28 | We require that the config has a field 'callback_name', which the trainer
29 | uses to determine which callbacks to run. The dictionary, callback_metrics
30 | should be scalar valued, and will be automatically added to the existing trainer
31 | scalar metrics.
32 | """
33 |
34 | # TODO(gilmer) Add serialization so that we can checkpoint callback state.
35 |
36 |
37 | class BaseCallBack:
38 | """Base callback to specify the required API."""
39 |
40 | def __init__(self, model, params, batch_stats, optimizer_state,
41 | optimizer_update_fn, dataset, hps, callback_config, train_dir,
42 | rng, mesh):
43 | """Defines the API for callback construction."""
44 | pass
45 |
46 | def run_eval(self, params, batch_stats, optimizer_state, global_step):
47 | """Define the API for running the callback during eval.
48 |
49 | Args:
50 | params: Replicated params from the trainer.
51 | batch_stats: Replicated batch_stats from the trainer.
52 | optimizer_state: Replicated optimizer state from the trainer.
53 | global_step: Current training step.
54 |
55 | Returns:
56 | A dictionary of scalar metrics. Note, any existing metric returned by
57 | trainer.evaluate are forbidden, e.g. including 'train/ce_loss' will
58 | resort in trainer throwing an exception.
59 | """
60 | raise NotImplementedError('Subclasses must implement run_eval().')
61 |
--------------------------------------------------------------------------------
/init2winit/callbacks.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Registry for the available callbacks."""
17 |
18 | from init2winit import gradient_statistics_callback
19 | from init2winit.hessian import model_debugger_callback
20 | from init2winit.mt_eval import mt_callback
21 |
22 |
23 | _ALL_CALLBACKS = {
24 | 'mt': mt_callback.MTEvaluationCallback,
25 | 'model_debugger': model_debugger_callback.ModelDebugCallback,
26 | 'gradient_statistics': (
27 | gradient_statistics_callback.GradientStatisticsCallback
28 | ),
29 | }
30 |
31 |
32 | def get_callback(callback_name):
33 | """Get the corresponding callback builder based on the callback_name.
34 |
35 | Args:
36 | callback_name: (str) e.g. mt.
37 |
38 | Returns:
39 | Callback builder class.
40 | Raises:
41 | ValueError if callback_name is unrecognized.
42 | """
43 | try:
44 | return _ALL_CALLBACKS[callback_name]
45 | except KeyError:
46 | raise ValueError('Unrecognized callback name: {}'.format(callback_name))
47 |
--------------------------------------------------------------------------------
/init2winit/dataset_lib/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/init2winit/dataset_lib/data_selectors.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Definitions for supported data selection functions."""
17 |
18 |
19 | def noop(
20 | dataset_iterator,
21 | optimizer_state,
22 | params,
23 | batch_stats,
24 | hps,
25 | global_step,
26 | constant_base_rng):
27 | """An example no-op data selector that just yields the next batch.
28 |
29 | Args:
30 | dataset_iterator: the (preprocessed, batched, prefetched) dataset iterator.
31 | optimizer_state: the current optimizer state.
32 | params: the model parameters.
33 | batch_stats: the model batch statistics.
34 | hps: the experiment hyperparameters.
35 | global_step: the current global step.
36 | constant_base_rng: the RNG used for the experiment. IMPORTANT NOTE: this
37 | will be constant for all calls to this function, in order to get a unique
38 | RNG each time we need to do
39 | `rng = jax.random.fold_in(constant_base_rng, global_step)`.
40 |
41 | Yields:
42 | A batch of data.
43 | """
44 | del optimizer_state
45 | del params
46 | del batch_stats
47 | del hps
48 | del global_step
49 | del constant_base_rng
50 | yield from dataset_iterator
51 |
52 |
53 | def data_echoing(
54 | dataset_iterator,
55 | optimizer_state,
56 | params,
57 | batch_stats,
58 | hps,
59 | global_step,
60 | constant_base_rng):
61 | """An example data echoing selector.
62 |
63 | Args:
64 | dataset_iterator: the (preprocessed, batched, prefetched) dataset iterator.
65 | optimizer_state: the current optimizer state.
66 | params: the model parameters.
67 | batch_stats: the model batch statistics.
68 | hps: the experiment hyperparameters.
69 | global_step: the current global step.
70 | constant_base_rng: the RNG used for the experiment. IMPORTANT NOTE: this
71 | will be constant for all calls to this function, in order to get a unique
72 | RNG each time we need to do
73 | `rng = jax.random.fold_in(constant_base_rng, global_step)`.
74 |
75 | Yields:
76 | A batch of data.
77 | """
78 | del optimizer_state
79 | del params
80 | del batch_stats
81 | del global_step
82 | del constant_base_rng
83 | for x in dataset_iterator:
84 | for _ in range(hps.num_data_echoes):
85 | yield x
86 |
87 |
88 | ALL_SELECTORS = {
89 | 'noop': noop,
90 | 'data_echoing': data_echoing,
91 | }
92 |
--------------------------------------------------------------------------------
/init2winit/dataset_lib/fake_dataset.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Fake image input pipeline. Returns the same batch of ones over and over."""
17 | import copy
18 |
19 | from init2winit.dataset_lib import data_utils
20 | import jax
21 | import jax.numpy as jnp
22 | from ml_collections.config_dict import config_dict
23 | import numpy as np
24 |
25 |
26 | TRAIN_IMAGES = 1281167
27 | EVAL_IMAGES = 50000
28 |
29 |
30 | NUM_CLASSES = 1000
31 | IMAGE_SIZE = 224
32 |
33 |
34 | DEFAULT_HPARAMS = config_dict.ConfigDict(dict(
35 | input_shape=(224, 224, 3),
36 | output_shape=(NUM_CLASSES,),
37 | train_size=TRAIN_IMAGES,
38 | valid_size=EVAL_IMAGES))
39 |
40 | METADATA = {
41 | 'apply_one_hot_in_loss': False,
42 | }
43 |
44 |
45 | def get_fake_batch(hps):
46 | """Generate batches of images of all ones and one-hot labels."""
47 | batch_size = hps.batch_size
48 | input_shape = hps.input_shape
49 | num_classes = hps.output_shape[0]
50 | train_input_shape = (batch_size, *input_shape)
51 | images = jnp.ones(train_input_shape, dtype=jnp.float32)
52 | labels = jax.nn.one_hot(
53 | np.zeros((batch_size,)), num_classes, dtype=jnp.int32)
54 | batch = {
55 | 'inputs': images,
56 | 'targets': labels,
57 | 'weights': jnp.ones(batch_size, dtype=images.dtype),
58 | }
59 | return batch
60 |
61 |
62 | def get_fake(shuffle_rng, batch_size, eval_batch_size, hps=None):
63 | """Data generators for imagenet."""
64 | del shuffle_rng
65 | per_host_batch_size = batch_size // jax.process_count()
66 | per_host_eval_batch_size = eval_batch_size // jax.process_count()
67 |
68 | train_hps = copy.copy(hps)
69 | train_hps.unlock()
70 | train_hps.batch_size = per_host_batch_size
71 | fake_train_batch = get_fake_batch(train_hps)
72 |
73 | test_hps = copy.copy(hps)
74 | test_hps.unlock()
75 | test_hps.batch_size = per_host_eval_batch_size
76 | fake_test_batch = get_fake_batch(test_hps)
77 |
78 | def train_iterator_fn():
79 | while True:
80 | yield fake_train_batch
81 |
82 | def valid_epoch(epoch, num_batches=None):
83 | del num_batches
84 | del epoch
85 | # Note that we do // beacuse we do not support partial batching for the fake
86 | # dataset.
87 | for _ in range(hps.valid_size // eval_batch_size):
88 | yield fake_test_batch
89 |
90 | # pylint: disable=unreachable
91 | def eval_train_epoch(*args, **kwargs):
92 | del args
93 | del kwargs
94 | return
95 | yield # This yield is needed to make this a valid (null) iterator.
96 | # pylint: enable=unreachable
97 | # pylint: disable=unreachable
98 |
99 | def test_epoch(*args, **kwargs):
100 | del args
101 | del kwargs
102 | return
103 | yield # This yield is needed to make this a valid (null) iterator.
104 | # pylint: enable=unreachable
105 |
106 | return data_utils.Dataset(
107 | train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch)
108 |
109 |
--------------------------------------------------------------------------------
/init2winit/dataset_lib/librispeech.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """LM1B input pipeline."""
17 |
18 | import itertools
19 |
20 | from init2winit.dataset_lib import data_utils
21 | from init2winit.dataset_lib import librispeech_input_pipeline
22 | from init2winit.dataset_lib.data_utils import Dataset
23 | import jax
24 | from ml_collections.config_dict import config_dict
25 | import numpy as np
26 |
27 | MAX_INPUT_LENGTH = 320000
28 | MAX_TARGET_LENGTH = 256
29 | VOCAB_SIZE = 1024
30 |
31 | DEFAULT_HPARAMS = config_dict.ConfigDict(
32 | dict(
33 | max_input_length=MAX_INPUT_LENGTH,
34 | max_target_length=MAX_TARGET_LENGTH,
35 | train_split='train_clean100+train_clean360+train_other500',
36 | eval_split='dev_clean+dev_other',
37 | test_split='test_clean',
38 | input_shape=[(MAX_INPUT_LENGTH,), (MAX_INPUT_LENGTH,)],
39 | output_shape=(-1, VOCAB_SIZE),
40 | train_size=281241,
41 | tokenizer_vocab_path='',
42 | tokenizer_type='SPM'))
43 |
44 | METADATA = {'apply_one_hot_in_loss': False}
45 |
46 |
47 | def _batch_to_dict(batch):
48 | batch_np = data_utils.tf_to_numpy(batch)
49 | return batch_np
50 |
51 |
52 | def get_librispeech(shuffle_rng, batch_size, eval_batch_size=None, hps=None):
53 | """Wrapper to conform to the general dataset API."""
54 | process_count = jax.process_count()
55 | if batch_size % process_count != 0:
56 | raise ValueError('process_count={} must divide batch_size={}.'.format(
57 | process_count, batch_size))
58 |
59 | per_host_batch_size = batch_size // process_count
60 | if eval_batch_size is None:
61 | eval_batch_size = batch_size
62 |
63 | if eval_batch_size % process_count != 0:
64 | raise ValueError('process_count={} must divide eval_batch_size={}.'.format(
65 | process_count, eval_batch_size))
66 | per_host_eval_batch_size = eval_batch_size // process_count
67 |
68 | return _get_librispeech(hps, per_host_batch_size, per_host_eval_batch_size,
69 | shuffle_rng)
70 |
71 |
72 | def _get_librispeech(hps, per_host_batch_size, per_host_eval_batch_size,
73 | shuffle_rng):
74 | """Data generators for lm1b."""
75 | n_devices = jax.local_device_count()
76 | if per_host_batch_size % n_devices != 0:
77 | raise ValueError('n_devices={} must divide per_host_batch_size={}.'.format(
78 | n_devices, per_host_batch_size))
79 |
80 | if per_host_eval_batch_size % n_devices != 0:
81 | raise ValueError(
82 | 'n_devices={} must divide per_host_eval_batch_size={}.'.format(
83 | n_devices, per_host_eval_batch_size))
84 |
85 | train_ds, eval_ds, test_ds = librispeech_input_pipeline.get_librispeech_datasets(
86 | hps, per_host_batch_size, per_host_eval_batch_size, shuffle_rng)
87 |
88 | def train_iterator_fn():
89 | for batch in iter(train_ds):
90 | yield _batch_to_dict(batch)
91 |
92 | def eval_train_epoch(num_batches=None):
93 | eval_train_iter = iter(train_ds)
94 | for batch in itertools.islice(eval_train_iter, num_batches):
95 | yield _batch_to_dict(batch)
96 |
97 | def valid_epoch(num_batches=None):
98 | valid_iter = iter(eval_ds)
99 | for batch in itertools.islice(valid_iter, num_batches):
100 | batch = _batch_to_dict(batch)
101 | yield data_utils.maybe_pad_batch(
102 | batch, desired_batch_size=per_host_eval_batch_size, padding_value=1.0)
103 |
104 | def test_epoch(num_batches=None):
105 | test_iter = iter(test_ds)
106 | for batch in itertools.islice(test_iter, num_batches):
107 | batch = _batch_to_dict(batch)
108 | yield data_utils.maybe_pad_batch(
109 | batch, desired_batch_size=per_host_eval_batch_size, padding_value=1.0)
110 |
111 | # pylint: enable=unreachable
112 | return Dataset(train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch)
113 |
114 |
115 | def get_fake_batch(hps):
116 | return {
117 | 'inputs':
118 | np.ones((hps.batch_size, hps.max_input_length),
119 | dtype=hps.model_dtype),
120 | 'input_paddings':
121 | np.ones((hps.batch_size, hps.max_input_length),
122 | dtype=hps.model_dtype),
123 | 'targets':
124 | np.ones((hps.batch_size, hps.max_target_length),
125 | dtype=hps.model_dtype),
126 | 'target_paddings':
127 | np.ones((hps.batch_size, hps.max_target_length),
128 | dtype=hps.model_dtype),
129 | }
130 |
--------------------------------------------------------------------------------
/init2winit/dataset_lib/mlperf_imagenet_dataset.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """ImageNet input pipeline with MLPerf preprocessing."""
17 |
18 | import itertools
19 |
20 | from init2winit.dataset_lib import data_utils
21 | from init2winit.dataset_lib import imagenet_dataset
22 | from init2winit.dataset_lib import mlperf_input_pipeline
23 | import jax
24 | from ml_collections.config_dict import config_dict
25 | import numpy as np
26 | import tensorflow.compat.v2 as tf
27 |
28 |
29 | DEFAULT_HPARAMS = config_dict.ConfigDict(dict(
30 | input_shape=(224, 224, 3),
31 | output_shape=(1000,),
32 | train_size=1281167,
33 | valid_size=50000,
34 | test_size=10000, # ImageNet-v2.
35 | use_imagenetv2_test=True))
36 |
37 | METADATA = {
38 | 'apply_one_hot_in_loss': False,
39 | }
40 |
41 |
42 | def get_mlperf_imagenet(rng,
43 | batch_size,
44 | eval_batch_size,
45 | hps=None):
46 | """Data generators for imagenet.
47 |
48 | Args:
49 | rng: RNG seed that is split into a shuffle seed and a seed that is folded
50 | into a per-example seed.
51 | batch_size: the *global* batch size used for training.
52 | eval_batch_size: the *global* batch size used for evaluation.
53 | hps: the hparams for the experiment, only required field is valid_size.
54 |
55 | Returns:
56 | A data_utils.Dataset for the MLPerf version of ImageNet.
57 | """
58 | if batch_size % jax.device_count() != 0:
59 | raise ValueError(
60 | 'Require batch_size % jax.device_count(), received '
61 | 'batch_size={}, device_count={}.'.format(
62 | batch_size, jax.device_count()))
63 | if eval_batch_size % jax.device_count() != 0:
64 | raise ValueError(
65 | 'Require eval_batch_size % jax.device_count(), received '
66 | 'eval_batch_size={}, device_count={}.'.format(
67 | eval_batch_size, jax.device_count()))
68 | host_batch_size = batch_size // jax.process_count()
69 | eval_host_batch_size = eval_batch_size // jax.process_count()
70 |
71 | max_eval_steps = hps.valid_size // eval_batch_size + 1
72 |
73 | input_dtype = tf.bfloat16
74 | shuffle_buffer_size = 16384
75 |
76 | train_ds = mlperf_input_pipeline.load_split(
77 | host_batch_size,
78 | dtype=input_dtype,
79 | split='train',
80 | rng=rng,
81 | shuffle_size=shuffle_buffer_size)
82 |
83 | eval_train_ds = mlperf_input_pipeline.load_split(
84 | host_batch_size,
85 | dtype=input_dtype,
86 | split='eval_train',
87 | rng=rng,
88 | shuffle_size=shuffle_buffer_size)
89 |
90 | eval_ds = mlperf_input_pipeline.load_split(
91 | eval_host_batch_size,
92 | dtype=input_dtype,
93 | split='validation',
94 | rng=rng,
95 | shuffle_size=shuffle_buffer_size)
96 |
97 | # We do not have TFRecords of ImageNet-v2 in the same format as the
98 | # train/validation splits above, so we reuse the same test split from the
99 | # non-MLPerf pipeline.
100 | test_ds = None
101 | if hps.use_imagenetv2_test:
102 | test_ds = imagenet_dataset.load_split(
103 | eval_host_batch_size,
104 | 'test',
105 | hps=hps,
106 | image_size=224,
107 | tfds_dataset_name='imagenet_v2/matched-frequency')
108 |
109 | # We cannot use tfds.as_numpy because this calls tensor.numpy() which does an
110 | # additional copy of the tensor, instead we call tensor._numpy() below.
111 | def train_iterator_fn():
112 | return data_utils.iterator_as_numpy(iter(train_ds))
113 |
114 | def eval_train_epoch(num_batches=None):
115 | if num_batches is None:
116 | num_batches = 0
117 | eval_train_iter = iter(eval_train_ds)
118 | np_iter = data_utils.iterator_as_numpy(
119 | itertools.islice(eval_train_iter, num_batches))
120 | for batch in np_iter:
121 | yield data_utils.maybe_pad_batch(batch, eval_host_batch_size)
122 |
123 | def valid_epoch(num_batches=None):
124 | if num_batches is None:
125 | num_batches = max_eval_steps
126 | valid_iter = iter(eval_ds)
127 | np_iter = data_utils.iterator_as_numpy(
128 | itertools.islice(valid_iter, num_batches))
129 | for batch in np_iter:
130 | yield data_utils.maybe_pad_batch(batch, eval_host_batch_size)
131 |
132 | def test_epoch(num_batches=None):
133 | if test_ds:
134 | test_iter = iter(test_ds)
135 | np_iter = data_utils.iterator_as_numpy(
136 | itertools.islice(test_iter, num_batches))
137 | for batch in np_iter:
138 | yield data_utils.maybe_pad_batch(batch, eval_host_batch_size)
139 | else:
140 | # pylint: disable=unreachable
141 | return
142 | yield # This yield is needed to make this a valid (null) iterator.
143 | # pylint: enable=unreachable
144 |
145 | return data_utils.Dataset(
146 | train_iterator_fn,
147 | eval_train_epoch,
148 | valid_epoch,
149 | test_epoch)
150 |
151 |
152 | def get_fake_batch(hps):
153 | return {
154 | 'inputs':
155 | np.ones((hps.batch_size, *hps.input_shape), dtype=hps.model_dtype),
156 | 'targets':
157 | np.ones((hps.batch_size, *hps.output_shape), dtype=hps.model_dtype),
158 | 'weights':
159 | np.ones((hps.batch_size,), dtype=hps.model_dtype),
160 | }
161 |
--------------------------------------------------------------------------------
/init2winit/dataset_lib/nanodo_data_loader_shared.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Common utils for nanodo data loaders.
17 |
18 | Implementation based on:
19 | https://github.com/google-deepmind/nanodo/blob/main/nanodo/data.py
20 | """
21 |
22 | from collections.abc import Mapping, Sequence
23 | import dataclasses
24 | import enum
25 | from typing import Iterable, Iterator, Union
26 |
27 | import grain.python as grain
28 | import jax
29 | import jax.numpy as jnp
30 |
31 | import numpy as np
32 |
33 | import sentencepiece as spm
34 |
35 |
36 | PAD_ID = 0
37 |
38 |
39 | class Preprocess(enum.Enum):
40 | NOAM_PACKED = 1
41 | PADDED = 2
42 |
43 |
44 | ### pure python helpers for use with grain ###
45 | # Need this because we can't pickle SentencePieceProcessor object
46 | class SPTokenizer:
47 | """Wrapper class for SentencePiece tokenizer."""
48 |
49 | def __init__(self, vocab_path):
50 | self._tokenizer = None
51 | self._vocab_path = vocab_path
52 |
53 | def get_tokenizer(self) -> spm.SentencePieceProcessor:
54 | if not self._tokenizer:
55 | self._tokenizer = get_py_tokenizer(self._vocab_path)
56 | return self._tokenizer
57 |
58 |
59 | class SentencePieceByteTokenizer(spm.SentencePieceProcessor):
60 | """A simple Byte level tokenizer."""
61 |
62 | def eos_id(self) -> int:
63 | return 1
64 |
65 | def bos_id(self) -> int:
66 | return 2
67 |
68 | def pad_id(self) -> int:
69 | return PAD_ID
70 |
71 | def GetPieceSize(self) -> int:
72 | return 256
73 |
74 | # pylint: disable=invalid-name
75 | def EncodeAsIds(self, text: Union[bytes, str]) -> list[int]:
76 | if isinstance(text, str):
77 | return list(bytes(text, 'utf-8'))
78 | if isinstance(text, bytes):
79 | return [int(x) for x in text]
80 | raise ValueError(f'Invalid text: {text} type={type(text)}')
81 |
82 | def DecodeIds(self, ids: Iterable[int]) -> str:
83 | return bytes(ids).decode('utf-8')
84 | # pylint: enable=invalid-name
85 |
86 |
87 | def py_tokenize(
88 | features: Mapping[str, str],
89 | spt: SPTokenizer,
90 | pad_len: int | None = None,
91 | pad_id: int = PAD_ID,
92 | ) -> Sequence[int]:
93 | """Tokenizes text into ids, optionally pads or truncates to pad_len."""
94 | text = features['text']
95 | tokenizer = spt.get_tokenizer()
96 | bos_id = tokenizer.bos_id()
97 | eos_id = tokenizer.eos_id()
98 | ids = tokenizer.EncodeAsIds(text)
99 |
100 | ids.insert(0, bos_id)
101 | ids.append(eos_id)
102 | if pad_len is not None:
103 | if len(ids) < pad_len:
104 | ids.extend([pad_id] * (pad_len - len(ids)))
105 | elif len(ids) > pad_len:
106 | ids = ids[:pad_len]
107 | return ids
108 |
109 |
110 | @dataclasses.dataclass
111 | class NoamPack:
112 | """Pygrain operation for tokenizing and Noam packing text."""
113 |
114 | context_size: int
115 |
116 | def __call__(
117 | self, idseq_iterator: Iterator[grain.Record]
118 | ) -> Iterator[grain.Record]:
119 | packed_ids = []
120 | for input_record in idseq_iterator:
121 | start = 0
122 | while start < len(input_record.data):
123 | rem_data = input_record.data[start:]
124 | if len(packed_ids) + len(rem_data) < self.context_size:
125 | packed_ids.extend(rem_data) # use rest of example, move-on
126 | break
127 | else:
128 | take = self.context_size - len(packed_ids)
129 | packed_ids.extend(rem_data[:take])
130 | last_record_key = input_record.metadata.remove_record_key()
131 | yield grain.Record(
132 | last_record_key, np.array(packed_ids, dtype=np.int32)
133 | )
134 | start += take
135 | packed_ids = []
136 | # Drop remainder for simplicity.
137 | # We lose the rest of the example on restore.
138 |
139 |
140 | # pylint: disable=invalid-name
141 |
142 |
143 | def get_py_tokenizer(path: str) -> spm.SentencePieceProcessor:
144 | if not path:
145 | # byte tokenizer shortcut
146 | return SentencePieceByteTokenizer()
147 | sp = spm.SentencePieceProcessor()
148 | sp.Load(path)
149 | assert sp.pad_id() == PAD_ID
150 | assert sp.eos_id() != -1
151 | assert sp.bos_id() != -1
152 | return sp
153 |
154 |
155 | def get_in_out(
156 | in_BxL: jax.Array,
157 | pad_id: int = PAD_ID,
158 | ) -> tuple[jax.Array, jax.Array, jax.Array]:
159 | """Returns input, output, and weights for a batch of examples."""
160 | # Assumes input of the form for eval.
161 | x_BxL = in_BxL
162 | y_BxL = jnp.pad(
163 | in_BxL[:, 1:],
164 | ((0, 0), (0, 1)),
165 | mode='constant',
166 | constant_values=pad_id,
167 | )
168 | weights_BxL = jnp.where(y_BxL != pad_id, 1, 0).astype(jnp.float32)
169 |
170 | return x_BxL, y_BxL, weights_BxL
171 |
--------------------------------------------------------------------------------
/init2winit/dataset_lib/nqm_noise.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Data generators for init2winit."""
17 |
18 | from init2winit.dataset_lib import data_utils
19 | import jax.random
20 | from ml_collections.config_dict import config_dict
21 | import numpy as np
22 |
23 |
24 | NQM_HPARAMS = config_dict.ConfigDict(
25 | dict(
26 | train_size=1e10,
27 | valid_size=0,
28 | test_size=0,
29 | input_shape=(100,), # This determines the dimension.
30 | output_shape=(1,),
31 | ))
32 | NQM_METADATA = {
33 | 'apply_one_hot_in_loss': False,
34 | }
35 |
36 |
37 | def get_nqm_noise(shuffle_rng, batch_size, eval_batch_size, hps=None):
38 | """Returns the noise seed for the nqm model.
39 |
40 | NOTE: This dataset is only meant to be used with the nqm model.
41 | This just generates isotropic Gaussian noise of the desired dimension.
42 | The nqm model will then multiple this noise by a matrix D, with the properly
43 | that D^T D = C. This yields noise with gradient covariance C.
44 |
45 | Args:
46 | shuffle_rng: Not used.
47 | batch_size: The global train batch size, used to determine the batch size
48 | yielded from train_epoch().
49 | eval_batch_size: Not used.
50 | hps: Hparams object. We only refer to hps.input_shape to determine the
51 | dimension of the noise.
52 | Returns:
53 | train_epoch, eval_train_epoch, valid_epoch, test_epoch: three generators.
54 | Only train_epoch is used.
55 | """
56 | del eval_batch_size
57 |
58 | per_host_batch_size = batch_size // jax.process_count()
59 | # Should train_rng / eval_rng possibly have different seeds?
60 | seed = data_utils.convert_jax_to_tf_random_seed(shuffle_rng)
61 | train_rng = np.random.RandomState(seed=seed)
62 | eval_rng = np.random.RandomState(seed=seed)
63 |
64 | def train_iterator_fn():
65 | while True:
66 | yield {
67 | 'inputs': train_rng.normal(
68 | size=(per_host_batch_size, *hps.input_shape)
69 | )
70 | }
71 |
72 | def eval_train_epoch(num_batches):
73 | for _ in range(num_batches):
74 | yield {
75 | 'inputs': eval_rng.normal(
76 | size=(per_host_batch_size, *hps.input_shape)
77 | )
78 | }
79 |
80 | # pylint: disable=unreachable
81 | def valid_epoch(*args, **kwargs):
82 | del args
83 | del kwargs
84 | return
85 | yield # This yield is needed to make this a valid (null) iterator.
86 | # pylint: enable=unreachable
87 |
88 | # pylint: disable=unreachable
89 | def test_epoch(*args, **kwargs):
90 | del args
91 | del kwargs
92 | return
93 | yield # This yield is needed to make this a valid (null) iterator.
94 |
95 | # pylint: enable=unreachable
96 |
97 | return data_utils.Dataset(
98 | train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch
99 | )
100 |
--------------------------------------------------------------------------------
/init2winit/dataset_lib/test_data_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Unit tests for datasets.py."""
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 | from init2winit.dataset_lib import data_utils
21 | import numpy as np
22 |
23 | desired_batch_size = 23
24 | batch_axes = [0, 0, 3, 2]
25 | test_names = ['default', 'NHWC', 'HWCN', 'HWNC']
26 | image_formats = [None, 'NHWC', 'HWCN', 'HWNC']
27 | batch_size = 13
28 | width = 11
29 | num_channels = 3
30 | input_shapes = [
31 | (batch_size, width, width, num_channels),
32 | (batch_size, width, width, num_channels),
33 | (width, width, num_channels, batch_size),
34 | (width, width, batch_size, num_channels),
35 | ]
36 | test_parameters = zip(test_names, image_formats, batch_axes, input_shapes)
37 |
38 |
39 | class DataUtilsTest(parameterized.TestCase):
40 | """Unit tests for datasets.py."""
41 |
42 | @parameterized.named_parameters(*test_parameters)
43 | def test_padding(self, image_format, batch_axis, input_shape):
44 | """Test that the shape is the expected padded shape."""
45 | batch = {'inputs': np.ones(input_shape)}
46 | padded_batch = data_utils.maybe_pad_batch(
47 | batch, desired_batch_size, image_format)
48 | expected_shapes = list(input_shape)
49 | expected_shapes[batch_axis] = desired_batch_size
50 | self.assertEqual(padded_batch['inputs'].shape, tuple(expected_shapes))
51 | self.assertEqual(padded_batch['weights'].shape, (desired_batch_size,))
52 |
53 | def test_padding_seq2seq(self):
54 | """Test padding for sequence-to-sequence models."""
55 | input_len_max = 25
56 | input_len_true = 22 # true input_seq_length for each example in batch.
57 | target_len_max = 25
58 | target_len_true = 21 # true target_seq_length for each example in batch.
59 |
60 | inputs_shape = (batch_size, input_len_max)
61 | targets_shape = (batch_size, target_len_max)
62 | batch = {'inputs': np.ones(inputs_shape), 'targets': np.ones(targets_shape)}
63 | batch['inputs'][:, input_len_true:] = 0 # zero-pad extra inputs tokens
64 | batch['targets'][:, target_len_true:] = 0 # zero-pad extra targets tokens
65 | expected_inputs_shape = (desired_batch_size, input_len_max)
66 | expected_targets_shape = (desired_batch_size, target_len_max)
67 | expected_weights_shape = (desired_batch_size, target_len_max)
68 | padded_batch = data_utils.maybe_pad_batch(
69 | batch, desired_batch_size, data_format=None, mask_key='targets')
70 | self.assertEqual(padded_batch['inputs'].shape, expected_inputs_shape)
71 | self.assertEqual(padded_batch['targets'].shape, expected_targets_shape)
72 | self.assertEqual(padded_batch['weights'].shape, expected_weights_shape)
73 |
74 | batch_pad = desired_batch_size - batch_size
75 | expected_weights_array = np.ones((desired_batch_size, target_len_max))
76 | # pad at batch axis
77 | expected_weights_array[-batch_pad:] = 0
78 | # # pad at sequence_len axis
79 | expected_weights_array[:, target_len_true:] = 0
80 | self.assertTrue(
81 | np.array_equal(padded_batch['weights'], expected_weights_array))
82 |
83 |
84 | if __name__ == '__main__':
85 | absltest.main()
86 |
--------------------------------------------------------------------------------
/init2winit/dataset_lib/test_small_image_datasets.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for init2winit.dataset_lib.small_image_datasets."""
17 |
18 | import itertools
19 |
20 | from absl.testing import absltest
21 | from init2winit.dataset_lib import small_image_datasets
22 | from jax import random
23 | from ml_collections.config_dict import config_dict
24 |
25 |
26 | class SmallImageDatasetsTest(absltest.TestCase):
27 | """Unit tests for small_image_datasets.py."""
28 |
29 | def test_cifar10(self):
30 | """Test example generation in CIFAR10 is reproducible."""
31 | dataset = small_image_datasets.get_cifar10(
32 | random.PRNGKey(0), 1, 1,
33 | config_dict.ConfigDict(
34 | dict(
35 | flip_probability=0.5,
36 | alpha=1.0,
37 | crop_num_pixels=4,
38 | use_mixup=True,
39 | train_size=45000,
40 | valid_size=5000,
41 | test_size=10000,
42 | include_example_keys=True,
43 | input_shape=(32, 32, 3),
44 | output_shape=(10,))))
45 |
46 | examples = itertools.islice(dataset.valid_epoch(), 10)
47 | example_keys = [
48 | example['example_key'][0].decode('utf-8') for example in examples
49 | ]
50 | self.assertEqual(example_keys, [
51 | 'cifar10-train.array_record-00000-of-00001__45000',
52 | 'cifar10-train.array_record-00000-of-00001__45001',
53 | 'cifar10-train.array_record-00000-of-00001__45002',
54 | 'cifar10-train.array_record-00000-of-00001__45003',
55 | 'cifar10-train.array_record-00000-of-00001__45004',
56 | 'cifar10-train.array_record-00000-of-00001__45005',
57 | 'cifar10-train.array_record-00000-of-00001__45006',
58 | 'cifar10-train.array_record-00000-of-00001__45007',
59 | 'cifar10-train.array_record-00000-of-00001__45008',
60 | 'cifar10-train.array_record-00000-of-00001__45009',
61 | ])
62 |
63 |
64 | if __name__ == '__main__':
65 | absltest.main()
66 |
--------------------------------------------------------------------------------
/init2winit/dataset_lib/test_wikitext_tokenizer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for init2winit.dataset_lib.wikitext103."""
17 |
18 | from absl.testing import absltest
19 | from init2winit.dataset_lib import wikitext_tokenizer
20 | import tensorflow as tf
21 |
22 |
23 |
24 | class TestWikitextTokenizer(absltest.TestCase):
25 | """Unit tests for wikitext103.py."""
26 |
27 | def test_tokenizer_vocab_size(self):
28 | """Test vocab size.
29 |
30 | Vocab size should be number of unique words in text file + 2 for the
31 | and tokens.
32 | """
33 | # Get number of unique tokens from tokenizer.
34 | text_dataset = tf.data.TextLineDataset(file_name)
35 |
36 | tokenizer = wikitext_tokenizer.Tokenizer()
37 | tokenizer.train(text_dataset)
38 |
39 | num_unique_tokens = len(tokenizer.dictionary.idx2word)
40 |
41 | # Get number of unique words from fake data.
42 | with open(file_name, 'r') as f:
43 | data = ''
44 | for line in f:
45 | # Not removing this would count tokens like '\n\n' and '\n' while the
46 | # TextLineDataset strips them.
47 | line = line.strip('\n')
48 | data = data + line
49 |
50 | words = data.split(' ')
51 | num_unique_words = len(set(words))
52 |
53 | self.assertEqual(num_unique_tokens, num_unique_words + 2)
54 |
55 | if __name__ == '__main__':
56 | absltest.main()
57 |
--------------------------------------------------------------------------------
/init2winit/dataset_lib/wikitext103.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Module containing hyperparameters, metadata and dataset getter for Wikitext-103 dataset."""
17 |
18 | import itertools
19 |
20 | from init2winit.dataset_lib import data_utils
21 | from init2winit.dataset_lib import wikitext103_input_pipeline as input_pipeline
22 | from init2winit.dataset_lib import wikitext2_input_pipeline
23 | import jax
24 | from ml_collections.config_dict import config_dict
25 | import numpy as np
26 |
27 | PAD_ID = wikitext2_input_pipeline.PAD_ID
28 | Dataset = data_utils.Dataset
29 |
30 | VOCAB_SIZE = 267735
31 |
32 | DEFAULT_HPARAMS = config_dict.ConfigDict(
33 | dict(
34 | sequence_length=128,
35 | max_target_length=128,
36 | max_eval_target_length=128,
37 | eval_sequence_length=128,
38 | input_shape=(128,),
39 | output_shape=(input_pipeline.WORD_VOCAB_SIZE,),
40 | train_size=800210, # Number of sequences.
41 | tokenizer='word',
42 | tokenizer_vocab_path=None,
43 | vocab_size=input_pipeline.WORD_VOCAB_SIZE,
44 | ))
45 |
46 |
47 | METADATA = {
48 | 'apply_one_hot_in_loss': True,
49 | 'shift_inputs': True,
50 | 'causal': True,
51 | 'pad_token': -1,
52 | }
53 |
54 |
55 | def add_weights_to_batch(batch, pad_id: int = PAD_ID):
56 | """Add weights for the input values so that paddings have 0 weight.
57 |
58 | Args:
59 | batch: Batch represented by dict containing 'inputs' and 'targets'.
60 | pad_id: Value for 'inputs' that will have weight 0.
61 |
62 | Returns:
63 | batch with weights
64 | """
65 | batch['weights'] = np.where(batch['inputs'] == pad_id, 0.0, 1.0)
66 | return batch
67 |
68 |
69 | def get_wikitext103(
70 | shuffle_rng,
71 | batch_size: int,
72 | eval_batch_size: int = None,
73 | hps: config_dict.ConfigDict = None,
74 | pad_id: int = PAD_ID) -> Dataset:
75 | """Returns Wikitext-103 Dataset.
76 |
77 | Args:
78 | shuffle_rng: jax.random.PRNGKey
79 | batch_size: training batch size
80 | eval_batch_size: validation batch size
81 | hps: Hyper parameters
82 | pad_id: Value for 'inputs' that will have weight 0.
83 |
84 | Returns:
85 | Dataset
86 |
87 | Raises:
88 | ValueError: If batch_size is not divisible by jax process count.
89 | ValueError: If eval_batch_size is not divisible by jax process count.
90 | """
91 | process_count = jax.process_count()
92 |
93 | if batch_size % process_count != 0:
94 | raise ValueError(
95 | 'process_count={} must divide batch_size={}.'.format(
96 | process_count, batch_size))
97 |
98 | if eval_batch_size is None:
99 | eval_batch_size = batch_size
100 |
101 | if eval_batch_size % process_count != 0:
102 | raise ValueError(
103 | 'process_count={} must divide batch_size={}.'.format(
104 | process_count, batch_size))
105 |
106 | train_dataset, eval_train_dataset, valid_dataset, test_dataset = (
107 | input_pipeline.get_wikitext103_dataset(
108 | hps,
109 | train_batch_size=batch_size,
110 | valid_batch_size=eval_batch_size,
111 | test_batch_size=eval_batch_size,
112 | shuffle_seed=data_utils.convert_jax_to_tf_random_seed(shuffle_rng),
113 | )
114 | )
115 |
116 | def train_iterator_fn():
117 | for batch in train_dataset:
118 | yield add_weights_to_batch(data_utils.tf_to_numpy(batch), pad_id)
119 |
120 | def eval_train_epoch(num_batches=None):
121 | for batch in itertools.islice(iter(eval_train_dataset), num_batches):
122 | yield add_weights_to_batch(data_utils.tf_to_numpy(batch), pad_id)
123 |
124 | def valid_epoch(num_batches=None):
125 | for batch in itertools.islice(iter(valid_dataset), num_batches):
126 | yield add_weights_to_batch(data_utils.tf_to_numpy(batch), pad_id)
127 |
128 | def test_epoch(num_batches=None):
129 | for batch in itertools.islice(iter(test_dataset), num_batches):
130 | yield add_weights_to_batch(data_utils.tf_to_numpy(batch), pad_id)
131 |
132 | return Dataset(train_iterator_fn, eval_train_epoch, valid_epoch,
133 | test_epoch)
134 |
--------------------------------------------------------------------------------
/init2winit/dataset_lib/wikitext103_spm.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Module containing hyperparameters, metadata and dataset getter for Wikitext-103 dataset."""
17 |
18 | import functools
19 | from init2winit.dataset_lib import wikitext103
20 | from init2winit.dataset_lib import wikitext103_input_pipeline
21 | from ml_collections.config_dict import config_dict
22 |
23 | SPM_TOKENIZER_VOCAB_SIZE = wikitext103_input_pipeline.SPM_TOKENIZER_VOCAB_SIZE
24 | SPM_TOKENIZER_VOCAB_PATH = wikitext103_input_pipeline.SPM_TOKENIZER_VOCAB_PATH
25 | PAD_ID = -1
26 | get_wikitext103 = functools.partial(wikitext103.get_wikitext103, pad_id=PAD_ID)
27 |
28 | DEFAULT_HPARAMS = config_dict.ConfigDict(
29 | dict(
30 | sequence_length=128,
31 | max_target_length=128,
32 | max_eval_target_length=128,
33 | eval_sequence_length=128,
34 | input_shape=(128,),
35 | output_shape=(SPM_TOKENIZER_VOCAB_SIZE,),
36 | tokenizer='sentencepiece',
37 | tokenizer_vocab_path=SPM_TOKENIZER_VOCAB_PATH,
38 | vocab_size=SPM_TOKENIZER_VOCAB_SIZE,
39 | train_size=800210, # TODO(kasimbeg): Update this
40 | )
41 | )
42 |
43 |
44 | METADATA = {
45 | 'apply_one_hot_in_loss': True,
46 | 'shift_inputs': True,
47 | 'causal': True,
48 | 'pad_token': -1,
49 | }
50 |
--------------------------------------------------------------------------------
/init2winit/dataset_lib/wikitext2.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Module containing hyperparameters, metadata and dataset getter for Wikitext-2 dataset."""
17 |
18 | import itertools
19 |
20 | from init2winit.dataset_lib import data_utils
21 | from init2winit.dataset_lib import wikitext2_input_pipeline as input_pipeline
22 | from init2winit.dataset_lib.data_utils import Dataset
23 | from init2winit.dataset_lib.wikitext2_input_pipeline import PAD_ID
24 | import jax
25 | from ml_collections.config_dict import config_dict
26 | import numpy as np
27 |
28 | VOCAB_SIZE = 33278
29 |
30 | DEFAULT_HPARAMS = config_dict.ConfigDict(
31 | dict(
32 | sequence_length=34,
33 | max_target_length=34,
34 | max_eval_target_length=34,
35 | input_shape=(34,),
36 | output_shape=(VOCAB_SIZE,),
37 | vocab_size=VOCAB_SIZE,
38 | # TODO(kasimbeg) : add vocab path after seperating out tokenizer
39 | # vocab_path=None,
40 | train_size=59676 # Number of sequences.
41 | ))
42 |
43 | METADATA = {
44 | 'apply_one_hot_in_loss': True,
45 | 'shift_inputs': True,
46 | 'causal': True,
47 | }
48 |
49 |
50 | def add_weights_to_batch(batch, pad_id: int = PAD_ID):
51 | """Add weights for the input values so that paddings have 0 weight.
52 |
53 | Args:
54 | batch: Batch represented by dict containing 'inputs' and 'targets'.
55 | pad_id: Value for 'inputs' that will have weight 0.
56 |
57 | Returns:
58 | batch with weights
59 | """
60 | batch['weights'] = np.where(batch['inputs'] == pad_id, 0.0, 1.0)
61 | return batch
62 |
63 |
64 | def get_wikitext2(
65 | data_rng,
66 | batch_size: int,
67 | eval_batch_size: int = None,
68 | hps: config_dict.ConfigDict = None,) -> Dataset:
69 | """Returns Wikitext-2 Dataset.
70 |
71 | Args:
72 | data_rng: jax.random.PRNGKey
73 | batch_size: training batch size
74 | eval_batch_size: validation batch size
75 | hps: Hyper parameters
76 |
77 | Returns:
78 | Dataset
79 |
80 | Raises:
81 | ValueError: If batch_size is not divisible by jax process count.
82 | ValueError: If eval_batch_size is not divisible by jax process count.
83 | """
84 | process_count = jax.process_count()
85 |
86 | if batch_size % process_count != 0:
87 | raise ValueError(
88 | 'process_count={} must divide batch_size={}.'.format(
89 | process_count, batch_size))
90 |
91 | if eval_batch_size % process_count != 0:
92 | raise ValueError(
93 | 'process_count={} must divide batch_size={}.'.format(
94 | process_count, batch_size))
95 |
96 | if eval_batch_size is None:
97 | eval_batch_size = batch_size
98 |
99 | train_dataset, eval_train_dataset, valid_dataset, test_dataset = input_pipeline.get_wikitext2_dataset(
100 | hps,
101 | train_batch_size=batch_size,
102 | valid_batch_size=eval_batch_size,
103 | test_batch_size=eval_batch_size,
104 | shuffle_seed=data_utils.convert_jax_to_tf_random_seed(data_rng),
105 | )
106 |
107 | def train_iterator_fn():
108 | for batch in train_dataset:
109 | yield add_weights_to_batch(data_utils.tf_to_numpy(batch))
110 |
111 | def eval_train_epoch(num_batches=None):
112 | for batch in itertools.islice(iter(eval_train_dataset), num_batches):
113 | yield add_weights_to_batch(data_utils.tf_to_numpy(batch))
114 |
115 | def valid_epoch(num_batches=None):
116 | for batch in itertools.islice(iter(valid_dataset), num_batches):
117 | yield add_weights_to_batch(data_utils.tf_to_numpy(batch))
118 |
119 | def test_epoch(num_batches=None):
120 | for batch in itertools.islice(iter(test_dataset), num_batches):
121 | yield add_weights_to_batch(data_utils.tf_to_numpy(batch))
122 |
123 | return Dataset(train_iterator_fn, eval_train_epoch, valid_epoch,
124 | test_epoch)
125 |
--------------------------------------------------------------------------------
/init2winit/dataset_lib/wikitext_tokenizer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Contains Tokenizer class for word level tokenization.
17 |
18 | Note that the current tokenization workflow is not yet optimized for time and
19 | memory yet.
20 |
21 | """
22 |
23 | import tensorflow as tf
24 |
25 | EOS_TOKEN = b''
26 | UNKNOWN_TOKEN = b''
27 |
28 |
29 | class _Dictionary:
30 | """Dictionary contains word-to-id mappings and id-to-word mappings.
31 |
32 | Attributes:
33 | word2idx: dict containing key-values where keys are words and values are
34 | tokens.
35 | idx2word: list where the index of each word in the list is the token value.
36 | """
37 |
38 | def __init__(self):
39 | self.word2idx = {}
40 | self.idx2word = []
41 |
42 | def add_word(self, word):
43 | if word not in self.word2idx:
44 | self.idx2word.append(word)
45 | # Start the first token idx at 1, because 0 is reserved for special tokens
46 | # e.g. for padding and masking
47 | self.word2idx[word] = len(self.idx2word)
48 | return self.word2idx[word]
49 |
50 | def __len__(self):
51 | return len(self.idx2word)
52 |
53 |
54 | class Tokenizer:
55 | """Tokenizer object for word level tokenization from words to unique ids.
56 |
57 | Attributes:
58 | dictionary: Dictionary containing word-to-id and id-to-word mappings
59 | lookup_table: tf.lookup.StaticHashTable for looking up token ids from words
60 | """
61 |
62 | def __init__(self):
63 | self.dictionary = _Dictionary()
64 |
65 | def train(self, dataset: tf.data.TextLineDataset):
66 | """Trains a Tokenizer from a TextLineDataset."""
67 | # Add words to the dictionary
68 | self.dictionary.add_word(UNKNOWN_TOKEN) # add default unknown token
69 | for line in dataset:
70 | words = line.numpy().split() + [EOS_TOKEN]
71 | for word in words:
72 | self.dictionary.add_word(word)
73 | # Make static vocabulary table for tf.data style tokenization
74 | self.lookup_table = tf.lookup.StaticHashTable(
75 | tf.lookup.KeyValueTensorInitializer(
76 | tf.constant(list(self.dictionary.word2idx.keys()), dtype=tf.string),
77 | tf.constant(
78 | list(self.dictionary.word2idx.values()), dtype=tf.int32
79 | ),
80 | ),
81 | default_value=self.dictionary.word2idx[UNKNOWN_TOKEN],
82 | )
83 |
84 | def tokenize(self, input_tensor: tf.Tensor) -> tf.Tensor:
85 | """Tokenizes a tensor of UTF-8 strings.
86 |
87 | Args:
88 | input_tensor: A `RaggedTensor` or `Tensor` of UTF-8 strings with any
89 | shape.
90 |
91 | Returns:
92 | A `RaggedTensor` or `Tensor` of tokenized text. The returned shape is
93 | the shape of the input tensor.
94 | """
95 | eos_tensor = tf.constant([EOS_TOKEN], dtype=tf.string)
96 | input_tensor_split = tf.strings.split(input_tensor)
97 | input_tensor_extended = tf.concat([input_tensor_split, eos_tensor], axis=-1)
98 | return self.lookup_table.lookup(input_tensor_extended)
99 |
--------------------------------------------------------------------------------
/init2winit/gradient_statistics_callback.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Callback for computing gradient statistics given set of params.
17 | """
18 |
19 | import functools
20 | import itertools
21 | import os
22 |
23 | import flax.linen as nn
24 | from init2winit import base_callback
25 | from init2winit import checkpoint
26 | from init2winit.dataset_lib import data_utils
27 | import jax
28 | import jax.numpy as jnp
29 |
30 |
31 | class GradientStatisticsCallback(base_callback.BaseCallBack):
32 | """Runs evals on MT models with datasets/params different than in training."""
33 |
34 | def __init__(self,
35 | model,
36 | params,
37 | batch_stats,
38 | optimizer_state,
39 | optimizer_update_fn,
40 | dataset,
41 | hps,
42 | callback_config,
43 | train_dir,
44 | rng,
45 | mesh):
46 | del optimizer_state
47 | del optimizer_update_fn
48 |
49 | self.dataset = dataset
50 | self.model = model
51 | self.hps = hps
52 | self.callback_config = callback_config
53 | self.rng = rng
54 | self.save_path = os.path.join(train_dir, 'gradient_statistics/')
55 | self.mesh = mesh
56 |
57 | self.num_batches_in_training_epoch = (
58 | self.hps.train_size // self.hps.batch_size
59 | )
60 | if callback_config is not None:
61 | if 'num_batches_in_training_epoch' in callback_config.keys():
62 | self.num_batches_in_training_epoch = callback_config[
63 | 'num_batches_in_training_epoch'
64 | ]
65 |
66 | self.num_updates = 0
67 |
68 | def update(params, batch, batch_stats, dropout_rng):
69 | def opt_cost(params):
70 | return self.model.training_cost(
71 | params,
72 | batch=batch,
73 | batch_stats=batch_stats,
74 | dropout_rng=dropout_rng,
75 | )
76 |
77 | grad_fn = jax.value_and_grad(opt_cost, has_aux=True)
78 | _, grad = grad_fn(params)
79 |
80 | return grad
81 |
82 | params_sharding = jax.tree_util.tree_map(
83 | lambda x: x.sharding, params
84 | )
85 | batch_stats_sharding = nn.get_sharding(batch_stats, self.mesh)
86 |
87 | self.jitted_update = jax.jit(
88 | update,
89 | in_shardings=(
90 | params_sharding,
91 | jax.sharding.NamedSharding(
92 | self.mesh, jax.sharding.PartitionSpec('devices')),
93 | batch_stats_sharding,
94 | None
95 | ),
96 | out_shardings=(params_sharding)
97 | )
98 |
99 | def run_eval(self, params, batch_stats, optimizer_state, global_step):
100 | """Computes gradient statistics from mini batches over full training data.
101 | """
102 | del optimizer_state
103 | train_iter = itertools.islice(
104 | self.dataset.train_iterator_fn(), self.num_batches_in_training_epoch
105 | )
106 |
107 | grad_sum = jax.tree.map(jnp.zeros_like, params)
108 | grad_squared_sum = jax.tree.map(jnp.zeros_like, params)
109 | self.num_updates = 0
110 |
111 | make_global_array_fn = functools.partial(
112 | data_utils.make_global_array, mesh=self.mesh
113 | )
114 |
115 | for batch in train_iter:
116 | sharded_batch = jax.tree_util.tree_map(make_global_array_fn, batch)
117 | grads = self.jitted_update(params, sharded_batch, batch_stats, self.rng)
118 |
119 | grad_sum = jax.tree_util.tree_map(
120 | lambda g_sum, g: g_sum + g, grad_sum, grads
121 | )
122 |
123 | grad_squared_sum = jax.tree_util.tree_map(
124 | lambda g_squared, g: g_squared + g**2, grad_squared_sum, grads
125 | )
126 |
127 | self.num_updates += 1
128 |
129 | grad_mean = jax.tree_util.tree_map(
130 | lambda g_sum: g_sum / self.num_updates, grad_sum
131 | )
132 | grad_std = jax.tree_util.tree_map(
133 | lambda g_squared, g_mean: jnp.sqrt( # pylint: disable=g-long-lambda
134 | g_squared / self.num_updates - g_mean**2
135 | ),
136 | grad_squared_sum,
137 | grad_mean,
138 | )
139 |
140 | state = dict(
141 | grad_std=jax.device_get(grad_std),
142 | grad_mean=jax.device_get(grad_mean),
143 | step=global_step
144 | )
145 |
146 | checkpoint.save_checkpoint(
147 | self.save_path,
148 | step=global_step,
149 | state=state,
150 | prefix='measurement_',
151 | max_to_keep=None)
152 |
153 | return {}
154 |
--------------------------------------------------------------------------------
/init2winit/init_lib/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/init2winit/init_lib/initializers.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Registry for the available initializers we can test.
17 |
18 | API of an initializer:
19 | new_params = init(loss, init_params, hps, num_outputs, input_shape)
20 |
21 | TODO(gilmer, gdahl, schsam): The API of an initializer should in general be
22 | aware of the moments of the data. Currently we assume that all input coordinates
23 | are iid standard normal distributions.
24 | """
25 |
26 | from init2winit.init_lib import meta_init
27 | from init2winit.init_lib import sparse_init
28 | from ml_collections.config_dict import config_dict
29 |
30 |
31 | # This function is meant to match the general API of an initializer
32 | # pylint: disable=unused-argument
33 | def noop(
34 | loss_fn=None,
35 | flax_module=None,
36 | params=None,
37 | hps=None,
38 | input_shape=None,
39 | output_shape=None,
40 | rng_key=None,
41 | metrics_logger=None,
42 | ):
43 | """No-op init."""
44 | return params
45 | # pylint: enable=unused-argument
46 |
47 | DEFAULT_HPARAMS = config_dict.ConfigDict()
48 |
49 | _ALL_INITIALIZERS = {
50 | 'noop': (noop, DEFAULT_HPARAMS),
51 | 'meta_init': (meta_init.meta_init, meta_init.DEFAULT_HPARAMS),
52 | 'sparse_init': (sparse_init.sparse_init, sparse_init.DEFAULT_HPARAMS),
53 | }
54 |
55 |
56 | def get_initializer(initializer_name):
57 | """Get the corresponding initializer function based on the initializer string.
58 |
59 | API of an initializer:
60 | init_fn, hparams = get_initializer(init)
61 | new_params, final_l = init_fn(loss, init_params, hps,
62 | num_outputs, input_shape)
63 |
64 | Args:
65 | initializer_name: (str) e.g. default.
66 |
67 | Returns:
68 | initializer
69 | Raises:
70 | ValueError if model is unrecognized.
71 | """
72 | try:
73 | return _ALL_INITIALIZERS[initializer_name][0]
74 | except KeyError:
75 | raise ValueError('Unrecognized initializer: {}'.format(initializer_name))
76 |
77 |
78 | def get_initializer_hparams(initializer_name):
79 | """Get the corresponding hyperparameters based on the initializer string.
80 |
81 | Args:
82 | initializer_name: (str) e.g. default.
83 |
84 | Returns:
85 | hps
86 | Raises:
87 | ValueError if model is unrecognized.
88 | """
89 | try:
90 | return _ALL_INITIALIZERS[initializer_name][1]
91 | except KeyError:
92 | raise ValueError('Unrecognized initializer: {}'.format(initializer_name))
93 |
--------------------------------------------------------------------------------
/init2winit/init_lib/sparse_init.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Defines the SparseInit initializer.
17 |
18 | This initializer limits the number of non-zero incoming and outgoing connection
19 | weights. For more information, see Section 5 of (Martens, 2010), which can be
20 | found at https://www.cs.toronto.edu/~jmartens/docs/Deep_HessianFree.pdf.
21 | """
22 |
23 | from flax.core import frozen_dict
24 | from flax.core import unfreeze
25 | import jax
26 | from ml_collections.config_dict import config_dict
27 | import numpy as np
28 |
29 | DEFAULT_HPARAMS = config_dict.ConfigDict(dict(non_zero_connection_weights=15,))
30 |
31 |
32 | def sparse_init(loss_fn,
33 | flax_module,
34 | params,
35 | hps,
36 | input_shape,
37 | output_shape,
38 | rng_key,
39 | metrics_logger=None,
40 | log_every=10):
41 | """Implements SparseInit initializer.
42 |
43 | Args:
44 | loss_fn: Loss function.
45 | flax_module: Flax nn.Module class.
46 | params: The dict of model parameters.
47 | hps: HParam object. Required hparams are meta_learning_rate,
48 | meta_batch_size, meta_steps, and epsilon.
49 | input_shape: Must agree with batch[0].shape[1:].
50 | output_shape: Must agree with batch[1].shape[1:].
51 | rng_key: jax.PRNGKey, used to seed all randomness.
52 | metrics_logger: Instance of utils.MetricsLogger
53 | log_every: Print meta loss every k steps.
54 |
55 | Returns:
56 | A Flax model with sparse initialization.
57 | """
58 |
59 | del flax_module, loss_fn, input_shape, output_shape, metrics_logger, log_every
60 |
61 | params = unfreeze(params)
62 | activation_functions = hps.activation_function
63 | num_hidden_layers = len(hps.hid_sizes)
64 | if isinstance(hps.activation_function, str):
65 | activation_functions = [hps.activation_function] * num_hidden_layers
66 | for i, key in enumerate(params):
67 | num_units_in, num_units_out = params[key]['kernel'].shape
68 |
69 | mask = np.full((num_units_in, num_units_out), True, dtype=bool)
70 |
71 | # Restrict the number of non-zero weights from input units.
72 | rng_key, *rng_keys_in = jax.random.split(rng_key, num_units_in + 1)
73 | for k in range(num_units_in):
74 | if num_units_out > hps.non_zero_connection_weights:
75 | non_zero_units_out = jax.random.choice(
76 | rng_keys_in[k], num_units_out, (hps.non_zero_connection_weights,),
77 | replace=False)
78 | mask[k, non_zero_units_out] = False
79 | else:
80 | mask[k, :] = False
81 |
82 | # Restrict the number of non-zero weights to output units.
83 | rng_key, *rng_keys_out = jax.random.split(rng_key, num_units_out + 1)
84 | for k in range(num_units_out):
85 | if num_units_in > hps.non_zero_connection_weights:
86 | non_zero_units_in = jax.random.choice(
87 | rng_keys_out[k], num_units_in, (hps.non_zero_connection_weights,),
88 | replace=False)
89 | mask[non_zero_units_in, k] = False
90 | else:
91 | mask[:, k] = False
92 | params[key]['kernel'] = params[key]['kernel'].at[mask].set(0.0)
93 |
94 | if i < num_hidden_layers and activation_functions[i] == 'tanh':
95 | params[key]['bias'] = params[key]['bias'].at[:].set(0.5)
96 | else:
97 | params[key]['bias'] = params[key]['bias'].at[:].set(0.0)
98 | return frozen_dict.freeze(params)
99 |
--------------------------------------------------------------------------------
/init2winit/model_lib/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/init2winit/model_lib/adabelief_vgg.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Flax implementation of Adabelief VGG.
17 |
18 | This module ports the Adabelief implemetation of VGG to Flax. The
19 | Adabelief paper and github can be found here:
20 |
21 | https://arxiv.org/abs/2010.07468
22 |
23 | https://github.com/juntang-zhuang/Adabelief-Optimizer/blob/update_0.2.0/PyTorch_Experiments/classification_cifar10/models/vgg.py
24 |
25 | The original VGGNet paper can be found here:
26 |
27 | https://arxiv.org/abs/1409.1556
28 | """
29 |
30 | import functools
31 |
32 | from flax import linen as nn
33 | from init2winit.model_lib import base_model
34 | from init2winit.model_lib import model_utils
35 | import jax.numpy as jnp
36 | from ml_collections.config_dict import config_dict
37 |
38 |
39 | DEFAULT_HPARAMS = config_dict.ConfigDict(
40 | dict(
41 | num_layers=11, # Must be one of [11, 13, 16, 19]
42 | layer_rescale_factors={},
43 | lr_hparams={
44 | 'schedule': 'constant',
45 | 'base_lr': 0.2,
46 | },
47 | normalizer='none',
48 | optimizer='momentum',
49 | opt_hparams={
50 | 'momentum': 0.9,
51 | },
52 | batch_size=128,
53 | l2_decay_factor=0.0001,
54 | l2_decay_rank_threshold=2,
55 | label_smoothing=None,
56 | rng_seed=-1,
57 | use_shallue_label_smoothing=False,
58 | model_dtype='float32',
59 | grad_clip=None,
60 | ))
61 |
62 |
63 | def classifier(x, num_outputs, dropout_rate, deterministic):
64 | """Implements the classification portion of the network."""
65 |
66 | x = nn.Dropout(rate=dropout_rate, deterministic=deterministic)(x)
67 | x = nn.Dense(512)(x)
68 | x = nn.relu(x)
69 | x = nn.Dropout(rate=dropout_rate, deterministic=deterministic)(x)
70 | x = nn.Dense(512)(x)
71 | x = nn.relu(x)
72 | x = nn.Dense(num_outputs)(x)
73 | return x
74 |
75 |
76 | def features(x, num_layers, normalizer, dtype, train):
77 | """Implements the feature extraction portion of the network."""
78 |
79 | layers = _layer_size_options[num_layers]
80 | conv = functools.partial(nn.Conv, use_bias=False, dtype=dtype)
81 | maybe_normalize = model_utils.get_normalizer(normalizer, train)
82 | for l in layers:
83 | if l == 'M':
84 | x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
85 | else:
86 | x = conv(features=l, kernel_size=(3, 3), padding=((1, 1), (1, 1)))(x)
87 | x = maybe_normalize()(x)
88 | x = nn.relu(x)
89 | return x
90 |
91 |
92 | class VGG(nn.Module):
93 | """Adabelief VGG."""
94 | num_layers: int
95 | num_outputs: int
96 | normalizer: str = 'none'
97 | dtype: str = 'float32'
98 |
99 | @nn.compact
100 | def __call__(self, x, train):
101 | x = features(x, self.num_layers, self.normalizer, self.dtype, train)
102 | x = jnp.reshape(x, (x.shape[0], -1))
103 | x = classifier(
104 | x, self.num_outputs, dropout_rate=0.5, deterministic=not train)
105 | return x
106 |
107 |
108 | # Specifies the sequence of layers in the feature extraction section of the
109 | # network for a given size.
110 | # The numbers indicate the feature size of a convolutional layer, the
111 | # letter M indicates a max pooling layer.
112 | _layer_size_options = {
113 | 1: [
114 | 8, 'M', 16, 'M', 32, 32, 'M', 64, 64, 'M', 64, 64, 'M'
115 | ], # used for testing only.
116 | 11: [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
117 | 13: [
118 | 64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'
119 | ],
120 | 16: [
121 | 64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512,
122 | 512, 512, 'M'
123 | ],
124 | 19: [
125 | 64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512,
126 | 'M', 512, 512, 512, 512, 'M'
127 | ],
128 | }
129 |
130 |
131 | # pylint: disable=[missing-class-docstring]
132 | class AdaBeliefVGGModel(base_model.BaseModel):
133 | def build_flax_module(self):
134 | """Adabelief VGG."""
135 | return VGG(
136 | num_layers=self.hps.num_layers,
137 | num_outputs=self.hps['output_shape'][-1],
138 | dtype=self.hps.model_dtype,
139 | normalizer=self.hps.normalizer)
140 |
141 | def get_fake_inputs(self, hps):
142 | """Helper method solely for the purpose of initialzing the model."""
143 | dummy_inputs = [
144 | jnp.zeros((hps.batch_size, *hps.input_shape), dtype=hps.model_dtype)
145 | ]
146 | return dummy_inputs
147 |
--------------------------------------------------------------------------------
/init2winit/model_lib/autoencoder.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | r"""Fully connected autoencoder.
17 |
18 | This model builds an autoencoder using FullyConnected module.
19 | More information on the fully connected autoencoder model can be found here:
20 |
21 | https://www.cs.toronto.edu/~hinton/science.pdf
22 |
23 | """
24 |
25 | from init2winit.model_lib import base_model
26 | from init2winit.model_lib.fully_connected import FullyConnected
27 | from jax.nn import initializers
28 | import jax.numpy as jnp
29 | from ml_collections.config_dict import config_dict
30 |
31 | # small test hparams
32 | # https://blog.keras.io/building-autoencoders-in-keras.html
33 | # for the configuration of a standard fully connected autoencoder model,
34 | # see https://www.cs.toronto.edu/~hinton/science.pdf.
35 | DEFAULT_HPARAMS = config_dict.ConfigDict(
36 | dict(
37 | hid_sizes=[128, 64, 32, 64, 128],
38 | activation_function=['relu', 'relu', 'relu', 'relu', 'relu'],
39 | kernel_scales=[1.0] * 6,
40 | lr_hparams={
41 | 'base_lr': 0.1,
42 | 'schedule': 'constant'
43 | },
44 | layer_rescale_factors={},
45 | optimizer='hessian_free',
46 | opt_hparams={
47 | 'cg_max_iter': 250,
48 | 'cg_iter_tracking_method': 'back_tracking',
49 | 'use_line_search': True,
50 | 'init_damping': 50.0,
51 | 'damping_ub': 10 ** 2,
52 | 'damping_lb': 10 ** -6,
53 | },
54 | batch_size=128,
55 | label_smoothing=None,
56 | rng_seed=-1,
57 | use_shallue_label_smoothing=False,
58 | model_dtype='float32',
59 | l2_decay_factor=2e-5,
60 | l2_decay_rank_threshold=1,
61 | ))
62 |
63 |
64 | class AutoEncoderModel(base_model.BaseModel):
65 | """Model class for AutoEncoder model."""
66 |
67 | def build_flax_module(self):
68 | kernel_inits = [
69 | initializers.normal(scale)
70 | for scale in self.hps.kernel_scales
71 | ]
72 |
73 | return FullyConnected(
74 | num_outputs=self.hps['output_shape'][-1],
75 | hid_sizes=self.hps.hid_sizes,
76 | activation_function=self.hps.activation_function,
77 | kernel_inits=kernel_inits)
78 |
79 | def get_fake_inputs(self, hps):
80 | """Helper method solely for the purpose of initialzing the model."""
81 | dummy_inputs = [
82 | jnp.zeros((hps.batch_size, *hps.input_shape), dtype=hps.model_dtype)
83 | ]
84 | return dummy_inputs
85 |
--------------------------------------------------------------------------------
/init2winit/model_lib/convolutional_autoencoder.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | r"""Convolutional autoencoder.
17 |
18 | This model uses a convolutional encoder-decoder network to reconstruct input
19 | images as outputs.
20 |
21 | """
22 |
23 | from typing import Any, Dict, Sequence
24 |
25 | from flax import linen as nn
26 | from init2winit.model_lib import base_model
27 | from init2winit.model_lib import model_utils
28 | from jax import numpy as jnp
29 | from ml_collections.config_dict import config_dict
30 |
31 |
32 | # small test hparams from
33 | # https://blog.keras.io/building-autoencoders-in-keras.html
34 | DEFAULT_HPARAMS = config_dict.ConfigDict(
35 | dict(
36 | encoder={
37 | 'filter_sizes': [16, 8, 8],
38 | 'kernel_sizes': [(3, 3), (3, 3), (3, 3)],
39 | 'kernel_paddings': ['SAME', 'SAME', 'SAME'],
40 | 'window_sizes': [(2, 2), (2, 2), (2, 2)],
41 | 'window_paddings': ['SAME', 'SAME', 'SAME'],
42 | 'strides': [(2, 2), (2, 2), (2, 2)],
43 | 'activations': ['relu', 'relu', 'relu'],
44 | },
45 | decoder={
46 | 'filter_sizes': [8, 8, 16, 1],
47 | 'kernel_sizes': [(3, 3), (3, 3), (3, 3), (3, 3)],
48 | 'window_sizes': [(2, 2), (2, 2), (2, 2), None],
49 | 'paddings': ['SAME', ((1, 0), (1, 0)), 'SAME', 'SAME'],
50 | 'activations': ['relu', 'relu', 'relu', 'id'],
51 | },
52 |
53 | activation_function='relu',
54 | lr_hparams={
55 | 'base_lr': 0.02,
56 | 'schedule': 'constant'
57 | },
58 | layer_rescale_factors={},
59 | optimizer='momentum',
60 | opt_hparams={
61 | 'momentum': 0,
62 | },
63 | batch_size=128,
64 | l2_decay_factor=None,
65 | l2_decay_rank_threshold=0,
66 | label_smoothing=None,
67 | rng_seed=-1,
68 | use_shallue_label_smoothing=False,
69 | model_dtype='float32',
70 | grad_clip=None,
71 | ))
72 |
73 |
74 | class ConvAutoEncoder(nn.Module):
75 | """Defines a fully connected neural network.
76 |
77 | The model assumes the input data has shape
78 | [batch_size_per_device, *input_shape] where input_shape may be of arbitrary
79 | rank. The model flatten the input before applying a dense layer.
80 | """
81 | output_shape: Sequence[int]
82 | encoder: Dict[str, Any]
83 | decoder: Dict[str, Any]
84 |
85 | @nn.compact
86 | def __call__(self, x, train):
87 | del train
88 | encoder_keys = [
89 | 'filter_sizes',
90 | 'kernel_sizes',
91 | 'kernel_paddings',
92 | 'window_sizes',
93 | 'window_paddings',
94 | 'strides',
95 | 'activations',
96 | ]
97 | if len(set(len(self.encoder[k]) for k in encoder_keys)) > 1:
98 | raise ValueError(
99 | 'The elements in encoder dict do not have the same length.')
100 |
101 | decoder_keys = [
102 | 'filter_sizes',
103 | 'kernel_sizes',
104 | 'window_sizes',
105 | 'paddings',
106 | 'activations',
107 | ]
108 | if len(set(len(self.decoder[k]) for k in decoder_keys)) > 1:
109 | raise ValueError(
110 | 'The elements in decoder dict do not have the same length.')
111 |
112 | # encoder
113 | for i in range(len(self.encoder['filter_sizes'])):
114 | x = nn.Conv(
115 | self.encoder['filter_sizes'][i],
116 | self.encoder['kernel_sizes'][i],
117 | padding=self.encoder['kernel_paddings'][i])(x)
118 | x = model_utils.ACTIVATIONS[self.encoder['activations'][i]](x)
119 | x = nn.max_pool(
120 | x, self.encoder['window_sizes'][i],
121 | strides=self.encoder['strides'][i],
122 | padding=self.encoder['window_paddings'][i])
123 |
124 | # decoder
125 | for i in range(len(self.decoder['filter_sizes'])):
126 | x = nn.ConvTranspose(
127 | self.decoder['filter_sizes'][i],
128 | self.decoder['kernel_sizes'][i],
129 | self.decoder['window_sizes'][i],
130 | padding=self.decoder['paddings'][i])(x)
131 | x = model_utils.ACTIVATIONS[self.decoder['activations'][i]](x)
132 | return x
133 |
134 |
135 | # pylint: disable=[missing-class-docstring]
136 | class ConvAutoEncoderModel(base_model.BaseModel):
137 |
138 | def build_flax_module(self):
139 | return ConvAutoEncoder(
140 | output_shape=self.hps.output_shape,
141 | encoder=self.hps.encoder,
142 | decoder=self.hps.decoder)
143 |
144 | def get_fake_inputs(self, hps):
145 | """Helper method solely for the purpose of initialzing the model."""
146 | dummy_inputs = [
147 | jnp.zeros((hps.batch_size, *hps.input_shape), dtype=hps.model_dtype)
148 | ]
149 | return dummy_inputs
150 |
--------------------------------------------------------------------------------
/init2winit/model_lib/fully_connected.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Simple fully connected feedforward neural network classifier."""
17 | import copy
18 | from typing import Any, Tuple
19 |
20 | from flax import linen as nn
21 | from init2winit.model_lib import base_model
22 | from init2winit.model_lib import model_utils
23 | from jax.nn import initializers
24 | import jax.numpy as jnp
25 | from ml_collections.config_dict import config_dict
26 |
27 |
28 | # small hparams used for unit tests
29 | DEFAULT_HPARAMS = config_dict.ConfigDict(
30 | dict(
31 | hid_sizes=[20, 10],
32 | kernel_scales=[1.0, 1.0, 1.0],
33 | lr_hparams={
34 | 'base_lr': 0.1,
35 | 'schedule': 'constant'
36 | },
37 | layer_rescale_factors={},
38 | optimizer='momentum',
39 | opt_hparams={
40 | 'momentum': 0.9,
41 | },
42 | batch_size=128,
43 | total_accumulated_batch_size=None,
44 | activation_function='relu',
45 | l2_decay_factor=.0005,
46 | l2_decay_rank_threshold=2,
47 | label_smoothing=None,
48 | rng_seed=-1,
49 | use_shallue_label_smoothing=False,
50 | model_dtype='float32',
51 | grad_clip=None,
52 | ))
53 |
54 |
55 | class FullyConnected(nn.Module):
56 | """Defines a fully connected neural network.
57 |
58 | The model assumes the input data has shape
59 | [batch_size_per_device, *input_shape] where input_shape may be of arbitrary
60 | rank. The model flatten the input before applying a dense layer.
61 | """
62 | num_outputs: int
63 | hid_sizes: Tuple[int]
64 | activation_function: Any
65 | kernel_inits: Tuple[model_utils.Initializer]
66 | bias_init: model_utils.Initializer = initializers.zeros
67 |
68 | @nn.compact
69 | def __call__(self, x, train):
70 | del train
71 | if not isinstance(self.activation_function, str):
72 | if len(self.activation_function) != len(self.hid_sizes):
73 | raise ValueError(
74 | 'The number of activation functions must be equal to the number '
75 | 'of hidden layers')
76 | activation_function = copy.deepcopy(self.activation_function)
77 | else:
78 | activation_function = [self.activation_function] * len(self.hid_sizes)
79 |
80 | x = jnp.reshape(x, (x.shape[0], -1))
81 | for i, (num_hid, init) in enumerate(
82 | zip(self.hid_sizes, self.kernel_inits[:-1])):
83 | x = nn.Dense(num_hid, kernel_init=init, bias_init=self.bias_init)(x)
84 | x = model_utils.ACTIVATIONS[activation_function[i]](x)
85 | x = nn.Dense(
86 | self.num_outputs,
87 | kernel_init=self.kernel_inits[-1],
88 | bias_init=self.bias_init)(x)
89 | return x
90 |
91 |
92 | # pylint: disable=missing-class-docstring
93 | class FullyConnectedModel(base_model.BaseModel):
94 | """Model class for fully connected model."""
95 |
96 | def build_flax_module(self):
97 | kernel_inits = [
98 | initializers.variance_scaling(scale, 'fan_in', 'truncated_normal')
99 | for scale in self.hps.kernel_scales
100 | ]
101 | return FullyConnected(
102 | num_outputs=self.hps['output_shape'][-1],
103 | hid_sizes=tuple(self.hps.hid_sizes),
104 | activation_function=self.hps.activation_function,
105 | kernel_inits=tuple(kernel_inits))
106 |
107 | def get_fake_inputs(self, hps):
108 | """Helper method solely for the purpose of initialzing the model."""
109 | dummy_inputs = [
110 | jnp.zeros((hps.batch_size, *hps.input_shape), dtype=hps.model_dtype)
111 | ]
112 | return dummy_inputs
113 |
--------------------------------------------------------------------------------
/init2winit/model_lib/max_pooling_cnn.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Max pooling convnet classifier.
17 |
18 | This model can be used to implement the 3c3d architecture from:
19 | https://github.com/fsschneider/DeepOBS/blob/master/deepobs/tensorflow/testproblems/_3c3d.py
20 | """
21 | from typing import Any, Sequence
22 |
23 | from flax import linen as nn
24 | from init2winit.model_lib import base_model
25 | from init2winit.model_lib import model_utils
26 | from jax.nn import initializers
27 | import jax.numpy as jnp
28 |
29 | from ml_collections.config_dict import config_dict
30 |
31 |
32 | # small hparams used for unit tests
33 | DEFAULT_HPARAMS = config_dict.ConfigDict(dict(
34 | num_filters=[64, 96, 128],
35 | kernel_sizes=[5, 3, 3],
36 | kernel_paddings=['VALID', 'VALID', 'SAME'],
37 | window_sizes=[3, 3, 3],
38 | window_paddings=['SAME', 'SAME', 'SAME'],
39 | strides=[2, 2, 2],
40 | num_dense_units=[512, 256],
41 | lr_hparams={
42 | 'base_lr': 0.001,
43 | 'schedule': 'constant'
44 | },
45 | layer_rescale_factors={},
46 | optimizer='momentum',
47 | opt_hparams={
48 | 'momentum': 0.9,
49 | },
50 | batch_size=128,
51 | activation_fn='relu',
52 | normalizer='none',
53 | l2_decay_factor=.0005,
54 | l2_decay_rank_threshold=2,
55 | label_smoothing=None,
56 | rng_seed=-1,
57 | use_shallue_label_smoothing=False,
58 | model_dtype='float32',
59 | grad_clip=None,
60 | total_accumulated_batch_size=None,
61 | ))
62 |
63 |
64 | class MaxPoolingCNN(nn.Module):
65 | """Defines a CNN model with max pooling.
66 |
67 | The model assumes the input shape is [batch, H, W, C].
68 | """
69 | num_outputs: int
70 | num_filters: Sequence[int]
71 | kernel_sizes: Sequence[int]
72 | kernel_paddings: Sequence[str]
73 | window_sizes: Sequence[int]
74 | window_paddings: Sequence[str]
75 | strides: Sequence[int]
76 | num_dense_units: int
77 | activation_fn: Any
78 | normalizer: str = 'none'
79 | kernel_init: model_utils.Initializer = initializers.lecun_normal()
80 | bias_init: model_utils.Initializer = initializers.zeros
81 |
82 | @nn.compact
83 | def __call__(self, x, train):
84 | maybe_normalize = model_utils.get_normalizer(self.normalizer, train)
85 | iterator = zip(
86 | self.num_filters, self.kernel_sizes, self.kernel_paddings,
87 | self.window_sizes, self.window_paddings, self.strides)
88 | for num_filters, kernel_size, kernel_padding, window_size, window_padding, stride in iterator:
89 | x = nn.Conv(
90 | num_filters, (kernel_size, kernel_size), (1, 1),
91 | padding=kernel_padding,
92 | kernel_init=self.kernel_init,
93 | bias_init=self.bias_init)(x)
94 | x = model_utils.ACTIVATIONS[self.activation_fn](x)
95 | x = maybe_normalize()(x)
96 | x = nn.max_pool(
97 | x,
98 | window_shape=(window_size, window_size),
99 | strides=(stride, stride),
100 | padding=window_padding)
101 | x = jnp.reshape(x, (x.shape[0], -1))
102 | for num_units in self.num_dense_units:
103 | x = nn.Dense(
104 | num_units, kernel_init=self.kernel_init, bias_init=self.bias_init)(x)
105 | x = model_utils.ACTIVATIONS[self.activation_fn](x)
106 | x = maybe_normalize()(x)
107 | x = nn.Dense(
108 | self.num_outputs,
109 | kernel_init=self.kernel_init,
110 | bias_init=self.bias_init)(x)
111 | return x
112 |
113 |
114 | class MaxPoolingCNNModel(base_model.BaseModel):
115 | """Model class for MaxPooling CNNModel."""
116 |
117 | def build_flax_module(self):
118 | """CNN with a set of conv layers with max pooling followed by fully connected layers."""
119 | return MaxPoolingCNN(
120 | num_outputs=self.hps['output_shape'][-1],
121 | num_filters=self.hps.num_filters,
122 | kernel_sizes=self.hps.kernel_sizes,
123 | kernel_paddings=self.hps.kernel_paddings,
124 | window_sizes=self.hps.window_sizes,
125 | window_paddings=self.hps.window_paddings,
126 | strides=self.hps.strides,
127 | num_dense_units=self.hps.num_dense_units,
128 | activation_fn=self.hps.activation_fn,
129 | normalizer=self.hps.normalizer)
130 |
131 | def get_fake_inputs(self, hps):
132 | """Helper method solely for the purpose of initialzing the model."""
133 | dummy_inputs = [
134 | jnp.zeros((hps.batch_size, *hps.input_shape), dtype=hps.model_dtype)
135 | ]
136 | return dummy_inputs
137 |
--------------------------------------------------------------------------------
/init2winit/model_lib/partition_tree.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Contains functions to partition a parameter pytree."""
17 |
18 |
19 | def outer_key(x):
20 | return x[0]
21 |
22 |
23 | def create_partition_flat_params_fn(key_map):
24 | """Partitions a flattened pytree according to the provided key_map.
25 |
26 | Subsets are determined by the kep_map which hashes the flattened model
27 | parameter keys into disjount groups. For example, if the flattened param
28 | tree is
29 |
30 | {('a', 'b'): 1.0,
31 | ('a', 'c'): 1.0,
32 | ('d', 'b'): 2.0}
33 |
34 | And we partition on the out_key then the output is
35 | {'a': {('a', 'b'): 1.0, ('a', 'c'): 1.0}
36 | 'd': {('d', 'b'): 2.0}}.
37 |
38 | Args:
39 | key_map: Maps a tuple of strings to a hashable value.
40 |
41 | Returns:
42 | partition_flat_params, a function which returns a partitioned param
43 | dictionary.
44 | """
45 | def partition_flat_params(flat_params):
46 | subparam_groups = {}
47 | for tup in flat_params:
48 | mapped_key = key_map(tup)
49 | if mapped_key not in subparam_groups:
50 | subparam_groups[mapped_key] = {}
51 | subparam_groups[mapped_key][tup] = flat_params[tup]
52 |
53 | return subparam_groups
54 | return partition_flat_params
55 |
56 |
57 | registry = {
58 | 'outer_key': create_partition_flat_params_fn(outer_key),
59 | }
60 |
61 |
62 | def get_param_partition_fn(name):
63 | return registry[name]
64 |
65 |
66 | # Used in test_model_debugger.py
67 | def get_test_group(params):
68 | del params
69 | return ['B_0/C_0', 'C_0']
70 |
71 |
72 | skip_analysis_registry = {
73 | 'test_group': get_test_group,
74 | }
75 |
76 |
77 | def get_skip_analysis_fn(name):
78 | return skip_analysis_registry[name]
79 |
80 |
--------------------------------------------------------------------------------
/init2winit/model_lib/simple_cnn.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Simple convnet classifier."""
17 | from typing import Sequence
18 |
19 | from flax import linen as nn
20 | from init2winit.model_lib import base_model
21 | from init2winit.model_lib import model_utils
22 | from jax.nn import initializers
23 | import jax.numpy as jnp
24 |
25 | from ml_collections.config_dict import config_dict
26 |
27 |
28 | # small hparams used for unit tests
29 | DEFAULT_HPARAMS = config_dict.ConfigDict(dict(
30 | num_filters=[20, 10],
31 | kernel_sizes=[3, 3],
32 | lr_hparams={
33 | 'base_lr': 0.001,
34 | 'schedule': 'constant'
35 | },
36 | layer_rescale_factors={},
37 | optimizer='momentum',
38 | opt_hparams={
39 | 'momentum': 0.9,
40 | },
41 | batch_size=128,
42 | activation_function='relu',
43 | l2_decay_factor=.0005,
44 | l2_decay_rank_threshold=2,
45 | label_smoothing=None,
46 | rng_seed=-1,
47 | use_shallue_label_smoothing=False,
48 | model_dtype='float32',
49 | ))
50 |
51 |
52 | class SimpleCNN(nn.Module):
53 | """Defines a simple CNN model.
54 |
55 | The model assumes the input shape is [batch, H, W, C].
56 | """
57 | num_outputs: int
58 | num_filters: Sequence[int]
59 | kernel_sizes: Sequence[int]
60 | activation_function: int
61 | kernel_init: model_utils.Initializer = initializers.lecun_normal()
62 | bias_init: model_utils.Initializer = initializers.zeros
63 |
64 | @nn.compact
65 | def __call__(self, x, train):
66 | for num_filters, kernel_size in zip(self.num_filters, self.kernel_sizes):
67 | x = nn.Conv(
68 | num_filters, (kernel_size, kernel_size), (1, 1),
69 | kernel_init=self.kernel_init,
70 | bias_init=self.bias_init)(x)
71 | x = model_utils.ACTIVATIONS[self.activation_function](x)
72 | x = jnp.reshape(x, (x.shape[0], -1))
73 | x = nn.Dense(
74 | self.num_outputs,
75 | kernel_init=self.kernel_init,
76 | bias_init=self.bias_init)(x)
77 | return x
78 |
79 |
80 | class SimpleCNNModel(base_model.BaseModel):
81 | """Model class for Simple CNN Model."""
82 |
83 | def build_flax_module(self):
84 | """Simple CNN with a set of conv layers followed by fully connected layers."""
85 | return SimpleCNN(
86 | num_outputs=self.hps['output_shape'][-1],
87 | num_filters=self.hps.num_filters,
88 | kernel_sizes=self.hps.kernel_sizes,
89 | activation_function=self.hps.activation_function)
90 |
91 | def get_fake_inputs(self, hps):
92 | """Helper method solely for the purpose of initialzing the model."""
93 | dummy_inputs = [
94 | jnp.zeros((hps.batch_size, *hps.input_shape), dtype=hps.model_dtype)
95 | ]
96 | return dummy_inputs
97 |
--------------------------------------------------------------------------------
/init2winit/mt_eval/eval_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """BLEU calculation utilities."""
17 |
18 | import os
19 | import pathlib
20 |
21 | from init2winit import checkpoint
22 | import jax
23 | import sacrebleu
24 | from tensorflow.io import gfile
25 |
26 | exists = gfile.exists
27 | glob = gfile.glob
28 |
29 |
30 | def compute_bleu_from_predictions(predictions, references, language_code, name):
31 | """Computes BLEU score given predictions and references."""
32 | sacrebleu_tokenizer = 'zh' if language_code == 'zh' else sacrebleu.DEFAULT_TOKENIZER
33 | bleu_score = sacrebleu.corpus_bleu(
34 | predictions, [references], tokenize=sacrebleu_tokenizer).score
35 | return {name: bleu_score}
36 |
37 |
38 | def get_eval_fpath(ckpt_dir, ckpt_step, eval_split):
39 | output_dir = str(pathlib.Path(ckpt_dir).parents[0])
40 | return os.path.join(output_dir, 'bleu_' + eval_split + '_' + str(ckpt_step))
41 |
42 |
43 | def load_evals(ckpt_dir, ckpt_step, eval_split):
44 | """Loads results if already available, else return None."""
45 | ckpt_eval_fpath = get_eval_fpath(ckpt_dir, ckpt_step, eval_split)
46 | if not exists(ckpt_eval_fpath):
47 | return None
48 | else:
49 | with gfile.GFile(ckpt_eval_fpath, 'r') as f:
50 | bleu_score = f.readlines()[-1]
51 | return float(bleu_score.strip())
52 |
53 |
54 | def save_evals(ckpt_dir, ckpt_step, eval_split, bleu_score):
55 | ckpt_eval_fpath = get_eval_fpath(ckpt_dir, ckpt_step, eval_split)
56 | with gfile.GFile(ckpt_eval_fpath, 'w') as f:
57 | f.write(str(bleu_score))
58 |
59 |
60 | def _load_checkpoint(checkpoint_path, params):
61 | """Load model (and batch stats) from checkpoint."""
62 | target = dict(
63 | params=params,
64 | global_step=-1,
65 | preemption_count=0,
66 | sum_train_cost=0.0)
67 | ckpt = checkpoint.load_checkpoint(
68 | checkpoint_path,
69 | target=target,
70 | )
71 | params = ckpt['params']
72 | return params
73 |
74 |
75 | def average_checkpoints(checkpoint_paths, params):
76 | """Averages a set of checkpoints in input checkpoints."""
77 | assert len(checkpoint_paths) >= 1
78 | # Sum parameters of separate models together.
79 | params = _load_checkpoint(checkpoint_paths[0], params)
80 | for checkpoint_path in checkpoint_paths[1:]:
81 | params_update = _load_checkpoint(
82 | checkpoint_path, params
83 | )
84 | # TODO(dxin): Make this averaging process more numerically stable.
85 | params = jax.tree.map(lambda x, y: x + y, params, params_update)
86 |
87 | # Average checkpoints.
88 | params = jax.tree.map(lambda x: x / float(len(checkpoint_paths)), params)
89 | return params
90 |
91 |
92 | def get_checkpoints_in_range(checkpoint_dir, lower_bound, upper_bound):
93 | """Get checkpoint paths in step range [lower_bound, upper_bound]."""
94 | checkpoint_paths = []
95 | for checkpoint_path in glob(os.path.join(checkpoint_dir, 'ckpt_*')):
96 | ckpt_step = int(checkpoint_path.split('_')[-1])
97 | if ckpt_step >= lower_bound and ckpt_step <= upper_bound:
98 | checkpoint_paths.append(os.path.join(checkpoint_dir, checkpoint_path))
99 | return checkpoint_paths
100 |
--------------------------------------------------------------------------------
/init2winit/optimizer_lib/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/init2winit/optimizer_lib/factor_sam.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Efficient implementation of Sharpness Aware Minimization (SAM).
17 |
18 | Applies SAM learning rule every k steps, and factorizes the perturbation
19 | radius and the regularization strength.
20 | """
21 |
22 | import functools
23 | from typing import Optional
24 |
25 | from init2winit.model_lib import model_utils
26 | import jax
27 | import jax.numpy as jnp
28 | import optax
29 |
30 | _GRAD_CLIP_EPS = 1e-6
31 |
32 |
33 | def normalize_vector(y: jnp.ndarray) -> jnp.ndarray:
34 | """Returns unit norm version of original pytree.
35 |
36 | Args:
37 | y: A pytree of numpy ndarray, vector y in the equation above.
38 | """
39 | gradient_norm = jnp.sqrt(
40 | sum([jnp.sum(jnp.square(e)) for e in jax.tree_util.tree_leaves(y)]))
41 | normalized_gradient = jax.tree.map(lambda x: x / gradient_norm, y)
42 | return normalized_gradient
43 |
44 |
45 | def clean_update(updates, state, unused_grad_fn_params_tuple):
46 | """Returns the clean update function."""
47 | return updates, state
48 |
49 |
50 | def sam_update(
51 | updates,
52 | state,
53 | grad_fn_params_tuple,
54 | rho=0.1,
55 | alpha=1.0,
56 | ):
57 | """SAM update function."""
58 | (grad_fn, params) = grad_fn_params_tuple
59 | updates = normalize_vector(updates)
60 | noised_params = jax.tree_util.tree_map(
61 | lambda p, u: p + rho * u, params, updates
62 | )
63 | _, sam_updates = grad_fn(noised_params)
64 |
65 | # Regularizer gradient - difference between SAM and clean updates.
66 | sam_updates = jax.tree.map(lambda x, y: x - y, sam_updates, updates)
67 |
68 | # Rescale and apply regularizer
69 | updates = jax.tree.map(lambda x, y: x + alpha * y, updates, sam_updates)
70 | return updates, state
71 |
72 |
73 | def sharpness_aware_minimization(
74 | rho: float,
75 | alpha: float,
76 | k: int,
77 | grad_clip: Optional[float],
78 | base_opt_init_fn,
79 | base_opt_update_fn,
80 | ) -> optax.GradientTransformation:
81 | """Implementation of Sharpness Aware Minimization (SAM).
82 |
83 | Paper: https://arxiv.org/abs/2010.01412
84 | Code: https://github.com/google-research/sam
85 |
86 | References:
87 | Foret et al, 2021: https://arxiv.org/abs/2010.01412
88 | Args:
89 | rho: The size of the neighborhood for the sharpness aware minimization
90 | gradient updates. Defaults to 0.1.
91 | alpha: Additional scaling factor for regularization strength.
92 | k: Period on which to apply SAM. Regularization strength is scaled by k.
93 | grad_clip: The optional value to clip the updates by. Defaults to None.
94 | base_opt_init_fn: The initialization function for the base optimizer used to
95 | generate updates given the total gradient.
96 | base_opt_update_fn: The update function for the base optimizer used to
97 | generate updates given the total gradient.
98 |
99 | Returns:
100 | The corresponding `GradientTransformation`.
101 | """
102 |
103 | def init_fn(params):
104 | return base_opt_init_fn(params)
105 |
106 | # TODO(thetish): Implement version which applies SAM before averaging over
107 | # devices.
108 | def update_fn(updates, state, grad_fn_params_tuple):
109 | # Updates here have been averaged across devices in Trainer before being
110 | # sent to the optimizer.
111 | (_, params) = grad_fn_params_tuple
112 |
113 | # Update function in between SAM steps.
114 | intermediate_update_fn = clean_update
115 |
116 | # Sam update. Scale alpha by k to keep optimal rho independent of k.
117 | alpha_eff = alpha * k
118 | sam_update_fn = functools.partial(sam_update, rho=rho, alpha=alpha_eff)
119 | updates, state = jax.lax.cond( # Apply SAM every k steps.
120 | state.count % k == 0,
121 | sam_update_fn,
122 | intermediate_update_fn,
123 | updates,
124 | state,
125 | grad_fn_params_tuple,
126 | )
127 | # Clipping
128 | if grad_clip:
129 | updates_norm = jnp.sqrt(model_utils.l2_regularization(updates, 0))
130 | scaled_updates = jax.tree.map(
131 | lambda x: x / (updates_norm + _GRAD_CLIP_EPS) * grad_clip, updates)
132 | updates = jax.lax.cond(updates_norm > grad_clip, lambda _: scaled_updates,
133 | lambda _: updates, None)
134 | # TODO(thetish): Explore different order for base optimizer and SAM. For
135 | # example, in Adam preconditioning the SAM perturbation is helpful.
136 | return base_opt_update_fn(updates, state, params) # Apply base optimizer
137 |
138 | return optax.GradientTransformation(init_fn, update_fn)
139 |
--------------------------------------------------------------------------------
/init2winit/optimizer_lib/kitchen_sink/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Kitchen Sink: decomposing optimizers in JAX."""
17 |
18 | from init2winit.optimizer_lib.kitchen_sink._src.alias import adamw_generic
19 | from init2winit.optimizer_lib.kitchen_sink._src.alias import adapropw
20 | from init2winit.optimizer_lib.kitchen_sink._src.alias import nadampw
21 | from init2winit.optimizer_lib.kitchen_sink._src.alias import nadamw
22 | from init2winit.optimizer_lib.kitchen_sink._src.core import kitchen_sink
23 | from init2winit.optimizer_lib.kitchen_sink._src.transform import add_decayed_weights
24 | from init2winit.optimizer_lib.kitchen_sink._src.transform import bias_correction
25 | from init2winit.optimizer_lib.kitchen_sink._src.transform import BiasCorrectionState
26 | from init2winit.optimizer_lib.kitchen_sink._src.transform import clip_updates
27 | from init2winit.optimizer_lib.kitchen_sink._src.transform import first_moment_ema
28 | from init2winit.optimizer_lib.kitchen_sink._src.transform import nesterov
29 | from init2winit.optimizer_lib.kitchen_sink._src.transform import nesterovpp
30 | from init2winit.optimizer_lib.kitchen_sink._src.transform import polyak_averaging
31 | from init2winit.optimizer_lib.kitchen_sink._src.transform import Polyak_AveragingState
32 | from init2winit.optimizer_lib.kitchen_sink._src.transform import polyak_hb
33 | from init2winit.optimizer_lib.kitchen_sink._src.transform import precondition_by_amsgrad
34 | from init2winit.optimizer_lib.kitchen_sink._src.transform import precondition_by_layered_adaptive_rms
35 | from init2winit.optimizer_lib.kitchen_sink._src.transform import precondition_by_rms
36 | from init2winit.optimizer_lib.kitchen_sink._src.transform import precondition_by_rss
37 | from init2winit.optimizer_lib.kitchen_sink._src.transform import precondition_by_yogi
38 | from init2winit.optimizer_lib.kitchen_sink._src.transform import PreconditionByLayeredAdaptiveRMSState
39 | from init2winit.optimizer_lib.kitchen_sink._src.transform import PreconditionByRssState
40 | from init2winit.optimizer_lib.kitchen_sink._src.transform import PreconditionBySecondMomentCoordinateWiseState
41 | from init2winit.optimizer_lib.kitchen_sink._src.transform import sanitize_values
42 | from init2winit.optimizer_lib.kitchen_sink._src.transform import scale_by_adam
43 | from init2winit.optimizer_lib.kitchen_sink._src.transform import scale_by_adaprop
44 | from init2winit.optimizer_lib.kitchen_sink._src.transform import scale_by_amsgrad
45 | from init2winit.optimizer_lib.kitchen_sink._src.transform import scale_by_learning_rate
46 | from init2winit.optimizer_lib.kitchen_sink._src.transform import scale_by_nadam
47 | from init2winit.optimizer_lib.kitchen_sink._src.transform import ScaleByAdamState
48 | from init2winit.optimizer_lib.kitchen_sink._src.transform import ScaleByAMSGradState
49 | from init2winit.optimizer_lib.kitchen_sink._src.utils import unfreeze_wrapper
50 |
51 |
52 | __version__ = '0.0.1'
53 |
54 | __all__ = (
55 | 'nadamw',
56 | 'nadampw',
57 | 'adamw_generic',
58 | 'adapropw',
59 | 'kitchen_sink',
60 | 'bias_correction',
61 | 'BiasCorrectionState',
62 | 'clip_updates',
63 | 'first_moment_ema',
64 | 'nesterov',
65 | 'nesterovpp',
66 | 'add_decayed_weights',
67 | 'polyak_averaging',
68 | 'Polyak_AveragingState',
69 | 'polyak_hb',
70 | 'precondition_by_amsgrad',
71 | 'precondition_by_layered_adaptive_rms',
72 | 'precondition_by_rms',
73 | 'precondition_by_rss',
74 | 'precondition_by_yogi',
75 | 'PreconditionByLayeredAdaptiveRMSState',
76 | 'PreconditionByRssState',
77 | 'PreconditionBySecondMomentCoordinateWiseState',
78 | 'sanitize_values',
79 | 'scale_by_adam',
80 | 'scale_by_amsgrad',
81 | 'scale_by_adaprop',
82 | 'scale_by_learning_rate',
83 | 'scale_by_nadam',
84 | 'ScaleByAdamState',
85 | 'ScaleByAMSGradState',
86 | 'unfreeze_wrapper',
87 | )
88 |
--------------------------------------------------------------------------------
/init2winit/optimizer_lib/kitchen_sink/_src/combine.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Combine utilities."""
17 | import functools
18 | from typing import Any, NamedTuple
19 | from typing import Callable
20 | from typing import Optional
21 | from typing import Tuple
22 | from typing import Union
23 | import jax
24 | import jax.numpy as jnp
25 | import optax
26 |
27 | # TODO(dsuo): Add back grafting combinator.
28 |
29 |
30 | def join(by: Union[str, Callable[[optax.GradientTransformation, ...],
31 | optax.Updates]], *args,
32 | **kwargs) -> Callable[..., optax.GradientTransformation]:
33 | """Join multiple chains."""
34 |
35 | if by is None or by == 'chain':
36 | return lambda *args, **kwargs: optax.chain(*(args + tuple(kwargs.values())))
37 | if isinstance(by, str):
38 | if by not in combinator_registry:
39 | raise ValueError(f'Unrecognized `by` function {by}.')
40 | by_init, by_update = combinator_registry[by](*args, **kwargs)
41 |
42 | # TODO(dsuo): match docs/autocomplete with combinator args.
43 | def transform(*args, **kwargs):
44 |
45 | def init(params: optax.Params) -> optax.OptState:
46 | args_state = tuple(chain.init(params) for chain in args)
47 | kwargs_state = {
48 | name: chain.init(params) for name, chain in kwargs.items()
49 | }
50 | combinator_state = by_init(params, *args_state, **kwargs_state)
51 | return combinator_state, args_state, kwargs_state
52 |
53 | def update(
54 | updates: optax.Updates,
55 | state: optax.OptState,
56 | params: Optional[optax.Params] = None
57 | ) -> Tuple[optax.Updates, optax.OptState]:
58 | combinator_state, args_state, kwargs_state = state
59 |
60 | args_results = [
61 | chain.update(updates, state, params)
62 | for chain, state in zip(args, args_state)
63 | ]
64 | args_updates = tuple(result[0] for result in args_results)
65 | args_state = tuple(result[1] for result in args_results)
66 |
67 | kwargs_results = {
68 | name: chain.update(updates, kwargs_state[name], params)
69 | for name, chain in kwargs.items()
70 | }
71 | kwargs_updates = {
72 | name: result[0] for name, result in kwargs_results.items()
73 | }
74 | kwargs_state = {
75 | name: result[1] for name, result in kwargs_results.items()
76 | }
77 |
78 | updates, combinator_state = by_update(
79 | combinator_state, *args_updates, **kwargs_updates
80 | )
81 |
82 | return updates, (combinator_state, args_state, kwargs_state)
83 |
84 | return optax.GradientTransformation(init, update)
85 |
86 | return transform
87 |
88 |
89 | def _grafting_helper(chain, use_global_norm=False):
90 | norm = jax.tree.map(jnp.linalg.norm, chain)
91 | if use_global_norm:
92 | global_norm = jax.tree_util.tree_reduce(lambda x, y: jnp.sqrt(x**2 + y**2),
93 | norm)
94 | norm = jax.tree.map(lambda x: global_norm, norm)
95 | return norm
96 |
97 |
98 | class GraftingState(NamedTuple):
99 | """State for the Layered Adaptive RMS Preconditioner algorithm."""
100 | mag_norm: Any
101 | dir_norm: Any
102 |
103 |
104 | def combine_by_grafting(eps: float = 0.0, use_global_norm: bool = False):
105 | """Grafting combinator.
106 |
107 | Args:
108 | eps (float, optional): term added to D normalization denominator for
109 | numerical stability (default: 1e-16)
110 | use_global_norm (bool, optional): graft global l2 norms rather than
111 | per-layer (default: False)
112 |
113 | Returns:
114 | updates in the shape of params.
115 | """
116 |
117 | def init(params, *args, **kwargs):
118 | del args, kwargs
119 | mag_norm = jax.tree.map(lambda x: 0.0, params)
120 | dir_norm = jax.tree.map(lambda x: 0.0, params)
121 |
122 | return GraftingState(mag_norm=mag_norm, dir_norm=dir_norm)
123 |
124 | def update(state, mag_chain, dir_chain):
125 | del state
126 | mag_norm = _grafting_helper(mag_chain, use_global_norm=use_global_norm)
127 | dir_norm = _grafting_helper(dir_chain, use_global_norm=use_global_norm)
128 |
129 | updates = jax.tree.map(
130 | lambda dir, dirn, magn: dir / (dirn + eps) * magn,
131 | dir_chain,
132 | dir_norm,
133 | mag_norm,
134 | )
135 |
136 | return updates, GraftingState(mag_norm=mag_norm, dir_norm=dir_norm)
137 |
138 | return init, update
139 |
140 |
141 | def combine_by_sum():
142 | """Sum combinator.
143 |
144 | Returns:
145 | updates in the shape of params.
146 | """
147 |
148 | def init(params, *args, **kwargs):
149 | del args, kwargs, params
150 | return optax.EmptyState()
151 |
152 | def update(state, *args, **kwargs):
153 | args = args + tuple(kwargs.values())
154 | return functools.reduce(
155 | lambda x, y: jax.tree_multimap(lambda i, j: i + j, x, y), args), state
156 |
157 | return init, update
158 |
159 |
160 | combinator_registry = {
161 | 'grafting': combine_by_grafting,
162 | 'sum': combine_by_sum,
163 | }
164 |
--------------------------------------------------------------------------------
/init2winit/optimizer_lib/kitchen_sink/_src/core.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Modularizing optimization ideas.
17 |
18 | This project seeks to take ideas in optimization (e.g., scale decay,
19 | momentum) and understand when, how, and some insight into why they are
20 | effective.
21 | """
22 |
23 | from typing import Any, Dict
24 |
25 | from init2winit.optimizer_lib.kitchen_sink._src import utils
26 | from init2winit.optimizer_lib.kitchen_sink._src.combine import join
27 | from init2winit.optimizer_lib.kitchen_sink._src.mask import mask_registry
28 | from init2winit.optimizer_lib.kitchen_sink._src.transform import transformation_registry
29 | import ml_collections
30 | import optax
31 |
32 |
33 | # TODO(dsuo): document config syntax.
34 |
35 |
36 | def _get_mask(x):
37 | """Find a mask in a given element."""
38 | if 'mask' in x:
39 | mask = x['mask']
40 | elif 'mask' in x.get('hps', {}):
41 | mask = x['hps']['mask']
42 | else:
43 | mask = None
44 |
45 | if mask in mask_registry:
46 | mask = mask_registry[mask]
47 |
48 | return mask
49 |
50 |
51 | def _kitchen_sink_helper(config):
52 | """Recursively chain and join `optax.GradientTransformation`s."""
53 |
54 | if utils.is_leaf(config):
55 | if 'by' in config:
56 | raise KeyError(f'Leaf {config} should not have key `by`.')
57 |
58 | el = config['element']
59 | if el not in transformation_registry:
60 | raise ValueError(f'Transformation {el} not found.')
61 | hps = config.get('hps', {})
62 | tx = transformation_registry[el](**hps)
63 |
64 | else:
65 | if 'hps' in config:
66 | raise KeyError(f'Config {config} should not have key `hps`.')
67 |
68 | to_join = config.get('join', {})
69 | for key, val in to_join.items():
70 | to_join[key] = _kitchen_sink_helper(val)
71 |
72 | # Defaults to `None`, which chains child components together.
73 | by = config.get('by')
74 | by_kwargs = config.get('by_kwargs', {})
75 |
76 | tx = join(by, **by_kwargs)(**to_join)
77 |
78 | mask = _get_mask(config)
79 | if mask is not None:
80 | tx = optax.masked(tx, mask)
81 |
82 | return tx
83 |
84 |
85 | def kitchen_sink(config: Dict[str, Any],
86 | learning_rate: float = None) -> optax.GradientTransformation:
87 | """Runs a list of GradientTransforms in parallel and combines.
88 |
89 | Args:
90 | config: dictionary configuring an optimizer.
91 | learning_rate: learning rate that gets injected.
92 |
93 | Returns:
94 | optax.GradientTransform
95 | """
96 | # Cast to dict in case we have an ml_collections.ConfigDict.
97 |
98 | if isinstance(config, ml_collections.ConfigDict):
99 | config = config.to_dict()
100 | elif not isinstance(config, dict):
101 | raise ValueError(
102 | 'Kitchen Sink configuration needs to be a config dict or a python dict')
103 |
104 | # Syntactic sugar. If we have an implied chain, make it explicitly a chain.
105 | if all([str(i) in config for i in range(len(config))]):
106 | config = {'join': config}
107 |
108 | # Handle `one_minus_` hps, if any.
109 | config = utils.map_element(utils.handle_one_minus, config)
110 |
111 | # Apply learning rate to any existing scale_by_learning_rate
112 | if learning_rate is not None:
113 | config = utils.apply_and_maybe_scale_by_learning_rate(config, learning_rate)
114 |
115 | return utils.unfreeze_wrapper(*_kitchen_sink_helper(config))
116 |
--------------------------------------------------------------------------------
/init2winit/optimizer_lib/kitchen_sink/_src/mask.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Mask utilities."""
17 | import flax
18 |
19 |
20 | def create_mask(fn):
21 | """Creates a mask that maps fn over the leaves of a dict.
22 |
23 | Args:
24 | fn: function to apply taking - k: Tuple containing nodes (strings) in path
25 | to the leaf - v: The leaf
26 |
27 | Returns:
28 | mask: function that takes dict and returns mapped dict
29 | """
30 |
31 | def mask(data):
32 | flattened_dict = flax.traverse_util.flatten_dict(data)
33 | return flax.traverse_util.unflatten_dict(
34 | {k: fn(k, v) for k, v in flattened_dict.items()})
35 |
36 | return mask
37 |
38 |
39 | def create_weight_decay_mask():
40 | return create_mask(
41 | lambda p, _: 'bias' not in p and not p[-2].startswith('BatchNorm'))
42 |
43 | mask_registry = {
44 | 'bias_bn': create_weight_decay_mask(),
45 | }
46 |
--------------------------------------------------------------------------------
/init2winit/optimizer_lib/kitchen_sink/_src/test_mask.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for utils."""
17 |
18 | from typing import Sequence
19 |
20 | from absl.testing import absltest
21 | import chex
22 | import flax
23 | import flax.linen as nn
24 | from init2winit.optimizer_lib.kitchen_sink._src.mask import create_mask
25 | from init2winit.optimizer_lib.kitchen_sink._src.mask import create_weight_decay_mask
26 | import jax
27 | import jax.numpy as jnp
28 | import optax
29 |
30 | # pylint:disable=duplicate-key
31 |
32 |
33 | class Foo(nn.Module):
34 | """Dummy model."""
35 |
36 | train: bool
37 | filters: int
38 |
39 | @nn.compact
40 | def __call__(self, x):
41 | x = nn.Conv(self.filters, (1, 1), use_bias=False, dtype=jnp.float32)(x)
42 | x = nn.BatchNorm(
43 | use_running_average=not self.train,
44 | momentum=0.9,
45 | epsilon=1e-5,
46 | dtype=jnp.float32)(
47 | x)
48 | return x
49 |
50 |
51 | class Bar(nn.Module):
52 | """Dummy model."""
53 |
54 | features: Sequence[int]
55 |
56 | @nn.compact
57 | def __call__(self, inputs):
58 | x = inputs
59 | for i, feat in enumerate(self.features):
60 | x = nn.Dense(feat, use_bias=True, name=f'layers_{i}')(x)
61 | if i != len(self.features) - 1:
62 | x = nn.relu(x)
63 | return x
64 |
65 |
66 | class CreateMaskTest(chex.TestCase):
67 | """Test masking."""
68 |
69 | def test_simple(self):
70 | """Check if the leaf key is `a`."""
71 | mask = create_mask(lambda path, _: path[-1] == 'a')
72 | data = {'a': 4, 'b': {'a': 5, 'c': 1}, 'c': {'a': {'b': 1}}}
73 |
74 | truth = {'a': True, 'b': {'a': True, 'c': False}, 'c': {'a': {'b': False}}}
75 |
76 | chex.assert_equal(mask(data), truth)
77 |
78 |
79 | class CreateWeightDecayMaskTest(chex.TestCase):
80 | """Test weight decay mask."""
81 |
82 | def test_simple(self):
83 | """Check that the correct tags are removed."""
84 | mask = create_weight_decay_mask()
85 | data = {
86 | 'bias': {
87 | 'b': 4
88 | },
89 | 'bias': {
90 | 'BatchNorm_0': 4,
91 | 'bias': 5,
92 | 'a': 0
93 | },
94 | 'BatchNorm_0': {
95 | 'b': 4
96 | },
97 | 'a': {
98 | 'b': {
99 | 'BatchNorm_0': 0,
100 | 'bias': 0
101 | },
102 | 'c': 0
103 | }
104 | }
105 | truth = {
106 | 'bias': {
107 | 'b': False
108 | },
109 | 'bias': {
110 | 'BatchNorm_0': False,
111 | 'bias': False,
112 | 'a': False
113 | },
114 | 'BatchNorm_0': {
115 | 'b': False
116 | },
117 | 'a': {
118 | 'b': {
119 | 'BatchNorm_0': True,
120 | 'bias': False
121 | },
122 | 'c': True
123 | }
124 | }
125 |
126 | chex.assert_equal(mask(data), truth)
127 |
128 | @chex.variants(with_jit=True, without_jit=True)
129 | def test_batch(self):
130 | """Test that batch layer is indeed ignored.
131 |
132 | Code taken from: https://github.com/google/flax/issues/932
133 | """
134 | key = jax.random.PRNGKey(0)
135 | x = jnp.ones((5, 4, 4, 3))
136 | y = jax.random.uniform(key, (5, 4, 4, 7))
137 |
138 | foo_vars = flax.core.unfreeze(Foo(filters=7, train=True).init(key, x))
139 | tx = optax.masked(optax.adam(1e-7), create_weight_decay_mask())
140 |
141 | @self.variant
142 | def train_step(params, x, y):
143 | y1, new_batch_stats = Foo(
144 | filters=7, train=True).apply(
145 | params, x, mutable=['batch_stats'])
146 |
147 | return jnp.abs(y - y1).sum(), new_batch_stats
148 |
149 | state = self.variant(tx.init)(foo_vars['params'])
150 | grads, _ = jax.grad(train_step, has_aux=True)(foo_vars, x, y)
151 | updates, state = self.variant(tx.update)(dict(grads['params']), state)
152 |
153 | chex.assert_trees_all_close(updates['BatchNorm_0'],
154 | grads['params']['BatchNorm_0'])
155 |
156 | @chex.variants(with_jit=True, without_jit=True)
157 | def test_bias(self):
158 | """Test that biases are ignored."""
159 | key1, key2 = jax.random.split(jax.random.PRNGKey(0), 2)
160 | x = jax.random.uniform(key1, (4, 4))
161 |
162 | model = Bar(features=[3, 4, 5])
163 | params = flax.core.unfreeze(model.init(key2, x))
164 | y = jax.random.uniform(key1, model.apply(params, x).shape)
165 |
166 | tx = optax.masked(optax.adam(1e-7), create_weight_decay_mask())
167 | state = tx.init(params)
168 |
169 | def loss(params, x, y):
170 | pred = model.apply(params, x)
171 | return jnp.abs(pred - y).sum()
172 |
173 | grads = jax.grad(loss)(params, x, y)
174 | updates, state = self.variant(tx.update)(dict(grads), state)
175 |
176 | for i in range(3):
177 | chex.assert_trees_all_close(grads['params'][f'layers_{i}']['bias'],
178 | updates['params'][f'layers_{i}']['bias'])
179 |
180 |
181 | if __name__ == '__main__':
182 | absltest.main()
183 |
--------------------------------------------------------------------------------
/init2winit/optimizer_lib/kitchen_sink/_src/utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Optimizer utilities."""
17 |
18 | import copy
19 | import operator
20 |
21 | from absl import logging
22 | import flax
23 | import jax
24 | import jax.numpy as jnp
25 | import optax
26 |
27 |
28 | def total_tree_sum(pytree):
29 | """Compute the overall sum of a pytree."""
30 | sums = jax.tree.map(jnp.sum, pytree)
31 | return jax.tree_util.tree_reduce(operator.add, sums, 0)
32 |
33 |
34 | def tree_norm_sql2(pytree):
35 | """Compute the param-wise squared L2 norm of a pytree."""
36 | return jax.tree.map(lambda x: jnp.linalg.norm(x.reshape(-1)) ** 2, pytree)
37 |
38 |
39 | def total_tree_norm_sql2(pytree):
40 | """Compute the overall squared L2 norm of a pytree."""
41 | sql2_norms = tree_norm_sql2(pytree)
42 | return jax.tree_util.tree_reduce(operator.add, sql2_norms, 0)
43 |
44 |
45 | def is_leaf(x):
46 | return isinstance(x, dict) and 'element' in x
47 |
48 |
49 | def map_element(fn, config, true_leaf_fn=None):
50 | if not isinstance(config, dict):
51 | if true_leaf_fn is not None:
52 | return true_leaf_fn(config)
53 | else:
54 | return config
55 | elif 'element' in config:
56 | return fn(config)
57 | else:
58 | return {k: map_element(fn, v, true_leaf_fn) for k, v in config.items()}
59 |
60 |
61 | def unfreeze_wrapper(init_fn, update_fn):
62 | """Freeze/unfreeze params."""
63 |
64 | # NOTE(dsuo): We use plain dicts internally due to this issue
65 | # https://github.com/deepmind/optax/issues/160.
66 | def wrapped_init_fn(params):
67 | return init_fn(flax.core.unfreeze(params))
68 |
69 | def wrapped_update_fn(updates, state, params=None):
70 | new_updates, state = update_fn(
71 | flax.core.unfreeze(updates), state,
72 | None if params is None else flax.core.unfreeze(params))
73 |
74 | if isinstance(updates, flax.core.FrozenDict):
75 | new_updates = flax.core.freeze(new_updates)
76 |
77 | return new_updates, state
78 |
79 | return optax.GradientTransformation(wrapped_init_fn, wrapped_update_fn)
80 |
81 |
82 | def handle_one_minus(x):
83 | if 'hps' in x:
84 | for hp in copy.deepcopy(x['hps']).keys():
85 | if 'one_minus_' in hp:
86 | x['hps'][hp.replace('one_minus_', '')] = 1 - x['hps'][hp]
87 | del x['hps'][hp]
88 | return x
89 |
90 |
91 | def apply_and_maybe_scale_by_learning_rate(config, learning_rate):
92 | """Apply learning rate and possibly scale by learning rate."""
93 |
94 | def is_scale_by_lr(x):
95 | return not isinstance(x, str) and x['element'] == 'scale_by_learning_rate'
96 |
97 | def contains_lr_as_param(x):
98 | return not isinstance(x, str) and x.get(
99 | 'hps', None) and 'learning_rate' in x['hps']
100 |
101 | def update_leaf(x):
102 | if contains_lr_as_param(x):
103 | x['hps']['learning_rate'] = learning_rate
104 | return x
105 | return x
106 |
107 | scaled = map_element(is_scale_by_lr, config, true_leaf_fn=lambda x: False)
108 | num_scaled = jax.tree_util.tree_reduce(lambda x, y: x + y, scaled, 0)
109 |
110 | if num_scaled == 0:
111 | return {
112 | 'join': {
113 | '0': config,
114 | '1': {
115 | 'element': 'scale_by_learning_rate',
116 | 'hps': {
117 | 'learning_rate': learning_rate
118 | }
119 | }
120 | }
121 | }
122 | elif num_scaled == 1:
123 | return map_element(update_leaf, config)
124 | else:
125 | logging.warning('Kitchen Sink configuration has more than one '
126 | 'scale_by_learning_rate. Please double check config')
127 |
--------------------------------------------------------------------------------
/init2winit/optimizer_lib/linalg/README.md:
--------------------------------------------------------------------------------
1 | Linear algebra package for non-diagonal preconditioning.
--------------------------------------------------------------------------------
/init2winit/optimizer_lib/linalg/low_rank_root_update_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Test for computing the sqrt and inverse sqrt of a matrix."""
17 |
18 | from typing import Tuple
19 |
20 | from absl.testing import absltest
21 | from absl.testing import parameterized
22 | from init2winit.optimizer_lib.linalg import low_rank_root_update
23 | import jax
24 | import numpy as np
25 | import scipy.stats
26 |
27 |
28 | def _small_perturbation(n: int, gamma: float,
29 | rng: np.random.RandomState) -> np.ndarray:
30 | """Returns a vector of absolute values ofnormally distributed values with standard deviation gamma."""
31 | s = gamma*np.abs(rng.normal(size=n))
32 | return s
33 |
34 |
35 | def _random_singular_values(n: int, gamma: float,
36 | rng: np.random.RandomState) -> np.ndarray:
37 | """Returns n random singular values in [γ, 1]."""
38 | s = gamma**rng.random((n,)) # log of singular values uniformly distributed
39 | if n > 0:
40 | s[0] = gamma
41 | if n > 1:
42 | s[1] = 1
43 | return s
44 |
45 |
46 | def _random_svd(
47 | n: int, gamma: float,
48 | rng: np.random.RandomState) -> Tuple[np.ndarray, np.ndarray]:
49 | """Returns a random SVD decomposition with singular values in [γ, 1]."""
50 | # sample a uniformly random orthogonal matrix.
51 | v = scipy.stats.ortho_group.rvs(n, random_state=rng)
52 | s = _random_singular_values(n, gamma, rng)
53 | return s, v
54 |
55 |
56 | @jax.jit
57 | def _update_sqrt(x, ix, g):
58 | ra_size = np.where(
59 | x.shape[-1] < 64, x.shape[-1], np.minimum(x.shape[-1] // 12, 64)
60 | )
61 | rank_array = np.zeros(ra_size)
62 | return low_rank_root_update.low_rank_root_update(
63 | x, ix, g, rank_array, 1e-6, 2
64 | )
65 |
66 |
67 | class InvSquareRootTest(parameterized.TestCase):
68 |
69 | @parameterized.named_parameters(
70 | { # pylint:disable=g-complex-comprehension
71 | 'testcase_name': f'n={n}',
72 | 'n': n, # pylint: disable=undefined-variable
73 | 'p': p, # pylint: disable=undefined-variable
74 | }
75 | for n in [2, 31]
76 | for p in [2]
77 | )
78 | def test_random_matrix(self, n, p):
79 | rng = np.random.RandomState(seed=42)
80 |
81 | sigma = 1e-2 # smallest singular value of test matrix
82 | s, v = _random_svd(n, sigma, rng)
83 | s = s.astype(np.float64)
84 | v = v.astype(np.float64)
85 | q = _small_perturbation(n, 1e-4, rng)
86 | q = np.diag(q.astype(np.float64))
87 | a_sqrt = (v * s**(1 / p)) @ v.T
88 | a_isqrt = (v * s**(-1 / p)) @ v.T
89 | exact = (v * (s + q**2)**(-1 / p)) @ v.T
90 | ans = _update_sqrt(a_sqrt, a_isqrt, q)[1]
91 | ans = np.array(ans).astype(np.float64)
92 | error = np.linalg.norm(ans - exact, 2) / np.linalg.norm(exact, 2)
93 | kappa = 1 / p / sigma
94 | expected_error = 3 * kappa * np.finfo(np.float32).eps
95 | self.assertLessEqual(error, expected_error)
96 |
97 |
98 | if __name__ == '__main__':
99 | jax.config.update('jax_enable_x64', False)
100 | absltest.main()
101 |
--------------------------------------------------------------------------------
/init2winit/optimizer_lib/linalg/paterson_stockmeyer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Paterson-Stockmeyer method for polynomial evaluation."""
17 | from typing import Any, Callable, List, Sequence, TypeVar
18 |
19 | import numpy as np
20 |
21 | T = TypeVar('T')
22 |
23 |
24 | def _powers(x: T, n: int, product: Callable[[T, T], T]) -> List[T]:
25 | """Returns the list [x, x², ..., xⁿ]."""
26 | xp = [None] * (n + 1)
27 | xp[1] = x
28 | for j in range(2, n + 1):
29 | # To reduce round-off, compute xʲ as the result of O(log j) mutliplies
30 | xp[j] = product(xp[j // 2], xp[(j + 1) // 2])
31 | return xp[1:]
32 |
33 |
34 | def polynomial_no_constant(a: Sequence[Any], x: T, product: Callable[[T, T],
35 | T]) -> T:
36 | """Paterson-Stockmeyer evaluation of a[0] x + a[1] x² + ... + a[n-1] xⁿ.
37 |
38 | A variant of the Paterson-Stockmeyer method for polynomial evaluation
39 | presented in [2], which avoids using the multiplicative identity (x⁰). The
40 | algorithm uses only ⌈2√n⌉ - 2 multiplications instead of n - 1, making it
41 | especially suitable when multiplications are expensive, e.g., for matrices.
42 | The reduced number of multiplications is accomplished by grouping the terms as
43 |
44 | (a[0] x + a[1] x² + ... + a[ s-1] xˢ) +
45 | xˢ (a[s] x + a[s+1] x² + ... + a[2s-1] xˢ) +
46 | (xˢ)² (a[2s] x + a[2s+1] x² + ... + a[3s-1] xˢ) +
47 | ...
48 |
49 | with s = ⌈√n⌉. The powers up to xˢ are precomputed with s - 1 multiplications,
50 | allowing all the (at most) degree s polynomials in parentheses above to be
51 | evaluated. These are then combined using Horner's rule with ⌈n/s⌉ - 1
52 | subsequent multiplications.
53 |
54 | [1] Michael S. Paterson and Larry J. Stockmeyer, "On the number of nonscalar
55 | multiplications necessary to evaluate polynomials," SIAM J. Comput., 2
56 | (1973), pp. 60–66.
57 |
58 | [2] M. Fasi, "Optimality of the Paterson-Stockmeyer method for evaluating
59 | matrix polynomials and rational matrix functions," Linear Algebra Appl.,
60 | 574 (2019), pp. 182–200.
61 |
62 | Args:
63 | a: Polynomial coefficients. a[j] is the coefficient of xʲ⁺¹.
64 | x: Argument to evaluate the polynomial at.
65 | product: Multiplication function.
66 |
67 | Returns:
68 | The polynomial a[0] x + a[1] x² + ... + a[n-1] xⁿ .
69 |
70 | Raises:
71 | ValueError if `a` is empty.
72 | """
73 | n = len(a)
74 | if n == 0:
75 | raise ValueError('polynomial_no_constant: coefficients empty.')
76 | s = int(np.ceil(np.sqrt(n)))
77 | xp = _powers(x, s, product)
78 | inner = lambda alpha: sum([cj * xj for (cj, xj) in zip(alpha, xp)])
79 | inner_poly = lambda i: inner(a[s * i:min(n, s * (i + 1))])
80 | i = (n + s - 1) // s - 1
81 | y = inner_poly(i)
82 | for i in reversed(range(i)):
83 | y = inner_poly(i) + product(xp[s - 1], y)
84 | return y
85 |
--------------------------------------------------------------------------------
/init2winit/optimizer_lib/linalg/root_selector.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Depending on a variable, selects whether to use exact or approximate root function."""
17 |
18 | import functools
19 | from typing import Tuple
20 |
21 | import chex
22 | from init2winit.optimizer_lib.linalg import low_rank_root_update
23 | from init2winit.optimizer_lib.linalg import pth_inv_root_rmn
24 | from jax import lax
25 | import numpy as np
26 |
27 |
28 | def root_selector(
29 | x: chex.Array,
30 | sx: chex.Array,
31 | isx: chex.Array,
32 | up: chex.Array,
33 | p: int,
34 | eps: float,
35 | exact_root: bool,
36 | rank_estimate: int,
37 | block_krylov_dim_multiplier: int = 2,
38 | stable_iter: bool = False,
39 | unroll: bool = False,
40 | verbose: bool = True,
41 | ) -> Tuple[chex.Array, chex.Array]:
42 | """Returns |X|^{1/p} and |X|⁻¹ᐟᵖ.
43 |
44 | Args:
45 | x: Input matrix must be SPD with eigenvalues >= float32 epsilon.
46 | sx: Old sqrt of the matrix X.
47 | isx: Old inverse sqrt of the matrix X.
48 | up: Update to the matrix of the form update @ update.T
49 | p: Exponent.
50 | eps: small constant to avoid numerical issues with Lyapunov solver.
51 | exact_root: If True, solve the root exactly, otherwise.
52 | rank_estimate: Rank estimate of the update.
53 | block_krylov_dim_multiplier: Multiplier for the block krylov dimension over
54 | the rank estimate.
55 | stable_iter: Whether to use the stable iteration for the inner loop.
56 | unroll: Whether to unroll the loop over iterations.
57 | verbose: Whether to log some information about the iteration, including the
58 | coefficients `a` for each iteration.
59 |
60 | Returns:
61 | An approximation of |X|⁻¹ᐟᵖ.
62 | """
63 | f_er = functools.partial(
64 | pth_inv_root_rmn.pth_inv_root_rmn,
65 | fast_root=True,
66 | precision="float32",
67 | stable_iter=stable_iter,
68 | unroll=unroll,
69 | verbose=verbose,
70 | )
71 |
72 | def _exact_root():
73 | return f_er(x, p)
74 |
75 | rank_array = np.zeros(np.where(
76 | x.shape[-1] < 64,
77 | x.shape[-1],
78 | np.minimum(x.shape[-1] // 8, rank_estimate),
79 | ))
80 | f_ar = functools.partial(
81 | low_rank_root_update.low_rank_root_update,
82 | rank_array=rank_array,
83 | eps=eps,
84 | block_krylov_dim_multiplier=block_krylov_dim_multiplier,
85 | verbose=verbose,
86 | )
87 | def _approx_root():
88 | return f_ar(sx, isx, up)
89 |
90 | return lax.cond(exact_root, _exact_root, _approx_root)
91 |
--------------------------------------------------------------------------------
/init2winit/optimizer_lib/online_newton_step.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Optimizers for online newton step algorithms."""
17 |
18 | from typing import NamedTuple
19 |
20 | from init2winit.optimizer_lib import kitchen_sink
21 | from init2winit.optimizer_lib import utils
22 | import jax
23 | import jax.numpy as jnp
24 | import optax
25 |
26 |
27 | def diag_ons(learning_rate,
28 | weight_decay: float = 0.0,
29 | b1: float = 0.9,
30 | b2: float = 0.999,
31 | eps: float = 1e-8):
32 | """The diagonal version of Online Newton Step with flexible updates.
33 |
34 | Args:
35 | learning_rate: A fixed global scaling factor.
36 | weight_decay: weight decay.
37 | b1: Exponential decay rate to track the first moment of past gradients.
38 | b2: Exponential decay rate to track the second moment of past gradients.
39 | eps: A small constant applied to denominator outside of the square root (as
40 | in the Adam paper) to avoid dividing by zero when rescaling.
41 |
42 | Returns:
43 | The corresponding `GradientTransformation`.
44 | """
45 | if b1 == 1.0 and b2 == 1.0:
46 | # Diag ONS without momentum and second moment decay
47 | return optax.chain(
48 | kitchen_sink.precondition_by_rss(eps=eps, power=1.0),
49 | optax.add_decayed_weights(weight_decay), optax.scale(learning_rate))
50 | elif b1 == 1.0 and b2 != 1.0:
51 | # Diag ONS without momentum but with second moment decay
52 | return optax.chain(
53 | kitchen_sink.precondition_by_rms(
54 | decay=b2, eps=eps, eps_root=0.0, power=1.0),
55 | optax.add_decayed_weights(weight_decay), optax.scale(learning_rate))
56 | elif b1 != 1.0 and b2 != 1.0:
57 | # Diag ONS with momentum and second moment decay
58 | return optax.chain(
59 | kitchen_sink.scale_by_adam(b1, b2, eps, eps_root=0.0, power=1.0),
60 | optax.add_decayed_weights(weight_decay), optax.scale(learning_rate))
61 |
62 |
63 | def last_layer_transformation(last_layer_optimizer, base_lr,
64 | last_layer_base_lr, learning_rate):
65 | """Use an optimizer while scaling by a different learning rate."""
66 |
67 | return optax.chain(last_layer_optimizer,
68 | optax.scale(learning_rate * last_layer_base_lr / base_lr))
69 |
70 |
71 | def sherman_morrison(a_inv, u, alpha):
72 | """Given A^-1, compute (A + alpha * u u^T)^-1 using Sherman-Morrison."""
73 | denom = 1 + alpha * u.T @ a_inv @ u
74 | numer = alpha * jnp.outer(a_inv @ u, u) @ a_inv
75 |
76 | return a_inv - numer / denom
77 |
78 |
79 | class OnlineNewtonState(NamedTuple):
80 | """State holding the sum of gradient squares to date."""
81 | inv_hessian: optax.Updates
82 |
83 |
84 | def full_matrix_ons(alpha, initial_accumulator_value=0.1):
85 | """A full Online Newton Step transformation."""
86 |
87 | def init_fn(params):
88 | raveled_params, _ = jax.flatten_util.ravel_pytree(params)
89 | initial_hessian = jnp.diag(
90 | jnp.full_like(raveled_params, 1. / initial_accumulator_value))
91 |
92 | return OnlineNewtonState(inv_hessian=initial_hessian)
93 |
94 | def update_fn(updates, state, params=None):
95 | del params
96 |
97 | raveled_updates, unravel = jax.flatten_util.ravel_pytree(updates)
98 | new_hessian = sherman_morrison(state.inv_hessian, raveled_updates, alpha)
99 | new_updates = unravel(new_hessian @ raveled_updates)
100 |
101 | return new_updates, OnlineNewtonState(inv_hessian=new_hessian)
102 |
103 | return optax.GradientTransformation(init_fn, update_fn)
104 |
105 |
106 | def online_newton_step(learning_rate, alpha, weight_decay):
107 | r"""An optimizer that does full matrix preconditioning."""
108 |
109 | return optax.chain(
110 | optax.add_decayed_weights(weight_decay), full_matrix_ons(alpha),
111 | optax.sgd(learning_rate))
112 |
113 |
114 | def multiple_optimizer(last_layer_name, network_optimizer, last_layer_optimizer,
115 | last_layer_base_lr, base_lr):
116 | """Use a different optimizer for the last layer."""
117 |
118 | def get_select_fn(layer_name):
119 | """Get a function that selects the specified layer as last layer."""
120 |
121 | def select_layer(tree):
122 | return {k: ('ll' if k == layer_name else 'net') for k, v in tree.items()}
123 |
124 | return select_layer
125 |
126 | return kitchen_sink.unfreeze_wrapper(*optax.multi_transform(
127 | {
128 | 'net':
129 | network_optimizer,
130 | # Scale the learning rate of the last layer according to match
131 | # last_layer_base_lr
132 | 'll':
133 | utils.static_inject_hyperparams(last_layer_transformation)
134 | (last_layer_optimizer,
135 | base_lr,
136 | last_layer_base_lr,
137 | learning_rate=0.0),
138 | },
139 | get_select_fn(last_layer_name)))
140 |
--------------------------------------------------------------------------------
/init2winit/optimizer_lib/samuel.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Implementation of the SAMUEL optimizer.
17 |
18 | Paper: https://arxiv.org/pdf/2203.01400.pdf
19 | """
20 |
21 | import copy
22 | from typing import Any
23 | from typing import Dict
24 | from typing import List
25 | from typing import NamedTuple
26 |
27 | from init2winit.optimizer_lib.utils import static_inject_hyperparams
28 | import jax
29 | import jax.numpy as jnp
30 | import optax
31 |
32 |
33 | class SamuelState(NamedTuple):
34 | inner_state: NamedTuple
35 | expert_weights: jnp.ndarray
36 | key: jnp.ndarray
37 | current_expert: int = 0
38 | step: int = 0
39 |
40 |
41 | def samuel(
42 | optimizers: List[str],
43 | hps: List[Dict[str, Any]],
44 | mw_etas: jnp.ndarray,
45 | seed: int = 0,
46 | train_loss: float = 0.0,
47 | learning_rate: float = 0.0,
48 | ):
49 | """Samuel optimizer.
50 |
51 | NOTES
52 | - This implementation assumes each host is an expert (i.e., holds a copy of
53 | the model). As a consequence, we must modify the input and training
54 | pipelines to forgot data parallelism across hosts and limit to only the
55 | local devices available to a given host.
56 | - We synchronize after each batch. This is not always necessary and can be
57 | a point of future performance optimization.
58 |
59 | TODO(dsuo): add LR schedules for optimizers.
60 |
61 | Args:
62 | optimizers: list of strings indicating optax optimizers.
63 | hps: list of hps for each optimizer.
64 | mw_etas: list of multiplicative weight etas.
65 | seed: initial jax random seed.
66 | train_loss: train loss to be injected at update time.
67 | learning_rate: for compatability, but ignored for now.
68 |
69 | Returns:
70 | samuel optimizer
71 | """
72 | del learning_rate
73 |
74 | num_experts = len(optimizers)
75 | mw_etas = jnp.array(mw_etas)
76 |
77 | if num_experts != jax.process_count():
78 | raise ValueError(
79 | 'This implementation of SAMUEL requires the number of optimizers to be '
80 | 'equal to the number of hosts (one host per expert).')
81 |
82 | optimizer = optimizers[jax.process_index()]
83 | hps = hps[jax.process_index()]
84 |
85 | optimizer = getattr(optax, optimizer)(**hps)
86 |
87 | def init_fn(params):
88 | return SamuelState(
89 | inner_state=optimizer.init(params),
90 | expert_weights=jnp.repeat(mw_etas, num_experts, axis=1),
91 | # NOTE(dsuo): each init gives the same key, but given the changing
92 | # params for each model, this is not an issue.
93 | key=jax.random.PRNGKey(seed),
94 | )
95 |
96 | def update_fn(updates, state, params):
97 | del params
98 |
99 | key, subkey = jax.random.split(state.key)
100 |
101 | # Compute updates based on inner optimizer
102 | updates, inner_state = optimizer.update(updates, state.inner_state)
103 |
104 | prob = state.expert_weights.sum(axis=0) / state.expert_weights.sum()
105 |
106 | # NOTE(dsuo): we rely on jax determinism for each host to behave the same.
107 | current_expert = jax.random.choice(subkey, jnp.arange(prob.size), p=prob)
108 |
109 | # Synchronize train_losses across hosts.
110 | # NOTE(dsuo): since we are already insider a pmap, we can't use
111 | # jax.experimental.multihost_utils.
112 | # NOTE(dsuo): train_losses is of shape (jax.process_count(),).
113 | train_losses = jax.lax.all_gather(train_loss, 'batch').reshape(
114 | jax.process_count(), jax.local_device_count())[:, 0]
115 |
116 | # Compute loss regret and update expert weights.
117 | loss_regret = train_losses.at[current_expert].get() - train_losses
118 | expert_weights = state.expert_weights * jnp.exp(mw_etas * loss_regret)
119 |
120 | state = SamuelState(
121 | inner_state=inner_state,
122 | expert_weights=expert_weights,
123 | key=key,
124 | current_expert=current_expert,
125 | step=state.step + 1,
126 | )
127 | return updates, state
128 |
129 | return optax.GradientTransformation(init_fn, update_fn)
130 |
131 |
132 | def from_hparams(opt_hparams):
133 | """Create SAMUEL optimizer from init2winit."""
134 | opt_hparams_optimizers = opt_hparams['optimizers']
135 |
136 | optimizers = []
137 | hps = []
138 | index = 0
139 | while str(index) in opt_hparams_optimizers:
140 | hparams = opt_hparams_optimizers[str(index)]
141 | hp = hparams.get('hps', {})
142 |
143 | for h in copy.deepcopy(hp).keys():
144 | if 'one_minus_' in h:
145 | hp[h.replace('one_minus_', '')] = 1 - hp[h]
146 | del hp[h]
147 |
148 | optimizers.append(hparams.get('optimizer'))
149 | hps.append(hp)
150 | index += 1
151 |
152 | return static_inject_hyperparams(samuel)(
153 | optimizers=optimizers, hps=hps, **opt_hparams['args'])
154 |
--------------------------------------------------------------------------------
/init2winit/optimizer_lib/sharpness_aware_minimization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Implementation of Sharpness Aware Minimization (SAM).
17 |
18 | This implementation is still being actively evaluated by the MLCommons Training
19 | Algorithms Benchmark, so it should not be used (yet).
20 |
21 | Paper: https://arxiv.org/abs/2010.01412
22 | Code: https://github.com/google-research/sam
23 | """
24 |
25 | from typing import Optional
26 |
27 | from init2winit.model_lib import model_utils
28 |
29 | import jax
30 | import jax.numpy as jnp
31 | import optax
32 |
33 | _GRAD_CLIP_EPS = 1e-6
34 |
35 |
36 | # Copied from the official SAM GitHub repository. Note how it doesn't add an
37 | # epsilon to the gradient norm before normalizing the gradients.
38 | def dual_vector(y: jnp.ndarray) -> jnp.ndarray:
39 | """Returns the solution of max_x y^T x s.t.
40 |
41 | ||x||_2 <= 1.
42 |
43 | Args:
44 | y: A pytree of numpy ndarray, vector y in the equation above.
45 | """
46 | gradient_norm = jnp.sqrt(
47 | sum([jnp.sum(jnp.square(e)) for e in jax.tree_util.tree_leaves(y)]))
48 | normalized_gradient = jax.tree.map(lambda x: x / gradient_norm, y)
49 | return normalized_gradient
50 |
51 |
52 | def sharpness_aware_minimization(
53 | rho: float,
54 | grad_clip: Optional[float],
55 | base_opt_init_fn,
56 | base_opt_update_fn,
57 | ) -> optax.GradientTransformation:
58 | """Implementation of Sharpness Aware Minimization (SAM).
59 |
60 | Paper: https://arxiv.org/abs/2010.01412
61 | Code: https://github.com/google-research/sam
62 |
63 | References:
64 | Foret et al, 2021: https://arxiv.org/abs/2010.01412
65 | Args:
66 | rho: The size of the neighborhood for the sharpness aware minimization
67 | gradient updates. Defaults to 0.1.
68 | grad_clip: The optional value to clip the updates by. Defaults to None.
69 | base_opt_init_fn: The initialization function for the base optimizer used to
70 | generate updates given the total gradient.
71 | base_opt_update_fn: The update function for the base optimizer used to
72 | generate updates given the total gradient.
73 |
74 | Returns:
75 | The corresponding `GradientTransformation`.
76 | """
77 |
78 | def init_fn(params):
79 | return base_opt_init_fn(params)
80 |
81 | def update_fn(updates, state, grad_fn_params_tuple):
82 | (grad_fn, params) = grad_fn_params_tuple
83 |
84 | # Updates here have been averaged across devices in Trainer before being
85 | # sent to the optimizer. We obtain gradients computed on the noised
86 | # parameters in the same order as how Trainer does on the original
87 | # gradients and with the same 1e-6 epsilon that is used when clipping the
88 | # gradients.
89 | updates = dual_vector(updates)
90 | noised_params = jax.tree_util.tree_map(lambda p, u: p + rho * u, params,
91 | updates)
92 | _, updates = grad_fn(noised_params)
93 |
94 | updates_norm = jnp.sqrt(model_utils.l2_regularization(updates, 0))
95 | if grad_clip:
96 | scaled_updates = jax.tree.map(
97 | lambda x: x / (updates_norm + _GRAD_CLIP_EPS) * grad_clip, updates)
98 | updates = jax.lax.cond(updates_norm > grad_clip, lambda _: scaled_updates,
99 | lambda _: updates, None)
100 | updates, state = base_opt_update_fn(updates, state, params)
101 |
102 | return updates, state
103 |
104 | return optax.GradientTransformation(init_fn, update_fn)
105 |
--------------------------------------------------------------------------------
/init2winit/optimizer_lib/test_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for utils."""
17 |
18 | from absl.testing import absltest
19 | import chex
20 | from init2winit.optimizer_lib import optimizers
21 | from init2winit.optimizer_lib import utils
22 | import jax
23 | import jax.numpy as jnp
24 | from ml_collections.config_dict import ConfigDict
25 | import optax
26 |
27 |
28 | # pylint:disable=duplicate-key
29 |
30 |
31 | class ExtractFieldTest(chex.TestCase):
32 | """Test the extract_field() function."""
33 |
34 | def test_adam(self):
35 | init_fn, update_fn = optimizers.get_optimizer(
36 | ConfigDict({
37 | 'optimizer': 'adam',
38 | 'l2_decay_factor': None,
39 | 'batch_size': 50,
40 | 'total_accumulated_batch_size': 100, # Use gradient accumulation.
41 | 'opt_hparams': {
42 | 'beta1': 0.9,
43 | 'beta2': 0.999,
44 | 'epsilon': 1e-7,
45 | 'weight_decay': 0.0,
46 | },
47 | })
48 | )
49 | del update_fn
50 | optimizer_state = init_fn({'foo': jnp.ones(10)})
51 | # Test that we can extract 'count'.
52 | chex.assert_type(utils.extract_field(optimizer_state, 'count'), int)
53 | # Test that we can extract 'nu'.
54 | chex.assert_shape(utils.extract_field(optimizer_state, 'nu')['foo'], (10,))
55 | # Test that we can extract 'mu'.
56 | chex.assert_shape(utils.extract_field(optimizer_state, 'mu')['foo'], (10,))
57 | # Test that attemptping to extract a nonexistent field "abc" returns None.
58 | chex.assert_equal(utils.extract_field(optimizer_state, 'abc'), None)
59 |
60 |
61 | class GradientAggregationDecoratorTest(chex.TestCase):
62 | """Test the requires_gradient_aggregation() decorator."""
63 |
64 | def test_no_aggregation(self):
65 | """Tests behavior with the decorator."""
66 |
67 | @utils.no_cross_device_gradient_aggregation
68 | def dummy_update_fn(updates, state, params):
69 | del updates, state, params
70 |
71 | self.assertFalse(utils.requires_gradient_aggregation(dummy_update_fn))
72 |
73 | def test_with_aggregation(self):
74 | """Tests the default behavior."""
75 |
76 | def dummy_update_fn(updates, state, params):
77 | del updates, state, params
78 |
79 | self.assertTrue(utils.requires_gradient_aggregation(dummy_update_fn))
80 |
81 |
82 | class OverwriteHparamNamesTest(chex.TestCase):
83 | """Test the overwrite_hparam_names() function."""
84 |
85 | def test_overwrite_hparams_names(self):
86 | init_params = jnp.array([1.0, 2.0, 3.0])
87 |
88 | def fun(x):
89 | return 0.5 * jnp.sum(x**2)
90 |
91 | # If we were to setting up the learning rate, we would stick at the current
92 | # params
93 | opt = optax.inject_hyperparams(optax.sgd)(learning_rate=0.0)
94 |
95 | state = opt.init(init_params)
96 |
97 | @jax.jit
98 | def step(params, state):
99 | grad = jax.grad(fun)(params)
100 | updates, state = opt.update(grad, state)
101 | params = optax.apply_updates(params, updates)
102 | return params, state
103 |
104 | params = init_params
105 | for _ in range(5):
106 | params, state = step(params, state)
107 |
108 | norm_diff = jnp.linalg.norm(init_params - params)
109 | self.assertEqual(norm_diff, 0.0)
110 |
111 | # If we set the learning rate via lr, we descend well
112 | opt = optax.inject_hyperparams(optax.sgd)(learning_rate=0.0)
113 | opt = utils.overwrite_hparam_names(opt, learning_rate='lr')
114 |
115 | state = opt.init(init_params)
116 | state = optax.tree_utils.tree_set(state, lr=0.5)
117 |
118 | params = init_params
119 | for i in range(5):
120 | state = optax.tree_utils.tree_set(state, lr=1 / (i + 2))
121 | params, state = step(params, state)
122 | lr = optax.tree_utils.tree_get(state, 'lr')
123 | self.assertEqual(lr, 1 / (i + 2))
124 |
125 | self.assertLessEqual(fun(params), fun(init_params))
126 |
127 |
128 | class AppendHparamName(chex.TestCase):
129 | """Test the append_hparam_name() function."""
130 |
131 | def test_append_hparam_name(self):
132 | init_params = jnp.array([1.0, 2.0, 3.0])
133 |
134 | def fun(x):
135 | return 0.5 * jnp.sum(x**2)
136 |
137 | opt = optax.inject_hyperparams(optax.sgd)(learning_rate=0.5)
138 | new_opt = utils.append_hparam_name(opt, 'foo')
139 |
140 | # Test that we can access and set the new hparam
141 | state = new_opt.init(init_params)
142 | state = optax.tree_utils.tree_set(state, foo=2.)
143 | foo = optax.tree_utils.tree_get(state, 'foo')
144 | self.assertEqual(foo, 2.0)
145 |
146 | # Test that the optimizer runs without any issue
147 | @jax.jit
148 | def step(params, state):
149 | grad = jax.grad(fun)(params)
150 | updates, state = new_opt.update(grad, state)
151 | params = optax.apply_updates(params, updates)
152 | return params, state
153 |
154 | params = init_params
155 | for _ in range(5):
156 | params, state = step(params, state)
157 |
158 | self.assertLessEqual(fun(params), fun(init_params))
159 |
160 |
161 | if __name__ == '__main__':
162 | absltest.main()
163 |
--------------------------------------------------------------------------------
/init2winit/projects/optlrschedule/README.md:
--------------------------------------------------------------------------------
1 | # README
2 |
--------------------------------------------------------------------------------
/init2winit/projects/optlrschedule/notebook_utils/parquet_util.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | r"""Utility functions for loading parquet files in notebooks."""
17 |
18 | import io
19 | from typing import List, Optional
20 |
21 | from etils import epath
22 | import utils as i2wutils # local file import
23 | import pandas as pd
24 |
25 |
26 | def load_parquet_file(
27 | path: str,
28 | file_name: Optional[str] = None,
29 | *,
30 | sort_by: str = 'score',
31 | ascending: bool = True,
32 | include_provenance: bool = False,
33 | ) -> pd.DataFrame:
34 | """Load a single parquet file and return it as a sorted DataFrame.
35 |
36 | Args:
37 | path: Directory path string
38 | file_name (optional): File name string (default: 'results.parquet')
39 | sort_by: Column to sort by (default: 'score')
40 | ascending: Sort order (default: True)
41 | include_provenance: Whether to include the provenance of the data in the
42 | DataFrame (default: False). If set, adds a column 'provenance' to the
43 | DataFrame with the path of the file.
44 |
45 | Returns:
46 | pandas DataFrame
47 | """
48 |
49 | if file_name:
50 | path = epath.Path(path) / file_name
51 | else:
52 | path = epath.Path(path)
53 |
54 | # Read the file
55 | with path.open('rb') as in_f:
56 | buf = io.BytesIO(in_f.read())
57 | df = pd.read_parquet(buf)
58 |
59 | # Sort if the column exists
60 | if sort_by in df.columns:
61 | df.sort_values(by=sort_by, ascending=ascending, inplace=True)
62 |
63 | if include_provenance:
64 | df['provenance'] = str(path)
65 | return df
66 |
67 |
68 | def load_all_parquet_files(
69 | paths: List[str],
70 | file_name: Optional[str] = None,
71 | *,
72 | sort_by: str = 'score',
73 | ascending: bool = True,
74 | include_provenance: bool = False,
75 | num_workers: int = 50,
76 | ) -> pd.DataFrame:
77 | """Load and merge all parquet files from different paths.
78 |
79 | Args:
80 | paths: List of directory paths.
81 | file_name (optional): File name string (default: 'results.parquet')
82 | sort_by: Column to sort by (default: 'score').
83 | ascending: Sort order (default: True).
84 | include_provenance: Whether to include the provenance of the data in the
85 | DataFrame (default: False). If set, adds a column 'provenance' to the
86 | DataFrame with the path of the file each row came from.
87 | num_workers: Number of workers to use for parallel loading (default: 50).
88 |
89 | Returns:
90 | Merged pandas DataFrame.
91 | """
92 | shared_kwargs = {
93 | 'file_name': file_name,
94 | 'sort_by': sort_by,
95 | 'ascending': ascending,
96 | 'include_provenance': include_provenance,
97 | }
98 | kwargs_list = []
99 | for path in paths:
100 | kwargs_list.append({
101 | 'path': path,
102 | **shared_kwargs,
103 | })
104 | dfs = i2wutils.run_in_parallel(load_parquet_file, kwargs_list, num_workers)
105 | if dfs:
106 | # Concat will ignore empty DataFrames properly.
107 | merged_df = pd.concat(dfs, ignore_index=True)
108 | return merged_df
109 | else:
110 | return pd.DataFrame()
111 |
112 |
113 | # TODO(gdahl): delete this function once we are happy with the parallel version.
114 | def load_all_parquet_files_sequentially(
115 | paths: List[str],
116 | file_name: Optional[str] = None,
117 | *,
118 | sort_by: str = 'score',
119 | ascending: bool = True,
120 | include_provenance: bool = False,
121 | ) -> pd.DataFrame:
122 | """Load and merge all parquet files from different paths.
123 |
124 | Args:
125 | paths: List of directory paths.
126 | file_name (optional): File name string (default: 'results.parquet')
127 | sort_by: Column to sort by (default: 'score').
128 | ascending: Sort order (default: True).
129 | include_provenance: Whether to include the provenance of the data in the
130 | DataFrame (default: False). If set, adds a column 'provenance' to the
131 | DataFrame with the path of the file each row came from.
132 |
133 | Returns:
134 | Merged pandas DataFrame.
135 | """
136 | dfs = []
137 |
138 | for path in paths:
139 | df = load_parquet_file(
140 | path,
141 | file_name,
142 | sort_by=sort_by,
143 | ascending=ascending,
144 | include_provenance=include_provenance,
145 | )
146 | if not df.empty:
147 | dfs.append(df)
148 |
149 | if dfs:
150 | merged_df = pd.concat(dfs, ignore_index=True)
151 | return merged_df
152 | else:
153 | return pd.DataFrame()
154 |
--------------------------------------------------------------------------------
/init2winit/projects/optlrschedule/scheduler/constant_schedule_family.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Constant schedule family implementation using optax."""
17 |
18 | from typing import Dict
19 | from init2winit.projects.optlrschedule.scheduler import base_schedule_family
20 | import numpy as np
21 |
22 |
23 | class ConstantScheduleFamily(base_schedule_family.WarmupScheduleFamily):
24 | """Constant learning rate schedule with configurable warmup."""
25 |
26 | def list_schedule_parameter_keys(self) -> list[str]:
27 | return ['p.warmup_steps']
28 |
29 | def get_schedule(
30 | self,
31 | schedule_param: Dict[str, float],
32 | base_lr: float,
33 | ) -> np.ndarray:
34 | """Generate constant learning rate schedule with warmup.
35 |
36 | Args:
37 | schedule_param: Dictionary containing schedule parameters.
38 | base_lr: Base learning rate.
39 |
40 | Returns:
41 | np.ndarray: Array of learning rates for each training step.
42 | """
43 | self.validate_param(schedule_param)
44 | schedule = np.zeros(self.total_steps)
45 | warmup_config = self.schedule_family_config.get('warmup_config', {})
46 |
47 | warmup_steps = int(schedule_param.get('p.warmup_steps', 0))
48 | if warmup_steps > 0:
49 | # Warmup phase
50 | warmup_fn = self.get_warmup_fn(self.warmup_type)
51 | for step in range(warmup_steps):
52 | schedule[step] = warmup_fn(
53 | step, warmup_steps, base_lr, **warmup_config
54 | )
55 |
56 | # Constant phase
57 | schedule[warmup_steps:] = base_lr
58 | else:
59 | # No warmup, constant throughout
60 | schedule[:] = base_lr
61 |
62 | return schedule
63 |
--------------------------------------------------------------------------------
/init2winit/projects/optlrschedule/scheduler/cosine_schedule_family.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Cosine schedule family implementation using optax.
17 |
18 | Our code is based on the optax implementation of cosine decay schedule as listed
19 | here:
20 | https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html#optax.schedules.cosine_decay_schedule
21 |
22 | schedule_family_config: Dictionary containing configuration such as:
23 | total_steps: Maximum number of training updates.
24 |
25 | schedule_params:
26 | warmup_steps: Number of warmup steps.
27 | alpha: Decay factor.
28 | """
29 |
30 | from typing import Any, Dict
31 |
32 | from init2winit.projects.optlrschedule.scheduler import base_schedule_family
33 | import numpy as np
34 |
35 |
36 | class CosineScheduleFamily(base_schedule_family.WarmupScheduleFamily):
37 | """Cosine schedule with configurable warmup methods."""
38 |
39 | def validate_config(self, config: Dict[str, Any]) -> None:
40 | """Validate configuration parameters."""
41 |
42 | if 'alpha' not in config:
43 | raise ValueError('alpha must be specified in config')
44 | if config['alpha'] < 0.0:
45 | raise ValueError('alpha must be non-negative')
46 |
47 | if (
48 | 'warmup_type' in config
49 | and config['warmup_type'] not in base_schedule_family.WARMUP_TYPES
50 | ):
51 | raise ValueError(
52 | 'warmup_type must be one of linear, cosine, exponential, or'
53 | ' polynomial'
54 | )
55 |
56 | def validate_param(
57 | self, schedule_param: base_schedule_family.ScheduleParams
58 | ) -> bool:
59 | """Validate schedule parameters."""
60 | super().validate_param(schedule_param)
61 |
62 | required_params = {'p.exponent'}
63 | missing_params = required_params - set(schedule_param.keys())
64 | if missing_params:
65 | raise ValueError(f'Missing required parameters: {missing_params}')
66 |
67 | if not isinstance(schedule_param['p.exponent'], (int, float)):
68 | raise ValueError('exponent must be a number')
69 | if not (0.0 <= schedule_param['p.exponent']):
70 | raise ValueError('exponent must be larger than 0.0')
71 |
72 | return True
73 |
74 | def cosine_decay(
75 | self,
76 | step: int,
77 | decay_steps: int,
78 | base_lr: float,
79 | alpha: float,
80 | exponent: float,
81 | ) -> float:
82 | """Cosine decay from base_lr to alpha * base_lr.
83 |
84 | Args:
85 | step: Current training step
86 | decay_steps: Number of decay steps
87 | base_lr: Base learning rate
88 | alpha: Decay factor
89 | exponent: Power to raise the cosine decay to
90 |
91 | Returns:
92 | float: Learning rate at the current step
93 | """
94 | progress = step / decay_steps
95 | cosine = (1 + np.cos(np.pi * progress)) / 2
96 | decayed = (1 - alpha) * (cosine**exponent) + alpha
97 | return base_lr * decayed
98 |
99 | def list_schedule_parameter_keys(self) -> list[str]:
100 | return ['p.warmup_steps', 'p.exponent']
101 |
102 | def get_schedule(
103 | self,
104 | schedule_param: Dict[str, float],
105 | base_lr: float,
106 | ) -> np.ndarray:
107 | """Generate learning rate schedule based on parameters.
108 |
109 | Args:
110 | schedule_param: Dictionary containing schedule parameters
111 | base_lr: Base learning rate
112 |
113 | Returns:
114 | np.ndarray: Array of learning rates for each training step
115 | """
116 | # self.validate_param(schedule_param)
117 |
118 | alpha = self.schedule_family_config['alpha']
119 | warmup_steps = int(schedule_param.get('p.warmup_steps', 0))
120 | exponent = schedule_param.get('p.exponent', 1.0)
121 |
122 | schedule = np.zeros(self.total_steps)
123 | warmup_fn = self.get_warmup_fn(self.warmup_type)
124 | warmup_config = self.schedule_family_config.get('warmup_config', {})
125 |
126 | # Warmup phase
127 | for step in range(warmup_steps):
128 | schedule[step] = warmup_fn(step, warmup_steps, base_lr, **warmup_config)
129 |
130 | # Decay phase
131 | decay_steps = self.total_steps - warmup_steps
132 | for step in range(warmup_steps, self.total_steps):
133 | decay_step = step - warmup_steps
134 | schedule[step] = self.cosine_decay(
135 | decay_step, decay_steps, base_lr, alpha, exponent
136 | )
137 |
138 | return schedule
139 |
--------------------------------------------------------------------------------
/init2winit/projects/optlrschedule/scheduler/cosine_standard_schedule_family.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Cosine schedule family implementation with fixed exponent of 1.
17 |
18 | The only difference with cosine_schedule_family.py is that the exponent is
19 | fixed to 1.0.
20 |
21 | schedule_family_config: Dictionary containing configuration such as:
22 | total_steps: Maximum number of training updates.
23 |
24 | schedule_params:
25 | warmup_steps: Number of warmup steps.
26 | alpha: Decay factor.
27 | """
28 |
29 | from typing import Dict
30 |
31 | from init2winit.projects.optlrschedule.scheduler import base_schedule_family
32 | from init2winit.projects.optlrschedule.scheduler import cosine_schedule_family
33 | import numpy as np
34 |
35 |
36 | class CosineStandardScheduleFamily(cosine_schedule_family.CosineScheduleFamily):
37 | """Cosine schedule with configurable warmup methods and exponent of 1.0."""
38 |
39 | def validate_param(
40 | self, schedule_param: base_schedule_family.ScheduleParams
41 | ) -> bool:
42 | """Validate schedule parameters."""
43 | return base_schedule_family.WarmupScheduleFamily.validate_param(
44 | self, schedule_param
45 | )
46 |
47 | def list_schedule_parameter_keys(self) -> list[str]:
48 | return ['p.warmup_steps']
49 |
50 | def get_schedule(
51 | self,
52 | schedule_param: Dict[str, float],
53 | base_lr: float,
54 | ) -> np.ndarray:
55 | """Generate learning rate schedule based on parameters.
56 |
57 | Args:
58 | schedule_param: Dictionary containing schedule parameters
59 | base_lr: Base learning rate
60 |
61 | Returns:
62 | np.ndarray: Array of learning rates for each training step
63 | """
64 |
65 | alpha = self.schedule_family_config['alpha']
66 | warmup_steps = int(schedule_param.get('p.warmup_steps', 0))
67 |
68 | schedule = np.zeros(self.total_steps)
69 | warmup_fn = self.get_warmup_fn(self.warmup_type)
70 | warmup_config = self.schedule_family_config.get('warmup_config', {})
71 |
72 | # Warmup phase
73 | for step in range(warmup_steps):
74 | schedule[step] = warmup_fn(step, warmup_steps, base_lr, **warmup_config)
75 |
76 | # Decay phase
77 | decay_steps = self.total_steps - warmup_steps
78 | for step in range(warmup_steps, self.total_steps):
79 | decay_step = step - warmup_steps
80 | schedule[step] = self.cosine_decay(
81 | decay_step, decay_steps, base_lr, alpha, 1.0
82 | )
83 |
84 | return schedule
85 |
--------------------------------------------------------------------------------
/init2winit/projects/optlrschedule/scheduler/rex_schedule_family.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """REX schedule family implementation.
17 |
18 | W generalize REX schedule by introducing a beta parameter and remove coefficient
19 | of 1/2 in the denominator.
20 |
21 | arxiv: https://arxiv.org/pdf/2107.04197 (MLSys2022)
22 | github: https://github.com/IvanVassi/REX_LR/blob/main/lr_scheduler.py
23 |
24 | The difference between this implementation and the original REX schedule is
25 | that the beta parameter is provided as a schedule parameter.
26 |
27 | Original REX schedule:
28 | progress = t/T
29 | beta = 0.9 (in code) or 1 (in paper)
30 | return (1 - progress) / ((1 - progress * beta)/2 + 1/2)
31 |
32 | Our Generalized REX schedule:
33 | progress = t/T
34 | beta = [0, inf]
35 | alpha = 1 - beta
36 | return (1 - progress) / (1 - progress * alpha)
37 | """
38 |
39 | from typing import Any
40 | from init2winit.projects.optlrschedule.scheduler import base_schedule_family
41 | import numpy as np
42 |
43 |
44 | class RexScheduleFamily(base_schedule_family.WarmupScheduleFamily):
45 | """REX schedule implementation with configurable warmup using NumPy.
46 |
47 | The default values for max_val and min_val in schedule_config are 1 and 0,
48 | respectively,
49 | and the learning rate is scaled by base_lr. The beta parameter (default 0.9)
50 | is provided as a schedule parameter.
51 | """
52 |
53 | def validate_param(self, schedule_param: dict[str, Any]) -> bool:
54 | """Validate schedule parameters."""
55 | super().validate_param(schedule_param)
56 | required_params = {
57 | 'p.warmup_steps',
58 | 'p.beta',
59 | } # p.max_val, p.min_val, and p.beta are optional with defaults.
60 | missing_params = required_params - set(schedule_param.keys())
61 | if missing_params:
62 | raise ValueError(
63 | f'Missing required schedule parameters: {missing_params}'
64 | )
65 |
66 | return True
67 |
68 | def list_schedule_parameter_keys(self) -> list[str]:
69 | return ['p.warmup_steps', 'p.beta']
70 |
71 | def rex_decay(self, progress: np.ndarray, beta: float = 0.9) -> np.ndarray:
72 | """Compute the REX decay multiplier.
73 |
74 | Args:
75 | progress: The progress in the decay phase (range from 0 to 1). Measured
76 | from the beginning of the decay phase.
77 | beta: The beta parameter controlling the shape of the decay.
78 |
79 | Returns:
80 | np.ndarray: The REX decay multiplier for each progress value.
81 | """
82 | alpha = 1 - beta
83 | return (1 - progress) / (1 - progress * alpha)
84 |
85 | def get_schedule(
86 | self, schedule_param: dict[str, Any], base_lr: float
87 | ) -> np.ndarray:
88 | """Generate the learning rate schedule for all training steps.
89 |
90 | Args:
91 | schedule_param: Dictionary of schedule parameters (e.g.,
92 | {'p.warmup_steps': 100, 'p.beta': 0.9}).
93 | base_lr: Base learning rate, which is used to scale the schedule.
94 |
95 | Returns:
96 | np.ndarray: An array of learning rates for each training step.
97 | """
98 | # Validate parameters (optional if called externally).
99 | self.validate_param(schedule_param)
100 |
101 | warmup_steps = int(schedule_param['p.warmup_steps'])
102 | beta = schedule_param.get('p.beta', 0.9)
103 |
104 | schedule = np.zeros(self.total_steps)
105 | warmup_fn = self.get_warmup_fn(self.warmup_type)
106 | warmup_config = self.schedule_family_config.get('warmup_config', {})
107 |
108 | # Warmup phase: compute learning rate values during warmup.
109 | # Use base_lr as the target learning rate during warmup.
110 | warmup_lr = []
111 | for step in range(warmup_steps):
112 | lr = warmup_fn(step, warmup_steps, base_lr, **warmup_config)
113 | warmup_lr.append(lr)
114 | warmup_lr = np.array(warmup_lr)
115 | schedule[:warmup_steps] = warmup_lr
116 |
117 | # Decay phase:
118 | decay_steps = self.total_steps - warmup_steps
119 |
120 | # Compute normalized progress (0 to 1) for the decay phase.
121 | steps = np.arange(decay_steps)
122 | progress = steps / decay_steps # progress increases from 0 to 1
123 | # Apply the REX decay function with the specified beta.
124 | decay_multiplier = self.rex_decay(progress, beta=beta)
125 | # Compute learning rate for the decay phase:
126 | # scale by base_lr.
127 | decay_lr = base_lr * decay_multiplier
128 |
129 | schedule[warmup_steps:] = decay_lr
130 | return schedule
131 |
--------------------------------------------------------------------------------
/init2winit/projects/optlrschedule/scheduler/schedule_families.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | r"""Schedule families for learning rate schedules."""
17 |
18 | from typing import Any
19 |
20 | from init2winit.projects.optlrschedule.scheduler import constant_schedule_family
21 | from init2winit.projects.optlrschedule.scheduler import cosine_schedule_family
22 | from init2winit.projects.optlrschedule.scheduler import cosine_standard_schedule_family
23 | from init2winit.projects.optlrschedule.scheduler import rex_schedule_family
24 | from init2winit.projects.optlrschedule.scheduler import smooth_nonmonotonic_schedule_family
25 | from init2winit.projects.optlrschedule.scheduler import sqrt_schedule_family
26 | from init2winit.projects.optlrschedule.scheduler import twopointslinear_schedule_family
27 | from init2winit.projects.optlrschedule.scheduler import twopointsspline_schedule_family
28 |
29 | SCHEDULE_FAMILIES = {
30 | 'cosine': cosine_schedule_family.CosineScheduleFamily,
31 | 'cosine_standard': (
32 | cosine_standard_schedule_family.CosineStandardScheduleFamily
33 | ),
34 | 'constant': constant_schedule_family.ConstantScheduleFamily,
35 | 'twopointsspline': (
36 | twopointsspline_schedule_family.TwoPointSplineScheduleFamily
37 | ),
38 | 'twopointslinear': (
39 | twopointslinear_schedule_family.TwoPointLinearScheduleFamily
40 | ),
41 | 'sqrt': sqrt_schedule_family.SqrtScheduleFamily,
42 | 'smoothnonmonotonic': (
43 | smooth_nonmonotonic_schedule_family.TwoPointSplineSmoothNonmonoticScheduleFamily
44 | ),
45 | 'rex': rex_schedule_family.RexScheduleFamily,
46 | }
47 |
48 |
49 | def get_schedule_family_class(
50 | schedule_type: str,
51 | ) -> type[Any]:
52 | """Get schedule family class for a given schedule type."""
53 | try:
54 | return SCHEDULE_FAMILIES[schedule_type]
55 | except KeyError as e:
56 | raise ValueError(f'Unsupported schedule type: {schedule_type}') from e
57 |
--------------------------------------------------------------------------------
/init2winit/projects/optlrschedule/scheduler/smooth_nonmonotonic_schedule_family.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | r"""Jax implementation of Smooth Non-monotonic Schedule family.
17 |
18 | schedule_family_config = {
19 | total_steps: Total number of training steps
20 | }
21 |
22 | schedule_params = {
23 | 'y_start': Boundary value of initial learning rate
24 | 'y_end': Boundary value of final learning rate
25 | 'x_peak': Peak position of learning rate in horizontal axis
26 | 'y1': Learning rate at first control point
27 | 'delta_x1': Related distance of first control point from start
28 | 'y2': Learning rate at second control point
29 | 'delta_x2': Related distance of second control point from first control
30 | point
31 | }
32 | """
33 |
34 | from typing import Dict, Tuple
35 | from absl import logging
36 | from init2winit.projects.optlrschedule.scheduler import base_schedule_family
37 | import numpy as np
38 | from scipy import interpolate
39 |
40 |
41 | class TwoPointSplineSmoothNonmonoticScheduleFamily(
42 | base_schedule_family.BaseScheduleFamily
43 | ):
44 | """Non-monotonic learning rate scheduler with arbitrary peak placement."""
45 |
46 | def validate_param(self, params: Dict[str, float]) -> None:
47 | """Validate schedule parameters."""
48 | required_params = {
49 | 'p.y_start',
50 | 'p.y_end', # Boundary values
51 | 'p.x_peak', # Peak position
52 | 'p.y1',
53 | 'p.delta_x1', # First normal point
54 | 'p.y2',
55 | 'p.delta_x2', # Second normal point
56 | }
57 | if not all(k in params for k in required_params):
58 | raise ValueError(f'Missing parameters. Required: {required_params}')
59 |
60 | # Validate ranges
61 | for param, value in params.items():
62 | if param.startswith('p.delta_x') and not (0 < value < 1):
63 | raise ValueError(f'{param} must be in (0, 1), got {value}')
64 | if param.startswith('p.y') and not (0 <= value <= 1):
65 | raise ValueError(f'{param} must be in [0, 1], got {value}')
66 | if param == 'p.x_peak' and not (0 < value < 1):
67 | raise ValueError(f'x_peak must be in (0, 1), got {value}')
68 |
69 | def _compute_control_points(
70 | self, params: Dict[str, float]
71 | ) -> Tuple[np.ndarray, np.ndarray]:
72 | """Compute control points using stick-breaking procedure.
73 |
74 | Points can be placed in any order, with peak at any position.
75 |
76 | Args:
77 | params: Dictionary of schedule parameters.
78 |
79 | Returns:
80 | Tuple of control points (x, y).
81 | """
82 | # First point using delta_x1
83 | x1 = params['p.delta_x1']
84 |
85 | # Second point using delta_x2 of remaining space
86 | remaining_space = 1.0 - x1
87 | x2 = x1 + remaining_space * params['p.delta_x2']
88 |
89 | x_points = np.array([0.0, x1, x2, params['p.x_peak'], 1.0])
90 | y_points = np.array(
91 | [params['p.y_start'],
92 | params['p.y1'],
93 | params['p.y2'],
94 | 1.0,
95 | params['p.y_end']]
96 | )
97 | order = np.argsort(x_points)
98 | x_points = x_points[order]
99 | y_points = y_points[order]
100 |
101 | # Ensure uniqueness of x coordinates
102 | unique_indices = np.unique(x_points, return_index=True)[1]
103 | if len(unique_indices) < len(x_points):
104 | logging.warning(
105 | 'Found duplicates in x_points. Reducing from %d to %d unique points.',
106 | len(x_points),
107 | len(unique_indices),
108 | )
109 | x_points = x_points[unique_indices]
110 | y_points = y_points[unique_indices]
111 |
112 | return x_points, y_points
113 |
114 | def list_schedule_parameter_keys(self) -> list[str]:
115 | return [
116 | 'p.y_start',
117 | 'p.y_end',
118 | 'p.x_peak',
119 | 'p.y1',
120 | 'p.delta_x1',
121 | 'p.y2',
122 | 'p.delta_x2',
123 | ]
124 |
125 | def get_schedule(
126 | self, params: Dict[str, float], base_lr: float
127 | ) -> np.ndarray:
128 | """Generate learning rate schedule."""
129 | self.validate_param(params)
130 |
131 | # Compute control points
132 | x_points, y_points = self._compute_control_points(params)
133 | x_steps = x_points * self.total_steps
134 | spline = interpolate.PchipInterpolator(x_steps, y_points)
135 | steps = np.arange(self.total_steps)
136 | lr_array = spline(steps) * base_lr
137 |
138 | return lr_array
139 |
--------------------------------------------------------------------------------
/init2winit/projects/optlrschedule/scheduler/sqrt_schedule_family.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Sqrt decay schedule family implementation using optax.
17 |
18 | Our code is based on the optax implementation of sqrt decay decay schedule
19 |
20 | schedule_family_config: Dictionary containing configuration such as:
21 | total_steps: Maximum number of training updates.
22 |
23 | schedule_params:
24 | warmup_steps: Number of warmup steps.
25 | alpha: Decay factor.
26 | """
27 |
28 | from typing import Any, Dict
29 |
30 | from init2winit.projects.optlrschedule.scheduler import base_schedule_family
31 | import numpy as np
32 |
33 |
34 | class SqrtScheduleFamily(base_schedule_family.WarmupScheduleFamily):
35 | """Sqrt decay schedule with configurable warmup methods."""
36 |
37 | def validate_config(self, config: Dict[str, Any]) -> None:
38 | """Validate configuration parameters."""
39 | super().validate_config(config)
40 |
41 | if (
42 | 'warmup_type' in config
43 | and config['warmup_type'] not in base_schedule_family.WARMUP_TYPES
44 | ):
45 | raise ValueError(
46 | 'warmup_type must be one of linear, cosine, exponential, or'
47 | ' polynomial'
48 | )
49 |
50 | def validate_param(
51 | self, schedule_param: base_schedule_family.ScheduleParams
52 | ) -> bool:
53 | """Validate schedule parameters."""
54 | super().validate_param(schedule_param)
55 |
56 | required_params = {'p.alpha'}
57 | missing_params = required_params - set(schedule_param.keys())
58 | if missing_params:
59 | raise ValueError(f'Missing required parameters: {missing_params}')
60 |
61 | if not isinstance(schedule_param['p.alpha'], (float)):
62 | raise ValueError('alpha must be a number')
63 | if not (0.0 <= schedule_param['p.alpha'] <= 1.0):
64 | raise ValueError('alpha must be in the range [0.0, 1.0]')
65 |
66 | return True
67 |
68 | def sqrt_decay(self, x: float, alpha: float) -> float:
69 | """Sqrt decay function.
70 |
71 | Args:
72 | x: Normalized progress (value between 0 and 1).
73 | alpha: Decay factor.
74 |
75 | Returns:
76 | float: Decay multiplier at the current progress.
77 | """
78 | return np.sqrt(1 - x**2) ** alpha
79 |
80 | def list_schedule_parameter_keys(self) -> list[str]:
81 | return ['p.warmup_steps', 'p.alpha']
82 |
83 | def get_schedule(
84 | self,
85 | schedule_param: Dict[str, float],
86 | base_lr: float,
87 | ) -> np.ndarray:
88 | """Generate learning rate schedule based on parameters.
89 |
90 | Args:
91 | schedule_param: Dictionary containing schedule parameters
92 | base_lr: Base learning rate
93 |
94 | Returns:
95 | np.ndarray: Array of learning rates for each training step
96 | """
97 | warmup_steps = int(schedule_param.get('p.warmup_steps', 0))
98 | alpha = schedule_param.get('p.alpha', 1.0)
99 |
100 | schedule = np.zeros(self.total_steps)
101 | warmup_fn = self.get_warmup_fn(self.warmup_type)
102 | warmup_config = self.schedule_family_config.get('warmup_config', {})
103 |
104 | # Warmup phase
105 | for step in range(warmup_steps):
106 | schedule[step] = warmup_fn(step, warmup_steps, base_lr, **warmup_config)
107 |
108 | # Decay phase
109 | decay_steps = self.total_steps - warmup_steps
110 | decay_base_lr = schedule[warmup_steps - 1] if warmup_steps > 0 else base_lr
111 |
112 | for step in range(warmup_steps, self.total_steps):
113 | decay_step = step - warmup_steps
114 | normalized_progress = (
115 | decay_step / decay_steps
116 | ) # Normalized progress (0 to 1)
117 | decay_multiplier = self.sqrt_decay(normalized_progress, alpha)
118 | schedule[step] = decay_base_lr * decay_multiplier
119 |
120 | return schedule
121 |
--------------------------------------------------------------------------------
/init2winit/projects/optlrschedule/search_algorithm/search_algorithms.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | r"""Search algorithms for schedule parameters search."""
17 |
18 | from typing import Any
19 |
20 | from init2winit.projects.optlrschedule.search_algorithm import (
21 | coordinate_descent_search,
22 | )
23 | from init2winit.projects.optlrschedule.search_algorithm import grid_search
24 | from init2winit.projects.optlrschedule.search_algorithm import random_search
25 |
26 |
27 | SEARCH_ALGORITHMS = {
28 | 'random': random_search.RandomSearch,
29 | 'grid': grid_search.GridSearch,
30 | 'coordinate_descent': coordinate_descent_search.CoordinateDescentSearch,
31 | }
32 |
33 |
34 | def get_search_algorithm_class(search_type: str) -> type[Any]:
35 | """Get search algorithm class for a given search type.
36 |
37 | Args:
38 | search_type: The type of search algorithm to get.
39 |
40 | Returns:
41 | The class of the search algorithm.
42 |
43 | Raises:
44 | ValueError: If the search type is not found.
45 | """
46 | try:
47 | return SEARCH_ALGORITHMS[search_type]
48 | except KeyError as e:
49 | raise ValueError(
50 | f'Search type {search_type} not found in {SEARCH_ALGORITHMS.keys()}'
51 | ) from e
52 |
--------------------------------------------------------------------------------
/init2winit/projects/optlrschedule/workload/workloads.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | r"""Workload classes for different tasks."""
17 |
18 | from typing import Any
19 |
20 | from init2winit.projects.optlrschedule.workload import cifar10_cnn
21 | from init2winit.projects.optlrschedule.workload import linear_regression
22 | from init2winit.projects.optlrschedule.workload import wikitext103_transformer
23 |
24 |
25 | WORKLOADS = {
26 | 'cifar10_cnn': cifar10_cnn.Cifar10Training,
27 | 'wikitext103': wikitext103_transformer.Wikitext103Transformer,
28 | 'linear_regression': linear_regression.LinearRegression,
29 | }
30 |
31 |
32 | def get_workload_class(workload_name: str) -> type[Any]:
33 | """Get workload class for a given workload name.
34 |
35 | Args:
36 | workload_name: The name of the workload to get.
37 |
38 | Returns:
39 | The class of the workload.
40 |
41 | Raises:
42 | ValueError: If the workload name is not found.
43 | """
44 | if workload_name not in WORKLOADS:
45 | raise ValueError(f'Unsupported workload: {workload_name}')
46 | return WORKLOADS[workload_name]
47 |
--------------------------------------------------------------------------------
/init2winit/shared_test_utilities.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Shared utilities for unit tests."""
17 |
18 | import functools
19 |
20 | import jax
21 | import jax.numpy as jnp
22 |
23 |
24 | def pytree_equal(tree1, tree2):
25 | try:
26 | equal_tree = jax.tree_util.tree_map(jnp.array_equal, tree1, tree2)
27 | return jax.tree_util.tree_reduce(lambda x, y: x and y, equal_tree)
28 | # The tree_utils will raise TypeErrors if structures don't match.
29 | except TypeError:
30 | return False
31 |
32 |
33 | def pytree_allclose(tree1, tree2, rtol=1e-5):
34 | try:
35 | allclose = functools.partial(jnp.allclose, rtol=rtol)
36 | equal_tree = jax.tree_util.tree_map(allclose, tree1, tree2)
37 | return jax.tree_util.tree_reduce(lambda x, y: x and y, equal_tree)
38 | # The tree_utils will raise TypeErrors if structures don't match.
39 | except TypeError:
40 | return False
41 |
--------------------------------------------------------------------------------
/init2winit/test_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for utils.py.
17 |
18 | """
19 |
20 | import os
21 | import shutil
22 | import tempfile
23 |
24 | from absl.testing import absltest
25 | from absl.testing import parameterized
26 | from init2winit import checkpoint
27 | from init2winit import utils
28 | import jax.numpy as jnp
29 | import numpy as np
30 |
31 |
32 | def _identity(i):
33 | return i
34 |
35 |
36 | def _fn_that_always_fails(arg):
37 | del arg
38 | raise ValueError('I always fail.')
39 |
40 |
41 | class UtilsTest(parameterized.TestCase):
42 | """Tests for utils.py."""
43 |
44 | def setUp(self):
45 | super(UtilsTest, self).setUp()
46 | self.test_dir = tempfile.mkdtemp()
47 |
48 | def tearDown(self):
49 | shutil.rmtree(self.test_dir)
50 | super(UtilsTest, self).tearDown()
51 |
52 | @parameterized.named_parameters(
53 | dict(
54 | testcase_name='empty list of args',
55 | num_workers=1,
56 | input_list_dict=[],
57 | expected=[],
58 | ),
59 | dict(
60 | testcase_name='one worker, nonempty list',
61 | num_workers=1,
62 | input_list_dict=[dict(i=k) for k in range(1, 10)],
63 | expected=list(range(1, 10)),
64 | ),
65 | dict(
66 | testcase_name='fewer workers than jobs, nonempty list',
67 | num_workers=3,
68 | input_list_dict=[dict(i=k) for k in range(1, 10)],
69 | expected=list(range(1, 10)),
70 | ),
71 | dict(
72 | testcase_name='more workers than jobs, nonempty list',
73 | num_workers=20,
74 | input_list_dict=[dict(i=k) for k in range(1, 10)],
75 | expected=list(range(1, 10)),
76 | ),
77 | )
78 | def testRunInParallel(self, input_list_dict, num_workers, expected):
79 | """Test running successful fns in parallel, originally from mlbileschi."""
80 | actual = utils.run_in_parallel(_identity, input_list_dict, num_workers)
81 | self.assertEqual(actual, expected)
82 |
83 | def testRunInParallelOnFailingFn(self):
84 | """Test running failing fns in parallel, originally from mlbileschi."""
85 | with self.assertRaisesRegex(ValueError, 'I always fail.'):
86 | utils.run_in_parallel(_fn_that_always_fails, [dict(arg='hi')], 10)
87 |
88 | def testAppendPytree(self):
89 | """Test appending and loading pytrees."""
90 | pytrees = [{'a': i} for i in range(10)]
91 | pytree_path = os.path.join(self.test_dir, 'pytree.ckpt')
92 | logger = utils.MetricLogger(pytree_path=pytree_path)
93 |
94 | for pytree in pytrees:
95 | logger.append_pytree(pytree)
96 |
97 | latest = checkpoint.load_latest_checkpoint(pytree_path, prefix='')
98 | saved_pytrees = latest['pytree'] if latest else []
99 | self.assertEqual(
100 | pytrees, [saved_pytrees[str(i)] for i in range(len(saved_pytrees))])
101 |
102 | def testArrayAppend(self):
103 | """Test appending to an array."""
104 | np.testing.assert_allclose(
105 | utils.array_append(jnp.array([1, 2, 3]), 4), jnp.array([1, 2, 3, 4]))
106 | np.testing.assert_allclose(
107 | utils.array_append(jnp.array([[1, 2], [3, 4]]), jnp.array([5, 6])),
108 | jnp.array([[1, 2], [3, 4], [5, 6]]))
109 |
110 | def testTreeNormSqL2(self):
111 | """Test computing the squared L2 norm of a pytree."""
112 | pytree = {'foo': jnp.ones(10), 'baz': jnp.ones(20)}
113 | self.assertEqual(utils.tree_norm_sql2(pytree), {'foo': 10.0, 'baz': 20.0})
114 | self.assertEqual(utils.total_tree_norm_sql2(pytree), 30.0)
115 |
116 | def testTreeSum(self):
117 | """Test computing the sum of a pytree."""
118 | pytree = {'foo': 2*jnp.ones(10), 'baz': jnp.ones(20)}
119 | self.assertEqual(utils.total_tree_sum(pytree), 40)
120 |
121 | if __name__ == '__main__':
122 | absltest.main()
123 |
124 |
--------------------------------------------------------------------------------
/init2winit/testdata/wikitext_tokenizer_fake_data.txt:
--------------------------------------------------------------------------------
1 | = = Lulu = =
2 |
3 |
4 | Lulu is a Goldendoodle , which is a mixed breed of the Poodle and the Golden Retriever breeds . Her posture is very poodle-like , similar to her father Lucas . Her mother Luna was also a Goldendoodle , although she carried herself much more like a Golden Retriever than a Poodle .
--------------------------------------------------------------------------------
/init2winit/tools/inspect_dataset.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | r"""Main file for the init2winit project.
17 |
18 | """
19 |
20 | import os
21 | import sys
22 |
23 | from absl import app
24 | from absl import flags
25 | from absl import logging
26 | from init2winit import hyperparameters
27 | from init2winit.dataset_lib import datasets
28 | import jax
29 | import tensorflow as tf
30 |
31 | # Don't let TF see the GPU, because all we use it for is tf.data loading.
32 | tf.config.experimental.set_visible_devices([], 'GPU')
33 |
34 | # Enable flax xprof trace labelling.
35 | os.environ['FLAX_PROFILE'] = 'true'
36 |
37 | flags.DEFINE_string('dataset', None, 'Which dataset to inspect')
38 | flags.DEFINE_string('model', None, 'Which model to use')
39 | flags.DEFINE_integer('batch_size', None,
40 | 'Number of examples to retrieve in 1 batch')
41 | flags.DEFINE_integer('num_batches', None, 'Number of batches to retrieve')
42 |
43 | FLAGS = flags.FLAGS
44 |
45 |
46 | def main(unused_argv):
47 | if jax.process_index() == 0:
48 | logging.info('argv:\n%s', ' '.join(sys.argv))
49 | logging.info('device_count: %d', jax.device_count())
50 | logging.info('num_hosts : %d', jax.process_count())
51 | logging.info('host_id : %d', jax.process_index())
52 |
53 | if FLAGS.batch_size is None or FLAGS.batch_size <= 0:
54 | raise ValueError("""FLAGS.batch_size value is invalid,
55 | expected a positive non-zero integer.""")
56 |
57 | if FLAGS.dataset is None:
58 | raise ValueError("""FLAGS.dataset value is invalid,
59 | expected a non-empty string describing dataset name.""")
60 |
61 | batch_size = FLAGS.batch_size
62 | num_batches = FLAGS.num_batches
63 | dataset_name = FLAGS.dataset
64 | model_name = FLAGS.model
65 | initializer_name = 'noop'
66 |
67 | hparam_overrides = {
68 | 'batch_size': batch_size,
69 | }
70 |
71 | hps = hyperparameters.build_hparams(
72 | model_name=model_name,
73 | initializer_name=initializer_name,
74 | dataset_name=dataset_name,
75 | hparam_file=None,
76 | hparam_overrides=hparam_overrides)
77 |
78 | rng = jax.random.PRNGKey(0)
79 | rng, data_rng = jax.random.split(rng)
80 |
81 | dataset = datasets.get_dataset(FLAGS.dataset)(data_rng, batch_size,
82 | batch_size, hps)
83 | train_iter = dataset.train_iterator_fn()
84 |
85 | for i in range(num_batches):
86 | batch = next(train_iter)
87 | logging.info('train batch_num = %d, batch = %r', i, batch)
88 |
89 | for batch in dataset.valid_epoch(num_batches):
90 | logging.info('validation batch = %r', batch)
91 |
92 |
93 | if __name__ == '__main__':
94 | app.run(main)
95 |
--------------------------------------------------------------------------------
/init2winit/trainer_lib/trainers.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The init2winit Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Trainers for init2winit."""
17 |
18 | from init2winit.trainer_lib import trainer
19 |
20 | _ALL_TRAINERS = {
21 | 'standard': trainer.Trainer,
22 | }
23 |
24 |
25 | def get_trainer_cls(trainer_name):
26 | """Maps trainer name to a Trainer class."""
27 | try:
28 | return _ALL_TRAINERS[trainer_name]
29 | except KeyError:
30 | raise ValueError('Unrecognized trainer: {}'.format(trainer_name)) from None
31 |
32 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | """init2winit.
2 |
3 | See more details in the
4 | [`README.md`](https://github.com/google/init2winit).
5 | """
6 |
7 | from setuptools import find_packages
8 | from setuptools import setup
9 |
10 | setup(
11 | name='init2winit',
12 | version='0.0.1',
13 | description='init2winit',
14 | author='init2winit Team',
15 | author_email='znado@google.com',
16 | url='http://github.com/google/init2winit',
17 | license='Apache 2.0',
18 | packages=find_packages(),
19 | install_requires=[
20 | 'absl-py>=0.8.1',
21 | 'clu',
22 | 'flax',
23 | 'jax',
24 | 'jax-bitempered-loss',
25 | 'jraph',
26 | 'ml_collections',
27 | 'numpy>=1.7',
28 | 'optax',
29 | 'optax-shampoo',
30 | 'pandas',
31 | 'sentencepiece',
32 | 'tensorboard',
33 | 'tensorflow-datasets',
34 | 'tensorflow-text==2.5.0-rc0',
35 | 'tensorflow==2.5.0',
36 | ],
37 | extras_require={},
38 | classifiers=[
39 | 'Development Status :: 3 - Alpha',
40 | 'Intended Audience :: Developers',
41 | 'Intended Audience :: Science/Research',
42 | 'License :: OSI Approved :: Apache Software License',
43 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
44 | ],
45 | keywords='jax machine learning',
46 | )
47 |
--------------------------------------------------------------------------------