├── .gitignore ├── LICENSE ├── README.md ├── config └── default.yaml ├── configs.py ├── data └── hellaswag_val.jsonl ├── dataset.py ├── main.py ├── main_multihost.py ├── model ├── __init__.py └── model.py ├── optimizers ├── __init__.py ├── adam.py ├── schedule_free.py └── tearfree │ ├── grafting.py │ ├── grafting_test.py │ ├── momentum.py │ ├── momentum_test.py │ ├── optimizer.py │ ├── optimizer_smoke_test.py │ ├── optimizer_test.py │ ├── praxis_shim.py │ ├── reallocation.py │ ├── reallocation_test.py │ ├── reshaper.py │ ├── reshaper_test.py │ ├── second_order.py │ ├── shampoo.py │ ├── shampoo_test.py │ ├── sketchy.py │ └── sketchy_test.py ├── requirements.txt ├── scripts ├── 125M.sh ├── 125M_mh_tpu.sh ├── 350M_mh_tpu.sh ├── delete_tpu_lockfile.sh ├── free_tpus.sh └── test.sh ├── sharding.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | /.idea 3 | wandb 4 | *.pyc 5 | .DS_Store 6 | scratch 7 | nohup.out -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Evan Walters 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # llm-jax 2 | 3 | Pretrain a SmolLM-style language model on the fineweb-edu dataset. A 350M param model can reach 51% hellaswag in only 250B tokens by using psgd kron optimizer and architecture improvements. 4 | 5 | Has various optimizers: PSGD Kron, adamw, shampoo, CASPR, and schedule-free. Any optimizer can be wrapped in 6 | schedule-free, see configs.py for more details. 7 | 8 | Only set up for pretraining right now, working on inference, conversion to pytorch, and uploading to huggingface hub. 9 | 10 | Saves checkpoints to out_dir, set same experiment name to resume. 11 | 12 | Set --profile to profile training to tensorboard, tensorboard dir is /profile. 13 | 14 | See configs.py for other settings and all hyperparameters. 15 | 16 | This repo is made possible by [Google's TRC program](https://sites.research.google/trc/about/). 17 | 18 | Started with [this repo, credit to @jenkspt](https://github.com/jenkspt/gpt-jax). Also pulled some tools 19 | from [big_vision](https://github.com/google-research/big_vision) to add FSDP sharding. 20 | 21 | Shoutout to @Grad62304977 for sharing model tips to improve training stability. 22 | 23 | 24 | ## Install 25 | 26 | Clone llm-jax 27 | ```shell 28 | git clone https://github.com/evanatyourservice/llm-jax.git 29 | ``` 30 | 31 | Install python dependencies TPU 32 | ```shell 33 | cd llm-jax && pip install -U pip && pip install -U -r requirements.txt && pip install --force-reinstall --upgrade --no-cache-dir 'jax[tpu]' 'jaxlib' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html && pip install 'numpy<2' 34 | ``` 35 | 36 | Install python dependencies GPU 37 | ```shell 38 | cd llm-jax && pip install -U pip && pip install -r requirements.txt && pip install --force-reinstall --upgrade --no-cache-dir 'jax[cuda12]' && pip install 'numpy<2' 39 | ``` 40 | 41 | 42 | ## Run 43 | 44 | See examples in /scripts like `scripts/125M_mh_tpu.sh`. 45 | 46 | create TPU using queued-resources 47 | ```shell 48 | gcloud compute tpus queued-resources create node-4 --node-id node-4 --project distributedmuzerojax --zone us-central2-b --accelerator-type v4-16 --runtime-version tpu-ubuntu2204-base --scopes https://www.googleapis.com/auth/cloud-platform 49 | ``` 50 | -------------------------------------------------------------------------------- /config/default.yaml: -------------------------------------------------------------------------------- 1 | """Example tyro YAML 2 | 3 | Unused but can be loaded with default=tyro.from_yaml(TrainConfig, yaml_filepath) in main.py 4 | or by adding filepath logic to get_default_config in configs.py. 5 | 6 | For example in get_default_config you could grab a yaml path from an env variable and return 7 | tyro.from_yaml(TrainConfig, env_var_yaml_path). Then set the env variable in a script. 8 | """ 9 | !dataclass:TrainConfig 10 | hellaswag_eval_interval: 500 11 | checkpoint_interval: 1000 12 | batch_size: 128 13 | train_steps: 100000 14 | compute_dtype: float32 15 | params_dtype: float32 16 | optimizer: !dataclass:OptimizerConfig 17 | type: "adamw" 18 | learning_rate: 0.001 19 | warmup_steps: 1000 20 | weight_decay: 0.1 21 | wandb: !dataclass:WandbConfig 22 | mode: online 23 | model: !dataclass:ModelConfig 24 | block_size: 2048 25 | scan_layers: False 26 | -------------------------------------------------------------------------------- /configs.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from datetime import datetime 3 | from typing import Optional 4 | 5 | 6 | @dataclass(frozen=True) 7 | class ModelConfig: 8 | """Default model config for 125M. 9 | 10 | Attributes: 11 | block_size: Block size. 12 | vocab_size: Vocabulary size. 13 | num_layers: Number of layers. 14 | num_heads: Number of attention heads. 15 | num_kv_heads: Number of key-value heads. 16 | head_dim: Head dimension. 17 | num_embeds: Number of embeddings. 18 | hidden_dim: Hidden dimension. 19 | rope_theta: Rotary embedding theta. 20 | scan_layers: Whether to scan layers. 21 | remat: Whether to use remat. Should be used if scanning layers. 22 | remat_everything: Whether to remat everything, otherwise only use 23 | `checkpoint_dots_with_no_batch_dims`. 24 | min_size_to_shard_mb: Minimum size of shards to create. 25 | """ 26 | 27 | block_size: int = 2048 28 | vocab_size: int = 32768 29 | num_layers: int = 30 30 | num_heads: int = 9 31 | num_kv_heads: int = 3 32 | head_dim: int = 576 // 9 33 | num_embeds: int = 576 34 | hidden_dim: int = 1536 35 | rope_theta: float = 1000000.0 36 | scan_layers: bool = False 37 | remat: bool = False 38 | remat_everything: bool = False 39 | min_size_to_shard_mb: int = 0.1 40 | 41 | 42 | @dataclass(frozen=True) 43 | class OptimizerConfig: 44 | """Optimizer configuration. 45 | 46 | Attributes: 47 | type: Optimizer type, one of ["adamw", "kron", "shampoo", "caspr"] 48 | schedule_free: Whether to wrap optimizer in schedule-free. 49 | learning_rate: Learning rate. 50 | warmup_steps: Warmup steps. 51 | flat_lr: Whether to use a flat learning rate or decay linearly to 0.05x. 52 | weight_decay: Weight decay. 53 | b1: Beta 1. 54 | b2: Beta 2. 55 | eps: Epsilon. 56 | nesterov: Whether to use nesterov momentum. 57 | preconditioner_update_probability: Probability of updating the 58 | preconditioner in PSGD. Default for PSGD kron is 0.03. 59 | max_size_triangular: Max dim size for preconditioner to be triangular 60 | in PSGD. 61 | memory_save_mode: Memory save mode for kron, one of 62 | [None, "one_diag", "all_diag"]. 63 | preconditioner_dtype: Dtype of the preconditioner in PSGD. 64 | lax_map_scanned_layers: Whether to use lax.map for scanned layers instead 65 | of vmap. Useful for large models (>1B) to save memory. 66 | lax_map_batch_size: Batch size for lax.map, see jax docs for more info. 67 | """ 68 | 69 | type: str = "kron" 70 | schedule_free: bool = False 71 | learning_rate: float = 0.001 72 | warmup_steps: int = 1000 73 | flat_lr: bool = False 74 | weight_decay: float = 0.1 75 | b1: float = 0.9 76 | b2: float = 0.95 77 | eps: float = 1e-8 78 | nesterov: bool = False 79 | preconditioner_update_probability: float = 0.03 80 | max_size_triangular: int = 8192 81 | memory_save_mode: Optional[str] = None 82 | preconditioner_dtype: str = "float32" 83 | lax_map_scanned_layers: bool = False 84 | lax_map_batch_size: int = 8 85 | 86 | 87 | @dataclass(frozen=True) 88 | class WandbConfig: 89 | """Wandb logging configuration.""" 90 | 91 | entity: str = "" 92 | project: str = "llm-jax" 93 | mode: str = "online" 94 | 95 | 96 | date_and_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 97 | 98 | 99 | @dataclass(frozen=True) 100 | class TrainConfig: 101 | """Training configuration. 102 | 103 | Attributes: 104 | experiment_name: Name of the experiment. 105 | out_dir: Output directory for checkpoints (can be gcs path). 106 | attempt_to_load_checkpoint: Whether to attempt to load a checkpoint. 107 | only_print_model: Whether to only print the model then quit. 108 | hellaswag_eval_interval: Interval to evaluate hellaswag. 109 | checkpoint_interval: Interval to save checkpoints. 110 | checkpoint_milestone: Milestone to save checkpoints. 111 | keep_checkpoints: Number of historical checkpoints to keep. 112 | batch_size: Batch size. 113 | train_steps: Total number of training iterations. 114 | gradient_accumulation_steps: Number of gradient accumulation steps. 115 | compute_dtype: Compute dtype. 116 | params_dtype: Params dtype. 117 | profile: Whether to profile the training to tensorboard. 118 | n_profile_steps: Number of steps to profile. 119 | optimizer: Optimizer config. 120 | wandb: Wandb logging config. 121 | model: Model config. 122 | """ 123 | 124 | seed: int = 10 125 | experiment_name: str = f"run_{date_and_time}" 126 | out_dir: str = "gs://optimizertesting/llm-jax" 127 | attempt_to_load_checkpoint: bool = True 128 | only_print_model: bool = False 129 | hellaswag_eval_interval: int = 1000 130 | checkpoint_interval: int = 1000 131 | keep_checkpoints: int = 1 132 | checkpoint_milestone: int = 25000 133 | batch_size: int = 256 134 | train_steps: int = 50000 135 | gradient_accumulation_steps: int = 1 136 | compute_dtype: str = "float32" 137 | params_dtype: str = "float32" 138 | profile: bool = False 139 | n_profile_steps: int = 5 140 | optimizer: OptimizerConfig = field(default_factory=OptimizerConfig) 141 | wandb: WandbConfig = field(default_factory=WandbConfig) 142 | model: ModelConfig = field(default_factory=ModelConfig) 143 | 144 | assert ( 145 | hellaswag_eval_interval % 100 == 0 146 | ), "Hellaswag_eval_interval must be a multiple of 100" 147 | 148 | 149 | def get_default_config() -> TrainConfig: 150 | return TrainConfig() 151 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from typing import Optional 4 | import numpy as np 5 | from tqdm import tqdm 6 | 7 | import jax 8 | import tensorflow as tf 9 | import datasets 10 | from datasets import load_dataset, IterableDataset 11 | import datasets.config 12 | from transformers import AutoTokenizer 13 | 14 | from utils import ( 15 | make_fsarray_from_local_slice, 16 | prefetch_iterator, 17 | threadstart_iterator, 18 | write_note, 19 | ) 20 | 21 | 22 | os.environ["TOKENIZERS_PARALLELISM"] = "true" 23 | datasets.config.STREAMING_READ_MAX_RETRIES = 17280 # 17280 * 5 = 1 day 24 | datasets.config.STREAMING_READ_RETRY_INTERVAL = 5 25 | 26 | OPTIONS = tf.data.Options() 27 | OPTIONS.deterministic = True 28 | OPTIONS.threading.private_threadpool_size = 48 29 | OPTIONS.threading.max_intra_op_parallelism = 1 30 | # Stop a whole bunch of magic stuff that eats up all RAM: 31 | OPTIONS.experimental_optimization.inject_prefetch = False 32 | 33 | 34 | TOKENIZER = "mistralai/Mistral-7B-v0.3" 35 | 36 | 37 | def prepare_hellaswag( 38 | batch_size: int, 39 | block_size: int, 40 | flat_devices, 41 | tf_prefetch: int = 2, 42 | device_prefetch: int = 0, 43 | ): 44 | """Read file and tokenize the hellaswag dataset.""" 45 | write_note("preparing hellaswag") 46 | 47 | tokenizer = AutoTokenizer.from_pretrained( 48 | TOKENIZER, trust_remote_code=True, use_fast=True 49 | ) 50 | 51 | all_data = [] 52 | all_beginning_lengths = [] 53 | all_seq_lengths = [] 54 | all_labels = [] 55 | with open("data/hellaswag_val.jsonl", "r") as f: 56 | # iterate over lines and tokenize 57 | for line in tqdm(f, total=10042): 58 | item = json.loads(line) 59 | 60 | context = item["ctx"] 61 | endings = item["endings"] 62 | correct_end = item["label"] 63 | 64 | beginning_length = len(tokenizer(context)["input_ids"]) 65 | 66 | data_to_concat = [] 67 | beginning_lengths_to_concat = [] 68 | seq_lengths_to_concat = [] 69 | for ending in endings: 70 | output = tokenizer(context + " " + ending)["input_ids"] 71 | output_len = len(output) 72 | 73 | # pad to block_size 74 | if output_len < block_size: 75 | output = output + [tokenizer.eos_token_id] * ( 76 | block_size - output_len 77 | ) 78 | # max length is block_size 79 | output = output[:block_size] 80 | 81 | data_to_concat.append(output) 82 | beginning_lengths_to_concat.append(beginning_length) 83 | seq_lengths_to_concat.append(output_len) 84 | 85 | all_data.append(np.array(data_to_concat, dtype=np.uint16)) 86 | all_beginning_lengths.append( 87 | np.array(beginning_lengths_to_concat, dtype=np.int32) 88 | ) 89 | all_seq_lengths.append(np.array(seq_lengths_to_concat, dtype=np.int32)) 90 | all_labels.append(int(correct_end)) 91 | 92 | all_data = np.array(all_data, dtype=np.uint16) 93 | all_beginning_lengths = np.array(all_beginning_lengths, dtype=np.int32) 94 | all_seq_lengths = np.array(all_seq_lengths, dtype=np.int32) 95 | all_labels = np.array(all_labels, dtype=np.int32) 96 | 97 | ds = tf.data.Dataset.from_tensor_slices( 98 | (all_data, all_beginning_lengths, all_seq_lengths, all_labels) 99 | ) 100 | ds = ds.shard(jax.process_count(), jax.process_index()) 101 | ds = ds.repeat() 102 | 103 | ds = ds.batch( 104 | batch_size // jax.process_count(), 105 | drop_remainder=True, 106 | num_parallel_calls=tf.data.AUTOTUNE, 107 | ) 108 | 109 | ds = ds.with_options(OPTIONS) 110 | ds = ds.prefetch(tf_prefetch) 111 | ds = ds.as_numpy_iterator() 112 | ds = iter(ds) 113 | # ds = threadstart_iterator(ds) 114 | ds = ( 115 | jax.tree.map(lambda x: make_fsarray_from_local_slice(x, flat_devices), elem) 116 | for elem in ds 117 | ) 118 | if device_prefetch > 0: 119 | ds = prefetch_iterator(ds, device_prefetch) 120 | return ds 121 | 122 | 123 | def fineweb_edu_dataset( 124 | batch_size: int, 125 | block_size: int, 126 | flat_devices, 127 | fineweb_edu_name: Optional[str] = None, 128 | tf_prefetch: int = 5, 129 | device_prefetch: int = 0, 130 | ): 131 | """Load the fineweb-edu dataset.""" 132 | platform = jax.devices()[0].platform 133 | # use /dev/shm if on a TPU vm for more space 134 | if platform == "tpu": 135 | cache_dir = "/dev/shm/huggingface_cache" 136 | else: 137 | cache_dir = None 138 | 139 | tokenizer = AutoTokenizer.from_pretrained( 140 | TOKENIZER, trust_remote_code=True, use_fast=True 141 | ) 142 | 143 | def gen(): 144 | hf_ds: IterableDataset = load_dataset( 145 | "HuggingFaceFW/fineweb-edu", 146 | split="train", 147 | name=fineweb_edu_name, 148 | cache_dir=cache_dir, 149 | streaming=True, 150 | ) 151 | 152 | def tokenize(example): 153 | # mistral tokenizer adds bos token to beginning 154 | tokenized = tokenizer(example)["input_ids"] 155 | # cap tokenized lengths to 10 * block_size to prevent too much 156 | # similarity between blocks in a batch or group of batches 157 | tokenized = [t[: 10 * block_size] for t in tokenized] 158 | return {"tokens": tokenized} 159 | 160 | hf_ds = hf_ds.map(tokenize, input_columns="text", batched=True, batch_size=128) 161 | 162 | hf_ds = hf_ds.with_format("numpy") 163 | 164 | for example in hf_ds: 165 | yield example["tokens"].astype(np.uint16) 166 | 167 | ds = tf.data.Dataset.from_generator( 168 | gen, output_signature=tf.TensorSpec(shape=(None,), dtype=tf.uint16) 169 | ) 170 | ds = ds.shuffle(128) # shuffle dataset examples 171 | ds = ds.unbatch() 172 | ds = ds.batch(block_size, drop_remainder=True, num_parallel_calls=tf.data.AUTOTUNE) 173 | ds = ds.shuffle(20 * 1024) # shuffle blocks 174 | ds = ds.batch( 175 | batch_size // jax.process_count(), 176 | drop_remainder=True, 177 | num_parallel_calls=tf.data.AUTOTUNE, 178 | ) 179 | ds = ds.with_options(OPTIONS) 180 | ds = ds.prefetch(tf_prefetch) 181 | ds = ds.as_numpy_iterator() 182 | ds = iter(ds) 183 | # ds = threadstart_iterator(ds) 184 | ds = ( 185 | jax.tree.map(lambda x: make_fsarray_from_local_slice(x, flat_devices), elem) 186 | for elem in ds 187 | ) 188 | if device_prefetch > 0: 189 | ds = prefetch_iterator(ds, device_prefetch) 190 | return ds 191 | 192 | 193 | # fineweb-edu has 96 shards 194 | _fw_shard_names = [ 195 | "CC-MAIN-2024-10", 196 | "CC-MAIN-2023-50", 197 | "CC-MAIN-2023-40", 198 | "CC-MAIN-2023-23", 199 | "CC-MAIN-2023-14", 200 | "CC-MAIN-2023-06", 201 | "CC-MAIN-2022-49", 202 | "CC-MAIN-2022-40", 203 | "CC-MAIN-2022-33", 204 | "CC-MAIN-2022-27", 205 | "CC-MAIN-2022-21", 206 | "CC-MAIN-2022-05", 207 | "CC-MAIN-2021-49", 208 | "CC-MAIN-2021-43", 209 | "CC-MAIN-2021-39", 210 | "CC-MAIN-2021-31", 211 | "CC-MAIN-2021-25", 212 | "CC-MAIN-2021-21", 213 | "CC-MAIN-2021-17", 214 | "CC-MAIN-2021-10", 215 | "CC-MAIN-2021-04", 216 | "CC-MAIN-2020-50", 217 | "CC-MAIN-2020-45", 218 | "CC-MAIN-2020-40", 219 | "CC-MAIN-2020-34", 220 | "CC-MAIN-2020-29", 221 | "CC-MAIN-2020-24", 222 | "CC-MAIN-2020-16", 223 | "CC-MAIN-2020-10", 224 | "CC-MAIN-2020-05", 225 | "CC-MAIN-2019-51", 226 | "CC-MAIN-2019-47", 227 | "CC-MAIN-2019-43", 228 | "CC-MAIN-2019-39", 229 | "CC-MAIN-2019-35", 230 | "CC-MAIN-2019-30", 231 | "CC-MAIN-2019-26", 232 | "CC-MAIN-2019-22", 233 | "CC-MAIN-2019-18", 234 | "CC-MAIN-2019-13", 235 | "CC-MAIN-2019-09", 236 | "CC-MAIN-2019-04", 237 | "CC-MAIN-2018-51", 238 | "CC-MAIN-2018-47", 239 | "CC-MAIN-2018-43", 240 | "CC-MAIN-2018-39", 241 | "CC-MAIN-2018-34", 242 | "CC-MAIN-2018-30", 243 | "CC-MAIN-2018-26", 244 | "CC-MAIN-2018-22", 245 | "CC-MAIN-2018-17", 246 | "CC-MAIN-2018-13", 247 | "CC-MAIN-2018-09", 248 | "CC-MAIN-2018-05", 249 | "CC-MAIN-2017-51", 250 | "CC-MAIN-2017-47", 251 | "CC-MAIN-2017-43", 252 | "CC-MAIN-2017-39", 253 | "CC-MAIN-2017-34", 254 | "CC-MAIN-2017-30", 255 | "CC-MAIN-2017-26", 256 | "CC-MAIN-2017-22", 257 | "CC-MAIN-2017-17", 258 | "CC-MAIN-2017-13", 259 | "CC-MAIN-2017-09", 260 | "CC-MAIN-2017-04", 261 | "CC-MAIN-2016-50", 262 | "CC-MAIN-2016-44", 263 | "CC-MAIN-2016-40", 264 | "CC-MAIN-2016-36", 265 | "CC-MAIN-2016-30", 266 | "CC-MAIN-2016-26", 267 | "CC-MAIN-2016-22", 268 | "CC-MAIN-2016-18", 269 | "CC-MAIN-2016-07", 270 | "CC-MAIN-2015-48", 271 | "CC-MAIN-2015-40", 272 | "CC-MAIN-2015-35", 273 | "CC-MAIN-2015-32", 274 | "CC-MAIN-2015-27", 275 | "CC-MAIN-2015-22", 276 | "CC-MAIN-2015-18", 277 | "CC-MAIN-2015-14", 278 | "CC-MAIN-2015-11", 279 | "CC-MAIN-2015-06", 280 | "CC-MAIN-2014-52", 281 | "CC-MAIN-2014-49", 282 | "CC-MAIN-2014-42", 283 | "CC-MAIN-2014-41", 284 | "CC-MAIN-2014-35", 285 | "CC-MAIN-2014-23", 286 | "CC-MAIN-2014-15", 287 | "CC-MAIN-2014-10", 288 | "CC-MAIN-2013-48", 289 | "CC-MAIN-2013-20", 290 | ] 291 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import tyro 2 | 3 | from configs import TrainConfig, get_default_config 4 | from train import main 5 | 6 | 7 | if __name__ == "__main__": 8 | config = tyro.cli(TrainConfig, default=get_default_config(), use_underscores=True) 9 | main(config) 10 | -------------------------------------------------------------------------------- /main_multihost.py: -------------------------------------------------------------------------------- 1 | import tyro 2 | import jax 3 | 4 | from configs import TrainConfig, get_default_config 5 | from train import main 6 | 7 | 8 | if __name__ == "__main__": 9 | jax.distributed.initialize() 10 | 11 | config = tyro.cli(TrainConfig, default=get_default_config(), use_underscores=True) 12 | main(config) 13 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/evanatyourservice/llm-jax/fe296a2f8d628889c59e8cb442bc3b1f84277607/model/__init__.py -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import numpy as np 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | from jax.sharding import Mesh, NamedSharding as NS, PartitionSpec as P 7 | import flax.linen as nn 8 | 9 | from configs import ModelConfig 10 | 11 | 12 | init_fn = lambda dim: nn.initializers.normal(jnp.sqrt(2 / (5 * dim))) 13 | wang_fn = lambda dim, n_layers: nn.initializers.normal(2 / n_layers / jnp.sqrt(dim)) 14 | constrain = lambda x, mesh, spec: jax.lax.with_sharding_constraint(x, NS(mesh, spec)) 15 | 16 | 17 | class RMSNorm(nn.Module): 18 | """RMSNorm layer. 19 | 20 | Upcasts to float32 and back.""" 21 | 22 | @nn.compact 23 | def __call__(self, x): 24 | var = jnp.mean(jnp.square(x.astype(jnp.float32)), axis=-1, keepdims=True) 25 | normed_inputs = x * jax.lax.rsqrt(var + 1e-06) 26 | normed_inputs = normed_inputs.astype(x.dtype) 27 | 28 | scale = self.param("scale", nn.initializers.zeros_init(), (x.shape[-1])) 29 | scale = jnp.expand_dims(scale, axis=range(len(x.shape) - 1)) 30 | normed_inputs = normed_inputs * (1 + scale) 31 | return normed_inputs 32 | 33 | 34 | class Embedder(nn.Module): 35 | """Embedder module.""" 36 | 37 | vocab_size: int 38 | embed_dim: int 39 | mesh: Mesh 40 | 41 | def setup(self): 42 | self.embedding = self.param( 43 | "embedding", init_fn(self.embed_dim), (self.vocab_size, self.embed_dim) 44 | ) 45 | 46 | def encode(self, x: jax.Array) -> jax.Array: 47 | x = jnp.take(self.embedding, x, axis=0) 48 | if self.mesh is not None: 49 | x = constrain(x, self.mesh, P("fsdp")) 50 | x *= jnp.sqrt(self.embed_dim).astype(x.dtype) 51 | return x 52 | 53 | def decode(self, x: jax.Array) -> jax.Array: 54 | x = jnp.dot(x, self.embedding.T) 55 | if self.mesh is not None: 56 | x = constrain(x, self.mesh, P("fsdp")) 57 | x = jnp.tanh(x / 30) * 30 58 | return x 59 | 60 | 61 | def _get_large_negative(dtype): 62 | dtype_max = jnp.finfo(dtype).max 63 | return jnp.asarray(-0.7 * dtype_max, dtype=dtype) 64 | 65 | 66 | def _get_causal_mask(T, S): 67 | mask = jnp.tril(jnp.ones((T, S), dtype=jnp.bool_)) 68 | return mask[None, None, :, :] 69 | 70 | 71 | def _dot_product_attention_core(query, key, value): 72 | head_dim = query.shape[-1] 73 | query *= jax.lax.rsqrt(jnp.array(head_dim, dtype=jnp.float32)).astype(query.dtype) 74 | logits = jnp.einsum("BTNH,BSNH->BNTS", query, key) 75 | logits = jnp.tanh(logits / 50) * 50 76 | causal_mask = _get_causal_mask(logits.shape[-2], logits.shape[-1]) 77 | logits = jnp.where(causal_mask, logits, _get_large_negative(logits.dtype)) 78 | probs = jax.nn.softmax(logits.astype(jnp.float32)).astype(logits.dtype) 79 | encoded = jnp.einsum("BNTS,BSNH->BTNH", probs, value) 80 | return encoded 81 | 82 | 83 | def _sine_table(features, length, min_timescale=1.0, max_timescale=10000.0): 84 | fraction = jnp.arange(0, features, 2, dtype=jnp.float32) / features 85 | timescale = min_timescale * (max_timescale / min_timescale) ** fraction 86 | rotational_frequency = 1.0 / timescale 87 | # Must use high precision einsum here, bfloat16 rounding is catastrophic. 88 | sinusoid_inp = jnp.einsum( 89 | "i,j->ij", 90 | jnp.arange(length), 91 | rotational_frequency, 92 | precision=jax.lax.Precision.HIGHEST, 93 | ) 94 | sinusoid_inp = jnp.concatenate([sinusoid_inp, sinusoid_inp], axis=-1) 95 | return jnp.sin(sinusoid_inp), jnp.cos(sinusoid_inp) 96 | 97 | 98 | def _rotate_half(x): 99 | x1, x2 = jnp.split(x, 2, axis=-1) 100 | x = jnp.concatenate([-x2, x1], axis=-1) 101 | return x 102 | 103 | 104 | def _apply_rotary_embedding(q, k, cos, sin): 105 | # come in as (B, T, K, G, H) and (B, T, K, H) 106 | qlen = q.shape[-4] 107 | klen = k.shape[-3] 108 | 109 | qcos = jnp.expand_dims(cos[:qlen, :], range(len(q.shape) - 2)) 110 | qsin = jnp.expand_dims(sin[:qlen, :], range(len(q.shape) - 2)) 111 | kcos = jnp.expand_dims(cos[:klen, :], range(len(k.shape) - 2)) 112 | ksin = jnp.expand_dims(sin[:klen, :], range(len(k.shape) - 2)) 113 | 114 | qcos = jnp.swapaxes(qcos, -2, -4) 115 | qsin = jnp.swapaxes(qsin, -2, -4) 116 | kcos = jnp.swapaxes(kcos, -2, -3) 117 | ksin = jnp.swapaxes(ksin, -2, -3) 118 | 119 | # done in float32 120 | out_q = q * qcos + _rotate_half(q) * qsin 121 | out_k = k * kcos + _rotate_half(k) * ksin 122 | 123 | return out_q.astype(q.dtype), out_k.astype(k.dtype) 124 | 125 | 126 | class Attention(nn.Module): 127 | """Multi-head attention with RoPE and GQA. 128 | 129 | Upcasts to float32 and back for softmax.""" 130 | 131 | num_heads: int 132 | num_kv_heads: int 133 | head_dim: int 134 | rope_theta: float 135 | n_layers: int 136 | mesh: Mesh 137 | 138 | @nn.compact 139 | def __call__(self, x): 140 | B, T, C = x.shape 141 | N = self.num_heads 142 | K = self.num_kv_heads 143 | G = N // K 144 | H = self.head_dim 145 | 146 | q_params = self.param("q_kernel", init_fn(C), (C, N * H)) 147 | k_params = self.param("k_kernel", init_fn(C), (C, K * H)) 148 | v_params = self.param("v_kernel", init_fn(C), (C, K * H)) 149 | out_params = self.param("out_kernel", wang_fn(N * H, self.n_layers), (N * H, C)) 150 | 151 | q = jnp.dot(x, q_params) 152 | k = jnp.dot(x, k_params) 153 | v = jnp.dot(x, v_params) 154 | 155 | q = jnp.reshape(q, (B, T, K, G, H)) 156 | k = jnp.reshape(k, (B, T, K, H)) 157 | v = jnp.reshape(v, (B, T, K, H)) 158 | 159 | sin, cos = _sine_table(H, T, max_timescale=self.rope_theta) 160 | q, k = _apply_rotary_embedding(q, k, cos, sin) 161 | 162 | vmapped_fn = jax.vmap( 163 | _dot_product_attention_core, in_axes=(3, None, None), out_axes=3 164 | ) 165 | encoded = vmapped_fn(q, k, v) 166 | encoded = jnp.reshape(encoded, (B, T, N * H)) 167 | out = jnp.dot(encoded, out_params) 168 | if self.mesh is not None: 169 | out = constrain(out, self.mesh, P("fsdp")) 170 | return RMSNorm()(out) # normformer 171 | 172 | 173 | class MLP(nn.Module): 174 | hidden_dim: int 175 | n_layers: int 176 | mesh: Mesh 177 | 178 | @nn.compact 179 | def __call__(self, x): 180 | C = x.shape[-1] 181 | 182 | gate_kernel = self.param("gate_kernel", init_fn(C), (C, self.hidden_dim)) 183 | up_kernel = self.param("up_kernel", init_fn(C), (C, self.hidden_dim)) 184 | down_kernel = self.param( 185 | "down_kernel", wang_fn(self.hidden_dim, self.n_layers), (self.hidden_dim, C) 186 | ) 187 | 188 | gate = jnp.dot(x, gate_kernel) 189 | gate = nn.silu(gate) 190 | 191 | up = jnp.dot(x, up_kernel) 192 | x = gate * up 193 | 194 | x = RMSNorm()(x) # normformer 195 | 196 | down = jnp.dot(x, down_kernel) 197 | if self.mesh is not None: 198 | down = constrain(down, self.mesh, P("fsdp")) 199 | return down 200 | 201 | 202 | class Block(nn.Module): 203 | """Transformer block.""" 204 | 205 | num_heads: int 206 | num_kv_heads: int 207 | head_dim: int 208 | hidden_dim: int 209 | rope_theta: float 210 | n_layers: int 211 | mesh: Mesh 212 | use_scan: bool = False 213 | 214 | @nn.compact 215 | def __call__(self, x): 216 | attn_layer = Attention( 217 | self.num_heads, 218 | self.num_kv_heads, 219 | self.head_dim, 220 | self.rope_theta, 221 | self.n_layers, 222 | self.mesh, 223 | ) 224 | x += attn_layer(RMSNorm()(x)) 225 | x += MLP(self.hidden_dim, self.n_layers, self.mesh)(RMSNorm()(x)) 226 | if self.use_scan: 227 | return (x, None) 228 | return x 229 | 230 | 231 | class Transformer(nn.Module): 232 | config: ModelConfig 233 | mesh: Mesh = None 234 | using_grad_accum: bool = False 235 | 236 | @nn.compact 237 | def __call__(self, tokens): 238 | remat_policy = None 239 | if not self.config.remat_everything: 240 | remat_policy = jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims 241 | 242 | if self.config.remat: 243 | embedder = nn.remat( 244 | Embedder, prevent_cse=not self.using_grad_accum, policy=remat_policy 245 | )( 246 | self.config.vocab_size, 247 | self.config.num_embeds, 248 | self.mesh, 249 | ) 250 | else: 251 | embedder = Embedder( 252 | self.config.vocab_size, 253 | self.config.num_embeds, 254 | self.mesh, 255 | ) 256 | 257 | x = embedder.encode(tokens) 258 | 259 | if self.config.remat: 260 | prevent_cse = True 261 | if self.using_grad_accum or self.config.scan_layers: 262 | prevent_cse = False 263 | BlockModule = nn.remat(Block, prevent_cse=prevent_cse, policy=remat_policy) 264 | else: 265 | BlockModule = Block 266 | 267 | if self.config.scan_layers: 268 | x, _ = nn.scan( 269 | BlockModule, 270 | variable_axes={True: 0}, 271 | split_rngs={True: True}, 272 | length=self.config.num_layers, 273 | )( 274 | self.config.num_heads, 275 | self.config.num_kv_heads, 276 | self.config.head_dim, 277 | self.config.hidden_dim, 278 | self.config.rope_theta, 279 | self.config.num_layers, 280 | self.mesh, 281 | use_scan=True, 282 | )( 283 | x 284 | ) 285 | else: 286 | for _ in range(self.config.num_layers): 287 | x = BlockModule( 288 | self.config.num_heads, 289 | self.config.num_kv_heads, 290 | self.config.head_dim, 291 | self.config.hidden_dim, 292 | self.config.rope_theta, 293 | self.config.num_layers, 294 | self.mesh, 295 | )(x) 296 | 297 | x = RMSNorm()(x) 298 | logits = embedder.decode(x) 299 | return logits 300 | -------------------------------------------------------------------------------- /optimizers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/evanatyourservice/llm-jax/fe296a2f8d628889c59e8cb442bc3b1f84277607/optimizers/__init__.py -------------------------------------------------------------------------------- /optimizers/adam.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Any, Callable, Union 2 | 3 | import chex 4 | import jax 5 | import jax.numpy as jnp 6 | from optax._src import base, numerics, utils, transform, combine 7 | from optax import tree_utils as otu 8 | 9 | 10 | def scale_by_adam( 11 | b1: float = 0.9, 12 | b2: float = 0.999, 13 | eps: float = 1e-8, 14 | eps_root: float = 0.0, 15 | mu_dtype: Optional[chex.ArrayDType] = None, 16 | nesterov: bool = False, 17 | ) -> base.GradientTransformation: 18 | """Same as optax version but doesn't create momentum buffer if b1 == 0.""" 19 | mu_dtype = utils.canonicalize_dtype(mu_dtype) 20 | 21 | def init_fn(params): 22 | if b1 > 0: 23 | mu = otu.tree_zeros_like(params, dtype=mu_dtype) 24 | else: 25 | mu = None 26 | nu = otu.tree_zeros_like(params) 27 | state = transform.ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) 28 | 29 | # Calculate sizes for nu (preconditioner) and mu (momentum) 30 | nu_n_elements = sum(leaf.size for leaf in jax.tree.leaves(nu)) 31 | nu_size_MB = sum( 32 | leaf.size * leaf.dtype.itemsize / (2**20) for leaf in jax.tree.leaves(nu) 33 | ) 34 | if jax.process_index() == 0: 35 | print( 36 | f"Adam Preconditioner (nu) size: {nu_n_elements} elements, {nu_size_MB:.2f} MB" 37 | ) 38 | if mu is not None: 39 | mu_n_elements = sum(leaf.size for leaf in jax.tree.leaves(mu)) 40 | mu_size_MB = sum( 41 | leaf.size * leaf.dtype.itemsize / (2**20) 42 | for leaf in jax.tree.leaves(mu) 43 | ) 44 | if jax.process_index() == 0: 45 | print( 46 | f"Adam Momentum (mu) size: {mu_n_elements} elements, {mu_size_MB:.2f} MB" 47 | ) 48 | 49 | return state 50 | 51 | def update_fn(updates, state, params=None): 52 | del params 53 | count_inc = numerics.safe_int32_increment(state.count) 54 | if b1 > 0: 55 | mu = otu.tree_update_moment(updates, state.mu, b1, 1) 56 | if nesterov: 57 | mu_hat = jax.tree.map( 58 | lambda m, g: b1 * m + (1 - b1) * g, 59 | otu.tree_bias_correction( 60 | mu, b1, numerics.safe_int32_increment(count_inc) 61 | ), 62 | otu.tree_bias_correction(updates, b1, count_inc), 63 | ) 64 | else: 65 | mu_hat = otu.tree_bias_correction(mu, b1, count_inc) 66 | mu = otu.tree_cast(mu, mu_dtype) 67 | else: 68 | mu = None 69 | mu_hat = updates 70 | nu = otu.tree_update_moment_per_elem_norm(updates, state.nu, b2, 2) 71 | nu_hat = otu.tree_bias_correction(nu, b2, count_inc) 72 | updates = jax.tree.map( 73 | lambda m, v: m / (jnp.sqrt(v + eps_root) + eps), mu_hat, nu_hat 74 | ) 75 | return updates, transform.ScaleByAdamState(count=count_inc, mu=mu, nu=nu) 76 | 77 | return base.GradientTransformation(init_fn, update_fn) 78 | 79 | 80 | def adamw( 81 | learning_rate: base.ScalarOrSchedule, 82 | b1: float = 0.9, 83 | b2: float = 0.999, 84 | eps: float = 1e-8, 85 | eps_root: float = 0.0, 86 | mu_dtype: Optional[Any] = None, 87 | weight_decay: float = 1e-4, 88 | mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None, 89 | nesterov: bool = False, 90 | ) -> base.GradientTransformation: 91 | return combine.chain( 92 | scale_by_adam( 93 | b1=b1, 94 | b2=b2, 95 | eps=eps, 96 | eps_root=eps_root, 97 | mu_dtype=mu_dtype, 98 | nesterov=nesterov, 99 | ), 100 | transform.add_decayed_weights(weight_decay, mask), 101 | transform.scale_by_learning_rate(learning_rate), 102 | ) 103 | -------------------------------------------------------------------------------- /optimizers/schedule_free.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Schedule-Free wrapper for faster training & removes the need for lr decay.""" 16 | 17 | from typing import NamedTuple, Optional 18 | 19 | import chex 20 | import jax 21 | import jax.numpy as jnp 22 | from optax._src import alias 23 | from optax._src import base 24 | from optax._src import combine 25 | from optax._src import transform 26 | from optax.schedules import _schedule 27 | from optax.transforms import _adding 28 | 29 | 30 | class ScheduleFreeState(NamedTuple): 31 | b1: chex.Array 32 | weight_sum: chex.Array 33 | step_count: chex.Array 34 | max_lr: chex.Array 35 | base_optimizer_state: base.OptState 36 | z: base.Params 37 | 38 | 39 | def schedule_free_eval_params(state: ScheduleFreeState, params: base.Params): 40 | return jax.tree_util.tree_map( 41 | lambda yi, zi: (yi - (1.0 - state.b1) * zi) / state.b1, params, state.z 42 | ) 43 | 44 | 45 | def schedule_free( 46 | base_optimizer: base.GradientTransformation, 47 | learning_rate: base.ScalarOrSchedule, 48 | b1: float = 0.9, 49 | weight_lr_power: float = 2.0, 50 | state_dtype=jnp.float32, 51 | ) -> base.GradientTransformationExtraArgs: 52 | base_optimizer = base.with_extra_args_support(base_optimizer) 53 | 54 | def init_fn(params: base.Params) -> ScheduleFreeState: 55 | if b1 == 0: 56 | raise ValueError( 57 | "The current implementation of schedule_free requires b1 > 0." 58 | ) 59 | z = jax.tree_util.tree_map(lambda t: t.astype(state_dtype), params) 60 | 61 | # Calculate and print size for z 62 | z_n_elements = sum(leaf.size for leaf in jax.tree.leaves(z)) 63 | z_size_MB = sum( 64 | leaf.size * leaf.dtype.itemsize / (2**20) for leaf in jax.tree.leaves(z) 65 | ) 66 | if jax.process_index() == 0: 67 | print( 68 | f"Schedule-Free Z buffer size: {z_n_elements} elements, " 69 | f"{z_size_MB:.2f} MB" 70 | ) 71 | 72 | return ScheduleFreeState( 73 | b1=jnp.array(b1, dtype=jnp.float32), 74 | weight_sum=jnp.zeros([], dtype=jnp.float32), 75 | step_count=jnp.ones([], dtype=jnp.int32), 76 | max_lr=jnp.zeros([], dtype=jnp.float32), 77 | base_optimizer_state=base_optimizer.init(params), 78 | z=z, 79 | ) 80 | 81 | def update_fn( 82 | grads: base.Updates, 83 | state: ScheduleFreeState, 84 | params: Optional[base.Params] = None, 85 | **extra_args, 86 | ): 87 | lr = learning_rate 88 | if callable(learning_rate): 89 | lr = learning_rate(state.step_count) 90 | max_lr = jnp.maximum(state.max_lr, lr) 91 | 92 | next_step_count = state.step_count + 1 93 | 94 | weight = max_lr**weight_lr_power 95 | next_total_weight = state.weight_sum + weight 96 | # We add this to avoid NaNs in the case of a small learning rate. 97 | ck = jnp.where( 98 | jnp.logical_or(jnp.isnan(weight), jnp.isnan(next_total_weight)), 99 | jnp.full(weight.shape, jnp.nan), 100 | jnp.nan_to_num(weight / next_total_weight, nan=0.0, posinf=jnp.inf), 101 | ) 102 | 103 | base_updates, next_base_optimizer_state = base_optimizer.update( 104 | grads, state.base_optimizer_state, params, **extra_args 105 | ) 106 | z = jax.tree_util.tree_map( 107 | lambda pi, ui: jnp.asarray(pi + ui).astype(jnp.asarray(pi).dtype), 108 | state.z, 109 | base_updates, 110 | ) 111 | 112 | # Important: recompute x to both save memory and maintain accurate x seq 113 | # especially if y is modified by another transform wrapped on top. 114 | prev_x = jax.tree_util.tree_map( 115 | lambda yi, zi: (yi - (1.0 - b1) * zi) / b1, params, state.z 116 | ) 117 | 118 | x = jax.tree_util.tree_map(lambda xi, zi: (1.0 - ck) * xi + ck * zi, prev_x, z) 119 | new_params = jax.tree_util.tree_map( 120 | lambda xi, zi: b1 * xi + (1.0 - b1) * zi, x, z 121 | ) 122 | updates = jax.tree_util.tree_map(lambda npi, pi: npi - pi, new_params, params) 123 | 124 | next_state = ScheduleFreeState( 125 | b1=jnp.array(b1, dtype=jnp.float32), 126 | weight_sum=next_total_weight, 127 | step_count=next_step_count, 128 | max_lr=max_lr, 129 | base_optimizer_state=next_base_optimizer_state, 130 | z=z, 131 | ) 132 | 133 | return updates, next_state 134 | 135 | return base.GradientTransformationExtraArgs(init_fn, update_fn) 136 | 137 | 138 | def schedule_free_sgd( 139 | learning_rate: float = 1.0, 140 | *, 141 | warmup_steps: int = 0, 142 | b1: float = 0.9, 143 | weight_decay: float = 0.0, 144 | weight_lr_power: float = 2.0, 145 | state_dtype=jnp.float32, 146 | ) -> base.GradientTransformationExtraArgs: 147 | if warmup_steps > 0: 148 | learning_rate = _schedule.warmup_constant_schedule( 149 | init_value=0, peak_value=learning_rate, warmup_steps=warmup_steps 150 | ) 151 | optimizer = alias.sgd(learning_rate) 152 | if weight_decay > 0: 153 | optimizer = combine.chain(_adding.add_decayed_weights(weight_decay), optimizer) 154 | return schedule_free( 155 | optimizer, 156 | learning_rate=learning_rate, 157 | b1=b1, 158 | weight_lr_power=weight_lr_power, 159 | state_dtype=state_dtype, 160 | ) 161 | 162 | 163 | def schedule_free_adamw( 164 | learning_rate: float = 0.0025, 165 | *, 166 | warmup_steps: int = 0, 167 | b1: float = 0.9, 168 | b2: float = 0.999, 169 | eps: float = 1e-8, 170 | weight_decay: float = 0.0, 171 | weight_lr_power: float = 2.0, 172 | state_dtype=jnp.float32, 173 | ) -> base.GradientTransformationExtraArgs: 174 | if warmup_steps > 0: 175 | learning_rate = _schedule.warmup_constant_schedule( 176 | init_value=0, peak_value=learning_rate, warmup_steps=warmup_steps 177 | ) 178 | # The following is the same as adamw, but with the momentum term removed. 179 | optimizer = combine.chain( 180 | transform.scale_by_rms( 181 | decay=b2, eps=eps, eps_in_sqrt=False, bias_correction=True 182 | ), 183 | _adding.add_decayed_weights(weight_decay), 184 | transform.scale_by_learning_rate(learning_rate), 185 | ) 186 | return schedule_free( 187 | optimizer, 188 | learning_rate=learning_rate, 189 | b1=b1, 190 | weight_lr_power=weight_lr_power, 191 | state_dtype=state_dtype, 192 | ) 193 | -------------------------------------------------------------------------------- /optimizers/tearfree/grafting.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Grafting norm adjustment (https://openreview.net/forum?id=FpKgG31Z_i9).""" 16 | 17 | import copy 18 | import dataclasses 19 | import enum 20 | import functools 21 | from typing import Any, NamedTuple 22 | 23 | import chex 24 | from flax import struct 25 | import jax 26 | from jax import numpy as jnp 27 | import optax 28 | from optimizers.tearfree import praxis_shim 29 | 30 | 31 | @enum.unique 32 | class GraftingType(enum.Enum): 33 | """Different grafting types.""" 34 | 35 | NONE = "none" 36 | SGD = "sgd" 37 | RMSPROP = "rmsprop" 38 | ADAFACTOR = "adafactor" 39 | 40 | 41 | @dataclasses.dataclass 42 | class Options: 43 | """Grafting configuration to change norms for updates. 44 | 45 | A grafting update is computed as if it was running alongside the 46 | tearfree optimizer. Its norm is used for updates. During the initial 47 | few steps before preconditioning is applied, the grafting update 48 | is used entirely. 49 | 50 | Note that the grafting optimizer is ignorant of both weight decay, 51 | learning rate, and the momentum. 52 | 53 | Attributes: 54 | grafting_type: Which optimizer to use for grafting updates. 55 | second_moment_decay: Second moment accumulator decay. For ADA-Factor, this 56 | value must be bounded between (0, 1). For RMSProp, the second moment 57 | accumulator becomes sum if set to 1 (i.e., Adagrad), should be in (0, 1]. 58 | Must be 0 if unused (e.g., for SGD/NONE). 59 | start_preconditioning_step: When to start applying preconditioning. 60 | epsilon: Avoids divide by zero in RMSProp and ADA-Factor by adding this term 61 | to the expression `(epsilon + acc)^(-1/2)` when taking the inverse square 62 | root of the accumulator; should be non-negative. 63 | skip_preconditioning_any_dim_gt: Skip second-order preconditioning if any 64 | dimension of a tensor is greater than this value (only apply grafting 65 | update). Argument ignored if NONE grafting. 66 | skip_preconditioning_rank1: Skip preconditioning the tensor if the rank is 1 67 | or less. Argument ignored if NONE grafting. 68 | min_dim_size_to_factor: (Applies to ADA-Factor Only.) Only factor the 69 | statistics if two array dimensions have at least this size. 70 | multiply_by_parameter_scale: (Applies to ADA-Factor Only.) If True, then 71 | scale learning_rate by parameter norm. If False, provided learning_rate is 72 | absolute step size. 73 | clipping_threshold: (Applies to ADA-Factor Only.) Clipping 74 | threshold. Must be >= 1. 75 | """ 76 | 77 | grafting_type: GraftingType = GraftingType.RMSPROP 78 | second_moment_decay: float = 0.999 79 | start_preconditioning_step: int = 0 80 | epsilon: float = 1e-8 81 | skip_preconditioning_any_dim_gt: int = 8192 82 | skip_preconditioning_rank1: bool = True 83 | min_dim_size_to_factor: int = 128 84 | multiply_by_parameter_scale: float = True 85 | clipping_threshold: float = 1.0 86 | 87 | 88 | def graft( 89 | options: Options, direction: praxis_shim.ShardedGradientTransformation 90 | ) -> praxis_shim.ShardedGradientTransformation: 91 | """Generate the grafting update from options and direction update. 92 | 93 | Args: 94 | options: The grafting options. 95 | direction: A sharded gradient transformation which determines the direction 96 | of the update (grafting, if applied, changes the norm of this update). 97 | 98 | Returns: 99 | The wrapped transformation which applies either the grafting update 100 | directly or the `direction` update with grafting norm, depending on the 101 | current step or whether to apply grafting/preconditioning at all. 102 | """ 103 | _validate(options) 104 | 105 | if options.grafting_type == GraftingType.NONE: 106 | return direction 107 | 108 | if options.grafting_type == GraftingType.SGD: 109 | return _graft_with(direction, _sgd(), options) 110 | 111 | if options.grafting_type == GraftingType.RMSPROP: 112 | return _graft_with(direction, _rmsprop(options), options) 113 | 114 | if options.grafting_type == GraftingType.ADAFACTOR: 115 | return _graft_with(direction, _adafactor(options), options) 116 | # check options for validity (SGD/none and no 2nd moment, appropriate range) 117 | # test to check sharded gradient transform is otherwise identical to 118 | # praxis' 119 | raise NotImplementedError 120 | 121 | 122 | def _validate(options: Options): 123 | """Raise ValueError if the options have an invalid specification.""" 124 | if options.grafting_type in [GraftingType.RMSPROP, GraftingType.ADAFACTOR]: 125 | if options.epsilon < 0: 126 | raise ValueError( 127 | "epsilon ({}) should be non-negative".format(options.epsilon) 128 | ) 129 | if options.grafting_type == GraftingType.RMSPROP: 130 | if not (0 < options.second_moment_decay <= 1.0): 131 | raise ValueError( 132 | "second_moment_decay ({}) not in (0, 1] for graft ({})".format( 133 | options.second_moment_decay, options.grafting_type 134 | ) 135 | ) 136 | if options.grafting_type == GraftingType.ADAFACTOR: 137 | if not (0 < options.second_moment_decay < 1.0): 138 | raise ValueError( 139 | "second_moment_decay ({}) not in (0, 1) for graft ({})".format( 140 | options.second_moment_decay, options.grafting_type 141 | ) 142 | ) 143 | if not (0 < options.min_dim_size_to_factor): 144 | raise ValueError( 145 | "min_dim_size_to_factor ({}) should be positive for graft ({})".format( 146 | options.min_dim_size_to_factor, options.grafting_type 147 | ) 148 | ) 149 | if options.clipping_threshold < 1: 150 | raise ValueError( 151 | "clipping_threshold ({}) should be >= 1 for graft ({})".format( 152 | options.clipping_threshold, options.grafting_type 153 | ) 154 | ) 155 | 156 | 157 | def _sgd() -> praxis_shim.ShardedGradientTransformation: 158 | """Create SGD sharded gradient transform.""" 159 | grad_transform = optax.identity() 160 | return praxis_shim.ShardedGradientTransformation( 161 | grad_transform.init, grad_transform.update, optax.EmptyState 162 | ) 163 | 164 | 165 | def _adafactor(options: Options) -> praxis_shim.ShardedGradientTransformation: 166 | """Create AdaFactor sharded gradient transform.""" 167 | tx = [ 168 | optax.adafactor( 169 | min_dim_size_to_factor=options.min_dim_size_to_factor, 170 | decay_rate=options.second_moment_decay, 171 | multiply_by_parameter_scale=options.multiply_by_parameter_scale, 172 | eps=options.epsilon, 173 | clipping_threshold=options.clipping_threshold, 174 | ) 175 | ] 176 | # Sign flip: optax.adafactor uses descent direction in updates. 177 | tx.append(optax.scale(-1)) 178 | grad_transform = optax.chain(*tx) 179 | 180 | def _adafactor_pspec_fn(params_unused): 181 | del params_unused 182 | raise NotImplementedError 183 | 184 | return praxis_shim.ShardedGradientTransformation( 185 | grad_transform.init, grad_transform.update, _adafactor_pspec_fn 186 | ) 187 | 188 | 189 | # Dummy wrapper for better state pretty printing, to identify what parameters 190 | # are for. 191 | class RMSPropAccumulator(NamedTuple): 192 | """State holding the sum/ema of gradient squares so far.""" 193 | 194 | acc: optax.Updates 195 | 196 | 197 | def _rmsprop(options: Options) -> praxis_shim.ShardedGradientTransformation: 198 | """Create RMSProp sharded gradient transform.""" 199 | 200 | def init_fn(params): 201 | acc = jax.tree.map(jnp.zeros_like, params) 202 | return RMSPropAccumulator(acc=acc) 203 | 204 | def update_fn(updates, state, params=None): 205 | del params 206 | 207 | # CHANGED: rmsprop normalized from distributed shampoo 208 | # normalizes grads layer-wise 209 | update_norms = jax.tree.map(jnp.linalg.norm, updates) 210 | update_norms = jax.tree.map(lambda x: jnp.where(x > 0.0, x, 1.0), update_norms) 211 | updates = jax.tree.map(lambda g, n: g / n, updates, update_norms) 212 | 213 | def ema(prev, new): 214 | second_moment_decay = options.second_moment_decay 215 | snew = jnp.square(new) 216 | if second_moment_decay == 1.0: 217 | return snew + prev 218 | else: 219 | return snew * (1 - second_moment_decay) + second_moment_decay * prev 220 | 221 | new_state = RMSPropAccumulator(jax.tree.map(ema, state.acc, updates)) 222 | epsilon = options.epsilon 223 | new_updates = jax.tree.map( 224 | lambda g, acc: g * jax.lax.rsqrt(acc + epsilon), updates, new_state.acc 225 | ) 226 | return new_updates, new_state 227 | 228 | def init_partition_spec_fn(mdl_params): 229 | def _opt_state_sharding_spec(var_hparams): 230 | s_var_hparams = copy.deepcopy(var_hparams) 231 | s_var_hparams.init = None 232 | return s_var_hparams 233 | 234 | mdl_sharding = jax.tree.map(_opt_state_sharding_spec, mdl_params) 235 | return RMSPropAccumulator(acc=mdl_sharding) 236 | 237 | return praxis_shim.ShardedGradientTransformation( 238 | init=init_fn, update=update_fn, init_partition_spec=init_partition_spec_fn 239 | ) 240 | 241 | 242 | class GraftingState(NamedTuple): 243 | """State holding the count for grafting.""" 244 | 245 | count: jax.Array 246 | direction: optax.OptState 247 | norm: optax.OptState 248 | 249 | 250 | def _graft_with( 251 | direction: praxis_shim.ShardedGradientTransformation, 252 | norm: praxis_shim.ShardedGradientTransformation, 253 | options: Options, 254 | ) -> praxis_shim.ShardedGradientTransformation: 255 | """Created a maybe-grafted update from a base update and a graft one.""" 256 | 257 | start_preconditioning_step = options.start_preconditioning_step 258 | mask = functools.partial(_mask_skipped, options) 259 | 260 | def init_fn(params): 261 | return GraftingState( 262 | count=jnp.zeros([], jnp.int32), 263 | direction=direction.init(mask(params)), 264 | norm=norm.init(params), 265 | ) 266 | 267 | def update_fn(updates, state, params=None): 268 | base_updates, base_state = direction.update( 269 | mask(updates), state.direction, mask(params) 270 | ) 271 | graft_updates, graft_state = norm.update(updates, state.norm, params) 272 | new_state = GraftingState( 273 | count=state.count + 1, direction=base_state, norm=graft_state 274 | ) 275 | 276 | def maybe_graft(graft_upd, base): 277 | if _masked(base): 278 | return graft_upd 279 | assert graft_upd.shape == base.shape 280 | 281 | base_norm = jnp.linalg.norm(base) 282 | multiplier = jnp.where( 283 | base_norm > 0.0, jnp.linalg.norm(graft_upd) / base_norm, 0.0 284 | ) 285 | return jnp.where( 286 | state.count >= start_preconditioning_step, base * multiplier, graft_upd 287 | ) 288 | 289 | new_updates = jax.tree.map( 290 | maybe_graft, graft_updates, base_updates, is_leaf=_masked 291 | ) 292 | return new_updates, new_state 293 | 294 | def init_partition_spec_fn(mdl_params): 295 | count_pspec = praxis_shim.WeightHParams( 296 | shape=[], 297 | init=None, 298 | dtype=jnp.int32, 299 | collections=None, 300 | tensor_split_dims_mapping=[], 301 | ) 302 | return dict( 303 | count=count_pspec, 304 | direction=direction.init_partition_spec(mdl_params), 305 | norm=norm.init_partition_spec(mdl_params), 306 | ) 307 | 308 | return praxis_shim.ShardedGradientTransformation( 309 | init_fn, update_fn, init_partition_spec_fn 310 | ) 311 | 312 | 313 | @struct.dataclass 314 | class _GraftMask: 315 | """Helper tuple which masks out params before preconditioning.""" 316 | 317 | pass 318 | 319 | 320 | def _mask_skipped(options: Options, tree: chex.ArrayTree) -> chex.ArrayTree: 321 | """Masks out arrays to which preconditioning should not be applied.""" 322 | 323 | def _maybe_mask(x: jax.Array): 324 | if options.skip_preconditioning_rank1 and x.ndim <= 1: 325 | return _GraftMask() 326 | if any(s > options.skip_preconditioning_any_dim_gt for s in x.shape): 327 | return _GraftMask() 328 | return x 329 | 330 | return jax.tree.map(_maybe_mask, tree) 331 | 332 | 333 | def _masked(tree_node: Any) -> bool: 334 | """Returns whether a tree node has been masked out.""" 335 | return isinstance(tree_node, _GraftMask) 336 | -------------------------------------------------------------------------------- /optimizers/tearfree/grafting_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for grafting implementations.""" 16 | 17 | import functools 18 | import itertools 19 | from typing import Sequence 20 | 21 | from absl.testing import absltest 22 | from absl.testing import parameterized 23 | import jax 24 | from jax import numpy as jnp 25 | import numpy as np 26 | import optax 27 | from optimizers.tearfree import grafting 28 | from optimizers.tearfree import praxis_shim 29 | 30 | 31 | def _minustwo() -> praxis_shim.ShardedGradientTransformation: 32 | """Generate a direction-reversing gradient transformation.""" 33 | update = functools.partial(jax.tree.map, lambda x: -2 * x) 34 | return praxis_shim.ShardedGradientTransformation( 35 | lambda _: optax.EmptyState, lambda u, s, _: (update(u), s), optax.EmptyState 36 | ) 37 | 38 | 39 | def _make_invalid_cases() -> Sequence[dict[str, ...]]: 40 | """Generate invalid cases which should throw.""" 41 | return [ 42 | { 43 | "testcase_name": "rmsprop_0", 44 | "invalid_options": grafting.Options( 45 | grafting.GraftingType.RMSPROP, 46 | second_moment_decay=0.0, 47 | start_preconditioning_step=0, 48 | ), 49 | }, 50 | { 51 | "testcase_name": "rmsprop_neg", 52 | "invalid_options": grafting.Options( 53 | grafting.GraftingType.RMSPROP, 54 | second_moment_decay=-1.0, 55 | start_preconditioning_step=0, 56 | ), 57 | }, 58 | { 59 | "testcase_name": "rmsprop_eps_neg", 60 | "invalid_options": grafting.Options( 61 | grafting.GraftingType.RMSPROP, 62 | epsilon=-1.0, 63 | start_preconditioning_step=0, 64 | ), 65 | }, 66 | { 67 | "testcase_name": "adafactor_0", 68 | "invalid_options": grafting.Options( 69 | grafting.GraftingType.ADAFACTOR, 70 | second_moment_decay=-1.0, 71 | start_preconditioning_step=0, 72 | ), 73 | }, 74 | { 75 | "testcase_name": "adafactor_neg", 76 | "invalid_options": grafting.Options( 77 | grafting.GraftingType.ADAFACTOR, 78 | second_moment_decay=-1.0, 79 | start_preconditioning_step=0, 80 | ), 81 | }, 82 | { 83 | "testcase_name": "adafactor_not_less_than_1", 84 | "invalid_options": grafting.Options( 85 | grafting.GraftingType.ADAFACTOR, 86 | second_moment_decay=1.0, 87 | start_preconditioning_step=0, 88 | ), 89 | }, 90 | { 91 | "testcase_name": "adafactor_eps_neg", 92 | "invalid_options": grafting.Options( 93 | grafting.GraftingType.ADAFACTOR, 94 | epsilon=-1.0, 95 | start_preconditioning_step=0, 96 | ), 97 | }, 98 | { 99 | "testcase_name": "adafactor_min_size_0", 100 | "invalid_options": grafting.Options( 101 | grafting.GraftingType.ADAFACTOR, 102 | min_dim_size_to_factor=0, 103 | start_preconditioning_step=0, 104 | ), 105 | }, 106 | { 107 | "testcase_name": "adafactor_min_size_neg", 108 | "invalid_options": grafting.Options( 109 | grafting.GraftingType.ADAFACTOR, 110 | min_dim_size_to_factor=-1, 111 | start_preconditioning_step=0, 112 | ), 113 | }, 114 | { 115 | "testcase_name": "adafactor_clip_less_than_1", 116 | "invalid_options": grafting.Options( 117 | grafting.GraftingType.ADAFACTOR, 118 | clipping_threshold=0.5, 119 | start_preconditioning_step=0, 120 | ), 121 | }, 122 | ] 123 | 124 | 125 | class GraftingTest(parameterized.TestCase): 126 | """Basic test for grafting praxis_shim implementations.""" 127 | 128 | def _check_equal(self, expected_tx, actual_tx, nsteps, shape=(3,)): 129 | rng = jax.random.PRNGKey(0) 130 | rng, key = jax.random.split(rng) 131 | params = jax.random.normal(key, shape) 132 | expected_state = expected_tx.init(params) 133 | actual_state = actual_tx.init(params) 134 | 135 | for i in range(nsteps): 136 | rng, key = jax.random.split(rng) 137 | grad = jax.random.normal(key, shape) 138 | expected_grad, expected_state = expected_tx.update( 139 | grad, expected_state, params 140 | ) 141 | actual_grad, actual_state = actual_tx.update(grad, actual_state, params) 142 | np.testing.assert_allclose(expected_grad, actual_grad, err_msg=i) 143 | 144 | def test_no_graft(self): 145 | """Check that no graft behaves exactly as the base transform.""" 146 | options = grafting.Options( 147 | grafting.GraftingType.NONE, 148 | 0.0, 149 | start_preconditioning_step=0, 150 | skip_preconditioning_rank1=False, 151 | ) 152 | grafted = grafting.graft(options, _minustwo()) 153 | nsteps = 4 154 | self._check_equal(_minustwo(), grafted, nsteps) 155 | 156 | def _check_norm_direction( 157 | self, norm_tx, direction_tx, actual_tx, nsteps, start_precond_step, shape=(3,) 158 | ): 159 | rng = jax.random.PRNGKey(0) 160 | rng, key = jax.random.split(rng) 161 | params = jax.random.normal(key, shape) 162 | state = actual_tx.init(params) 163 | norm_state = norm_tx.init(params) 164 | direction_state = norm_tx.init(params) 165 | 166 | for i in range(nsteps): 167 | rng, key = jax.random.split(rng) 168 | grad = jax.random.normal(key, shape) 169 | actual_grad, state = actual_tx.update(grad, state, params) 170 | 171 | norm_grad, norm_state = norm_tx.update(grad, norm_state, params) 172 | direction_grad, direction_state = direction_tx.update( 173 | grad, direction_state, params 174 | ) 175 | 176 | if i >= start_precond_step: 177 | direction_norm = jnp.linalg.norm(direction_grad) 178 | actual_norm = jnp.linalg.norm(actual_grad) 179 | norm_norm = jnp.linalg.norm(norm_grad) 180 | direction_grad_unit = direction_grad / direction_norm 181 | actual_grad_unit = actual_grad / actual_norm 182 | np.testing.assert_allclose( 183 | direction_grad_unit, actual_grad_unit, rtol=1e-6 184 | ) 185 | np.testing.assert_allclose(actual_norm, norm_norm, rtol=1e-6) 186 | else: 187 | np.testing.assert_allclose(norm_grad, actual_grad) 188 | 189 | def _norm_tx(self, options): 190 | if options.grafting_type == grafting.GraftingType.SGD: 191 | return grafting._sgd() 192 | if options.grafting_type == grafting.GraftingType.RMSPROP: 193 | return grafting._rmsprop(options) 194 | if options.grafting_type == grafting.GraftingType.ADAFACTOR: 195 | return grafting._adafactor(options) 196 | raise ValueError("unsupported grafting type " + str(options.grafting_type)) 197 | 198 | @parameterized.parameters( 199 | itertools.product([0, 1, 2], ["sgd", "rmsprop", "adafactor"], [(3,), (3, 2)]) 200 | ) 201 | def test_norm_direction(self, step, graft, shape): 202 | """Validate initial graft update, then switch to its norm.""" 203 | options = grafting.Options( 204 | grafting.GraftingType(graft), 205 | 0.9 if (graft == "rmsprop" or graft == "adafactor") else 0.0, 206 | start_preconditioning_step=step, 207 | skip_preconditioning_rank1=len(shape) > 1, 208 | min_dim_size_to_factor=1, 209 | ) 210 | grafted = grafting.graft(options, _minustwo()) 211 | nsteps = 4 212 | norm_tx = self._norm_tx(options) 213 | self._check_norm_direction(norm_tx, _minustwo(), grafted, nsteps, step, shape) 214 | 215 | @parameterized.parameters({"shape": s} for s in [tuple(), (3,), (5,), (5, 2)]) 216 | def test_skip(self, shape): 217 | """Make sure we skip preconditioning if out-of-bounds.""" 218 | options = grafting.Options( 219 | start_preconditioning_step=2, 220 | skip_preconditioning_any_dim_gt=4, 221 | skip_preconditioning_rank1=True, 222 | ) 223 | grafted = grafting.graft(options, _minustwo()) 224 | nsteps = 4 225 | norm_tx = self._norm_tx(options) 226 | self._check_equal(norm_tx, grafted, nsteps, shape) 227 | 228 | @parameterized.named_parameters(_make_invalid_cases()) 229 | def test_invalid(self, invalid_options): 230 | with self.assertRaises(ValueError): 231 | grafting.graft(invalid_options, _minustwo()) 232 | 233 | 234 | if __name__ == "__main__": 235 | absltest.main() 236 | -------------------------------------------------------------------------------- /optimizers/tearfree/momentum.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Momentum configuration and transform.""" 16 | 17 | import copy 18 | import dataclasses 19 | from typing import Any, NamedTuple, Union, Optional 20 | 21 | import jax 22 | import jax.tree_util as jtu 23 | import optax 24 | from optax._src import base, utils 25 | import optax.tree_utils as otu 26 | from optimizers.tearfree import praxis_shim 27 | 28 | 29 | @dataclasses.dataclass 30 | class Options: 31 | """Configuration dataclass for momentum. 32 | 33 | Notably, this class contains weight decay parameters. Why? 34 | 35 | In classical convex literature, Nesterov acceleration applied to gradient 36 | descent can be viewed as "revising" the last iterate's momentum based on 37 | the gradient we observe immediately after taking a momentum "gamble" 38 | (see viz, https://stats.stackexchange.com/a/191727). 39 | 40 | To maintain this interpretation exactly, we would need to go against 41 | the grain on how weight decay is implemented. Momentum must be the last* 42 | gradient transformation applied to the iterate, which would require the 43 | weight decay to be applied to the update before it's used to change 44 | the velocity (momentum's state, the first moment). 45 | 46 | In particular, AdamW and Adafactor suggest direct weight downscaling, 47 | excluding weight decay from the velocity accumulation. 48 | 49 | As a result, the true meaning of Nesterov acceleration here is better 50 | understood literally, described in its parameter doc. 51 | 52 | *Technically, some optimizers include the learning rate in the update used to 53 | update the velocity (e.g., Adafactor), but others apply the learning rate 54 | scaling last, after momentum (e.g., Adam). We can recover the former from the 55 | latter by dividing the decay by the root of the learning rate, so this 56 | particular "gradient transformation" shouldn't be viewed as affecting 57 | the Nesterov interpretation, up to tuning constants. 58 | 59 | Attributs: 60 | ema: If true, momentum is computed as an exponential moving 61 | average: `velocity(t+1) = decay * velocity(t) + (1 - decay) * update(t)` 62 | If false, then uses "trace" accumulation for momentum: 63 | `velocity(t+1) = decay * velocity(t) + update(t)`. Note that if the 64 | updates were the same (they aren't) then these would be the same up to a 65 | factor of `(1 - decay)`. This corresponds to distributed_shampoo argument 66 | `moving_average_for_momentum`. 67 | nesterov: Toggle for Nesterov acceleration. If false, then the new 68 | update `update'(t+1)` simply equals `velocity(t+1)`. If true, then 69 | `update'(t+1) = maybe_decay * update(t) + decay * velocity(t+1)`, where 70 | `maybe_decay` is `(1 - decay)` if `ema` and 1 otherwise. 71 | momentum_decay: The decay referred to in `ema` and `nesterov` formulas. 72 | weight_decay: Add `weight_decay * x(t)` to the `update(t)` value, where 73 | `x(t)` is the value of the current parameters. 74 | weight_decay_after_momentum: Whether weight decay addition is performed 75 | after the momentum transformation. 76 | momentum_dtype: str, `float32` or `bfloat16`, dtype of momentum buffer. 77 | """ 78 | 79 | ema: bool = True 80 | nesterov: bool = False 81 | momentum_decay: Optional[float] = 0.9 82 | weight_decay: float = 1e-4 83 | weight_decay_after_momentum: bool = True 84 | momentum_dtype: str = "float32" 85 | 86 | 87 | State = Union[optax.MaskedNode, optax.TraceState] 88 | 89 | 90 | def apply(options: Options) -> praxis_shim.ShardedGradientTransformation: 91 | """Generate the momentum update from options.""" 92 | _validate(options) 93 | 94 | momentum_transforms = [] 95 | if options.momentum_decay: 96 | if options.ema: 97 | momentum_transforms.append(optax.scale(1 - options.momentum_decay)) 98 | momentum_transforms.append( 99 | _sharded_trace( 100 | options.momentum_decay, options.nesterov, options.momentum_dtype 101 | ) 102 | ) 103 | 104 | wd_transforms = [optax.add_decayed_weights(options.weight_decay)] * ( 105 | options.weight_decay > 0.0 106 | ) 107 | 108 | if options.weight_decay_after_momentum: 109 | transforms = momentum_transforms + wd_transforms 110 | else: 111 | transforms = wd_transforms + momentum_transforms 112 | 113 | return praxis_shim.sharded_chain(*transforms) 114 | 115 | 116 | def _validate(options: Options): 117 | """Raise ValueError if options are invalid.""" 118 | if options.momentum_decay is not None: 119 | if not (0 <= options.momentum_decay <= 1): 120 | raise ValueError( 121 | "momentum_decay ({}) must be in [0, 1]".format(options.momentum_decay) 122 | ) 123 | 124 | if not (options.weight_decay >= 0): 125 | raise ValueError("weight_decay ({}) must be >= 0".format(options.weight_decay)) 126 | 127 | 128 | def _sharded_trace( 129 | momentum: float, nesterov: bool, accumulator_dtype: str 130 | ) -> praxis_shim.ShardedGradientTransformation: 131 | """Extend optax's trace to allow sharding.""" 132 | trace_transform = trace(momentum, nesterov, accumulator_dtype=accumulator_dtype) 133 | 134 | def init_pspec_fn(mdl_params): 135 | def _opt_state_sharding_spec(var_hparams): 136 | s_var_hparams = copy.deepcopy(var_hparams) 137 | s_var_hparams.init = None 138 | return s_var_hparams 139 | 140 | mdl_sharding = jax.tree.map(_opt_state_sharding_spec, mdl_params) 141 | return TraceState(trace=mdl_sharding) 142 | 143 | return praxis_shim.ShardedGradientTransformation( 144 | trace_transform.init, trace_transform.update, init_pspec_fn 145 | ) 146 | 147 | 148 | class TraceState(NamedTuple): 149 | """Holds an aggregation of past updates.""" 150 | 151 | trace: base.Params 152 | 153 | 154 | def trace( 155 | decay: float, nesterov: bool = False, accumulator_dtype: Optional[Any] = None 156 | ) -> base.GradientTransformation: 157 | """Compute a trace of past updates. 158 | 159 | Note: `trace` and `ema` have very similar but distinct updates; 160 | `trace = decay * trace + t`, while `ema = decay * ema + (1-decay) * t`. 161 | Both are frequently found in the optimization literature. 162 | 163 | Args: 164 | decay: Decay rate for the trace of past updates. 165 | nesterov: Whether to use Nesterov momentum. 166 | accumulator_dtype: Optional `dtype` to be used for the accumulator; if 167 | `None` then the `dtype` is inferred from `params` and `updates`. 168 | 169 | Returns: 170 | A `GradientTransformation` object. 171 | """ 172 | 173 | accumulator_dtype = utils.canonicalize_dtype(accumulator_dtype) 174 | 175 | def init_fn(params): 176 | trace = otu.tree_zeros_like(params, dtype=accumulator_dtype) 177 | 178 | # Calculate and print size for trace 179 | trace_n_elements = sum(leaf.size for leaf in jax.tree.leaves(trace)) 180 | trace_size_MB = sum( 181 | leaf.size * leaf.dtype.itemsize / (2**20) for leaf in jax.tree.leaves(trace) 182 | ) 183 | if jax.process_index() == 0: 184 | print(f"Momentum size: {trace_n_elements} elements, {trace_size_MB:.2f} MB") 185 | 186 | return TraceState(trace=trace) 187 | 188 | def update_fn(updates, state, params=None): 189 | del params 190 | f = lambda g, t: g + decay * t 191 | new_trace = jtu.tree_map(f, updates, state.trace) 192 | updates = jtu.tree_map(f, updates, new_trace) if nesterov else new_trace 193 | new_trace = otu.tree_cast(new_trace, accumulator_dtype) 194 | return updates, TraceState(trace=new_trace) 195 | 196 | return base.GradientTransformation(init_fn, update_fn) 197 | -------------------------------------------------------------------------------- /optimizers/tearfree/momentum_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for momentum implementation.""" 16 | 17 | import itertools 18 | from typing import Sequence 19 | 20 | from absl.testing import absltest 21 | from absl.testing import parameterized 22 | import jax 23 | from jax import numpy as jnp 24 | import numpy as np 25 | import optax 26 | from optimizers.tearfree import momentum 27 | 28 | 29 | def _make_no_state_cases() -> Sequence[dict[str, ...]]: 30 | bools = [False, True] 31 | cases = [] 32 | for ema, nesterov, wd, wd_after in itertools.product( 33 | bools, bools, [0.0, 0.9], bools 34 | ): 35 | momentum_decay = 0.0 36 | options = momentum.Options(ema, nesterov, momentum_decay, wd, wd_after) 37 | cases.append({"options": options}) 38 | return cases 39 | 40 | 41 | def _make_invalid_cases() -> Sequence[dict[str, ...]]: 42 | """Generate invalid cases which should throw.""" 43 | return [ 44 | { 45 | "testcase_name": "momentum_neg", 46 | "invalid_options": momentum.Options(momentum_decay=-1.0), 47 | }, 48 | { 49 | "testcase_name": "wd_neg", 50 | "invalid_options": momentum.Options(weight_decay=-0.1), 51 | }, 52 | { 53 | "testcase_name": "momentum_large", 54 | "invalid_options": momentum.Options(momentum_decay=1.1), 55 | }, 56 | ] 57 | 58 | 59 | class MomentumTest(parameterized.TestCase): 60 | """Basic test for momentum implementation.""" 61 | 62 | def _unroll(self, tx, n, extract=False, wd=0): 63 | """Generate states and grad updates n times.""" 64 | rng = jax.random.PRNGKey(0) 65 | params = jnp.ones((3,)) 66 | grads = jax.random.normal(rng, (n, 3)) + wd * params 67 | init = tx.init(params) 68 | 69 | def scan(state, grad): 70 | new_grad, new_state = tx.update(grad, state, params) 71 | return new_state, (new_state, new_grad) 72 | 73 | _, (states, out_grad) = jax.lax.scan(scan, init, grads) 74 | if not extract: 75 | return out_grad 76 | return self._extract_velocity(states), out_grad, grads 77 | 78 | def _check_equal(self, expected_tx, actual_tx, nsteps): 79 | expected_grads = self._unroll(expected_tx, nsteps) 80 | actual_grads = self._unroll(actual_tx, nsteps) 81 | np.testing.assert_allclose(expected_grads, actual_grads) 82 | 83 | @parameterized.parameters(0.1, 0.9, 0.99) 84 | def test_ema(self, decay): 85 | """Check that we simulate ema decay.""" 86 | options = momentum.Options(ema=True, nesterov=False, momentum_decay=decay) 87 | nsteps = 4 88 | actual = momentum.apply(options) 89 | expected = optax.ema(decay, debias=False) 90 | self._check_equal(expected, actual, nsteps) 91 | 92 | def _extract_velocity(self, state): 93 | """Asserts only velocity state exists, extracts it.""" 94 | flat = jax.tree_util.tree_flatten(state)[0] 95 | self.assertLen(flat, 1) 96 | return flat[0] 97 | 98 | @parameterized.parameters(itertools.product([False, True], repeat=2)) 99 | def test_wd_before_momentum(self, ema, nesterov): 100 | options = momentum.Options( 101 | ema=ema, nesterov=nesterov, momentum_decay=0.9, weight_decay=0.0 102 | ) 103 | nsteps = 4 104 | tx = momentum.apply(options) 105 | expected_grads = self._unroll(tx, nsteps, wd=0.1) 106 | options = momentum.Options( 107 | ema=ema, 108 | nesterov=nesterov, 109 | momentum_decay=0.9, 110 | weight_decay=0.1, 111 | weight_decay_after_momentum=False, 112 | ) 113 | tx = momentum.apply(options) 114 | actual_grads = self._unroll(tx, nsteps) 115 | np.testing.assert_allclose(expected_grads, actual_grads) 116 | 117 | @parameterized.parameters(itertools.product([False, True], repeat=2)) 118 | def test_basic(self, ema, decay_after): 119 | wd = 0.1 if decay_after else 0.0 120 | if decay_after: 121 | return 122 | decay = 0.9 123 | options = momentum.Options( 124 | ema=ema, 125 | nesterov=True, 126 | momentum_decay=decay, 127 | weight_decay=wd, 128 | weight_decay_after_momentum=True, 129 | ) 130 | tx = momentum.apply(options) 131 | v, g, ig = self._unroll(tx, 2, extract=True) 132 | 133 | ev = jnp.zeros((3,)) 134 | factor = (1 - decay) if ema else 1.0 135 | ev += factor * ig[0] 136 | self.assertSequenceAlmostEqual(v[0], ev, msg=v) 137 | expected_grad = decay * ev + factor * ig[0] 138 | expected_grad += jnp.ones((3,)) * wd 139 | self.assertSequenceAlmostEqual(g[0], expected_grad) 140 | 141 | ev = ev * decay + factor * ig[1] 142 | self.assertSequenceAlmostEqual(v[1], ev, delta=1e-6) 143 | expected_grad = decay * ev + factor * ig[1] 144 | expected_grad += jnp.ones((3,)) * wd 145 | self.assertSequenceAlmostEqual(g[1], expected_grad, delta=1e-6) 146 | 147 | @parameterized.parameters(_make_no_state_cases()) 148 | def test_no_state(self, options): 149 | """Ensure no state is created when decay is 0.0.""" 150 | assert options.momentum_decay == 0.0 151 | tx = momentum.apply(options) 152 | state = tx.init(jnp.zeros((3,))) 153 | flat = jax.tree_util.tree_flatten(state)[0] 154 | self.assertEmpty(flat) 155 | 156 | @parameterized.named_parameters(_make_invalid_cases()) 157 | def test_invalid(self, invalid_options): 158 | with self.assertRaises(ValueError): 159 | momentum.apply(invalid_options) 160 | 161 | 162 | if __name__ == "__main__": 163 | absltest.main() 164 | -------------------------------------------------------------------------------- /optimizers/tearfree/optimizer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tearfree optimizer implementation. 16 | 17 | OOM making your eyes water? Try the Tearfree Shampoo optimizer. 18 | 19 | This module handles logic for 20 | 21 | 1. Statistics/preconditioner update frequency 22 | 2. Applying momentum 23 | 3. Combining grafting and preconditioning updates, applying grafting 24 | 4. Typical update procedures, like learning rate, momentum, etc. 25 | """ 26 | 27 | import dataclasses 28 | from typing import Union 29 | 30 | import chex 31 | import optax 32 | from optimizers.tearfree import grafting 33 | from optimizers.tearfree import momentum 34 | from optimizers.tearfree import praxis_shim 35 | from optimizers.tearfree import second_order 36 | 37 | 38 | @dataclasses.dataclass 39 | class TearfreeOptions: 40 | """Configuration dataclass for tearfree optimizer. 41 | 42 | Attributes: 43 | grafting_options: Grafting options to modify update norm (see 44 | `grafting.Options`). 45 | second_order_options: Second-order statistics tracking options (see 46 | `second_order.Options`). 47 | momentum_options: Momentum options (see `momentum.Options`). 48 | """ 49 | 50 | grafting_options: grafting.Options = dataclasses.field( 51 | default_factory=grafting.Options 52 | ) 53 | second_order_options: second_order.Options = dataclasses.field( 54 | default_factory=second_order.Options 55 | ) 56 | momentum_options: momentum.Options = dataclasses.field( 57 | default_factory=momentum.Options 58 | ) 59 | 60 | 61 | def tearfree( 62 | learning_rate: Union[chex.Numeric, optax.Schedule], options: TearfreeOptions 63 | ) -> praxis_shim.ShardedGradientTransformation: 64 | """Tearfree optimizer, supports pjit and jit. 65 | 66 | Preconditioned, grafted updates with momentum. 67 | 68 | One key difference in the logic is to only use a single momentum between 69 | the graft and preconditioned update. `distributed_shampoo` keeps a separate 70 | `diagonal_momentum` buffer, but never uses it after preconditioning is 71 | active (it is not used to adjust the grafting norm). This implies (1) 72 | we save memory (only one momentum buffer), (2) we are identical to 73 | `distributed_shampoo` if there is no warmup or no preconditioning 74 | (`options.start_preconditioning_step` is inf or 0). 75 | 76 | Args: 77 | learning_rate: The learning rate value or schedule. Learning rate is 78 | "decoupled", i.e., we always apply it last to the update (after weight 79 | decay, after momentum, etc.). 80 | options: Tearfree optimizer options. 81 | 82 | Returns: 83 | The sharded gradient transformation corresponding to an updated, 84 | preconditioned gradient, times the negative learning rate. 85 | """ 86 | 87 | second_order_tx = second_order.apply(options.second_order_options) 88 | graft_tx = grafting.graft(options.grafting_options, second_order_tx) 89 | momentum_tx = momentum.apply(options.momentum_options) 90 | if callable(learning_rate): 91 | lr_tx = optax.scale_by_schedule(lambda x: -1.0 * learning_rate(x)) 92 | else: 93 | lr_tx = optax.scale(-1.0 * learning_rate) 94 | return praxis_shim.sharded_chain(graft_tx, momentum_tx, lr_tx) 95 | -------------------------------------------------------------------------------- /optimizers/tearfree/optimizer_smoke_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Smoke tests for tearfree. 16 | 17 | The smoke test uses CPU-based sharding to verify that, under a variety of 18 | settings, (1) the optimizer results in finite, not-nan gradients and (2) 19 | distributed computation options don't change the math. 20 | """ 21 | 22 | import copy 23 | from typing import Sequence, Union 24 | 25 | from absl.testing import absltest 26 | from absl.testing import parameterized 27 | import chex 28 | import jax 29 | from jax import numpy as jnp 30 | import numpy as np 31 | import optax 32 | from optimizers.tearfree import grafting 33 | from optimizers.tearfree import momentum 34 | from optimizers.tearfree import optimizer 35 | from optimizers.tearfree import second_order 36 | from optimizers.tearfree import shampoo 37 | from optimizers.tearfree import sketchy 38 | 39 | 40 | def _make_distributed_equality_cases() -> list[dict[str, ...]]: 41 | """Make test cases of options for optimizer checks.""" 42 | cases = [] 43 | 44 | # Basic options exercise all of shampoo, grafting after the first step. 45 | basic_options = optimizer.TearfreeOptions( 46 | grafting_options=grafting.Options( 47 | grafting_type=grafting.GraftingType.RMSPROP, 48 | second_moment_decay=0.9, 49 | epsilon=1e-5, 50 | start_preconditioning_step=1, 51 | skip_preconditioning_any_dim_gt=4096, 52 | skip_preconditioning_rank1=False, 53 | ), 54 | second_order_options=second_order.Options( 55 | second_order_type=second_order.SecondOrderType.SHAMPOO, 56 | shampoo_options=shampoo.Options( 57 | block_size=1024, 58 | update_preconditioners_freq=1, 59 | update_statistics_freq=1, 60 | second_moment_decay=0.9, 61 | ), 62 | merge_dims=4096, 63 | ), 64 | momentum_options=momentum.Options( 65 | ema=True, 66 | nesterov=True, 67 | momentum_decay=0.5, 68 | weight_decay=0.0, 69 | weight_decay_after_momentum=True, 70 | ), 71 | ) 72 | 73 | basic_case = { 74 | "testcase_name": "basic", 75 | "nsteps": 3, 76 | "options": basic_options, 77 | "lr": 0.1, 78 | "shape": (4,), 79 | } 80 | cases.append(basic_case) 81 | 82 | case = copy.deepcopy(basic_case) 83 | case["lr"] = lambda x: 0.1 / (x + 1) 84 | case["testcase_name"] = "schedule" 85 | cases.append(case) 86 | 87 | case = copy.deepcopy(basic_case) 88 | second_order_options = case["options"].second_order_options 89 | second_order_options.second_order_type = second_order.SecondOrderType.SKETCHY 90 | second_order_options.shampoo_options = None 91 | second_order_options.sketchy_options = sketchy.Options() 92 | case["testcase_name"] = "sketchy" 93 | cases.append(case) 94 | 95 | case = copy.deepcopy(case) 96 | case["testcase_name"] += "_notrunc_lowrank" 97 | sketchy_options = case["options"].second_order_options.sketchy_options 98 | sketchy_options.truncate_numerical_noise = False 99 | sketchy_options.rank = 2 100 | cases.append(case) 101 | 102 | case = copy.deepcopy(basic_case) 103 | case["options"].grafting_options.grafting_type = grafting.GraftingType.ADAFACTOR 104 | case["testcase_name"] = "adafactor" 105 | cases.append(case) 106 | 107 | # Need to test we at least parallelize the identical-to-tensor shapes 108 | # without any blocks. 109 | # Additional variants: 110 | # wd 111 | # wd with decay before momentum 112 | # grid of nesterov/ema 113 | # exercise merge dims 2d doing a merge 114 | # exercise merge dims 3d with only one thing merged 115 | # skip preconditioning any dim gt activating 116 | # skip preconditioning any dim gt rank1 activating 117 | # update stats/precond every 2 (6 steps) 118 | # update stats/precond every 2/4 (6 steps) 119 | 120 | # Test block-wise parallelism for Shampoo 121 | 122 | return cases 123 | 124 | 125 | class OptimizerSmokeTest(parameterized.TestCase): 126 | """Basic test for optimizer configurations.""" 127 | 128 | def _unroll(self, options, shape, transform=None, lr=0.1, n=4): 129 | """Generate states and grad updates n times.""" 130 | rng = jax.random.PRNGKey(0) 131 | params = jnp.zeros(shape) 132 | grads = jax.random.normal(rng, (n, *shape)) 133 | 134 | if transform is not None: 135 | params = transform(params) 136 | grads = jnp.stack([transform(g) for g in grads]) 137 | 138 | tx = optimizer.tearfree(lr, options) 139 | 140 | init = tx.init(params) 141 | 142 | def reduce(state, grad): 143 | new_grad, new_state = tx.update(grad, state, params) 144 | return new_state, new_grad 145 | 146 | _, out_grads = jax.lax.scan(reduce, init, grads) 147 | return out_grads 148 | 149 | @parameterized.named_parameters(_make_distributed_equality_cases()) 150 | def test_distributed_equality( 151 | self, 152 | options: optimizer.TearfreeOptions, 153 | shape: Sequence[int], 154 | lr: Union[float, optax.Schedule], 155 | nsteps: int, 156 | ) -> None: 157 | single_core = self._unroll(options, shape, lr=lr, n=nsteps) 158 | multi_core = self._unroll(options, shape, lr=lr, n=nsteps) 159 | 160 | chex.assert_tree_all_finite(single_core) 161 | np.testing.assert_allclose(single_core, multi_core) 162 | 163 | 164 | if __name__ == "__main__": 165 | absltest.main() 166 | -------------------------------------------------------------------------------- /optimizers/tearfree/optimizer_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for tearfree optimizer.""" 16 | 17 | import dataclasses 18 | 19 | from absl.testing import absltest 20 | from absl.testing import parameterized 21 | import jax 22 | from jax import numpy as jnp 23 | import numpy as np 24 | import optax 25 | from optimizers.tearfree import grafting 26 | from optimizers.tearfree import momentum 27 | from optimizers.tearfree import optimizer 28 | from optimizers.tearfree import praxis_shim 29 | from optimizers.tearfree import second_order 30 | from optimizers.tearfree import shampoo 31 | 32 | 33 | class OptimizerTest(parameterized.TestCase): 34 | """Basic test for optimizer configurations.""" 35 | 36 | def setUp(self): 37 | super().setUp() 38 | jax.config.update("jax_debug_nans", True) 39 | 40 | def _unroll(self, options, shape, transform=None, lr=0.1, n=4): 41 | """Generate states and grad updates n times.""" 42 | rng = jax.random.PRNGKey(0) 43 | params = jnp.zeros(shape) 44 | grads = jax.random.normal(rng, (n, *shape)) 45 | 46 | if transform is not None: 47 | params = transform(params) 48 | grads = jnp.stack([transform(g) for g in grads]) 49 | 50 | if isinstance(options, optimizer.TearfreeOptions): 51 | tx = optimizer.tearfree(lr, options) 52 | else: 53 | tx = options 54 | init = tx.init(params) 55 | 56 | def reduce(state, grad): 57 | new_grad, new_state = tx.update(grad, state, params) 58 | return new_state, new_grad 59 | 60 | _, out_grads = jax.lax.scan(reduce, init, grads) 61 | return out_grads 62 | 63 | def _no_graft_no_momentum(self): 64 | return optimizer.TearfreeOptions( 65 | grafting_options=grafting.Options( 66 | grafting_type=grafting.GraftingType.NONE, 67 | second_moment_decay=0.0, 68 | skip_preconditioning_rank1=False, 69 | ), 70 | momentum_options=momentum.Options(momentum_decay=0.0), 71 | ) 72 | 73 | def test_merge_dims(self): 74 | shape = (2, 2) 75 | options = dataclasses.replace( 76 | self._no_graft_no_momentum(), 77 | second_order_options=second_order.Options(merge_dims=4), 78 | ) 79 | transform = lambda x: x.reshape(4) 80 | actual = self._unroll(options, shape) 81 | expected = self._unroll(options, shape, transform) 82 | np.testing.assert_allclose(actual.reshape(-1, 4), expected) 83 | 84 | def test_block_size(self): 85 | shape = (4,) 86 | options = dataclasses.replace( 87 | self._no_graft_no_momentum(), 88 | second_order_options=second_order.Options( 89 | shampoo_options=shampoo.Options(block_size=3) 90 | ), 91 | ) 92 | actual = self._unroll(options, shape) 93 | expected = self._unroll(options, shape) 94 | np.testing.assert_allclose(actual, expected) 95 | 96 | @parameterized.parameters( 97 | momentum.Options(), # Default is 0.9, active momentum. 98 | momentum.Options(momentum_decay=0.0), 99 | momentum.Options(weight_decay=0.01), 100 | momentum.Options(weight_decay=0.01, weight_decay_after_momentum=False), 101 | momentum.Options(nesterov=False), 102 | momentum.Options(ema=True), 103 | momentum.Options(ema=True, nesterov=True), 104 | ) 105 | def test_momentum_no_graft(self, momentum_options): 106 | shape = (4,) 107 | options = self._no_graft_no_momentum() 108 | options.momentum_options = momentum_options 109 | tx = praxis_shim.sharded_chain( 110 | second_order.apply(options.second_order_options), 111 | momentum.apply(momentum_options), 112 | optax.scale(-0.1), 113 | ) 114 | actual = self._unroll(options, shape) 115 | expected = self._unroll(tx, shape) 116 | np.testing.assert_allclose(actual, expected) 117 | 118 | def _grafting_tx( 119 | self, grafting_options 120 | ) -> praxis_shim.ShardedGradientTransformation: 121 | id_tx = optax.identity() 122 | id_tx_shard = praxis_shim.ShardedGradientTransformation( 123 | id_tx.init, id_tx.update, lambda _: optax.EmptyState() 124 | ) 125 | return grafting.graft(grafting_options, id_tx_shard) 126 | 127 | def _grafting_tx_with_momentum(self, grafting_options, momentum_options, lr=0.1): 128 | return praxis_shim.sharded_chain( 129 | self._grafting_tx(grafting_options), 130 | momentum.apply(momentum_options), 131 | optax.scale(-lr), 132 | ) 133 | 134 | @parameterized.parameters( 135 | grafting.Options(), 136 | grafting.Options( 137 | grafting_type=grafting.GraftingType.SGD, second_moment_decay=0.0 138 | ), 139 | grafting.Options(second_moment_decay=1.0), 140 | ) 141 | def test_momentum_yes_graft(self, grafting_options): 142 | shape = (4,) 143 | nsteps = 4 144 | options = self._no_graft_no_momentum() 145 | options.momentum_options.momentum_decay = 0.9 146 | options.grafting_options = grafting_options 147 | grafting_options.start_preconditioning_step = nsteps + 1 148 | grafting_options.skip_preconditioning_rank1 = False 149 | tx = self._grafting_tx_with_momentum(grafting_options, options.momentum_options) 150 | expected = self._unroll(tx, shape, n=nsteps) 151 | actual = self._unroll(options, shape, n=nsteps) 152 | np.testing.assert_allclose(actual, expected) 153 | 154 | def _precondition_at(self, i): 155 | """Return optimizer with momentum, grafting, and start precon at step i.""" 156 | return optimizer.TearfreeOptions( 157 | grafting_options=grafting.Options( 158 | start_preconditioning_step=i, skip_preconditioning_rank1=False 159 | ) 160 | ) 161 | 162 | @parameterized.parameters( 163 | dict(shape=(1, 1, 1)), dict(shape=(1,)), dict(shape=tuple()) 164 | ) 165 | def test_scalar_is_grafting(self, shape): 166 | nsteps = 4 167 | options = self._precondition_at(2) 168 | tx = self._grafting_tx_with_momentum( 169 | options.grafting_options, options.momentum_options 170 | ) 171 | expected = self._unroll(tx, shape, n=nsteps) 172 | actual = self._unroll(options, shape, n=nsteps) 173 | np.testing.assert_allclose(actual, expected) 174 | 175 | def test_lr(self): 176 | shape = (3,) 177 | options = self._precondition_at(2) 178 | nsteps = 4 179 | 180 | def schedule(count): 181 | return (count + 1) * 0.1 182 | 183 | actual = self._unroll(options, shape, lr=schedule, n=nsteps) 184 | expected = self._unroll(options, shape, lr=0.1, n=nsteps) 185 | expected *= (jnp.arange(nsteps) + 1).reshape(-1, 1) 186 | np.testing.assert_allclose(actual, expected) 187 | 188 | 189 | if __name__ == "__main__": 190 | absltest.main() 191 | -------------------------------------------------------------------------------- /optimizers/tearfree/praxis_shim.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Shim interfaces for praxis, to avoid circular dependencies.""" 16 | 17 | import dataclasses 18 | from typing import Any, NamedTuple, Union 19 | 20 | import jax 21 | from jax import numpy as jnp 22 | import optax 23 | 24 | 25 | @dataclasses.dataclass(frozen=True) 26 | class ShardedGradientTransformation: 27 | """GradientTransformation that supports spmd.""" 28 | 29 | init: optax.TransformInitFn 30 | update: optax.TransformUpdateFn 31 | init_partition_spec: Any 32 | 33 | 34 | NestedHParams = Any 35 | 36 | 37 | class WeightHParams(NamedTuple): 38 | shape: list[int] 39 | init: Any 40 | dtype: jnp.dtype 41 | collections: Any 42 | tensor_split_dims_mapping: list[int] 43 | 44 | 45 | def sharded_chain( 46 | *args: Union[optax.GradientTransformation, ShardedGradientTransformation] 47 | ) -> ShardedGradientTransformation: 48 | """Chain as in praxis.optimizers.sharded_chain.""" 49 | 50 | def init_fn(params): 51 | return tuple(fn.init(params) for fn in args) 52 | 53 | def update_fn(updates, state, params=None): 54 | if len(args) != len(state): 55 | raise ValueError( 56 | "The number of updates and states has to be the same in " 57 | f"sharded chain. got {len(args)=}, {len(state)=}" 58 | ) 59 | 60 | new_state = [] 61 | for s, fn in zip(state, args): 62 | updates, new_s = fn.update(updates, s, params) 63 | # Some of the new states may have None instead of optax.MaskedNode. 64 | new_s = jax.tree.map( 65 | lambda x: optax.MaskedNode() if x is None else x, 66 | new_s, 67 | is_leaf=lambda x: x is None, 68 | ) 69 | new_state.append(new_s) 70 | return updates, tuple(new_state) 71 | 72 | def init_partition_spec_fn(mdl_vars): 73 | partition_specs = [] 74 | for fn in args: 75 | init_partition_spec = getattr(fn, "init_partition_spec", None) 76 | if callable(init_partition_spec): 77 | nmap = init_partition_spec(mdl_vars) 78 | partition_specs.append(nmap) 79 | else: 80 | # Raise ValueError as we are attempting to sharded_chain an optimizer 81 | # that does not have an `init_partition_spec` method defined. 82 | raise ValueError( 83 | "Attempting to use an optimizer in sharded_chain that " 84 | "does not have an init_partition_spec." 85 | ) 86 | return optax.MaskedState(inner_state=tuple(partition_specs)) 87 | 88 | return ShardedGradientTransformation( 89 | init=init_fn, update=update_fn, init_partition_spec=init_partition_spec_fn 90 | ) 91 | -------------------------------------------------------------------------------- /optimizers/tearfree/reallocation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Sketchy memory reallocation across layers based on checkpoint info.""" 16 | 17 | import concurrent 18 | import copy 19 | import os 20 | from typing import Any, Optional 21 | 22 | from absl import app 23 | from absl import flags 24 | from flax.training import checkpoints as flax_chpts 25 | from jax import numpy as jnp 26 | 27 | 28 | def load_checkpoints(file_dir: str) -> list[Any]: 29 | """Load checkpoints from the directory where checkpoints are saved.""" 30 | files = [] 31 | 32 | for f in os.listdir(file_dir): 33 | if f.startswith("ckpt_"): 34 | v = int(f[len("ckpt_") :]) 35 | files.append((v, f)) 36 | files.sort() 37 | return files 38 | 39 | 40 | def create_state(file_dir: str, idx: list[int]): 41 | """Create states from selected checkpoints.""" 42 | 43 | files = load_checkpoints(file_dir) 44 | 45 | def extract_state(args): 46 | _, prefix = args 47 | restored = flax_chpts.restore_checkpoint(file_dir, target=None, prefix=prefix) 48 | state = restored["optimizer_state"] 49 | if "base_state" in state: 50 | state = state["base_state"] 51 | return state, prefix 52 | 53 | with concurrent.futures.ThreadPoolExecutor() as tpe: 54 | states, _ = zip(*tpe.map(extract_state, [files[i] for i in idx])) 55 | 56 | return tuple(states) 57 | 58 | 59 | def layers_and_axes(sketches: dict[str, Any]): 60 | """List names for all of the layers.""" 61 | 62 | def extract_paths(sketches, parent_key="", paths=None): 63 | if paths is None: 64 | paths = set() 65 | 66 | for key, value in sketches.items(): 67 | new_key = parent_key + "/" + key if parent_key else key 68 | if isinstance(value, dict): 69 | extract_paths(value, new_key, paths) 70 | else: 71 | paths.add(parent_key) 72 | return paths 73 | 74 | all_layer_names = extract_paths(sketches) 75 | layer_names = {name for name in all_layer_names if name[-2] == "/"} 76 | axes_set = {name[-1] for name in all_layer_names if name[-2] == "/"} 77 | return layer_names, len(axes_set) 78 | 79 | 80 | def create_groups( 81 | sketches: dict[str, Any], layer_names: set[str] 82 | ) -> dict[int, list[str]]: 83 | """Create groups for layers based on their dimensions.""" 84 | 85 | group_dict = {} 86 | 87 | for name in layer_names: 88 | dirs = name.split("/") 89 | carry = sketches 90 | for d in dirs: 91 | carry = carry[d] 92 | if "dim" in carry: 93 | key = carry["dim"] 94 | else: 95 | key = carry["eigvecs"].shape[0] 96 | if key in group_dict: 97 | group_dict[key].append(name) 98 | else: 99 | group_dict[key] = [name] 100 | 101 | return group_dict 102 | 103 | 104 | def score_fn( 105 | states, rule, layer_names: set[str], running_average=False 106 | ) -> dict[str, float]: 107 | """Calculate scores for each layer.""" 108 | 109 | feasible_rules = [ 110 | "ggt_intrinsic_rank", 111 | "ggt_trace", 112 | "tail_rho", 113 | "sketch_intrinsic_rank", 114 | "sketch_trace", 115 | ] 116 | 117 | if rule not in feasible_rules: 118 | raise NotImplementedError() 119 | 120 | if rule.startswith("ggt"): 121 | target = "ema_ggt" 122 | elif rule.startswith("sketch"): 123 | target = "eigvals" 124 | else: 125 | target = "tail" 126 | ops_dict = { 127 | "ggt_intrinsic_rank": lambda x: jnp.trace(x) / jnp.linalg.norm(x, 2), 128 | "ggt_trace": jnp.trace, 129 | "tail_rho": lambda x: x, 130 | "sketch_intrinsic_rank": ( 131 | lambda x: jnp.sum(x) / jnp.max(x) if jnp.sum(x) else 0 132 | ), 133 | "sketch_trace": jnp.sum, 134 | } 135 | if running_average: 136 | sketches = [ 137 | st["inner_state"]["0"]["direction"]["1"]["sketches"] for st in states 138 | ] 139 | else: 140 | sketches = [states[-1]["inner_state"]["0"]["direction"]["1"]["sketches"]] 141 | 142 | len_sketches = len(sketches) 143 | score_dict = {} 144 | for name in layer_names: 145 | dirs = name.split("/") 146 | current_target = copy.deepcopy(sketches) 147 | for i in range(len_sketches): 148 | ct = current_target[i] 149 | for d in dirs: 150 | ct = ct[d] 151 | current_target[i] = ct[target] 152 | score_dict[name] = jnp.mean( 153 | jnp.array([ops_dict[rule](ct) for ct in current_target]) 154 | ) 155 | return score_dict # pytype: disable=bad-return-type # jnp-type 156 | 157 | 158 | def create_redist_dict( 159 | file_dir: str, 160 | idx: list[int], 161 | rule: str, 162 | running_average: bool, 163 | sketchy_rank: int, 164 | states: Optional[str] = None, 165 | ): 166 | """Create dictionary of reallocated memory to each layers.""" 167 | if not states: 168 | states = create_state(file_dir, idx) 169 | sketches = states[-1]["inner_state"]["0"]["direction"]["1"]["sketches"] 170 | layer_names, num_axes = layers_and_axes(sketches) 171 | group_dict = create_groups(sketches, layer_names) 172 | score_dict = score_fn(states, rule, layer_names, running_average) 173 | 174 | def create_redist(): 175 | res = {} 176 | for p in list(score_dict): 177 | dirs = p.split("/")[:-2] 178 | cur = res 179 | for d in dirs[:-1]: 180 | cur = cur.setdefault(d, {}) 181 | cur[dirs[-1]] = [0] * num_axes 182 | return res 183 | 184 | def alloc_fn(redist, group, realloc_dict): 185 | for key in group: 186 | dirs, axes_id = key.split("/")[:-2], int(key.split("/")[-1]) 187 | carry = redist 188 | for d in dirs: 189 | carry = carry[d] 190 | carry[axes_id] = realloc_dict[key] 191 | return redist 192 | 193 | def rd(x): 194 | return int(x // 1) + 1 195 | 196 | def grp_info(dim): 197 | group = group_dict[dim] 198 | group_size = len(group) 199 | group_resource = group_size * sketchy_rank 200 | return group, group_size, group_resource 201 | 202 | def is_outlier(score, total_score, total_resource, dim): 203 | unit_rsc = total_resource / total_score if total_score else 0.0 204 | allocated_rsc = rd(score * unit_rsc) - 1 205 | return allocated_rsc > dim 206 | 207 | redist_dict = create_redist() 208 | 209 | for dim in group_dict: 210 | group, group_size, group_resource = grp_info(dim) 211 | assert group_resource >= group_size, (group_resource, group_size) 212 | group_resource -= group_size 213 | total_score = sum(score_dict[key] for key in group) 214 | sorted_scores = sorted( 215 | [(key, score_dict[key]) for key in group], key=lambda x: x[1], reverse=True 216 | ) 217 | realloc = {} 218 | for pair in sorted_scores: 219 | if is_outlier(pair[1], total_score, group_resource, dim - 1): 220 | realloc.update({pair[0]: dim}) 221 | group_resource -= dim - 1 222 | total_score -= pair[1] 223 | else: 224 | unit_rsc = group_resource / total_score if total_score else 0.0 225 | realloc.update({pair[0]: rd(pair[1] * unit_rsc)}) 226 | group_resource -= rd(pair[1] * unit_rsc) - 1 227 | total_score -= pair[1] 228 | 229 | for key in realloc: 230 | assert realloc[key] <= dim, (key, realloc[key], dim) 231 | 232 | allocated = sum(realloc.values()) 233 | _, _, group_resource = grp_info(dim) 234 | assert allocated <= group_resource, (group_resource, allocated) 235 | 236 | if allocated < group_resource: 237 | extra = group_resource - allocated 238 | for key, _ in sorted_scores: 239 | realloc[key] = min(realloc[key] + 1, dim) 240 | extra = extra - 1 if realloc[key] + 1 < dim else extra 241 | if extra <= 0: 242 | break 243 | 244 | redist_dict = alloc_fn(redist_dict, group, realloc) 245 | 246 | return redist_dict 247 | 248 | 249 | _DIR = flags.DEFINE_string("dir", "", "directory with checkpoints, must be set") 250 | 251 | _IDX = flags.DEFINE_multi_integer( 252 | "idx", -1, "indices of checkpoints to anlayze, default last checkpoint" 253 | ) 254 | 255 | _RULE = flags.DEFINE_string( 256 | "rule", "sketch_trace", "statistics to reallocate based on, default sketch trace" 257 | ) 258 | 259 | _AVG = flags.DEFINE_bool( 260 | "avg", False, "whether to use running average of the statistics, default False" 261 | ) 262 | 263 | 264 | _RANK = flags.DEFINE_integer( 265 | "rank", 256, "rellocation base per-layer resource, default 256" 266 | ) 267 | 268 | 269 | def _validate_flags(): 270 | """Raise errors if flags are improperly set.""" 271 | if not _DIR.value: 272 | raise ValueError("--dir must be set") 273 | return 0 274 | 275 | 276 | def main(argv: ...): 277 | del argv 278 | _validate_flags() 279 | args = [_DIR.value, _IDX.value, _RULE.value, _AVG.value, _RANK.value] 280 | return create_redist_dict(*args) 281 | 282 | 283 | if __name__ == "__main__": 284 | app.run(main) 285 | -------------------------------------------------------------------------------- /optimizers/tearfree/reallocation_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Simple test case for memory reallocation function.""" 16 | 17 | import json 18 | import os 19 | 20 | from absl.testing import absltest 21 | from absl.testing import parameterized 22 | from jax import numpy as jnp 23 | from optimizers.tearfree import reallocation 24 | 25 | 26 | def dict_almost_equal(dict1, dict2, delta=1): 27 | """Helper function.""" 28 | for key, value in dict1.items(): 29 | assert key in dict2, key 30 | if isinstance(value, dict): 31 | dict_almost_equal(value, dict2[key], delta) 32 | else: 33 | for i in range(len(value)): 34 | assert jnp.abs(value[i] - dict2[key][i]) <= delta 35 | 36 | 37 | class ReallocationTest(parameterized.TestCase): 38 | 39 | def test_create_redist_dict(self): 40 | chpt_path = "" 41 | data_dir = os.path.join(os.path.dirname(__file__), "reallocation_test_data") 42 | realloc_path = os.path.join(data_dir, "gnn_realloc.json") 43 | states_path = os.path.join(data_dir, "states.json") 44 | with open(states_path, "r") as f: 45 | states = tuple(json.load(f)) 46 | sketches = states[-1]["inner_state"]["0"]["direction"]["1"]["sketches"] 47 | for layer in sketches: 48 | tmp = sketches[layer]["kernel"]["axes"] 49 | for axes in tmp: 50 | tmp[axes]["eigvals"] = jnp.array( 51 | tmp[axes]["eigvals"], dtype=jnp.float32 52 | ) 53 | states[-1]["inner_state"]["0"]["direction"]["1"]["sketches"][layer][ 54 | "kernel" 55 | ]["axes"] = tmp 56 | realloc_result = reallocation.create_redist_dict( 57 | chpt_path, [-1], "sketch_trace", False, 256, states 58 | ) 59 | with open(realloc_path, "r") as f: 60 | realloc_dict = json.load(f) 61 | 62 | dict_almost_equal(realloc_result, realloc_dict, delta=1) 63 | 64 | 65 | if __name__ == "__main__": 66 | absltest.main() 67 | -------------------------------------------------------------------------------- /optimizers/tearfree/reshaper.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Parameter reshaping module.""" 16 | 17 | import dataclasses 18 | import functools 19 | import numpy as np 20 | 21 | import jax 22 | from jax import numpy as jnp 23 | import optax 24 | 25 | 26 | @dataclasses.dataclass 27 | class Options: 28 | """Parameter reshaping options. 29 | 30 | Attributes: 31 | merge_dims: Collapse dimensions smaller than this number left-to-right, 32 | e.g., [3, 1, 5, 2, 2] becomes [3, 5, 4] with `merge_dims = 4`. Notice 33 | ordering, [2, 3, 2] becomes [6, 2] with `merge_dims = 6`, not its reverse. 34 | block_size: If nonzero, pads all dimensions larger than the block size to a 35 | multiple of the block size. 36 | """ 37 | 38 | merge_dims: int = 8192 39 | block_size: int = 256 40 | 41 | 42 | @dataclasses.dataclass 43 | class _Shapes: 44 | """Shape container.""" 45 | 46 | original_shape: list[int] 47 | merged_shape: list[int] 48 | padded_shape: list[int] 49 | 50 | 51 | def _derive_shapes(options: Options, param: jax.Array) -> _Shapes: 52 | """Derive desired shapes from options.""" 53 | merged = _merge_small_dims(param.shape, options.merge_dims) 54 | if merged == [1]: 55 | return _Shapes( 56 | original_shape=list(param.shape), merged_shape=[], padded_shape=[] 57 | ) 58 | if options.block_size == 0: 59 | padded = merged 60 | else: 61 | padded = [] 62 | for s in merged: 63 | if s >= options.block_size: 64 | s = (s + options.block_size - 1) // options.block_size 65 | s *= options.block_size 66 | padded.append(s) 67 | return _Shapes( 68 | original_shape=list(param.shape), merged_shape=merged, padded_shape=padded 69 | ) 70 | 71 | 72 | def merge(options: Options) -> optax.GradientTransformation: 73 | """Merge and maybe pad gradients, leaving params alone.""" 74 | 75 | if options.merge_dims < 2: 76 | raise ValueError( 77 | "merge_dims ({}) must be at least 2".format(options.merge_dims) 78 | ) 79 | 80 | if options.block_size < 2 and options.block_size != 0: 81 | raise ValueError( 82 | "block_size ({}) must be at least 2 (or 0 to disable)".format( 83 | options.block_size 84 | ) 85 | ) 86 | 87 | def _merge(update: jax.Array, shapes: _Shapes) -> jax.Array: 88 | assert list(update.shape) == shapes.original_shape, (update.shape, shapes) 89 | merged = update.reshape(shapes.merged_shape) 90 | padding = [(0, p - m) for p, m in zip(shapes.padded_shape, shapes.merged_shape)] 91 | if padding and options.block_size > 0: 92 | return jnp.pad(merged, padding) 93 | return merged 94 | 95 | def update( 96 | updates: optax.Updates, state: optax.MaskedNode, params: optax.Params 97 | ) -> tuple[optax.Updates, optax.MaskedNode]: 98 | shapes = jax.tree.map(functools.partial(_derive_shapes, options), params) 99 | new_updates = jax.tree.map(_merge, updates, shapes) 100 | return new_updates, state 101 | 102 | return optax.GradientTransformation(lambda _: optax.MaskedNode(), update) 103 | 104 | 105 | def unmerge(options: Options) -> optax.GradientTransformation: 106 | """Unmerge and unpad gradients, leaving params alone.""" 107 | 108 | def _unmerge(update: jax.Array, shapes: _Shapes) -> jax.Array: 109 | assert list(update.shape) == shapes.padded_shape, (update.shape, shapes) 110 | if options.block_size == 0: 111 | merged = update 112 | else: 113 | merged = update[tuple(slice(0, m) for m in shapes.merged_shape)] 114 | return merged.reshape(shapes.original_shape) 115 | 116 | def update( 117 | updates: optax.Updates, state: optax.MaskedNode, params: optax.Params 118 | ) -> tuple[optax.Updates, optax.MaskedNode]: 119 | shapes = jax.tree.map(functools.partial(_derive_shapes, options), params) 120 | new_updates = jax.tree.map(_unmerge, updates, shapes) 121 | return new_updates, state 122 | 123 | return optax.GradientTransformation(lambda _: optax.MaskedNode(), update) 124 | 125 | 126 | def _merge_small_dims(shape_to_merge, max_dim): 127 | """Merge small dimensions. 128 | 129 | If there are some small dimensions, we collapse them: 130 | e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if max_dim = 1024 131 | [1, 2, 768, 1, 2048] --> [2, 768, 2048] 132 | 133 | Args: 134 | shape_to_merge: Shape to merge small dimensions. 135 | max_dim: Maximal dimension of output shape used in merging. 136 | 137 | Returns: 138 | Merged shape. 139 | """ 140 | if shape_to_merge and np.all(np.array(shape_to_merge) == 1): 141 | return [1] 142 | 143 | resulting_shape = [] 144 | product = 1 145 | for d in shape_to_merge: 146 | if product * d <= max_dim: 147 | product *= d 148 | else: 149 | if product > 1: 150 | resulting_shape.append(product) 151 | product = d 152 | if product > 1: 153 | resulting_shape.append(product) 154 | return resulting_shape 155 | -------------------------------------------------------------------------------- /optimizers/tearfree/reshaper_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for momentum implementation.""" 16 | 17 | from typing import Sequence 18 | 19 | from absl.testing import absltest 20 | from absl.testing import parameterized 21 | import jax 22 | from jax import numpy as jnp 23 | import numpy as np 24 | from optimizers.tearfree import reshaper 25 | 26 | 27 | def _make_invalid_cases() -> Sequence[dict[str, ...]]: 28 | """Generate invalid cases which should throw.""" 29 | return [ 30 | { 31 | "testcase_name": "smallblock", 32 | "invalid_options": reshaper.Options(block_size=1), 33 | }, 34 | { 35 | "testcase_name": "smallmerge", 36 | "invalid_options": reshaper.Options(merge_dims=0), 37 | }, 38 | ] 39 | 40 | 41 | def _make_expected_shape_cases() -> Sequence[dict[str, ...]]: 42 | cases = [ 43 | {"in_shape": [4], "merge": 2, "block": 3, "out_shape": [6]}, 44 | {"in_shape": [3], "merge": 2, "block": 3, "out_shape": [3]}, 45 | {"in_shape": [1, 3, 1], "merge": 2, "block": 3, "out_shape": [3]}, 46 | {"in_shape": [1, 3, 1], "merge": 3, "block": 3, "out_shape": [3]}, 47 | {"in_shape": [1, 3, 1], "merge": 3, "block": 4, "out_shape": [3]}, 48 | {"in_shape": [1, 3, 1, 2], "merge": 2, "block": 3, "out_shape": [3, 2]}, 49 | {"in_shape": [4, 1, 5], "merge": 2, "block": 3, "out_shape": [6, 6]}, 50 | {"in_shape": [1], "merge": 2, "block": 2, "out_shape": []}, 51 | {"in_shape": [1, 1, 1], "merge": 2, "block": 2, "out_shape": []}, 52 | {"in_shape": [1, 1, 1], "merge": 2, "block": 2, "out_shape": []}, 53 | {"in_shape": [3, 1, 5, 2, 2], "merge": 4, "block": 10, "out_shape": [3, 5, 4]}, 54 | {"in_shape": [2, 3, 2], "merge": 6, "block": 10, "out_shape": [6, 2]}, 55 | ] 56 | for case in cases[:]: 57 | if all(i <= case["block"] for i in case["in_shape"]): 58 | block0 = case.copy() 59 | block0["block"] = 0 60 | cases.append(block0) 61 | return cases 62 | 63 | 64 | class ReshaperTest(parameterized.TestCase): 65 | """Basic test for shampoo implementation.""" 66 | 67 | @parameterized.named_parameters(_make_invalid_cases()) 68 | def test_invalid(self, invalid_options): 69 | with self.assertRaises(ValueError): 70 | reshaper.merge(invalid_options) 71 | 72 | @parameterized.parameters(_make_expected_shape_cases()) 73 | def test_expected_shape(self, in_shape, merge, block, out_shape): 74 | options = reshaper.Options(merge_dims=merge, block_size=block) 75 | init_fn, update_fn = reshaper.merge(options) 76 | init = jnp.zeros(in_shape) 77 | out, _ = update_fn(init, init_fn(None), init) 78 | self.assertSequenceEqual(out.shape, out_shape) 79 | 80 | @parameterized.parameters(_make_expected_shape_cases()) 81 | def test_inversion(self, in_shape, merge, block, out_shape): 82 | del out_shape 83 | options = reshaper.Options(merge_dims=merge, block_size=block) 84 | init_fn, update_fn = reshaper.merge(options) 85 | init = jax.random.normal(jax.random.PRNGKey(0), in_shape) 86 | out, _ = update_fn(init, init_fn(None), init) 87 | init_fn, update_fn = reshaper.unmerge(options) 88 | recover, _ = update_fn(out, init_fn(None), init) 89 | np.testing.assert_array_equal(init, recover) 90 | 91 | def test_tree(self): 92 | shapes = {"w": [[{"b": (3, 2)}]], "z": (1, 2, 1)} 93 | init = jax.tree.map(jnp.zeros, shapes, is_leaf=lambda x: isinstance(x, tuple)) 94 | options = reshaper.Options(merge_dims=2, block_size=2) 95 | init_fn, update_fn = reshaper.merge(options) 96 | out, _ = update_fn(init, init_fn(None), init) 97 | out_shapes = jax.tree.map(lambda x: tuple(x.shape), out) 98 | expected_shapes = {"w": [[{"b": (4, 2)}]], "z": (2,)} 99 | 100 | self.assertEqual(out_shapes, expected_shapes) 101 | 102 | 103 | if __name__ == "__main__": 104 | absltest.main() 105 | -------------------------------------------------------------------------------- /optimizers/tearfree/second_order.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Various strategies for tracking second order statistics.""" 16 | 17 | import dataclasses 18 | import enum 19 | from typing import Optional 20 | 21 | import optax 22 | from optimizers.tearfree import praxis_shim 23 | from optimizers.tearfree import reshaper 24 | from optimizers.tearfree import shampoo 25 | from optimizers.tearfree import sketchy 26 | 27 | 28 | @enum.unique 29 | class SecondOrderType(enum.Enum): 30 | """Different second order covariance tracking methods.""" 31 | 32 | SHAMPOO = "shampoo" 33 | SKETCHY = "sketchy" 34 | 35 | 36 | @dataclasses.dataclass 37 | class Options: 38 | """Toggle which second order statistics to track. 39 | 40 | Attributes: 41 | merge_dims: Merges small dimensions, see `reshaper.Options.merge_dims`. 42 | second_order_type: Which optimizer to use for grafting updates. 43 | shampoo_options: Options for blocked shampoo. 44 | sketchy_options: Options for Sketchy. 45 | """ 46 | 47 | merge_dims: int = 8192 48 | second_order_type: SecondOrderType = SecondOrderType.SHAMPOO 49 | shampoo_options: Optional[shampoo.Options] = dataclasses.field( 50 | default_factory=shampoo.Options 51 | ) 52 | sketchy_options: Optional[sketchy.Options] = None 53 | 54 | 55 | def apply(options: Options) -> praxis_shim.ShardedGradientTransformation: 56 | """Generate the second order update from options.""" 57 | reshaper_options = _reshaper_options(options) 58 | merge_tx = reshaper.merge(reshaper_options) 59 | precond_tx = _update_stats_and_precondition(options) 60 | 61 | def wrap_init(params: optax.Params): 62 | reshaped_params, _ = merge_tx.update(params, merge_tx.init(params), params) 63 | return precond_tx.init(reshaped_params) 64 | 65 | # TODO(vladf): later, we'll need to wrap pspec as well. 66 | wrapped_precond_tx = praxis_shim.ShardedGradientTransformation( 67 | wrap_init, precond_tx.update, precond_tx.init_partition_spec 68 | ) 69 | 70 | return praxis_shim.sharded_chain( 71 | merge_tx, wrapped_precond_tx, reshaper.unmerge(reshaper_options) 72 | ) 73 | 74 | 75 | def _reshaper_options(options: Options) -> reshaper.Options: 76 | if options.second_order_type == SecondOrderType.SHAMPOO: 77 | assert options.shampoo_options 78 | block_size = options.shampoo_options.block_size 79 | return reshaper.Options(options.merge_dims, block_size) 80 | elif options.second_order_type == SecondOrderType.SKETCHY: 81 | return reshaper.Options(options.merge_dims, 0) 82 | else: 83 | raise ValueError( 84 | "unknown second order type {}".format(options.second_order_type) 85 | ) 86 | 87 | 88 | def _update_stats_and_precondition( 89 | options: Options, 90 | ) -> praxis_shim.ShardedGradientTransformation: 91 | if options.second_order_type == SecondOrderType.SHAMPOO: 92 | assert options.shampoo_options 93 | return shampoo.apply(options.shampoo_options) 94 | elif options.second_order_type == SecondOrderType.SKETCHY: 95 | assert options.sketchy_options 96 | return sketchy.apply(options.sketchy_options) 97 | else: 98 | raise ValueError( 99 | "unknown second order type {}".format(options.second_order_type) 100 | ) 101 | -------------------------------------------------------------------------------- /optimizers/tearfree/shampoo.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Shampoo second-order statistics preconditioning.""" 16 | 17 | import dataclasses 18 | import functools 19 | import math 20 | import string 21 | from typing import Iterator, NamedTuple, Optional, Sequence 22 | 23 | import chex 24 | import jax 25 | from jax import numpy as jnp 26 | import optax 27 | from optimizers.tearfree import praxis_shim 28 | from utils import write_note 29 | 30 | 31 | @dataclasses.dataclass 32 | class Options: 33 | """Shampoo covariance approximation options. 34 | 35 | See https://arxiv.org/abs/2002.09018. 36 | 37 | Attributes: 38 | use_CASPR_variant: If true, uses the CASPR variant of Shampoo. 39 | block_size: Determines the block size for Shampoo's block-diagonal gradient 40 | covariance approximation. 41 | update_preconditioners_freq: Number of steps between preconditioner updates. 42 | update_statistics_freq: Number of steps between statistics updates. 43 | second_moment_decay: Decay rate for second moment exponential moving 44 | average. If 1.0 then sums. 45 | """ 46 | 47 | use_CASPR_variant: bool = False 48 | block_size: int = 256 49 | update_preconditioners_freq: int = 20 50 | update_statistics_freq: int = 1 51 | second_moment_decay: float = 0.999 52 | # TODO(vladf): 53 | # lb, rb, b sharding: maybe later? 54 | # spmd_mesh_axis_names: Sequence[str] = () maybe later? 55 | 56 | 57 | def apply(options: Options) -> praxis_shim.ShardedGradientTransformation: 58 | """Return gradient transform for (blocked) shampoo preconditioning.""" 59 | 60 | _validate(options) 61 | 62 | # raise if no unit dims (must be scalar) 63 | 64 | # not intentional constants from distributed shampoo 65 | # exponent_override = 0 66 | # matrix_epsilon = 0 and eigh = True 67 | 68 | return praxis_shim.ShardedGradientTransformation( 69 | functools.partial(_init, options), 70 | functools.partial(_update, options), 71 | functools.partial(_pspec, options), 72 | ) 73 | 74 | 75 | class _AxesBlocks(NamedTuple): 76 | """Represents statistics or preconditioner matrices for a single tensor. 77 | 78 | Maintains the second-order statistics for gradients, to be used in second 79 | order optimization. 80 | 81 | There are two key matrices that are maintained per axis, when not using any 82 | blocking approximations. 83 | 84 | For each axis `i` with length `d_i`, we track the `(d_i, d_i)` covariances 85 | and their inverse p-th roots, where `p` is twice the rank of the gradients 86 | whose covariances we're tracking, excluding unit dimensions. 87 | 88 | A covariance for a higher-order tensor's i-th axis is recovered from the outer 89 | product of contracting all but the axis i. E.g., for an order-3 tensor G, the 90 | covariance for its 0th dimension is `C_{ij} = sum_{k,l} G_{ikl}G_{jkl}`. 91 | 92 | Since these matrices can be quite large for tensors with large dimensions, we 93 | introduce an approximation which stores just block diagonal components of 94 | these matrices. This corresponds to the full covariance statistics of the 95 | disjoint partitions of these tensors, with large dimensions blocked by a 96 | provided block size. 97 | 98 | Below, we refer to all large dimensions as those at least equal to the block 99 | size. 100 | 101 | When storing block diagonal matrices, we store them as `N` blocks of size `B`. 102 | 103 | Attributes: 104 | stats: A list of length equal to the tensor's ndim with blocks of matrices 105 | of shape [N, B, B] where B is an axis-specific block size, but at most 106 | block_size. 107 | roots: Same shape as stats, inverse p-th roots for statistics, where p is 108 | the twice the rank of the original parameter to optimize. 109 | """ 110 | 111 | stats: list[jax.Array] 112 | roots: list[jax.Array] 113 | 114 | 115 | class _ShampooState(NamedTuple): 116 | # A scalar int32 for step count. 117 | count: jax.Array 118 | # A tree of the same shape as the params of _AxesBlocks leaves of f32. 119 | blocks: chex.ArrayTree 120 | 121 | 122 | @dataclasses.dataclass(frozen=True) 123 | class _BlocksMetadata: 124 | """Metadata for _AxesBlocks to track indexing information. 125 | 126 | Attributes: 127 | block_sizes: Per-dimension statistics & preconditioner block sizes, length 128 | is equal to rank of tensor we have metadata for. 129 | num_blocks: `N`, the total number of blocks. 130 | debug_name: A way to refer to this parameter in debug messages. 131 | large_block_size: The block size originally specified in the shampoo 132 | options, this is the minimum size for a dimension to be considered large. 133 | param_shape: The shape of the original parameter we're blocking. 134 | large_axes: Axes with dimension at least `large_block_size` in the original 135 | tensor. 136 | blocks_per_large_axis: Number of blocks in the corresponding axis for each 137 | axis in large_axes. 138 | blocks_axis: The axis of the the `N` in the blocked tensor (see _blockify). 139 | """ 140 | 141 | block_sizes: list[int] 142 | num_blocks: int 143 | debug_name: str 144 | large_block_size: int 145 | param_shape: list[int] 146 | large_axes: list[int] 147 | blocks_per_large_axis: list[int] 148 | blocks_axis: int 149 | 150 | 151 | def _blocks_metadata( 152 | options: Options, param_shape: Sequence[int], debug: str 153 | ) -> _BlocksMetadata: 154 | """Generate the blocks metadata for a parameter.""" 155 | dims = [min(dim, options.block_size) for dim in param_shape] 156 | large_axes = [i for i, d in enumerate(param_shape) if d >= options.block_size] 157 | blocks_per_large_axis = [param_shape[i] // options.block_size for i in large_axes] 158 | num_blocks = math.prod(blocks_per_large_axis + [1]) 159 | 160 | return _BlocksMetadata( 161 | block_sizes=dims, 162 | num_blocks=num_blocks, 163 | debug_name=debug, 164 | large_block_size=options.block_size, 165 | large_axes=large_axes, 166 | param_shape=list(param_shape), 167 | blocks_per_large_axis=blocks_per_large_axis, 168 | blocks_axis=min(large_axes, default=0), 169 | ) 170 | 171 | 172 | def _validate(options: Options) -> None: 173 | """Raise ValueError if options are invalid.""" 174 | if options.block_size <= 1: 175 | raise ValueError(f"block_size ({options.block_size}) must be >1") 176 | 177 | if options.update_preconditioners_freq <= 0: 178 | raise ValueError( 179 | "update_preconditioners_freq ({}) must be positive".format( 180 | options.update_preconditioners_freq 181 | ) 182 | ) 183 | 184 | if options.update_statistics_freq <= 0: 185 | raise ValueError( 186 | "update_statistics_freq ({}) must be positive".format( 187 | options.update_statistics_freq 188 | ) 189 | ) 190 | 191 | if not (0 <= options.second_moment_decay <= 1): 192 | raise ValueError( 193 | f"second_moment_decay ({options.second_moment_decay}) " 194 | "should be in [0, 1]" 195 | ) 196 | 197 | 198 | def _init(options: Options, params: optax.Params) -> _ShampooState: 199 | """Initialize stats to 0 and preconditioners to identity.""" 200 | 201 | def make_blocks(path: ..., param: jax.Array) -> _AxesBlocks: 202 | if any(dim == 1 for dim in param.shape): 203 | raise ValueError( 204 | "param {} shape ({}) has unit dimensions".format(path, param.shape) 205 | ) 206 | 207 | if sum(dim >= options.block_size for dim in param.shape) > 2: 208 | raise ValueError( 209 | "param {} shape ({}) has >2 large dims for block size {}".format( 210 | path, param.shape, options.block_size 211 | ) 212 | ) 213 | 214 | if any( 215 | dim % options.block_size != 0 216 | for dim in param.shape 217 | if dim >= options.block_size 218 | ): 219 | raise ValueError( 220 | "param {} shape ({}) has large dims indivisible by block size {}".format( 221 | path, param.shape, options.block_size 222 | ) 223 | ) 224 | 225 | meta = _blocks_metadata(options, param.shape, str(path)) 226 | n = meta.num_blocks 227 | dims = meta.block_sizes 228 | stats = [jnp.zeros((n, d, d)) for d in dims] 229 | precond = [jnp.eye(d) * jnp.ones((n, 1, 1)) for d in dims] 230 | return _AxesBlocks(stats, precond) 231 | 232 | blocks = jax.tree_util.tree_map_with_path(make_blocks, params) 233 | 234 | # Calculate and print sizes for stats and preconditioners together 235 | total_n_elements = sum(leaf.size for leaf in jax.tree.leaves(blocks)) 236 | total_size_MB = sum( 237 | leaf.size * leaf.dtype.itemsize / (2**20) for leaf in jax.tree.leaves(blocks) 238 | ) 239 | if jax.process_index() == 0: 240 | print( 241 | f"Shampoo Stats and Preconditioners size: {total_n_elements} elements, " 242 | f"{total_size_MB:.2f} MB" 243 | ) 244 | 245 | return _ShampooState(count=jnp.zeros([], jnp.int32), blocks=blocks) 246 | 247 | 248 | def _pspec( 249 | options: Options, params: praxis_shim.NestedHParams 250 | ) -> praxis_shim.NestedHParams: 251 | """Generate sharding specification for shampoo state.""" 252 | count_pspec = praxis_shim.WeightHParams( 253 | shape=[], 254 | init=None, 255 | dtype=jnp.int32, 256 | collections=None, 257 | tensor_split_dims_mapping=[], 258 | ) 259 | 260 | def make_blocks_pspec( 261 | path: ..., param: praxis_shim.WeightHParams 262 | ) -> praxis_shim.NestedHParams: 263 | meta = _blocks_metadata(options, param.shape, str(path)) 264 | num_blocks = meta.num_blocks 265 | dims = meta.block_sizes 266 | replicated = functools.partial( 267 | praxis_shim.WeightHParams, 268 | init=None, 269 | dtype=jnp.float32, 270 | collections=None, 271 | tensor_split_dims_mapping=[-1, -1, -1], 272 | ) 273 | stats = [replicated((num_blocks, d, d)) for d in dims] 274 | precond = stats 275 | return dict(stats=stats, roots=precond) 276 | 277 | return dict( 278 | count=count_pspec, 279 | blocks=jax.tree_util.tree_map_with_path( 280 | make_blocks_pspec, params, is_leaf=lambda x: hasattr(x, "shape") 281 | ), 282 | ) 283 | 284 | 285 | def _update( 286 | options: Options, 287 | updates: optax.Updates, 288 | state: _ShampooState, 289 | params: Optional[optax.Params] = None, 290 | ) -> tuple[optax.Updates, _ShampooState]: 291 | """Update internal shampoo stats and precondition gradients.""" 292 | del params 293 | meta = jax.tree_util.tree_map_with_path( 294 | lambda path, x: _blocks_metadata(options, x.shape, str(path)), updates 295 | ) 296 | blocks = state.blocks 297 | blockified_updates = jax.tree.map(_blockify, updates, meta) 298 | is_block = lambda x: isinstance(x, _AxesBlocks) 299 | 300 | stats_updated_blocks = functools.partial( 301 | jax.tree.map, 302 | functools.partial(_update_block_stats, options.second_moment_decay), 303 | blockified_updates, 304 | blocks, 305 | meta, 306 | is_leaf=is_block, 307 | ) 308 | should_update_stats = (state.count % options.update_statistics_freq) == 0 309 | blocks = jax.lax.cond(should_update_stats, stats_updated_blocks, lambda: blocks) 310 | 311 | precond_updated_blocks = functools.partial( 312 | jax.tree.map, _update_block_precond, blocks, meta, is_leaf=is_block 313 | ) 314 | should_update_precond = (state.count % options.update_preconditioners_freq) == 0 315 | blocks = jax.lax.cond(should_update_precond, precond_updated_blocks, lambda: blocks) 316 | new_state = _ShampooState(count=state.count + 1, blocks=blocks) 317 | if options.use_CASPR_variant: 318 | write_note("CASPR operations:") 319 | new_updates = jax.tree.map( 320 | _precondition_blocks_caspr, 321 | blockified_updates, 322 | blocks, 323 | meta, 324 | is_leaf=is_block, 325 | ) 326 | else: 327 | write_note("Shampoo operations:") 328 | new_updates = jax.tree.map( 329 | _precondition_blocks_shampoo, 330 | blockified_updates, 331 | blocks, 332 | meta, 333 | is_leaf=is_block, 334 | ) 335 | new_updates = jax.tree.map(_deblockify, new_updates, meta) 336 | 337 | return new_updates, new_state 338 | 339 | 340 | def _blockify(x: jax.Array, meta: _BlocksMetadata) -> jax.Array: 341 | """Reshape the update such that it is blocked along large dimensions. 342 | 343 | Inserts the `N` dimension dimension right on the first axis in 344 | `meta.large_axes`, which is the `meta.blocks_axis` in the returned tensor. 345 | 346 | Shifts all original axes in `x` that are on or after `meta.blocks_axis` 347 | (including what was originally the first axis in `meta.large_axes`) forward 348 | by one. All large axes will now be of length equal to the largest block size. 349 | 350 | In the case that there's no blocking, we put a dummy blocks axis in axis 0 351 | with dimension 1, so the handling of the original axes is the same as the 352 | blocked cases. 353 | 354 | For example: 355 | - Suppose block size is 5 and x is shaped [3, 20, 25, 4]. Then 356 | x has two large axes (1 and 2). The resulting blocked value will be 357 | shaped [3, (4*5), 5, 5, 4], with the large axes being converted 358 | to block size and a new 20-dimensional axis with the product of the 359 | 4 blocks for the original axis 1 and 5 blocks for the original axis 2. 360 | All other axes are kept the same. Note meta.blocks_axis precedes the 361 | large axes' new locations, so it's set to 1. 362 | - Suppose block size is 5 and x is shaped [5, 2]. The result will be 363 | [1, 5, 2], with the first dimension corresponding to the single block at 364 | axis 0. 365 | - Suppose block size is 5 and x is [3, 4]. There are no large axes, but to 366 | get rid of edge cases we still add a meta.blocks_axis at axis 0 with a 367 | single block [1, 3, 4]. 368 | - Suppose block size is 5 and x is [15, 2, 10]. We'll return 369 | [(3*2), 5, 2, 5], following the same rules as before. Note the large 370 | axes stay in place. 371 | 372 | Args: 373 | x: Input to block. 374 | meta: Metadata about the input. 375 | 376 | Returns: 377 | A blocked version of the input and the dimension with the number of 378 | blocks. 379 | """ 380 | assert list(x.shape) == meta.param_shape, (x.shape, meta.param_shape) 381 | 382 | if not meta.large_axes: 383 | # Just create a unit N/blocks axis. 384 | return jnp.expand_dims(x, meta.blocks_axis) 385 | 386 | if len(meta.large_axes) == 1: 387 | # Block the only large axis. 388 | before, after = _split_exclusively(x.shape, meta.large_axes) 389 | new_shape = before + [meta.num_blocks, meta.large_block_size] + after 390 | return x.reshape(new_shape) 391 | 392 | assert len(meta.large_axes) == 2, meta.large_axes 393 | 394 | # Extract the blocks from both large axes. 395 | l_blocks, r_blocks = meta.blocks_per_large_axis 396 | before, middle, after = _split_exclusively(x.shape, meta.large_axes) 397 | stitch = lambda l, r: before + l + middle + r + after 398 | split_blocked_shape = stitch( 399 | [l_blocks, meta.large_block_size], [r_blocks, meta.large_block_size] 400 | ) 401 | split_blocked_x = x.reshape(split_blocked_shape) 402 | 403 | # Move over the blocks from the right axis next to the left one. 404 | perm = list(range(len(split_blocked_shape))) 405 | l_blocks_ix = len(before) 406 | r_blocks_ix = len(before) + 2 + len(middle) 407 | perm.pop(r_blocks_ix) 408 | perm.insert(l_blocks_ix + 1, r_blocks_ix) 409 | adjacent_blocked_x = jnp.transpose(split_blocked_x, perm) 410 | 411 | # Transpose the previous sharding too. 412 | new_shape = stitch( 413 | [meta.num_blocks, meta.large_block_size], [meta.large_block_size] 414 | ) 415 | reshaped = adjacent_blocked_x.reshape(new_shape) 416 | assert l_blocks_ix == meta.blocks_axis 417 | 418 | return reshaped 419 | 420 | 421 | def _deblockify(blocked_x: jax.Array, meta: _BlocksMetadata) -> jax.Array: 422 | """Invert _blockify().""" 423 | if not meta.large_axes: 424 | return jnp.squeeze(blocked_x, meta.blocks_axis) 425 | 426 | if len(meta.large_axes) == 1: 427 | return blocked_x.reshape(meta.param_shape) 428 | 429 | assert len(meta.large_axes) == 2 430 | 431 | # Re-split the blocks axis. 432 | assert blocked_x.shape[meta.blocks_axis] == meta.num_blocks 433 | before, after = _split_exclusively(blocked_x.shape, [meta.blocks_axis]) 434 | split_blocks_shape = before + meta.blocks_per_large_axis + after 435 | split_blocked_x = blocked_x.reshape(split_blocks_shape) 436 | 437 | # Move the right large axis blocks back in front of their axis. 438 | perm = list(range(len(split_blocked_x.shape))) 439 | # In blocked_x: 440 | # [..., blocks axis, left block, ..., right block, ...] 441 | # ^blocks_axis ^large_axes[1] + 1 442 | # In split_blocked_x: 443 | # [..., left blocks, right blocks, left block, ..., right block, ...] 444 | # ^blocks_axis ^large_axes[1] + 2 445 | r_blocks_ix = meta.blocks_axis + 1 446 | r_blocks_val = perm.pop(r_blocks_ix) 447 | # After pop: 448 | # [..., left blocks, left block, ..., right block, ...] 449 | # ^blocks_axis ^large_axes[1] + 1 450 | r_blocked_axis_ix = meta.large_axes[1] + 1 451 | perm.insert(r_blocked_axis_ix, r_blocks_val) 452 | split_blocked_x = jnp.transpose(split_blocked_x, perm) 453 | 454 | reshaped = jnp.reshape(split_blocked_x, meta.param_shape) 455 | return reshaped 456 | 457 | 458 | def _update_block_stats( 459 | second_moment_decay: float, 460 | update: jax.Array, 461 | block: _AxesBlocks, 462 | meta: _BlocksMetadata, 463 | ) -> _AxesBlocks: 464 | """Update covariance statistics given a blocked gradient.""" 465 | 466 | new_stats = [] 467 | with jax.named_scope("ShampooStats"): 468 | for axis, cov in enumerate(block.stats): 469 | all_axes = list(range(len(meta.param_shape))) 470 | all_axes.remove(axis) 471 | 472 | dot_all = functools.partial(jnp.tensordot, axes=(all_axes, all_axes)) 473 | batched_tensordot = jax.vmap(dot_all, in_axes=meta.blocks_axis, out_axes=0) 474 | new_cov = batched_tensordot(update, update) 475 | new_stats.append(_ema_update(cov, new_cov, second_moment_decay)) 476 | 477 | return _AxesBlocks(stats=new_stats, roots=block.roots) 478 | 479 | 480 | def _pth_inv_root(p: int, cov: jax.Array) -> jax.Array: 481 | """Calculate a batch of p-th inverse roots.""" 482 | eps = 1e-6 483 | w, v = jnp.linalg.eigh(cov) 484 | mask = w <= eps * jnp.max(w) 485 | half = jnp.where(mask, 1.0, w) ** (-0.5 / p) 486 | half = jnp.where(mask, 0.0, half) 487 | half_v = jnp.expand_dims(half, -2) * v 488 | return jnp.einsum("bik,bjk->bij", half_v, half_v) 489 | 490 | 491 | def _update_block_precond(block: _AxesBlocks, meta: _BlocksMetadata) -> _AxesBlocks: 492 | """Update preconditioners.""" 493 | # p=2 works better 494 | p = 2 # len(meta.param_shape) * 2 495 | 496 | with jax.named_scope("PthInvRoot"): 497 | new_roots = list(map(functools.partial(_pth_inv_root, p), block.stats)) 498 | 499 | return _AxesBlocks(roots=new_roots, stats=block.stats) 500 | 501 | 502 | def _precondition_blocks_shampoo( 503 | update: jax.Array, blocks: _AxesBlocks, meta: _BlocksMetadata 504 | ) -> jax.Array: 505 | """Precondition blocked gradients.""" 506 | it = _einsum_letters(meta) 507 | blocks_axis_letter = next(it) 508 | 509 | # Contract along the innermost axis of each preconditioner, 510 | # making the other equal-length axis the output. 511 | contraction_letters = [next(it) for _ in meta.param_shape] 512 | output_letters = [next(it) for _ in meta.param_shape] 513 | preconditioners = blocks.roots 514 | preconditioner_inputs = [ 515 | blocks_axis_letter + o + c for c, o in zip(contraction_letters, output_letters) 516 | ] 517 | 518 | blocked_input = contraction_letters[:] 519 | blocked_input.insert(meta.blocks_axis, blocks_axis_letter) 520 | blocked_input = "".join(blocked_input) 521 | blocked_output = output_letters[:] 522 | blocked_output.insert(meta.blocks_axis, blocks_axis_letter) 523 | blocked_output = "".join(blocked_output) 524 | 525 | # Build up the einsum equation and invoke it. 526 | inputs = ",".join([blocked_input] + preconditioner_inputs) 527 | formula = inputs + "->" + blocked_output 528 | with jax.named_scope("PreconditionShampoo"): 529 | write_note(f"{formula} {update.shape} {[x.shape for x in preconditioners]}") 530 | return jnp.einsum(formula, update, *preconditioners) 531 | 532 | 533 | def _precondition_blocks_caspr( 534 | update: jax.Array, blocks: _AxesBlocks, meta: _BlocksMetadata 535 | ) -> jax.Array: 536 | """Precondition blocked gradients.""" 537 | it = _einsum_letters(meta) 538 | blocks_axis_letter = next(it) 539 | 540 | # Contract along the innermost axis of each preconditioner, 541 | # making the other equal-length axis the output. 542 | contraction_letters = [next(it) for _ in meta.param_shape] 543 | output_letters = [next(it) for _ in meta.param_shape] 544 | preconditioners = blocks.roots 545 | preconditioner_inputs = [ 546 | blocks_axis_letter + o + c for c, o in zip(contraction_letters, output_letters) 547 | ] 548 | 549 | blocked_input = contraction_letters[:] 550 | blocked_input.insert(meta.blocks_axis, blocks_axis_letter) 551 | blocked_input = "".join(blocked_input) 552 | 553 | preconditioner_outputs = [] 554 | for i, o in enumerate(output_letters): 555 | p_o = list(blocked_input) 556 | p_o[i + 1] = o 557 | preconditioner_outputs.append("".join(p_o)) 558 | 559 | formulas = [ 560 | ",".join([blocked_input] + [p_in]) + "->" + p_out 561 | for p_in, p_out in zip(preconditioner_inputs, preconditioner_outputs) 562 | ] 563 | 564 | with jax.named_scope("PreconditionCASPR"): 565 | write_note(f"{formulas} {update.shape} {[x.shape for x in preconditioners]}") 566 | for i in range(2): 567 | to_sum = [] 568 | for f, p in zip(formulas, preconditioners): 569 | to_sum.append(jnp.einsum(f, update, p)) 570 | update = sum(to_sum) 571 | return update 572 | 573 | 574 | def _split_exclusively(ls: Sequence[int], splits: Sequence[int]) -> list[list[int]]: 575 | """Returns possibly-empty segments between sorted split points in ls.""" 576 | assert all( 577 | l < r for l, r in zip(splits, splits[1:]) 578 | ), f"splits {splits} must be distinct ascending" 579 | assert all( 580 | 0 <= i < len(ls) for i in splits 581 | ), f"splits {splits} must index into list {ls} of length {len(ls)}" 582 | splits = [-1] + list(splits) + [len(ls)] 583 | return [list(ls[l + 1 : r]) for l, r in zip(splits, splits[1:])] 584 | 585 | 586 | def _einsum_letters(meta: _BlocksMetadata) -> Iterator[str]: 587 | for c in string.ascii_letters: 588 | yield c 589 | 590 | raise ValueError( 591 | f"shape {meta.param_shape} too high-dimensional for {meta.debug_name}" 592 | ) 593 | 594 | 595 | def _ema_update(old: jax.Array, new: jax.Array, decay: float) -> jax.Array: 596 | if decay == 1.0: 597 | return old + new 598 | return old * decay + new * (1 - decay) 599 | -------------------------------------------------------------------------------- /optimizers/tearfree/shampoo_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for momentum implementation.""" 16 | 17 | import itertools 18 | from typing import Sequence 19 | 20 | from absl import logging 21 | from absl.testing import absltest 22 | from absl.testing import parameterized 23 | import jax 24 | from jax import numpy as jnp 25 | import numpy as np 26 | from optimizers.tearfree import shampoo 27 | 28 | 29 | def _make_invalid_cases() -> Sequence[dict[str, ...]]: 30 | """Generate invalid cases which should throw.""" 31 | return [ 32 | { 33 | "testcase_name": "block_size0", 34 | "invalid_options": shampoo.Options(block_size=0), 35 | }, 36 | { 37 | "testcase_name": "precond0", 38 | "invalid_options": shampoo.Options(update_preconditioners_freq=0), 39 | }, 40 | { 41 | "testcase_name": "stats0", 42 | "invalid_options": shampoo.Options(update_statistics_freq=0), 43 | }, 44 | { 45 | "testcase_name": "decay_neg", 46 | "invalid_options": shampoo.Options(second_moment_decay=-0.1), 47 | }, 48 | { 49 | "testcase_name": "decay_large", 50 | "invalid_options": shampoo.Options(second_moment_decay=1.1), 51 | }, 52 | { 53 | "testcase_name": "block_size1", 54 | "invalid_options": shampoo.Options(block_size=1), 55 | }, 56 | ] 57 | 58 | 59 | def _make_blockify_deblockify_cases() -> Sequence[dict[str, ...]]: 60 | shapes_blocks = [ 61 | (tuple(), 2, "scalar"), 62 | ((5,), 6, "1d_0large"), 63 | ((5,), 5, "1d_1large"), 64 | ((4,), 2, "1d_1large_moreblocks"), 65 | ((2, 3), 6, "2d_0large"), 66 | ((2, 3), 3, "2d_1large"), 67 | ((2, 2), 2, "2d_2large"), 68 | ((4, 4), 2, "2d_2large_moreblocks"), 69 | ((2, 3, 3, 2), 4, "highdim_0large"), 70 | ((2, 3, 2, 2), 3, "highdim_1large"), 71 | ((2, 2 * 3, 2, 2), 3, "highdim_1large_moreblocks"), 72 | ((2, 3, 3, 2), 3, "highdim_2large_together"), 73 | ((2, 3, 2, 3), 3, "highdim_2large_separate"), 74 | ] 75 | 76 | cases = [] 77 | for shape, block_size, name in shapes_blocks: 78 | cases.append(dict(shape=shape, block_size=block_size, testcase_name=name)) 79 | return cases 80 | 81 | 82 | class ShampooTest(parameterized.TestCase): 83 | """Basic test for shampoo implementation.""" 84 | 85 | def setUp(self): 86 | super().setUp() 87 | jax.config.update("jax_debug_nans", True) 88 | 89 | def _unroll(self, options, n, shape): 90 | """Generate states and grad updates n times.""" 91 | rng = jax.random.PRNGKey(0) 92 | params = jnp.zeros(shape) 93 | grads = jax.random.normal(rng, (n, *shape)) 94 | return self._unroll_concrete(options, params, grads) 95 | 96 | def _unroll_concrete(self, options, params, grads): 97 | """Unrolls with provided params and grads.""" 98 | tx = shampoo.apply(options) 99 | init = tx.init(params) 100 | 101 | def reduce(state, grad): 102 | new_grad, new_state = tx.update(grad, state, params) 103 | return new_state, new_grad 104 | 105 | _, out_grads = jax.lax.scan(reduce, init, grads) 106 | return grads, out_grads 107 | 108 | @parameterized.parameters( 109 | {"shape": (1, 2, 1)}, 110 | {"shape": (1, 1, 3, 1, 2, 1)}, 111 | {"shape": (2, 1, 3, 2)}, 112 | {"shape": (1, 1)}, 113 | {"shape": (1,)}, 114 | ) 115 | def test_unit_dims_raise(self, shape): 116 | """Assert raises if unit dimensions are present.""" 117 | with self.assertRaises(ValueError): 118 | self._unroll(shampoo.Options(), 1, shape) 119 | 120 | def test_scalars(self): 121 | """Validate scalar parameters aren't preconditioned.""" 122 | grads, out_grads = self._unroll(shampoo.Options(), 2, tuple()) 123 | np.testing.assert_allclose(grads, out_grads) 124 | 125 | def _root(self, x, p): 126 | """Computes the matrix root x**(-1/(2*p)).""" 127 | return shampoo._pth_inv_root(p * 2, x[np.newaxis, ...])[0] 128 | 129 | @parameterized.parameters(1, 2) 130 | def test_basic(self, ndim): 131 | """Check basic numerical example without blocking or decay.""" 132 | options = shampoo.Options(second_moment_decay=1.0) 133 | shape = (2,) * ndim 134 | nsteps = 2 135 | grads, out_grads = self._unroll(options, 2, shape) 136 | 137 | l, r = 0, 0 138 | for i in range(nsteps): 139 | if len(shape) == 1: 140 | l += np.multiply.outer(grads[i], grads[i]) 141 | elif len(shape) == 2: 142 | l += grads[i].dot(grads[i].T) 143 | r += grads[i].T.dot(grads[i]) 144 | 145 | pl, pr = self._root(l, len(shape)), r 146 | if len(shape) == 2: 147 | pr = self._root(r, len(shape)) 148 | 149 | pg = pl.dot(grads[i]) 150 | if len(shape) == 2: 151 | pg = pg.dot(pr) 152 | 153 | np.testing.assert_allclose(pg, out_grads[i], rtol=1e-3) 154 | 155 | def test_basic_block(self): 156 | """Check basic numerical example with blocking.""" 157 | options = shampoo.Options(second_moment_decay=1.0, block_size=2) 158 | shape = (4,) 159 | nsteps = 2 160 | 161 | # Don't use unroll here to allow state-printing. 162 | rng = jax.random.PRNGKey(0) 163 | params = jnp.zeros(shape) 164 | grads = jax.random.normal(rng, (nsteps, *shape)) 165 | 166 | tx = shampoo.apply(options) 167 | state = tx.init(params) 168 | logging.info("init state: %s", state) 169 | 170 | b0, b1 = 0, 0 171 | for i in range(nsteps): 172 | out_grad, state = tx.update(grads[i], state, params) 173 | logging.info("state @ %s: %s", i, state) 174 | g0, g1 = grads[i][:2], grads[i][2:] 175 | b0 += np.multiply.outer(g0, g0) 176 | b1 += np.multiply.outer(g1, g1) 177 | p0, p1 = self._root(b0, len(shape)), self._root(b1, len(shape)) 178 | logging.info("g0 %s g1 %s", g0, g1) 179 | logging.info("b0 %s b1 %s", b0, b1) 180 | logging.info("p0 %s p1 %s", p0, p1) 181 | pg = np.concatenate([p0.dot(g0), p1.dot(g1)], axis=0) 182 | np.testing.assert_allclose(pg, out_grad, rtol=1e-3) 183 | 184 | @parameterized.named_parameters(_make_invalid_cases()) 185 | def test_invalid(self, invalid_options): 186 | with self.assertRaises(ValueError): 187 | shampoo.apply(invalid_options) 188 | 189 | @parameterized.named_parameters(_make_blockify_deblockify_cases()) 190 | def test_blockify_deblockify(self, shape, block_size): 191 | rng = jax.random.PRNGKey(0) 192 | x = jax.random.normal(rng, shape) 193 | options = shampoo.Options(block_size=block_size) 194 | meta = shampoo._blocks_metadata(options, x.shape, debug="") 195 | bx = shampoo._blockify(x, meta) 196 | dx = shampoo._deblockify(bx, meta) 197 | self.assertSequenceEqual(dx.shape, x.shape) 198 | np.testing.assert_array_equal(x, dx) 199 | 200 | @parameterized.parameters( 201 | [{"decay": d, "last": b} for d, b in itertools.product([0, 0.8], [False, True])] 202 | ) 203 | def test_basic_ema(self, decay, last): 204 | """Tests EMA accumulation in stats.""" 205 | z = jnp.zeros((2,)) 206 | g = jnp.array([0.5, -0.5]) 207 | 208 | if last: 209 | seq = jnp.stack([z, z, g]) 210 | one = jnp.stack([g]) 211 | expected_decay = 1 - decay 212 | else: 213 | seq = jnp.stack([g, z, z, g]) 214 | one = jnp.stack([g]) 215 | expected_decay = (1 - decay) * (decay**3 + 1) 216 | 217 | decayed = shampoo.Options(second_moment_decay=decay) 218 | no_decay = shampoo.Options(second_moment_decay=1.0) 219 | 220 | last = self._unroll_concrete(decayed, z, seq)[1][-1] 221 | last_no_decay = self._unroll_concrete(no_decay, z, one)[1][-1] 222 | last_no_decay /= np.sqrt(expected_decay) 223 | np.testing.assert_allclose(last, last_no_decay, rtol=1e-3) 224 | 225 | @parameterized.named_parameters(_make_blockify_deblockify_cases()) 226 | def test_blocks_equality(self, shape, block_size): 227 | rng = jax.random.PRNGKey(0) 228 | nsteps = 3 229 | grads = jax.random.normal(rng, (nsteps, *shape)) 230 | options = shampoo.Options(block_size=block_size) 231 | 232 | meta = shampoo._blocks_metadata(options, shape, debug="") 233 | grads_for_each_block = [[] for _ in range(meta.num_blocks)] 234 | for grad in grads: 235 | bgrad = shampoo._blockify(grad, meta) 236 | for i in range(meta.num_blocks): 237 | grads_for_each_block[i].append(jnp.take(bgrad, i, meta.blocks_axis)) 238 | last_grad = [] 239 | unblocked_options = shampoo.Options(block_size=1 + block_size) 240 | for block in grads_for_each_block: 241 | block = jnp.stack(block) 242 | block_grads, block_out_grads = self._unroll_concrete( 243 | unblocked_options, block[0], block 244 | ) 245 | del block_grads 246 | last_grad.append(block_out_grads[-1]) 247 | 248 | expected = jnp.stack(last_grad, axis=meta.blocks_axis) 249 | expected = shampoo._deblockify(expected, meta) 250 | actual = self._unroll_concrete(options, grads[0], grads)[1] 251 | np.testing.assert_allclose(expected, actual[-1]) 252 | 253 | def test_stats_freq(self): 254 | rng = jax.random.PRNGKey(0) 255 | grads = jax.random.normal(rng, (9, 3)) 256 | options = shampoo.Options(update_statistics_freq=3) 257 | _, out_grads = self._unroll_concrete(options, grads[0], grads) 258 | options = shampoo.Options(update_statistics_freq=1) 259 | _, out_grads_skip = self._unroll_concrete(options, grads[0], grads[::3]) 260 | np.testing.assert_allclose(out_grads[::3], out_grads_skip) 261 | 262 | def test_precond_freq(self): 263 | rng = jax.random.PRNGKey(0) 264 | rng, key = jax.random.split(rng) 265 | freq = 5 266 | grads = jax.random.normal(rng, (freq * 2, 3)) 267 | 268 | rng1, rng2 = jax.random.split(key, 2) 269 | seq1 = jnp.arange(freq, dtype=int) 270 | seq2 = jnp.copy(seq1) 271 | # Shuffle within groups of 272 | shuffled = jnp.take(grads, jnp.concatenate([seq1, seq2 + freq]), axis=0) 273 | 274 | grads = jnp.concatenate([jnp.zeros((1, 3)), grads]) 275 | shuffled = jnp.concatenate([jnp.zeros((1, 3)), shuffled]) 276 | 277 | options = shampoo.Options( 278 | update_preconditioners_freq=freq, second_moment_decay=1 279 | ) 280 | _, out_grads = self._unroll_concrete(options, grads[0], grads) 281 | _, out_grads_shuf = self._unroll_concrete(options, grads[0], shuffled) 282 | np.testing.assert_allclose(out_grads, out_grads_shuf) 283 | 284 | def test_tree(self): 285 | shape = (3, 2) 286 | n = 4 287 | options = shampoo.Options() 288 | rng = jax.random.PRNGKey(0) 289 | params = jnp.zeros(shape) 290 | grads = jax.random.normal(rng, (n, *shape)) 291 | _, out_grads = self._unroll_concrete(options, params, grads) 292 | 293 | params = {"w": [{"b": params}]} 294 | grads = {"w": [{"b": grads}]} 295 | _, actual_out_grads = self._unroll_concrete(options, params, grads) 296 | 297 | np.testing.assert_allclose(out_grads, actual_out_grads["w"][0]["b"]) 298 | 299 | 300 | if __name__ == "__main__": 301 | absltest.main() 302 | -------------------------------------------------------------------------------- /optimizers/tearfree/sketchy.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Sketchy low-rank second-order statistics preconditioning.""" 16 | 17 | import dataclasses 18 | import functools 19 | from typing import Any, NamedTuple, Optional, Union 20 | 21 | from absl import logging 22 | import chex 23 | import jax 24 | from jax import numpy as jnp 25 | import optax 26 | from optimizers.tearfree import praxis_shim 27 | 28 | 29 | @dataclasses.dataclass 30 | class Options: 31 | """Sketchy covariance approximation options. 32 | 33 | See https://arxiv.org/abs/2302.03764. 34 | 35 | Attributes: 36 | epsilon: The diagonal positive perturbation added to preconditioner before 37 | inversion. 38 | rank: The sketch size to use for FD sketch for each tensor's dimension. 39 | relative_epsilon: Whether to scale epsilon by the top singular value of the 40 | covariance or not. 41 | second_moment_decay: Exponential moving average for second-order statistics 42 | tracking. If 1.0 then sums. 43 | update_freq: Number of steps between covariance updates. 44 | add_ggt: whether to store the exponentially moving GGT in the states. Set 45 | to TRUE to save the exponentially moving GGT. 46 | memory_alloc: optional dictionary to indicate reallocation of memory used 47 | in sketching the covariance matrix. 48 | ekfac_svd: whether to use the ekfac_svd precondtioner instead of Sketchy, 49 | default setting to FALSE. 50 | linear_approx_tail: whether to use the approximately linear relationship 51 | between log(eigval) and log(eigval_rank) to calculate tail 52 | """ 53 | 54 | epsilon: float = 1e-7 55 | rank: int = 128 56 | relative_epsilon: bool = True 57 | second_moment_decay: float = 0.999 58 | update_freq: int = 1 59 | add_ggt: bool = False 60 | memory_alloc: Optional[dict[str, Any]] = None 61 | ekfac_svd: bool = False 62 | linear_approx_tail: bool = False 63 | 64 | 65 | def apply(options: Options) -> praxis_shim.ShardedGradientTransformation: 66 | """Return gradient transform for (blocked) shampoo preconditioning.""" 67 | 68 | _validate(options) 69 | 70 | return praxis_shim.ShardedGradientTransformation( 71 | functools.partial(_init, options), 72 | functools.partial(_update, options), 73 | functools.partial(_pspec, options), 74 | ) 75 | 76 | 77 | class _AxisState(NamedTuple): 78 | """Contains the covariance sketch state for each tensor dimension.""" 79 | 80 | eigvecs: jax.Array 81 | # These refer to the eigenvalues of the running *square root* of of the 82 | # covariance. 83 | eigvals: jax.Array 84 | # Analogously, but the -(1/(2*ndim)) root of the covariance, where ndim 85 | # is total tensor rank. 86 | inv_eigvals: jax.Array 87 | # The tail, however, tracks the cumulative escaped mass, which is the sum 88 | # of eigenvalues of the full gradient covariance which were subtracted out. 89 | tail: jax.Array 90 | # Analogously to inv_eigvals, the -(1/(2*ndim))-th root. 91 | inv_tail: jax.Array 92 | # Add additional optional state to store ema GGT 93 | ema_ggt: Union[optax.MaskedNode, jax.Array] 94 | # Save svd result to perform ekfac preconditioning 95 | svd_result_u: Union[optax.MaskedNode, jax.Array] 96 | # Save svd result to perform ekfac preconditioning, this corresponeds to the 97 | # inv_eigvals using Sketchy 98 | svd_result_s: Union[optax.MaskedNode, jax.Array] 99 | # Analogous to inv_tail in the Sketchy case, but up to time t-1 100 | inv_prev_tail: Union[optax.MaskedNode, jax.Array] 101 | 102 | 103 | class _TensorState(NamedTuple): 104 | """Per-tensor state contains a list of axis states for each dimension.""" 105 | 106 | axes: list[_AxisState] 107 | 108 | 109 | class _SketchyState(NamedTuple): 110 | # A scalar int32 for step count. 111 | count: jax.Array 112 | # A tree of the same shape as the params of _TensorState leaves of f32. 113 | sketches: chex.ArrayTree 114 | 115 | 116 | def _validate(options: Options) -> None: 117 | """Raise ValueError if options are invalid.""" 118 | 119 | if options.update_freq <= 0: 120 | raise ValueError( 121 | "update_freq ({}) must be positive".format(options.update_freq) 122 | ) 123 | 124 | if not (0 <= options.second_moment_decay <= 1): 125 | raise ValueError( 126 | f"second_moment_decay ({options.second_moment_decay}) " 127 | "should be in [0, 1]" 128 | ) 129 | 130 | if options.rank <= 0: 131 | raise ValueError(f"rank ({options.rank}) must be at least 1") 132 | 133 | 134 | def _path_to_key(path: Any) -> str: 135 | concat_path = "" 136 | for dickey in path: 137 | if hasattr(dickey, "key"): 138 | concat_path = concat_path + "/" + dickey.key 139 | elif hasattr(dickey, "idx"): 140 | concat_path = concat_path + "/" + str(dickey.idx) 141 | else: 142 | raise ValueError("no key or idx found") 143 | return concat_path 144 | 145 | 146 | def _locate_path( 147 | path: ..., dictionary: dict[str, Any] 148 | ) -> Union[dict[str, Any], list[int]]: 149 | """Locate a path in a dictionary.""" 150 | 151 | carry = dictionary 152 | for p in path: 153 | if not hasattr(p, "key") and not hasattr(p, "idx"): 154 | raise ValueError("no key or idx found") 155 | carry = carry[p.key] if hasattr(p, "key") else carry[p.idx] 156 | assert isinstance(carry, list), type(carry) 157 | return carry 158 | 159 | 160 | def _init(options: Options, params: optax.Params) -> _SketchyState: 161 | """Inititialize sketch.""" 162 | 163 | def _tensor_state(path: ..., param: jax.Array) -> _TensorState: 164 | axes = [] 165 | axes_idx = 0 166 | add_ggt = options.add_ggt 167 | memory_alloc = options.memory_alloc 168 | ekfac = options.ekfac_svd 169 | total_dim_prod = jnp.prod(jnp.array(param.shape)) 170 | for d in param.shape: 171 | if d == 1: 172 | raise ValueError( 173 | "param {} shape ({}) has unit dimensions".format(path, param.shape) 174 | ) 175 | if memory_alloc: 176 | k = min(d, _locate_path(path, memory_alloc)[axes_idx]) 177 | logging.info( 178 | "custom rank: %d rank allocated to %s", k, _path_to_key(path) 179 | ) 180 | else: 181 | logging.warning( 182 | "memory_alloc not found for path %s, using global rank %d", 183 | _path_to_key(path), 184 | options.rank, 185 | ) 186 | k = min(d, options.rank) 187 | others_dim_prod = int(total_dim_prod / d) if d else 0 188 | m = min(d, k + others_dim_prod) 189 | axes_idx += 1 190 | 191 | axes.append( 192 | _AxisState( 193 | eigvecs=jnp.zeros((d, k)), 194 | eigvals=jnp.zeros((k,)), 195 | inv_eigvals=jnp.zeros((k,)), 196 | tail=jnp.zeros(tuple()), 197 | inv_tail=jnp.zeros(tuple()), 198 | ema_ggt=jnp.zeros((d, d)) if add_ggt else optax.MaskedNode(), 199 | svd_result_u=jnp.zeros((d, m)) if ekfac else optax.MaskedNode(), 200 | svd_result_s=jnp.zeros((m,)) if ekfac else optax.MaskedNode(), 201 | inv_prev_tail=jnp.zeros(tuple()) if ekfac else optax.MaskedNode(), 202 | ) 203 | ) 204 | return _TensorState(axes) 205 | 206 | return _SketchyState( 207 | count=jnp.zeros([], jnp.int32), 208 | sketches=jax.tree_util.tree_map_with_path(_tensor_state, params), 209 | ) 210 | 211 | 212 | def _pspec( 213 | options: Options, params: praxis_shim.NestedHParams 214 | ) -> praxis_shim.NestedHParams: 215 | """Generate sharding specification for sketchy state.""" 216 | 217 | count_pspec = praxis_shim.WeightHParams( 218 | shape=[], 219 | init=None, 220 | dtype=jnp.int32, 221 | collections=None, 222 | tensor_split_dims_mapping=[], 223 | ) 224 | 225 | def _tensor_pspec( 226 | path: ..., param: praxis_shim.WeightHParams 227 | ) -> praxis_shim.NestedHParams: 228 | 229 | total_dim_prod = jnp.prod(jnp.array(param.shape)) 230 | 231 | def _replicated(shape): 232 | return praxis_shim.WeightHParams( 233 | shape=list(shape), 234 | init=None, 235 | dtype=jnp.float32, 236 | collections=None, 237 | tensor_split_dims_mapping=[-1] * len(shape), 238 | ) 239 | 240 | def _make_axis_state(d: int, axes_idx: int): 241 | memory_alloc = options.memory_alloc 242 | ekfac = options.ekfac_svd 243 | if memory_alloc: 244 | k = min(d, _locate_path(path, memory_alloc)[axes_idx]) 245 | else: 246 | k = min(d, options.rank) 247 | add_ggt = options.add_ggt 248 | others_dim_prod = int(total_dim_prod / d) if d else 0 249 | m = min(d, k + others_dim_prod) 250 | axes_idx += 1 251 | return dict( 252 | eigvecs=_replicated((d, k)), 253 | eigvals=_replicated((k,)), 254 | inv_eigvals=_replicated((k,)), 255 | tail=_replicated(tuple()), 256 | inv_tail=_replicated(tuple()), 257 | ema_ggt=_replicated((d, d)) if add_ggt else optax.MaskedNode(), 258 | svd_result_u=_replicated((d, m)) if ekfac else optax.MaskedNode(), 259 | svd_result_s=_replicated((m,)) if ekfac else optax.MaskedNode(), 260 | inv_prev_tail=_replicated(tuple()) if ekfac else optax.MaskedNode(), 261 | ) 262 | 263 | return dict( 264 | axes=[ 265 | _make_axis_state(d, axes_idx) 266 | for d, axes_idx in zip(param.shape, range(len(param.shape))) 267 | ] 268 | ) 269 | 270 | return dict( 271 | count=count_pspec, 272 | sketches=jax.tree_util.tree_map_with_path( 273 | _tensor_pspec, params, is_leaf=lambda x: hasattr(x, "shape") 274 | ), 275 | ) 276 | 277 | 278 | def _update( 279 | options: Options, 280 | updates: optax.Updates, 281 | state: _SketchyState, 282 | params: Optional[optax.Params] = None, 283 | ) -> tuple[optax.Updates, _SketchyState]: 284 | """Update internal shampoo stats and precondition gradients.""" 285 | del params 286 | sketches = state.sketches 287 | is_tensor_state = lambda x: isinstance(x, _TensorState) 288 | 289 | should_update_stats = (state.count % options.update_freq) == 0 290 | updated_sketches = functools.partial( 291 | jax.tree_util.tree_map_with_path, 292 | functools.partial(_update_sketches, options), 293 | updates, 294 | sketches, 295 | is_leaf=is_tensor_state, 296 | ) 297 | 298 | if not options.ekfac_svd: 299 | new_sketches = jax.lax.cond( 300 | should_update_stats, updated_sketches, lambda: sketches 301 | ) 302 | 303 | else: 304 | # when using ekfac_svd, need to call updated_sketches() every iteration 305 | # since the preconditioner needs to be updated every iteration 306 | # even when the sketches are updated every update_freq iterations 307 | updated_preconditioner_only = functools.partial( 308 | jax.tree_util.tree_map_with_path, 309 | lambda p, u, s: _update_sketches(options, p, u, s, False), 310 | updates, 311 | sketches, 312 | is_leaf=is_tensor_state, 313 | ) 314 | new_sketches = jax.lax.cond( 315 | should_update_stats, updated_sketches, updated_preconditioner_only 316 | ) 317 | 318 | new_updates = jax.tree_util.tree_map_with_path( 319 | functools.partial(_precondition, options), 320 | updates, 321 | new_sketches, 322 | is_leaf=is_tensor_state, 323 | ) 324 | return new_updates, _SketchyState(count=state.count + 1, sketches=new_sketches) 325 | 326 | 327 | def _update_sketches( 328 | options: Options, 329 | path: ..., 330 | update: jax.Array, 331 | sketches: _TensorState, 332 | update_sketches: bool = True, 333 | ) -> _TensorState: 334 | """Update sketched covariance statistics given a gradient.""" 335 | new_axes = [] 336 | for dim, axis_state in enumerate(sketches.axes): 337 | with jax.named_scope("UpdateSketchDim{}".format(dim)): 338 | new_axes.append( 339 | _update_axis(options, dim, path, update, axis_state, update_sketches) 340 | ) 341 | return _TensorState(new_axes) 342 | 343 | 344 | def _precondition( 345 | options: Options, path: ..., update: jax.Array, sketches: _TensorState 346 | ) -> jax.Array: 347 | """Precondition gradients.""" 348 | g = update 349 | original_shape = g.shape 350 | roll = tuple(range(1, g.ndim)) + (0,) 351 | memory_alloc = options.memory_alloc 352 | ekfac = options.ekfac_svd 353 | for dim, axis_state in enumerate(sketches.axes): 354 | with jax.named_scope("SketchPreconditionDim{}".format(dim)): 355 | # Rotate g during this loop; keep the axis to precondition first. 356 | d = original_shape[dim] 357 | assert g.shape[0] == d 358 | if memory_alloc: 359 | k = min(d, _locate_path(path, memory_alloc)[dim]) 360 | else: 361 | k = min(d, options.rank) 362 | assert list(axis_state.eigvecs.shape) == [d, k] 363 | eigvecs = axis_state.eigvecs if not ekfac else axis_state.svd_result_u 364 | lowrank_basis = jnp.tensordot(g, eigvecs, axes=[[0], [0]]) 365 | lowrank_component = jnp.tensordot( 366 | lowrank_basis, eigvecs, axes=[[g.ndim - 1], [1]] 367 | ) 368 | g = jnp.transpose(g, axes=roll) 369 | complement = g - lowrank_component 370 | inv_eigvals = ( 371 | axis_state.inv_eigvals if not ekfac else axis_state.svd_result_s 372 | ) 373 | scaled_basis = lowrank_basis * inv_eigvals 374 | scaled_lowrank_component = jnp.tensordot( 375 | scaled_basis, eigvecs, axes=[[g.ndim - 1], [1]] 376 | ) 377 | g = scaled_lowrank_component 378 | inv_tail = axis_state.inv_tail if not ekfac else axis_state.inv_prev_tail 379 | g += inv_tail * complement 380 | return g 381 | 382 | 383 | # pylint: disable = g-long-lambda 384 | def _update_axis( 385 | options: Options, 386 | dim: int, 387 | path: ..., 388 | update: jax.Array, 389 | axis_state: _AxisState, 390 | update_sketches: bool = True, 391 | ) -> _AxisState: 392 | """Perform an FD update for statistics.""" 393 | # _low_rank_root 394 | d = update.shape[dim] 395 | memory_alloc = options.memory_alloc 396 | if memory_alloc: 397 | k = min(d, _locate_path(path, memory_alloc)[dim]) 398 | else: 399 | k = min(d, options.rank) 400 | 401 | sketch_dk = axis_state.eigvecs 402 | assert sketch_dk.shape == (d, k), (sketch_dk.shape, d, k, update.shape, dim) 403 | 404 | sketch_dk *= axis_state.eigvals[jnp.newaxis, :] 405 | all_but_dim = [i for i in range(update.ndim) if i != dim] 406 | g_dm = update.transpose([dim] + all_but_dim).reshape(d, -1) 407 | decay = jnp.sqrt(options.second_moment_decay) 408 | 409 | # This implementation uses only O(|gradient size|) memory because 410 | # full_matrices is False, but may be slow. Consider LOBPCG instead. 411 | updated = jnp.concatenate([sketch_dk * decay, g_dm], axis=1) 412 | # This dimensionality reduction with QR is a mathematical no-op but required 413 | # to avoid b/286607548. 414 | updated = jnp.linalg.qr(updated.T, mode="r").T 415 | 416 | def _safe_svd(x): 417 | # Wrap with a nan check due to hangs per b/286654608. 418 | svd = lambda y: jnp.linalg.svd(y, full_matrices=False)[:2] 419 | 420 | def _all_nan(y): 421 | m = min(y.shape) 422 | u = jnp.full((d, m), jnp.nan, jnp.float32) 423 | s = jnp.full((m,), jnp.nan, jnp.float32) 424 | return u, s 425 | 426 | return jax.lax.cond(jnp.isfinite(x).all(), svd, _all_nan, x) 427 | 428 | u, s = _safe_svd(updated) 429 | assert u.shape[0] == d 430 | assert u.shape[1] >= k 431 | 432 | cutoff = jnp.maximum(s[k], 0.0) if k < len(s) else 0.0 433 | top_eigs = jnp.maximum(s[:k], 0.0) 434 | deflated = jnp.sqrt(jnp.maximum(0.0, top_eigs - cutoff)) * jnp.sqrt( 435 | top_eigs + cutoff 436 | ) 437 | if options.linear_approx_tail and d > k: 438 | num_points = (k + 1) // 2 439 | assert num_points > 0 440 | ranks = jnp.arange(1, num_points + 1) 441 | vals = axis_state.eigvals[:num_points] 442 | assert ranks.shape == vals.shape 443 | sample_cov = jnp.cov(ranks, vals) 444 | s_x, s_xy = sample_cov[0, 0], sample_cov[0, 1] 445 | slope = jax.lax.cond(s_x > 0, lambda: s_xy / (s_x**2), lambda: 0.0) 446 | intercept = jnp.mean(vals) - slope * jnp.mean(ranks) 447 | log_ranks = jnp.log(jnp.arange(k + 1, d + 1)) 448 | fitted_vals = slope * log_ranks + intercept 449 | tail = jnp.exp(jax.scipy.special.logsumexp(fitted_vals * 2)) / (d - k) 450 | undeflated = jnp.square(jnp.maximum(top_eigs, 0.0)) 451 | else: 452 | tail = axis_state.tail * decay + cutoff**2 453 | # Avoid numerical error from the sqrt computation and from subtracting 454 | # and re-adding cutoff^2 (mathematically, undeflated == deflated^2 + tail). 455 | undeflated = jnp.square(jnp.maximum(top_eigs, 0.0)) + axis_state.tail * decay 456 | eigvecs = u[:, :k] 457 | 458 | mask = deflated > 0 459 | 460 | alpha = jnp.asarray(-1.0 / (2 * update.ndim), dtype=jnp.float32) 461 | eigvecs *= mask 462 | if options.relative_epsilon and options.epsilon > 0: 463 | eps = jnp.max(undeflated) * options.epsilon 464 | else: 465 | eps = options.epsilon 466 | inv_eigvals = jnp.where(mask, (undeflated + eps) ** alpha, 0.0) 467 | eigvals = deflated * mask 468 | inv_tail = jnp.where(tail > 0, (tail + eps) ** alpha, 0.0) 469 | 470 | if options.add_ggt: 471 | ema_ggt = axis_state.ema_ggt * decay + g_dm.dot(g_dm.T) * (1 - decay) 472 | else: 473 | ema_ggt = axis_state.ema_ggt 474 | 475 | if options.ekfac_svd: 476 | assert u.shape[1] <= d 477 | prev_tail = axis_state.tail 478 | undeflated_ekfac = jnp.square(jnp.maximum(s, 0.0)) + prev_tail * decay 479 | svd_result_u = u 480 | svd_result_s = jnp.where( 481 | undeflated_ekfac > 0, (undeflated_ekfac + eps) ** alpha, 0.0 482 | ) 483 | inv_prev_tail = axis_state.inv_tail 484 | else: 485 | svd_result_u = axis_state.svd_result_u 486 | svd_result_s = axis_state.svd_result_s 487 | inv_prev_tail = axis_state.inv_prev_tail 488 | 489 | res = _AxisState( 490 | eigvecs, 491 | eigvals, 492 | inv_eigvals, 493 | tail, 494 | inv_tail, 495 | ema_ggt, 496 | svd_result_u, 497 | svd_result_s, 498 | inv_prev_tail, 499 | ) 500 | 501 | return jax.lax.cond( 502 | update_sketches, 503 | lambda: res, 504 | lambda: res._replace( 505 | eigvecs=axis_state.eigvecs, 506 | eigvals=axis_state.eigvals, 507 | inv_eigvals=axis_state.inv_eigvals, 508 | tail=axis_state.tail, 509 | inv_tail=axis_state.inv_tail, 510 | ), 511 | ) 512 | -------------------------------------------------------------------------------- /optimizers/tearfree/sketchy_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for momentum implementation.""" 16 | 17 | import itertools 18 | from typing import Sequence 19 | 20 | from absl.testing import absltest 21 | from absl.testing import parameterized 22 | import jax 23 | from jax import numpy as jnp 24 | import numpy as np 25 | from optimizers.tearfree import shampoo 26 | from optimizers.tearfree import sketchy 27 | 28 | 29 | def _make_invalid_cases() -> Sequence[dict[str, ...]]: 30 | """Generate invalid cases which should throw.""" 31 | return [ 32 | {"testcase_name": "freq0", "invalid_options": sketchy.Options(update_freq=0)}, 33 | { 34 | "testcase_name": "decay_neg", 35 | "invalid_options": sketchy.Options(second_moment_decay=-0.1), 36 | }, 37 | { 38 | "testcase_name": "decay_large", 39 | "invalid_options": sketchy.Options(second_moment_decay=1.1), 40 | }, 41 | ] 42 | 43 | 44 | class SketchyTest(parameterized.TestCase): 45 | """Basic test for shampoo implementation.""" 46 | 47 | def setUp(self): 48 | super().setUp() 49 | jax.config.update("jax_debug_nans", True) 50 | 51 | @parameterized.parameters( 52 | {"shape": (1, 2, 1)}, 53 | {"shape": (1, 1, 3, 1, 2, 1)}, 54 | {"shape": (2, 1, 3, 2)}, 55 | {"shape": (1, 1)}, 56 | {"shape": (1,)}, 57 | ) 58 | def test_unit_dims_raise(self, shape): 59 | """Assert raises if unit dimensions are present.""" 60 | with self.assertRaises(ValueError): 61 | self._unroll(sketchy.apply(sketchy.Options()), 1, shape) 62 | 63 | @parameterized.named_parameters(_make_invalid_cases()) 64 | def test_invalid(self, invalid_options): 65 | with self.assertRaises(ValueError): 66 | sketchy.apply(invalid_options) 67 | 68 | def _make_null_state(self, d, k) -> sketchy._AxisState: 69 | return sketchy._init(sketchy.Options(rank=k), jnp.zeros((d,))).sketches.axes[0] 70 | 71 | def _make_eye_state(self, d, eigs, tail, ndim) -> sketchy._AxisState: 72 | k = len(eigs) 73 | state = self._make_null_state(d, k) 74 | state = state._replace(eigvecs=jnp.eye(d, k)) 75 | state = state._replace(eigvals=state.eigvals + jnp.asarray(eigs)) 76 | state = state._replace(tail=tail) 77 | if tail > 0: 78 | state._replace(inv_tail=tail ** (-1 / (2 * ndim))) 79 | mask = state.eigvals > 0 80 | ie = jnp.where(mask, (state.tail + state.eigvals**2), 1.0) ** (-1 / (2 * ndim)) 81 | ie *= mask 82 | state._replace(inv_eigvals=ie) 83 | return state 84 | 85 | def _no_decay_options(self, sketch_size, epsilon=0.0): 86 | return sketchy.Options(rank=sketch_size, second_moment_decay=1, epsilon=epsilon) 87 | 88 | @parameterized.parameters(range(1, 5)) 89 | def test_dynamic_exponent(self, ndim): 90 | """Test that exponent for various gradient ndim is correct.""" 91 | size = 4 92 | prev = self._make_eye_state(size, [0], 0.0, ndim) 93 | grad = np.zeros([size] * ndim, np.float32) 94 | grad[(0,) * ndim] = 2**ndim 95 | ret = sketchy._update_axis(self._no_decay_options(1), 0, "", grad, prev) 96 | self.assertAlmostEqual(ret.inv_eigvals, 1 / 2, delta=1e-6) 97 | 98 | prev = self._make_eye_state(size, [2**ndim], 0.0, ndim) 99 | grad = np.zeros([size] * ndim, np.float32) 100 | ret = sketchy._update_axis(self._no_decay_options(1), 0, "", grad, prev) 101 | self.assertAlmostEqual(ret.inv_eigvals, 1 / 2, delta=1e-6) 102 | 103 | def test_epsilon(self): 104 | """Test that epsilon is properly calculated.""" 105 | size = 4 106 | ndim = 2 107 | prev = self._make_eye_state(size, [0], 4.0, ndim) 108 | grad = np.zeros([size] * ndim, np.float32) 109 | grad[(0,) * ndim] = 2 110 | options = self._no_decay_options(1, epsilon=1e-3) 111 | ret = sketchy._update_axis(options, 0, "", grad, prev) 112 | self.assertAlmostEqual( 113 | ret.inv_eigvals[0], ((4 + 4) * 1.001) ** (-1 / 4), delta=1e-3, msg=ret 114 | ) 115 | self.assertAlmostEqual(ret.inv_tail, (4 * 1.001) ** (-1 / 4), delta=1e-3) 116 | options.relative_epsilon = False 117 | ret = sketchy._update_axis(options, 0, "", grad, prev) 118 | self.assertAlmostEqual( 119 | ret.inv_eigvals[0], (4 + 4 + 0.001) ** (-1 / 4), delta=1e-6 120 | ) 121 | self.assertAlmostEqual(ret.inv_tail, (4 + 0.001) ** (-1 / 4), delta=1e-3) 122 | 123 | def _make_rand_state(self, size, eigs, tail, ndim): 124 | rng = np.random.default_rng(1234) 125 | b = rng.standard_normal(size=[size, size]) 126 | b = b.dot(b.T) 127 | _, v = np.linalg.eigh(b) 128 | state = self._make_eye_state(size, eigs, tail, ndim) 129 | state = state._replace(eigvecs=v[:, : len(eigs)]) 130 | return state 131 | 132 | # pylint: disable=g-long-lambda 133 | def test_realloc(self): 134 | """Test the memory reallocation functions properly.""" 135 | dim, nsteps = 8, 3 136 | memory_dict = {"a": [2], "b": [[6], [8]], "c": {"d": [4], "e": [8]}} 137 | tx = sketchy.apply(sketchy.Options(memory_alloc=memory_dict)) 138 | shape = jax.tree.map( 139 | lambda x: (dim,), 140 | memory_dict, 141 | is_leaf=lambda x: isinstance(x, list) 142 | and all(not isinstance(y, list) for y in x), 143 | ) 144 | grads_tree, updates = self._unroll(tx, nsteps, shape, None, True) 145 | emw_run = jax.tree.map( 146 | lambda k, sp, grad: self._unroll( 147 | tx=sketchy.apply(sketchy.Options(rank=k[0])), 148 | n=nsteps, 149 | shape=sp, 150 | grads=grad, 151 | ), 152 | memory_dict, 153 | shape, 154 | grads_tree, 155 | is_leaf=lambda x: isinstance(x, list) 156 | and all(not isinstance(y, list) for y in x), 157 | ) 158 | jax.tree.map(np.testing.assert_allclose, updates, emw_run) 159 | 160 | # test ekfac resulting preconditioned gradient on random values are finite 161 | def test_ekfac(self): 162 | tx = sketchy.apply(sketchy.Options(ekfac_svd=True)) 163 | nsteps = 3 164 | shape = (4, 5) 165 | updated_grads = self._unroll(tx, nsteps, shape) 166 | assert np.all(np.isfinite(updated_grads)) 167 | 168 | # test the preconditioned gradient from linearly approximated tails are finite 169 | def test_linear_approx(self): 170 | tx = sketchy.apply(sketchy.Options(rank=6, linear_approx_tail=True)) 171 | nsteps = 3 172 | shape = (10, 10) 173 | updated_grads = self._unroll(tx, nsteps, shape) 174 | assert np.all(np.isfinite(updated_grads)) 175 | 176 | # test covariance-adding equality from FD 177 | # with rand initial state, and with zero 178 | # 179 | # Do it under ndim 1 2 or 3 (choose random axis for higher dims) 180 | 181 | @parameterized.parameters( 182 | itertools.product( 183 | [1, 2, 3], 184 | [0.1, 0.9, 1.0], 185 | ["zero", "id", "rand"], 186 | [0.0, 1.0], 187 | [False, True], 188 | ) 189 | ) 190 | def test_basic(self, ndim, decay, init, tail, last_axis): 191 | """Validate low rank returned matrix.""" 192 | d = 3 193 | k = 2 194 | rng = np.random.default_rng(1234) 195 | 196 | # Make other dims slightly larger 197 | shape = [d + i for i in range(ndim)] 198 | if last_axis: 199 | shape = shape[::-1] 200 | grad = rng.standard_normal(size=shape) 201 | 202 | if last_axis: 203 | grad_2d = grad.reshape(-1, d) 204 | added_cov = grad_2d.T.dot(grad_2d) 205 | else: 206 | grad_2d = grad.reshape(d, -1) 207 | added_cov = grad_2d.dot(grad_2d.T) 208 | top_added_eig = np.linalg.eigvalsh(added_cov).max() 209 | # Test out one eig above, one below. 210 | eigs = np.array([top_added_eig * 4, top_added_eig / 4]) 211 | 212 | if init == "zero": 213 | prev = self._make_null_state(d, k) 214 | elif init == "id": 215 | prev = self._make_eye_state(d, eigs, tail, ndim) 216 | else: 217 | assert init == "rand", init 218 | prev = self._make_rand_state(d, eigs, tail, ndim) 219 | 220 | options = sketchy.Options(second_moment_decay=decay, rank=k, epsilon=0.0) 221 | dim = ndim - 1 if last_axis else 0 222 | updated = sketchy._update_axis(options, dim, "", grad, prev) 223 | 224 | if updated.tail > 0: 225 | self.assertAlmostEqual(updated.tail ** (-1 / (2 * ndim)), updated.inv_tail) 226 | else: 227 | self.assertAlmostEqual(updated.inv_tail, 0) 228 | self.assertAlmostEqual(updated.tail, 0) 229 | 230 | ie = updated.inv_eigvals 231 | e = updated.eigvals**2 + updated.tail 232 | mask = updated.eigvals > 0 233 | expected_ie = mask * np.where(mask, e, 1.0) ** (-1 / (2 * ndim)) 234 | delta = 1e-5 * min(expected_ie.max(), ie.max()) 235 | self.assertSequenceAlmostEqual(expected_ie, ie, delta=delta) 236 | 237 | def _make_cov(sketch: sketchy._AxisState, add_tail=True): 238 | # Note eigvals refer to the *root* singular values, so squaring as 239 | # we do below recovers covariance. 240 | eigvals = np.sqrt(add_tail * sketch.tail + np.square(sketch.eigvals)) 241 | half = sketch.eigvecs * eigvals 242 | complement = np.eye(d) - sketch.eigvecs.dot(sketch.eigvecs.T) 243 | tail = complement * sketch.tail if add_tail else 0.0 244 | return half.dot(half.T) + tail 245 | 246 | self.assertGreaterEqual(updated.tail, prev.tail * decay) 247 | 248 | prev_cov = _make_cov(prev) 249 | new_cov = _make_cov(updated) 250 | pd_eigs = np.linalg.eigvalsh(new_cov - decay * prev_cov) 251 | # Validate positive definiteness up to numerical error. 252 | self.assertGreaterEqual(pd_eigs.min(), -pd_eigs.max() * 1e-4) 253 | 254 | prev_no_tail = _make_cov(prev, add_tail=False) 255 | w2, v2 = np.linalg.eigh(decay * prev_no_tail + added_cov) 256 | w2 = np.maximum(0, w2 - w2[d - k - 1]) 257 | half = v2 * jnp.sqrt(w2) 258 | expected_cov = half.dot(half.T) 259 | actual_cov = _make_cov(updated, add_tail=False) 260 | np.testing.assert_allclose(expected_cov, actual_cov, rtol=1e-3) 261 | 262 | def _unroll(self, tx, n, shape, grads=None, return_grads=False): 263 | """Generate states and grad updates n times.""" 264 | rng = jax.random.PRNGKey(0) 265 | params = jax.tree.map( 266 | jnp.zeros, 267 | shape, 268 | is_leaf=lambda x: isinstance(x, tuple) 269 | and all(isinstance(y, int) for y in x), 270 | ) 271 | if grads is None: 272 | grads = jax.tree.map( 273 | lambda sp: jax.random.normal(rng, (n, *sp)), 274 | shape, 275 | is_leaf=lambda x: isinstance(x, tuple) 276 | and all(isinstance(y, int) for y in x), 277 | ) 278 | 279 | init = tx.init(params) 280 | 281 | def reduce(state, grad): 282 | new_grad, new_state = tx.update(grad, state, params) 283 | return new_state, new_grad 284 | 285 | _, out_grads = jax.lax.scan(reduce, init, grads) 286 | if return_grads: 287 | return grads, out_grads 288 | return out_grads 289 | 290 | def test_reduction_to_shampoo(self): 291 | tx = sketchy.apply(sketchy.Options(second_moment_decay=0.99, epsilon=0.0)) 292 | shampoo_tx = shampoo.apply(shampoo.Options(second_moment_decay=0.99)) 293 | # Choose a shape well below sketchy rank & shampoo block size. 294 | shape = (4, 5) 295 | nsteps = 3 296 | sketchy_run = self._unroll(tx, nsteps, shape) 297 | # Shampoo 2nd moment is computed as (1 - decay) * update + decay * update 298 | # so we must adjust the preconditioned grad by a factor sqrt(1/(1-decay)). 299 | shampoo_run = self._unroll(shampoo_tx, nsteps, shape) / 10 300 | np.testing.assert_allclose(shampoo_run, sketchy_run, rtol=3e-3, atol=2e-4) 301 | 302 | 303 | if __name__ == "__main__": 304 | absltest.main() 305 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tyro 2 | wandb 3 | flax 4 | tqdm 5 | numpy 6 | datasets 7 | tensorflow 8 | transformers 9 | optax 10 | einops 11 | psgd-jax 12 | chex 13 | google-cloud-storage 14 | orbax-checkpoint 15 | sentencepiece 16 | torchinfo 17 | tensorboard-plugin-profile 18 | gcsfs 19 | -------------------------------------------------------------------------------- /scripts/125M.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | EXPERIMENT=run_$(date +%Y-%m-%d_%H-%M-%S) 4 | echo $EXPERIMENT 5 | 6 | python3 main.py \ 7 | --experiment_name=$EXPERIMENT \ 8 | --out_dir=gs://optimizertesting/llm-jax \ 9 | --attempt_to_load_checkpoint \ 10 | --hellaswag_eval_interval=1000 \ 11 | --checkpoint_interval=1000 \ 12 | --train_steps=10000 \ 13 | --batch_size=512 \ 14 | --gradient_accumulation_steps=1 \ 15 | --compute_dtype=bfloat16 \ 16 | --params_dtype=float32 \ 17 | --profile \ 18 | --model.scan_layers \ 19 | --model.remat \ 20 | --model.no_remat_everything \ 21 | --optimizer.type=kron \ 22 | --optimizer.learning_rate=0.001 \ 23 | --optimizer.flat_lr \ 24 | --optimizer.warmup_steps=1000 \ 25 | --optimizer.b1=0.95 \ 26 | --optimizer.weight_decay=0.1 -------------------------------------------------------------------------------- /scripts/125M_mh_tpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # script for multihost tpu (set for v4-16, increase settings for larger vms) 3 | # usage: bash scripts/125M_mh_tpu.sh 4 | 5 | WANDB_API_KEY=$1 6 | HF_TOKEN=$2 7 | 8 | EXPERIMENT=run_$(date +%Y-%m-%d_%H-%M-%S) 9 | echo $EXPERIMENT 10 | 11 | gcloud compute tpus tpu-vm ssh --zone "us-central2-b" "tpu_vm_name" --project "project_name" --worker=all --command "bash -c \" 12 | export WANDB_API_KEY=$WANDB_API_KEY 13 | export HF_TOKEN=$HF_TOKEN 14 | export LIBTPU_INIT_ARGS="--xla_enable_async_all_gather=true" 15 | cd llm-jax 16 | nohup python3 main_multihost.py \ 17 | --experiment_name=$EXPERIMENT \ 18 | --out_dir=gs://optimizertesting/llm-jax \ 19 | --attempt_to_load_checkpoint \ 20 | --hellaswag_eval_interval=1000 \ 21 | --checkpoint_interval=1000 \ 22 | --train_steps=100000 \ 23 | --batch_size=512 \ 24 | --gradient_accumulation_steps=1 \ 25 | --compute_dtype=bfloat16 \ 26 | --params_dtype=float32 \ 27 | --profile \ 28 | --model.scan_layers \ 29 | --model.remat \ 30 | --model.no_remat_everything \ 31 | --optimizer.type=kron \ 32 | --optimizer.learning_rate=0.001 \ 33 | --optimizer.warmup_steps=1000 \ 34 | --optimizer.b1=0.95 \ 35 | --optimizer.weight_decay=0.1 \ 36 | --optimizer.preconditioner_update_probability=0.03 \ 37 | --optimizer.preconditioner_dtype=float32 \ 38 | > nohup.out 2>&1 & 39 | PID=\\\$! 40 | echo 'Background process started with PID '\\\$PID 41 | disown \\\$PID 42 | exit 43 | \"" -------------------------------------------------------------------------------- /scripts/350M_mh_tpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # script for multihost tpu (set for v4-16, increase settings for larger vms) 3 | # usage: bash scripts/125M_mh_tpu.sh 4 | 5 | WANDB_API_KEY=$1 6 | HF_TOKEN=$2 7 | 8 | EXPERIMENT=run_$(date +%Y-%m-%d_%H-%M-%S) 9 | echo $EXPERIMENT 10 | 11 | gcloud compute tpus tpu-vm ssh --zone "us-central2-b" "tpu_vm_name" --project "project_name" --worker=all --command "bash -c \" 12 | export WANDB_API_KEY=$WANDB_API_KEY 13 | export HF_TOKEN=$HF_TOKEN 14 | export LIBTPU_INIT_ARGS="--xla_enable_async_all_gather=true" 15 | cd llm-jax 16 | nohup python3 main_multihost.py \ 17 | --experiment_name=$EXPERIMENT \ 18 | --out_dir=gs://optimizertesting/llm-jax \ 19 | --attempt_to_load_checkpoint \ 20 | --hellaswag_eval_interval=1000 \ 21 | --checkpoint_interval=1000 \ 22 | --train_steps=250000 \ 23 | --batch_size=512 \ 24 | --gradient_accumulation_steps=1 \ 25 | --compute_dtype=bfloat16 \ 26 | --profile \ 27 | --model.num_layers=32 \ 28 | --model.num_heads=15 \ 29 | --model.num_kv_heads=5 \ 30 | --model.head_dim=64 \ 31 | --model.num_embeds=960 \ 32 | --model.hidden_dim=2560 \ 33 | --model.scan_layers \ 34 | --model.remat \ 35 | --model.no_remat_everything \ 36 | --optimizer.type=kron \ 37 | --optimizer.learning_rate=0.001 \ 38 | --optimizer.warmup_steps=1000 \ 39 | --optimizer.b1=0.95 \ 40 | --optimizer.weight_decay=0.1 \ 41 | --optimizer.preconditioner_update_probability=0.05 \ 42 | > nohup.out 2>&1 & 43 | PID=\\\$! 44 | echo 'Background process started with PID '\\\$PID 45 | disown \\\$PID 46 | exit 47 | \"" -------------------------------------------------------------------------------- /scripts/delete_tpu_lockfile.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # sudo bash scripts/delete_tpu_lockfile.sh 3 | 4 | sudo rm -rf /tmp/libtpu_lockfile 5 | sudo rm -rf /tmp/tpu_logs -------------------------------------------------------------------------------- /scripts/free_tpus.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Function to print usage 3 | usage() { 4 | echo "Usage: $0 --project [--tpu_name ] [--zone ] [--verbose]" 5 | echo "The --project flag is mandatory." 6 | echo "If --zone is not specified, us-central2-b will be used by default." 7 | exit 1 8 | } 9 | # Parse command line arguments 10 | ZONE="us-central2-b" 11 | VERBOSE=false 12 | PROJECT="" 13 | while [[ "$#" -gt 0 ]]; do 14 | case $1 in 15 | --tpu_name) PROVIDED_TPU_NAME="$2"; shift ;; 16 | --zone) ZONE="$2"; shift ;; 17 | --project) PROJECT="$2"; shift ;; 18 | --verbose) VERBOSE=true ;; 19 | *) usage ;; 20 | esac 21 | shift 22 | done 23 | # Check if project is provided 24 | if [ -z "$PROJECT" ]; then 25 | echo "Error: --project flag is mandatory" 26 | usage 27 | fi 28 | # Function to find an available TPU 29 | find_tpu() { 30 | local found_tpu=$(gcloud compute tpus list --project $PROJECT --zone=$ZONE --format="value(name)" --limit=1 2>/dev/null) 31 | if [ $? -ne 0 ]; then 32 | echo "Error: Failed to list TPUs. Please check your gcloud configuration and permissions." 33 | exit 1 34 | fi 35 | echo $found_tpu 36 | } 37 | # Set TPU_NAME 38 | if [ -z "$PROVIDED_TPU_NAME" ]; then 39 | TPU_NAME=$(find_tpu) 40 | if [ -z "$TPU_NAME" ]; then 41 | echo "No TPU found in zone $ZONE. Please make sure you have an active TPU or specify a TPU name." 42 | exit 1 43 | fi 44 | echo "No TPU name provided. Using automatically found TPU: $TPU_NAME" 45 | else 46 | TPU_NAME=$PROVIDED_TPU_NAME 47 | echo "Using provided TPU name: $TPU_NAME" 48 | fi 49 | # Verify TPU existence 50 | if ! gcloud compute tpus describe $TPU_NAME --project $PROJECT --zone=$ZONE &>/dev/null; then 51 | echo "Error: TPU '$TPU_NAME' not found in zone $ZONE. Please check the TPU name and zone." 52 | exit 1 53 | fi 54 | # Print selected TPU, zone, and project 55 | echo "Using TPU '$TPU_NAME' in zone '$ZONE' for project '$PROJECT'" 56 | # Function to run command on all TPU VM workers 57 | run_on_all_workers() { 58 | local command="$1" 59 | local output=$(gcloud compute tpus tpu-vm ssh $TPU_NAME --project $PROJECT --zone=$ZONE --worker=all --command "$command" 2>&1) 60 | if [ $? -ne 0 ]; then 61 | echo "Error: Failed to execute command on TPU workers. Please check your TPU status and permissions." 62 | exit 1 63 | fi 64 | echo "$output" 65 | } 66 | # Kill processes using /dev/accel0 67 | echo "Checking for processes using /dev/accel0..." 68 | kill_output=$(run_on_all_workers "sudo lsof -t /dev/accel0 | xargs -r sudo kill -9") 69 | if [ -n "$kill_output" ]; then 70 | echo "Processes killed on TPU workers:" 71 | echo "$kill_output" 72 | else 73 | echo "No processes found using /dev/accel0" 74 | fi 75 | # Print system information if verbose flag is set 76 | if [ "$VERBOSE" = true ]; then 77 | echo "Printing system information..." 78 | system_info=$(run_on_all_workers "uname -a && lscpu") 79 | echo "$system_info" 80 | fi 81 | echo "All operations completed." -------------------------------------------------------------------------------- /scripts/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export XLA_FLAGS="--xla_force_host_platform_device_count=2" 4 | 5 | python3 main.py \ 6 | --out_dir=/Users/evanwalters/llm_testing \ 7 | --no_attempt_to_load_checkpoint \ 8 | --compute_dtype=float32 \ 9 | --params_dtype=float32 \ 10 | --model.min_size_to_shard_mb=0 \ 11 | --train_steps=1000 \ 12 | --hellaswag_eval_interval=20 \ 13 | --checkpoint_interval=20 \ 14 | --batch_size=4 \ 15 | --gradient_accumulation_steps=2 \ 16 | --profile \ 17 | --wandb.mode=offline \ 18 | --optimizer.type=kron \ 19 | --optimizer.schedule_free \ 20 | --optimizer.learning_rate=0.001 \ 21 | --optimizer.flat_lr \ 22 | --optimizer.warmup_steps=20 \ 23 | --optimizer.preconditioner_dtype=float32 \ 24 | --optimizer.no_lax_map_scanned_layers \ 25 | --optimizer.lax_map_batch_size=1 \ 26 | --model.block_size=64 \ 27 | --model.num_layers=2 \ 28 | --model.num_heads=4 \ 29 | --model.num_embeds=8 \ 30 | --model.head_dim=4 \ 31 | --model.hidden_dim=8 \ 32 | --model.num_kv_heads=2 \ 33 | --model.scan_layers -------------------------------------------------------------------------------- /sharding.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Base sharding functions from big vision changed for our nets and optimizers.""" 16 | import numpy as np 17 | 18 | import jax 19 | from jax.sharding import NamedSharding, PartitionSpec as P 20 | import flax.linen as nn 21 | 22 | from utils import tree_flatten_with_names, write_note 23 | 24 | 25 | def infer_sharding(params, mesh, op): 26 | """Infer sharding spec for the given parameters. 27 | 28 | Return a sharding tree and a spec tree. 29 | """ 30 | x_with_names, tree_def = tree_flatten_with_names(params) 31 | names = tree_def.unflatten(list(zip(*x_with_names))[0]) 32 | 33 | specs = jax.tree.map(lambda x: (None,) * x.ndim, params) 34 | 35 | specs = jax.tree.map( 36 | lambda x, name, spec: op(spec, mesh, name, x), 37 | params, 38 | names, 39 | specs, 40 | # Preconditioners for PSGD and tearfree shampoo kept in lists 41 | is_leaf=lambda v: isinstance(v, nn.Partitioned) or isinstance(v, list), 42 | ) 43 | 44 | # Two-level tree_map to prevent it from doing traversal inside the spec. 45 | specs = jax.tree.map(lambda _, spec: P(*spec), nn.unbox(params), specs) 46 | sharding = jax.tree.map(lambda spec: NamedSharding(mesh, spec), specs) 47 | return sharding, specs 48 | 49 | 50 | def fsdp_sharding(axis, min_size_to_shard_mb=1): 51 | """Simple FSDP sharding rules.""" 52 | # TODO consider not overwriting already sharded dims 53 | axis = axis if isinstance(axis, str) else tuple(axis) 54 | axis_tuple = axis if isinstance(axis, tuple) else (axis,) 55 | 56 | def _update_spec(cur_spec, mesh, name, x): 57 | axis_size = np.prod([mesh.shape[a] for a in axis_tuple]) 58 | 59 | if isinstance(x, list): 60 | # Preconditioners for PSGD and tearfree shampoo kept in lists 61 | precond_specs = [] 62 | shard_dim = -2 # first precond matrix dim 63 | for precond in x: 64 | shape = precond.shape 65 | new_sharding = [None for _ in shape] 66 | if ( 67 | np.prod(shape) * precond.dtype.itemsize 68 | >= min_size_to_shard_mb * (2**20) 69 | and len(shape) > 1 70 | and shape[shard_dim] % axis_size == 0 71 | ): 72 | new_sharding[shard_dim] = axis 73 | print(f"sharding {name}:{shape} to {new_sharding}") 74 | precond_specs.append(tuple(new_sharding)) 75 | return precond_specs 76 | 77 | shape = x.shape 78 | 79 | # Partitioning rules, simple FSDP 80 | # indexed backwards from last dim for friendliness to scanned leading dims 81 | if ( 82 | np.prod(shape) * x.dtype.itemsize >= min_size_to_shard_mb * (2**20) 83 | and len(shape) > 1 84 | ): 85 | new_sharding = [None for _ in shape] 86 | if "scale" in name or "bias" in name: 87 | pass 88 | elif any(s in name for s in ["embedding", "out_kernel", "down_kernel"]): 89 | # shard these on last dim (-1) 90 | if shape[-1] % axis_size == 0: 91 | new_sharding[-1] = axis 92 | print(f"sharding {name}:{shape} to {new_sharding}") 93 | return tuple(new_sharding) 94 | else: 95 | print( 96 | f"WARNING: Parameter {name}:{shape} is not sharded because " 97 | f"last dimension is not divisible by axis size {axis_size}. " 98 | "Consider changing last dim to be divisible by axis size." 99 | ) 100 | elif any( 101 | s in name 102 | for s in [ 103 | "q_kernel", 104 | "k_kernel", 105 | "v_kernel", 106 | "gate_kernel", 107 | "up_kernel", 108 | ] 109 | ): 110 | # shard these on first dim (-2) 111 | if shape[-2] % axis_size == 0: 112 | new_sharding[-2] = axis 113 | print(f"sharding {name}:{shape} to {new_sharding}") 114 | return tuple(new_sharding) 115 | else: 116 | print( 117 | f"WARNING: Parameter {name}:{shape} is not sharded because " 118 | f"first dimension is not divisible by axis size {axis_size}. " 119 | "Consider changing first dim to be divisible by axis size." 120 | ) 121 | else: 122 | # If not explicitly sharded above, infer here by partitioning 123 | # along largest axis that is divisible and not taken. 124 | idx = np.argsort(shape)[::-1] 125 | for i in idx: 126 | if shape[i] % axis_size == 0: 127 | if cur_spec[i] is None: 128 | new_sharding[i] = axis 129 | print(f"sharding {name}:{shape} to {new_sharding}") 130 | return tuple(new_sharding) 131 | 132 | write_note( 133 | f"Parameter {name}:{shape} not sharded because did not meet rules " 134 | f"or already occupied by other sharding rules: {cur_spec}" 135 | ) 136 | return cur_spec 137 | 138 | return _update_spec 139 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import itertools 3 | from multiprocessing.pool import ThreadPool 4 | from typing import Mapping 5 | import dataclasses 6 | import numpy as np 7 | 8 | import jax 9 | import flax 10 | 11 | 12 | def _traverse_with_names(tree, with_inner_nodes=False): 13 | """Traverses nested dicts/dataclasses and emits (leaf_name, leaf_val).""" 14 | if dataclasses.is_dataclass(tree): 15 | tree = flax.serialization.to_state_dict(tree) 16 | # Don't output the non-leaf nodes. If the optimizer doesn't have a state 17 | # the tree leaves can be Nones which was interpreted as a leaf by this 18 | # function but not by the other functions (like jax.tree.map). 19 | if tree is None: 20 | return 21 | elif isinstance(tree, Mapping): 22 | keys = sorted(tree.keys()) 23 | for key in keys: 24 | for path, v in _traverse_with_names(tree[key], with_inner_nodes): 25 | yield (key + "/" + path).rstrip("/"), v 26 | if with_inner_nodes: 27 | yield "", tree 28 | elif isinstance(tree, (list, tuple)): 29 | for idx in range(len(tree)): 30 | for path, v in _traverse_with_names(tree[idx], with_inner_nodes): 31 | yield (str(idx) + "/" + path).rstrip("/"), v 32 | if with_inner_nodes: 33 | yield "", tree 34 | else: 35 | yield "", tree 36 | 37 | 38 | def tree_flatten_with_names(tree): 39 | """Populates tree_flatten with leaf names. 40 | 41 | This function populates output of tree_flatten with leaf names, using a 42 | custom traversal that produces names is provided. The custom traversal does 43 | NOT have to traverse tree in the same order as jax, as we take care of 44 | automatically aligning jax' and custom traversals. 45 | 46 | Args: 47 | tree: python tree. 48 | 49 | Returns: 50 | A list of values with names: [(name, value), ...] 51 | """ 52 | vals, tree_def = jax.tree.flatten(tree) 53 | 54 | # "Fake" token tree that is use to track jax internal tree traversal and 55 | # adjust our custom tree traversal to be compatible with it. 56 | tokens = range(len(vals)) 57 | token_tree = tree_def.unflatten(tokens) 58 | val_names, perm = zip(*_traverse_with_names(token_tree)) 59 | inv_perm = np.argsort(perm) 60 | 61 | # Custom traverasal should visit the same number of leaves. 62 | assert len(val_names) == len(vals) 63 | 64 | return [(val_names[i], v) for i, v in zip(inv_perm, vals)], tree_def 65 | 66 | 67 | def make_fsarray_from_local_slice(local_slice, global_devices): 68 | """Create a fully-sharded global device array from local host arrays. 69 | 70 | Args: 71 | local_slice: Something convertible to a numpy array (eg also TF tensors) 72 | that is this host's slice of the global array. 73 | global_devices: The list of global devices. Needed for consistent ordering. 74 | 75 | Returns: 76 | The global on-device array which consists of all local slices stacked 77 | together in the order consistent with the devices. 78 | """ 79 | mesh = jax.sharding.Mesh(global_devices, ("devices",)) 80 | sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("devices")) 81 | local_ds = mesh.local_devices 82 | 83 | x = np.asarray(memoryview(local_slice)) # No-copy: http://(internal link) 84 | xs = jax.device_put(np.split(x, len(local_ds), axis=0), local_ds) 85 | 86 | global_shape = (x.shape[0] * jax.process_count(), *x.shape[1:]) 87 | return jax.make_array_from_single_device_arrays(global_shape, sharding, xs) 88 | 89 | 90 | def threadstart_iterator(it): 91 | """Starts an iterator right away in a background thread.""" 92 | # We already want to "start" the iterator in order to start the underlying 93 | # dataset prefetch mechanisms, so here we get the first element. But we don't 94 | # want to lose it from training, so we yield that one afterward. 95 | # (internal link) 96 | pool = ThreadPool(processes=1) 97 | first_ex_promise = pool.apply_async(lambda: next(it)) 98 | 99 | yield first_ex_promise.get() 100 | yield from it 101 | 102 | 103 | def prefetch_iterator(it, n): 104 | """Runs iterator `it` ahead for `n` steps. Adapted from flax.""" 105 | if not n: 106 | yield from it 107 | return 108 | queue = collections.deque() 109 | 110 | def enqueue(n_steps): # Enqueues *up to* `n` elements from the iterator. 111 | for data in itertools.islice(it, n_steps): 112 | # Prefetching will parallelize any processing that happens in a different 113 | # thread (like `jax.device_put()`), but it will be of no use for 114 | # processing that happens in the same thread. 115 | queue.append(data) 116 | 117 | enqueue(n) # Fill up the buffer. 118 | while queue: 119 | yield queue.popleft() 120 | enqueue(1) 121 | 122 | 123 | def tree_broadcast(prefix, target): 124 | """Broadcasts a prefix tree to a full tree. 125 | 126 | Input-output examples: 127 | 1. prefix: {"x": 10, "y": 20} 128 | target: {"x": {"a": 1, "b": 2}, "y": 3} 129 | 130 | Result: {"x": {"a": 10, "b": 10}, "y": 20} 131 | 132 | 2. prefix: 100 133 | target: {"x": {"a": 1, "b": 2}, "y": 3} 134 | 135 | Result: {"x": {"a": 100, "b": 100}, "y": 100} 136 | 137 | 3. prefix: {"x": 10} 138 | target: {"x": {"a": 1, "b": 2}, "y": 3} 139 | 140 | Result: ValueError 141 | 142 | Args: 143 | prefix: prefix pytree. 144 | target: boradcast target for a prefix tree. 145 | 146 | Returns: 147 | prefix tree broadcasted to a target tree. 148 | """ 149 | 150 | def _broadcast(leaf, subtree): 151 | return jax.tree.map(lambda _: leaf, subtree) 152 | 153 | return jax.tree.map(_broadcast, prefix, target) 154 | 155 | 156 | def reshard(tree, shardings): 157 | """Take an arbitrarily* sharded pytree and shard it according to `shardings`. 158 | 159 | This is a no-op for tree elements which are already sharded as requested. 160 | 161 | *Arrays that are fully addressable (for example, CPU arrays) are assumed to be 162 | identical (i.e. replicated) across hosts. 163 | 164 | *It does not work if an element of `tree` is not fully-addressable, unless its 165 | sharding is already consistent with the target sharding. 166 | If this is needed, please ping lbeyer@ or akolesnikov@. 167 | 168 | Args: 169 | tree: a pytree of arrays. 170 | shardings: a (prefix) pytree of jax array shardings. 171 | Returns: 172 | A pytree of global jax arrays that follows provided shardings. 173 | """ 174 | 175 | def _make_global_arr(x, shard, shape): 176 | # Avoid unnecessary copies and transfers: 177 | if hasattr(x, "sharding") and x.sharding.is_equivalent_to( 178 | shard, len(shape) 179 | ): # pylint: disable=line-too-long 180 | return x 181 | if not getattr(x, "is_fully_addressable", True): 182 | raise RuntimeError( 183 | "Trying to reshard a non-fully-addressable array. " 184 | "Please see the doc-comment for detailed explanation." 185 | ) 186 | x = jax.device_get(x) # Might be on local devices. 187 | xs = [ 188 | jax.device_put(x[s], device=d) 189 | for d, s in shard.addressable_devices_indices_map(shape).items() 190 | ] 191 | return jax.make_array_from_single_device_arrays(shape, shard, xs) 192 | 193 | shapes = jax.tree.map(np.shape, tree) 194 | shardings = tree_broadcast(shardings, tree) 195 | return jax.tree.map(_make_global_arr, tree, shardings, shapes) 196 | 197 | 198 | def write_note(note: str): 199 | if jax.process_index() == 0: 200 | print(note) 201 | 202 | 203 | def check_dtypes(orig_dtype_tree, current_dtype_tree): 204 | """Pass in two trees of dtypes and check if they are the same.""" 205 | assert orig_dtype_tree == current_dtype_tree, ( 206 | f"dtype mismatch:\n" 207 | f"Before: {orig_dtype_tree}\n" 208 | f"After: {current_dtype_tree}\n" 209 | ) 210 | 211 | 212 | def count_params(params) -> int: 213 | return sum(np.prod(x.shape) for x in jax.tree_util.tree_leaves(params)) 214 | 215 | 216 | def get_step(state) -> int: 217 | return jax.device_get(state.step).item() 218 | --------------------------------------------------------------------------------