├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── hessian ├── model_debugger.py ├── model_debugger_callback.py ├── precondition.py ├── test_model_debugger.py └── test_precondition.py ├── init2winit ├── __init__.py ├── base_callback.py ├── callbacks.py ├── checkpoint.py ├── dataset_lib │ ├── __init__.py │ ├── autoaugment.py │ ├── criteo_terabyte_dataset.py │ ├── data_selectors.py │ ├── data_utils.py │ ├── datasets.py │ ├── fake_dataset.py │ ├── fastmri_dataset.py │ ├── image_preprocessing.py │ ├── imagenet_dataset.py │ ├── imagenet_preprocessing.py │ ├── librispeech.py │ ├── librispeech_input_pipeline.py │ ├── lm1b_input_pipeline_v2.py │ ├── lm1b_v2.py │ ├── mlperf_imagenet_dataset.py │ ├── mlperf_input_pipeline.py │ ├── mt_pipeline.py │ ├── mt_pipeline_test.py │ ├── mt_tokenizer.py │ ├── nanodo_c4.py │ ├── nanodo_data_loader_shared.py │ ├── nanodo_fineweb_edu.py │ ├── nqm_noise.py │ ├── ogbg_molpcba.py │ ├── pg19.py │ ├── protein_vocab.py │ ├── proteins.py │ ├── small_image_datasets.py │ ├── spm_tokenizer.py │ ├── test_data_utils.py │ ├── test_datasets.py │ ├── test_ogbg_molpcba.py │ ├── test_small_image_datasets.py │ ├── test_wikitext_tokenizer.py │ ├── translate_wmt.py │ ├── wikitext103.py │ ├── wikitext103_input_pipeline.py │ ├── wikitext103_spm.py │ ├── wikitext2.py │ ├── wikitext2_input_pipeline.py │ └── wikitext_tokenizer.py ├── gradient_statistics_callback.py ├── hyperparameters.py ├── init_lib │ ├── __init__.py │ ├── initializers.py │ ├── meta_init.py │ ├── sparse_init.py │ └── test_initializers.py ├── main.py ├── model_lib │ ├── __init__.py │ ├── adabelief_densenet.py │ ├── adabelief_resnet.py │ ├── adabelief_vgg.py │ ├── attention.py │ ├── autoencoder.py │ ├── base_model.py │ ├── binarize_layers.py │ ├── conformer.py │ ├── convolutional_autoencoder.py │ ├── deepspeech.py │ ├── dlrm.py │ ├── fully_connected.py │ ├── gnn.py │ ├── librispeech_preprocessor.py │ ├── local_attention_transformer.py │ ├── losses.py │ ├── lstm.py │ ├── lstm_lm.py │ ├── max_pooling_cnn.py │ ├── metrics.py │ ├── mlperf_resnet.py │ ├── model_utils.py │ ├── models.py │ ├── nanodo.py │ ├── normalization.py │ ├── nqm.py │ ├── partition_tree.py │ ├── resnet.py │ ├── simple_cnn.py │ ├── spectrum_augmenter.py │ ├── test_local_attention_transformer.py │ ├── test_losses.py │ ├── test_metrics.py │ ├── test_models.py │ ├── test_normalization.py │ ├── transformer_lm.py │ ├── transformer_stu_lm.py │ ├── transformer_stu_tensordot_lm.py │ ├── unet.py │ ├── vit.py │ ├── wide_resnet.py │ ├── xformer_translate.py │ ├── xformer_translate_binary.py │ └── xformer_translate_mlc_variant.py ├── mt_eval │ ├── decode.py │ ├── eval_utils.py │ ├── inference.py │ ├── main.py │ └── mt_callback.py ├── optimizer_lib │ ├── __init__.py │ ├── factor_sam.py │ ├── gradient_accumulator.py │ ├── kitchen_sink │ │ ├── __init__.py │ │ └── _src │ │ │ ├── alias.py │ │ │ ├── combine.py │ │ │ ├── core.py │ │ │ ├── mask.py │ │ │ ├── preconditioner.py │ │ │ ├── test_core.py │ │ │ ├── test_mask.py │ │ │ ├── test_preconditioner.py │ │ │ ├── test_transform.py │ │ │ ├── transform.py │ │ │ └── utils.py │ ├── linalg │ │ ├── README.md │ │ ├── low_rank_root_update.py │ │ ├── low_rank_root_update_test.py │ │ ├── paterson_stockmeyer.py │ │ ├── pth_inv_root_rmn.py │ │ ├── pth_inv_root_rmn_coefficients.py │ │ ├── pth_inv_root_rmn_test.py │ │ └── root_selector.py │ ├── muon.py │ ├── online_newton_step.py │ ├── optimizers.py │ ├── pax_adafactor.py │ ├── samuel.py │ ├── search_subspace.py │ ├── sharpness_aware_minimization.py │ ├── sla.py │ ├── sla_test.py │ ├── test_gradient_accumulator.py │ ├── test_optimizers.py │ ├── test_search_subspace.py │ ├── test_utils.py │ └── utils.py ├── projects │ └── optlrschedule │ │ ├── README.md │ │ ├── notebook_utils │ │ ├── pandas_util.py │ │ ├── parquet_util.py │ │ ├── plot_util.py │ │ ├── schedule_util.py │ │ └── test_pandas_util.py │ │ ├── run_search_decoupled.py │ │ ├── scheduler │ │ ├── base_schedule_family.py │ │ ├── constant_schedule_family.py │ │ ├── cosine_schedule_family.py │ │ ├── cosine_standard_schedule_family.py │ │ ├── piecewise_schedule_family.py │ │ ├── rex_schedule_family.py │ │ ├── schedule_families.py │ │ ├── smooth_nonmonotonic_schedule_family.py │ │ ├── sqrt_schedule_family.py │ │ ├── test_scheduler.py │ │ ├── twopointslinear_schedule_family.py │ │ └── twopointsspline_schedule_family.py │ │ ├── search_algorithm │ │ ├── coordinate_descent_search.py │ │ ├── grid_search.py │ │ ├── random_search.py │ │ ├── search_algorithms.py │ │ └── test_search_algorithm.py │ │ └── workload │ │ ├── base_workload.py │ │ ├── cifar10_cnn.py │ │ ├── datasets │ │ └── wikitext_103.py │ │ ├── linear_regression.py │ │ ├── optimizers.py │ │ ├── test_workload.py │ │ ├── wikitext103_transformer.py │ │ └── workloads.py ├── references.md ├── schedules.py ├── shared_test_utilities.py ├── test_checkpoint.py ├── test_hyperparameters.py ├── test_schedules.py ├── test_training_metrics_grabber.py ├── test_utils.py ├── testdata │ └── wikitext_tokenizer_fake_data.txt ├── tools │ └── inspect_dataset.py ├── trainer_lib │ ├── base_trainer.py │ ├── test_trainer.py │ ├── trainer.py │ ├── trainer_utils.py │ └── trainers.py ├── training_metrics_grabber.py └── utils.py └── setup.py /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement (CLA). You (or your employer) retain the copyright to your 10 | contribution; this simply gives us permission to use and redistribute your 11 | contributions as part of the project. Head over to 12 | to see your current agreements on file or 13 | to sign a new one. 14 | 15 | You generally only need to submit a CLA once, so if you've already submitted one 16 | (even if it was for a different project), you probably don't need to do it 17 | again. 18 | 19 | ## Code Reviews 20 | 21 | All submissions, including submissions by project members, require review. We 22 | use GitHub pull requests for this purpose. Consult 23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 24 | information on using pull requests. 25 | 26 | ## Community Guidelines 27 | 28 | This project follows 29 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # init2winit 2 | 3 | A Jax/Flax codebase for running deterministic, scalable, and well-documented deep learning experiments, with a particular emphasis on neural network initialization, optimization, and tuning experiments. 4 | 5 | There is not yet a stable version (nor an official release of this library). 6 | All APIs are subject to change. 7 | 8 | This is a research project, not an official Google product. 9 | 10 | 11 | ## Installation 12 | The current development version requires Python 3.6-3.8. 13 | 14 | To install the latest development version inside a virtual environment, run 15 | 16 | ``` 17 | python3 -m venv env-i2w 18 | source env-i2w/bin/activate 19 | pip install --upgrade pip 20 | pip install "git+https://github.com/google/init2winit.git#egg=init2winit" 21 | pip install --upgrade jax jaxlib==0.1.66+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html 22 | ``` 23 | 24 | where `cuda111` corresponds to the installed version of CUDA. For more Jax install information see the [Jax README](https://github.com/google/jax#installation). 25 | 26 | ## Usage 27 | 28 | An example MNIST experiment can be run with the following command: 29 | 30 | ```sh 31 | python3 main.py \ 32 | --experiment_dir=/tmp/test_mnist \ 33 | --model=fully_connected \ 34 | --dataset=mnist \ 35 | --num_train_steps=10 36 | ``` 37 | 38 | For local debugging we recommend using the `fake` dataset: 39 | 40 | ```sh 41 | python3 main.py \ 42 | --experiment_dir=/tmp/test_fake \ 43 | --num_train_steps=10 \ 44 | --dataset=fake \ 45 | --hparam_overrides='{"input_shape": [28, 28, 1], "output_shape": [10]}' 46 | ``` 47 | 48 | The `hparam_overrides` accepts a serialized JSON object with hyperparameter names/values to use. See the flags in `main.py` for more information on possible configurations. 49 | 50 | See the [`dataset_lib`](https://github.com/google/init2winit/tree/master/init2winit/dataset_lib) and [`model_lib`](https://github.com/google/init2winit/tree/master/init2winit/model_lib) directories for currently implemented datasets and models. 51 | 52 | 53 | ## Citing 54 | To cite this repository: 55 | 56 | ```bibtex 57 | @software{init2winit2021github, 58 | author = {Justin M. Gilmer and George E. Dahl and Zachary Nado and Priya Kasimbeg and Sourabh Medapati}, 59 | title = {{init2winit}: a JAX codebase for initialization, optimization, and tuning research}, 60 | url = {http://github.com/google/init2winit}, 61 | version = {0.0.2}, 62 | year = {2023}, 63 | } 64 | ``` 65 | 66 | For a list of references to the models, datasets, and techniques implemented in this codebase, see [`references.md`](https://github.com/google/init2winit/tree/master/init2winit/references.md). 67 | 68 | 69 | ## Contributors 70 | Contributors (past and present): 71 | 72 | - Ankush Garg 73 | - Behrooz Ghorbani 74 | - Cheolmin Kim 75 | - David Cardoze 76 | - George E. Dahl 77 | - Justin M. Gilmer 78 | - Michal Badura 79 | - Priya Kasimbeg 80 | - Rohan Anil 81 | - Sourabh Medapati 82 | - Sneha Kudugunta 83 | - Varun Godbole 84 | - Zachary Nado 85 | - Vlad Feinberg 86 | - Derrick Xin 87 | - Naman Agarwal 88 | - Daniel Suo 89 | - Bilal Khan 90 | - Jeremy Cohen 91 | - Kacper Krasowiak 92 | 93 | -------------------------------------------------------------------------------- /hessian/model_debugger_callback.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Callback which runs the model debugger.""" 17 | 18 | import functools 19 | import os 20 | 21 | import flax 22 | from init2winit import utils 23 | from init2winit.dataset_lib import data_utils 24 | from init2winit.hessian import model_debugger 25 | import jax 26 | import jax.numpy as jnp 27 | 28 | DEFAULT_CONFIG = { 29 | 'name': 'model_debugger', 30 | } 31 | 32 | 33 | def get_grad(params, 34 | batch, 35 | rng, 36 | batch_stats=None, 37 | module_flags=None, 38 | training_cost=None): 39 | """Single step of the training loop. 40 | 41 | Args: 42 | params: the Flax param pytree. batch norm statistics. 43 | batch: the per-device batch of data to process. 44 | rng: the RNG used for calling the model. Assumes the step and device index 45 | has already been folded in. 46 | batch_stats: Same as in trainer.py 47 | module_flags: Used in the skip analysis. 48 | training_cost: a function used to calculate the training objective that will 49 | be differentiated to generate updates. Takes (`params`, `batch_stats`, 50 | `batch`, `rng`) as inputs. 51 | 52 | Returns: 53 | Gradient of the given loss. 54 | """ 55 | 56 | if module_flags is not None: 57 | kwargs = {'module_flags': module_flags} 58 | else: 59 | kwargs = {} 60 | def opt_cost(params): 61 | return training_cost( 62 | params, 63 | batch, 64 | batch_stats=batch_stats, 65 | dropout_rng=rng, 66 | **kwargs) 67 | 68 | grad_fn = jax.value_and_grad(opt_cost, has_aux=True) 69 | _, grad = grad_fn(params) 70 | 71 | grad = jax.lax.pmean(grad, axis_name='batch') 72 | return grad 73 | 74 | 75 | class ModelDebugCallback: 76 | """Used to run the hessian eval in the trainer binary.""" 77 | 78 | def __init__(self, model, params, batch_stats, optimizer_state, 79 | optimizer_update_fn, dataset, hps, callback_config, train_dir, 80 | rng): 81 | del hps 82 | del params 83 | del optimizer_state 84 | del optimizer_update_fn 85 | checkpoint_dir = os.path.join(train_dir, 'checkpoints') 86 | # copy batch_stats as we close over it, and it gets modified. 87 | self.dataset = dataset 88 | checkpoint_dir = os.path.join(train_dir, 'checkpoints') 89 | pytree_path = os.path.join(checkpoint_dir, 'debugger') 90 | logger = utils.MetricLogger(pytree_path=pytree_path) 91 | 92 | get_act_stats_fn = model_debugger.create_forward_pass_stats_fn( 93 | model.apply_on_batch, 94 | capture_activation_norms=True, 95 | sown_collection_names=callback_config.get('sown_collection_names')) 96 | batch_stats = jax.tree.map(lambda x: x[:][0], batch_stats) 97 | grad_fn = functools.partial( 98 | get_grad, 99 | batch_stats=batch_stats, 100 | training_cost=model.training_cost, 101 | ) 102 | debugger = model_debugger.ModelDebugger( 103 | use_pmap=True, 104 | forward_pass=get_act_stats_fn, 105 | metrics_logger=logger, 106 | grad_fn=grad_fn, 107 | skip_flags=callback_config.get('skip_flags'), 108 | skip_groups=callback_config.get('skip_groups')) 109 | # pmap functions for the training loop 110 | # in_axes = (params = 0, batch_stats = 0, batch = 0, step = None, 111 | # lr = None, rng = None, local_device_index = 0, training_metrics_grabber=0, 112 | # training_metrics_grabber, training_cost ) 113 | # Also, we can donate buffers for 'optimizer', 'batch_stats', 114 | # 'batch' and 'training_metrics_grabber' for update's pmapped computation. 115 | self.debugger = debugger 116 | self.logger = logger 117 | self.dataset = dataset 118 | 119 | batch = next(dataset.train_iterator_fn()) 120 | self.batch = data_utils.shard(batch) 121 | 122 | self.batch_rng = flax.jax_utils.replicate(rng) 123 | 124 | def run_eval(self, params, batch_stats, optimizer_state, global_step): 125 | """Runs ModelDebugger.full_eval on the given params. 126 | 127 | Note, the full lanczos tridiagonal matrix is saved via the logger to 128 | train_dir/checkpoints/config['name']. 129 | 130 | Args: 131 | params: Replicated model parameter tree. 132 | batch_stats: Replicated batch_stats from the trainer. 133 | optimizer_state: Replicated optimizer state from the trainer. 134 | global_step: Current training step. 135 | 136 | Returns: 137 | Max eigenvalue of the loss (full tridiag is saved to disk). 138 | """ 139 | del optimizer_state 140 | del batch_stats 141 | p_norms = jax.tree.map(lambda x: jnp.linalg.norm(x[0].reshape(-1))**2, 142 | params) 143 | 144 | self.debugger.full_eval( 145 | step=global_step, 146 | params=params, 147 | param_norms_sql2=p_norms, 148 | batch=self.batch, 149 | rng=self.batch_rng) 150 | 151 | return {} 152 | -------------------------------------------------------------------------------- /init2winit/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /init2winit/base_callback.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Infrastructure for arbitrary code to be run during training. 17 | 18 | Callbacks can be stateful, and in trainer are meant to be called as follows: 19 | 20 | 21 | callback_builder = callbacks.get_callback(config['callback_name']) 22 | callback = callback_builder(model, params, batch_stats, optimizer_state, 23 | dataset, hps, config, train_dir, rng, mesh) 24 | 25 | callback_metrics = callback.run_eval(params, batch_stats, 26 | optimizer_state, global_step). 27 | 28 | We require that the config has a field 'callback_name', which the trainer 29 | uses to determine which callbacks to run. The dictionary, callback_metrics 30 | should be scalar valued, and will be automatically added to the existing trainer 31 | scalar metrics. 32 | """ 33 | 34 | # TODO(gilmer) Add serialization so that we can checkpoint callback state. 35 | 36 | 37 | class BaseCallBack: 38 | """Base callback to specify the required API.""" 39 | 40 | def __init__(self, model, params, batch_stats, optimizer_state, 41 | optimizer_update_fn, dataset, hps, callback_config, train_dir, 42 | rng, mesh): 43 | """Defines the API for callback construction.""" 44 | pass 45 | 46 | def run_eval(self, params, batch_stats, optimizer_state, global_step): 47 | """Define the API for running the callback during eval. 48 | 49 | Args: 50 | params: Replicated params from the trainer. 51 | batch_stats: Replicated batch_stats from the trainer. 52 | optimizer_state: Replicated optimizer state from the trainer. 53 | global_step: Current training step. 54 | 55 | Returns: 56 | A dictionary of scalar metrics. Note, any existing metric returned by 57 | trainer.evaluate are forbidden, e.g. including 'train/ce_loss' will 58 | resort in trainer throwing an exception. 59 | """ 60 | raise NotImplementedError('Subclasses must implement run_eval().') 61 | -------------------------------------------------------------------------------- /init2winit/callbacks.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Registry for the available callbacks.""" 17 | 18 | from init2winit import gradient_statistics_callback 19 | from init2winit.hessian import model_debugger_callback 20 | from init2winit.mt_eval import mt_callback 21 | 22 | 23 | _ALL_CALLBACKS = { 24 | 'mt': mt_callback.MTEvaluationCallback, 25 | 'model_debugger': model_debugger_callback.ModelDebugCallback, 26 | 'gradient_statistics': ( 27 | gradient_statistics_callback.GradientStatisticsCallback 28 | ), 29 | } 30 | 31 | 32 | def get_callback(callback_name): 33 | """Get the corresponding callback builder based on the callback_name. 34 | 35 | Args: 36 | callback_name: (str) e.g. mt. 37 | 38 | Returns: 39 | Callback builder class. 40 | Raises: 41 | ValueError if callback_name is unrecognized. 42 | """ 43 | try: 44 | return _ALL_CALLBACKS[callback_name] 45 | except KeyError: 46 | raise ValueError('Unrecognized callback name: {}'.format(callback_name)) 47 | -------------------------------------------------------------------------------- /init2winit/dataset_lib/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /init2winit/dataset_lib/data_selectors.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Definitions for supported data selection functions.""" 17 | 18 | 19 | def noop( 20 | dataset_iterator, 21 | optimizer_state, 22 | params, 23 | batch_stats, 24 | hps, 25 | global_step, 26 | constant_base_rng): 27 | """An example no-op data selector that just yields the next batch. 28 | 29 | Args: 30 | dataset_iterator: the (preprocessed, batched, prefetched) dataset iterator. 31 | optimizer_state: the current optimizer state. 32 | params: the model parameters. 33 | batch_stats: the model batch statistics. 34 | hps: the experiment hyperparameters. 35 | global_step: the current global step. 36 | constant_base_rng: the RNG used for the experiment. IMPORTANT NOTE: this 37 | will be constant for all calls to this function, in order to get a unique 38 | RNG each time we need to do 39 | `rng = jax.random.fold_in(constant_base_rng, global_step)`. 40 | 41 | Yields: 42 | A batch of data. 43 | """ 44 | del optimizer_state 45 | del params 46 | del batch_stats 47 | del hps 48 | del global_step 49 | del constant_base_rng 50 | yield from dataset_iterator 51 | 52 | 53 | def data_echoing( 54 | dataset_iterator, 55 | optimizer_state, 56 | params, 57 | batch_stats, 58 | hps, 59 | global_step, 60 | constant_base_rng): 61 | """An example data echoing selector. 62 | 63 | Args: 64 | dataset_iterator: the (preprocessed, batched, prefetched) dataset iterator. 65 | optimizer_state: the current optimizer state. 66 | params: the model parameters. 67 | batch_stats: the model batch statistics. 68 | hps: the experiment hyperparameters. 69 | global_step: the current global step. 70 | constant_base_rng: the RNG used for the experiment. IMPORTANT NOTE: this 71 | will be constant for all calls to this function, in order to get a unique 72 | RNG each time we need to do 73 | `rng = jax.random.fold_in(constant_base_rng, global_step)`. 74 | 75 | Yields: 76 | A batch of data. 77 | """ 78 | del optimizer_state 79 | del params 80 | del batch_stats 81 | del global_step 82 | del constant_base_rng 83 | for x in dataset_iterator: 84 | for _ in range(hps.num_data_echoes): 85 | yield x 86 | 87 | 88 | ALL_SELECTORS = { 89 | 'noop': noop, 90 | 'data_echoing': data_echoing, 91 | } 92 | -------------------------------------------------------------------------------- /init2winit/dataset_lib/fake_dataset.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Fake image input pipeline. Returns the same batch of ones over and over.""" 17 | import copy 18 | 19 | from init2winit.dataset_lib import data_utils 20 | import jax 21 | import jax.numpy as jnp 22 | from ml_collections.config_dict import config_dict 23 | import numpy as np 24 | 25 | 26 | TRAIN_IMAGES = 1281167 27 | EVAL_IMAGES = 50000 28 | 29 | 30 | NUM_CLASSES = 1000 31 | IMAGE_SIZE = 224 32 | 33 | 34 | DEFAULT_HPARAMS = config_dict.ConfigDict(dict( 35 | input_shape=(224, 224, 3), 36 | output_shape=(NUM_CLASSES,), 37 | train_size=TRAIN_IMAGES, 38 | valid_size=EVAL_IMAGES)) 39 | 40 | METADATA = { 41 | 'apply_one_hot_in_loss': False, 42 | } 43 | 44 | 45 | def get_fake_batch(hps): 46 | """Generate batches of images of all ones and one-hot labels.""" 47 | batch_size = hps.batch_size 48 | input_shape = hps.input_shape 49 | num_classes = hps.output_shape[0] 50 | train_input_shape = (batch_size, *input_shape) 51 | images = jnp.ones(train_input_shape, dtype=jnp.float32) 52 | labels = jax.nn.one_hot( 53 | np.zeros((batch_size,)), num_classes, dtype=jnp.int32) 54 | batch = { 55 | 'inputs': images, 56 | 'targets': labels, 57 | 'weights': jnp.ones(batch_size, dtype=images.dtype), 58 | } 59 | return batch 60 | 61 | 62 | def get_fake(shuffle_rng, batch_size, eval_batch_size, hps=None): 63 | """Data generators for imagenet.""" 64 | del shuffle_rng 65 | per_host_batch_size = batch_size // jax.process_count() 66 | per_host_eval_batch_size = eval_batch_size // jax.process_count() 67 | 68 | train_hps = copy.copy(hps) 69 | train_hps.unlock() 70 | train_hps.batch_size = per_host_batch_size 71 | fake_train_batch = get_fake_batch(train_hps) 72 | 73 | test_hps = copy.copy(hps) 74 | test_hps.unlock() 75 | test_hps.batch_size = per_host_eval_batch_size 76 | fake_test_batch = get_fake_batch(test_hps) 77 | 78 | def train_iterator_fn(): 79 | while True: 80 | yield fake_train_batch 81 | 82 | def valid_epoch(epoch, num_batches=None): 83 | del num_batches 84 | del epoch 85 | # Note that we do // beacuse we do not support partial batching for the fake 86 | # dataset. 87 | for _ in range(hps.valid_size // eval_batch_size): 88 | yield fake_test_batch 89 | 90 | # pylint: disable=unreachable 91 | def eval_train_epoch(*args, **kwargs): 92 | del args 93 | del kwargs 94 | return 95 | yield # This yield is needed to make this a valid (null) iterator. 96 | # pylint: enable=unreachable 97 | # pylint: disable=unreachable 98 | 99 | def test_epoch(*args, **kwargs): 100 | del args 101 | del kwargs 102 | return 103 | yield # This yield is needed to make this a valid (null) iterator. 104 | # pylint: enable=unreachable 105 | 106 | return data_utils.Dataset( 107 | train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch) 108 | 109 | -------------------------------------------------------------------------------- /init2winit/dataset_lib/librispeech.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """LM1B input pipeline.""" 17 | 18 | import itertools 19 | 20 | from init2winit.dataset_lib import data_utils 21 | from init2winit.dataset_lib import librispeech_input_pipeline 22 | from init2winit.dataset_lib.data_utils import Dataset 23 | import jax 24 | from ml_collections.config_dict import config_dict 25 | import numpy as np 26 | 27 | MAX_INPUT_LENGTH = 320000 28 | MAX_TARGET_LENGTH = 256 29 | VOCAB_SIZE = 1024 30 | 31 | DEFAULT_HPARAMS = config_dict.ConfigDict( 32 | dict( 33 | max_input_length=MAX_INPUT_LENGTH, 34 | max_target_length=MAX_TARGET_LENGTH, 35 | train_split='train_clean100+train_clean360+train_other500', 36 | eval_split='dev_clean+dev_other', 37 | test_split='test_clean', 38 | input_shape=[(MAX_INPUT_LENGTH,), (MAX_INPUT_LENGTH,)], 39 | output_shape=(-1, VOCAB_SIZE), 40 | train_size=281241, 41 | tokenizer_vocab_path='', 42 | tokenizer_type='SPM')) 43 | 44 | METADATA = {'apply_one_hot_in_loss': False} 45 | 46 | 47 | def _batch_to_dict(batch): 48 | batch_np = data_utils.tf_to_numpy(batch) 49 | return batch_np 50 | 51 | 52 | def get_librispeech(shuffle_rng, batch_size, eval_batch_size=None, hps=None): 53 | """Wrapper to conform to the general dataset API.""" 54 | process_count = jax.process_count() 55 | if batch_size % process_count != 0: 56 | raise ValueError('process_count={} must divide batch_size={}.'.format( 57 | process_count, batch_size)) 58 | 59 | per_host_batch_size = batch_size // process_count 60 | if eval_batch_size is None: 61 | eval_batch_size = batch_size 62 | 63 | if eval_batch_size % process_count != 0: 64 | raise ValueError('process_count={} must divide eval_batch_size={}.'.format( 65 | process_count, eval_batch_size)) 66 | per_host_eval_batch_size = eval_batch_size // process_count 67 | 68 | return _get_librispeech(hps, per_host_batch_size, per_host_eval_batch_size, 69 | shuffle_rng) 70 | 71 | 72 | def _get_librispeech(hps, per_host_batch_size, per_host_eval_batch_size, 73 | shuffle_rng): 74 | """Data generators for lm1b.""" 75 | n_devices = jax.local_device_count() 76 | if per_host_batch_size % n_devices != 0: 77 | raise ValueError('n_devices={} must divide per_host_batch_size={}.'.format( 78 | n_devices, per_host_batch_size)) 79 | 80 | if per_host_eval_batch_size % n_devices != 0: 81 | raise ValueError( 82 | 'n_devices={} must divide per_host_eval_batch_size={}.'.format( 83 | n_devices, per_host_eval_batch_size)) 84 | 85 | train_ds, eval_ds, test_ds = librispeech_input_pipeline.get_librispeech_datasets( 86 | hps, per_host_batch_size, per_host_eval_batch_size, shuffle_rng) 87 | 88 | def train_iterator_fn(): 89 | for batch in iter(train_ds): 90 | yield _batch_to_dict(batch) 91 | 92 | def eval_train_epoch(num_batches=None): 93 | eval_train_iter = iter(train_ds) 94 | for batch in itertools.islice(eval_train_iter, num_batches): 95 | yield _batch_to_dict(batch) 96 | 97 | def valid_epoch(num_batches=None): 98 | valid_iter = iter(eval_ds) 99 | for batch in itertools.islice(valid_iter, num_batches): 100 | batch = _batch_to_dict(batch) 101 | yield data_utils.maybe_pad_batch( 102 | batch, desired_batch_size=per_host_eval_batch_size, padding_value=1.0) 103 | 104 | def test_epoch(num_batches=None): 105 | test_iter = iter(test_ds) 106 | for batch in itertools.islice(test_iter, num_batches): 107 | batch = _batch_to_dict(batch) 108 | yield data_utils.maybe_pad_batch( 109 | batch, desired_batch_size=per_host_eval_batch_size, padding_value=1.0) 110 | 111 | # pylint: enable=unreachable 112 | return Dataset(train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch) 113 | 114 | 115 | def get_fake_batch(hps): 116 | return { 117 | 'inputs': 118 | np.ones((hps.batch_size, hps.max_input_length), 119 | dtype=hps.model_dtype), 120 | 'input_paddings': 121 | np.ones((hps.batch_size, hps.max_input_length), 122 | dtype=hps.model_dtype), 123 | 'targets': 124 | np.ones((hps.batch_size, hps.max_target_length), 125 | dtype=hps.model_dtype), 126 | 'target_paddings': 127 | np.ones((hps.batch_size, hps.max_target_length), 128 | dtype=hps.model_dtype), 129 | } 130 | -------------------------------------------------------------------------------- /init2winit/dataset_lib/mlperf_imagenet_dataset.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """ImageNet input pipeline with MLPerf preprocessing.""" 17 | 18 | import itertools 19 | 20 | from init2winit.dataset_lib import data_utils 21 | from init2winit.dataset_lib import imagenet_dataset 22 | from init2winit.dataset_lib import mlperf_input_pipeline 23 | import jax 24 | from ml_collections.config_dict import config_dict 25 | import numpy as np 26 | import tensorflow.compat.v2 as tf 27 | 28 | 29 | DEFAULT_HPARAMS = config_dict.ConfigDict(dict( 30 | input_shape=(224, 224, 3), 31 | output_shape=(1000,), 32 | train_size=1281167, 33 | valid_size=50000, 34 | test_size=10000, # ImageNet-v2. 35 | use_imagenetv2_test=True)) 36 | 37 | METADATA = { 38 | 'apply_one_hot_in_loss': False, 39 | } 40 | 41 | 42 | def get_mlperf_imagenet(rng, 43 | batch_size, 44 | eval_batch_size, 45 | hps=None): 46 | """Data generators for imagenet. 47 | 48 | Args: 49 | rng: RNG seed that is split into a shuffle seed and a seed that is folded 50 | into a per-example seed. 51 | batch_size: the *global* batch size used for training. 52 | eval_batch_size: the *global* batch size used for evaluation. 53 | hps: the hparams for the experiment, only required field is valid_size. 54 | 55 | Returns: 56 | A data_utils.Dataset for the MLPerf version of ImageNet. 57 | """ 58 | if batch_size % jax.device_count() != 0: 59 | raise ValueError( 60 | 'Require batch_size % jax.device_count(), received ' 61 | 'batch_size={}, device_count={}.'.format( 62 | batch_size, jax.device_count())) 63 | if eval_batch_size % jax.device_count() != 0: 64 | raise ValueError( 65 | 'Require eval_batch_size % jax.device_count(), received ' 66 | 'eval_batch_size={}, device_count={}.'.format( 67 | eval_batch_size, jax.device_count())) 68 | host_batch_size = batch_size // jax.process_count() 69 | eval_host_batch_size = eval_batch_size // jax.process_count() 70 | 71 | max_eval_steps = hps.valid_size // eval_batch_size + 1 72 | 73 | input_dtype = tf.bfloat16 74 | shuffle_buffer_size = 16384 75 | 76 | train_ds = mlperf_input_pipeline.load_split( 77 | host_batch_size, 78 | dtype=input_dtype, 79 | split='train', 80 | rng=rng, 81 | shuffle_size=shuffle_buffer_size) 82 | 83 | eval_train_ds = mlperf_input_pipeline.load_split( 84 | host_batch_size, 85 | dtype=input_dtype, 86 | split='eval_train', 87 | rng=rng, 88 | shuffle_size=shuffle_buffer_size) 89 | 90 | eval_ds = mlperf_input_pipeline.load_split( 91 | eval_host_batch_size, 92 | dtype=input_dtype, 93 | split='validation', 94 | rng=rng, 95 | shuffle_size=shuffle_buffer_size) 96 | 97 | # We do not have TFRecords of ImageNet-v2 in the same format as the 98 | # train/validation splits above, so we reuse the same test split from the 99 | # non-MLPerf pipeline. 100 | test_ds = None 101 | if hps.use_imagenetv2_test: 102 | test_ds = imagenet_dataset.load_split( 103 | eval_host_batch_size, 104 | 'test', 105 | hps=hps, 106 | image_size=224, 107 | tfds_dataset_name='imagenet_v2/matched-frequency') 108 | 109 | # We cannot use tfds.as_numpy because this calls tensor.numpy() which does an 110 | # additional copy of the tensor, instead we call tensor._numpy() below. 111 | def train_iterator_fn(): 112 | return data_utils.iterator_as_numpy(iter(train_ds)) 113 | 114 | def eval_train_epoch(num_batches=None): 115 | if num_batches is None: 116 | num_batches = 0 117 | eval_train_iter = iter(eval_train_ds) 118 | np_iter = data_utils.iterator_as_numpy( 119 | itertools.islice(eval_train_iter, num_batches)) 120 | for batch in np_iter: 121 | yield data_utils.maybe_pad_batch(batch, eval_host_batch_size) 122 | 123 | def valid_epoch(num_batches=None): 124 | if num_batches is None: 125 | num_batches = max_eval_steps 126 | valid_iter = iter(eval_ds) 127 | np_iter = data_utils.iterator_as_numpy( 128 | itertools.islice(valid_iter, num_batches)) 129 | for batch in np_iter: 130 | yield data_utils.maybe_pad_batch(batch, eval_host_batch_size) 131 | 132 | def test_epoch(num_batches=None): 133 | if test_ds: 134 | test_iter = iter(test_ds) 135 | np_iter = data_utils.iterator_as_numpy( 136 | itertools.islice(test_iter, num_batches)) 137 | for batch in np_iter: 138 | yield data_utils.maybe_pad_batch(batch, eval_host_batch_size) 139 | else: 140 | # pylint: disable=unreachable 141 | return 142 | yield # This yield is needed to make this a valid (null) iterator. 143 | # pylint: enable=unreachable 144 | 145 | return data_utils.Dataset( 146 | train_iterator_fn, 147 | eval_train_epoch, 148 | valid_epoch, 149 | test_epoch) 150 | 151 | 152 | def get_fake_batch(hps): 153 | return { 154 | 'inputs': 155 | np.ones((hps.batch_size, *hps.input_shape), dtype=hps.model_dtype), 156 | 'targets': 157 | np.ones((hps.batch_size, *hps.output_shape), dtype=hps.model_dtype), 158 | 'weights': 159 | np.ones((hps.batch_size,), dtype=hps.model_dtype), 160 | } 161 | -------------------------------------------------------------------------------- /init2winit/dataset_lib/nanodo_data_loader_shared.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Common utils for nanodo data loaders. 17 | 18 | Implementation based on: 19 | https://github.com/google-deepmind/nanodo/blob/main/nanodo/data.py 20 | """ 21 | 22 | from collections.abc import Mapping, Sequence 23 | import dataclasses 24 | import enum 25 | from typing import Iterable, Iterator, Union 26 | 27 | import grain.python as grain 28 | import jax 29 | import jax.numpy as jnp 30 | 31 | import numpy as np 32 | 33 | import sentencepiece as spm 34 | 35 | 36 | PAD_ID = 0 37 | 38 | 39 | class Preprocess(enum.Enum): 40 | NOAM_PACKED = 1 41 | PADDED = 2 42 | 43 | 44 | ### pure python helpers for use with grain ### 45 | # Need this because we can't pickle SentencePieceProcessor object 46 | class SPTokenizer: 47 | """Wrapper class for SentencePiece tokenizer.""" 48 | 49 | def __init__(self, vocab_path): 50 | self._tokenizer = None 51 | self._vocab_path = vocab_path 52 | 53 | def get_tokenizer(self) -> spm.SentencePieceProcessor: 54 | if not self._tokenizer: 55 | self._tokenizer = get_py_tokenizer(self._vocab_path) 56 | return self._tokenizer 57 | 58 | 59 | class SentencePieceByteTokenizer(spm.SentencePieceProcessor): 60 | """A simple Byte level tokenizer.""" 61 | 62 | def eos_id(self) -> int: 63 | return 1 64 | 65 | def bos_id(self) -> int: 66 | return 2 67 | 68 | def pad_id(self) -> int: 69 | return PAD_ID 70 | 71 | def GetPieceSize(self) -> int: 72 | return 256 73 | 74 | # pylint: disable=invalid-name 75 | def EncodeAsIds(self, text: Union[bytes, str]) -> list[int]: 76 | if isinstance(text, str): 77 | return list(bytes(text, 'utf-8')) 78 | if isinstance(text, bytes): 79 | return [int(x) for x in text] 80 | raise ValueError(f'Invalid text: {text} type={type(text)}') 81 | 82 | def DecodeIds(self, ids: Iterable[int]) -> str: 83 | return bytes(ids).decode('utf-8') 84 | # pylint: enable=invalid-name 85 | 86 | 87 | def py_tokenize( 88 | features: Mapping[str, str], 89 | spt: SPTokenizer, 90 | pad_len: int | None = None, 91 | pad_id: int = PAD_ID, 92 | ) -> Sequence[int]: 93 | """Tokenizes text into ids, optionally pads or truncates to pad_len.""" 94 | text = features['text'] 95 | tokenizer = spt.get_tokenizer() 96 | bos_id = tokenizer.bos_id() 97 | eos_id = tokenizer.eos_id() 98 | ids = tokenizer.EncodeAsIds(text) 99 | 100 | ids.insert(0, bos_id) 101 | ids.append(eos_id) 102 | if pad_len is not None: 103 | if len(ids) < pad_len: 104 | ids.extend([pad_id] * (pad_len - len(ids))) 105 | elif len(ids) > pad_len: 106 | ids = ids[:pad_len] 107 | return ids 108 | 109 | 110 | @dataclasses.dataclass 111 | class NoamPack: 112 | """Pygrain operation for tokenizing and Noam packing text.""" 113 | 114 | context_size: int 115 | 116 | def __call__( 117 | self, idseq_iterator: Iterator[grain.Record] 118 | ) -> Iterator[grain.Record]: 119 | packed_ids = [] 120 | for input_record in idseq_iterator: 121 | start = 0 122 | while start < len(input_record.data): 123 | rem_data = input_record.data[start:] 124 | if len(packed_ids) + len(rem_data) < self.context_size: 125 | packed_ids.extend(rem_data) # use rest of example, move-on 126 | break 127 | else: 128 | take = self.context_size - len(packed_ids) 129 | packed_ids.extend(rem_data[:take]) 130 | last_record_key = input_record.metadata.remove_record_key() 131 | yield grain.Record( 132 | last_record_key, np.array(packed_ids, dtype=np.int32) 133 | ) 134 | start += take 135 | packed_ids = [] 136 | # Drop remainder for simplicity. 137 | # We lose the rest of the example on restore. 138 | 139 | 140 | # pylint: disable=invalid-name 141 | 142 | 143 | def get_py_tokenizer(path: str) -> spm.SentencePieceProcessor: 144 | if not path: 145 | # byte tokenizer shortcut 146 | return SentencePieceByteTokenizer() 147 | sp = spm.SentencePieceProcessor() 148 | sp.Load(path) 149 | assert sp.pad_id() == PAD_ID 150 | assert sp.eos_id() != -1 151 | assert sp.bos_id() != -1 152 | return sp 153 | 154 | 155 | def get_in_out( 156 | in_BxL: jax.Array, 157 | pad_id: int = PAD_ID, 158 | ) -> tuple[jax.Array, jax.Array, jax.Array]: 159 | """Returns input, output, and weights for a batch of examples.""" 160 | # Assumes input of the form for eval. 161 | x_BxL = in_BxL 162 | y_BxL = jnp.pad( 163 | in_BxL[:, 1:], 164 | ((0, 0), (0, 1)), 165 | mode='constant', 166 | constant_values=pad_id, 167 | ) 168 | weights_BxL = jnp.where(y_BxL != pad_id, 1, 0).astype(jnp.float32) 169 | 170 | return x_BxL, y_BxL, weights_BxL 171 | -------------------------------------------------------------------------------- /init2winit/dataset_lib/nqm_noise.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Data generators for init2winit.""" 17 | 18 | from init2winit.dataset_lib import data_utils 19 | import jax.random 20 | from ml_collections.config_dict import config_dict 21 | import numpy as np 22 | 23 | 24 | NQM_HPARAMS = config_dict.ConfigDict( 25 | dict( 26 | train_size=1e10, 27 | valid_size=0, 28 | test_size=0, 29 | input_shape=(100,), # This determines the dimension. 30 | output_shape=(1,), 31 | )) 32 | NQM_METADATA = { 33 | 'apply_one_hot_in_loss': False, 34 | } 35 | 36 | 37 | def get_nqm_noise(shuffle_rng, batch_size, eval_batch_size, hps=None): 38 | """Returns the noise seed for the nqm model. 39 | 40 | NOTE: This dataset is only meant to be used with the nqm model. 41 | This just generates isotropic Gaussian noise of the desired dimension. 42 | The nqm model will then multiple this noise by a matrix D, with the properly 43 | that D^T D = C. This yields noise with gradient covariance C. 44 | 45 | Args: 46 | shuffle_rng: Not used. 47 | batch_size: The global train batch size, used to determine the batch size 48 | yielded from train_epoch(). 49 | eval_batch_size: Not used. 50 | hps: Hparams object. We only refer to hps.input_shape to determine the 51 | dimension of the noise. 52 | Returns: 53 | train_epoch, eval_train_epoch, valid_epoch, test_epoch: three generators. 54 | Only train_epoch is used. 55 | """ 56 | del eval_batch_size 57 | 58 | per_host_batch_size = batch_size // jax.process_count() 59 | # Should train_rng / eval_rng possibly have different seeds? 60 | seed = data_utils.convert_jax_to_tf_random_seed(shuffle_rng) 61 | train_rng = np.random.RandomState(seed=seed) 62 | eval_rng = np.random.RandomState(seed=seed) 63 | 64 | def train_iterator_fn(): 65 | while True: 66 | yield { 67 | 'inputs': train_rng.normal( 68 | size=(per_host_batch_size, *hps.input_shape) 69 | ) 70 | } 71 | 72 | def eval_train_epoch(num_batches): 73 | for _ in range(num_batches): 74 | yield { 75 | 'inputs': eval_rng.normal( 76 | size=(per_host_batch_size, *hps.input_shape) 77 | ) 78 | } 79 | 80 | # pylint: disable=unreachable 81 | def valid_epoch(*args, **kwargs): 82 | del args 83 | del kwargs 84 | return 85 | yield # This yield is needed to make this a valid (null) iterator. 86 | # pylint: enable=unreachable 87 | 88 | # pylint: disable=unreachable 89 | def test_epoch(*args, **kwargs): 90 | del args 91 | del kwargs 92 | return 93 | yield # This yield is needed to make this a valid (null) iterator. 94 | 95 | # pylint: enable=unreachable 96 | 97 | return data_utils.Dataset( 98 | train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch 99 | ) 100 | -------------------------------------------------------------------------------- /init2winit/dataset_lib/test_data_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Unit tests for datasets.py.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | from init2winit.dataset_lib import data_utils 21 | import numpy as np 22 | 23 | desired_batch_size = 23 24 | batch_axes = [0, 0, 3, 2] 25 | test_names = ['default', 'NHWC', 'HWCN', 'HWNC'] 26 | image_formats = [None, 'NHWC', 'HWCN', 'HWNC'] 27 | batch_size = 13 28 | width = 11 29 | num_channels = 3 30 | input_shapes = [ 31 | (batch_size, width, width, num_channels), 32 | (batch_size, width, width, num_channels), 33 | (width, width, num_channels, batch_size), 34 | (width, width, batch_size, num_channels), 35 | ] 36 | test_parameters = zip(test_names, image_formats, batch_axes, input_shapes) 37 | 38 | 39 | class DataUtilsTest(parameterized.TestCase): 40 | """Unit tests for datasets.py.""" 41 | 42 | @parameterized.named_parameters(*test_parameters) 43 | def test_padding(self, image_format, batch_axis, input_shape): 44 | """Test that the shape is the expected padded shape.""" 45 | batch = {'inputs': np.ones(input_shape)} 46 | padded_batch = data_utils.maybe_pad_batch( 47 | batch, desired_batch_size, image_format) 48 | expected_shapes = list(input_shape) 49 | expected_shapes[batch_axis] = desired_batch_size 50 | self.assertEqual(padded_batch['inputs'].shape, tuple(expected_shapes)) 51 | self.assertEqual(padded_batch['weights'].shape, (desired_batch_size,)) 52 | 53 | def test_padding_seq2seq(self): 54 | """Test padding for sequence-to-sequence models.""" 55 | input_len_max = 25 56 | input_len_true = 22 # true input_seq_length for each example in batch. 57 | target_len_max = 25 58 | target_len_true = 21 # true target_seq_length for each example in batch. 59 | 60 | inputs_shape = (batch_size, input_len_max) 61 | targets_shape = (batch_size, target_len_max) 62 | batch = {'inputs': np.ones(inputs_shape), 'targets': np.ones(targets_shape)} 63 | batch['inputs'][:, input_len_true:] = 0 # zero-pad extra inputs tokens 64 | batch['targets'][:, target_len_true:] = 0 # zero-pad extra targets tokens 65 | expected_inputs_shape = (desired_batch_size, input_len_max) 66 | expected_targets_shape = (desired_batch_size, target_len_max) 67 | expected_weights_shape = (desired_batch_size, target_len_max) 68 | padded_batch = data_utils.maybe_pad_batch( 69 | batch, desired_batch_size, data_format=None, mask_key='targets') 70 | self.assertEqual(padded_batch['inputs'].shape, expected_inputs_shape) 71 | self.assertEqual(padded_batch['targets'].shape, expected_targets_shape) 72 | self.assertEqual(padded_batch['weights'].shape, expected_weights_shape) 73 | 74 | batch_pad = desired_batch_size - batch_size 75 | expected_weights_array = np.ones((desired_batch_size, target_len_max)) 76 | # pad at batch axis 77 | expected_weights_array[-batch_pad:] = 0 78 | # # pad at sequence_len axis 79 | expected_weights_array[:, target_len_true:] = 0 80 | self.assertTrue( 81 | np.array_equal(padded_batch['weights'], expected_weights_array)) 82 | 83 | 84 | if __name__ == '__main__': 85 | absltest.main() 86 | -------------------------------------------------------------------------------- /init2winit/dataset_lib/test_small_image_datasets.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for init2winit.dataset_lib.small_image_datasets.""" 17 | 18 | import itertools 19 | 20 | from absl.testing import absltest 21 | from init2winit.dataset_lib import small_image_datasets 22 | from jax import random 23 | from ml_collections.config_dict import config_dict 24 | 25 | 26 | class SmallImageDatasetsTest(absltest.TestCase): 27 | """Unit tests for small_image_datasets.py.""" 28 | 29 | def test_cifar10(self): 30 | """Test example generation in CIFAR10 is reproducible.""" 31 | dataset = small_image_datasets.get_cifar10( 32 | random.PRNGKey(0), 1, 1, 33 | config_dict.ConfigDict( 34 | dict( 35 | flip_probability=0.5, 36 | alpha=1.0, 37 | crop_num_pixels=4, 38 | use_mixup=True, 39 | train_size=45000, 40 | valid_size=5000, 41 | test_size=10000, 42 | include_example_keys=True, 43 | input_shape=(32, 32, 3), 44 | output_shape=(10,)))) 45 | 46 | examples = itertools.islice(dataset.valid_epoch(), 10) 47 | example_keys = [ 48 | example['example_key'][0].decode('utf-8') for example in examples 49 | ] 50 | self.assertEqual(example_keys, [ 51 | 'cifar10-train.array_record-00000-of-00001__45000', 52 | 'cifar10-train.array_record-00000-of-00001__45001', 53 | 'cifar10-train.array_record-00000-of-00001__45002', 54 | 'cifar10-train.array_record-00000-of-00001__45003', 55 | 'cifar10-train.array_record-00000-of-00001__45004', 56 | 'cifar10-train.array_record-00000-of-00001__45005', 57 | 'cifar10-train.array_record-00000-of-00001__45006', 58 | 'cifar10-train.array_record-00000-of-00001__45007', 59 | 'cifar10-train.array_record-00000-of-00001__45008', 60 | 'cifar10-train.array_record-00000-of-00001__45009', 61 | ]) 62 | 63 | 64 | if __name__ == '__main__': 65 | absltest.main() 66 | -------------------------------------------------------------------------------- /init2winit/dataset_lib/test_wikitext_tokenizer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for init2winit.dataset_lib.wikitext103.""" 17 | 18 | from absl.testing import absltest 19 | from init2winit.dataset_lib import wikitext_tokenizer 20 | import tensorflow as tf 21 | 22 | 23 | 24 | class TestWikitextTokenizer(absltest.TestCase): 25 | """Unit tests for wikitext103.py.""" 26 | 27 | def test_tokenizer_vocab_size(self): 28 | """Test vocab size. 29 | 30 | Vocab size should be number of unique words in text file + 2 for the 31 | and tokens. 32 | """ 33 | # Get number of unique tokens from tokenizer. 34 | text_dataset = tf.data.TextLineDataset(file_name) 35 | 36 | tokenizer = wikitext_tokenizer.Tokenizer() 37 | tokenizer.train(text_dataset) 38 | 39 | num_unique_tokens = len(tokenizer.dictionary.idx2word) 40 | 41 | # Get number of unique words from fake data. 42 | with open(file_name, 'r') as f: 43 | data = '' 44 | for line in f: 45 | # Not removing this would count tokens like '\n\n' and '\n' while the 46 | # TextLineDataset strips them. 47 | line = line.strip('\n') 48 | data = data + line 49 | 50 | words = data.split(' ') 51 | num_unique_words = len(set(words)) 52 | 53 | self.assertEqual(num_unique_tokens, num_unique_words + 2) 54 | 55 | if __name__ == '__main__': 56 | absltest.main() 57 | -------------------------------------------------------------------------------- /init2winit/dataset_lib/wikitext103.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Module containing hyperparameters, metadata and dataset getter for Wikitext-103 dataset.""" 17 | 18 | import itertools 19 | 20 | from init2winit.dataset_lib import data_utils 21 | from init2winit.dataset_lib import wikitext103_input_pipeline as input_pipeline 22 | from init2winit.dataset_lib import wikitext2_input_pipeline 23 | import jax 24 | from ml_collections.config_dict import config_dict 25 | import numpy as np 26 | 27 | PAD_ID = wikitext2_input_pipeline.PAD_ID 28 | Dataset = data_utils.Dataset 29 | 30 | VOCAB_SIZE = 267735 31 | 32 | DEFAULT_HPARAMS = config_dict.ConfigDict( 33 | dict( 34 | sequence_length=128, 35 | max_target_length=128, 36 | max_eval_target_length=128, 37 | eval_sequence_length=128, 38 | input_shape=(128,), 39 | output_shape=(input_pipeline.WORD_VOCAB_SIZE,), 40 | train_size=800210, # Number of sequences. 41 | tokenizer='word', 42 | tokenizer_vocab_path=None, 43 | vocab_size=input_pipeline.WORD_VOCAB_SIZE, 44 | )) 45 | 46 | 47 | METADATA = { 48 | 'apply_one_hot_in_loss': True, 49 | 'shift_inputs': True, 50 | 'causal': True, 51 | 'pad_token': -1, 52 | } 53 | 54 | 55 | def add_weights_to_batch(batch, pad_id: int = PAD_ID): 56 | """Add weights for the input values so that paddings have 0 weight. 57 | 58 | Args: 59 | batch: Batch represented by dict containing 'inputs' and 'targets'. 60 | pad_id: Value for 'inputs' that will have weight 0. 61 | 62 | Returns: 63 | batch with weights 64 | """ 65 | batch['weights'] = np.where(batch['inputs'] == pad_id, 0.0, 1.0) 66 | return batch 67 | 68 | 69 | def get_wikitext103( 70 | shuffle_rng, 71 | batch_size: int, 72 | eval_batch_size: int = None, 73 | hps: config_dict.ConfigDict = None, 74 | pad_id: int = PAD_ID) -> Dataset: 75 | """Returns Wikitext-103 Dataset. 76 | 77 | Args: 78 | shuffle_rng: jax.random.PRNGKey 79 | batch_size: training batch size 80 | eval_batch_size: validation batch size 81 | hps: Hyper parameters 82 | pad_id: Value for 'inputs' that will have weight 0. 83 | 84 | Returns: 85 | Dataset 86 | 87 | Raises: 88 | ValueError: If batch_size is not divisible by jax process count. 89 | ValueError: If eval_batch_size is not divisible by jax process count. 90 | """ 91 | process_count = jax.process_count() 92 | 93 | if batch_size % process_count != 0: 94 | raise ValueError( 95 | 'process_count={} must divide batch_size={}.'.format( 96 | process_count, batch_size)) 97 | 98 | if eval_batch_size is None: 99 | eval_batch_size = batch_size 100 | 101 | if eval_batch_size % process_count != 0: 102 | raise ValueError( 103 | 'process_count={} must divide batch_size={}.'.format( 104 | process_count, batch_size)) 105 | 106 | train_dataset, eval_train_dataset, valid_dataset, test_dataset = ( 107 | input_pipeline.get_wikitext103_dataset( 108 | hps, 109 | train_batch_size=batch_size, 110 | valid_batch_size=eval_batch_size, 111 | test_batch_size=eval_batch_size, 112 | shuffle_seed=data_utils.convert_jax_to_tf_random_seed(shuffle_rng), 113 | ) 114 | ) 115 | 116 | def train_iterator_fn(): 117 | for batch in train_dataset: 118 | yield add_weights_to_batch(data_utils.tf_to_numpy(batch), pad_id) 119 | 120 | def eval_train_epoch(num_batches=None): 121 | for batch in itertools.islice(iter(eval_train_dataset), num_batches): 122 | yield add_weights_to_batch(data_utils.tf_to_numpy(batch), pad_id) 123 | 124 | def valid_epoch(num_batches=None): 125 | for batch in itertools.islice(iter(valid_dataset), num_batches): 126 | yield add_weights_to_batch(data_utils.tf_to_numpy(batch), pad_id) 127 | 128 | def test_epoch(num_batches=None): 129 | for batch in itertools.islice(iter(test_dataset), num_batches): 130 | yield add_weights_to_batch(data_utils.tf_to_numpy(batch), pad_id) 131 | 132 | return Dataset(train_iterator_fn, eval_train_epoch, valid_epoch, 133 | test_epoch) 134 | -------------------------------------------------------------------------------- /init2winit/dataset_lib/wikitext103_spm.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Module containing hyperparameters, metadata and dataset getter for Wikitext-103 dataset.""" 17 | 18 | import functools 19 | from init2winit.dataset_lib import wikitext103 20 | from init2winit.dataset_lib import wikitext103_input_pipeline 21 | from ml_collections.config_dict import config_dict 22 | 23 | SPM_TOKENIZER_VOCAB_SIZE = wikitext103_input_pipeline.SPM_TOKENIZER_VOCAB_SIZE 24 | SPM_TOKENIZER_VOCAB_PATH = wikitext103_input_pipeline.SPM_TOKENIZER_VOCAB_PATH 25 | PAD_ID = -1 26 | get_wikitext103 = functools.partial(wikitext103.get_wikitext103, pad_id=PAD_ID) 27 | 28 | DEFAULT_HPARAMS = config_dict.ConfigDict( 29 | dict( 30 | sequence_length=128, 31 | max_target_length=128, 32 | max_eval_target_length=128, 33 | eval_sequence_length=128, 34 | input_shape=(128,), 35 | output_shape=(SPM_TOKENIZER_VOCAB_SIZE,), 36 | tokenizer='sentencepiece', 37 | tokenizer_vocab_path=SPM_TOKENIZER_VOCAB_PATH, 38 | vocab_size=SPM_TOKENIZER_VOCAB_SIZE, 39 | train_size=800210, # TODO(kasimbeg): Update this 40 | ) 41 | ) 42 | 43 | 44 | METADATA = { 45 | 'apply_one_hot_in_loss': True, 46 | 'shift_inputs': True, 47 | 'causal': True, 48 | 'pad_token': -1, 49 | } 50 | -------------------------------------------------------------------------------- /init2winit/dataset_lib/wikitext2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Module containing hyperparameters, metadata and dataset getter for Wikitext-2 dataset.""" 17 | 18 | import itertools 19 | 20 | from init2winit.dataset_lib import data_utils 21 | from init2winit.dataset_lib import wikitext2_input_pipeline as input_pipeline 22 | from init2winit.dataset_lib.data_utils import Dataset 23 | from init2winit.dataset_lib.wikitext2_input_pipeline import PAD_ID 24 | import jax 25 | from ml_collections.config_dict import config_dict 26 | import numpy as np 27 | 28 | VOCAB_SIZE = 33278 29 | 30 | DEFAULT_HPARAMS = config_dict.ConfigDict( 31 | dict( 32 | sequence_length=34, 33 | max_target_length=34, 34 | max_eval_target_length=34, 35 | input_shape=(34,), 36 | output_shape=(VOCAB_SIZE,), 37 | vocab_size=VOCAB_SIZE, 38 | # TODO(kasimbeg) : add vocab path after seperating out tokenizer 39 | # vocab_path=None, 40 | train_size=59676 # Number of sequences. 41 | )) 42 | 43 | METADATA = { 44 | 'apply_one_hot_in_loss': True, 45 | 'shift_inputs': True, 46 | 'causal': True, 47 | } 48 | 49 | 50 | def add_weights_to_batch(batch, pad_id: int = PAD_ID): 51 | """Add weights for the input values so that paddings have 0 weight. 52 | 53 | Args: 54 | batch: Batch represented by dict containing 'inputs' and 'targets'. 55 | pad_id: Value for 'inputs' that will have weight 0. 56 | 57 | Returns: 58 | batch with weights 59 | """ 60 | batch['weights'] = np.where(batch['inputs'] == pad_id, 0.0, 1.0) 61 | return batch 62 | 63 | 64 | def get_wikitext2( 65 | data_rng, 66 | batch_size: int, 67 | eval_batch_size: int = None, 68 | hps: config_dict.ConfigDict = None,) -> Dataset: 69 | """Returns Wikitext-2 Dataset. 70 | 71 | Args: 72 | data_rng: jax.random.PRNGKey 73 | batch_size: training batch size 74 | eval_batch_size: validation batch size 75 | hps: Hyper parameters 76 | 77 | Returns: 78 | Dataset 79 | 80 | Raises: 81 | ValueError: If batch_size is not divisible by jax process count. 82 | ValueError: If eval_batch_size is not divisible by jax process count. 83 | """ 84 | process_count = jax.process_count() 85 | 86 | if batch_size % process_count != 0: 87 | raise ValueError( 88 | 'process_count={} must divide batch_size={}.'.format( 89 | process_count, batch_size)) 90 | 91 | if eval_batch_size % process_count != 0: 92 | raise ValueError( 93 | 'process_count={} must divide batch_size={}.'.format( 94 | process_count, batch_size)) 95 | 96 | if eval_batch_size is None: 97 | eval_batch_size = batch_size 98 | 99 | train_dataset, eval_train_dataset, valid_dataset, test_dataset = input_pipeline.get_wikitext2_dataset( 100 | hps, 101 | train_batch_size=batch_size, 102 | valid_batch_size=eval_batch_size, 103 | test_batch_size=eval_batch_size, 104 | shuffle_seed=data_utils.convert_jax_to_tf_random_seed(data_rng), 105 | ) 106 | 107 | def train_iterator_fn(): 108 | for batch in train_dataset: 109 | yield add_weights_to_batch(data_utils.tf_to_numpy(batch)) 110 | 111 | def eval_train_epoch(num_batches=None): 112 | for batch in itertools.islice(iter(eval_train_dataset), num_batches): 113 | yield add_weights_to_batch(data_utils.tf_to_numpy(batch)) 114 | 115 | def valid_epoch(num_batches=None): 116 | for batch in itertools.islice(iter(valid_dataset), num_batches): 117 | yield add_weights_to_batch(data_utils.tf_to_numpy(batch)) 118 | 119 | def test_epoch(num_batches=None): 120 | for batch in itertools.islice(iter(test_dataset), num_batches): 121 | yield add_weights_to_batch(data_utils.tf_to_numpy(batch)) 122 | 123 | return Dataset(train_iterator_fn, eval_train_epoch, valid_epoch, 124 | test_epoch) 125 | -------------------------------------------------------------------------------- /init2winit/dataset_lib/wikitext_tokenizer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Contains Tokenizer class for word level tokenization. 17 | 18 | Note that the current tokenization workflow is not yet optimized for time and 19 | memory yet. 20 | 21 | """ 22 | 23 | import tensorflow as tf 24 | 25 | EOS_TOKEN = b'' 26 | UNKNOWN_TOKEN = b'' 27 | 28 | 29 | class _Dictionary: 30 | """Dictionary contains word-to-id mappings and id-to-word mappings. 31 | 32 | Attributes: 33 | word2idx: dict containing key-values where keys are words and values are 34 | tokens. 35 | idx2word: list where the index of each word in the list is the token value. 36 | """ 37 | 38 | def __init__(self): 39 | self.word2idx = {} 40 | self.idx2word = [] 41 | 42 | def add_word(self, word): 43 | if word not in self.word2idx: 44 | self.idx2word.append(word) 45 | # Start the first token idx at 1, because 0 is reserved for special tokens 46 | # e.g. for padding and masking 47 | self.word2idx[word] = len(self.idx2word) 48 | return self.word2idx[word] 49 | 50 | def __len__(self): 51 | return len(self.idx2word) 52 | 53 | 54 | class Tokenizer: 55 | """Tokenizer object for word level tokenization from words to unique ids. 56 | 57 | Attributes: 58 | dictionary: Dictionary containing word-to-id and id-to-word mappings 59 | lookup_table: tf.lookup.StaticHashTable for looking up token ids from words 60 | """ 61 | 62 | def __init__(self): 63 | self.dictionary = _Dictionary() 64 | 65 | def train(self, dataset: tf.data.TextLineDataset): 66 | """Trains a Tokenizer from a TextLineDataset.""" 67 | # Add words to the dictionary 68 | self.dictionary.add_word(UNKNOWN_TOKEN) # add default unknown token 69 | for line in dataset: 70 | words = line.numpy().split() + [EOS_TOKEN] 71 | for word in words: 72 | self.dictionary.add_word(word) 73 | # Make static vocabulary table for tf.data style tokenization 74 | self.lookup_table = tf.lookup.StaticHashTable( 75 | tf.lookup.KeyValueTensorInitializer( 76 | tf.constant(list(self.dictionary.word2idx.keys()), dtype=tf.string), 77 | tf.constant( 78 | list(self.dictionary.word2idx.values()), dtype=tf.int32 79 | ), 80 | ), 81 | default_value=self.dictionary.word2idx[UNKNOWN_TOKEN], 82 | ) 83 | 84 | def tokenize(self, input_tensor: tf.Tensor) -> tf.Tensor: 85 | """Tokenizes a tensor of UTF-8 strings. 86 | 87 | Args: 88 | input_tensor: A `RaggedTensor` or `Tensor` of UTF-8 strings with any 89 | shape. 90 | 91 | Returns: 92 | A `RaggedTensor` or `Tensor` of tokenized text. The returned shape is 93 | the shape of the input tensor. 94 | """ 95 | eos_tensor = tf.constant([EOS_TOKEN], dtype=tf.string) 96 | input_tensor_split = tf.strings.split(input_tensor) 97 | input_tensor_extended = tf.concat([input_tensor_split, eos_tensor], axis=-1) 98 | return self.lookup_table.lookup(input_tensor_extended) 99 | -------------------------------------------------------------------------------- /init2winit/gradient_statistics_callback.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Callback for computing gradient statistics given set of params. 17 | """ 18 | 19 | import functools 20 | import itertools 21 | import os 22 | 23 | import flax.linen as nn 24 | from init2winit import base_callback 25 | from init2winit import checkpoint 26 | from init2winit.dataset_lib import data_utils 27 | import jax 28 | import jax.numpy as jnp 29 | 30 | 31 | class GradientStatisticsCallback(base_callback.BaseCallBack): 32 | """Runs evals on MT models with datasets/params different than in training.""" 33 | 34 | def __init__(self, 35 | model, 36 | params, 37 | batch_stats, 38 | optimizer_state, 39 | optimizer_update_fn, 40 | dataset, 41 | hps, 42 | callback_config, 43 | train_dir, 44 | rng, 45 | mesh): 46 | del optimizer_state 47 | del optimizer_update_fn 48 | 49 | self.dataset = dataset 50 | self.model = model 51 | self.hps = hps 52 | self.callback_config = callback_config 53 | self.rng = rng 54 | self.save_path = os.path.join(train_dir, 'gradient_statistics/') 55 | self.mesh = mesh 56 | 57 | self.num_batches_in_training_epoch = ( 58 | self.hps.train_size // self.hps.batch_size 59 | ) 60 | if callback_config is not None: 61 | if 'num_batches_in_training_epoch' in callback_config.keys(): 62 | self.num_batches_in_training_epoch = callback_config[ 63 | 'num_batches_in_training_epoch' 64 | ] 65 | 66 | self.num_updates = 0 67 | 68 | def update(params, batch, batch_stats, dropout_rng): 69 | def opt_cost(params): 70 | return self.model.training_cost( 71 | params, 72 | batch=batch, 73 | batch_stats=batch_stats, 74 | dropout_rng=dropout_rng, 75 | ) 76 | 77 | grad_fn = jax.value_and_grad(opt_cost, has_aux=True) 78 | _, grad = grad_fn(params) 79 | 80 | return grad 81 | 82 | params_sharding = jax.tree_util.tree_map( 83 | lambda x: x.sharding, params 84 | ) 85 | batch_stats_sharding = nn.get_sharding(batch_stats, self.mesh) 86 | 87 | self.jitted_update = jax.jit( 88 | update, 89 | in_shardings=( 90 | params_sharding, 91 | jax.sharding.NamedSharding( 92 | self.mesh, jax.sharding.PartitionSpec('devices')), 93 | batch_stats_sharding, 94 | None 95 | ), 96 | out_shardings=(params_sharding) 97 | ) 98 | 99 | def run_eval(self, params, batch_stats, optimizer_state, global_step): 100 | """Computes gradient statistics from mini batches over full training data. 101 | """ 102 | del optimizer_state 103 | train_iter = itertools.islice( 104 | self.dataset.train_iterator_fn(), self.num_batches_in_training_epoch 105 | ) 106 | 107 | grad_sum = jax.tree.map(jnp.zeros_like, params) 108 | grad_squared_sum = jax.tree.map(jnp.zeros_like, params) 109 | self.num_updates = 0 110 | 111 | make_global_array_fn = functools.partial( 112 | data_utils.make_global_array, mesh=self.mesh 113 | ) 114 | 115 | for batch in train_iter: 116 | sharded_batch = jax.tree_util.tree_map(make_global_array_fn, batch) 117 | grads = self.jitted_update(params, sharded_batch, batch_stats, self.rng) 118 | 119 | grad_sum = jax.tree_util.tree_map( 120 | lambda g_sum, g: g_sum + g, grad_sum, grads 121 | ) 122 | 123 | grad_squared_sum = jax.tree_util.tree_map( 124 | lambda g_squared, g: g_squared + g**2, grad_squared_sum, grads 125 | ) 126 | 127 | self.num_updates += 1 128 | 129 | grad_mean = jax.tree_util.tree_map( 130 | lambda g_sum: g_sum / self.num_updates, grad_sum 131 | ) 132 | grad_std = jax.tree_util.tree_map( 133 | lambda g_squared, g_mean: jnp.sqrt( # pylint: disable=g-long-lambda 134 | g_squared / self.num_updates - g_mean**2 135 | ), 136 | grad_squared_sum, 137 | grad_mean, 138 | ) 139 | 140 | state = dict( 141 | grad_std=jax.device_get(grad_std), 142 | grad_mean=jax.device_get(grad_mean), 143 | step=global_step 144 | ) 145 | 146 | checkpoint.save_checkpoint( 147 | self.save_path, 148 | step=global_step, 149 | state=state, 150 | prefix='measurement_', 151 | max_to_keep=None) 152 | 153 | return {} 154 | -------------------------------------------------------------------------------- /init2winit/init_lib/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /init2winit/init_lib/initializers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Registry for the available initializers we can test. 17 | 18 | API of an initializer: 19 | new_params = init(loss, init_params, hps, num_outputs, input_shape) 20 | 21 | TODO(gilmer, gdahl, schsam): The API of an initializer should in general be 22 | aware of the moments of the data. Currently we assume that all input coordinates 23 | are iid standard normal distributions. 24 | """ 25 | 26 | from init2winit.init_lib import meta_init 27 | from init2winit.init_lib import sparse_init 28 | from ml_collections.config_dict import config_dict 29 | 30 | 31 | # This function is meant to match the general API of an initializer 32 | # pylint: disable=unused-argument 33 | def noop( 34 | loss_fn=None, 35 | flax_module=None, 36 | params=None, 37 | hps=None, 38 | input_shape=None, 39 | output_shape=None, 40 | rng_key=None, 41 | metrics_logger=None, 42 | ): 43 | """No-op init.""" 44 | return params 45 | # pylint: enable=unused-argument 46 | 47 | DEFAULT_HPARAMS = config_dict.ConfigDict() 48 | 49 | _ALL_INITIALIZERS = { 50 | 'noop': (noop, DEFAULT_HPARAMS), 51 | 'meta_init': (meta_init.meta_init, meta_init.DEFAULT_HPARAMS), 52 | 'sparse_init': (sparse_init.sparse_init, sparse_init.DEFAULT_HPARAMS), 53 | } 54 | 55 | 56 | def get_initializer(initializer_name): 57 | """Get the corresponding initializer function based on the initializer string. 58 | 59 | API of an initializer: 60 | init_fn, hparams = get_initializer(init) 61 | new_params, final_l = init_fn(loss, init_params, hps, 62 | num_outputs, input_shape) 63 | 64 | Args: 65 | initializer_name: (str) e.g. default. 66 | 67 | Returns: 68 | initializer 69 | Raises: 70 | ValueError if model is unrecognized. 71 | """ 72 | try: 73 | return _ALL_INITIALIZERS[initializer_name][0] 74 | except KeyError: 75 | raise ValueError('Unrecognized initializer: {}'.format(initializer_name)) 76 | 77 | 78 | def get_initializer_hparams(initializer_name): 79 | """Get the corresponding hyperparameters based on the initializer string. 80 | 81 | Args: 82 | initializer_name: (str) e.g. default. 83 | 84 | Returns: 85 | hps 86 | Raises: 87 | ValueError if model is unrecognized. 88 | """ 89 | try: 90 | return _ALL_INITIALIZERS[initializer_name][1] 91 | except KeyError: 92 | raise ValueError('Unrecognized initializer: {}'.format(initializer_name)) 93 | -------------------------------------------------------------------------------- /init2winit/init_lib/sparse_init.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Defines the SparseInit initializer. 17 | 18 | This initializer limits the number of non-zero incoming and outgoing connection 19 | weights. For more information, see Section 5 of (Martens, 2010), which can be 20 | found at https://www.cs.toronto.edu/~jmartens/docs/Deep_HessianFree.pdf. 21 | """ 22 | 23 | from flax.core import frozen_dict 24 | from flax.core import unfreeze 25 | import jax 26 | from ml_collections.config_dict import config_dict 27 | import numpy as np 28 | 29 | DEFAULT_HPARAMS = config_dict.ConfigDict(dict(non_zero_connection_weights=15,)) 30 | 31 | 32 | def sparse_init(loss_fn, 33 | flax_module, 34 | params, 35 | hps, 36 | input_shape, 37 | output_shape, 38 | rng_key, 39 | metrics_logger=None, 40 | log_every=10): 41 | """Implements SparseInit initializer. 42 | 43 | Args: 44 | loss_fn: Loss function. 45 | flax_module: Flax nn.Module class. 46 | params: The dict of model parameters. 47 | hps: HParam object. Required hparams are meta_learning_rate, 48 | meta_batch_size, meta_steps, and epsilon. 49 | input_shape: Must agree with batch[0].shape[1:]. 50 | output_shape: Must agree with batch[1].shape[1:]. 51 | rng_key: jax.PRNGKey, used to seed all randomness. 52 | metrics_logger: Instance of utils.MetricsLogger 53 | log_every: Print meta loss every k steps. 54 | 55 | Returns: 56 | A Flax model with sparse initialization. 57 | """ 58 | 59 | del flax_module, loss_fn, input_shape, output_shape, metrics_logger, log_every 60 | 61 | params = unfreeze(params) 62 | activation_functions = hps.activation_function 63 | num_hidden_layers = len(hps.hid_sizes) 64 | if isinstance(hps.activation_function, str): 65 | activation_functions = [hps.activation_function] * num_hidden_layers 66 | for i, key in enumerate(params): 67 | num_units_in, num_units_out = params[key]['kernel'].shape 68 | 69 | mask = np.full((num_units_in, num_units_out), True, dtype=bool) 70 | 71 | # Restrict the number of non-zero weights from input units. 72 | rng_key, *rng_keys_in = jax.random.split(rng_key, num_units_in + 1) 73 | for k in range(num_units_in): 74 | if num_units_out > hps.non_zero_connection_weights: 75 | non_zero_units_out = jax.random.choice( 76 | rng_keys_in[k], num_units_out, (hps.non_zero_connection_weights,), 77 | replace=False) 78 | mask[k, non_zero_units_out] = False 79 | else: 80 | mask[k, :] = False 81 | 82 | # Restrict the number of non-zero weights to output units. 83 | rng_key, *rng_keys_out = jax.random.split(rng_key, num_units_out + 1) 84 | for k in range(num_units_out): 85 | if num_units_in > hps.non_zero_connection_weights: 86 | non_zero_units_in = jax.random.choice( 87 | rng_keys_out[k], num_units_in, (hps.non_zero_connection_weights,), 88 | replace=False) 89 | mask[non_zero_units_in, k] = False 90 | else: 91 | mask[:, k] = False 92 | params[key]['kernel'] = params[key]['kernel'].at[mask].set(0.0) 93 | 94 | if i < num_hidden_layers and activation_functions[i] == 'tanh': 95 | params[key]['bias'] = params[key]['bias'].at[:].set(0.5) 96 | else: 97 | params[key]['bias'] = params[key]['bias'].at[:].set(0.0) 98 | return frozen_dict.freeze(params) 99 | -------------------------------------------------------------------------------- /init2winit/model_lib/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /init2winit/model_lib/adabelief_vgg.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Flax implementation of Adabelief VGG. 17 | 18 | This module ports the Adabelief implemetation of VGG to Flax. The 19 | Adabelief paper and github can be found here: 20 | 21 | https://arxiv.org/abs/2010.07468 22 | 23 | https://github.com/juntang-zhuang/Adabelief-Optimizer/blob/update_0.2.0/PyTorch_Experiments/classification_cifar10/models/vgg.py 24 | 25 | The original VGGNet paper can be found here: 26 | 27 | https://arxiv.org/abs/1409.1556 28 | """ 29 | 30 | import functools 31 | 32 | from flax import linen as nn 33 | from init2winit.model_lib import base_model 34 | from init2winit.model_lib import model_utils 35 | import jax.numpy as jnp 36 | from ml_collections.config_dict import config_dict 37 | 38 | 39 | DEFAULT_HPARAMS = config_dict.ConfigDict( 40 | dict( 41 | num_layers=11, # Must be one of [11, 13, 16, 19] 42 | layer_rescale_factors={}, 43 | lr_hparams={ 44 | 'schedule': 'constant', 45 | 'base_lr': 0.2, 46 | }, 47 | normalizer='none', 48 | optimizer='momentum', 49 | opt_hparams={ 50 | 'momentum': 0.9, 51 | }, 52 | batch_size=128, 53 | l2_decay_factor=0.0001, 54 | l2_decay_rank_threshold=2, 55 | label_smoothing=None, 56 | rng_seed=-1, 57 | use_shallue_label_smoothing=False, 58 | model_dtype='float32', 59 | grad_clip=None, 60 | )) 61 | 62 | 63 | def classifier(x, num_outputs, dropout_rate, deterministic): 64 | """Implements the classification portion of the network.""" 65 | 66 | x = nn.Dropout(rate=dropout_rate, deterministic=deterministic)(x) 67 | x = nn.Dense(512)(x) 68 | x = nn.relu(x) 69 | x = nn.Dropout(rate=dropout_rate, deterministic=deterministic)(x) 70 | x = nn.Dense(512)(x) 71 | x = nn.relu(x) 72 | x = nn.Dense(num_outputs)(x) 73 | return x 74 | 75 | 76 | def features(x, num_layers, normalizer, dtype, train): 77 | """Implements the feature extraction portion of the network.""" 78 | 79 | layers = _layer_size_options[num_layers] 80 | conv = functools.partial(nn.Conv, use_bias=False, dtype=dtype) 81 | maybe_normalize = model_utils.get_normalizer(normalizer, train) 82 | for l in layers: 83 | if l == 'M': 84 | x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2)) 85 | else: 86 | x = conv(features=l, kernel_size=(3, 3), padding=((1, 1), (1, 1)))(x) 87 | x = maybe_normalize()(x) 88 | x = nn.relu(x) 89 | return x 90 | 91 | 92 | class VGG(nn.Module): 93 | """Adabelief VGG.""" 94 | num_layers: int 95 | num_outputs: int 96 | normalizer: str = 'none' 97 | dtype: str = 'float32' 98 | 99 | @nn.compact 100 | def __call__(self, x, train): 101 | x = features(x, self.num_layers, self.normalizer, self.dtype, train) 102 | x = jnp.reshape(x, (x.shape[0], -1)) 103 | x = classifier( 104 | x, self.num_outputs, dropout_rate=0.5, deterministic=not train) 105 | return x 106 | 107 | 108 | # Specifies the sequence of layers in the feature extraction section of the 109 | # network for a given size. 110 | # The numbers indicate the feature size of a convolutional layer, the 111 | # letter M indicates a max pooling layer. 112 | _layer_size_options = { 113 | 1: [ 114 | 8, 'M', 16, 'M', 32, 32, 'M', 64, 64, 'M', 64, 64, 'M' 115 | ], # used for testing only. 116 | 11: [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 117 | 13: [ 118 | 64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M' 119 | ], 120 | 16: [ 121 | 64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 122 | 512, 512, 'M' 123 | ], 124 | 19: [ 125 | 64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 126 | 'M', 512, 512, 512, 512, 'M' 127 | ], 128 | } 129 | 130 | 131 | # pylint: disable=[missing-class-docstring] 132 | class AdaBeliefVGGModel(base_model.BaseModel): 133 | def build_flax_module(self): 134 | """Adabelief VGG.""" 135 | return VGG( 136 | num_layers=self.hps.num_layers, 137 | num_outputs=self.hps['output_shape'][-1], 138 | dtype=self.hps.model_dtype, 139 | normalizer=self.hps.normalizer) 140 | 141 | def get_fake_inputs(self, hps): 142 | """Helper method solely for the purpose of initialzing the model.""" 143 | dummy_inputs = [ 144 | jnp.zeros((hps.batch_size, *hps.input_shape), dtype=hps.model_dtype) 145 | ] 146 | return dummy_inputs 147 | -------------------------------------------------------------------------------- /init2winit/model_lib/autoencoder.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | r"""Fully connected autoencoder. 17 | 18 | This model builds an autoencoder using FullyConnected module. 19 | More information on the fully connected autoencoder model can be found here: 20 | 21 | https://www.cs.toronto.edu/~hinton/science.pdf 22 | 23 | """ 24 | 25 | from init2winit.model_lib import base_model 26 | from init2winit.model_lib.fully_connected import FullyConnected 27 | from jax.nn import initializers 28 | import jax.numpy as jnp 29 | from ml_collections.config_dict import config_dict 30 | 31 | # small test hparams 32 | # https://blog.keras.io/building-autoencoders-in-keras.html 33 | # for the configuration of a standard fully connected autoencoder model, 34 | # see https://www.cs.toronto.edu/~hinton/science.pdf. 35 | DEFAULT_HPARAMS = config_dict.ConfigDict( 36 | dict( 37 | hid_sizes=[128, 64, 32, 64, 128], 38 | activation_function=['relu', 'relu', 'relu', 'relu', 'relu'], 39 | kernel_scales=[1.0] * 6, 40 | lr_hparams={ 41 | 'base_lr': 0.1, 42 | 'schedule': 'constant' 43 | }, 44 | layer_rescale_factors={}, 45 | optimizer='hessian_free', 46 | opt_hparams={ 47 | 'cg_max_iter': 250, 48 | 'cg_iter_tracking_method': 'back_tracking', 49 | 'use_line_search': True, 50 | 'init_damping': 50.0, 51 | 'damping_ub': 10 ** 2, 52 | 'damping_lb': 10 ** -6, 53 | }, 54 | batch_size=128, 55 | label_smoothing=None, 56 | rng_seed=-1, 57 | use_shallue_label_smoothing=False, 58 | model_dtype='float32', 59 | l2_decay_factor=2e-5, 60 | l2_decay_rank_threshold=1, 61 | )) 62 | 63 | 64 | class AutoEncoderModel(base_model.BaseModel): 65 | """Model class for AutoEncoder model.""" 66 | 67 | def build_flax_module(self): 68 | kernel_inits = [ 69 | initializers.normal(scale) 70 | for scale in self.hps.kernel_scales 71 | ] 72 | 73 | return FullyConnected( 74 | num_outputs=self.hps['output_shape'][-1], 75 | hid_sizes=self.hps.hid_sizes, 76 | activation_function=self.hps.activation_function, 77 | kernel_inits=kernel_inits) 78 | 79 | def get_fake_inputs(self, hps): 80 | """Helper method solely for the purpose of initialzing the model.""" 81 | dummy_inputs = [ 82 | jnp.zeros((hps.batch_size, *hps.input_shape), dtype=hps.model_dtype) 83 | ] 84 | return dummy_inputs 85 | -------------------------------------------------------------------------------- /init2winit/model_lib/convolutional_autoencoder.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | r"""Convolutional autoencoder. 17 | 18 | This model uses a convolutional encoder-decoder network to reconstruct input 19 | images as outputs. 20 | 21 | """ 22 | 23 | from typing import Any, Dict, Sequence 24 | 25 | from flax import linen as nn 26 | from init2winit.model_lib import base_model 27 | from init2winit.model_lib import model_utils 28 | from jax import numpy as jnp 29 | from ml_collections.config_dict import config_dict 30 | 31 | 32 | # small test hparams from 33 | # https://blog.keras.io/building-autoencoders-in-keras.html 34 | DEFAULT_HPARAMS = config_dict.ConfigDict( 35 | dict( 36 | encoder={ 37 | 'filter_sizes': [16, 8, 8], 38 | 'kernel_sizes': [(3, 3), (3, 3), (3, 3)], 39 | 'kernel_paddings': ['SAME', 'SAME', 'SAME'], 40 | 'window_sizes': [(2, 2), (2, 2), (2, 2)], 41 | 'window_paddings': ['SAME', 'SAME', 'SAME'], 42 | 'strides': [(2, 2), (2, 2), (2, 2)], 43 | 'activations': ['relu', 'relu', 'relu'], 44 | }, 45 | decoder={ 46 | 'filter_sizes': [8, 8, 16, 1], 47 | 'kernel_sizes': [(3, 3), (3, 3), (3, 3), (3, 3)], 48 | 'window_sizes': [(2, 2), (2, 2), (2, 2), None], 49 | 'paddings': ['SAME', ((1, 0), (1, 0)), 'SAME', 'SAME'], 50 | 'activations': ['relu', 'relu', 'relu', 'id'], 51 | }, 52 | 53 | activation_function='relu', 54 | lr_hparams={ 55 | 'base_lr': 0.02, 56 | 'schedule': 'constant' 57 | }, 58 | layer_rescale_factors={}, 59 | optimizer='momentum', 60 | opt_hparams={ 61 | 'momentum': 0, 62 | }, 63 | batch_size=128, 64 | l2_decay_factor=None, 65 | l2_decay_rank_threshold=0, 66 | label_smoothing=None, 67 | rng_seed=-1, 68 | use_shallue_label_smoothing=False, 69 | model_dtype='float32', 70 | grad_clip=None, 71 | )) 72 | 73 | 74 | class ConvAutoEncoder(nn.Module): 75 | """Defines a fully connected neural network. 76 | 77 | The model assumes the input data has shape 78 | [batch_size_per_device, *input_shape] where input_shape may be of arbitrary 79 | rank. The model flatten the input before applying a dense layer. 80 | """ 81 | output_shape: Sequence[int] 82 | encoder: Dict[str, Any] 83 | decoder: Dict[str, Any] 84 | 85 | @nn.compact 86 | def __call__(self, x, train): 87 | del train 88 | encoder_keys = [ 89 | 'filter_sizes', 90 | 'kernel_sizes', 91 | 'kernel_paddings', 92 | 'window_sizes', 93 | 'window_paddings', 94 | 'strides', 95 | 'activations', 96 | ] 97 | if len(set(len(self.encoder[k]) for k in encoder_keys)) > 1: 98 | raise ValueError( 99 | 'The elements in encoder dict do not have the same length.') 100 | 101 | decoder_keys = [ 102 | 'filter_sizes', 103 | 'kernel_sizes', 104 | 'window_sizes', 105 | 'paddings', 106 | 'activations', 107 | ] 108 | if len(set(len(self.decoder[k]) for k in decoder_keys)) > 1: 109 | raise ValueError( 110 | 'The elements in decoder dict do not have the same length.') 111 | 112 | # encoder 113 | for i in range(len(self.encoder['filter_sizes'])): 114 | x = nn.Conv( 115 | self.encoder['filter_sizes'][i], 116 | self.encoder['kernel_sizes'][i], 117 | padding=self.encoder['kernel_paddings'][i])(x) 118 | x = model_utils.ACTIVATIONS[self.encoder['activations'][i]](x) 119 | x = nn.max_pool( 120 | x, self.encoder['window_sizes'][i], 121 | strides=self.encoder['strides'][i], 122 | padding=self.encoder['window_paddings'][i]) 123 | 124 | # decoder 125 | for i in range(len(self.decoder['filter_sizes'])): 126 | x = nn.ConvTranspose( 127 | self.decoder['filter_sizes'][i], 128 | self.decoder['kernel_sizes'][i], 129 | self.decoder['window_sizes'][i], 130 | padding=self.decoder['paddings'][i])(x) 131 | x = model_utils.ACTIVATIONS[self.decoder['activations'][i]](x) 132 | return x 133 | 134 | 135 | # pylint: disable=[missing-class-docstring] 136 | class ConvAutoEncoderModel(base_model.BaseModel): 137 | 138 | def build_flax_module(self): 139 | return ConvAutoEncoder( 140 | output_shape=self.hps.output_shape, 141 | encoder=self.hps.encoder, 142 | decoder=self.hps.decoder) 143 | 144 | def get_fake_inputs(self, hps): 145 | """Helper method solely for the purpose of initialzing the model.""" 146 | dummy_inputs = [ 147 | jnp.zeros((hps.batch_size, *hps.input_shape), dtype=hps.model_dtype) 148 | ] 149 | return dummy_inputs 150 | -------------------------------------------------------------------------------- /init2winit/model_lib/fully_connected.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Simple fully connected feedforward neural network classifier.""" 17 | import copy 18 | from typing import Any, Tuple 19 | 20 | from flax import linen as nn 21 | from init2winit.model_lib import base_model 22 | from init2winit.model_lib import model_utils 23 | from jax.nn import initializers 24 | import jax.numpy as jnp 25 | from ml_collections.config_dict import config_dict 26 | 27 | 28 | # small hparams used for unit tests 29 | DEFAULT_HPARAMS = config_dict.ConfigDict( 30 | dict( 31 | hid_sizes=[20, 10], 32 | kernel_scales=[1.0, 1.0, 1.0], 33 | lr_hparams={ 34 | 'base_lr': 0.1, 35 | 'schedule': 'constant' 36 | }, 37 | layer_rescale_factors={}, 38 | optimizer='momentum', 39 | opt_hparams={ 40 | 'momentum': 0.9, 41 | }, 42 | batch_size=128, 43 | total_accumulated_batch_size=None, 44 | activation_function='relu', 45 | l2_decay_factor=.0005, 46 | l2_decay_rank_threshold=2, 47 | label_smoothing=None, 48 | rng_seed=-1, 49 | use_shallue_label_smoothing=False, 50 | model_dtype='float32', 51 | grad_clip=None, 52 | )) 53 | 54 | 55 | class FullyConnected(nn.Module): 56 | """Defines a fully connected neural network. 57 | 58 | The model assumes the input data has shape 59 | [batch_size_per_device, *input_shape] where input_shape may be of arbitrary 60 | rank. The model flatten the input before applying a dense layer. 61 | """ 62 | num_outputs: int 63 | hid_sizes: Tuple[int] 64 | activation_function: Any 65 | kernel_inits: Tuple[model_utils.Initializer] 66 | bias_init: model_utils.Initializer = initializers.zeros 67 | 68 | @nn.compact 69 | def __call__(self, x, train): 70 | del train 71 | if not isinstance(self.activation_function, str): 72 | if len(self.activation_function) != len(self.hid_sizes): 73 | raise ValueError( 74 | 'The number of activation functions must be equal to the number ' 75 | 'of hidden layers') 76 | activation_function = copy.deepcopy(self.activation_function) 77 | else: 78 | activation_function = [self.activation_function] * len(self.hid_sizes) 79 | 80 | x = jnp.reshape(x, (x.shape[0], -1)) 81 | for i, (num_hid, init) in enumerate( 82 | zip(self.hid_sizes, self.kernel_inits[:-1])): 83 | x = nn.Dense(num_hid, kernel_init=init, bias_init=self.bias_init)(x) 84 | x = model_utils.ACTIVATIONS[activation_function[i]](x) 85 | x = nn.Dense( 86 | self.num_outputs, 87 | kernel_init=self.kernel_inits[-1], 88 | bias_init=self.bias_init)(x) 89 | return x 90 | 91 | 92 | # pylint: disable=missing-class-docstring 93 | class FullyConnectedModel(base_model.BaseModel): 94 | """Model class for fully connected model.""" 95 | 96 | def build_flax_module(self): 97 | kernel_inits = [ 98 | initializers.variance_scaling(scale, 'fan_in', 'truncated_normal') 99 | for scale in self.hps.kernel_scales 100 | ] 101 | return FullyConnected( 102 | num_outputs=self.hps['output_shape'][-1], 103 | hid_sizes=tuple(self.hps.hid_sizes), 104 | activation_function=self.hps.activation_function, 105 | kernel_inits=tuple(kernel_inits)) 106 | 107 | def get_fake_inputs(self, hps): 108 | """Helper method solely for the purpose of initialzing the model.""" 109 | dummy_inputs = [ 110 | jnp.zeros((hps.batch_size, *hps.input_shape), dtype=hps.model_dtype) 111 | ] 112 | return dummy_inputs 113 | -------------------------------------------------------------------------------- /init2winit/model_lib/max_pooling_cnn.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Max pooling convnet classifier. 17 | 18 | This model can be used to implement the 3c3d architecture from: 19 | https://github.com/fsschneider/DeepOBS/blob/master/deepobs/tensorflow/testproblems/_3c3d.py 20 | """ 21 | from typing import Any, Sequence 22 | 23 | from flax import linen as nn 24 | from init2winit.model_lib import base_model 25 | from init2winit.model_lib import model_utils 26 | from jax.nn import initializers 27 | import jax.numpy as jnp 28 | 29 | from ml_collections.config_dict import config_dict 30 | 31 | 32 | # small hparams used for unit tests 33 | DEFAULT_HPARAMS = config_dict.ConfigDict(dict( 34 | num_filters=[64, 96, 128], 35 | kernel_sizes=[5, 3, 3], 36 | kernel_paddings=['VALID', 'VALID', 'SAME'], 37 | window_sizes=[3, 3, 3], 38 | window_paddings=['SAME', 'SAME', 'SAME'], 39 | strides=[2, 2, 2], 40 | num_dense_units=[512, 256], 41 | lr_hparams={ 42 | 'base_lr': 0.001, 43 | 'schedule': 'constant' 44 | }, 45 | layer_rescale_factors={}, 46 | optimizer='momentum', 47 | opt_hparams={ 48 | 'momentum': 0.9, 49 | }, 50 | batch_size=128, 51 | activation_fn='relu', 52 | normalizer='none', 53 | l2_decay_factor=.0005, 54 | l2_decay_rank_threshold=2, 55 | label_smoothing=None, 56 | rng_seed=-1, 57 | use_shallue_label_smoothing=False, 58 | model_dtype='float32', 59 | grad_clip=None, 60 | total_accumulated_batch_size=None, 61 | )) 62 | 63 | 64 | class MaxPoolingCNN(nn.Module): 65 | """Defines a CNN model with max pooling. 66 | 67 | The model assumes the input shape is [batch, H, W, C]. 68 | """ 69 | num_outputs: int 70 | num_filters: Sequence[int] 71 | kernel_sizes: Sequence[int] 72 | kernel_paddings: Sequence[str] 73 | window_sizes: Sequence[int] 74 | window_paddings: Sequence[str] 75 | strides: Sequence[int] 76 | num_dense_units: int 77 | activation_fn: Any 78 | normalizer: str = 'none' 79 | kernel_init: model_utils.Initializer = initializers.lecun_normal() 80 | bias_init: model_utils.Initializer = initializers.zeros 81 | 82 | @nn.compact 83 | def __call__(self, x, train): 84 | maybe_normalize = model_utils.get_normalizer(self.normalizer, train) 85 | iterator = zip( 86 | self.num_filters, self.kernel_sizes, self.kernel_paddings, 87 | self.window_sizes, self.window_paddings, self.strides) 88 | for num_filters, kernel_size, kernel_padding, window_size, window_padding, stride in iterator: 89 | x = nn.Conv( 90 | num_filters, (kernel_size, kernel_size), (1, 1), 91 | padding=kernel_padding, 92 | kernel_init=self.kernel_init, 93 | bias_init=self.bias_init)(x) 94 | x = model_utils.ACTIVATIONS[self.activation_fn](x) 95 | x = maybe_normalize()(x) 96 | x = nn.max_pool( 97 | x, 98 | window_shape=(window_size, window_size), 99 | strides=(stride, stride), 100 | padding=window_padding) 101 | x = jnp.reshape(x, (x.shape[0], -1)) 102 | for num_units in self.num_dense_units: 103 | x = nn.Dense( 104 | num_units, kernel_init=self.kernel_init, bias_init=self.bias_init)(x) 105 | x = model_utils.ACTIVATIONS[self.activation_fn](x) 106 | x = maybe_normalize()(x) 107 | x = nn.Dense( 108 | self.num_outputs, 109 | kernel_init=self.kernel_init, 110 | bias_init=self.bias_init)(x) 111 | return x 112 | 113 | 114 | class MaxPoolingCNNModel(base_model.BaseModel): 115 | """Model class for MaxPooling CNNModel.""" 116 | 117 | def build_flax_module(self): 118 | """CNN with a set of conv layers with max pooling followed by fully connected layers.""" 119 | return MaxPoolingCNN( 120 | num_outputs=self.hps['output_shape'][-1], 121 | num_filters=self.hps.num_filters, 122 | kernel_sizes=self.hps.kernel_sizes, 123 | kernel_paddings=self.hps.kernel_paddings, 124 | window_sizes=self.hps.window_sizes, 125 | window_paddings=self.hps.window_paddings, 126 | strides=self.hps.strides, 127 | num_dense_units=self.hps.num_dense_units, 128 | activation_fn=self.hps.activation_fn, 129 | normalizer=self.hps.normalizer) 130 | 131 | def get_fake_inputs(self, hps): 132 | """Helper method solely for the purpose of initialzing the model.""" 133 | dummy_inputs = [ 134 | jnp.zeros((hps.batch_size, *hps.input_shape), dtype=hps.model_dtype) 135 | ] 136 | return dummy_inputs 137 | -------------------------------------------------------------------------------- /init2winit/model_lib/partition_tree.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Contains functions to partition a parameter pytree.""" 17 | 18 | 19 | def outer_key(x): 20 | return x[0] 21 | 22 | 23 | def create_partition_flat_params_fn(key_map): 24 | """Partitions a flattened pytree according to the provided key_map. 25 | 26 | Subsets are determined by the kep_map which hashes the flattened model 27 | parameter keys into disjount groups. For example, if the flattened param 28 | tree is 29 | 30 | {('a', 'b'): 1.0, 31 | ('a', 'c'): 1.0, 32 | ('d', 'b'): 2.0} 33 | 34 | And we partition on the out_key then the output is 35 | {'a': {('a', 'b'): 1.0, ('a', 'c'): 1.0} 36 | 'd': {('d', 'b'): 2.0}}. 37 | 38 | Args: 39 | key_map: Maps a tuple of strings to a hashable value. 40 | 41 | Returns: 42 | partition_flat_params, a function which returns a partitioned param 43 | dictionary. 44 | """ 45 | def partition_flat_params(flat_params): 46 | subparam_groups = {} 47 | for tup in flat_params: 48 | mapped_key = key_map(tup) 49 | if mapped_key not in subparam_groups: 50 | subparam_groups[mapped_key] = {} 51 | subparam_groups[mapped_key][tup] = flat_params[tup] 52 | 53 | return subparam_groups 54 | return partition_flat_params 55 | 56 | 57 | registry = { 58 | 'outer_key': create_partition_flat_params_fn(outer_key), 59 | } 60 | 61 | 62 | def get_param_partition_fn(name): 63 | return registry[name] 64 | 65 | 66 | # Used in test_model_debugger.py 67 | def get_test_group(params): 68 | del params 69 | return ['B_0/C_0', 'C_0'] 70 | 71 | 72 | skip_analysis_registry = { 73 | 'test_group': get_test_group, 74 | } 75 | 76 | 77 | def get_skip_analysis_fn(name): 78 | return skip_analysis_registry[name] 79 | 80 | -------------------------------------------------------------------------------- /init2winit/model_lib/simple_cnn.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Simple convnet classifier.""" 17 | from typing import Sequence 18 | 19 | from flax import linen as nn 20 | from init2winit.model_lib import base_model 21 | from init2winit.model_lib import model_utils 22 | from jax.nn import initializers 23 | import jax.numpy as jnp 24 | 25 | from ml_collections.config_dict import config_dict 26 | 27 | 28 | # small hparams used for unit tests 29 | DEFAULT_HPARAMS = config_dict.ConfigDict(dict( 30 | num_filters=[20, 10], 31 | kernel_sizes=[3, 3], 32 | lr_hparams={ 33 | 'base_lr': 0.001, 34 | 'schedule': 'constant' 35 | }, 36 | layer_rescale_factors={}, 37 | optimizer='momentum', 38 | opt_hparams={ 39 | 'momentum': 0.9, 40 | }, 41 | batch_size=128, 42 | activation_function='relu', 43 | l2_decay_factor=.0005, 44 | l2_decay_rank_threshold=2, 45 | label_smoothing=None, 46 | rng_seed=-1, 47 | use_shallue_label_smoothing=False, 48 | model_dtype='float32', 49 | )) 50 | 51 | 52 | class SimpleCNN(nn.Module): 53 | """Defines a simple CNN model. 54 | 55 | The model assumes the input shape is [batch, H, W, C]. 56 | """ 57 | num_outputs: int 58 | num_filters: Sequence[int] 59 | kernel_sizes: Sequence[int] 60 | activation_function: int 61 | kernel_init: model_utils.Initializer = initializers.lecun_normal() 62 | bias_init: model_utils.Initializer = initializers.zeros 63 | 64 | @nn.compact 65 | def __call__(self, x, train): 66 | for num_filters, kernel_size in zip(self.num_filters, self.kernel_sizes): 67 | x = nn.Conv( 68 | num_filters, (kernel_size, kernel_size), (1, 1), 69 | kernel_init=self.kernel_init, 70 | bias_init=self.bias_init)(x) 71 | x = model_utils.ACTIVATIONS[self.activation_function](x) 72 | x = jnp.reshape(x, (x.shape[0], -1)) 73 | x = nn.Dense( 74 | self.num_outputs, 75 | kernel_init=self.kernel_init, 76 | bias_init=self.bias_init)(x) 77 | return x 78 | 79 | 80 | class SimpleCNNModel(base_model.BaseModel): 81 | """Model class for Simple CNN Model.""" 82 | 83 | def build_flax_module(self): 84 | """Simple CNN with a set of conv layers followed by fully connected layers.""" 85 | return SimpleCNN( 86 | num_outputs=self.hps['output_shape'][-1], 87 | num_filters=self.hps.num_filters, 88 | kernel_sizes=self.hps.kernel_sizes, 89 | activation_function=self.hps.activation_function) 90 | 91 | def get_fake_inputs(self, hps): 92 | """Helper method solely for the purpose of initialzing the model.""" 93 | dummy_inputs = [ 94 | jnp.zeros((hps.batch_size, *hps.input_shape), dtype=hps.model_dtype) 95 | ] 96 | return dummy_inputs 97 | -------------------------------------------------------------------------------- /init2winit/mt_eval/eval_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """BLEU calculation utilities.""" 17 | 18 | import os 19 | import pathlib 20 | 21 | from init2winit import checkpoint 22 | import jax 23 | import sacrebleu 24 | from tensorflow.io import gfile 25 | 26 | exists = gfile.exists 27 | glob = gfile.glob 28 | 29 | 30 | def compute_bleu_from_predictions(predictions, references, language_code, name): 31 | """Computes BLEU score given predictions and references.""" 32 | sacrebleu_tokenizer = 'zh' if language_code == 'zh' else sacrebleu.DEFAULT_TOKENIZER 33 | bleu_score = sacrebleu.corpus_bleu( 34 | predictions, [references], tokenize=sacrebleu_tokenizer).score 35 | return {name: bleu_score} 36 | 37 | 38 | def get_eval_fpath(ckpt_dir, ckpt_step, eval_split): 39 | output_dir = str(pathlib.Path(ckpt_dir).parents[0]) 40 | return os.path.join(output_dir, 'bleu_' + eval_split + '_' + str(ckpt_step)) 41 | 42 | 43 | def load_evals(ckpt_dir, ckpt_step, eval_split): 44 | """Loads results if already available, else return None.""" 45 | ckpt_eval_fpath = get_eval_fpath(ckpt_dir, ckpt_step, eval_split) 46 | if not exists(ckpt_eval_fpath): 47 | return None 48 | else: 49 | with gfile.GFile(ckpt_eval_fpath, 'r') as f: 50 | bleu_score = f.readlines()[-1] 51 | return float(bleu_score.strip()) 52 | 53 | 54 | def save_evals(ckpt_dir, ckpt_step, eval_split, bleu_score): 55 | ckpt_eval_fpath = get_eval_fpath(ckpt_dir, ckpt_step, eval_split) 56 | with gfile.GFile(ckpt_eval_fpath, 'w') as f: 57 | f.write(str(bleu_score)) 58 | 59 | 60 | def _load_checkpoint(checkpoint_path, params): 61 | """Load model (and batch stats) from checkpoint.""" 62 | target = dict( 63 | params=params, 64 | global_step=-1, 65 | preemption_count=0, 66 | sum_train_cost=0.0) 67 | ckpt = checkpoint.load_checkpoint( 68 | checkpoint_path, 69 | target=target, 70 | ) 71 | params = ckpt['params'] 72 | return params 73 | 74 | 75 | def average_checkpoints(checkpoint_paths, params): 76 | """Averages a set of checkpoints in input checkpoints.""" 77 | assert len(checkpoint_paths) >= 1 78 | # Sum parameters of separate models together. 79 | params = _load_checkpoint(checkpoint_paths[0], params) 80 | for checkpoint_path in checkpoint_paths[1:]: 81 | params_update = _load_checkpoint( 82 | checkpoint_path, params 83 | ) 84 | # TODO(dxin): Make this averaging process more numerically stable. 85 | params = jax.tree.map(lambda x, y: x + y, params, params_update) 86 | 87 | # Average checkpoints. 88 | params = jax.tree.map(lambda x: x / float(len(checkpoint_paths)), params) 89 | return params 90 | 91 | 92 | def get_checkpoints_in_range(checkpoint_dir, lower_bound, upper_bound): 93 | """Get checkpoint paths in step range [lower_bound, upper_bound].""" 94 | checkpoint_paths = [] 95 | for checkpoint_path in glob(os.path.join(checkpoint_dir, 'ckpt_*')): 96 | ckpt_step = int(checkpoint_path.split('_')[-1]) 97 | if ckpt_step >= lower_bound and ckpt_step <= upper_bound: 98 | checkpoint_paths.append(os.path.join(checkpoint_dir, checkpoint_path)) 99 | return checkpoint_paths 100 | -------------------------------------------------------------------------------- /init2winit/optimizer_lib/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /init2winit/optimizer_lib/factor_sam.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Efficient implementation of Sharpness Aware Minimization (SAM). 17 | 18 | Applies SAM learning rule every k steps, and factorizes the perturbation 19 | radius and the regularization strength. 20 | """ 21 | 22 | import functools 23 | from typing import Optional 24 | 25 | from init2winit.model_lib import model_utils 26 | import jax 27 | import jax.numpy as jnp 28 | import optax 29 | 30 | _GRAD_CLIP_EPS = 1e-6 31 | 32 | 33 | def normalize_vector(y: jnp.ndarray) -> jnp.ndarray: 34 | """Returns unit norm version of original pytree. 35 | 36 | Args: 37 | y: A pytree of numpy ndarray, vector y in the equation above. 38 | """ 39 | gradient_norm = jnp.sqrt( 40 | sum([jnp.sum(jnp.square(e)) for e in jax.tree_util.tree_leaves(y)])) 41 | normalized_gradient = jax.tree.map(lambda x: x / gradient_norm, y) 42 | return normalized_gradient 43 | 44 | 45 | def clean_update(updates, state, unused_grad_fn_params_tuple): 46 | """Returns the clean update function.""" 47 | return updates, state 48 | 49 | 50 | def sam_update( 51 | updates, 52 | state, 53 | grad_fn_params_tuple, 54 | rho=0.1, 55 | alpha=1.0, 56 | ): 57 | """SAM update function.""" 58 | (grad_fn, params) = grad_fn_params_tuple 59 | updates = normalize_vector(updates) 60 | noised_params = jax.tree_util.tree_map( 61 | lambda p, u: p + rho * u, params, updates 62 | ) 63 | _, sam_updates = grad_fn(noised_params) 64 | 65 | # Regularizer gradient - difference between SAM and clean updates. 66 | sam_updates = jax.tree.map(lambda x, y: x - y, sam_updates, updates) 67 | 68 | # Rescale and apply regularizer 69 | updates = jax.tree.map(lambda x, y: x + alpha * y, updates, sam_updates) 70 | return updates, state 71 | 72 | 73 | def sharpness_aware_minimization( 74 | rho: float, 75 | alpha: float, 76 | k: int, 77 | grad_clip: Optional[float], 78 | base_opt_init_fn, 79 | base_opt_update_fn, 80 | ) -> optax.GradientTransformation: 81 | """Implementation of Sharpness Aware Minimization (SAM). 82 | 83 | Paper: https://arxiv.org/abs/2010.01412 84 | Code: https://github.com/google-research/sam 85 | 86 | References: 87 | Foret et al, 2021: https://arxiv.org/abs/2010.01412 88 | Args: 89 | rho: The size of the neighborhood for the sharpness aware minimization 90 | gradient updates. Defaults to 0.1. 91 | alpha: Additional scaling factor for regularization strength. 92 | k: Period on which to apply SAM. Regularization strength is scaled by k. 93 | grad_clip: The optional value to clip the updates by. Defaults to None. 94 | base_opt_init_fn: The initialization function for the base optimizer used to 95 | generate updates given the total gradient. 96 | base_opt_update_fn: The update function for the base optimizer used to 97 | generate updates given the total gradient. 98 | 99 | Returns: 100 | The corresponding `GradientTransformation`. 101 | """ 102 | 103 | def init_fn(params): 104 | return base_opt_init_fn(params) 105 | 106 | # TODO(thetish): Implement version which applies SAM before averaging over 107 | # devices. 108 | def update_fn(updates, state, grad_fn_params_tuple): 109 | # Updates here have been averaged across devices in Trainer before being 110 | # sent to the optimizer. 111 | (_, params) = grad_fn_params_tuple 112 | 113 | # Update function in between SAM steps. 114 | intermediate_update_fn = clean_update 115 | 116 | # Sam update. Scale alpha by k to keep optimal rho independent of k. 117 | alpha_eff = alpha * k 118 | sam_update_fn = functools.partial(sam_update, rho=rho, alpha=alpha_eff) 119 | updates, state = jax.lax.cond( # Apply SAM every k steps. 120 | state.count % k == 0, 121 | sam_update_fn, 122 | intermediate_update_fn, 123 | updates, 124 | state, 125 | grad_fn_params_tuple, 126 | ) 127 | # Clipping 128 | if grad_clip: 129 | updates_norm = jnp.sqrt(model_utils.l2_regularization(updates, 0)) 130 | scaled_updates = jax.tree.map( 131 | lambda x: x / (updates_norm + _GRAD_CLIP_EPS) * grad_clip, updates) 132 | updates = jax.lax.cond(updates_norm > grad_clip, lambda _: scaled_updates, 133 | lambda _: updates, None) 134 | # TODO(thetish): Explore different order for base optimizer and SAM. For 135 | # example, in Adam preconditioning the SAM perturbation is helpful. 136 | return base_opt_update_fn(updates, state, params) # Apply base optimizer 137 | 138 | return optax.GradientTransformation(init_fn, update_fn) 139 | -------------------------------------------------------------------------------- /init2winit/optimizer_lib/kitchen_sink/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Kitchen Sink: decomposing optimizers in JAX.""" 17 | 18 | from init2winit.optimizer_lib.kitchen_sink._src.alias import adamw_generic 19 | from init2winit.optimizer_lib.kitchen_sink._src.alias import adapropw 20 | from init2winit.optimizer_lib.kitchen_sink._src.alias import nadampw 21 | from init2winit.optimizer_lib.kitchen_sink._src.alias import nadamw 22 | from init2winit.optimizer_lib.kitchen_sink._src.core import kitchen_sink 23 | from init2winit.optimizer_lib.kitchen_sink._src.transform import add_decayed_weights 24 | from init2winit.optimizer_lib.kitchen_sink._src.transform import bias_correction 25 | from init2winit.optimizer_lib.kitchen_sink._src.transform import BiasCorrectionState 26 | from init2winit.optimizer_lib.kitchen_sink._src.transform import clip_updates 27 | from init2winit.optimizer_lib.kitchen_sink._src.transform import first_moment_ema 28 | from init2winit.optimizer_lib.kitchen_sink._src.transform import nesterov 29 | from init2winit.optimizer_lib.kitchen_sink._src.transform import nesterovpp 30 | from init2winit.optimizer_lib.kitchen_sink._src.transform import polyak_averaging 31 | from init2winit.optimizer_lib.kitchen_sink._src.transform import Polyak_AveragingState 32 | from init2winit.optimizer_lib.kitchen_sink._src.transform import polyak_hb 33 | from init2winit.optimizer_lib.kitchen_sink._src.transform import precondition_by_amsgrad 34 | from init2winit.optimizer_lib.kitchen_sink._src.transform import precondition_by_layered_adaptive_rms 35 | from init2winit.optimizer_lib.kitchen_sink._src.transform import precondition_by_rms 36 | from init2winit.optimizer_lib.kitchen_sink._src.transform import precondition_by_rss 37 | from init2winit.optimizer_lib.kitchen_sink._src.transform import precondition_by_yogi 38 | from init2winit.optimizer_lib.kitchen_sink._src.transform import PreconditionByLayeredAdaptiveRMSState 39 | from init2winit.optimizer_lib.kitchen_sink._src.transform import PreconditionByRssState 40 | from init2winit.optimizer_lib.kitchen_sink._src.transform import PreconditionBySecondMomentCoordinateWiseState 41 | from init2winit.optimizer_lib.kitchen_sink._src.transform import sanitize_values 42 | from init2winit.optimizer_lib.kitchen_sink._src.transform import scale_by_adam 43 | from init2winit.optimizer_lib.kitchen_sink._src.transform import scale_by_adaprop 44 | from init2winit.optimizer_lib.kitchen_sink._src.transform import scale_by_amsgrad 45 | from init2winit.optimizer_lib.kitchen_sink._src.transform import scale_by_learning_rate 46 | from init2winit.optimizer_lib.kitchen_sink._src.transform import scale_by_nadam 47 | from init2winit.optimizer_lib.kitchen_sink._src.transform import ScaleByAdamState 48 | from init2winit.optimizer_lib.kitchen_sink._src.transform import ScaleByAMSGradState 49 | from init2winit.optimizer_lib.kitchen_sink._src.utils import unfreeze_wrapper 50 | 51 | 52 | __version__ = '0.0.1' 53 | 54 | __all__ = ( 55 | 'nadamw', 56 | 'nadampw', 57 | 'adamw_generic', 58 | 'adapropw', 59 | 'kitchen_sink', 60 | 'bias_correction', 61 | 'BiasCorrectionState', 62 | 'clip_updates', 63 | 'first_moment_ema', 64 | 'nesterov', 65 | 'nesterovpp', 66 | 'add_decayed_weights', 67 | 'polyak_averaging', 68 | 'Polyak_AveragingState', 69 | 'polyak_hb', 70 | 'precondition_by_amsgrad', 71 | 'precondition_by_layered_adaptive_rms', 72 | 'precondition_by_rms', 73 | 'precondition_by_rss', 74 | 'precondition_by_yogi', 75 | 'PreconditionByLayeredAdaptiveRMSState', 76 | 'PreconditionByRssState', 77 | 'PreconditionBySecondMomentCoordinateWiseState', 78 | 'sanitize_values', 79 | 'scale_by_adam', 80 | 'scale_by_amsgrad', 81 | 'scale_by_adaprop', 82 | 'scale_by_learning_rate', 83 | 'scale_by_nadam', 84 | 'ScaleByAdamState', 85 | 'ScaleByAMSGradState', 86 | 'unfreeze_wrapper', 87 | ) 88 | -------------------------------------------------------------------------------- /init2winit/optimizer_lib/kitchen_sink/_src/combine.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Combine utilities.""" 17 | import functools 18 | from typing import Any, NamedTuple 19 | from typing import Callable 20 | from typing import Optional 21 | from typing import Tuple 22 | from typing import Union 23 | import jax 24 | import jax.numpy as jnp 25 | import optax 26 | 27 | # TODO(dsuo): Add back grafting combinator. 28 | 29 | 30 | def join(by: Union[str, Callable[[optax.GradientTransformation, ...], 31 | optax.Updates]], *args, 32 | **kwargs) -> Callable[..., optax.GradientTransformation]: 33 | """Join multiple chains.""" 34 | 35 | if by is None or by == 'chain': 36 | return lambda *args, **kwargs: optax.chain(*(args + tuple(kwargs.values()))) 37 | if isinstance(by, str): 38 | if by not in combinator_registry: 39 | raise ValueError(f'Unrecognized `by` function {by}.') 40 | by_init, by_update = combinator_registry[by](*args, **kwargs) 41 | 42 | # TODO(dsuo): match docs/autocomplete with combinator args. 43 | def transform(*args, **kwargs): 44 | 45 | def init(params: optax.Params) -> optax.OptState: 46 | args_state = tuple(chain.init(params) for chain in args) 47 | kwargs_state = { 48 | name: chain.init(params) for name, chain in kwargs.items() 49 | } 50 | combinator_state = by_init(params, *args_state, **kwargs_state) 51 | return combinator_state, args_state, kwargs_state 52 | 53 | def update( 54 | updates: optax.Updates, 55 | state: optax.OptState, 56 | params: Optional[optax.Params] = None 57 | ) -> Tuple[optax.Updates, optax.OptState]: 58 | combinator_state, args_state, kwargs_state = state 59 | 60 | args_results = [ 61 | chain.update(updates, state, params) 62 | for chain, state in zip(args, args_state) 63 | ] 64 | args_updates = tuple(result[0] for result in args_results) 65 | args_state = tuple(result[1] for result in args_results) 66 | 67 | kwargs_results = { 68 | name: chain.update(updates, kwargs_state[name], params) 69 | for name, chain in kwargs.items() 70 | } 71 | kwargs_updates = { 72 | name: result[0] for name, result in kwargs_results.items() 73 | } 74 | kwargs_state = { 75 | name: result[1] for name, result in kwargs_results.items() 76 | } 77 | 78 | updates, combinator_state = by_update( 79 | combinator_state, *args_updates, **kwargs_updates 80 | ) 81 | 82 | return updates, (combinator_state, args_state, kwargs_state) 83 | 84 | return optax.GradientTransformation(init, update) 85 | 86 | return transform 87 | 88 | 89 | def _grafting_helper(chain, use_global_norm=False): 90 | norm = jax.tree.map(jnp.linalg.norm, chain) 91 | if use_global_norm: 92 | global_norm = jax.tree_util.tree_reduce(lambda x, y: jnp.sqrt(x**2 + y**2), 93 | norm) 94 | norm = jax.tree.map(lambda x: global_norm, norm) 95 | return norm 96 | 97 | 98 | class GraftingState(NamedTuple): 99 | """State for the Layered Adaptive RMS Preconditioner algorithm.""" 100 | mag_norm: Any 101 | dir_norm: Any 102 | 103 | 104 | def combine_by_grafting(eps: float = 0.0, use_global_norm: bool = False): 105 | """Grafting combinator. 106 | 107 | Args: 108 | eps (float, optional): term added to D normalization denominator for 109 | numerical stability (default: 1e-16) 110 | use_global_norm (bool, optional): graft global l2 norms rather than 111 | per-layer (default: False) 112 | 113 | Returns: 114 | updates in the shape of params. 115 | """ 116 | 117 | def init(params, *args, **kwargs): 118 | del args, kwargs 119 | mag_norm = jax.tree.map(lambda x: 0.0, params) 120 | dir_norm = jax.tree.map(lambda x: 0.0, params) 121 | 122 | return GraftingState(mag_norm=mag_norm, dir_norm=dir_norm) 123 | 124 | def update(state, mag_chain, dir_chain): 125 | del state 126 | mag_norm = _grafting_helper(mag_chain, use_global_norm=use_global_norm) 127 | dir_norm = _grafting_helper(dir_chain, use_global_norm=use_global_norm) 128 | 129 | updates = jax.tree.map( 130 | lambda dir, dirn, magn: dir / (dirn + eps) * magn, 131 | dir_chain, 132 | dir_norm, 133 | mag_norm, 134 | ) 135 | 136 | return updates, GraftingState(mag_norm=mag_norm, dir_norm=dir_norm) 137 | 138 | return init, update 139 | 140 | 141 | def combine_by_sum(): 142 | """Sum combinator. 143 | 144 | Returns: 145 | updates in the shape of params. 146 | """ 147 | 148 | def init(params, *args, **kwargs): 149 | del args, kwargs, params 150 | return optax.EmptyState() 151 | 152 | def update(state, *args, **kwargs): 153 | args = args + tuple(kwargs.values()) 154 | return functools.reduce( 155 | lambda x, y: jax.tree_multimap(lambda i, j: i + j, x, y), args), state 156 | 157 | return init, update 158 | 159 | 160 | combinator_registry = { 161 | 'grafting': combine_by_grafting, 162 | 'sum': combine_by_sum, 163 | } 164 | -------------------------------------------------------------------------------- /init2winit/optimizer_lib/kitchen_sink/_src/core.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Modularizing optimization ideas. 17 | 18 | This project seeks to take ideas in optimization (e.g., scale decay, 19 | momentum) and understand when, how, and some insight into why they are 20 | effective. 21 | """ 22 | 23 | from typing import Any, Dict 24 | 25 | from init2winit.optimizer_lib.kitchen_sink._src import utils 26 | from init2winit.optimizer_lib.kitchen_sink._src.combine import join 27 | from init2winit.optimizer_lib.kitchen_sink._src.mask import mask_registry 28 | from init2winit.optimizer_lib.kitchen_sink._src.transform import transformation_registry 29 | import ml_collections 30 | import optax 31 | 32 | 33 | # TODO(dsuo): document config syntax. 34 | 35 | 36 | def _get_mask(x): 37 | """Find a mask in a given element.""" 38 | if 'mask' in x: 39 | mask = x['mask'] 40 | elif 'mask' in x.get('hps', {}): 41 | mask = x['hps']['mask'] 42 | else: 43 | mask = None 44 | 45 | if mask in mask_registry: 46 | mask = mask_registry[mask] 47 | 48 | return mask 49 | 50 | 51 | def _kitchen_sink_helper(config): 52 | """Recursively chain and join `optax.GradientTransformation`s.""" 53 | 54 | if utils.is_leaf(config): 55 | if 'by' in config: 56 | raise KeyError(f'Leaf {config} should not have key `by`.') 57 | 58 | el = config['element'] 59 | if el not in transformation_registry: 60 | raise ValueError(f'Transformation {el} not found.') 61 | hps = config.get('hps', {}) 62 | tx = transformation_registry[el](**hps) 63 | 64 | else: 65 | if 'hps' in config: 66 | raise KeyError(f'Config {config} should not have key `hps`.') 67 | 68 | to_join = config.get('join', {}) 69 | for key, val in to_join.items(): 70 | to_join[key] = _kitchen_sink_helper(val) 71 | 72 | # Defaults to `None`, which chains child components together. 73 | by = config.get('by') 74 | by_kwargs = config.get('by_kwargs', {}) 75 | 76 | tx = join(by, **by_kwargs)(**to_join) 77 | 78 | mask = _get_mask(config) 79 | if mask is not None: 80 | tx = optax.masked(tx, mask) 81 | 82 | return tx 83 | 84 | 85 | def kitchen_sink(config: Dict[str, Any], 86 | learning_rate: float = None) -> optax.GradientTransformation: 87 | """Runs a list of GradientTransforms in parallel and combines. 88 | 89 | Args: 90 | config: dictionary configuring an optimizer. 91 | learning_rate: learning rate that gets injected. 92 | 93 | Returns: 94 | optax.GradientTransform 95 | """ 96 | # Cast to dict in case we have an ml_collections.ConfigDict. 97 | 98 | if isinstance(config, ml_collections.ConfigDict): 99 | config = config.to_dict() 100 | elif not isinstance(config, dict): 101 | raise ValueError( 102 | 'Kitchen Sink configuration needs to be a config dict or a python dict') 103 | 104 | # Syntactic sugar. If we have an implied chain, make it explicitly a chain. 105 | if all([str(i) in config for i in range(len(config))]): 106 | config = {'join': config} 107 | 108 | # Handle `one_minus_` hps, if any. 109 | config = utils.map_element(utils.handle_one_minus, config) 110 | 111 | # Apply learning rate to any existing scale_by_learning_rate 112 | if learning_rate is not None: 113 | config = utils.apply_and_maybe_scale_by_learning_rate(config, learning_rate) 114 | 115 | return utils.unfreeze_wrapper(*_kitchen_sink_helper(config)) 116 | -------------------------------------------------------------------------------- /init2winit/optimizer_lib/kitchen_sink/_src/mask.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Mask utilities.""" 17 | import flax 18 | 19 | 20 | def create_mask(fn): 21 | """Creates a mask that maps fn over the leaves of a dict. 22 | 23 | Args: 24 | fn: function to apply taking - k: Tuple containing nodes (strings) in path 25 | to the leaf - v: The leaf 26 | 27 | Returns: 28 | mask: function that takes dict and returns mapped dict 29 | """ 30 | 31 | def mask(data): 32 | flattened_dict = flax.traverse_util.flatten_dict(data) 33 | return flax.traverse_util.unflatten_dict( 34 | {k: fn(k, v) for k, v in flattened_dict.items()}) 35 | 36 | return mask 37 | 38 | 39 | def create_weight_decay_mask(): 40 | return create_mask( 41 | lambda p, _: 'bias' not in p and not p[-2].startswith('BatchNorm')) 42 | 43 | mask_registry = { 44 | 'bias_bn': create_weight_decay_mask(), 45 | } 46 | -------------------------------------------------------------------------------- /init2winit/optimizer_lib/kitchen_sink/_src/test_mask.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for utils.""" 17 | 18 | from typing import Sequence 19 | 20 | from absl.testing import absltest 21 | import chex 22 | import flax 23 | import flax.linen as nn 24 | from init2winit.optimizer_lib.kitchen_sink._src.mask import create_mask 25 | from init2winit.optimizer_lib.kitchen_sink._src.mask import create_weight_decay_mask 26 | import jax 27 | import jax.numpy as jnp 28 | import optax 29 | 30 | # pylint:disable=duplicate-key 31 | 32 | 33 | class Foo(nn.Module): 34 | """Dummy model.""" 35 | 36 | train: bool 37 | filters: int 38 | 39 | @nn.compact 40 | def __call__(self, x): 41 | x = nn.Conv(self.filters, (1, 1), use_bias=False, dtype=jnp.float32)(x) 42 | x = nn.BatchNorm( 43 | use_running_average=not self.train, 44 | momentum=0.9, 45 | epsilon=1e-5, 46 | dtype=jnp.float32)( 47 | x) 48 | return x 49 | 50 | 51 | class Bar(nn.Module): 52 | """Dummy model.""" 53 | 54 | features: Sequence[int] 55 | 56 | @nn.compact 57 | def __call__(self, inputs): 58 | x = inputs 59 | for i, feat in enumerate(self.features): 60 | x = nn.Dense(feat, use_bias=True, name=f'layers_{i}')(x) 61 | if i != len(self.features) - 1: 62 | x = nn.relu(x) 63 | return x 64 | 65 | 66 | class CreateMaskTest(chex.TestCase): 67 | """Test masking.""" 68 | 69 | def test_simple(self): 70 | """Check if the leaf key is `a`.""" 71 | mask = create_mask(lambda path, _: path[-1] == 'a') 72 | data = {'a': 4, 'b': {'a': 5, 'c': 1}, 'c': {'a': {'b': 1}}} 73 | 74 | truth = {'a': True, 'b': {'a': True, 'c': False}, 'c': {'a': {'b': False}}} 75 | 76 | chex.assert_equal(mask(data), truth) 77 | 78 | 79 | class CreateWeightDecayMaskTest(chex.TestCase): 80 | """Test weight decay mask.""" 81 | 82 | def test_simple(self): 83 | """Check that the correct tags are removed.""" 84 | mask = create_weight_decay_mask() 85 | data = { 86 | 'bias': { 87 | 'b': 4 88 | }, 89 | 'bias': { 90 | 'BatchNorm_0': 4, 91 | 'bias': 5, 92 | 'a': 0 93 | }, 94 | 'BatchNorm_0': { 95 | 'b': 4 96 | }, 97 | 'a': { 98 | 'b': { 99 | 'BatchNorm_0': 0, 100 | 'bias': 0 101 | }, 102 | 'c': 0 103 | } 104 | } 105 | truth = { 106 | 'bias': { 107 | 'b': False 108 | }, 109 | 'bias': { 110 | 'BatchNorm_0': False, 111 | 'bias': False, 112 | 'a': False 113 | }, 114 | 'BatchNorm_0': { 115 | 'b': False 116 | }, 117 | 'a': { 118 | 'b': { 119 | 'BatchNorm_0': True, 120 | 'bias': False 121 | }, 122 | 'c': True 123 | } 124 | } 125 | 126 | chex.assert_equal(mask(data), truth) 127 | 128 | @chex.variants(with_jit=True, without_jit=True) 129 | def test_batch(self): 130 | """Test that batch layer is indeed ignored. 131 | 132 | Code taken from: https://github.com/google/flax/issues/932 133 | """ 134 | key = jax.random.PRNGKey(0) 135 | x = jnp.ones((5, 4, 4, 3)) 136 | y = jax.random.uniform(key, (5, 4, 4, 7)) 137 | 138 | foo_vars = flax.core.unfreeze(Foo(filters=7, train=True).init(key, x)) 139 | tx = optax.masked(optax.adam(1e-7), create_weight_decay_mask()) 140 | 141 | @self.variant 142 | def train_step(params, x, y): 143 | y1, new_batch_stats = Foo( 144 | filters=7, train=True).apply( 145 | params, x, mutable=['batch_stats']) 146 | 147 | return jnp.abs(y - y1).sum(), new_batch_stats 148 | 149 | state = self.variant(tx.init)(foo_vars['params']) 150 | grads, _ = jax.grad(train_step, has_aux=True)(foo_vars, x, y) 151 | updates, state = self.variant(tx.update)(dict(grads['params']), state) 152 | 153 | chex.assert_trees_all_close(updates['BatchNorm_0'], 154 | grads['params']['BatchNorm_0']) 155 | 156 | @chex.variants(with_jit=True, without_jit=True) 157 | def test_bias(self): 158 | """Test that biases are ignored.""" 159 | key1, key2 = jax.random.split(jax.random.PRNGKey(0), 2) 160 | x = jax.random.uniform(key1, (4, 4)) 161 | 162 | model = Bar(features=[3, 4, 5]) 163 | params = flax.core.unfreeze(model.init(key2, x)) 164 | y = jax.random.uniform(key1, model.apply(params, x).shape) 165 | 166 | tx = optax.masked(optax.adam(1e-7), create_weight_decay_mask()) 167 | state = tx.init(params) 168 | 169 | def loss(params, x, y): 170 | pred = model.apply(params, x) 171 | return jnp.abs(pred - y).sum() 172 | 173 | grads = jax.grad(loss)(params, x, y) 174 | updates, state = self.variant(tx.update)(dict(grads), state) 175 | 176 | for i in range(3): 177 | chex.assert_trees_all_close(grads['params'][f'layers_{i}']['bias'], 178 | updates['params'][f'layers_{i}']['bias']) 179 | 180 | 181 | if __name__ == '__main__': 182 | absltest.main() 183 | -------------------------------------------------------------------------------- /init2winit/optimizer_lib/kitchen_sink/_src/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Optimizer utilities.""" 17 | 18 | import copy 19 | import operator 20 | 21 | from absl import logging 22 | import flax 23 | import jax 24 | import jax.numpy as jnp 25 | import optax 26 | 27 | 28 | def total_tree_sum(pytree): 29 | """Compute the overall sum of a pytree.""" 30 | sums = jax.tree.map(jnp.sum, pytree) 31 | return jax.tree_util.tree_reduce(operator.add, sums, 0) 32 | 33 | 34 | def tree_norm_sql2(pytree): 35 | """Compute the param-wise squared L2 norm of a pytree.""" 36 | return jax.tree.map(lambda x: jnp.linalg.norm(x.reshape(-1)) ** 2, pytree) 37 | 38 | 39 | def total_tree_norm_sql2(pytree): 40 | """Compute the overall squared L2 norm of a pytree.""" 41 | sql2_norms = tree_norm_sql2(pytree) 42 | return jax.tree_util.tree_reduce(operator.add, sql2_norms, 0) 43 | 44 | 45 | def is_leaf(x): 46 | return isinstance(x, dict) and 'element' in x 47 | 48 | 49 | def map_element(fn, config, true_leaf_fn=None): 50 | if not isinstance(config, dict): 51 | if true_leaf_fn is not None: 52 | return true_leaf_fn(config) 53 | else: 54 | return config 55 | elif 'element' in config: 56 | return fn(config) 57 | else: 58 | return {k: map_element(fn, v, true_leaf_fn) for k, v in config.items()} 59 | 60 | 61 | def unfreeze_wrapper(init_fn, update_fn): 62 | """Freeze/unfreeze params.""" 63 | 64 | # NOTE(dsuo): We use plain dicts internally due to this issue 65 | # https://github.com/deepmind/optax/issues/160. 66 | def wrapped_init_fn(params): 67 | return init_fn(flax.core.unfreeze(params)) 68 | 69 | def wrapped_update_fn(updates, state, params=None): 70 | new_updates, state = update_fn( 71 | flax.core.unfreeze(updates), state, 72 | None if params is None else flax.core.unfreeze(params)) 73 | 74 | if isinstance(updates, flax.core.FrozenDict): 75 | new_updates = flax.core.freeze(new_updates) 76 | 77 | return new_updates, state 78 | 79 | return optax.GradientTransformation(wrapped_init_fn, wrapped_update_fn) 80 | 81 | 82 | def handle_one_minus(x): 83 | if 'hps' in x: 84 | for hp in copy.deepcopy(x['hps']).keys(): 85 | if 'one_minus_' in hp: 86 | x['hps'][hp.replace('one_minus_', '')] = 1 - x['hps'][hp] 87 | del x['hps'][hp] 88 | return x 89 | 90 | 91 | def apply_and_maybe_scale_by_learning_rate(config, learning_rate): 92 | """Apply learning rate and possibly scale by learning rate.""" 93 | 94 | def is_scale_by_lr(x): 95 | return not isinstance(x, str) and x['element'] == 'scale_by_learning_rate' 96 | 97 | def contains_lr_as_param(x): 98 | return not isinstance(x, str) and x.get( 99 | 'hps', None) and 'learning_rate' in x['hps'] 100 | 101 | def update_leaf(x): 102 | if contains_lr_as_param(x): 103 | x['hps']['learning_rate'] = learning_rate 104 | return x 105 | return x 106 | 107 | scaled = map_element(is_scale_by_lr, config, true_leaf_fn=lambda x: False) 108 | num_scaled = jax.tree_util.tree_reduce(lambda x, y: x + y, scaled, 0) 109 | 110 | if num_scaled == 0: 111 | return { 112 | 'join': { 113 | '0': config, 114 | '1': { 115 | 'element': 'scale_by_learning_rate', 116 | 'hps': { 117 | 'learning_rate': learning_rate 118 | } 119 | } 120 | } 121 | } 122 | elif num_scaled == 1: 123 | return map_element(update_leaf, config) 124 | else: 125 | logging.warning('Kitchen Sink configuration has more than one ' 126 | 'scale_by_learning_rate. Please double check config') 127 | -------------------------------------------------------------------------------- /init2winit/optimizer_lib/linalg/README.md: -------------------------------------------------------------------------------- 1 | Linear algebra package for non-diagonal preconditioning. -------------------------------------------------------------------------------- /init2winit/optimizer_lib/linalg/low_rank_root_update_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Test for computing the sqrt and inverse sqrt of a matrix.""" 17 | 18 | from typing import Tuple 19 | 20 | from absl.testing import absltest 21 | from absl.testing import parameterized 22 | from init2winit.optimizer_lib.linalg import low_rank_root_update 23 | import jax 24 | import numpy as np 25 | import scipy.stats 26 | 27 | 28 | def _small_perturbation(n: int, gamma: float, 29 | rng: np.random.RandomState) -> np.ndarray: 30 | """Returns a vector of absolute values ofnormally distributed values with standard deviation gamma.""" 31 | s = gamma*np.abs(rng.normal(size=n)) 32 | return s 33 | 34 | 35 | def _random_singular_values(n: int, gamma: float, 36 | rng: np.random.RandomState) -> np.ndarray: 37 | """Returns n random singular values in [γ, 1].""" 38 | s = gamma**rng.random((n,)) # log of singular values uniformly distributed 39 | if n > 0: 40 | s[0] = gamma 41 | if n > 1: 42 | s[1] = 1 43 | return s 44 | 45 | 46 | def _random_svd( 47 | n: int, gamma: float, 48 | rng: np.random.RandomState) -> Tuple[np.ndarray, np.ndarray]: 49 | """Returns a random SVD decomposition with singular values in [γ, 1].""" 50 | # sample a uniformly random orthogonal matrix. 51 | v = scipy.stats.ortho_group.rvs(n, random_state=rng) 52 | s = _random_singular_values(n, gamma, rng) 53 | return s, v 54 | 55 | 56 | @jax.jit 57 | def _update_sqrt(x, ix, g): 58 | ra_size = np.where( 59 | x.shape[-1] < 64, x.shape[-1], np.minimum(x.shape[-1] // 12, 64) 60 | ) 61 | rank_array = np.zeros(ra_size) 62 | return low_rank_root_update.low_rank_root_update( 63 | x, ix, g, rank_array, 1e-6, 2 64 | ) 65 | 66 | 67 | class InvSquareRootTest(parameterized.TestCase): 68 | 69 | @parameterized.named_parameters( 70 | { # pylint:disable=g-complex-comprehension 71 | 'testcase_name': f'n={n}', 72 | 'n': n, # pylint: disable=undefined-variable 73 | 'p': p, # pylint: disable=undefined-variable 74 | } 75 | for n in [2, 31] 76 | for p in [2] 77 | ) 78 | def test_random_matrix(self, n, p): 79 | rng = np.random.RandomState(seed=42) 80 | 81 | sigma = 1e-2 # smallest singular value of test matrix 82 | s, v = _random_svd(n, sigma, rng) 83 | s = s.astype(np.float64) 84 | v = v.astype(np.float64) 85 | q = _small_perturbation(n, 1e-4, rng) 86 | q = np.diag(q.astype(np.float64)) 87 | a_sqrt = (v * s**(1 / p)) @ v.T 88 | a_isqrt = (v * s**(-1 / p)) @ v.T 89 | exact = (v * (s + q**2)**(-1 / p)) @ v.T 90 | ans = _update_sqrt(a_sqrt, a_isqrt, q)[1] 91 | ans = np.array(ans).astype(np.float64) 92 | error = np.linalg.norm(ans - exact, 2) / np.linalg.norm(exact, 2) 93 | kappa = 1 / p / sigma 94 | expected_error = 3 * kappa * np.finfo(np.float32).eps 95 | self.assertLessEqual(error, expected_error) 96 | 97 | 98 | if __name__ == '__main__': 99 | jax.config.update('jax_enable_x64', False) 100 | absltest.main() 101 | -------------------------------------------------------------------------------- /init2winit/optimizer_lib/linalg/paterson_stockmeyer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Paterson-Stockmeyer method for polynomial evaluation.""" 17 | from typing import Any, Callable, List, Sequence, TypeVar 18 | 19 | import numpy as np 20 | 21 | T = TypeVar('T') 22 | 23 | 24 | def _powers(x: T, n: int, product: Callable[[T, T], T]) -> List[T]: 25 | """Returns the list [x, x², ..., xⁿ].""" 26 | xp = [None] * (n + 1) 27 | xp[1] = x 28 | for j in range(2, n + 1): 29 | # To reduce round-off, compute xʲ as the result of O(log j) mutliplies 30 | xp[j] = product(xp[j // 2], xp[(j + 1) // 2]) 31 | return xp[1:] 32 | 33 | 34 | def polynomial_no_constant(a: Sequence[Any], x: T, product: Callable[[T, T], 35 | T]) -> T: 36 | """Paterson-Stockmeyer evaluation of a[0] x + a[1] x² + ... + a[n-1] xⁿ. 37 | 38 | A variant of the Paterson-Stockmeyer method for polynomial evaluation 39 | presented in [2], which avoids using the multiplicative identity (x⁰). The 40 | algorithm uses only ⌈2√n⌉ - 2 multiplications instead of n - 1, making it 41 | especially suitable when multiplications are expensive, e.g., for matrices. 42 | The reduced number of multiplications is accomplished by grouping the terms as 43 | 44 | (a[0] x + a[1] x² + ... + a[ s-1] xˢ) + 45 | xˢ (a[s] x + a[s+1] x² + ... + a[2s-1] xˢ) + 46 | (xˢ)² (a[2s] x + a[2s+1] x² + ... + a[3s-1] xˢ) + 47 | ... 48 | 49 | with s = ⌈√n⌉. The powers up to xˢ are precomputed with s - 1 multiplications, 50 | allowing all the (at most) degree s polynomials in parentheses above to be 51 | evaluated. These are then combined using Horner's rule with ⌈n/s⌉ - 1 52 | subsequent multiplications. 53 | 54 | [1] Michael S. Paterson and Larry J. Stockmeyer, "On the number of nonscalar 55 | multiplications necessary to evaluate polynomials," SIAM J. Comput., 2 56 | (1973), pp. 60–66. 57 | 58 | [2] M. Fasi, "Optimality of the Paterson-Stockmeyer method for evaluating 59 | matrix polynomials and rational matrix functions," Linear Algebra Appl., 60 | 574 (2019), pp. 182–200. 61 | 62 | Args: 63 | a: Polynomial coefficients. a[j] is the coefficient of xʲ⁺¹. 64 | x: Argument to evaluate the polynomial at. 65 | product: Multiplication function. 66 | 67 | Returns: 68 | The polynomial a[0] x + a[1] x² + ... + a[n-1] xⁿ . 69 | 70 | Raises: 71 | ValueError if `a` is empty. 72 | """ 73 | n = len(a) 74 | if n == 0: 75 | raise ValueError('polynomial_no_constant: coefficients empty.') 76 | s = int(np.ceil(np.sqrt(n))) 77 | xp = _powers(x, s, product) 78 | inner = lambda alpha: sum([cj * xj for (cj, xj) in zip(alpha, xp)]) 79 | inner_poly = lambda i: inner(a[s * i:min(n, s * (i + 1))]) 80 | i = (n + s - 1) // s - 1 81 | y = inner_poly(i) 82 | for i in reversed(range(i)): 83 | y = inner_poly(i) + product(xp[s - 1], y) 84 | return y 85 | -------------------------------------------------------------------------------- /init2winit/optimizer_lib/linalg/root_selector.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Depending on a variable, selects whether to use exact or approximate root function.""" 17 | 18 | import functools 19 | from typing import Tuple 20 | 21 | import chex 22 | from init2winit.optimizer_lib.linalg import low_rank_root_update 23 | from init2winit.optimizer_lib.linalg import pth_inv_root_rmn 24 | from jax import lax 25 | import numpy as np 26 | 27 | 28 | def root_selector( 29 | x: chex.Array, 30 | sx: chex.Array, 31 | isx: chex.Array, 32 | up: chex.Array, 33 | p: int, 34 | eps: float, 35 | exact_root: bool, 36 | rank_estimate: int, 37 | block_krylov_dim_multiplier: int = 2, 38 | stable_iter: bool = False, 39 | unroll: bool = False, 40 | verbose: bool = True, 41 | ) -> Tuple[chex.Array, chex.Array]: 42 | """Returns |X|^{1/p} and |X|⁻¹ᐟᵖ. 43 | 44 | Args: 45 | x: Input matrix must be SPD with eigenvalues >= float32 epsilon. 46 | sx: Old sqrt of the matrix X. 47 | isx: Old inverse sqrt of the matrix X. 48 | up: Update to the matrix of the form update @ update.T 49 | p: Exponent. 50 | eps: small constant to avoid numerical issues with Lyapunov solver. 51 | exact_root: If True, solve the root exactly, otherwise. 52 | rank_estimate: Rank estimate of the update. 53 | block_krylov_dim_multiplier: Multiplier for the block krylov dimension over 54 | the rank estimate. 55 | stable_iter: Whether to use the stable iteration for the inner loop. 56 | unroll: Whether to unroll the loop over iterations. 57 | verbose: Whether to log some information about the iteration, including the 58 | coefficients `a` for each iteration. 59 | 60 | Returns: 61 | An approximation of |X|⁻¹ᐟᵖ. 62 | """ 63 | f_er = functools.partial( 64 | pth_inv_root_rmn.pth_inv_root_rmn, 65 | fast_root=True, 66 | precision="float32", 67 | stable_iter=stable_iter, 68 | unroll=unroll, 69 | verbose=verbose, 70 | ) 71 | 72 | def _exact_root(): 73 | return f_er(x, p) 74 | 75 | rank_array = np.zeros(np.where( 76 | x.shape[-1] < 64, 77 | x.shape[-1], 78 | np.minimum(x.shape[-1] // 8, rank_estimate), 79 | )) 80 | f_ar = functools.partial( 81 | low_rank_root_update.low_rank_root_update, 82 | rank_array=rank_array, 83 | eps=eps, 84 | block_krylov_dim_multiplier=block_krylov_dim_multiplier, 85 | verbose=verbose, 86 | ) 87 | def _approx_root(): 88 | return f_ar(sx, isx, up) 89 | 90 | return lax.cond(exact_root, _exact_root, _approx_root) 91 | -------------------------------------------------------------------------------- /init2winit/optimizer_lib/online_newton_step.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Optimizers for online newton step algorithms.""" 17 | 18 | from typing import NamedTuple 19 | 20 | from init2winit.optimizer_lib import kitchen_sink 21 | from init2winit.optimizer_lib import utils 22 | import jax 23 | import jax.numpy as jnp 24 | import optax 25 | 26 | 27 | def diag_ons(learning_rate, 28 | weight_decay: float = 0.0, 29 | b1: float = 0.9, 30 | b2: float = 0.999, 31 | eps: float = 1e-8): 32 | """The diagonal version of Online Newton Step with flexible updates. 33 | 34 | Args: 35 | learning_rate: A fixed global scaling factor. 36 | weight_decay: weight decay. 37 | b1: Exponential decay rate to track the first moment of past gradients. 38 | b2: Exponential decay rate to track the second moment of past gradients. 39 | eps: A small constant applied to denominator outside of the square root (as 40 | in the Adam paper) to avoid dividing by zero when rescaling. 41 | 42 | Returns: 43 | The corresponding `GradientTransformation`. 44 | """ 45 | if b1 == 1.0 and b2 == 1.0: 46 | # Diag ONS without momentum and second moment decay 47 | return optax.chain( 48 | kitchen_sink.precondition_by_rss(eps=eps, power=1.0), 49 | optax.add_decayed_weights(weight_decay), optax.scale(learning_rate)) 50 | elif b1 == 1.0 and b2 != 1.0: 51 | # Diag ONS without momentum but with second moment decay 52 | return optax.chain( 53 | kitchen_sink.precondition_by_rms( 54 | decay=b2, eps=eps, eps_root=0.0, power=1.0), 55 | optax.add_decayed_weights(weight_decay), optax.scale(learning_rate)) 56 | elif b1 != 1.0 and b2 != 1.0: 57 | # Diag ONS with momentum and second moment decay 58 | return optax.chain( 59 | kitchen_sink.scale_by_adam(b1, b2, eps, eps_root=0.0, power=1.0), 60 | optax.add_decayed_weights(weight_decay), optax.scale(learning_rate)) 61 | 62 | 63 | def last_layer_transformation(last_layer_optimizer, base_lr, 64 | last_layer_base_lr, learning_rate): 65 | """Use an optimizer while scaling by a different learning rate.""" 66 | 67 | return optax.chain(last_layer_optimizer, 68 | optax.scale(learning_rate * last_layer_base_lr / base_lr)) 69 | 70 | 71 | def sherman_morrison(a_inv, u, alpha): 72 | """Given A^-1, compute (A + alpha * u u^T)^-1 using Sherman-Morrison.""" 73 | denom = 1 + alpha * u.T @ a_inv @ u 74 | numer = alpha * jnp.outer(a_inv @ u, u) @ a_inv 75 | 76 | return a_inv - numer / denom 77 | 78 | 79 | class OnlineNewtonState(NamedTuple): 80 | """State holding the sum of gradient squares to date.""" 81 | inv_hessian: optax.Updates 82 | 83 | 84 | def full_matrix_ons(alpha, initial_accumulator_value=0.1): 85 | """A full Online Newton Step transformation.""" 86 | 87 | def init_fn(params): 88 | raveled_params, _ = jax.flatten_util.ravel_pytree(params) 89 | initial_hessian = jnp.diag( 90 | jnp.full_like(raveled_params, 1. / initial_accumulator_value)) 91 | 92 | return OnlineNewtonState(inv_hessian=initial_hessian) 93 | 94 | def update_fn(updates, state, params=None): 95 | del params 96 | 97 | raveled_updates, unravel = jax.flatten_util.ravel_pytree(updates) 98 | new_hessian = sherman_morrison(state.inv_hessian, raveled_updates, alpha) 99 | new_updates = unravel(new_hessian @ raveled_updates) 100 | 101 | return new_updates, OnlineNewtonState(inv_hessian=new_hessian) 102 | 103 | return optax.GradientTransformation(init_fn, update_fn) 104 | 105 | 106 | def online_newton_step(learning_rate, alpha, weight_decay): 107 | r"""An optimizer that does full matrix preconditioning.""" 108 | 109 | return optax.chain( 110 | optax.add_decayed_weights(weight_decay), full_matrix_ons(alpha), 111 | optax.sgd(learning_rate)) 112 | 113 | 114 | def multiple_optimizer(last_layer_name, network_optimizer, last_layer_optimizer, 115 | last_layer_base_lr, base_lr): 116 | """Use a different optimizer for the last layer.""" 117 | 118 | def get_select_fn(layer_name): 119 | """Get a function that selects the specified layer as last layer.""" 120 | 121 | def select_layer(tree): 122 | return {k: ('ll' if k == layer_name else 'net') for k, v in tree.items()} 123 | 124 | return select_layer 125 | 126 | return kitchen_sink.unfreeze_wrapper(*optax.multi_transform( 127 | { 128 | 'net': 129 | network_optimizer, 130 | # Scale the learning rate of the last layer according to match 131 | # last_layer_base_lr 132 | 'll': 133 | utils.static_inject_hyperparams(last_layer_transformation) 134 | (last_layer_optimizer, 135 | base_lr, 136 | last_layer_base_lr, 137 | learning_rate=0.0), 138 | }, 139 | get_select_fn(last_layer_name))) 140 | -------------------------------------------------------------------------------- /init2winit/optimizer_lib/samuel.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Implementation of the SAMUEL optimizer. 17 | 18 | Paper: https://arxiv.org/pdf/2203.01400.pdf 19 | """ 20 | 21 | import copy 22 | from typing import Any 23 | from typing import Dict 24 | from typing import List 25 | from typing import NamedTuple 26 | 27 | from init2winit.optimizer_lib.utils import static_inject_hyperparams 28 | import jax 29 | import jax.numpy as jnp 30 | import optax 31 | 32 | 33 | class SamuelState(NamedTuple): 34 | inner_state: NamedTuple 35 | expert_weights: jnp.ndarray 36 | key: jnp.ndarray 37 | current_expert: int = 0 38 | step: int = 0 39 | 40 | 41 | def samuel( 42 | optimizers: List[str], 43 | hps: List[Dict[str, Any]], 44 | mw_etas: jnp.ndarray, 45 | seed: int = 0, 46 | train_loss: float = 0.0, 47 | learning_rate: float = 0.0, 48 | ): 49 | """Samuel optimizer. 50 | 51 | NOTES 52 | - This implementation assumes each host is an expert (i.e., holds a copy of 53 | the model). As a consequence, we must modify the input and training 54 | pipelines to forgot data parallelism across hosts and limit to only the 55 | local devices available to a given host. 56 | - We synchronize after each batch. This is not always necessary and can be 57 | a point of future performance optimization. 58 | 59 | TODO(dsuo): add LR schedules for optimizers. 60 | 61 | Args: 62 | optimizers: list of strings indicating optax optimizers. 63 | hps: list of hps for each optimizer. 64 | mw_etas: list of multiplicative weight etas. 65 | seed: initial jax random seed. 66 | train_loss: train loss to be injected at update time. 67 | learning_rate: for compatability, but ignored for now. 68 | 69 | Returns: 70 | samuel optimizer 71 | """ 72 | del learning_rate 73 | 74 | num_experts = len(optimizers) 75 | mw_etas = jnp.array(mw_etas) 76 | 77 | if num_experts != jax.process_count(): 78 | raise ValueError( 79 | 'This implementation of SAMUEL requires the number of optimizers to be ' 80 | 'equal to the number of hosts (one host per expert).') 81 | 82 | optimizer = optimizers[jax.process_index()] 83 | hps = hps[jax.process_index()] 84 | 85 | optimizer = getattr(optax, optimizer)(**hps) 86 | 87 | def init_fn(params): 88 | return SamuelState( 89 | inner_state=optimizer.init(params), 90 | expert_weights=jnp.repeat(mw_etas, num_experts, axis=1), 91 | # NOTE(dsuo): each init gives the same key, but given the changing 92 | # params for each model, this is not an issue. 93 | key=jax.random.PRNGKey(seed), 94 | ) 95 | 96 | def update_fn(updates, state, params): 97 | del params 98 | 99 | key, subkey = jax.random.split(state.key) 100 | 101 | # Compute updates based on inner optimizer 102 | updates, inner_state = optimizer.update(updates, state.inner_state) 103 | 104 | prob = state.expert_weights.sum(axis=0) / state.expert_weights.sum() 105 | 106 | # NOTE(dsuo): we rely on jax determinism for each host to behave the same. 107 | current_expert = jax.random.choice(subkey, jnp.arange(prob.size), p=prob) 108 | 109 | # Synchronize train_losses across hosts. 110 | # NOTE(dsuo): since we are already insider a pmap, we can't use 111 | # jax.experimental.multihost_utils. 112 | # NOTE(dsuo): train_losses is of shape (jax.process_count(),). 113 | train_losses = jax.lax.all_gather(train_loss, 'batch').reshape( 114 | jax.process_count(), jax.local_device_count())[:, 0] 115 | 116 | # Compute loss regret and update expert weights. 117 | loss_regret = train_losses.at[current_expert].get() - train_losses 118 | expert_weights = state.expert_weights * jnp.exp(mw_etas * loss_regret) 119 | 120 | state = SamuelState( 121 | inner_state=inner_state, 122 | expert_weights=expert_weights, 123 | key=key, 124 | current_expert=current_expert, 125 | step=state.step + 1, 126 | ) 127 | return updates, state 128 | 129 | return optax.GradientTransformation(init_fn, update_fn) 130 | 131 | 132 | def from_hparams(opt_hparams): 133 | """Create SAMUEL optimizer from init2winit.""" 134 | opt_hparams_optimizers = opt_hparams['optimizers'] 135 | 136 | optimizers = [] 137 | hps = [] 138 | index = 0 139 | while str(index) in opt_hparams_optimizers: 140 | hparams = opt_hparams_optimizers[str(index)] 141 | hp = hparams.get('hps', {}) 142 | 143 | for h in copy.deepcopy(hp).keys(): 144 | if 'one_minus_' in h: 145 | hp[h.replace('one_minus_', '')] = 1 - hp[h] 146 | del hp[h] 147 | 148 | optimizers.append(hparams.get('optimizer')) 149 | hps.append(hp) 150 | index += 1 151 | 152 | return static_inject_hyperparams(samuel)( 153 | optimizers=optimizers, hps=hps, **opt_hparams['args']) 154 | -------------------------------------------------------------------------------- /init2winit/optimizer_lib/sharpness_aware_minimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Implementation of Sharpness Aware Minimization (SAM). 17 | 18 | This implementation is still being actively evaluated by the MLCommons Training 19 | Algorithms Benchmark, so it should not be used (yet). 20 | 21 | Paper: https://arxiv.org/abs/2010.01412 22 | Code: https://github.com/google-research/sam 23 | """ 24 | 25 | from typing import Optional 26 | 27 | from init2winit.model_lib import model_utils 28 | 29 | import jax 30 | import jax.numpy as jnp 31 | import optax 32 | 33 | _GRAD_CLIP_EPS = 1e-6 34 | 35 | 36 | # Copied from the official SAM GitHub repository. Note how it doesn't add an 37 | # epsilon to the gradient norm before normalizing the gradients. 38 | def dual_vector(y: jnp.ndarray) -> jnp.ndarray: 39 | """Returns the solution of max_x y^T x s.t. 40 | 41 | ||x||_2 <= 1. 42 | 43 | Args: 44 | y: A pytree of numpy ndarray, vector y in the equation above. 45 | """ 46 | gradient_norm = jnp.sqrt( 47 | sum([jnp.sum(jnp.square(e)) for e in jax.tree_util.tree_leaves(y)])) 48 | normalized_gradient = jax.tree.map(lambda x: x / gradient_norm, y) 49 | return normalized_gradient 50 | 51 | 52 | def sharpness_aware_minimization( 53 | rho: float, 54 | grad_clip: Optional[float], 55 | base_opt_init_fn, 56 | base_opt_update_fn, 57 | ) -> optax.GradientTransformation: 58 | """Implementation of Sharpness Aware Minimization (SAM). 59 | 60 | Paper: https://arxiv.org/abs/2010.01412 61 | Code: https://github.com/google-research/sam 62 | 63 | References: 64 | Foret et al, 2021: https://arxiv.org/abs/2010.01412 65 | Args: 66 | rho: The size of the neighborhood for the sharpness aware minimization 67 | gradient updates. Defaults to 0.1. 68 | grad_clip: The optional value to clip the updates by. Defaults to None. 69 | base_opt_init_fn: The initialization function for the base optimizer used to 70 | generate updates given the total gradient. 71 | base_opt_update_fn: The update function for the base optimizer used to 72 | generate updates given the total gradient. 73 | 74 | Returns: 75 | The corresponding `GradientTransformation`. 76 | """ 77 | 78 | def init_fn(params): 79 | return base_opt_init_fn(params) 80 | 81 | def update_fn(updates, state, grad_fn_params_tuple): 82 | (grad_fn, params) = grad_fn_params_tuple 83 | 84 | # Updates here have been averaged across devices in Trainer before being 85 | # sent to the optimizer. We obtain gradients computed on the noised 86 | # parameters in the same order as how Trainer does on the original 87 | # gradients and with the same 1e-6 epsilon that is used when clipping the 88 | # gradients. 89 | updates = dual_vector(updates) 90 | noised_params = jax.tree_util.tree_map(lambda p, u: p + rho * u, params, 91 | updates) 92 | _, updates = grad_fn(noised_params) 93 | 94 | updates_norm = jnp.sqrt(model_utils.l2_regularization(updates, 0)) 95 | if grad_clip: 96 | scaled_updates = jax.tree.map( 97 | lambda x: x / (updates_norm + _GRAD_CLIP_EPS) * grad_clip, updates) 98 | updates = jax.lax.cond(updates_norm > grad_clip, lambda _: scaled_updates, 99 | lambda _: updates, None) 100 | updates, state = base_opt_update_fn(updates, state, params) 101 | 102 | return updates, state 103 | 104 | return optax.GradientTransformation(init_fn, update_fn) 105 | -------------------------------------------------------------------------------- /init2winit/optimizer_lib/test_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for utils.""" 17 | 18 | from absl.testing import absltest 19 | import chex 20 | from init2winit.optimizer_lib import optimizers 21 | from init2winit.optimizer_lib import utils 22 | import jax 23 | import jax.numpy as jnp 24 | from ml_collections.config_dict import ConfigDict 25 | import optax 26 | 27 | 28 | # pylint:disable=duplicate-key 29 | 30 | 31 | class ExtractFieldTest(chex.TestCase): 32 | """Test the extract_field() function.""" 33 | 34 | def test_adam(self): 35 | init_fn, update_fn = optimizers.get_optimizer( 36 | ConfigDict({ 37 | 'optimizer': 'adam', 38 | 'l2_decay_factor': None, 39 | 'batch_size': 50, 40 | 'total_accumulated_batch_size': 100, # Use gradient accumulation. 41 | 'opt_hparams': { 42 | 'beta1': 0.9, 43 | 'beta2': 0.999, 44 | 'epsilon': 1e-7, 45 | 'weight_decay': 0.0, 46 | }, 47 | }) 48 | ) 49 | del update_fn 50 | optimizer_state = init_fn({'foo': jnp.ones(10)}) 51 | # Test that we can extract 'count'. 52 | chex.assert_type(utils.extract_field(optimizer_state, 'count'), int) 53 | # Test that we can extract 'nu'. 54 | chex.assert_shape(utils.extract_field(optimizer_state, 'nu')['foo'], (10,)) 55 | # Test that we can extract 'mu'. 56 | chex.assert_shape(utils.extract_field(optimizer_state, 'mu')['foo'], (10,)) 57 | # Test that attemptping to extract a nonexistent field "abc" returns None. 58 | chex.assert_equal(utils.extract_field(optimizer_state, 'abc'), None) 59 | 60 | 61 | class GradientAggregationDecoratorTest(chex.TestCase): 62 | """Test the requires_gradient_aggregation() decorator.""" 63 | 64 | def test_no_aggregation(self): 65 | """Tests behavior with the decorator.""" 66 | 67 | @utils.no_cross_device_gradient_aggregation 68 | def dummy_update_fn(updates, state, params): 69 | del updates, state, params 70 | 71 | self.assertFalse(utils.requires_gradient_aggregation(dummy_update_fn)) 72 | 73 | def test_with_aggregation(self): 74 | """Tests the default behavior.""" 75 | 76 | def dummy_update_fn(updates, state, params): 77 | del updates, state, params 78 | 79 | self.assertTrue(utils.requires_gradient_aggregation(dummy_update_fn)) 80 | 81 | 82 | class OverwriteHparamNamesTest(chex.TestCase): 83 | """Test the overwrite_hparam_names() function.""" 84 | 85 | def test_overwrite_hparams_names(self): 86 | init_params = jnp.array([1.0, 2.0, 3.0]) 87 | 88 | def fun(x): 89 | return 0.5 * jnp.sum(x**2) 90 | 91 | # If we were to setting up the learning rate, we would stick at the current 92 | # params 93 | opt = optax.inject_hyperparams(optax.sgd)(learning_rate=0.0) 94 | 95 | state = opt.init(init_params) 96 | 97 | @jax.jit 98 | def step(params, state): 99 | grad = jax.grad(fun)(params) 100 | updates, state = opt.update(grad, state) 101 | params = optax.apply_updates(params, updates) 102 | return params, state 103 | 104 | params = init_params 105 | for _ in range(5): 106 | params, state = step(params, state) 107 | 108 | norm_diff = jnp.linalg.norm(init_params - params) 109 | self.assertEqual(norm_diff, 0.0) 110 | 111 | # If we set the learning rate via lr, we descend well 112 | opt = optax.inject_hyperparams(optax.sgd)(learning_rate=0.0) 113 | opt = utils.overwrite_hparam_names(opt, learning_rate='lr') 114 | 115 | state = opt.init(init_params) 116 | state = optax.tree_utils.tree_set(state, lr=0.5) 117 | 118 | params = init_params 119 | for i in range(5): 120 | state = optax.tree_utils.tree_set(state, lr=1 / (i + 2)) 121 | params, state = step(params, state) 122 | lr = optax.tree_utils.tree_get(state, 'lr') 123 | self.assertEqual(lr, 1 / (i + 2)) 124 | 125 | self.assertLessEqual(fun(params), fun(init_params)) 126 | 127 | 128 | class AppendHparamName(chex.TestCase): 129 | """Test the append_hparam_name() function.""" 130 | 131 | def test_append_hparam_name(self): 132 | init_params = jnp.array([1.0, 2.0, 3.0]) 133 | 134 | def fun(x): 135 | return 0.5 * jnp.sum(x**2) 136 | 137 | opt = optax.inject_hyperparams(optax.sgd)(learning_rate=0.5) 138 | new_opt = utils.append_hparam_name(opt, 'foo') 139 | 140 | # Test that we can access and set the new hparam 141 | state = new_opt.init(init_params) 142 | state = optax.tree_utils.tree_set(state, foo=2.) 143 | foo = optax.tree_utils.tree_get(state, 'foo') 144 | self.assertEqual(foo, 2.0) 145 | 146 | # Test that the optimizer runs without any issue 147 | @jax.jit 148 | def step(params, state): 149 | grad = jax.grad(fun)(params) 150 | updates, state = new_opt.update(grad, state) 151 | params = optax.apply_updates(params, updates) 152 | return params, state 153 | 154 | params = init_params 155 | for _ in range(5): 156 | params, state = step(params, state) 157 | 158 | self.assertLessEqual(fun(params), fun(init_params)) 159 | 160 | 161 | if __name__ == '__main__': 162 | absltest.main() 163 | -------------------------------------------------------------------------------- /init2winit/projects/optlrschedule/README.md: -------------------------------------------------------------------------------- 1 | # README 2 | -------------------------------------------------------------------------------- /init2winit/projects/optlrschedule/notebook_utils/parquet_util.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | r"""Utility functions for loading parquet files in notebooks.""" 17 | 18 | import io 19 | from typing import List, Optional 20 | 21 | from etils import epath 22 | import utils as i2wutils # local file import 23 | import pandas as pd 24 | 25 | 26 | def load_parquet_file( 27 | path: str, 28 | file_name: Optional[str] = None, 29 | *, 30 | sort_by: str = 'score', 31 | ascending: bool = True, 32 | include_provenance: bool = False, 33 | ) -> pd.DataFrame: 34 | """Load a single parquet file and return it as a sorted DataFrame. 35 | 36 | Args: 37 | path: Directory path string 38 | file_name (optional): File name string (default: 'results.parquet') 39 | sort_by: Column to sort by (default: 'score') 40 | ascending: Sort order (default: True) 41 | include_provenance: Whether to include the provenance of the data in the 42 | DataFrame (default: False). If set, adds a column 'provenance' to the 43 | DataFrame with the path of the file. 44 | 45 | Returns: 46 | pandas DataFrame 47 | """ 48 | 49 | if file_name: 50 | path = epath.Path(path) / file_name 51 | else: 52 | path = epath.Path(path) 53 | 54 | # Read the file 55 | with path.open('rb') as in_f: 56 | buf = io.BytesIO(in_f.read()) 57 | df = pd.read_parquet(buf) 58 | 59 | # Sort if the column exists 60 | if sort_by in df.columns: 61 | df.sort_values(by=sort_by, ascending=ascending, inplace=True) 62 | 63 | if include_provenance: 64 | df['provenance'] = str(path) 65 | return df 66 | 67 | 68 | def load_all_parquet_files( 69 | paths: List[str], 70 | file_name: Optional[str] = None, 71 | *, 72 | sort_by: str = 'score', 73 | ascending: bool = True, 74 | include_provenance: bool = False, 75 | num_workers: int = 50, 76 | ) -> pd.DataFrame: 77 | """Load and merge all parquet files from different paths. 78 | 79 | Args: 80 | paths: List of directory paths. 81 | file_name (optional): File name string (default: 'results.parquet') 82 | sort_by: Column to sort by (default: 'score'). 83 | ascending: Sort order (default: True). 84 | include_provenance: Whether to include the provenance of the data in the 85 | DataFrame (default: False). If set, adds a column 'provenance' to the 86 | DataFrame with the path of the file each row came from. 87 | num_workers: Number of workers to use for parallel loading (default: 50). 88 | 89 | Returns: 90 | Merged pandas DataFrame. 91 | """ 92 | shared_kwargs = { 93 | 'file_name': file_name, 94 | 'sort_by': sort_by, 95 | 'ascending': ascending, 96 | 'include_provenance': include_provenance, 97 | } 98 | kwargs_list = [] 99 | for path in paths: 100 | kwargs_list.append({ 101 | 'path': path, 102 | **shared_kwargs, 103 | }) 104 | dfs = i2wutils.run_in_parallel(load_parquet_file, kwargs_list, num_workers) 105 | if dfs: 106 | # Concat will ignore empty DataFrames properly. 107 | merged_df = pd.concat(dfs, ignore_index=True) 108 | return merged_df 109 | else: 110 | return pd.DataFrame() 111 | 112 | 113 | # TODO(gdahl): delete this function once we are happy with the parallel version. 114 | def load_all_parquet_files_sequentially( 115 | paths: List[str], 116 | file_name: Optional[str] = None, 117 | *, 118 | sort_by: str = 'score', 119 | ascending: bool = True, 120 | include_provenance: bool = False, 121 | ) -> pd.DataFrame: 122 | """Load and merge all parquet files from different paths. 123 | 124 | Args: 125 | paths: List of directory paths. 126 | file_name (optional): File name string (default: 'results.parquet') 127 | sort_by: Column to sort by (default: 'score'). 128 | ascending: Sort order (default: True). 129 | include_provenance: Whether to include the provenance of the data in the 130 | DataFrame (default: False). If set, adds a column 'provenance' to the 131 | DataFrame with the path of the file each row came from. 132 | 133 | Returns: 134 | Merged pandas DataFrame. 135 | """ 136 | dfs = [] 137 | 138 | for path in paths: 139 | df = load_parquet_file( 140 | path, 141 | file_name, 142 | sort_by=sort_by, 143 | ascending=ascending, 144 | include_provenance=include_provenance, 145 | ) 146 | if not df.empty: 147 | dfs.append(df) 148 | 149 | if dfs: 150 | merged_df = pd.concat(dfs, ignore_index=True) 151 | return merged_df 152 | else: 153 | return pd.DataFrame() 154 | -------------------------------------------------------------------------------- /init2winit/projects/optlrschedule/scheduler/constant_schedule_family.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Constant schedule family implementation using optax.""" 17 | 18 | from typing import Dict 19 | from init2winit.projects.optlrschedule.scheduler import base_schedule_family 20 | import numpy as np 21 | 22 | 23 | class ConstantScheduleFamily(base_schedule_family.WarmupScheduleFamily): 24 | """Constant learning rate schedule with configurable warmup.""" 25 | 26 | def list_schedule_parameter_keys(self) -> list[str]: 27 | return ['p.warmup_steps'] 28 | 29 | def get_schedule( 30 | self, 31 | schedule_param: Dict[str, float], 32 | base_lr: float, 33 | ) -> np.ndarray: 34 | """Generate constant learning rate schedule with warmup. 35 | 36 | Args: 37 | schedule_param: Dictionary containing schedule parameters. 38 | base_lr: Base learning rate. 39 | 40 | Returns: 41 | np.ndarray: Array of learning rates for each training step. 42 | """ 43 | self.validate_param(schedule_param) 44 | schedule = np.zeros(self.total_steps) 45 | warmup_config = self.schedule_family_config.get('warmup_config', {}) 46 | 47 | warmup_steps = int(schedule_param.get('p.warmup_steps', 0)) 48 | if warmup_steps > 0: 49 | # Warmup phase 50 | warmup_fn = self.get_warmup_fn(self.warmup_type) 51 | for step in range(warmup_steps): 52 | schedule[step] = warmup_fn( 53 | step, warmup_steps, base_lr, **warmup_config 54 | ) 55 | 56 | # Constant phase 57 | schedule[warmup_steps:] = base_lr 58 | else: 59 | # No warmup, constant throughout 60 | schedule[:] = base_lr 61 | 62 | return schedule 63 | -------------------------------------------------------------------------------- /init2winit/projects/optlrschedule/scheduler/cosine_schedule_family.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Cosine schedule family implementation using optax. 17 | 18 | Our code is based on the optax implementation of cosine decay schedule as listed 19 | here: 20 | https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html#optax.schedules.cosine_decay_schedule 21 | 22 | schedule_family_config: Dictionary containing configuration such as: 23 | total_steps: Maximum number of training updates. 24 | 25 | schedule_params: 26 | warmup_steps: Number of warmup steps. 27 | alpha: Decay factor. 28 | """ 29 | 30 | from typing import Any, Dict 31 | 32 | from init2winit.projects.optlrschedule.scheduler import base_schedule_family 33 | import numpy as np 34 | 35 | 36 | class CosineScheduleFamily(base_schedule_family.WarmupScheduleFamily): 37 | """Cosine schedule with configurable warmup methods.""" 38 | 39 | def validate_config(self, config: Dict[str, Any]) -> None: 40 | """Validate configuration parameters.""" 41 | 42 | if 'alpha' not in config: 43 | raise ValueError('alpha must be specified in config') 44 | if config['alpha'] < 0.0: 45 | raise ValueError('alpha must be non-negative') 46 | 47 | if ( 48 | 'warmup_type' in config 49 | and config['warmup_type'] not in base_schedule_family.WARMUP_TYPES 50 | ): 51 | raise ValueError( 52 | 'warmup_type must be one of linear, cosine, exponential, or' 53 | ' polynomial' 54 | ) 55 | 56 | def validate_param( 57 | self, schedule_param: base_schedule_family.ScheduleParams 58 | ) -> bool: 59 | """Validate schedule parameters.""" 60 | super().validate_param(schedule_param) 61 | 62 | required_params = {'p.exponent'} 63 | missing_params = required_params - set(schedule_param.keys()) 64 | if missing_params: 65 | raise ValueError(f'Missing required parameters: {missing_params}') 66 | 67 | if not isinstance(schedule_param['p.exponent'], (int, float)): 68 | raise ValueError('exponent must be a number') 69 | if not (0.0 <= schedule_param['p.exponent']): 70 | raise ValueError('exponent must be larger than 0.0') 71 | 72 | return True 73 | 74 | def cosine_decay( 75 | self, 76 | step: int, 77 | decay_steps: int, 78 | base_lr: float, 79 | alpha: float, 80 | exponent: float, 81 | ) -> float: 82 | """Cosine decay from base_lr to alpha * base_lr. 83 | 84 | Args: 85 | step: Current training step 86 | decay_steps: Number of decay steps 87 | base_lr: Base learning rate 88 | alpha: Decay factor 89 | exponent: Power to raise the cosine decay to 90 | 91 | Returns: 92 | float: Learning rate at the current step 93 | """ 94 | progress = step / decay_steps 95 | cosine = (1 + np.cos(np.pi * progress)) / 2 96 | decayed = (1 - alpha) * (cosine**exponent) + alpha 97 | return base_lr * decayed 98 | 99 | def list_schedule_parameter_keys(self) -> list[str]: 100 | return ['p.warmup_steps', 'p.exponent'] 101 | 102 | def get_schedule( 103 | self, 104 | schedule_param: Dict[str, float], 105 | base_lr: float, 106 | ) -> np.ndarray: 107 | """Generate learning rate schedule based on parameters. 108 | 109 | Args: 110 | schedule_param: Dictionary containing schedule parameters 111 | base_lr: Base learning rate 112 | 113 | Returns: 114 | np.ndarray: Array of learning rates for each training step 115 | """ 116 | # self.validate_param(schedule_param) 117 | 118 | alpha = self.schedule_family_config['alpha'] 119 | warmup_steps = int(schedule_param.get('p.warmup_steps', 0)) 120 | exponent = schedule_param.get('p.exponent', 1.0) 121 | 122 | schedule = np.zeros(self.total_steps) 123 | warmup_fn = self.get_warmup_fn(self.warmup_type) 124 | warmup_config = self.schedule_family_config.get('warmup_config', {}) 125 | 126 | # Warmup phase 127 | for step in range(warmup_steps): 128 | schedule[step] = warmup_fn(step, warmup_steps, base_lr, **warmup_config) 129 | 130 | # Decay phase 131 | decay_steps = self.total_steps - warmup_steps 132 | for step in range(warmup_steps, self.total_steps): 133 | decay_step = step - warmup_steps 134 | schedule[step] = self.cosine_decay( 135 | decay_step, decay_steps, base_lr, alpha, exponent 136 | ) 137 | 138 | return schedule 139 | -------------------------------------------------------------------------------- /init2winit/projects/optlrschedule/scheduler/cosine_standard_schedule_family.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Cosine schedule family implementation with fixed exponent of 1. 17 | 18 | The only difference with cosine_schedule_family.py is that the exponent is 19 | fixed to 1.0. 20 | 21 | schedule_family_config: Dictionary containing configuration such as: 22 | total_steps: Maximum number of training updates. 23 | 24 | schedule_params: 25 | warmup_steps: Number of warmup steps. 26 | alpha: Decay factor. 27 | """ 28 | 29 | from typing import Dict 30 | 31 | from init2winit.projects.optlrschedule.scheduler import base_schedule_family 32 | from init2winit.projects.optlrschedule.scheduler import cosine_schedule_family 33 | import numpy as np 34 | 35 | 36 | class CosineStandardScheduleFamily(cosine_schedule_family.CosineScheduleFamily): 37 | """Cosine schedule with configurable warmup methods and exponent of 1.0.""" 38 | 39 | def validate_param( 40 | self, schedule_param: base_schedule_family.ScheduleParams 41 | ) -> bool: 42 | """Validate schedule parameters.""" 43 | return base_schedule_family.WarmupScheduleFamily.validate_param( 44 | self, schedule_param 45 | ) 46 | 47 | def list_schedule_parameter_keys(self) -> list[str]: 48 | return ['p.warmup_steps'] 49 | 50 | def get_schedule( 51 | self, 52 | schedule_param: Dict[str, float], 53 | base_lr: float, 54 | ) -> np.ndarray: 55 | """Generate learning rate schedule based on parameters. 56 | 57 | Args: 58 | schedule_param: Dictionary containing schedule parameters 59 | base_lr: Base learning rate 60 | 61 | Returns: 62 | np.ndarray: Array of learning rates for each training step 63 | """ 64 | 65 | alpha = self.schedule_family_config['alpha'] 66 | warmup_steps = int(schedule_param.get('p.warmup_steps', 0)) 67 | 68 | schedule = np.zeros(self.total_steps) 69 | warmup_fn = self.get_warmup_fn(self.warmup_type) 70 | warmup_config = self.schedule_family_config.get('warmup_config', {}) 71 | 72 | # Warmup phase 73 | for step in range(warmup_steps): 74 | schedule[step] = warmup_fn(step, warmup_steps, base_lr, **warmup_config) 75 | 76 | # Decay phase 77 | decay_steps = self.total_steps - warmup_steps 78 | for step in range(warmup_steps, self.total_steps): 79 | decay_step = step - warmup_steps 80 | schedule[step] = self.cosine_decay( 81 | decay_step, decay_steps, base_lr, alpha, 1.0 82 | ) 83 | 84 | return schedule 85 | -------------------------------------------------------------------------------- /init2winit/projects/optlrschedule/scheduler/rex_schedule_family.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """REX schedule family implementation. 17 | 18 | W generalize REX schedule by introducing a beta parameter and remove coefficient 19 | of 1/2 in the denominator. 20 | 21 | arxiv: https://arxiv.org/pdf/2107.04197 (MLSys2022) 22 | github: https://github.com/IvanVassi/REX_LR/blob/main/lr_scheduler.py 23 | 24 | The difference between this implementation and the original REX schedule is 25 | that the beta parameter is provided as a schedule parameter. 26 | 27 | Original REX schedule: 28 | progress = t/T 29 | beta = 0.9 (in code) or 1 (in paper) 30 | return (1 - progress) / ((1 - progress * beta)/2 + 1/2) 31 | 32 | Our Generalized REX schedule: 33 | progress = t/T 34 | beta = [0, inf] 35 | alpha = 1 - beta 36 | return (1 - progress) / (1 - progress * alpha) 37 | """ 38 | 39 | from typing import Any 40 | from init2winit.projects.optlrschedule.scheduler import base_schedule_family 41 | import numpy as np 42 | 43 | 44 | class RexScheduleFamily(base_schedule_family.WarmupScheduleFamily): 45 | """REX schedule implementation with configurable warmup using NumPy. 46 | 47 | The default values for max_val and min_val in schedule_config are 1 and 0, 48 | respectively, 49 | and the learning rate is scaled by base_lr. The beta parameter (default 0.9) 50 | is provided as a schedule parameter. 51 | """ 52 | 53 | def validate_param(self, schedule_param: dict[str, Any]) -> bool: 54 | """Validate schedule parameters.""" 55 | super().validate_param(schedule_param) 56 | required_params = { 57 | 'p.warmup_steps', 58 | 'p.beta', 59 | } # p.max_val, p.min_val, and p.beta are optional with defaults. 60 | missing_params = required_params - set(schedule_param.keys()) 61 | if missing_params: 62 | raise ValueError( 63 | f'Missing required schedule parameters: {missing_params}' 64 | ) 65 | 66 | return True 67 | 68 | def list_schedule_parameter_keys(self) -> list[str]: 69 | return ['p.warmup_steps', 'p.beta'] 70 | 71 | def rex_decay(self, progress: np.ndarray, beta: float = 0.9) -> np.ndarray: 72 | """Compute the REX decay multiplier. 73 | 74 | Args: 75 | progress: The progress in the decay phase (range from 0 to 1). Measured 76 | from the beginning of the decay phase. 77 | beta: The beta parameter controlling the shape of the decay. 78 | 79 | Returns: 80 | np.ndarray: The REX decay multiplier for each progress value. 81 | """ 82 | alpha = 1 - beta 83 | return (1 - progress) / (1 - progress * alpha) 84 | 85 | def get_schedule( 86 | self, schedule_param: dict[str, Any], base_lr: float 87 | ) -> np.ndarray: 88 | """Generate the learning rate schedule for all training steps. 89 | 90 | Args: 91 | schedule_param: Dictionary of schedule parameters (e.g., 92 | {'p.warmup_steps': 100, 'p.beta': 0.9}). 93 | base_lr: Base learning rate, which is used to scale the schedule. 94 | 95 | Returns: 96 | np.ndarray: An array of learning rates for each training step. 97 | """ 98 | # Validate parameters (optional if called externally). 99 | self.validate_param(schedule_param) 100 | 101 | warmup_steps = int(schedule_param['p.warmup_steps']) 102 | beta = schedule_param.get('p.beta', 0.9) 103 | 104 | schedule = np.zeros(self.total_steps) 105 | warmup_fn = self.get_warmup_fn(self.warmup_type) 106 | warmup_config = self.schedule_family_config.get('warmup_config', {}) 107 | 108 | # Warmup phase: compute learning rate values during warmup. 109 | # Use base_lr as the target learning rate during warmup. 110 | warmup_lr = [] 111 | for step in range(warmup_steps): 112 | lr = warmup_fn(step, warmup_steps, base_lr, **warmup_config) 113 | warmup_lr.append(lr) 114 | warmup_lr = np.array(warmup_lr) 115 | schedule[:warmup_steps] = warmup_lr 116 | 117 | # Decay phase: 118 | decay_steps = self.total_steps - warmup_steps 119 | 120 | # Compute normalized progress (0 to 1) for the decay phase. 121 | steps = np.arange(decay_steps) 122 | progress = steps / decay_steps # progress increases from 0 to 1 123 | # Apply the REX decay function with the specified beta. 124 | decay_multiplier = self.rex_decay(progress, beta=beta) 125 | # Compute learning rate for the decay phase: 126 | # scale by base_lr. 127 | decay_lr = base_lr * decay_multiplier 128 | 129 | schedule[warmup_steps:] = decay_lr 130 | return schedule 131 | -------------------------------------------------------------------------------- /init2winit/projects/optlrschedule/scheduler/schedule_families.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | r"""Schedule families for learning rate schedules.""" 17 | 18 | from typing import Any 19 | 20 | from init2winit.projects.optlrschedule.scheduler import constant_schedule_family 21 | from init2winit.projects.optlrschedule.scheduler import cosine_schedule_family 22 | from init2winit.projects.optlrschedule.scheduler import cosine_standard_schedule_family 23 | from init2winit.projects.optlrschedule.scheduler import rex_schedule_family 24 | from init2winit.projects.optlrschedule.scheduler import smooth_nonmonotonic_schedule_family 25 | from init2winit.projects.optlrschedule.scheduler import sqrt_schedule_family 26 | from init2winit.projects.optlrschedule.scheduler import twopointslinear_schedule_family 27 | from init2winit.projects.optlrschedule.scheduler import twopointsspline_schedule_family 28 | 29 | SCHEDULE_FAMILIES = { 30 | 'cosine': cosine_schedule_family.CosineScheduleFamily, 31 | 'cosine_standard': ( 32 | cosine_standard_schedule_family.CosineStandardScheduleFamily 33 | ), 34 | 'constant': constant_schedule_family.ConstantScheduleFamily, 35 | 'twopointsspline': ( 36 | twopointsspline_schedule_family.TwoPointSplineScheduleFamily 37 | ), 38 | 'twopointslinear': ( 39 | twopointslinear_schedule_family.TwoPointLinearScheduleFamily 40 | ), 41 | 'sqrt': sqrt_schedule_family.SqrtScheduleFamily, 42 | 'smoothnonmonotonic': ( 43 | smooth_nonmonotonic_schedule_family.TwoPointSplineSmoothNonmonoticScheduleFamily 44 | ), 45 | 'rex': rex_schedule_family.RexScheduleFamily, 46 | } 47 | 48 | 49 | def get_schedule_family_class( 50 | schedule_type: str, 51 | ) -> type[Any]: 52 | """Get schedule family class for a given schedule type.""" 53 | try: 54 | return SCHEDULE_FAMILIES[schedule_type] 55 | except KeyError as e: 56 | raise ValueError(f'Unsupported schedule type: {schedule_type}') from e 57 | -------------------------------------------------------------------------------- /init2winit/projects/optlrschedule/scheduler/smooth_nonmonotonic_schedule_family.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | r"""Jax implementation of Smooth Non-monotonic Schedule family. 17 | 18 | schedule_family_config = { 19 | total_steps: Total number of training steps 20 | } 21 | 22 | schedule_params = { 23 | 'y_start': Boundary value of initial learning rate 24 | 'y_end': Boundary value of final learning rate 25 | 'x_peak': Peak position of learning rate in horizontal axis 26 | 'y1': Learning rate at first control point 27 | 'delta_x1': Related distance of first control point from start 28 | 'y2': Learning rate at second control point 29 | 'delta_x2': Related distance of second control point from first control 30 | point 31 | } 32 | """ 33 | 34 | from typing import Dict, Tuple 35 | from absl import logging 36 | from init2winit.projects.optlrschedule.scheduler import base_schedule_family 37 | import numpy as np 38 | from scipy import interpolate 39 | 40 | 41 | class TwoPointSplineSmoothNonmonoticScheduleFamily( 42 | base_schedule_family.BaseScheduleFamily 43 | ): 44 | """Non-monotonic learning rate scheduler with arbitrary peak placement.""" 45 | 46 | def validate_param(self, params: Dict[str, float]) -> None: 47 | """Validate schedule parameters.""" 48 | required_params = { 49 | 'p.y_start', 50 | 'p.y_end', # Boundary values 51 | 'p.x_peak', # Peak position 52 | 'p.y1', 53 | 'p.delta_x1', # First normal point 54 | 'p.y2', 55 | 'p.delta_x2', # Second normal point 56 | } 57 | if not all(k in params for k in required_params): 58 | raise ValueError(f'Missing parameters. Required: {required_params}') 59 | 60 | # Validate ranges 61 | for param, value in params.items(): 62 | if param.startswith('p.delta_x') and not (0 < value < 1): 63 | raise ValueError(f'{param} must be in (0, 1), got {value}') 64 | if param.startswith('p.y') and not (0 <= value <= 1): 65 | raise ValueError(f'{param} must be in [0, 1], got {value}') 66 | if param == 'p.x_peak' and not (0 < value < 1): 67 | raise ValueError(f'x_peak must be in (0, 1), got {value}') 68 | 69 | def _compute_control_points( 70 | self, params: Dict[str, float] 71 | ) -> Tuple[np.ndarray, np.ndarray]: 72 | """Compute control points using stick-breaking procedure. 73 | 74 | Points can be placed in any order, with peak at any position. 75 | 76 | Args: 77 | params: Dictionary of schedule parameters. 78 | 79 | Returns: 80 | Tuple of control points (x, y). 81 | """ 82 | # First point using delta_x1 83 | x1 = params['p.delta_x1'] 84 | 85 | # Second point using delta_x2 of remaining space 86 | remaining_space = 1.0 - x1 87 | x2 = x1 + remaining_space * params['p.delta_x2'] 88 | 89 | x_points = np.array([0.0, x1, x2, params['p.x_peak'], 1.0]) 90 | y_points = np.array( 91 | [params['p.y_start'], 92 | params['p.y1'], 93 | params['p.y2'], 94 | 1.0, 95 | params['p.y_end']] 96 | ) 97 | order = np.argsort(x_points) 98 | x_points = x_points[order] 99 | y_points = y_points[order] 100 | 101 | # Ensure uniqueness of x coordinates 102 | unique_indices = np.unique(x_points, return_index=True)[1] 103 | if len(unique_indices) < len(x_points): 104 | logging.warning( 105 | 'Found duplicates in x_points. Reducing from %d to %d unique points.', 106 | len(x_points), 107 | len(unique_indices), 108 | ) 109 | x_points = x_points[unique_indices] 110 | y_points = y_points[unique_indices] 111 | 112 | return x_points, y_points 113 | 114 | def list_schedule_parameter_keys(self) -> list[str]: 115 | return [ 116 | 'p.y_start', 117 | 'p.y_end', 118 | 'p.x_peak', 119 | 'p.y1', 120 | 'p.delta_x1', 121 | 'p.y2', 122 | 'p.delta_x2', 123 | ] 124 | 125 | def get_schedule( 126 | self, params: Dict[str, float], base_lr: float 127 | ) -> np.ndarray: 128 | """Generate learning rate schedule.""" 129 | self.validate_param(params) 130 | 131 | # Compute control points 132 | x_points, y_points = self._compute_control_points(params) 133 | x_steps = x_points * self.total_steps 134 | spline = interpolate.PchipInterpolator(x_steps, y_points) 135 | steps = np.arange(self.total_steps) 136 | lr_array = spline(steps) * base_lr 137 | 138 | return lr_array 139 | -------------------------------------------------------------------------------- /init2winit/projects/optlrschedule/scheduler/sqrt_schedule_family.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Sqrt decay schedule family implementation using optax. 17 | 18 | Our code is based on the optax implementation of sqrt decay decay schedule 19 | 20 | schedule_family_config: Dictionary containing configuration such as: 21 | total_steps: Maximum number of training updates. 22 | 23 | schedule_params: 24 | warmup_steps: Number of warmup steps. 25 | alpha: Decay factor. 26 | """ 27 | 28 | from typing import Any, Dict 29 | 30 | from init2winit.projects.optlrschedule.scheduler import base_schedule_family 31 | import numpy as np 32 | 33 | 34 | class SqrtScheduleFamily(base_schedule_family.WarmupScheduleFamily): 35 | """Sqrt decay schedule with configurable warmup methods.""" 36 | 37 | def validate_config(self, config: Dict[str, Any]) -> None: 38 | """Validate configuration parameters.""" 39 | super().validate_config(config) 40 | 41 | if ( 42 | 'warmup_type' in config 43 | and config['warmup_type'] not in base_schedule_family.WARMUP_TYPES 44 | ): 45 | raise ValueError( 46 | 'warmup_type must be one of linear, cosine, exponential, or' 47 | ' polynomial' 48 | ) 49 | 50 | def validate_param( 51 | self, schedule_param: base_schedule_family.ScheduleParams 52 | ) -> bool: 53 | """Validate schedule parameters.""" 54 | super().validate_param(schedule_param) 55 | 56 | required_params = {'p.alpha'} 57 | missing_params = required_params - set(schedule_param.keys()) 58 | if missing_params: 59 | raise ValueError(f'Missing required parameters: {missing_params}') 60 | 61 | if not isinstance(schedule_param['p.alpha'], (float)): 62 | raise ValueError('alpha must be a number') 63 | if not (0.0 <= schedule_param['p.alpha'] <= 1.0): 64 | raise ValueError('alpha must be in the range [0.0, 1.0]') 65 | 66 | return True 67 | 68 | def sqrt_decay(self, x: float, alpha: float) -> float: 69 | """Sqrt decay function. 70 | 71 | Args: 72 | x: Normalized progress (value between 0 and 1). 73 | alpha: Decay factor. 74 | 75 | Returns: 76 | float: Decay multiplier at the current progress. 77 | """ 78 | return np.sqrt(1 - x**2) ** alpha 79 | 80 | def list_schedule_parameter_keys(self) -> list[str]: 81 | return ['p.warmup_steps', 'p.alpha'] 82 | 83 | def get_schedule( 84 | self, 85 | schedule_param: Dict[str, float], 86 | base_lr: float, 87 | ) -> np.ndarray: 88 | """Generate learning rate schedule based on parameters. 89 | 90 | Args: 91 | schedule_param: Dictionary containing schedule parameters 92 | base_lr: Base learning rate 93 | 94 | Returns: 95 | np.ndarray: Array of learning rates for each training step 96 | """ 97 | warmup_steps = int(schedule_param.get('p.warmup_steps', 0)) 98 | alpha = schedule_param.get('p.alpha', 1.0) 99 | 100 | schedule = np.zeros(self.total_steps) 101 | warmup_fn = self.get_warmup_fn(self.warmup_type) 102 | warmup_config = self.schedule_family_config.get('warmup_config', {}) 103 | 104 | # Warmup phase 105 | for step in range(warmup_steps): 106 | schedule[step] = warmup_fn(step, warmup_steps, base_lr, **warmup_config) 107 | 108 | # Decay phase 109 | decay_steps = self.total_steps - warmup_steps 110 | decay_base_lr = schedule[warmup_steps - 1] if warmup_steps > 0 else base_lr 111 | 112 | for step in range(warmup_steps, self.total_steps): 113 | decay_step = step - warmup_steps 114 | normalized_progress = ( 115 | decay_step / decay_steps 116 | ) # Normalized progress (0 to 1) 117 | decay_multiplier = self.sqrt_decay(normalized_progress, alpha) 118 | schedule[step] = decay_base_lr * decay_multiplier 119 | 120 | return schedule 121 | -------------------------------------------------------------------------------- /init2winit/projects/optlrschedule/search_algorithm/search_algorithms.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | r"""Search algorithms for schedule parameters search.""" 17 | 18 | from typing import Any 19 | 20 | from init2winit.projects.optlrschedule.search_algorithm import ( 21 | coordinate_descent_search, 22 | ) 23 | from init2winit.projects.optlrschedule.search_algorithm import grid_search 24 | from init2winit.projects.optlrschedule.search_algorithm import random_search 25 | 26 | 27 | SEARCH_ALGORITHMS = { 28 | 'random': random_search.RandomSearch, 29 | 'grid': grid_search.GridSearch, 30 | 'coordinate_descent': coordinate_descent_search.CoordinateDescentSearch, 31 | } 32 | 33 | 34 | def get_search_algorithm_class(search_type: str) -> type[Any]: 35 | """Get search algorithm class for a given search type. 36 | 37 | Args: 38 | search_type: The type of search algorithm to get. 39 | 40 | Returns: 41 | The class of the search algorithm. 42 | 43 | Raises: 44 | ValueError: If the search type is not found. 45 | """ 46 | try: 47 | return SEARCH_ALGORITHMS[search_type] 48 | except KeyError as e: 49 | raise ValueError( 50 | f'Search type {search_type} not found in {SEARCH_ALGORITHMS.keys()}' 51 | ) from e 52 | -------------------------------------------------------------------------------- /init2winit/projects/optlrschedule/workload/workloads.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | r"""Workload classes for different tasks.""" 17 | 18 | from typing import Any 19 | 20 | from init2winit.projects.optlrschedule.workload import cifar10_cnn 21 | from init2winit.projects.optlrschedule.workload import linear_regression 22 | from init2winit.projects.optlrschedule.workload import wikitext103_transformer 23 | 24 | 25 | WORKLOADS = { 26 | 'cifar10_cnn': cifar10_cnn.Cifar10Training, 27 | 'wikitext103': wikitext103_transformer.Wikitext103Transformer, 28 | 'linear_regression': linear_regression.LinearRegression, 29 | } 30 | 31 | 32 | def get_workload_class(workload_name: str) -> type[Any]: 33 | """Get workload class for a given workload name. 34 | 35 | Args: 36 | workload_name: The name of the workload to get. 37 | 38 | Returns: 39 | The class of the workload. 40 | 41 | Raises: 42 | ValueError: If the workload name is not found. 43 | """ 44 | if workload_name not in WORKLOADS: 45 | raise ValueError(f'Unsupported workload: {workload_name}') 46 | return WORKLOADS[workload_name] 47 | -------------------------------------------------------------------------------- /init2winit/shared_test_utilities.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Shared utilities for unit tests.""" 17 | 18 | import functools 19 | 20 | import jax 21 | import jax.numpy as jnp 22 | 23 | 24 | def pytree_equal(tree1, tree2): 25 | try: 26 | equal_tree = jax.tree_util.tree_map(jnp.array_equal, tree1, tree2) 27 | return jax.tree_util.tree_reduce(lambda x, y: x and y, equal_tree) 28 | # The tree_utils will raise TypeErrors if structures don't match. 29 | except TypeError: 30 | return False 31 | 32 | 33 | def pytree_allclose(tree1, tree2, rtol=1e-5): 34 | try: 35 | allclose = functools.partial(jnp.allclose, rtol=rtol) 36 | equal_tree = jax.tree_util.tree_map(allclose, tree1, tree2) 37 | return jax.tree_util.tree_reduce(lambda x, y: x and y, equal_tree) 38 | # The tree_utils will raise TypeErrors if structures don't match. 39 | except TypeError: 40 | return False 41 | -------------------------------------------------------------------------------- /init2winit/test_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for utils.py. 17 | 18 | """ 19 | 20 | import os 21 | import shutil 22 | import tempfile 23 | 24 | from absl.testing import absltest 25 | from absl.testing import parameterized 26 | from init2winit import checkpoint 27 | from init2winit import utils 28 | import jax.numpy as jnp 29 | import numpy as np 30 | 31 | 32 | def _identity(i): 33 | return i 34 | 35 | 36 | def _fn_that_always_fails(arg): 37 | del arg 38 | raise ValueError('I always fail.') 39 | 40 | 41 | class UtilsTest(parameterized.TestCase): 42 | """Tests for utils.py.""" 43 | 44 | def setUp(self): 45 | super(UtilsTest, self).setUp() 46 | self.test_dir = tempfile.mkdtemp() 47 | 48 | def tearDown(self): 49 | shutil.rmtree(self.test_dir) 50 | super(UtilsTest, self).tearDown() 51 | 52 | @parameterized.named_parameters( 53 | dict( 54 | testcase_name='empty list of args', 55 | num_workers=1, 56 | input_list_dict=[], 57 | expected=[], 58 | ), 59 | dict( 60 | testcase_name='one worker, nonempty list', 61 | num_workers=1, 62 | input_list_dict=[dict(i=k) for k in range(1, 10)], 63 | expected=list(range(1, 10)), 64 | ), 65 | dict( 66 | testcase_name='fewer workers than jobs, nonempty list', 67 | num_workers=3, 68 | input_list_dict=[dict(i=k) for k in range(1, 10)], 69 | expected=list(range(1, 10)), 70 | ), 71 | dict( 72 | testcase_name='more workers than jobs, nonempty list', 73 | num_workers=20, 74 | input_list_dict=[dict(i=k) for k in range(1, 10)], 75 | expected=list(range(1, 10)), 76 | ), 77 | ) 78 | def testRunInParallel(self, input_list_dict, num_workers, expected): 79 | """Test running successful fns in parallel, originally from mlbileschi.""" 80 | actual = utils.run_in_parallel(_identity, input_list_dict, num_workers) 81 | self.assertEqual(actual, expected) 82 | 83 | def testRunInParallelOnFailingFn(self): 84 | """Test running failing fns in parallel, originally from mlbileschi.""" 85 | with self.assertRaisesRegex(ValueError, 'I always fail.'): 86 | utils.run_in_parallel(_fn_that_always_fails, [dict(arg='hi')], 10) 87 | 88 | def testAppendPytree(self): 89 | """Test appending and loading pytrees.""" 90 | pytrees = [{'a': i} for i in range(10)] 91 | pytree_path = os.path.join(self.test_dir, 'pytree.ckpt') 92 | logger = utils.MetricLogger(pytree_path=pytree_path) 93 | 94 | for pytree in pytrees: 95 | logger.append_pytree(pytree) 96 | 97 | latest = checkpoint.load_latest_checkpoint(pytree_path, prefix='') 98 | saved_pytrees = latest['pytree'] if latest else [] 99 | self.assertEqual( 100 | pytrees, [saved_pytrees[str(i)] for i in range(len(saved_pytrees))]) 101 | 102 | def testArrayAppend(self): 103 | """Test appending to an array.""" 104 | np.testing.assert_allclose( 105 | utils.array_append(jnp.array([1, 2, 3]), 4), jnp.array([1, 2, 3, 4])) 106 | np.testing.assert_allclose( 107 | utils.array_append(jnp.array([[1, 2], [3, 4]]), jnp.array([5, 6])), 108 | jnp.array([[1, 2], [3, 4], [5, 6]])) 109 | 110 | def testTreeNormSqL2(self): 111 | """Test computing the squared L2 norm of a pytree.""" 112 | pytree = {'foo': jnp.ones(10), 'baz': jnp.ones(20)} 113 | self.assertEqual(utils.tree_norm_sql2(pytree), {'foo': 10.0, 'baz': 20.0}) 114 | self.assertEqual(utils.total_tree_norm_sql2(pytree), 30.0) 115 | 116 | def testTreeSum(self): 117 | """Test computing the sum of a pytree.""" 118 | pytree = {'foo': 2*jnp.ones(10), 'baz': jnp.ones(20)} 119 | self.assertEqual(utils.total_tree_sum(pytree), 40) 120 | 121 | if __name__ == '__main__': 122 | absltest.main() 123 | 124 | -------------------------------------------------------------------------------- /init2winit/testdata/wikitext_tokenizer_fake_data.txt: -------------------------------------------------------------------------------- 1 | = = Lulu = = 2 | 3 | 4 | Lulu is a Goldendoodle , which is a mixed breed of the Poodle and the Golden Retriever breeds . Her posture is very poodle-like , similar to her father Lucas . Her mother Luna was also a Goldendoodle , although she carried herself much more like a Golden Retriever than a Poodle . -------------------------------------------------------------------------------- /init2winit/tools/inspect_dataset.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | r"""Main file for the init2winit project. 17 | 18 | """ 19 | 20 | import os 21 | import sys 22 | 23 | from absl import app 24 | from absl import flags 25 | from absl import logging 26 | from init2winit import hyperparameters 27 | from init2winit.dataset_lib import datasets 28 | import jax 29 | import tensorflow as tf 30 | 31 | # Don't let TF see the GPU, because all we use it for is tf.data loading. 32 | tf.config.experimental.set_visible_devices([], 'GPU') 33 | 34 | # Enable flax xprof trace labelling. 35 | os.environ['FLAX_PROFILE'] = 'true' 36 | 37 | flags.DEFINE_string('dataset', None, 'Which dataset to inspect') 38 | flags.DEFINE_string('model', None, 'Which model to use') 39 | flags.DEFINE_integer('batch_size', None, 40 | 'Number of examples to retrieve in 1 batch') 41 | flags.DEFINE_integer('num_batches', None, 'Number of batches to retrieve') 42 | 43 | FLAGS = flags.FLAGS 44 | 45 | 46 | def main(unused_argv): 47 | if jax.process_index() == 0: 48 | logging.info('argv:\n%s', ' '.join(sys.argv)) 49 | logging.info('device_count: %d', jax.device_count()) 50 | logging.info('num_hosts : %d', jax.process_count()) 51 | logging.info('host_id : %d', jax.process_index()) 52 | 53 | if FLAGS.batch_size is None or FLAGS.batch_size <= 0: 54 | raise ValueError("""FLAGS.batch_size value is invalid, 55 | expected a positive non-zero integer.""") 56 | 57 | if FLAGS.dataset is None: 58 | raise ValueError("""FLAGS.dataset value is invalid, 59 | expected a non-empty string describing dataset name.""") 60 | 61 | batch_size = FLAGS.batch_size 62 | num_batches = FLAGS.num_batches 63 | dataset_name = FLAGS.dataset 64 | model_name = FLAGS.model 65 | initializer_name = 'noop' 66 | 67 | hparam_overrides = { 68 | 'batch_size': batch_size, 69 | } 70 | 71 | hps = hyperparameters.build_hparams( 72 | model_name=model_name, 73 | initializer_name=initializer_name, 74 | dataset_name=dataset_name, 75 | hparam_file=None, 76 | hparam_overrides=hparam_overrides) 77 | 78 | rng = jax.random.PRNGKey(0) 79 | rng, data_rng = jax.random.split(rng) 80 | 81 | dataset = datasets.get_dataset(FLAGS.dataset)(data_rng, batch_size, 82 | batch_size, hps) 83 | train_iter = dataset.train_iterator_fn() 84 | 85 | for i in range(num_batches): 86 | batch = next(train_iter) 87 | logging.info('train batch_num = %d, batch = %r', i, batch) 88 | 89 | for batch in dataset.valid_epoch(num_batches): 90 | logging.info('validation batch = %r', batch) 91 | 92 | 93 | if __name__ == '__main__': 94 | app.run(main) 95 | -------------------------------------------------------------------------------- /init2winit/trainer_lib/trainers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The init2winit Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Trainers for init2winit.""" 17 | 18 | from init2winit.trainer_lib import trainer 19 | 20 | _ALL_TRAINERS = { 21 | 'standard': trainer.Trainer, 22 | } 23 | 24 | 25 | def get_trainer_cls(trainer_name): 26 | """Maps trainer name to a Trainer class.""" 27 | try: 28 | return _ALL_TRAINERS[trainer_name] 29 | except KeyError: 30 | raise ValueError('Unrecognized trainer: {}'.format(trainer_name)) from None 31 | 32 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """init2winit. 2 | 3 | See more details in the 4 | [`README.md`](https://github.com/google/init2winit). 5 | """ 6 | 7 | from setuptools import find_packages 8 | from setuptools import setup 9 | 10 | setup( 11 | name='init2winit', 12 | version='0.0.1', 13 | description='init2winit', 14 | author='init2winit Team', 15 | author_email='znado@google.com', 16 | url='http://github.com/google/init2winit', 17 | license='Apache 2.0', 18 | packages=find_packages(), 19 | install_requires=[ 20 | 'absl-py>=0.8.1', 21 | 'clu', 22 | 'flax', 23 | 'jax', 24 | 'jax-bitempered-loss', 25 | 'jraph', 26 | 'ml_collections', 27 | 'numpy>=1.7', 28 | 'optax', 29 | 'optax-shampoo', 30 | 'pandas', 31 | 'sentencepiece', 32 | 'tensorboard', 33 | 'tensorflow-datasets', 34 | 'tensorflow-text==2.5.0-rc0', 35 | 'tensorflow==2.5.0', 36 | ], 37 | extras_require={}, 38 | classifiers=[ 39 | 'Development Status :: 3 - Alpha', 40 | 'Intended Audience :: Developers', 41 | 'Intended Audience :: Science/Research', 42 | 'License :: OSI Approved :: Apache Software License', 43 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 44 | ], 45 | keywords='jax machine learning', 46 | ) 47 | --------------------------------------------------------------------------------