├── LICENSE ├── MARS ├── model.py ├── opt.py ├── optimizers │ ├── adamw.py │ ├── adopt.py │ ├── mars.py │ └── muon.py ├── train_CNN.py ├── train_CV.py ├── train_adamw.py ├── train_adamw_fw.py ├── train_mars.py ├── train_mars_fw.py ├── train_muon.py └── utils │ ├── configurator.py │ ├── cv_utils.py │ └── model_CNN.py ├── README.md ├── assets ├── MARS-AdamW.png ├── MARS-Lion.png ├── MARS-Shampoo.png ├── MARS.png ├── ShampooH.png ├── cifar100_test_acc.png ├── cifar100_test_loss.png ├── cifar10_test_acc.png ├── cifar10_test_loss.png ├── fineweb_hella.png ├── small_train.png ├── small_val.png ├── time_large.png ├── time_medium.png ├── time_small.png ├── val_large.png ├── val_medium.png ├── val_small.jpg ├── val_small.png ├── xl_train.png └── xl_val.png ├── config ├── train_gpt2_large_adamw.py ├── train_gpt2_large_mars.py ├── train_gpt2_large_muon.py ├── train_gpt2_medium_adamw.py ├── train_gpt2_medium_mars.py ├── train_gpt2_medium_muon.py ├── train_gpt2_small_adamw.py ├── train_gpt2_small_mars.py ├── train_gpt2_small_muon.py ├── train_gpt2_xl_adamw.py └── train_gpt2_xl_mars.py ├── data └── openwebtext │ └── prepare.py └── scripts ├── run_CNN.sh ├── run_CV.sh ├── run_adamw_large.sh ├── run_adamw_medium.sh ├── run_adamw_small.sh ├── run_adamw_small_fw.sh ├── run_adamw_xl_fw.sh ├── run_mars_large.sh ├── run_mars_medium.sh ├── run_mars_small.sh ├── run_mars_small_fw.sh ├── run_mars_xl_fw.sh ├── run_muon_large.sh ├── run_muon_medium.sh └── run_muon_small.sh /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /MARS/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from https://github.com/Liuhong99/Sophia/blob/main/model.py 3 | """ 4 | 5 | import math 6 | import inspect 7 | from dataclasses import dataclass 8 | from optimizers.adamw import AdamW 9 | import torch 10 | import torch.nn as nn 11 | from torch.nn import functional as F 12 | from optimizers.mars import MARS 13 | 14 | optimizer_dict = {'adamw': torch.optim.AdamW, 15 | 'adamw_ours': AdamW, 16 | 'mars': MARS, 17 | } 18 | 19 | # @torch.jit.script # good to enable when not using torch.compile, disable when using (our default) 20 | def new_gelu(x): 21 | """ 22 | Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). 23 | Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415 24 | """ 25 | return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) 26 | 27 | class LayerNorm(nn.Module): 28 | """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """ 29 | 30 | def __init__(self, ndim, bias): 31 | super().__init__() 32 | self.weight = nn.Parameter(torch.ones(ndim)) 33 | self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None 34 | 35 | def forward(self, input): 36 | return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) 37 | 38 | class CausalSelfAttention(nn.Module): 39 | 40 | def __init__(self, config, idx_layer): 41 | super().__init__() 42 | assert config.n_embd % config.n_head == 0 43 | # key, query, value projections for all heads, but in a batch 44 | self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) 45 | # output projection 46 | self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) 47 | # regularization 48 | self.attn_dropout = nn.Dropout(config.dropout) 49 | self.resid_dropout = nn.Dropout(config.dropout) 50 | self.n_head = config.n_head 51 | self.n_embd = config.n_embd 52 | self.dropout = config.dropout 53 | self.idx_layer = idx_layer 54 | self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx 55 | 56 | # causal mask to ensure that attention is only applied to the left in the input sequence 57 | self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) 58 | .view(1, 1, config.block_size, config.block_size)) 59 | 60 | def forward(self, x): 61 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) 62 | 63 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 64 | q, k ,v = self.c_attn(x).split(self.n_embd, dim=2) 65 | k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 66 | q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 67 | v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 68 | 69 | if self.scale_attn_by_inverse_layer_idx: 70 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)) / float(self.idx_layer + 1)) 71 | else: 72 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 73 | att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) 74 | att = F.softmax(att, dim=-1) 75 | att = self.attn_dropout(att) 76 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 77 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side 78 | 79 | # output projection 80 | y = self.resid_dropout(self.c_proj(y)) 81 | return y 82 | 83 | class MLP(nn.Module): 84 | 85 | def __init__(self, config): 86 | super().__init__() 87 | self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) 88 | self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) 89 | self.dropout = nn.Dropout(config.dropout) 90 | 91 | def forward(self, x): 92 | x = self.c_fc(x) 93 | x = new_gelu(x) 94 | x = self.c_proj(x) 95 | x = self.dropout(x) 96 | return x 97 | 98 | class Block(nn.Module): 99 | 100 | def __init__(self, config, idx_layer): 101 | super().__init__() 102 | self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) 103 | self.attn = CausalSelfAttention(config, idx_layer) 104 | self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) 105 | self.mlp = MLP(config) 106 | 107 | def forward(self, x): 108 | x = x + self.attn(self.ln_1(x)) 109 | x = x + self.mlp(self.ln_2(x)) 110 | return x 111 | 112 | @dataclass 113 | class GPTConfig: 114 | block_size: int = 1024 115 | vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency, 50304 116 | n_layer: int = 12 117 | n_head: int = 12 118 | n_embd: int = 768 119 | dropout: float = 0.0 120 | bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster 121 | scale_attn_by_inverse_layer_idx: bool = False 122 | 123 | 124 | class GPT(nn.Module): 125 | 126 | def __init__(self, config): 127 | super().__init__() 128 | assert config.vocab_size is not None 129 | assert config.block_size is not None 130 | self.config = config 131 | 132 | self.transformer = nn.ModuleDict(dict( 133 | wte = nn.Embedding(config.vocab_size, config.n_embd), 134 | wpe = nn.Embedding(config.block_size, config.n_embd), 135 | drop = nn.Dropout(config.dropout), 136 | h = nn.ModuleList([Block(config, idx_layer) for idx_layer in range(config.n_layer)]), 137 | ln_f = LayerNorm(config.n_embd, bias=config.bias), 138 | )) 139 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 140 | # with weight tying when using torch.compile() some warnings get generated: 141 | # "UserWarning: functional_call was passed multiple values for tied weights. 142 | # This behavior is deprecated and will be an error in future versions" 143 | # not 100% sure what this is, so far seems to be harmless. TODO investigate 144 | self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying 145 | 146 | # init all weights 147 | self.apply(self._init_weights) 148 | # apply special scaled init to the residual projections, per GPT-2 paper 149 | for pn, p in self.named_parameters(): 150 | if pn.endswith('c_proj.weight'): 151 | torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer)) 152 | 153 | # report number of parameters 154 | print("number of parameters: %.2fM" % (self.get_num_params()/1e6,)) 155 | 156 | def get_num_params(self, non_embedding=True): 157 | """ 158 | Return the number of parameters in the model. 159 | For non-embedding count (default), the position embeddings get subtracted. 160 | The token embeddings would too, except due to the parameter sharing these 161 | params are actually used as weights in the final layer, so we include them. 162 | """ 163 | n_params = sum(p.numel() for p in self.parameters()) 164 | if non_embedding: 165 | n_params -= self.transformer.wpe.weight.numel() 166 | return n_params 167 | 168 | def _init_weights(self, module): 169 | if isinstance(module, nn.Linear): 170 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 171 | if module.bias is not None: 172 | torch.nn.init.zeros_(module.bias) 173 | elif isinstance(module, nn.Embedding): 174 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 175 | 176 | def forward(self, idx, targets=None): 177 | device = idx.device 178 | b, t = idx.size() 179 | assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" 180 | pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t) 181 | 182 | # forward the GPT model itself 183 | tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) 184 | pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd) 185 | x = self.transformer.drop(tok_emb + pos_emb) 186 | for block in self.transformer.h: 187 | x = block(x) 188 | x = self.transformer.ln_f(x) 189 | 190 | if targets is not None: 191 | # if we are given some desired targets also calculate the loss 192 | if not isinstance(targets, int): 193 | logits = self.lm_head(x) 194 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) 195 | else: 196 | logits = self.lm_head(x) 197 | loss = None 198 | else: 199 | # inference-time mini-optimization: only forward the lm_head on the very last position 200 | logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim 201 | loss = None 202 | 203 | return logits, loss 204 | 205 | def crop_block_size(self, block_size): 206 | # model surgery to decrease the block size if necessary 207 | # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024) 208 | # but want to use a smaller block size for some smaller, simpler model 209 | assert block_size <= self.config.block_size 210 | self.config.block_size = block_size 211 | self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size]) 212 | for block in self.transformer.h: 213 | block.attn.bias = block.attn.bias[:,:,:block_size,:block_size] 214 | 215 | @classmethod 216 | def from_pretrained(cls, model_type, override_args=None): 217 | assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'} 218 | override_args = override_args or {} # default to empty dict 219 | # only dropout can be overridden see more notes below 220 | assert all(k == 'dropout' for k in override_args) 221 | from transformers import GPT2LMHeadModel 222 | print("loading weights from pretrained gpt: %s" % model_type) 223 | 224 | # n_layer, n_head and n_embd are determined from model_type 225 | config_args = { 226 | 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params 227 | 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params 228 | 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params 229 | 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params 230 | }[model_type] 231 | print("forcing vocab_size=50257, block_size=1024, bias=True") 232 | config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints 233 | config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints 234 | config_args['bias'] = True # always True for GPT model checkpoints 235 | # we can override the dropout rate, if desired 236 | if 'dropout' in override_args: 237 | print(f"overriding dropout rate to {override_args['dropout']}") 238 | config_args['dropout'] = override_args['dropout'] 239 | # create a from-scratch initialized minGPT model 240 | config = GPTConfig(**config_args) 241 | model = GPT(config) 242 | sd = model.state_dict() 243 | sd_keys = sd.keys() 244 | sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param 245 | 246 | # init a huggingface/transformers model 247 | model_hf = GPT2LMHeadModel.from_pretrained(model_type) 248 | sd_hf = model_hf.state_dict() 249 | 250 | # copy while ensuring all of the parameters are aligned and match in names and shapes 251 | sd_keys_hf = sd_hf.keys() 252 | sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer 253 | sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer) 254 | transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight'] 255 | # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear 256 | # this means that we have to transpose these weights when we import them 257 | assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}" 258 | for k in sd_keys_hf: 259 | if any(k.endswith(w) for w in transposed): 260 | # special treatment for the Conv1D weights we need to transpose 261 | assert sd_hf[k].shape[::-1] == sd[k].shape 262 | with torch.no_grad(): 263 | sd[k].copy_(sd_hf[k].t()) 264 | else: 265 | # vanilla copy over the other parameters 266 | assert sd_hf[k].shape == sd[k].shape 267 | with torch.no_grad(): 268 | sd[k].copy_(sd_hf[k]) 269 | 270 | return model 271 | 272 | def configure_optimizers(self, optimizer_name, weight_decay, learning_rate, betas, device_type, 273 | other_para_config=None): 274 | """ 275 | This long function is unfortunately doing something very simple and is being very defensive: 276 | We are separating out all parameters of the model into two buckets: those that will experience 277 | weight decay for regularization and those that won't (biases, and layernorm/embedding weights). 278 | We are then returning the PyTorch optimizer object. 279 | """ 280 | 281 | # separate out all parameters to those that will and won't experience regularizing weight decay 282 | decay = set() 283 | no_decay = set() 284 | whitelist_weight_modules = (torch.nn.Linear, ) 285 | blacklist_weight_modules = (torch.nn.LayerNorm, LayerNorm, torch.nn.Embedding) 286 | for mn, m in self.named_modules(): 287 | for pn, p in m.named_parameters(): 288 | fpn = '%s.%s' % (mn, pn) if mn else pn # full param name 289 | # random note: because named_modules and named_parameters are recursive 290 | # we will see the same tensors p many many times. but doing it this way 291 | # allows us to know which parent module any tensor p belongs to... 292 | if pn.endswith('bias'): 293 | # all biases will not be decayed 294 | no_decay.add(fpn) 295 | elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): 296 | # weights of whitelist modules will be weight decayed 297 | decay.add(fpn) 298 | elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): 299 | # weights of blacklist modules will NOT be weight decayed 300 | no_decay.add(fpn) 301 | 302 | # subtle: 'transformer.wte.weight' and 'lm_head.weight' are tied, so they 303 | # will appear in the no_decay and decay sets respectively after the above. 304 | # In addition, because named_parameters() doesn't return duplicates, it 305 | # will only return the first occurrence, key'd by 'transformer.wte.weight', below. 306 | # so let's manually remove 'lm_head.weight' from decay set. This will include 307 | # this tensor into optimization via transformer.wte.weight only, and not decayed. 308 | decay.remove('lm_head.weight') 309 | 310 | # validate that we considered every parameter 311 | param_dict = {pn: p for pn, p in self.named_parameters()} 312 | inter_params = decay & no_decay 313 | union_params = decay | no_decay 314 | assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) 315 | assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ 316 | % (str(param_dict.keys() - union_params), ) 317 | 318 | # create the pytorch optimizer object 319 | optim_groups = [ 320 | {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay}, 321 | {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, 322 | ] 323 | 324 | opt_func = optimizer_dict[optimizer_name] 325 | if optimizer_name == 'adamw': 326 | # new PyTorch nightly has a new 'fused' option for AdamW that is much faster 327 | use_fused = (device_type == 'cuda') and ('fused' in inspect.signature(torch.optim.AdamW).parameters) 328 | print(f"using fused AdamW: {use_fused}") 329 | extra_args = dict(fused=True) if use_fused else dict() 330 | optimizer = opt_func(optim_groups, lr=learning_rate, betas=betas, **extra_args) 331 | elif optimizer_name == 'adamw_ours': 332 | optimizer = opt_func(optim_groups, lr=learning_rate, betas=betas) 333 | elif optimizer_name == 'mars': 334 | if other_para_config is None: 335 | optimizer = opt_func(optim_groups, lr=learning_rate, betas=betas) 336 | else: 337 | optimizer = opt_func(optim_groups, lr=learning_rate, betas=betas, **other_para_config) 338 | else: 339 | raise ValueError('Invalid optimizer.') 340 | return optimizer 341 | 342 | def estimate_mfu(self, fwdbwd_per_iter, dt): 343 | """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """ 344 | # first estimate the number of flops we do per iteration. 345 | # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311 346 | N = self.get_num_params() 347 | cfg = self.config 348 | L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size 349 | flops_per_token = 6*N + 12*L*H*Q*T 350 | flops_per_fwdbwd = flops_per_token * T 351 | flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter 352 | # express our flops throughput as ratio of A100 bfloat16 peak flops 353 | flops_achieved = flops_per_iter * (1.0/dt) # per second 354 | flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS 355 | mfu = flops_achieved / flops_promised 356 | return mfu 357 | 358 | @torch.no_grad() 359 | def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): 360 | """ 361 | Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete 362 | the sequence max_new_tokens times, feeding the predictions back into the model each time. 363 | Most likely you'll want to make sure to be in model.eval() mode of operation for this. 364 | """ 365 | for _ in range(max_new_tokens): 366 | # if the sequence context is growing too long we must crop it at block_size 367 | idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:] 368 | # forward the model to get the logits for the index in the sequence 369 | logits, _ = self(idx_cond) 370 | # pluck the logits at the final step and scale by desired temperature 371 | logits = logits[:, -1, :] / temperature 372 | # optionally crop the logits to only the top k options 373 | if top_k is not None: 374 | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) 375 | logits[logits < v[:, [-1]]] = -float('Inf') 376 | # apply softmax to convert logits to (normalized) probabilities 377 | probs = F.softmax(logits, dim=-1) 378 | # sample from the distribution 379 | idx_next = torch.multinomial(probs, num_samples=1) 380 | # append sampled index to the running sequence and continue 381 | idx = torch.cat((idx, idx_next), dim=1) 382 | 383 | return idx 384 | -------------------------------------------------------------------------------- /MARS/opt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import collections 3 | """ 4 | Adapted from askerlee@github: https://github.com/KellerJordan/modded-nanogpt/issues/9 5 | """ 6 | def separate_params(param_groups): 7 | param_groups_2d = [] 8 | param_groups_non2d = [] 9 | total_param_2d_count = 0 10 | total_param_non2d_count = 0 11 | 12 | 13 | # Convert iterators to lists 14 | if isinstance(param_groups, collections.abc.Iterable): 15 | param_groups = list(param_groups) 16 | 17 | # Check if param_groups is a list of dicts or list of params 18 | if (isinstance(param_groups, list) and isinstance(param_groups[0], dict)) \ 19 | or isinstance(param_groups, dict): 20 | if isinstance(param_groups, dict): 21 | param_groups = [param_groups] 22 | # param_groups is a list of dicts 23 | for group in param_groups: 24 | params_2d, params_non2d, param_2d_count, param_non2d_count = separate_params(group['params']) 25 | param_group_2d = {'params': params_2d} 26 | param_group_non2d = {'params': params_non2d} 27 | # Copy the group dict and replace the 'params' key with the separated params 28 | for k in group.keys(): 29 | if k != 'params': 30 | param_group_2d[k] = group[k] 31 | param_group_non2d[k] = group[k] 32 | 33 | param_groups_2d.append(param_group_2d) 34 | param_groups_non2d.append(param_group_non2d) 35 | total_param_2d_count += param_2d_count 36 | total_param_non2d_count += param_non2d_count 37 | 38 | return param_groups_2d, param_groups_non2d, total_param_2d_count, total_param_non2d_count 39 | 40 | elif isinstance(param_groups, list) and isinstance(param_groups[0], torch.Tensor): 41 | params_2d = [] 42 | params_non2d = [] 43 | param_group = param_groups 44 | # param_group is a list of param tensors 45 | for param in param_group: 46 | if param.ndim >= 2: 47 | params_2d.append(param) 48 | else: 49 | params_non2d.append(param) 50 | return params_2d, params_non2d, len(params_2d), len(params_non2d) 51 | else: 52 | breakpoint() 53 | 54 | ''' 55 | # CombinedOptimizer is now a torch.optim.Optimizer, compatible with pytorch lightning. 56 | # Original Example: 57 | optimizer = CombinedOptimizer([ 58 | torch.optim.AdamW(self.lm_head.parameters(), lr=learning_rate, betas=betas, weight_decay=0, fused=True), 59 | OrthogonalNesterov(self.transformer.h.parameters(), lr=0.1*learning_rate, momentum=0.95) 60 | ]) 61 | # Refactored Example: 62 | optimizer = CombinedOptimizer(\ 63 | self.parameters(), 64 | [OrthogonalNesterov, torch.optim.AdamW], 65 | [{'lr': 0.1*learning_rate, 'momentum': 0.95}, 66 | {'lr': learning_rate, 'betas': betas, 'weight_decay': 0, 'fused': True} 67 | ]) 68 | ''' 69 | 70 | class CombinedOptimizer(torch.optim.Optimizer): 71 | def __init__(self, params, optimizer_types, configs, raw_model = False): 72 | # Separate 2D and non-2D parameters. 73 | # If params is a list of tensors, then each of param_groups_2d and param_groups_non2d 74 | # will be a list of tensors. 75 | # If params is a list of dicts, then each of param_groups_2d and param_groups_non2d 76 | # will be a list of dicts. 77 | # If params is a dict, then each of param_groups_2d and param_groups_non2d will 78 | # be a list of dicts containing only one dict. 79 | if raw_model: 80 | params_others = list(params.transformer.h.parameters()) 81 | param_groups_2d, param_groups_non2d, total_param_2d_count, total_param_non2d_count \ 82 | = separate_params(params_others) 83 | param_groups_non2d.extend(list(params.lm_head.parameters())) 84 | total_param_non2d_count += 2 85 | else: 86 | param_groups_2d, param_groups_non2d, total_param_2d_count, total_param_non2d_count \ 87 | = separate_params(params) 88 | param_groups_2d_non2d = (param_groups_non2d, param_groups_2d) 89 | print(f"Total 2D params: {total_param_2d_count}, Total non-2D params: {total_param_non2d_count}") 90 | 91 | assert len(optimizer_types) == len(configs) == 2 92 | self.optimizers = [ optimizer_types[i](param_groups_2d_non2d[i], **configs[i]) for i in range(2) ] 93 | self.param_groups = [pg for opt in self.optimizers for pg in opt.param_groups] 94 | self.base_lrs = [opt.param_groups[0]['lr'] for opt in self.optimizers] 95 | # Combine the state dicts of all opt in self.optimizers into a single dict 96 | self.state = {k: v for opt in self.optimizers for k, v in opt.state.items()} 97 | # Initially all states are empty. So no point to print their counts. 98 | # Only use the defaults of the OrthogonalNesterov optimizer 99 | self.defaults = self.optimizers[0].defaults 100 | 101 | def step(self, *args, **kwargs): 102 | for opt in self.optimizers: 103 | opt.step(*args, **kwargs) 104 | 105 | def zero_grad(self, **kwargs): 106 | for opt in self.optimizers: 107 | opt.zero_grad(**kwargs) 108 | 109 | def scale_lrs(self, lr_scale): 110 | for base_lr, opt in zip(self.base_lrs, self.optimizers): 111 | opt.param_groups[0]['lr'] = base_lr * lr_scale 112 | 113 | def state_dict(self): 114 | return [opt.state_dict() for opt in self.optimizers] -------------------------------------------------------------------------------- /MARS/optimizers/adamw.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim.optimizer import Optimizer 4 | # from megatron.optimizer.l2_norm import l2_norm 5 | 6 | def exists(val): 7 | return val is not None 8 | 9 | 10 | class AdamW(Optimizer): 11 | """Implements Adam algorithm. 12 | 13 | It has been proposed in `Adam: A Method for Stochastic Optimization`_. 14 | 15 | Arguments: 16 | params (iterable): iterable of parameters to optimize or dicts defining 17 | parameter groups 18 | lr (float, optional): learning rate (default: 1e-3) 19 | betas (Tuple[float, float], optional): coefficients used for computing 20 | running averages of gradient and its square (default: (0.9, 0.999)) 21 | eps (float, optional): term added to the denominator to improve 22 | numerical stability (default: 1e-8) 23 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 24 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 25 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 26 | 27 | .. _Adam\: A Method for Stochastic Optimization: 28 | https://arxiv.org/abs/1412.6980 29 | .. _On the Convergence of Adam and Beyond: 30 | https://openreview.net/forum?id=ryQu7f-RZ 31 | """ 32 | 33 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 34 | weight_decay=0, amsgrad=False): 35 | if not 0.0 <= lr: 36 | raise ValueError("Invalid learning rate: {}".format(lr)) 37 | if not 0.0 <= eps: 38 | raise ValueError("Invalid epsilon value: {}".format(eps)) 39 | if not 0.0 <= betas[0] < 1.0: 40 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 41 | if not 0.0 <= betas[1] < 1.0: 42 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 43 | defaults = dict(lr=lr, betas=betas, eps=eps, 44 | weight_decay=weight_decay, amsgrad=amsgrad) 45 | super(AdamW, self).__init__(params, defaults) 46 | self.eps = eps 47 | 48 | def __setstate__(self, state): 49 | super(AdamW, self).__setstate__(state) 50 | for group in self.param_groups: 51 | group.setdefault('amsgrad', False) 52 | 53 | @torch.no_grad() 54 | def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None, grad_scaler=None): 55 | """Performs a single optimization step. 56 | 57 | Arguments: 58 | closure (callable, optional): A closure that reevaluates the model 59 | and returns the loss. 60 | """ 61 | if any(p is not None for p in [grads, output_params, scale, grad_norms]): 62 | raise RuntimeError('FusedAdam has been updated. Simply initialize it identically to torch.optim.Adam, and call step() with no arguments.') 63 | 64 | loss = None 65 | if exists(closure): 66 | with torch.enable_grad(): 67 | loss = closure() 68 | real_update = 0 69 | real_update_wo_lr = 0 70 | 71 | for group in self.param_groups: 72 | for p in filter(lambda p: exists(p.grad), group['params']): 73 | if p.grad is None: 74 | continue 75 | grad = p.grad.data 76 | if grad.is_sparse: 77 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 78 | amsgrad = group['amsgrad'] 79 | 80 | state = self.state[p] 81 | #print('----- starting a parameter state', state.keys(), 'Length of state', len(state)) 82 | # State initialization 83 | if len(state) == 0: 84 | state['step'] = 0 85 | # Exponential moving average of gradient values 86 | state['exp_avg'] = torch.zeros_like(p.data) 87 | # Exponential moving average of squared gradient values 88 | state['exp_avg_sq'] = torch.zeros_like(p.data) 89 | if amsgrad: 90 | # Maintains max of all exp. moving avg. of sq. grad. values 91 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 92 | 93 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 94 | if amsgrad: 95 | max_exp_avg_sq = state['max_exp_avg_sq'] 96 | beta1, beta2 = group['betas'] 97 | 98 | if 'step' in state: 99 | state['step'] += 1 100 | else: 101 | state['step'] = 1 102 | 103 | # Decay the first and second moment running average coefficient 104 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 105 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 106 | if amsgrad: 107 | # Maintains the maximum of all 2nd moment running avg. till now 108 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 109 | # Use the max. for normalizing running avg. of gradient 110 | denom = max_exp_avg_sq.sqrt().add_(self.eps) 111 | else: 112 | denom = exp_avg_sq.sqrt().add_(self.eps) 113 | 114 | bias_correction1 = 1 - beta1 ** state['step'] 115 | bias_correction2 = 1 - beta2 ** state['step'] 116 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 117 | 118 | # p.data.addcdiv_(-step_size, exp_avg, denom) 119 | real_update_tmp = -step_size * torch.mul(p.data, group['weight_decay']).addcdiv_(1, exp_avg, denom) 120 | real_update_wo_lr_tmp = torch.mul(p.data, group['weight_decay']).addcdiv_(1, exp_avg, denom) 121 | 122 | p.data.add_(real_update_tmp) 123 | return loss 124 | 125 | 126 | 127 | -------------------------------------------------------------------------------- /MARS/optimizers/mars.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates 2 | # SPDX-License-Identifier: Apache-2.0 3 | import math 4 | import torch 5 | from torch.optim.optimizer import Optimizer 6 | import os 7 | import numpy as np 8 | import math 9 | # from megatron.optimizer.l2_norm import l2_norm 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | 15 | def update_fn(p, grad, exp_avg, exp_avg_sq, lr, wd, beta1, beta2, last_grad, eps, amsgrad, max_exp_avg_sq, step, gamma, 16 | mars_type, is_grad_2d, optimize_1d, lr_1d_factor, betas_1d, weight_decay_1d): 17 | # optimize_1d: use MARS for 1d para, not: use AdamW for 1d para 18 | if optimize_1d or is_grad_2d: 19 | c_t = (grad - last_grad).mul(gamma * (beta1 / (1. - beta1))).add(grad) 20 | c_t_norm = torch.norm(c_t) 21 | if c_t_norm > 1.: 22 | c_t = c_t / c_t_norm 23 | exp_avg.mul_(beta1).add_(c_t, alpha=1. - beta1) 24 | if (mars_type == "mars-adamw") or (mars_type == "mars-shampoo" and not is_grad_2d): 25 | exp_avg_sq.mul_(beta2).addcmul_(c_t, c_t, value=1. - beta2) 26 | bias_correction1 = 1.0 - beta1 ** step 27 | bias_correction2 = 1.0 - beta2 ** step 28 | if amsgrad: 29 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 30 | denom = max_exp_avg_sq.sqrt().mul(1/math.sqrt(bias_correction2)).add(eps).mul(bias_correction1) 31 | else: 32 | denom = exp_avg_sq.sqrt().mul(1/math.sqrt(bias_correction2)).add(eps).mul(bias_correction1) 33 | real_update_tmp = -lr * torch.mul(p.data, wd).add(exp_avg.div(denom)) 34 | elif mars_type == "mars-lion": 35 | real_update_tmp = -lr * torch.mul(p.data, wd).add(exp_avg.sign()) 36 | elif mars_type == "mars-shampoo" and is_grad_2d: 37 | factor = max(1, grad.size(0)/grad.size(1))**0.5 38 | real_update_tmp = NewtonSchulz(exp_avg.mul(1./(1.-beta1)), eps=eps).mul(factor).add(wd, p.data).mul(-lr) 39 | p.data.add_(real_update_tmp) 40 | else: 41 | beta1_1d, beta2_1d = betas_1d 42 | exp_avg.mul_(beta1_1d).add_(grad, alpha=1. - beta1_1d) 43 | exp_avg_sq.mul_(beta2_1d).addcmul_(grad, grad, value=1. - beta2_1d) 44 | bias_correction1 = 1.0 - beta1_1d ** step 45 | bias_correction2 = 1.0 - beta2_1d ** step 46 | if amsgrad: 47 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 48 | denom = max_exp_avg_sq.sqrt().mul(1/math.sqrt(bias_correction2)).add(eps).mul(bias_correction1) 49 | else: 50 | denom = exp_avg_sq.sqrt().mul(1/math.sqrt(bias_correction2)).add(eps).mul(bias_correction1) 51 | real_update_tmp = -lr * lr_1d_factor * torch.mul(p.data, weight_decay_1d).add(exp_avg.div(denom)) 52 | p.data.add_(real_update_tmp) 53 | return exp_avg, exp_avg_sq 54 | 55 | class MARS(Optimizer): 56 | def __init__(self, params, lr=3e-3, betas=(0.95, 0.99), eps=1e-8, weight_decay=0., amsgrad=False, gamma=0.025, 57 | is_approx=True, mars_type="mars-adamw", optimize_1d=False, lr_1d=3e-3, betas_1d=(0.9, 0.95), weight_decay_1d=0.1): 58 | if not 0.0 <= lr: 59 | raise ValueError("Invalid learning rate: {}".format(lr)) 60 | if not 0.0 <= eps: 61 | raise ValueError("Invalid epsilon value: {}".format(eps)) 62 | if not 0.0 <= betas[0] < 1.0: 63 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 64 | if not 0.0 <= betas[1] < 1.0: 65 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 66 | assert mars_type in ["mars-adamw", "mars-lion", "mars-shampoo"], "MARS type not supported" 67 | defaults = dict(lr=lr, betas=betas, eps=eps, 68 | weight_decay=weight_decay, amsgrad=amsgrad, 69 | mars_type=mars_type, gamma=gamma, 70 | optimize_1d=optimize_1d, weight_decay_1d=weight_decay_1d) 71 | super(MARS, self).__init__(params, defaults) 72 | self.eps = eps 73 | self.update_fn = update_fn 74 | self.lr = lr 75 | self.weight_decay=weight_decay 76 | self.amsgrad = amsgrad 77 | self.step_num = 0 78 | self.is_approx = is_approx 79 | self.gamma = gamma 80 | self.mars_type = mars_type 81 | self.optimize_1d = optimize_1d 82 | self.lr_1d_factor = lr_1d / lr 83 | self.weight_decay_1d = weight_decay_1d 84 | self.betas_1d = betas_1d 85 | 86 | @torch.no_grad() 87 | def update_last_grad(self): 88 | if not self.is_approx: 89 | for group in self.param_groups: 90 | for p in group['params']: 91 | state = self.state[p] 92 | if "last_grad" not in state: 93 | state["last_grad"] = torch.zeros_like(p) 94 | state["last_grad"].zero_().add_(state["previous_grad"], alpha=1.0) 95 | @torch.no_grad() 96 | def update_previous_grad(self): 97 | if not self.is_approx: 98 | for group in self.param_groups: 99 | #print ("para name", len(group['params']), len(group['names']), group['names']) 100 | for p in group['params']: 101 | # import pdb 102 | # pdb.set_trace() 103 | if p.grad is None: 104 | print (p, "grad is none") 105 | continue 106 | state = self.state[p] 107 | if "previous_grad" not in state: 108 | state['previous_grad'] = torch.zeros_like(p) 109 | state['previous_grad'].zero_().add_(p.grad, alpha=1.0) 110 | 111 | def __setstate__(self, state): 112 | super(MARS, self).__setstate__(state) 113 | for group in self.param_groups: 114 | group.setdefault('amsgrad', False) 115 | 116 | @torch.no_grad() 117 | def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None, grad_scaler=None): 118 | """Performs a single optimization step. 119 | 120 | Arguments: 121 | closure (callable, optional): A closure that reevaluates the model 122 | and returns the loss. 123 | 124 | If using exact version, the example usage is as follows: 125 | previous_X, previous_Y = None, None 126 | for epoch in range(epochs): 127 | for X, Y in data_loader: 128 | if previous_X: 129 | logits, loss = model(X, Y) 130 | loss.backward() 131 | optimizer.update_previous_grad() 132 | optimizer.zero_grad(set_to_none=True) 133 | logits, loss = model(X, Y) 134 | loss.backward() 135 | optimizer.step(bs=bs) 136 | optimizer.zero_grad(set_to_none=True) 137 | optimizer.update_last_grad() 138 | iter_num += 1 139 | previous_X, previous_Y = X.clone(), Y.clone() 140 | """ 141 | if any(p is not None for p in [grads, output_params, scale, grad_norms]): 142 | raise RuntimeError('FusedAdam has been updated. Simply initialize it identically to torch.optim.Adam, and call step() with no arguments.') 143 | 144 | loss = None 145 | if exists(closure): 146 | with torch.enable_grad(): 147 | loss = closure() 148 | real_update = 0 149 | real_update_wo_lr = 0 150 | gamma = self.gamma 151 | # import pdb 152 | # pdb.set_trace() 153 | for group in self.param_groups: 154 | for p in filter(lambda p: exists(p.grad), group['params']): 155 | if p.grad is None: 156 | continue 157 | grad = p.grad.data 158 | if grad.is_sparse: 159 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 160 | amsgrad = group['amsgrad'] 161 | 162 | state = self.state[p] 163 | #('----- starting a parameter state', state.keys(), 'Length of state', len(state)) 164 | # State initialization 165 | if len(state) <= 1: 166 | state['step'] = 0 167 | # Exponential moving average of gradient values 168 | state['exp_avg'] = torch.zeros_like(p.data) 169 | # Last Gradient 170 | state['last_grad'] = torch.zeros_like(p) 171 | #state['previous_grad'] = torch.zeros_like(p) 172 | # Exponential moving average of squared gradient values 173 | state['exp_avg_sq'] = torch.zeros_like(p.data) 174 | if amsgrad: 175 | # Maintains max of all exp. moving avg. of sq. grad. values 176 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 177 | # import pdb 178 | # pdb.set_trace() 179 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 180 | last_grad = state['last_grad'] 181 | lr, wd, beta1, beta2 = group['lr'], group['weight_decay'], *group['betas'] 182 | if amsgrad: 183 | max_exp_avg_sq = state['max_exp_avg_sq'] 184 | else: 185 | max_exp_avg_sq = 0 186 | 187 | if 'step' in state: 188 | state['step'] += 1 189 | else: 190 | state['step'] = 1 191 | step = state['step'] 192 | is_grad_2d = (len(grad.shape) == 2) 193 | exp_avg, exp_avg_sq = self.update_fn( 194 | p, 195 | grad, 196 | exp_avg, 197 | exp_avg_sq, 198 | lr, 199 | wd, 200 | beta1, 201 | beta2, 202 | last_grad, 203 | self.eps, 204 | amsgrad, 205 | max_exp_avg_sq, 206 | step, 207 | gamma, 208 | mars_type=self.mars_type, 209 | is_grad_2d=is_grad_2d, 210 | optimize_1d=self.optimize_1d, 211 | lr_1d_factor=self.lr_1d_factor, 212 | betas_1d=self.betas_1d, 213 | weight_decay_1d=self.weight_decay if self.optimize_1d else self.weight_decay_1d 214 | ) 215 | if self.is_approx: 216 | state['last_grad'] = grad 217 | self.step_num = step 218 | 219 | return loss 220 | 221 | @torch.compile 222 | def NewtonSchulz(M, steps=5, eps=1e-7): 223 | a, b, c = (3.4445, -4.7750, 2.0315) 224 | X = M.bfloat16() / (M.norm() + eps) 225 | if M.size(0) > M.size(1): 226 | X = X.T 227 | for _ in range(steps): 228 | A = X @ X.T 229 | B = A @ X 230 | X = a * X + b * B + c * A @ B 231 | if M.size(0) > M.size(1): 232 | X = X.T 233 | return X.to(M.dtype) 234 | -------------------------------------------------------------------------------- /MARS/optimizers/muon.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from KellerJordan/modded-nanogpt: https://github.com/KellerJordan/modded-nanogpt/blob/master/train_gpt2.py 3 | """ 4 | 5 | import torch 6 | import torch.distributed as dist 7 | import os 8 | 9 | def zeropower_via_svd(G, steps=None): 10 | U, S, V = G.svd() 11 | return U @ V.T 12 | 13 | @torch.compile 14 | def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7): 15 | """ 16 | Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a 17 | quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose 18 | of minimizing steps, it turns out to be empirically effective to keep increasing the slope at 19 | zero even beyond the point where the iteration no longer converges all the way to one everywhere 20 | on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T 21 | where S' is diagonal with S_{ii}' \sim Uniform(0.5, 1.5), which turns out not to hurt model 22 | performance at all relative to UV^T, where USV^T = G is the SVD. 23 | """ 24 | assert len(G.shape) == 2 25 | a, b, c = (3.4445, -4.7750, 2.0315) 26 | X = G.bfloat16() 27 | X /= (X.norm() + eps) # ensure top singular value <= 1 28 | if G.size(0) > G.size(1): 29 | X = X.T 30 | for _ in range(steps): 31 | A = X @ X.T 32 | B = A @ X 33 | X = a * X + b * B + c * A @ B 34 | if G.size(0) > G.size(1): 35 | X = X.T 36 | return X 37 | 38 | zeropower_backends = dict(svd=zeropower_via_svd, newtonschulz5=zeropower_via_newtonschulz5) 39 | 40 | class Muon(torch.optim.Optimizer): 41 | """ 42 | Muon - MomentUm Orthogonalized by Newton-schulz 43 | 44 | Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- 45 | processing step, in which each 2D parameter's update is replaced with the nearest orthogonal 46 | matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has 47 | the advantage that it can be stably run in bfloat16 on the GPU. 48 | 49 | Some warnings: 50 | - This optimizer assumes that all parameters passed in are 2D. 51 | - It should not be used for the embedding layer, the final fully connected layer, or any {0,1}-D 52 | parameters; those should all be optimized by a standard method (e.g., AdamW). 53 | - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions. 54 | - We believe it is unlikely to work well for training with small batch size. 55 | - We believe it may not work well for finetuning pretrained models, but we haven't tested this. 56 | - We have not yet tried this optimizer for training scenarios larger than NanoGPT (124M). 57 | 58 | Arguments: 59 | lr: The learning rate used by the internal SGD. 60 | momentum: The momentum used by the internal SGD. 61 | nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) 62 | backend: The chosen backend for the orthogonalization step. (recommended: 'newtonschulz5') 63 | backend_steps: The number of iteration steps to use in the backend, if it is iterative. 64 | """ 65 | def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, 66 | backend='newtonschulz5', backend_steps=5, weight_decay=0.): 67 | defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, backend=backend, backend_steps=backend_steps, weight_decay=weight_decay) 68 | super().__init__(params, defaults) 69 | if 'WORLD_SIZE' in os.environ: 70 | self.world_size = int(os.environ['WORLD_SIZE']) 71 | self.rank = int(os.environ['RANK']) 72 | else: 73 | self.world_size = 1 74 | self.rank = 0 75 | 76 | def step(self): 77 | 78 | for group in self.param_groups: 79 | 80 | lr = group['lr'] 81 | weight_decay = group['weight_decay'] 82 | momentum = group['momentum'] 83 | zeropower_backend = zeropower_backends[group['backend']] 84 | 85 | # generate weight updates in distributed fashion 86 | total_params = sum(p.numel() for p in group['params']) 87 | updates_flat = torch.zeros(total_params, device='cuda', dtype=torch.bfloat16) 88 | curr_idx = 0 89 | for i, p in enumerate(group['params']): 90 | # luckily this will perfectly distribute a transformer with multiple of 4 layers to 8 GPUs 91 | if i % int(self.world_size) == int(self.rank): 92 | g = p.grad 93 | assert g is not None 94 | if g.ndim > 2: 95 | g = g.view(g.size(0), -1) 96 | state = self.state[p] 97 | if 'momentum_buffer' not in state: 98 | state['momentum_buffer'] = torch.zeros_like(g) 99 | buf = state['momentum_buffer'] 100 | buf.mul_(momentum).add_(g) 101 | if group['nesterov']: 102 | g = g.add(buf, alpha=momentum) 103 | g = zeropower_backend(g, steps=group['backend_steps']) 104 | g *= max(1, g.size(0)/g.size(1))**0.5 105 | updates_flat[curr_idx:curr_idx+p.numel()] = g.flatten() 106 | curr_idx += p.numel() 107 | 108 | # sync updates across devices. we are not memory-constrained so can do this simple deserialization 109 | if self.world_size > 1: 110 | dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) 111 | 112 | # deserialize and apply updates 113 | curr_idx = 0 114 | for p in group['params']: 115 | g = updates_flat[curr_idx:curr_idx+p.numel()].view_as(p.data).type_as(p.data) 116 | p.data.mul_(1.-lr*weight_decay).add_(g, alpha=-lr) 117 | curr_idx += p.numel() -------------------------------------------------------------------------------- /MARS/train_CNN.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates 2 | # SPDX-License-Identifier: Apache-2.0 3 | import argparse 4 | from typing import List, Tuple, Type 5 | 6 | import matplotlib.pyplot as plt 7 | import torch 8 | import torch.nn as nn 9 | from torch.optim import Adam, AdamW 10 | from torch.optim.lr_scheduler import CosineAnnealingLR 11 | from torch.utils.data import DataLoader 12 | from torchvision import datasets, transforms 13 | import numpy as np 14 | from utils.model_CNN import Network 15 | from optimizers.adopt import ADOPT 16 | from optimizers.mars import MARS 17 | import random 18 | parser = argparse.ArgumentParser(add_help=True) 19 | parser.add_argument( 20 | "--dataset", type=str, default="cifar10", choices=["mnist", "cifar10"], help="dataset to use" 21 | ) 22 | parser.add_argument("-b", "--batch_size", type=int, default=128, help="batch size") 23 | parser.add_argument("-e", "--epochs", type=int, default=50, help="number of epochs") 24 | parser.add_argument("--seed", type=int, default=0, help="random seed") 25 | parser.add_argument("--cpu", action="store_true", help="use cpu only") 26 | 27 | 28 | def get_datasets(dataset_name: str, batch_size: int) -> Tuple[DataLoader, DataLoader]: 29 | """Get train and test dataloaders.""" 30 | if dataset_name == "mnist": 31 | transform = transforms.Compose([ 32 | transforms.ToTensor(), 33 | transforms.Normalize((0.1307,), (0.3081,)) 34 | ]) 35 | train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform) 36 | test_dataset = datasets.MNIST('./data', train=False, transform=transform) 37 | elif dataset_name == "cifar10": 38 | transform_train = transforms.Compose([ 39 | transforms.RandomHorizontalFlip(), 40 | transforms.ToTensor(), 41 | # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 42 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)) 43 | ]) 44 | transform_test = transforms.Compose([ 45 | transforms.ToTensor(), 46 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)) 47 | ]) 48 | train_dataset = datasets.CIFAR10('./data', train=True, download=True, transform=transform_train) 49 | test_dataset = datasets.CIFAR10('./data', train=False, transform=transform_test) 50 | else: 51 | raise NotImplementedError(f"{dataset_name=} is not implemented.") 52 | 53 | train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4) 54 | test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4) 55 | 56 | return train_loader, test_loader 57 | 58 | 59 | class WarmupCosineScheduler: 60 | """Custom learning rate scheduler with linear warmup and cosine decay.""" 61 | def __init__(self, optimizer, warmup_iters: int, total_iters: int, min_lr: float, max_lr: float): 62 | self.optimizer = optimizer 63 | self.warmup_iters = warmup_iters 64 | self.total_iters = total_iters 65 | self.min_lr = min_lr 66 | self.max_lr = max_lr 67 | self.current_iter = 0 68 | self.lr = 0 69 | 70 | def step(self): 71 | self.current_iter += 1 72 | if self.current_iter <= self.warmup_iters: 73 | lr = self.current_iter / self.warmup_iters * self.max_lr 74 | else: 75 | lr = self.min_lr + 0.5 * (self.max_lr - self.min_lr) * ( 76 | np.cos((self.current_iter - self.warmup_iters) / (self.total_iters - self.warmup_iters) * 3.14159265 / 2) 77 | ).item() 78 | 79 | for param_group in self.optimizer.param_groups: 80 | param_group['lr'] = lr 81 | self.lr = lr 82 | 83 | class Trainer: 84 | """Training manager for PyTorch models.""" 85 | def __init__(self, model: nn.Module, optimizer: torch.optim.Optimizer, scheduler, device: torch.device): 86 | self.model = model 87 | self.optimizer = optimizer 88 | self.scheduler = scheduler 89 | self.device = device 90 | self.criterion = nn.CrossEntropyLoss() 91 | self.train_acc_trace = [] 92 | self.val_acc_trace = [] 93 | 94 | def train_epoch(self, train_loader: DataLoader) -> float: 95 | self.model.train() 96 | correct = 0 97 | total = 0 98 | 99 | for batch in train_loader: 100 | images, targets = batch[0].to(self.device), batch[1].to(self.device) 101 | 102 | self.optimizer.zero_grad() 103 | outputs = self.model(images) 104 | loss = self.criterion(outputs, targets) 105 | loss.backward() 106 | self.optimizer.step() 107 | 108 | _, predicted = outputs.max(1) 109 | total += targets.size(0) 110 | correct += predicted.eq(targets).sum().item() 111 | if self.scheduler is not None: 112 | self.scheduler.step() 113 | return 100. * correct / total 114 | 115 | def evaluate(self, test_loader: DataLoader) -> float: 116 | self.model.eval() 117 | correct = 0 118 | total = 0 119 | 120 | with torch.no_grad(): 121 | for batch in test_loader: 122 | images, targets = batch[0].to(self.device), batch[1].to(self.device) 123 | outputs = self.model(images) 124 | 125 | _, predicted = outputs.max(1) 126 | total += targets.size(0) 127 | correct += predicted.eq(targets).sum().item() 128 | 129 | return 100. * correct / total 130 | 131 | def train(self, train_loader: DataLoader, test_loader: DataLoader, epochs: int): 132 | for epoch in range(epochs): 133 | train_acc = self.train_epoch(train_loader) 134 | val_acc = self.evaluate(test_loader) 135 | 136 | self.train_acc_trace.append(train_acc) 137 | self.val_acc_trace.append(val_acc) 138 | 139 | # if self.scheduler is not None: 140 | # self.scheduler.step() 141 | 142 | print(f"Epoch {epoch+1}/{epochs} - Train Acc: {train_acc:.2f}% - Val Acc: {val_acc:.2f}%") 143 | 144 | 145 | def get_optimizers(model: nn.Module, opt_name, args): 146 | """Configure optimizers and schedulers.""" 147 | total_steps = 50_000 // args.batch_size * args.epochs 148 | n_warmup = int(total_steps * 0.10) # % of total steps 149 | weight_decay = 1e-4 150 | max_lr = 6e-4 151 | min_lr = 1e-6 152 | 153 | if opt_name == "Adam": 154 | # Adam 155 | adam = Adam(model.parameters(), lr=max_lr) 156 | adam_scheduler = WarmupCosineScheduler( 157 | adam, n_warmup, total_steps, min_lr, max_lr 158 | ) 159 | optimizer = (adam, adam_scheduler, "Adam") 160 | 161 | elif opt_name == "AdamW": 162 | # AdamW 163 | adamw = AdamW(model.parameters(), lr=max_lr, weight_decay=weight_decay) 164 | adamw_scheduler = WarmupCosineScheduler( 165 | adamw, n_warmup, total_steps, min_lr, max_lr 166 | ) 167 | optimizer = (adamw, adamw_scheduler, "AdamW") 168 | elif opt_name == "ADOPT": 169 | # ADOPT 170 | adopt = ADOPT(model.parameters(), lr=max_lr, weight_decay=weight_decay) 171 | adopt_scheduler = WarmupCosineScheduler( 172 | adopt, n_warmup, total_steps, min_lr, max_lr 173 | ) 174 | optimizer = (adopt, adopt_scheduler, "ADOPT") 175 | elif opt_name == "MARS": 176 | # MARS 177 | mars = MARS(model.parameters(), lr=3e-3, weight_decay=weight_decay, optimize_1d=False) 178 | mars_scheduler = WarmupCosineScheduler( 179 | mars, n_warmup, total_steps, min_lr, 3e-3 180 | ) 181 | optimizer = (mars, mars_scheduler, "MARS") 182 | return optimizer 183 | 184 | 185 | def plot_results(results: List[List[float]], optimizer_names: List[str], args): 186 | """Plot training results.""" 187 | fig, ax = plt.subplots(figsize=(5.5, 3.5)) 188 | colors = ["#74add1", "#1730bd", "#1a9850", "#001c01"] 189 | 190 | for i, acc in enumerate(results): 191 | ax.plot(range(1, len(acc) + 1), acc, label=optimizer_names[i], lw=2, color=colors[i]) 192 | 193 | ax.set_title(f"{args.dataset.upper()} (val)", loc="left") 194 | ax.set_xlabel("Epoch", fontsize="medium") 195 | ax.set_ylabel("Accuracy (%)", fontsize="medium") 196 | 197 | ax.legend(ncols=2, columnspacing=0.8, fontsize="medium") 198 | ax.grid(alpha=0.2) 199 | 200 | ax.set_ylim(90 if args.dataset == "mnist" else 70) 201 | acc_min, acc_max = ax.get_ylim() 202 | ax.set_yticks(torch.linspace(acc_min, acc_max, 5).int().tolist()) 203 | ax.xaxis.set_major_locator(plt.MaxNLocator(integer=True)) 204 | 205 | fig.tight_layout() 206 | fig.savefig( 207 | f"./compare-{args.dataset}-blank.png", 208 | dpi=300, 209 | bbox_inches="tight", 210 | ) 211 | plt.show() 212 | 213 | 214 | def main(args): 215 | # Set random seed and device 216 | torch.manual_seed(args.seed) 217 | device = torch.device("cuda" if torch.cuda.is_available() and not args.cpu else "cpu") 218 | 219 | # Get dataloaders 220 | train_loader, test_loader = get_datasets(args.dataset, args.batch_size) 221 | # Model configuration 222 | model_config = { 223 | "n_inputs": (3, 32, 32) if args.dataset == "cifar10" else (1, 28, 28), 224 | "conv_layers_list": [ 225 | {"filters": 32, "kernel_size": 3, "repeat": 2, "batch_norm": True}, 226 | {"filters": 64, "kernel_size": 3, "repeat": 2, "batch_norm": True}, 227 | {"filters": 128, "kernel_size": 3, "repeat": 2, "batch_norm": True}, 228 | ], 229 | "n_hiddens_list": [512], 230 | "n_outputs": 10, 231 | "dropout": 0.2, 232 | } 233 | 234 | results = [] 235 | optimizer_names = [] 236 | # Train with different optimizers 237 | opt_names = ["Adam", "AdamW", "ADOPT", "MARS"] 238 | for opt_name in opt_names: 239 | print(opt_name) 240 | torch.manual_seed(args.seed) 241 | model = Network(**model_config).to(device) 242 | optimizer, scheduler, name = get_optimizers(model, opt_name, args) 243 | trainer = Trainer(model, optimizer, scheduler, device) 244 | trainer.train(train_loader, test_loader, args.epochs) 245 | results.append(trainer.val_acc_trace) 246 | optimizer_names.append(name) 247 | 248 | plot_results(results, optimizer_names, args) 249 | 250 | 251 | if __name__ == "__main__": 252 | args = parser.parse_args() 253 | main(args) -------------------------------------------------------------------------------- /MARS/train_CV.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from uclaml/Padam: https://github.com/uclaml/Padam/blob/master/run_cnn_test_cifar10.py 3 | """ 4 | import numpy as np 5 | import os 6 | import argparse 7 | import json 8 | from tqdm import tqdm 9 | 10 | parser = argparse.ArgumentParser(description='PyTorch Training') 11 | parser.add_argument( 12 | "--dataset", type=str, default="cifar10", choices=["mnist", "cifar10", "cifar100"], help="dataset to use" 13 | ) 14 | parser.add_argument( 15 | "--scheduler", type=str, default="multistep", choices=["multistep", "cosine", "constant"], help="scheduler to use" 16 | ) 17 | parser.add_argument("--train_bsz", type=int, default=128, help="training batch size") 18 | parser.add_argument("--eval_bsz", type=int, default=100, help="eval batch size") 19 | parser.add_argument("--seed", type=int, default=0, help="random seed") 20 | parser.add_argument("--cpu", action="store_true", help="use cpu only") 21 | parser.add_argument("--cuda", type=str, default="0", help="device to use") 22 | 23 | parser.add_argument('--lr', default=0.1, type=float, help='learning rate') 24 | parser.add_argument('--adamw_lr', default=0.003, type=float, help='learning rate for adamw') 25 | parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint') 26 | parser.add_argument('--optim', '-m', type=str, choices=["adam", "adamw", "mars", "muon"], default='mars', help='optimization method, default: mars') 27 | parser.add_argument('--net', '-n', type=str, default="resnet18", help='network archtecture, choosing from "simple_cnn" or torchvision models. default: resnet18') 28 | parser.add_argument('--wd', default=0., type=float, help='weight decay') 29 | parser.add_argument('--Nepoch', default=200, type=int, help='number of epoch') 30 | parser.add_argument('--beta1', default=0.9, type=float, help='beta1') 31 | parser.add_argument('--beta2', default=0.999, type=float, help='beta2') 32 | parser.add_argument('--wandb', action='store_true', help='use wandb') 33 | parser.add_argument('--save_dir', type=str, default="./checkpoint", help='save directory') 34 | parser.add_argument('--wandb_name', type=str, default="None", help='log directory') 35 | 36 | 37 | args = parser.parse_args() 38 | os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda 39 | if args.wandb: 40 | import wandb 41 | if args.wandb_name == "None": 42 | wandb.init(project="CV", name=args.dataset+"_"+args.net+"_"+args.optim+"_"+str(args.lr), config=args) 43 | else: 44 | wandb.init(project="CV", name=args.wandb_name, config=args) 45 | 46 | import torch 47 | import torch.nn as nn 48 | import torch.optim as optim 49 | import torch.backends.cudnn as cudnn 50 | from utils.cv_utils import get_datasets, get_scheduler, get_model 51 | use_cuda = torch.cuda.is_available() and not args.cpu 52 | 53 | os.environ['PYTHONHASHSEED'] = str(args.seed) 54 | np.random.seed(args.seed) 55 | torch.manual_seed(args.seed) 56 | torch.cuda.manual_seed(args.seed) 57 | torch.cuda.manual_seed_all(args.seed) 58 | 59 | trainloader, testloader = get_datasets(args.dataset, args.train_bsz, args.eval_bsz) 60 | if args.resume: 61 | # Load checkpoint. 62 | print('==> Resuming from checkpoint..') 63 | assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' 64 | checkpoint = torch.load(f'./checkpoint/{args.net}_{args.dataset}_'+args.optim) 65 | model = checkpoint['model'] 66 | start_epoch = checkpoint['epoch'] 67 | train_losses = checkpoint['train_losses'] 68 | test_losses = checkpoint['test_losses'] 69 | train_errs = checkpoint['train_errs'] 70 | test_errs = checkpoint['test_errs'] 71 | else: 72 | print('==> Building model..') 73 | 74 | model = get_model(args) 75 | 76 | if use_cuda: 77 | model.cuda() 78 | model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) 79 | cudnn.benchmark = True 80 | 81 | 82 | criterion = nn.CrossEntropyLoss() 83 | 84 | betas = (args.beta1, args.beta2) 85 | from optimizers.mars import MARS 86 | from optimizers.muon import Muon 87 | from opt import CombinedOptimizer 88 | from optimizers.adamw import AdamW 89 | if args.optim == 'adam': 90 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay = args.wd, betas = betas) 91 | elif args.optim == 'adamw': 92 | optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay = args.wd, betas = betas) 93 | elif args.optim == 'muon': 94 | optimizer = CombinedOptimizer(model.parameters(), [AdamW, Muon], [{'lr': args.adamw_lr, 'betas': betas, 'weight_decay': args.wd}, 95 | {'lr': args.lr, 'weight_decay': 0.}]) 96 | elif args.optim == 'mars': 97 | optimizer = MARS(model.parameters(), lr=args.lr, weight_decay = args.wd, lr_1d=args.adamw_lr) 98 | 99 | scheduler = get_scheduler(optimizer, args) 100 | best_acc = 0 # best test accuracy 101 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 102 | train_errs = [] 103 | test_errs = [] 104 | train_losses = [] 105 | test_losses = [] 106 | acc_list = [] 107 | t_bar = tqdm(total=len(trainloader)) 108 | t_bar2 = tqdm(total=len(testloader)) 109 | for epoch in range(start_epoch+1, args.Nepoch+1): 110 | 111 | scheduler.step() 112 | # print ('\nEpoch: %d' % epoch, ' Learning rate:', scheduler.get_lr()) 113 | model.train() # Training 114 | 115 | train_loss = 0 116 | correct_train = 0 117 | total_train = 0 118 | print(scheduler.get_lr()) 119 | t_bar.reset() 120 | for batch_idx, (inputs, targets) in enumerate(trainloader): 121 | if use_cuda: 122 | inputs, targets = inputs.cuda(), targets.cuda() 123 | 124 | optimizer.zero_grad() 125 | outputs = model(inputs) 126 | loss = criterion(outputs, targets) 127 | loss.backward() 128 | optimizer.step() 129 | 130 | train_loss += loss.item() 131 | _, predicted = torch.max(outputs.data, 1) 132 | total_train += targets.size(0) 133 | correct_train += predicted.eq(targets.data).cpu().sum().item() 134 | 135 | t_bar.update(1) 136 | t_bar.set_description('Epoch: %d | Loss: %.3f | Acc: %.3f%% ' % (epoch, train_loss/(batch_idx+1), 100.0/total_train*(correct_train))) 137 | t_bar.refresh() 138 | train_losses.append(train_loss/(batch_idx+1)) 139 | train_errs.append(1 - correct_train/total_train) 140 | 141 | model.eval() # Testing 142 | 143 | test_loss = 0 144 | correct = 0 145 | total = 0 146 | t_bar2.reset() 147 | for batch_idx, (inputs, targets) in enumerate(testloader): 148 | if use_cuda: 149 | inputs, targets = inputs.cuda(), targets.cuda() 150 | outputs = model(inputs) 151 | loss = criterion(outputs, targets) 152 | 153 | test_loss += loss.item() 154 | _, predicted = torch.max(outputs.data, 1) 155 | total += targets.size(0) 156 | correct += predicted.eq(targets.data).cpu().sum().item() 157 | 158 | t_bar2.update(1) 159 | t_bar2.set_description('Loss: %.3f | Acc: %.3f%% (Best: %.3f%%)' % (test_loss/(batch_idx+1), 100.0/total*(correct), best_acc)) 160 | t_bar2.refresh() 161 | test_errs.append(1 - correct/total) 162 | test_losses.append(test_loss/(batch_idx+1)) 163 | if args.wandb: 164 | wandb.log({"epoch": epoch, 165 | "train_loss": train_loss/(batch_idx+1), 166 | "train_acc": 100.0/total_train*(correct_train), 167 | "test_loss": test_loss/(batch_idx+1), 168 | "test_acc": 100.0/total*(correct), 169 | "lr": scheduler.get_lr()[0]}, step=epoch) 170 | # Save checkpoint 171 | acc = 100.0/total*(correct) 172 | if acc > best_acc: 173 | if not os.path.isdir('checkpoint'): 174 | os.mkdir('checkpoint') 175 | state = { 176 | 'model': model, 177 | 'epoch': epoch, 178 | } 179 | # torch.save(state, './checkpoint/cnn_cifar10_' + args.optim) 180 | torch.save(state, os.path.join(args.save_dir, "-".join([args.net, args.dataset, args.optim, str(args.lr).replace(".", "_")])+".pth")) 181 | best_acc = acc 182 | t_bar2.set_description('Model Saved! | Loss: %.3f | Acc: %.3f%% (Best: %.3f%%)' % (test_loss/(batch_idx+1), 100.0/total*(correct), best_acc)) 183 | t_bar2.refresh() 184 | acc_list.append(acc) 185 | -------------------------------------------------------------------------------- /MARS/train_adamw.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from https://github.com/Liuhong99/Sophia/blob/main/train_adam.py 3 | """ 4 | import os 5 | import time 6 | import math 7 | import pickle 8 | from contextlib import nullcontext 9 | 10 | import numpy as np 11 | import torch 12 | from torch.nn.parallel import DistributedDataParallel as DDP 13 | from torch.distributed import init_process_group, destroy_process_group 14 | 15 | from model import GPTConfig, GPT 16 | import sys 17 | from ast import literal_eval 18 | # ----------------------------------------------------------------------------- 19 | # default config values designed to train a gpt2 (124M) on OpenWebText 20 | # I/O 21 | out_dir = 'out' 22 | eval_interval = 2000 23 | log_interval = 1 24 | eval_iters = 200 25 | eval_only = False # if True, script exits right after the first eval 26 | always_save_checkpoint = True # if True, always save a checkpoint after each eval 27 | init_from = 'scratch' # 'scratch' or 'resume' or 'gpt2*' 28 | # wandb logging 29 | wandb_log = False # disabled by default 30 | wandb_project = 'mars' 31 | wandb_run_name = 'gpt2' # 'run' + str(time.time()) 32 | # data 33 | dataset = 'openwebtext' 34 | gradient_accumulation_steps = 5 # used to simulate larger batch sizes 35 | batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size 36 | block_size = 1024 37 | # model 38 | n_layer = 12 39 | n_head = 12 40 | n_embd = 768 41 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 42 | bias = False # do we use bias inside LayerNorm and Linear layers? 43 | # optimizer 44 | optimizer_name = 'adamw' 45 | learning_rate = 6e-4 # max learning rate 46 | max_iters = 600000 # total number of training iterations 47 | weight_decay = 1e-1 48 | beta1 = 0.9 49 | beta2 = 0.95 50 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 51 | interval = 10 52 | variant = 4 53 | schedule='cosine' 54 | # learning rate decay settings 55 | decay_lr = True # whether to decay the learning rate 56 | warmup_iters = 2000 # how many steps to warm up for 57 | lr_decay_iters = 600000 # should be ~= max_iters per Chinchilla 58 | min_lr = 6e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla 59 | # DDP settings 60 | backend = 'nccl' # 'nccl', 'gloo', etc. 61 | # system 62 | device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks 63 | dtype = 'bfloat16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler 64 | compile = True # use PyTorch 2.0 to compile the model to be faster 65 | scale_attn_by_inverse_layer_idx = True 66 | # ----------------------------------------------------------------------------- 67 | config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] 68 | for arg in sys.argv[1:]: 69 | if '=' not in arg: 70 | # assume it's the name of a config file 71 | assert not arg.startswith('--') 72 | config_file = arg 73 | print(f"Overriding config with {config_file}:") 74 | with open(config_file) as f: 75 | print(f.read()) 76 | exec(open(config_file).read()) 77 | else: 78 | # assume it's a --key=value argument 79 | assert arg.startswith('--') 80 | key, val = arg.split('=') 81 | key = key[2:] 82 | if key in globals(): 83 | try: 84 | # attempt to eval it it (e.g. if bool, number, or etc) 85 | attempt = literal_eval(val) 86 | except (SyntaxError, ValueError): 87 | # if that goes wrong, just use the string 88 | attempt = val 89 | # ensure the types match ok 90 | assert type(attempt) == type(globals()[key]) 91 | # cross fingers 92 | print(f"Overriding: {key} = {attempt}") 93 | globals()[key] = attempt 94 | else: 95 | raise ValueError(f"Unknown config key: {key}") 96 | 97 | config = {k: globals()[k] for k in config_keys} # will be useful for logging 98 | # ----------------------------------------------------------------------------- 99 | 100 | # various inits, derived attributes, I/O setup 101 | ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run? 102 | if ddp: 103 | init_process_group(backend=backend) 104 | ddp_rank = int(os.environ['RANK']) 105 | ddp_local_rank = int(os.environ['LOCAL_RANK']) 106 | device = f'cuda:{ddp_local_rank}' 107 | torch.cuda.set_device(device) 108 | master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. 109 | seed_offset = ddp_rank # each process gets a different seed 110 | else: 111 | # if not ddp, we are running on a single gpu, and one process 112 | master_process = True 113 | seed_offset = 0 114 | gradient_accumulation_steps *= 8 # simulate 8 gpus 115 | 116 | if master_process: 117 | os.makedirs(out_dir, exist_ok=True) 118 | torch.manual_seed(5000 + seed_offset) 119 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul 120 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn 121 | device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast 122 | # note: float16 data type will automatically use a GradScaler 123 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] 124 | ctx = nullcontext() if device_type == 'cpu' else torch.autocast(device_type=device_type, dtype=ptdtype) 125 | 126 | # poor man's data loader 127 | data_dir = os.path.join('data', dataset) 128 | train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r') 129 | val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r') 130 | def get_batch(split): 131 | data = train_data if split == 'train' else val_data 132 | ix = torch.randint(len(data) - block_size, (batch_size,)) 133 | x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix]) 134 | y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix]) 135 | if device_type == 'cuda': 136 | # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True) 137 | x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True) 138 | else: 139 | x, y = x.to(device), y.to(device) 140 | return x, y 141 | 142 | # init these up here, can override if init_from='resume' (i.e. from a checkpoint) 143 | iter_num = 0 144 | best_val_loss = 1e9 145 | 146 | # attempt to derive vocab_size from the dataset 147 | meta_path = os.path.join(data_dir, 'meta.pkl') 148 | meta_vocab_size = None 149 | if os.path.exists(meta_path): 150 | with open(meta_path, 'rb') as f: 151 | meta = pickle.load(f) 152 | meta_vocab_size = meta['vocab_size'] 153 | print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})") 154 | 155 | # model init 156 | model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size, 157 | bias=bias, vocab_size=None, dropout=dropout, scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx) # start with model_args from command line 158 | if init_from == 'scratch': 159 | # init a new model from scratch 160 | print("Initializing a new model from scratch") 161 | # determine the vocab size we'll use for from-scratch training 162 | if meta_vocab_size is None: 163 | print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)") 164 | model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304 165 | gptconf = GPTConfig(**model_args) 166 | model = GPT(gptconf) 167 | elif init_from == 'resume': 168 | print(f"Resuming training from {out_dir}") 169 | # resume training from a checkpoint. 170 | ckpt_path = os.path.join(out_dir, 'ckpt.pt') 171 | checkpoint = torch.load(ckpt_path, map_location=device) 172 | checkpoint_model_args = checkpoint['model_args'] 173 | # force these config attributes to be equal otherwise we can't even resume training 174 | # the rest of the attributes (e.g. dropout) can stay as desired from command line 175 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']: 176 | model_args[k] = checkpoint_model_args[k] 177 | # create the model 178 | gptconf = GPTConfig(**model_args) 179 | model = GPT(gptconf) 180 | state_dict = checkpoint['model'] 181 | # fix the keys of the state dictionary :( 182 | # honestly no idea how checkpoints sometimes get this prefix, have to debug more 183 | unwanted_prefix = '_orig_mod.' 184 | for k,v in list(state_dict.items()): 185 | if k.startswith(unwanted_prefix): 186 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) 187 | model.load_state_dict(state_dict) 188 | iter_num = checkpoint['iter_num'] 189 | best_val_loss = checkpoint['best_val_loss'] 190 | elif init_from.startswith('gpt2'): 191 | print(f"Initializing from OpenAI GPT-2 weights: {init_from}") 192 | # initialize from OpenAI GPT-2 weights 193 | override_args = dict(dropout=dropout) 194 | model = GPT.from_pretrained(init_from, override_args) 195 | # read off the created config params, so we can store them into checkpoint correctly 196 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']: 197 | model_args[k] = getattr(model.config, k) 198 | # crop down the model block size if desired, using model surgery 199 | if block_size < model.config.block_size: 200 | model.crop_block_size(block_size) 201 | model_args['block_size'] = block_size # so that the checkpoint will have the right value 202 | model.to(device) 203 | 204 | # initialize a GradScaler. If enabled=False scaler is a no-op 205 | scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16')) 206 | 207 | # optimizer 208 | optimizer = model.configure_optimizers(optimizer_name, weight_decay, learning_rate, (beta1, beta2), device_type) 209 | if init_from == 'resume': 210 | optimizer.load_state_dict(checkpoint['optimizer']) 211 | del state_dict 212 | del checkpoint 213 | # compile the model 214 | if compile: 215 | print("compiling the model... (takes a ~minute)") 216 | unoptimized_model = model 217 | model = torch.compile(model) # requires PyTorch 2.0 218 | 219 | # wrap model into DDP container 220 | if ddp: 221 | model = DDP(model, device_ids=[ddp_local_rank]) 222 | 223 | # helps estimate an arbitrarily accurate loss over either split using many batches 224 | @torch.no_grad() 225 | def estimate_loss(): 226 | out = {} 227 | model.eval() 228 | for split in ['train', 'val']: 229 | losses = torch.zeros(eval_iters) 230 | for k in range(eval_iters): 231 | X, Y = get_batch(split) 232 | with ctx: 233 | logits, loss = model(X, Y) 234 | losses[k] = loss.item() 235 | out[split] = losses.mean() 236 | model.train() 237 | return out 238 | 239 | # learning rate decay scheduler (cosine with warmup) 240 | def get_lr(it, schedule='cosine'): 241 | #ing rate schedule {schedule}") 242 | # 1) linear warmup for warmup_iters steps 243 | if it < warmup_iters: 244 | return learning_rate * it / warmup_iters 245 | # 2) if it > lr_decay_iters, return min learning rate 246 | if it > lr_decay_iters: 247 | return min_lr 248 | # 3) in between, use cosine decay down to min learning rate 249 | decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) 250 | assert 0 <= decay_ratio <= 1 251 | if schedule=='cosine': 252 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 253 | elif schedule=='exp': 254 | coeff = np.power(0.9, 100 * decay_ratio) 255 | return min_lr + coeff * (learning_rate - min_lr) 256 | 257 | # logging 258 | if wandb_log and master_process: 259 | import wandb 260 | wandb.init(project=wandb_project, name=wandb_run_name, config=config) 261 | 262 | # training loop 263 | X, Y = get_batch('train') # fetch the very first batch 264 | t0 = time.time() 265 | local_iter_num = 0 # number of iterations in the lifetime of this process 266 | raw_model = model.module if ddp else model # unwrap DDP container if needed 267 | running_mfu = -1.0 268 | clip_time = 0 269 | while True: 270 | 271 | # determine and set the learning rate for this iteration 272 | lr = get_lr(iter_num, schedule=schedule) if decay_lr else learning_rate 273 | for param_group in optimizer.param_groups: 274 | param_group['lr'] = lr 275 | 276 | # evaluate the loss on train/val sets and write checkpoints 277 | if iter_num % eval_interval == 0 and master_process: 278 | losses = estimate_loss() 279 | print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") 280 | if wandb_log: 281 | wandb.log({ 282 | "iter": iter_num, 283 | "train/loss": losses['train'], 284 | "val/loss": losses['val'], 285 | "lr": lr, 286 | "mfu": running_mfu*100, # convert to percentage 287 | }, step=iter_num) 288 | if losses['val'] < best_val_loss or always_save_checkpoint: 289 | best_val_loss = losses['val'] 290 | if iter_num > 0: 291 | checkpoint = { 292 | 'model': raw_model.state_dict(), 293 | 'optimizer': optimizer.state_dict(), 294 | 'model_args': model_args, 295 | 'iter_num': iter_num, 296 | 'best_val_loss': best_val_loss, 297 | 'config': config, 298 | } 299 | print(f"saving checkpoint to {out_dir}") 300 | torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt')) 301 | if iter_num % (eval_interval * 5) == 0: 302 | checkpoint = { 303 | 'model': raw_model.state_dict(), 304 | 'optimizer': optimizer.state_dict(), 305 | 'model_args': model_args, 306 | 'iter_num': iter_num, 307 | 'best_val_loss': best_val_loss, 308 | 'config': config, 309 | } 310 | print(f"saving checkpoint to {out_dir}") 311 | torch.save(checkpoint, os.path.join(out_dir, 'ckpt_'+str(iter_num)+'.pt')) 312 | if iter_num == 0 and eval_only: 313 | break 314 | 315 | # forward backward update, with optional gradient accumulation to simulate larger batch size 316 | # and using the GradScaler if data type is float16 317 | for micro_step in range(gradient_accumulation_steps): 318 | if ddp: 319 | # in DDP training we only need to sync gradients at the last micro step. 320 | # the official way to do this is with model.no_sync() context manager, but 321 | # I really dislike that this bloats the code and forces us to repeat code 322 | # looking at the source of that context manager, it just toggles this variable 323 | model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1) 324 | with ctx: 325 | logits, loss = model(X, Y) 326 | # immediately async prefetch next batch while model is doing the forward pass on the GPU 327 | X, Y = get_batch('train') 328 | # backward pass, with gradient scaling if training in fp16 329 | scaler.scale(loss).backward() 330 | # clip the gradient 331 | if grad_clip != 0.0: 332 | scaler.unscale_(optimizer) 333 | total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) 334 | if total_norm.item() > grad_clip: 335 | clip_time += 1 336 | # step the optimizer and scaler if training in fp16 337 | scaler.step(optimizer) 338 | scaler.update() 339 | # flush the gradients as soon as we can, no need for this memory anymore 340 | optimizer.zero_grad(set_to_none=True) 341 | 342 | # timing and logging 343 | t1 = time.time() 344 | dt = t1 - t0 345 | t0 = t1 346 | if iter_num % log_interval == 0 and master_process: 347 | lossf = loss.item() # loss as float. note: this is a CPU-GPU sync point 348 | if local_iter_num >= 5: # let the training loop settle a bit 349 | mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt) 350 | running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu 351 | print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%") 352 | params = [] 353 | for (name, p) in model.named_parameters(): 354 | params.append(p) 355 | total_param_norm = 0 356 | for p in params: 357 | param_norm = p.data.norm(2) 358 | total_param_norm += param_norm.item() ** 2 359 | total_param_norm = total_param_norm ** 0.5 360 | momentum_norm = 0 361 | LL = len(optimizer.state_dict()['state']) 362 | for jj in range(LL): 363 | momentum_norm += (optimizer.state_dict()['state'][jj]['exp_avg'].detach().norm(2)) ** 2 364 | momentum_norm = torch.sqrt(momentum_norm).item() 365 | if wandb_log: 366 | wandb.log({ 367 | "iter": iter_num, 368 | "train/loss": lossf, 369 | "lr": lr, 370 | "param_norm": total_param_norm, 371 | "momentum_norm" : momentum_norm, 372 | "train/clip_rate": clip_time / (iter_num + 1) 373 | }, step=iter_num) 374 | iter_num += 1 375 | local_iter_num += 1 376 | 377 | # termination conditions 378 | if iter_num > max_iters: 379 | break 380 | 381 | if ddp: 382 | destroy_process_group() 383 | -------------------------------------------------------------------------------- /MARS/train_adamw_fw.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import math 4 | import pickle 5 | from contextlib import nullcontext 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn.parallel import DistributedDataParallel as DDP 10 | from torch.distributed import init_process_group, destroy_process_group 11 | 12 | from model import GPTConfig, GPT 13 | import sys 14 | from ast import literal_eval 15 | # ----------------------------------------------------------------------------- 16 | # default config values designed to train a gpt2 (124M) on OpenWebText 17 | # I/O 18 | out_dir = 'out' 19 | eval_interval = 2000 20 | log_interval = 1 21 | eval_iters = 200 22 | eval_only = False # if True, script exits right after the first eval 23 | always_save_checkpoint = True # if True, always save a checkpoint after each eval 24 | init_from = 'scratch' # 'scratch' or 'resume' or 'gpt2*' 25 | # wandb logging 26 | wandb_log = False # disabled by default 27 | wandb_project = 'mars' 28 | wandb_run_name = 'gpt2' # 'run' + str(time.time()) 29 | # data 30 | dataset = 'fineweb-edu100B' 31 | gradient_accumulation_steps = 5 # used to simulate larger batch sizes 32 | batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size 33 | block_size = 1024 34 | # model 35 | n_layer = 12 36 | n_head = 12 37 | n_embd = 768 38 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 39 | bias = False # do we use bias inside LayerNorm and Linear layers? 40 | # optimizer 41 | optimizer_name = 'adamw' 42 | learning_rate = 6e-4 # max learning rate 43 | max_iters = 600000 # total number of training iterations 44 | weight_decay = 1e-1 45 | beta1 = 0.9 46 | beta2 = 0.95 47 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 48 | interval = 10 49 | variant = 4 50 | schedule='cosine' 51 | # learning rate decay settings 52 | decay_lr = True # whether to decay the learning rate 53 | warmup_iters = 2000 # how many steps to warm up for 54 | lr_decay_iters = 600000 # should be ~= max_iters per Chinchilla 55 | min_lr = 6e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla 56 | # DDP settings 57 | backend = 'nccl' # 'nccl', 'gloo', etc. 58 | # system 59 | device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks 60 | dtype = 'bfloat16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler 61 | compile = True # use PyTorch 2.0 to compile the model to be faster 62 | scale_attn_by_inverse_layer_idx = True 63 | # ----------------------------------------------------------------------------- 64 | config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] 65 | for arg in sys.argv[1:]: 66 | if '=' not in arg: 67 | # assume it's the name of a config file 68 | assert not arg.startswith('--') 69 | config_file = arg 70 | print(f"Overriding config with {config_file}:") 71 | with open(config_file) as f: 72 | print(f.read()) 73 | exec(open(config_file).read()) 74 | else: 75 | # assume it's a --key=value argument 76 | assert arg.startswith('--') 77 | key, val = arg.split('=') 78 | key = key[2:] 79 | if key in globals(): 80 | try: 81 | # attempt to eval it it (e.g. if bool, number, or etc) 82 | attempt = literal_eval(val) 83 | except (SyntaxError, ValueError): 84 | # if that goes wrong, just use the string 85 | attempt = val 86 | # ensure the types match ok 87 | assert type(attempt) == type(globals()[key]) 88 | # cross fingers 89 | print(f"Overriding: {key} = {attempt}") 90 | globals()[key] = attempt 91 | else: 92 | raise ValueError(f"Unknown config key: {key}") 93 | 94 | config = {k: globals()[k] for k in config_keys} # will be useful for logging 95 | # ----------------------------------------------------------------------------- 96 | 97 | # various inits, derived attributes, I/O setup 98 | ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run? 99 | if ddp: 100 | init_process_group(backend=backend) 101 | ddp_rank = int(os.environ['RANK']) 102 | ddp_local_rank = int(os.environ['LOCAL_RANK']) 103 | device = f'cuda:{ddp_local_rank}' 104 | torch.cuda.set_device(device) 105 | master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. 106 | seed_offset = ddp_rank # each process gets a different seed 107 | else: 108 | # if not ddp, we are running on a single gpu, and one process 109 | master_process = True 110 | seed_offset = 0 111 | gradient_accumulation_steps *= 8 # simulate 8 gpus 112 | 113 | if master_process: 114 | os.makedirs(out_dir, exist_ok=True) 115 | torch.manual_seed(5000 + seed_offset) 116 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul 117 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn 118 | device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast 119 | # note: float16 data type will automatically use a GradScaler 120 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] 121 | ctx = nullcontext() if device_type == 'cpu' else torch.autocast(device_type=device_type, dtype=ptdtype) 122 | 123 | # poor man's data loader 124 | data_dir = os.path.join('data', dataset) 125 | train_file_list = list(filter(lambda x: x.endswith('.bin') and x.startswith('fineweb_train'), os.listdir(data_dir))) 126 | train_data_list = [np.memmap(os.path.join(data_dir, file), dtype=np.uint16, mode='r') for file in train_file_list] 127 | val_data = np.memmap(os.path.join(data_dir, 'fineweb_val_000000.bin'), dtype=np.uint16, mode='r') 128 | import random 129 | random.seed(5000 + seed_offset) 130 | def get_batch(split): 131 | if split == 'train': 132 | data = random.choice(train_data_list) 133 | else: 134 | data = val_data 135 | offset = 512 136 | ix = torch.randint(len(data) - block_size - offset, (batch_size,)) 137 | x = torch.stack([torch.from_numpy((data[offset+i:offset+i+block_size]).astype(np.int64)) for i in ix]) 138 | y = torch.stack([torch.from_numpy((data[offset+i+1:offset+i+1+block_size]).astype(np.int64)) for i in ix]) 139 | if device_type == 'cuda': 140 | # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True) 141 | x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True) 142 | else: 143 | x, y = x.to(device), y.to(device) 144 | return x, y 145 | 146 | # init these up here, can override if init_from='resume' (i.e. from a checkpoint) 147 | iter_num = 0 148 | best_val_loss = 1e9 149 | 150 | # attempt to derive vocab_size from the dataset 151 | meta_path = os.path.join(data_dir, 'meta.pkl') 152 | meta_vocab_size = None 153 | if os.path.exists(meta_path): 154 | with open(meta_path, 'rb') as f: 155 | meta = pickle.load(f) 156 | meta_vocab_size = meta['vocab_size'] 157 | print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})") 158 | 159 | # model init 160 | model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size, 161 | bias=bias, vocab_size=None, dropout=dropout, scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx) # start with model_args from command line 162 | if init_from == 'scratch': 163 | # init a new model from scratch 164 | print("Initializing a new model from scratch") 165 | # determine the vocab size we'll use for from-scratch training 166 | if meta_vocab_size is None: 167 | print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)") 168 | model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304 169 | gptconf = GPTConfig(**model_args) 170 | model = GPT(gptconf) 171 | elif init_from == 'resume': 172 | print(f"Resuming training from {out_dir}") 173 | # resume training from a checkpoint. 174 | ckpt_path = os.path.join(out_dir, 'ckpt.pt') 175 | checkpoint = torch.load(ckpt_path, map_location=device) 176 | checkpoint_model_args = checkpoint['model_args'] 177 | # force these config attributes to be equal otherwise we can't even resume training 178 | # the rest of the attributes (e.g. dropout) can stay as desired from command line 179 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']: 180 | model_args[k] = checkpoint_model_args[k] 181 | # create the model 182 | gptconf = GPTConfig(**model_args) 183 | model = GPT(gptconf) 184 | state_dict = checkpoint['model'] 185 | # fix the keys of the state dictionary :( 186 | # honestly no idea how checkpoints sometimes get this prefix, have to debug more 187 | unwanted_prefix = '_orig_mod.' 188 | for k,v in list(state_dict.items()): 189 | if k.startswith(unwanted_prefix): 190 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) 191 | model.load_state_dict(state_dict) 192 | iter_num = checkpoint['iter_num'] 193 | best_val_loss = checkpoint['best_val_loss'] 194 | elif init_from.startswith('gpt2'): 195 | print(f"Initializing from OpenAI GPT-2 weights: {init_from}") 196 | # initialize from OpenAI GPT-2 weights 197 | override_args = dict(dropout=dropout) 198 | model = GPT.from_pretrained(init_from, override_args) 199 | # read off the created config params, so we can store them into checkpoint correctly 200 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']: 201 | model_args[k] = getattr(model.config, k) 202 | # crop down the model block size if desired, using model surgery 203 | if block_size < model.config.block_size: 204 | model.crop_block_size(block_size) 205 | model_args['block_size'] = block_size # so that the checkpoint will have the right value 206 | model.to(device) 207 | 208 | # initialize a GradScaler. If enabled=False scaler is a no-op 209 | scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16')) 210 | 211 | # optimizer 212 | optimizer = model.configure_optimizers(optimizer_name, weight_decay, learning_rate, (beta1, beta2), device_type) 213 | if init_from == 'resume': 214 | optimizer.load_state_dict(checkpoint['optimizer']) 215 | del state_dict 216 | del checkpoint 217 | # compile the model 218 | if compile: 219 | print("compiling the model... (takes a ~minute)") 220 | unoptimized_model = model 221 | model = torch.compile(model) # requires PyTorch 2.0 222 | 223 | # wrap model into DDP container 224 | if ddp: 225 | model = DDP(model, device_ids=[ddp_local_rank]) 226 | 227 | # helps estimate an arbitrarily accurate loss over either split using many batches 228 | @torch.no_grad() 229 | def estimate_loss(): 230 | out = {} 231 | model.eval() 232 | for split in ['train', 'val']: 233 | losses = torch.zeros(eval_iters) 234 | for k in range(eval_iters): 235 | X, Y = get_batch(split) 236 | with ctx: 237 | logits, loss = model(X, Y) 238 | losses[k] = loss.item() 239 | out[split] = losses.mean() 240 | model.train() 241 | return out 242 | 243 | # learning rate decay scheduler (cosine with warmup) 244 | def get_lr(it, schedule='cosine'): 245 | #ing rate schedule {schedule}") 246 | # 1) linear warmup for warmup_iters steps 247 | if it < warmup_iters: 248 | return learning_rate * it / warmup_iters 249 | # 2) if it > lr_decay_iters, return min learning rate 250 | if schedule=='wsd': 251 | if it < 0.8 * max_iters: 252 | return learning_rate 253 | else: 254 | return learning_rate * (max_iters - it) / (max_iters * 0.2) 255 | if it > lr_decay_iters: 256 | return min_lr 257 | # 3) in between, use cosine decay down to min learning rate 258 | decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) 259 | assert 0 <= decay_ratio <= 1 260 | if schedule=='cosine': 261 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 262 | elif schedule=='exp': 263 | coeff = np.power(0.9, 100 * decay_ratio) 264 | 265 | return min_lr + coeff * (learning_rate - min_lr) 266 | 267 | # logging 268 | if wandb_log and master_process: 269 | import wandb 270 | wandb.init(project=wandb_project, name=wandb_run_name, config=config) 271 | 272 | # training loop 273 | X, Y = get_batch('train') # fetch the very first batch 274 | t0 = time.time() 275 | local_iter_num = 0 # number of iterations in the lifetime of this process 276 | raw_model = model.module if ddp else model # unwrap DDP container if needed 277 | running_mfu = -1.0 278 | clip_time = 0 279 | while True: 280 | 281 | # determine and set the learning rate for this iteration 282 | lr = get_lr(iter_num, schedule=schedule) if decay_lr else learning_rate 283 | for param_group in optimizer.param_groups: 284 | param_group['lr'] = lr 285 | 286 | # evaluate the loss on train/val sets and write checkpoints 287 | if iter_num % eval_interval == 0 and master_process: 288 | losses = estimate_loss() 289 | print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") 290 | if wandb_log: 291 | wandb.log({ 292 | "iter": iter_num, 293 | "train/loss": losses['train'], 294 | "val/loss": losses['val'], 295 | "lr": lr, 296 | "mfu": running_mfu*100, # convert to percentage 297 | }, step=iter_num) 298 | if losses['val'] < best_val_loss or always_save_checkpoint: 299 | best_val_loss = losses['val'] 300 | if iter_num > 0: 301 | checkpoint = { 302 | 'model': raw_model.state_dict(), 303 | 'optimizer': optimizer.state_dict(), 304 | 'model_args': model_args, 305 | 'iter_num': iter_num, 306 | 'best_val_loss': best_val_loss, 307 | 'config': config, 308 | } 309 | print(f"saving checkpoint to {out_dir}") 310 | torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt')) 311 | if iter_num % (eval_interval * 5) == 0: 312 | checkpoint = { 313 | 'model': raw_model.state_dict(), 314 | 'optimizer': optimizer.state_dict(), 315 | 'model_args': model_args, 316 | 'iter_num': iter_num, 317 | 'best_val_loss': best_val_loss, 318 | 'config': config, 319 | } 320 | print(f"saving checkpoint to {out_dir}") 321 | torch.save(checkpoint, os.path.join(out_dir, 'ckpt_'+str(iter_num)+'.pt')) 322 | if iter_num == 0 and eval_only: 323 | break 324 | 325 | # forward backward update, with optional gradient accumulation to simulate larger batch size 326 | # and using the GradScaler if data type is float16 327 | for micro_step in range(gradient_accumulation_steps): 328 | if ddp: 329 | # in DDP training we only need to sync gradients at the last micro step. 330 | # the official way to do this is with model.no_sync() context manager, but 331 | # I really dislike that this bloats the code and forces us to repeat code 332 | # looking at the source of that context manager, it just toggles this variable 333 | model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1) 334 | with ctx: 335 | logits, loss = model(X, Y) 336 | # immediately async prefetch next batch while model is doing the forward pass on the GPU 337 | X, Y = get_batch('train') 338 | # backward pass, with gradient scaling if training in fp16 339 | scaler.scale(loss).backward() 340 | # clip the gradient 341 | if grad_clip != 0.0: 342 | scaler.unscale_(optimizer) 343 | total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) 344 | if total_norm.item() > grad_clip: 345 | clip_time += 1 346 | # step the optimizer and scaler if training in fp16 347 | scaler.step(optimizer) 348 | scaler.update() 349 | # flush the gradients as soon as we can, no need for this memory anymore 350 | optimizer.zero_grad(set_to_none=True) 351 | 352 | # timing and logging 353 | t1 = time.time() 354 | dt = t1 - t0 355 | t0 = t1 356 | if iter_num % log_interval == 0 and master_process: 357 | lossf = loss.item() # loss as float. note: this is a CPU-GPU sync point 358 | if local_iter_num >= 5: # let the training loop settle a bit 359 | mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt) 360 | running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu 361 | print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%") 362 | params = [] 363 | for (name, p) in model.named_parameters(): 364 | params.append(p) 365 | total_param_norm = 0 366 | for p in params: 367 | param_norm = p.data.norm(2) 368 | total_param_norm += param_norm.item() ** 2 369 | total_param_norm = total_param_norm ** 0.5 370 | momentum_norm = 0 371 | LL = len(optimizer.state_dict()['state']) 372 | for jj in range(LL): 373 | momentum_norm += (optimizer.state_dict()['state'][jj]['exp_avg'].detach().norm(2)) ** 2 374 | momentum_norm = torch.sqrt(momentum_norm).item() 375 | if wandb_log: 376 | wandb.log({ 377 | "iter": iter_num, 378 | "train/loss": lossf, 379 | "lr": lr, 380 | "param_norm": total_param_norm, 381 | "momentum_norm" : momentum_norm, 382 | "train/clip_rate": clip_time / (iter_num + 1) 383 | }, step=iter_num) 384 | iter_num += 1 385 | local_iter_num += 1 386 | 387 | # termination conditions 388 | if iter_num > max_iters: 389 | break 390 | 391 | if ddp: 392 | destroy_process_group() 393 | -------------------------------------------------------------------------------- /MARS/train_mars.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import math 4 | import pickle 5 | from contextlib import nullcontext 6 | from collections import deque 7 | 8 | import numpy as np 9 | import torch 10 | from torch.nn.parallel import DistributedDataParallel as DDP 11 | from torch.distributed import init_process_group, destroy_process_group 12 | 13 | from model import GPTConfig, GPT 14 | import sys 15 | from ast import literal_eval 16 | # ----------------------------------------------------------------------------- 17 | # default config values designed to train a gpt2 (124M) on OpenWebText 18 | # I/O 19 | data_path = "./data" 20 | out_dir = 'out' 21 | eval_interval = 2000 22 | log_interval = 1 23 | eval_iters = 200 24 | eval_only = False # if True, script exits right after the first eval 25 | always_save_checkpoint = False # if True, always save a checkpoint after each eval 26 | init_from = 'scratch' # 'scratch' or 'resume' or 'gpt2*' 27 | # wandb logging 28 | wandb_log = False # disabled by default 29 | wandb_project = 'owt' 30 | wandb_run_name = 'gpt2' # 'run' + str(time.time()) 31 | # data 32 | dataset = 'openwebtext' 33 | gradient_accumulation_steps = 5 # used to simulate larger batch sizes 34 | batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size 35 | initial_steps = 100 36 | block_size = 1024 37 | # model 38 | n_layer = 12 39 | n_head = 12 40 | n_embd = 768 41 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 42 | bias = False # do we use bias inside LayerNorm and Linear layers? 43 | # optimizer 44 | optimizer_name = 'mars' 45 | learning_rate = 6e-4 # max learning rate 46 | max_iters = 600000 # total number of training iterations 47 | weight_decay = 1e-1 48 | beta1 = 0.95 49 | beta2 = 0.99 50 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 51 | interval = 10 52 | variant = 4 53 | # learning rate decay settings 54 | decay_lr = True # whether to decay the learning rate 55 | warmup_iters = 2000 # how many steps to warm up for 56 | lr_decay_iters = 600000 # should be ~= max_iters per Chinchilla 57 | min_lr = 6e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla 58 | # DDP settings 59 | backend = 'nccl' # 'nccl', 'gloo', etc. 60 | # system 61 | device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks 62 | dtype = 'bfloat16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler 63 | compile = True # use PyTorch 2.0 to compile the model to be faster 64 | scale_attn_by_inverse_layer_idx = True 65 | # learning rate schedule 66 | schedule='cosine' 67 | scheme='exact' 68 | gamma=0.025 69 | lr_1d=3e-3 70 | is_approx=True 71 | mars_type="mars-adamw" 72 | optimize_1d=False 73 | weight_decay_1d=0.1 74 | betas_1d=(0.9, 0.95) 75 | # ----------------------------------------------------------------------------- 76 | config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] 77 | for arg in sys.argv[1:]: 78 | if '=' not in arg: 79 | # assume it's the name of a config file 80 | assert not arg.startswith('--') 81 | config_file = arg 82 | print(f"Overriding config with {config_file}:") 83 | with open(config_file) as f: 84 | print(f.read()) 85 | exec(open(config_file).read()) 86 | else: 87 | # assume it's a --key=value argument 88 | assert arg.startswith('--') 89 | key, val = arg.split('=') 90 | key = key[2:] 91 | if key in globals(): 92 | try: 93 | # attempt to eval it it (e.g. if bool, number, or etc) 94 | attempt = literal_eval(val) 95 | except (SyntaxError, ValueError): 96 | # if that goes wrong, just use the string 97 | attempt = val 98 | # ensure the types match ok 99 | assert type(attempt) == type(globals()[key]) 100 | # cross fingers 101 | print(f"Overriding: {key} = {attempt}") 102 | globals()[key] = attempt 103 | else: 104 | raise ValueError(f"Unknown config key: {key}") 105 | config = {k: globals()[k] for k in config_keys} # will be useful for logging 106 | # ----------------------------------------------------------------------------- 107 | 108 | # various inits, derived attributes, I/O setup 109 | ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run? 110 | if ddp: 111 | init_process_group(backend=backend) 112 | ddp_rank = int(os.environ['RANK']) 113 | ddp_local_rank = int(os.environ['LOCAL_RANK']) 114 | device = f'cuda:{ddp_local_rank}' 115 | torch.cuda.set_device(device) 116 | master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. 117 | seed_offset = ddp_rank # each process gets a different seed 118 | else: 119 | # if not ddp, we are running on a single gpu, and one process 120 | master_process = True 121 | seed_offset = 0 122 | gradient_accumulation_steps *= 8 # simulate 8 gpus 123 | 124 | if master_process: 125 | os.makedirs(out_dir, exist_ok=True) 126 | torch.manual_seed(5000 + seed_offset) 127 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul 128 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn 129 | device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast 130 | # note: float16 data type will automatically use a GradScaler 131 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] 132 | ctx = nullcontext() if device_type == 'cpu' else torch.autocast(device_type=device_type, dtype=ptdtype) 133 | 134 | # poor man's data loader 135 | data_dir = os.path.join(data_path, dataset) 136 | train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r') 137 | val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r') 138 | def get_batch(split): 139 | data = train_data if split == 'train' else val_data 140 | ix = torch.randint(len(data) - block_size, (batch_size,)) 141 | x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix]) 142 | y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix]) 143 | if device_type == 'cuda': 144 | # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True) 145 | x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True) 146 | else: 147 | x, y = x.to(device), y.to(device) 148 | return x, y 149 | 150 | # init these up here, can override if init_from='resume' (i.e. from a checkpoint) 151 | iter_num = 0 152 | best_val_loss = 1e9 153 | 154 | # attempt to derive vocab_size from the dataset 155 | meta_path = os.path.join(data_dir, 'meta.pkl') 156 | meta_vocab_size = None 157 | if os.path.exists(meta_path): 158 | with open(meta_path, 'rb') as f: 159 | meta = pickle.load(f) 160 | meta_vocab_size = meta['vocab_size'] 161 | print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})") 162 | 163 | # model init 164 | model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size, 165 | bias=bias, vocab_size=None, dropout=dropout, scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx) # start with model_args from command line 166 | if init_from == 'scratch': 167 | # init a new model from scratch 168 | print("Initializing a new model from scratch") 169 | # determine the vocab size we'll use for from-scratch training 170 | if meta_vocab_size is None: 171 | print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)") 172 | model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304 173 | gptconf = GPTConfig(**model_args) 174 | model = GPT(gptconf) 175 | elif init_from == 'resume': 176 | print(f"Resuming training from {out_dir}") 177 | # resume training from a checkpoint. 178 | ckpt_path = os.path.join(out_dir, 'ckpt.pt') 179 | checkpoint = torch.load(ckpt_path, map_location=device) 180 | checkpoint_model_args = checkpoint['model_args'] 181 | # force these config attributes to be equal otherwise we can't even resume training 182 | # the rest of the attributes (e.g. dropout) can stay as desired from command line 183 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']: 184 | model_args[k] = checkpoint_model_args[k] 185 | # create the model 186 | gptconf = GPTConfig(**model_args) 187 | model = GPT(gptconf) 188 | state_dict = checkpoint['model'] 189 | # fix the keys of the state dictionary :( 190 | # honestly no idea how checkpoints sometimes get this prefix, have to debug more 191 | unwanted_prefix = '_orig_mod.' 192 | for k,v in list(state_dict.items()): 193 | if k.startswith(unwanted_prefix): 194 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) 195 | model.load_state_dict(state_dict) 196 | iter_num = checkpoint['iter_num'] 197 | best_val_loss = checkpoint['best_val_loss'] 198 | elif init_from.startswith('gpt2'): 199 | print(f"Initializing from OpenAI GPT-2 weights: {init_from}") 200 | # initialize from OpenAI GPT-2 weights 201 | override_args = dict(dropout=dropout) 202 | model = GPT.from_pretrained(init_from, override_args) 203 | # read off the created config params, so we can store them into checkpoint correctly 204 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']: 205 | model_args[k] = getattr(model.config, k) 206 | # crop down the model block size if desired, using model surgery 207 | if block_size < model.config.block_size: 208 | model.crop_block_size(block_size) 209 | model_args['block_size'] = block_size # so that the checkpoint will have the right value 210 | model.to(device) 211 | 212 | # initialize a GradScaler. If enabled=False scaler is a no-op 213 | scaler = torch.amp.GradScaler('cuda', enabled=(dtype == 'float16')) 214 | other_params = {'gamma': gamma, 'is_approx': is_approx, 'mars_type': mars_type, 'optimize_1d': optimize_1d, 215 | 'lr_1d': lr_1d, 'betas_1d': betas_1d, 'weight_decay_1d': weight_decay_1d} 216 | # optimizer 217 | optimizer = model.configure_optimizers(optimizer_name, weight_decay, learning_rate, (beta1, beta2), device_type, 218 | other_params) 219 | if init_from == 'resume': 220 | optimizer.load_state_dict(checkpoint['optimizer']) 221 | del state_dict 222 | del checkpoint 223 | # compile the model 224 | if compile: 225 | print("compiling the model... (takes a ~minute)") 226 | unoptimized_model = model 227 | model = torch.compile(model) # requires PyTorch 2.0 228 | 229 | # wrap model into DDP container 230 | if ddp: 231 | print('DDP_used') 232 | model = DDP(model, device_ids=[ddp_local_rank]) 233 | 234 | # helps estimate an arbitrarily accurate loss over either split using many batches 235 | @torch.no_grad() 236 | def estimate_loss(): 237 | out = {} 238 | model.eval() 239 | for split in ['train', 'val']: 240 | losses = torch.zeros(eval_iters) 241 | for k in range(eval_iters): 242 | X, Y = get_batch(split) 243 | with ctx: 244 | logits, loss = model(X, Y) 245 | losses[k] = loss.item() 246 | out[split] = losses.mean() 247 | model.train() 248 | return out 249 | 250 | # learning rate decay scheduler (cosine with warmup) 251 | def get_lr(it, schedule='cosine'): 252 | #ing rate schedule {schedule}") 253 | # 1) linear warmup for warmup_iters steps 254 | if it < warmup_iters: 255 | return learning_rate * it / warmup_iters 256 | # 2) if it > lr_decay_iters, return min learning rate 257 | if it > lr_decay_iters: 258 | return min_lr 259 | # 3) in between, use cosine decay down to min learning rate 260 | decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) 261 | assert 0 <= decay_ratio <= 1 262 | if schedule=='cosine': 263 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 264 | elif schedule=='exp': 265 | coeff = np.power(0.9, 100 * decay_ratio) 266 | return min_lr + coeff * (learning_rate - min_lr) 267 | 268 | # logging 269 | if wandb_log and master_process: 270 | import wandb 271 | wandb.init(project=wandb_project, name=wandb_run_name, config=config) 272 | 273 | # training loop 274 | #X, Y = get_batch('train') # fetch the very first batch 275 | Xs=deque([]) 276 | Ys=deque([]) 277 | for micro_step in range(gradient_accumulation_steps): 278 | X, Y = get_batch('train') 279 | Xs.append(X) 280 | Ys.append(Y) 281 | t0 = time.time() 282 | local_iter_num = 0 # number of iterations in the lifetime of this process 283 | raw_model = model.module if ddp else model # unwrap DDP container if needed 284 | running_mfu = -1.0 285 | clip_time = 0 286 | while True: 287 | 288 | # determine and set the learning rate for this iteration 289 | lr = get_lr(iter_num, schedule=schedule) if decay_lr else learning_rate 290 | for param_group in optimizer.param_groups: 291 | param_group['lr'] = lr 292 | 293 | # evaluate the loss on train/val sets and write checkpoints 294 | if iter_num % eval_interval == 0 and master_process: 295 | losses = estimate_loss() 296 | print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") 297 | if wandb_log: 298 | wandb.log({ 299 | "iter": iter_num, 300 | "train/loss": losses['train'], 301 | "val/loss": losses['val'], 302 | "lr": lr, 303 | "mfu": running_mfu*100, # convert to percentage 304 | }, step=iter_num) 305 | if losses['val'] < best_val_loss or always_save_checkpoint: 306 | best_val_loss = losses['val'] 307 | if iter_num > 0: 308 | checkpoint = { 309 | 'model': raw_model.state_dict(), 310 | 'optimizer': optimizer.state_dict(), 311 | 'model_args': model_args, 312 | 'iter_num': iter_num, 313 | 'best_val_loss': best_val_loss, 314 | 'config': config, 315 | } 316 | print(f"saving checkpoint to {out_dir}") 317 | torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt')) 318 | if iter_num % (eval_interval * 5) == 0: 319 | checkpoint = { 320 | 'model': raw_model.state_dict(), 321 | 'optimizer': optimizer.state_dict(), 322 | 'model_args': model_args, 323 | 'iter_num': iter_num, 324 | 'best_val_loss': best_val_loss, 325 | 'config': config, 326 | } 327 | print(f"saving checkpoint to {out_dir}") 328 | torch.save(checkpoint, os.path.join(out_dir, 'ckpt_'+str(iter_num)+'.pt')) 329 | if iter_num == 0 and eval_only: 330 | break 331 | 332 | # forward backward update, with optional gradient accumulation to simulate larger batch size 333 | # and using the GradScaler if data type is float16 334 | minibatch_size = gradient_accumulation_steps 335 | X_cur = [] 336 | Y_cur = [] 337 | ## Update datasets 338 | for micro_step in range(minibatch_size): 339 | X_cur.append(Xs.popleft()) 340 | Y_cur.append(Ys.popleft()) 341 | X, Y = get_batch('train') 342 | Xs.append(X) 343 | Ys.append(Y) 344 | ## Calculate previous gradient with future batch data first, this information should be used at the next iteration. 345 | if scheme == 'exact' and not is_approx: 346 | ### Calculate the gradient again using the new batch 347 | for micro_step in range(gradient_accumulation_steps): 348 | if ddp: 349 | # in DDP training we only need to sync gradients at the last micro step. 350 | # the official way to do this is with model.no_sync() context manager, but 351 | # I really dislike that this bloats the code and forces us to repeat code 352 | # looking at the source of that context manager, it just toggles this variable 353 | model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1) 354 | with ctx: 355 | X = Xs[micro_step] 356 | Y = Ys[micro_step] 357 | logits, loss = model(X, Y) 358 | # immediately async prefetch next batch while model is doing the forward pass on the GPU 359 | # backward pass, with gradient scaling if training in fp16 360 | scaler.scale(loss).backward() 361 | if grad_clip != 0.0: 362 | scaler.unscale_(optimizer) 363 | total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) 364 | if total_norm.item() > grad_clip: 365 | clip_time += 1 366 | elif (grad_clip == 0.0) and (optimizer.gamma == 0.0): 367 | scaler.unscale_(optimizer) 368 | total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 369 | if total_norm.item() > 1.0: 370 | clip_time += 1 371 | ### Update the previous grad of the next iteration 372 | optimizer.update_previous_grad() 373 | 374 | # flush the gradients as soon as we can, no need for this memory anymore 375 | optimizer.zero_grad(set_to_none=True) 376 | 377 | ## Calculate the gradient of the current batch 378 | for micro_step in range(minibatch_size): 379 | if ddp: 380 | # in DDP training we only need to sync gradients at the last micro step. 381 | # the official way to do this is with model.no_sync() context manager, but 382 | # I really dislike that this bloats the code and forces us to repeat code 383 | # looking at the source of that context manager, it just toggles this variable 384 | model.require_backward_grad_sync = (micro_step == minibatch_size - 1) 385 | with ctx: 386 | X = X_cur[micro_step] 387 | Y = Y_cur[micro_step] 388 | logits, loss = model(X, Y) 389 | # immediately async prefetch next batch while model is doing the forward pass on the GPU 390 | # backward pass, with gradient scaling if training in fp16 391 | scaler.scale(loss).backward() 392 | # clip the gradient 393 | if grad_clip != 0.0: 394 | scaler.unscale_(optimizer) 395 | total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) 396 | if total_norm.item() > grad_clip: 397 | clip_time += 1 398 | ### First update the current value of gradient 399 | #optimizer.update_current_grad() 400 | # step the optimizer and scaler if training in fp16 401 | scaler.step(optimizer) 402 | scaler.update() 403 | ### TODO: Clean the grad 404 | optimizer.zero_grad(set_to_none=True) 405 | optimizer.update_last_grad() 406 | 407 | # timing and logging 408 | t1 = time.time() 409 | dt = t1 - t0 410 | t0 = t1 411 | if iter_num % log_interval == 0 and master_process: 412 | lossf = loss.item() # loss as float. note: this is a CPU-GPU sync point 413 | if local_iter_num >= 5: # let the training loop settle a bit 414 | mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt) 415 | running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu 416 | print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%") 417 | params = [] 418 | for (name, p) in model.named_parameters(): 419 | params.append(p) 420 | total_param_norm = 0 421 | for p in params: 422 | param_norm = p.data.norm(2) 423 | total_param_norm += param_norm.item() ** 2 424 | total_param_norm = total_param_norm ** 0.5 425 | momentum_norm = 0 426 | momentum_norm_sq = 0 427 | momentum_div = 0 428 | LL = len(optimizer.state_dict()['state']) 429 | for jj in range(LL): 430 | momentum_norm += (optimizer.state_dict()['state'][jj]['exp_avg'].detach().norm(2)) ** 2 431 | momentum_norm_sq += (optimizer.state_dict()['state'][jj]['exp_avg_sq'].detach().norm(2)) ** 2 432 | momentum_norm = torch.sqrt(momentum_norm).item() 433 | momentum_norm_sq = torch.sqrt(momentum_norm_sq).item() 434 | momentum_div = momentum_norm/(np.sqrt(momentum_norm_sq)+1e-8) 435 | if wandb_log: 436 | wandb.log({ 437 | "iter": iter_num, 438 | "train/loss": lossf, 439 | "lr": lr, 440 | "param_norm": total_param_norm, 441 | "momentum_norm" : momentum_norm, 442 | "momentum_norm_sq": momentum_norm_sq, 443 | "momentum_div": momentum_div, 444 | "train/clip_rate": clip_time / (iter_num + 1) 445 | }, step=iter_num) 446 | iter_num += 1 447 | local_iter_num += 1 448 | 449 | # termination conditions 450 | if iter_num > max_iters: 451 | break 452 | 453 | if ddp: 454 | destroy_process_group() 455 | -------------------------------------------------------------------------------- /MARS/train_muon.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import math 4 | import pickle 5 | from contextlib import nullcontext 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn.parallel import DistributedDataParallel as DDP 10 | from torch.distributed import init_process_group, destroy_process_group 11 | 12 | from model import GPTConfig, GPT 13 | import sys 14 | from ast import literal_eval 15 | # ----------------------------------------------------------------------------- 16 | # default config values designed to train a gpt2 (124M) on OpenWebText 17 | # I/O 18 | data_path = "./data" 19 | out_dir = 'out' 20 | eval_interval = 2000 21 | log_interval = 1 22 | eval_iters = 200 23 | eval_only = False # if True, script exits right after the first eval 24 | always_save_checkpoint = True # if True, always save a checkpoint after each eval 25 | init_from = 'scratch' # 'scratch' or 'resume' or 'gpt2*' 26 | # wandb logging 27 | wandb_log = False # disabled by default 28 | wandb_project = 'owt' 29 | wandb_run_name = 'gpt2' # 'run' + str(time.time()) 30 | # data 31 | dataset = 'openwebtext' 32 | gradient_accumulation_steps = 5 # used to simulate larger batch sizes 33 | batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size 34 | block_size = 1024 35 | # model 36 | n_layer = 12 37 | n_head = 12 38 | n_embd = 768 39 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 40 | bias = False # do we use bias inside LayerNorm and Linear layers? 41 | # optimizer 42 | optimizer_name = 'muon' 43 | learning_rate = 6e-4 # max learning rate 44 | muon_learning_rate = 2e-2 45 | max_iters = 600000 # total number of training iterations 46 | weight_decay = 1e-1 47 | muon_weight_decay = 0. 48 | beta1 = 0.95 49 | beta2 = 0.99 50 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 51 | interval = 10 52 | variant = 4 53 | # learning rate decay settings 54 | decay_lr = True # whether to decay the learning rate 55 | warmup_iters = 2000 # how many steps to warm up for‘ 56 | warmdown_iters = 2000 57 | lr_decay_iters = 600000 # should be ~= max_iters per Chinchilla 58 | min_lr = 6e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla 59 | # DDP settings 60 | backend = 'nccl' # 'nccl', 'gloo', etc. 61 | schedule = 'cosine' 62 | # system 63 | device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks 64 | dtype = 'bfloat16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler 65 | compile = True # use PyTorch 2.0 to compile the model to be faster 66 | scale_attn_by_inverse_layer_idx = True 67 | # ----------------------------------------------------------------------------- 68 | config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] 69 | for arg in sys.argv[1:]: 70 | if '=' not in arg: 71 | # assume it's the name of a config file 72 | assert not arg.startswith('--') 73 | config_file = arg 74 | print(f"Overriding config with {config_file}:") 75 | with open(config_file) as f: 76 | print(f.read()) 77 | exec(open(config_file).read()) 78 | else: 79 | # assume it's a --key=value argument 80 | assert arg.startswith('--') 81 | key, val = arg.split('=') 82 | key = key[2:] 83 | if key in globals(): 84 | try: 85 | # attempt to eval it it (e.g. if bool, number, or etc) 86 | attempt = literal_eval(val) 87 | except (SyntaxError, ValueError): 88 | # if that goes wrong, just use the string 89 | attempt = val 90 | # ensure the types match ok 91 | assert type(attempt) == type(globals()[key]) 92 | # cross fingers 93 | print(f"Overriding: {key} = {attempt}") 94 | globals()[key] = attempt 95 | else: 96 | raise ValueError(f"Unknown config key: {key}") 97 | config = {k: globals()[k] for k in config_keys} # will be useful for logging 98 | # ----------------------------------------------------------------------------- 99 | 100 | # various inits, derived attributes, I/O setup 101 | ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run? 102 | ddp_world_size = int(os.environ['WORLD_SIZE']) 103 | if ddp: 104 | init_process_group(backend=backend) 105 | ddp_rank = int(os.environ['RANK']) 106 | ddp_local_rank = int(os.environ['LOCAL_RANK']) 107 | device = f'cuda:{ddp_local_rank}' 108 | torch.cuda.set_device(device) 109 | master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. 110 | seed_offset = ddp_rank # each process gets a different seed 111 | else: 112 | # if not ddp, we are running on a single gpu, and one process 113 | master_process = True 114 | seed_offset = 0 115 | gradient_accumulation_steps *= 8 # simulate 8 gpus 116 | 117 | if master_process: 118 | os.makedirs(out_dir, exist_ok=True) 119 | torch.manual_seed(5000 + seed_offset) 120 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul 121 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn 122 | device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast 123 | # note: float16 data type will automatically use a GradScaler 124 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] 125 | ctx = nullcontext() if device_type == 'cpu' else torch.autocast(device_type=device_type, dtype=ptdtype) 126 | 127 | # poor man's data loader 128 | data_dir = os.path.join(data_path, dataset) 129 | train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r') 130 | val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r') 131 | def get_batch(split): 132 | data = train_data if split == 'train' else val_data 133 | ix = torch.randint(len(data) - block_size, (batch_size,)) 134 | x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix]) 135 | y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix]) 136 | if device_type == 'cuda': 137 | # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True) 138 | x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True) 139 | else: 140 | x, y = x.to(device), y.to(device) 141 | return x, y 142 | 143 | # init these up here, can override if init_from='resume' (i.e. from a checkpoint) 144 | iter_num = 0 145 | best_val_loss = 1e9 146 | 147 | # attempt to derive vocab_size from the dataset 148 | meta_path = os.path.join(data_dir, 'meta.pkl') 149 | meta_vocab_size = None 150 | if os.path.exists(meta_path): 151 | with open(meta_path, 'rb') as f: 152 | meta = pickle.load(f) 153 | meta_vocab_size = meta['vocab_size'] 154 | print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})") 155 | 156 | # model init 157 | model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size, 158 | bias=bias, vocab_size=None, dropout=dropout, scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx) # start with model_args from command line 159 | if init_from == 'scratch': 160 | # init a new model from scratch 161 | print("Initializing a new model from scratch") 162 | # determine the vocab size we'll use for from-scratch training 163 | if meta_vocab_size is None: 164 | print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)") 165 | model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304 166 | gptconf = GPTConfig(**model_args) 167 | model = GPT(gptconf) 168 | elif init_from == 'resume': 169 | print(f"Resuming training from {out_dir}") 170 | # resume training from a checkpoint. 171 | ckpt_path = os.path.join(out_dir, 'ckpt.pt') 172 | checkpoint = torch.load(ckpt_path, map_location=device) 173 | checkpoint_model_args = checkpoint['model_args'] 174 | # force these config attributes to be equal otherwise we can't even resume training 175 | # the rest of the attributes (e.g. dropout) can stay as desired from command line 176 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']: 177 | model_args[k] = checkpoint_model_args[k] 178 | # create the model 179 | gptconf = GPTConfig(**model_args) 180 | model = GPT(gptconf) 181 | state_dict = checkpoint['model'] 182 | # fix the keys of the state dictionary :( 183 | # honestly no idea how checkpoints sometimes get this prefix, have to debug more 184 | unwanted_prefix = '_orig_mod.' 185 | for k,v in list(state_dict.items()): 186 | if k.startswith(unwanted_prefix): 187 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) 188 | model.load_state_dict(state_dict) 189 | iter_num = checkpoint['iter_num'] 190 | best_val_loss = checkpoint['best_val_loss'] 191 | elif init_from.startswith('gpt2'): 192 | print(f"Initializing from OpenAI GPT-2 weights: {init_from}") 193 | # initialize from OpenAI GPT-2 weights 194 | override_args = dict(dropout=dropout) 195 | model = GPT.from_pretrained(init_from, override_args) 196 | # read off the created config params, so we can store them into checkpoint correctly 197 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']: 198 | model_args[k] = getattr(model.config, k) 199 | # crop down the model block size if desired, using model surgery 200 | if block_size < model.config.block_size: 201 | model.crop_block_size(block_size) 202 | model_args['block_size'] = block_size # so that the checkpoint will have the right value 203 | model.to(device) 204 | 205 | # initialize a GradScaler. If enabled=False scaler is a no-op 206 | scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16')) 207 | 208 | # optimizer 209 | from optimizers.muon import Muon 210 | from optimizers.adamw import AdamW 211 | params = list(model.parameters()) 212 | from opt import CombinedOptimizer 213 | # optimizer1 = AdamW([p for p in params if p.ndim == 1], weight_decay=weight_decay, lr=learning_rate, betas=(beta1, beta2)) 214 | # optimizer2 = Muon([p for p in params if p.ndim == 2], lr=muon_learning_rate, rank=ddp_rank, world_size=ddp_world_size) 215 | # optimizers = [optimizer1, optimizer2] 216 | optimizer = CombinedOptimizer(params, [AdamW, Muon], [{'lr': learning_rate, 'betas': (beta1, beta2), 'weight_decay': weight_decay}, 217 | {'lr': muon_learning_rate, 'weight_decay': muon_weight_decay}]) 218 | if init_from == 'resume': 219 | # for optimizer in optimizers: 220 | optimizer.load_state_dict(checkpoint['optimizer']) 221 | del state_dict 222 | del checkpoint 223 | # compile the model 224 | if compile: 225 | print("compiling the model... (takes a ~minute)") 226 | unoptimized_model = model 227 | model = torch.compile(model) # requires PyTorch 2.0 228 | 229 | # wrap model into DDP container 230 | if ddp: 231 | model = DDP(model, device_ids=[ddp_local_rank]) 232 | 233 | # helps estimate an arbitrarily accurate loss over either split using many batches 234 | @torch.no_grad() 235 | def estimate_loss(): 236 | out = {} 237 | model.eval() 238 | for split in ['train', 'val']: 239 | losses = torch.zeros(eval_iters) 240 | for k in range(eval_iters): 241 | X, Y = get_batch(split) 242 | with ctx: 243 | logits, loss = model(X, Y) 244 | losses[k] = loss.item() 245 | out[split] = losses.mean() 246 | model.train() 247 | return out 248 | 249 | # learning rate decay scheduler (cosine with warmup) 250 | def get_lr(it, schedule='cosine', base_lr=learning_rate): 251 | # 1) linear warmup for warmup_iters steps 252 | if it < warmup_iters: 253 | return base_lr * it / warmup_iters 254 | elif it < max_iters - warmdown_iters: 255 | return base_lr 256 | else: 257 | decay_ratio = (max_iters - it) / warmdown_iters 258 | return base_lr * decay_ratio 259 | 260 | # logging 261 | if wandb_log and master_process: 262 | import wandb 263 | wandb.init(project=wandb_project, name=wandb_run_name, config=config) 264 | 265 | # training loop 266 | X, Y = get_batch('train') # fetch the very first batch 267 | t0 = time.time() 268 | local_iter_num = 0 # number of iterations in the lifetime of this process 269 | raw_model = model.module if ddp else model # unwrap DDP container if needed 270 | running_mfu = -1.0 271 | clip_time = 0 272 | while True: 273 | 274 | # determine and set the learning rate for this iteration 275 | 276 | # for optimizer in optimizers: 277 | for i in range(len(optimizer.optimizers)): 278 | lr = get_lr(iter_num, schedule=schedule, base_lr=optimizer.base_lrs[i]) 279 | for param_group in optimizer.optimizers[i].param_groups: 280 | param_group['lr'] = lr 281 | 282 | # evaluate the loss on train/val sets and write checkpoints 283 | if iter_num % eval_interval == 0 and master_process: 284 | losses = estimate_loss() 285 | print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") 286 | if wandb_log: 287 | wandb.log({ 288 | "iter": iter_num, 289 | "train/loss": losses['train'], 290 | "val/loss": losses['val'], 291 | "lr": lr, 292 | "mfu": running_mfu*100, # convert to percentage 293 | }, step=iter_num) 294 | if losses['val'] < best_val_loss or always_save_checkpoint: 295 | best_val_loss = losses['val'] 296 | if iter_num > 0: 297 | checkpoint = { 298 | 'model': raw_model.state_dict(), 299 | 'optimizer': optimizer.state_dict(), 300 | 'model_args': model_args, 301 | 'iter_num': iter_num, 302 | 'best_val_loss': best_val_loss, 303 | 'config': config, 304 | } 305 | print(f"saving checkpoint to {out_dir}") 306 | torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt')) 307 | # model.save_pretrained(os.path.join(out_dir, 'ckpt.pt')) 308 | if iter_num % (eval_interval * 5) == 0: 309 | checkpoint = { 310 | 'model': raw_model.state_dict(), 311 | 'optimizer': optimizer.state_dict(), 312 | 'model_args': model_args, 313 | 'iter_num': iter_num, 314 | 'best_val_loss': best_val_loss, 315 | 'config': config, 316 | } 317 | print(f"saving checkpoint to {out_dir}") 318 | torch.save(checkpoint, os.path.join(out_dir, 'ckpt_'+str(iter_num)+'.pt')) 319 | # model.save_pretrained(os.path.join(out_dir, 'ckpt.pt')) 320 | if iter_num == 0 and eval_only: 321 | break 322 | 323 | # forward backward update, with optional gradient accumulation to simulate larger batch size 324 | # and using the GradScaler if data type is float16 325 | for micro_step in range(gradient_accumulation_steps): 326 | if ddp: 327 | # in DDP training we only need to sync gradients at the last micro step. 328 | # the official way to do this is with model.no_sync() context manager, but 329 | # I really dislike that this bloats the code and forces us to repeat code 330 | # looking at the source of that context manager, it just toggles this variable 331 | model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1) 332 | with ctx: 333 | logits, loss = model(X, Y) 334 | # immediately async prefetch next batch while model is doing the forward pass on the GPU 335 | X, Y = get_batch('train') 336 | # backward pass, with gradient scaling if training in fp16 337 | scaler.scale(loss).backward() 338 | # clip the gradient 339 | if grad_clip != 0.0: 340 | scaler.unscale_(optimizer.optimizers[0]) 341 | scaler.unscale_(optimizer.optimizers[1]) 342 | total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) 343 | if total_norm.item() > grad_clip: 344 | clip_time += 1 345 | # step the optimizer and scaler if training in fp16 346 | scaler.step(optimizer) 347 | scaler.update() 348 | # flush the gradients as soon as we can, no need for this memory anymore 349 | optimizer.zero_grad(set_to_none=True) 350 | 351 | # timing and logging 352 | t1 = time.time() 353 | dt = t1 - t0 354 | t0 = t1 355 | if iter_num % log_interval == 0 and master_process: 356 | lossf = loss.item() # loss as float. note: this is a CPU-GPU sync point 357 | if local_iter_num >= 5: # let the training loop settle a bit 358 | mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt) 359 | running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu 360 | print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%") 361 | params = [] 362 | for (name, p) in model.named_parameters(): 363 | params.append(p) 364 | total_param_norm = 0 365 | for p in params: 366 | param_norm = p.data.norm(2) 367 | total_param_norm += param_norm.item() ** 2 368 | total_param_norm = total_param_norm ** 0.5 369 | momentum_norm = 0 370 | momentum_norm_sq = 0 371 | momentum_div = 0 372 | LL = len(optimizer.optimizers[0].state_dict()['state']) 373 | for jj in range(LL): 374 | momentum_norm += (optimizer.optimizers[0].state_dict()['state'][jj]['exp_avg'].detach().norm(2)) ** 2 375 | momentum_norm_sq += (optimizer.optimizers[0].state_dict()['state'][jj]['exp_avg_sq'].detach().norm(2)) ** 2 376 | momentum_norm = torch.sqrt(momentum_norm).item() 377 | momentum_norm_sq = torch.sqrt(momentum_norm_sq).item() 378 | momentum_div = momentum_norm/(np.sqrt(momentum_norm_sq)+1e-8) 379 | if wandb_log: 380 | wandb.log({ 381 | "iter": iter_num, 382 | "train/loss": lossf, 383 | "lr": lr, 384 | "param_norm": total_param_norm, 385 | "momentum_norm" : momentum_norm, 386 | "momentum_norm_sq": momentum_norm_sq, 387 | "momentum_div": momentum_div, 388 | "train/clip_rate": clip_time / (iter_num + 1) 389 | }, step=iter_num) 390 | iter_num += 1 391 | local_iter_num += 1 392 | 393 | # termination conditions 394 | if iter_num > max_iters: 395 | break 396 | 397 | if ddp: 398 | destroy_process_group() 399 | -------------------------------------------------------------------------------- /MARS/utils/configurator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Poor Man's Configurator. Probably a terrible idea. Example usage: 3 | $ python train.py config/override_file.py --batch_size=32 4 | this will first run config/override_file.py, then override batch_size to 32 5 | 6 | The code in this file will be run as follows from e.g. train.py: 7 | >>> exec(open('configurator.py').read()) 8 | 9 | So it's not a Python module, it's just shuttling this code away from train.py 10 | The code in this script then overrides the globals() 11 | 12 | I know people are not going to love this, I just really dislike configuration 13 | complexity and having to prepend config. to every single variable. If someone 14 | comes up with a better simple Python solution I am all ears. 15 | """ 16 | 17 | import sys 18 | from ast import literal_eval 19 | 20 | for arg in sys.argv[1:]: 21 | if '=' not in arg: 22 | # assume it's the name of a config file 23 | assert not arg.startswith('--') 24 | config_file = arg 25 | print(f"Overriding config with {config_file}:") 26 | with open(config_file) as f: 27 | print(f.read()) 28 | exec(open(config_file).read()) 29 | else: 30 | # assume it's a --key=value argument 31 | assert arg.startswith('--') 32 | key, val = arg.split('=') 33 | key = key[2:] 34 | if key in globals(): 35 | try: 36 | # attempt to eval it it (e.g. if bool, number, or etc) 37 | attempt = literal_eval(val) 38 | except (SyntaxError, ValueError): 39 | # if that goes wrong, just use the string 40 | attempt = val 41 | # ensure the types match ok 42 | assert type(attempt) == type(globals()[key]) 43 | # cross fingers 44 | print(f"Overriding: {key} = {attempt}") 45 | globals()[key] = attempt 46 | else: 47 | raise ValueError(f"Unknown config key: {key}") 48 | -------------------------------------------------------------------------------- /MARS/utils/cv_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | import torchvision 5 | from torchvision import datasets, transforms 6 | from torch.utils.data import DataLoader 7 | 8 | def get_model(args): 9 | """ 10 | models including: 11 | - VGG16 12 | - resnet18 13 | from https://github.com/iShohei220/adopt/blob/main/adopt.py and https://github.com/uclaml/Padam/blob/master/models/resnet.py 14 | """ 15 | if args.dataset in ['mnist', 'cifar10']: 16 | num_classes = 10 17 | elif args.dataset in ['cifar100']: 18 | num_classes = 100 19 | else: 20 | raise NotImplementedError(f"{args.dataset} is not implemented.") 21 | if args.net == 'simple_cnn': 22 | from .model_CNN import Network 23 | model_config = { 24 | "n_inputs": (3, 32, 32) if args.dataset == "cifar10" else (1, 28, 28), 25 | "conv_layers_list": [ 26 | {"filters": 32, "kernel_size": 3, "repeat": 2, "batch_norm": True}, 27 | {"filters": 64, "kernel_size": 3, "repeat": 2, "batch_norm": True}, 28 | {"filters": 128, "kernel_size": 3, "repeat": 2, "batch_norm": True}, 29 | ], 30 | "n_hiddens_list": [512], 31 | "n_outputs": 10, 32 | "dropout": 0.2, 33 | } 34 | model = Network(**model_config) 35 | elif args.net == 'resnet18': 36 | from .model_CNN import ResNet18 37 | model = ResNet18(num_classes = num_classes) 38 | else: 39 | try: 40 | model = torchvision.models.get_model(args.net, num_classes=num_classes) 41 | except: 42 | print('Model not found') 43 | raise NotImplementedError 44 | return model 45 | 46 | def get_datasets(dataset_name: str, train_batch_size: int, eval_batch_size: int): 47 | """Get train and test dataloaders.""" 48 | print('==> Preparing data..') 49 | if dataset_name == "mnist": 50 | transform = transforms.Compose([ 51 | transforms.ToTensor(), 52 | transforms.Normalize((0.1307,), (0.3081,)) 53 | ]) 54 | train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform) 55 | test_dataset = datasets.MNIST('./data', train=False, transform=transform) 56 | elif dataset_name == "cifar10": 57 | transform_train = transforms.Compose([ 58 | transforms.RandomCrop(32, padding=4), 59 | transforms.RandomHorizontalFlip(), 60 | transforms.ToTensor(), 61 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 62 | ]) 63 | transform_test = transforms.Compose([ 64 | transforms.ToTensor(), 65 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 66 | ]) 67 | train_dataset = datasets.CIFAR10('./data', train=True, download=True, transform=transform_train) 68 | test_dataset = datasets.CIFAR10('./data', train=False, transform=transform_test) 69 | elif dataset_name == "cifar100": 70 | transform_train = transforms.Compose([ 71 | transforms.RandomCrop(32, padding=4), 72 | transforms.RandomHorizontalFlip(), 73 | transforms.ToTensor(), 74 | transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276]), 75 | ]) 76 | transform_test = transforms.Compose([ 77 | transforms.ToTensor(), 78 | transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276]), 79 | ]) 80 | train_dataset = datasets.CIFAR100('./data', train=True, download=True, transform=transform_train) 81 | test_dataset = datasets.CIFAR100('./data', train=False, transform=transform_test) 82 | else: 83 | raise NotImplementedError(f"{dataset_name=} is not implemented.") 84 | 85 | train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=4) 86 | test_loader = DataLoader(test_dataset, batch_size=eval_batch_size, shuffle=False, num_workers=4) 87 | 88 | return train_loader, test_loader 89 | 90 | 91 | class WarmupCosineScheduler: 92 | """Custom learning rate scheduler with linear warmup and cosine decay.""" 93 | def __init__(self, optimizer, warmup_iters: int, total_iters: int, min_lr=0.): 94 | self.optimizer = optimizer 95 | self.warmup_iters = warmup_iters 96 | self.total_iters = total_iters 97 | self.min_lr = min_lr 98 | self.max_lr_list = [] 99 | for param_group in self.optimizer.param_groups: 100 | self.max_lr_list.append(param_group['lr']) 101 | self.current_iter = 0 102 | self.lr_list = [] 103 | for param_group in self.optimizer.param_groups: 104 | self.lr_list.append(param_group['lr']) 105 | 106 | def step(self): 107 | self.current_iter += 1 108 | lr_list = [] 109 | cnt = 0 110 | for param_group in self.optimizer.param_groups: 111 | max_lr = self.max_lr_list[cnt] 112 | if self.current_iter <= self.warmup_iters: 113 | lr = self.current_iter / self.warmup_iters * max_lr 114 | else: 115 | lr = self.min_lr + 0.5 * (max_lr - self.min_lr) * ( 116 | np.cos((self.current_iter - self.warmup_iters) / (self.total_iters - self.warmup_iters) * 3.14159265 / 2) 117 | ).item() 118 | param_group['lr'] = lr 119 | cnt += 1 120 | lr_list.append(lr) 121 | self.lr_list = lr_list 122 | def get_lr(self): 123 | lr_list = [] 124 | for param_group in self.optimizer.param_groups: 125 | lr_list.append(param_group['lr']) 126 | return lr_list 127 | 128 | class ConstantScheduler: 129 | """Constant learning rate scheduler.""" 130 | def __init__(self, optimizer, lr: float): 131 | self.optimizer = optimizer 132 | lr_list = [] 133 | for param_group in self.optimizer.param_groups: 134 | lr_list.append(lr) 135 | 136 | def step(self): 137 | pass 138 | 139 | def get_lr(self): 140 | lr_list = [] 141 | for param_group in self.optimizer.param_groups: 142 | lr_list.append(param_group['lr']) 143 | return lr_list 144 | 145 | def get_scheduler(optimizer, args): 146 | if args.scheduler == 'multistep': 147 | from torch.optim.lr_scheduler import MultiStepLR 148 | scheduler = MultiStepLR(optimizer, milestones=[args.Nepoch // 2, (args.Nepoch * 3) // 4], gamma=0.1) 149 | elif args.scheduler == 'cosine': 150 | scheduler = WarmupCosineScheduler(optimizer, warmup_iters = args.Nepoch // 10, total_iters = args.Nepoch, 151 | min_lr = 0.) 152 | elif args.scheduler == 'constant': 153 | scheduler = ConstantScheduler(optimizer, lr = args.lr) 154 | return scheduler -------------------------------------------------------------------------------- /MARS/utils/model_CNN.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Type, Union 2 | import importlib 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | def pair(t): 9 | return t if isinstance(t, tuple) else (t, t) 10 | 11 | 12 | def get_activation(activation_f: str) -> Type: 13 | """Get PyTorch activation function by name.""" 14 | package_name = "torch.nn" 15 | module = importlib.import_module(package_name) 16 | 17 | activations = [getattr(module, attr) for attr in dir(module)] 18 | activations = [ 19 | cls for cls in activations if isinstance(cls, type) and issubclass(cls, nn.Module) 20 | ] 21 | names = [cls.__name__.lower() for cls in activations] 22 | 23 | try: 24 | index = names.index(activation_f.lower()) 25 | return activations[index] 26 | except ValueError: 27 | raise NotImplementedError(f"get_activation: {activation_f=} is not yet implemented.") 28 | 29 | 30 | def compute_padding( 31 | input_size: tuple, kernel_size: int | tuple, stride: int | tuple = 1, dilation: int | tuple = 1 32 | ) -> Tuple[int, int]: 33 | """Compute padding for 'same' convolution.""" 34 | if len(input_size) == 2: 35 | input_size = (*input_size, 1) 36 | if isinstance(kernel_size, int): 37 | kernel_size = (kernel_size, kernel_size) 38 | if isinstance(stride, int): 39 | stride = (stride, stride) 40 | if isinstance(dilation, int): 41 | dilation = (dilation, dilation) 42 | 43 | input_h, input_w, _ = input_size 44 | kernel_h, kernel_w = kernel_size 45 | stride_h, stride_w = stride 46 | dilation_h, dilation_w = dilation 47 | 48 | # Compute the effective kernel size after dilation 49 | effective_kernel_h = (kernel_h - 1) * dilation_h + 1 50 | effective_kernel_w = (kernel_w - 1) * dilation_w + 1 51 | 52 | # Compute the padding needed for same convolution 53 | pad_h = int(max((input_h - 1) * stride_h + effective_kernel_h - input_h, 0)) 54 | pad_w = int(max((input_w - 1) * stride_w + effective_kernel_w - input_w, 0)) 55 | 56 | # Compute the padding for each side 57 | pad_top = pad_h // 2 58 | pad_left = pad_w // 2 59 | 60 | return (pad_top, pad_left) 61 | 62 | 63 | class Base(nn.Module): 64 | """Base class for neural network models.""" 65 | def __init__(self, **kwargs): 66 | super().__init__() 67 | self.__dict__.update(kwargs) 68 | 69 | @property 70 | def num_params(self): 71 | return sum(p.numel() for p in self.parameters()) 72 | 73 | @property 74 | def shapes(self): 75 | return {name: p.shape for name, p in self.named_parameters()} 76 | 77 | def summary(self): 78 | print(self) 79 | print(f"Number of parameters: {self.num_params}") 80 | 81 | 82 | class Network(Base): 83 | """Fully Connected / Convolutional Neural Network 84 | 85 | Args: 86 | n_inputs (Union[List[int], Tuple[int], torch.Size]): Input shape 87 | n_outputs (int): Number of output classes 88 | conv_layers_list (List[dict], optional): List of convolutional layers. Defaults to []. 89 | n_hiddens_list (Union[List, int], optional): List of hidden units. Defaults to 0. 90 | activation_f (str, optional): Activation function. Defaults to "ReLU". 91 | dropout (float, optional): Dropout rate. Defaults to 0.0. 92 | 93 | conv_layers_list dict keys: 94 | filters: int 95 | kernel_size: int 96 | stride: int 97 | dilation: int 98 | padding: int 99 | bias: bool 100 | batch_norm: bool 101 | repeat: int 102 | """ 103 | def __init__( 104 | self, 105 | n_inputs: Union[List[int], Tuple[int], torch.Size], 106 | n_outputs: int, 107 | conv_layers_list: List[dict] = [], 108 | n_hiddens_list: Union[List, int] = 0, 109 | activation_f: str = "ReLU", 110 | dropout: float = 0.0, 111 | ): 112 | super().__init__() 113 | 114 | if isinstance(n_hiddens_list, int): 115 | n_hiddens_list = [n_hiddens_list] 116 | 117 | if n_hiddens_list == [] or n_hiddens_list == [0]: 118 | self.n_hidden_layers = 0 119 | else: 120 | self.n_hidden_layers = len(n_hiddens_list) 121 | 122 | activation = get_activation(activation_f) 123 | 124 | # Convert n_inputs to tensor for shape calculations 125 | ni = torch.tensor(n_inputs) 126 | 127 | conv_layers = [] 128 | if conv_layers_list: 129 | for conv_layer in conv_layers_list: 130 | n_channels = int(ni[0]) 131 | 132 | padding = conv_layer.get( 133 | "padding", 134 | compute_padding( # same padding 135 | tuple(ni.tolist()), 136 | conv_layer["kernel_size"], 137 | conv_layer.get("stride", 1), 138 | conv_layer.get("dilation", 1), 139 | ), 140 | ) 141 | 142 | # Add repeated conv blocks 143 | for i in range(conv_layer.get("repeat", 1)): 144 | # Convolutional layer 145 | conv_layers.append( 146 | nn.Conv2d( 147 | n_channels if i == 0 else conv_layer["filters"], 148 | conv_layer["filters"], 149 | conv_layer["kernel_size"], 150 | stride=conv_layer.get("stride", 1), 151 | padding=padding, 152 | dilation=conv_layer.get("dilation", 1), 153 | bias=conv_layer.get("bias", True), 154 | ) 155 | ) 156 | 157 | # Activation 158 | conv_layers.append(activation()) 159 | 160 | # Optional batch norm 161 | if conv_layer.get("batch_norm"): 162 | conv_layers.append(nn.BatchNorm2d(conv_layer["filters"])) 163 | 164 | # Max pooling after each conv block 165 | conv_layers.append(nn.MaxPool2d(2, stride=2)) 166 | 167 | # Optional dropout 168 | if dropout > 0: 169 | conv_layers.append(nn.Dropout(dropout)) 170 | 171 | # Update input shape for next layer 172 | ni = torch.cat([torch.tensor([conv_layer["filters"]]), ni[1:] // 2]) 173 | 174 | self.conv = nn.Sequential(*conv_layers) 175 | 176 | # Fully connected layers 177 | ni = int(torch.prod(ni)) 178 | fcn_layers = [] 179 | if self.n_hidden_layers > 0: 180 | for _, n_units in enumerate(n_hiddens_list): 181 | fcn_layers.extend([ 182 | nn.Linear(ni, n_units), 183 | activation() 184 | ]) 185 | if dropout > 0: 186 | fcn_layers.append(nn.Dropout(dropout)) 187 | ni = n_units 188 | 189 | self.fcn = nn.Sequential(*fcn_layers) 190 | self.output = nn.Linear(ni, n_outputs) 191 | 192 | def forward(self, x: torch.Tensor) -> torch.Tensor: 193 | x = self.conv(x) 194 | x = x.view(x.size(0), -1) 195 | x = self.fcn(x) 196 | return self.output(x) 197 | 198 | '''ResNet in PyTorch. 199 | 200 | For Pre-activation ResNet, see 'preact_resnet.py'. 201 | 202 | Reference: 203 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 204 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 205 | ''' 206 | 207 | 208 | 209 | class BasicBlock(nn.Module): 210 | expansion = 1 211 | 212 | def __init__(self, in_planes, planes, stride=1): 213 | super(BasicBlock, self).__init__() 214 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 215 | self.bn1 = nn.BatchNorm2d(planes) 216 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 217 | self.bn2 = nn.BatchNorm2d(planes) 218 | 219 | self.shortcut = nn.Sequential() 220 | if stride != 1 or in_planes != self.expansion*planes: 221 | self.shortcut = nn.Sequential( 222 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 223 | nn.BatchNorm2d(self.expansion*planes) 224 | ) 225 | 226 | def forward(self, x): 227 | out = F.relu(self.bn1(self.conv1(x))) 228 | out = self.bn2(self.conv2(out)) 229 | out += self.shortcut(x) 230 | out = F.relu(out) 231 | return out 232 | 233 | 234 | class Bottleneck(nn.Module): 235 | expansion = 4 236 | 237 | def __init__(self, in_planes, planes, stride=1): 238 | super(Bottleneck, self).__init__() 239 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 240 | self.bn1 = nn.BatchNorm2d(planes) 241 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 242 | self.bn2 = nn.BatchNorm2d(planes) 243 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 244 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 245 | 246 | self.shortcut = nn.Sequential() 247 | if stride != 1 or in_planes != self.expansion*planes: 248 | self.shortcut = nn.Sequential( 249 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 250 | nn.BatchNorm2d(self.expansion*planes) 251 | ) 252 | 253 | def forward(self, x): 254 | out = F.relu(self.bn1(self.conv1(x))) 255 | out = F.relu(self.bn2(self.conv2(out))) 256 | out = self.bn3(self.conv3(out)) 257 | out += self.shortcut(x) 258 | out = F.relu(out) 259 | return out 260 | 261 | 262 | class ResNet(nn.Module): 263 | def __init__(self, block, num_blocks, num_classes=10): 264 | super(ResNet, self).__init__() 265 | self.in_planes = 64 266 | 267 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 268 | self.bn1 = nn.BatchNorm2d(64) 269 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 270 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 271 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 272 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 273 | self.linear = nn.Linear(512*block.expansion, num_classes) 274 | 275 | def _make_layer(self, block, planes, num_blocks, stride): 276 | strides = [stride] + [1]*(num_blocks-1) 277 | layers = [] 278 | for stride in strides: 279 | layers.append(block(self.in_planes, planes, stride)) 280 | self.in_planes = planes * block.expansion 281 | return nn.Sequential(*layers) 282 | 283 | def forward(self, x): 284 | out = F.relu(self.bn1(self.conv1(x))) 285 | out = self.layer1(out) 286 | out = self.layer2(out) 287 | out = self.layer3(out) 288 | out = self.layer4(out) 289 | out = F.avg_pool2d(out, 4) 290 | out = out.view(out.size(0), -1) 291 | out = self.linear(out) 292 | return out 293 | 294 | 295 | def ResNet18(num_classes = 10): 296 | return ResNet(BasicBlock, [2,2,2,2], num_classes = num_classes) 297 | 298 | def ResNet34(num_classes = 10): 299 | return ResNet(BasicBlock, [3,4,6,3], num_classes = num_classes) 300 | 301 | def ResNet50(num_classes = 10): 302 | return ResNet(Bottleneck, [3,4,6,3], num_classes = num_classes) 303 | 304 | def ResNet101(num_classes = 10): 305 | return ResNet(Bottleneck, [3,4,23,3], num_classes = num_classes) 306 | 307 | def ResNet152(num_classes = 10): 308 | return ResNet(Bottleneck, [3,8,36,3], num_classes = num_classes) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MARS: Unleashing the Power of Variance Reduction for Training Large Models 2 | 3 | This repository contains the official code for the paper [MARS: Unleashing the Power of Variance Reduction for Training Large Models](https://arxiv.org/abs/2411.10438). 4 | 5 | Authors: [Huizhuo Yuan](https://scholar.google.com/citations?user=8foZzX4AAAAJ)\*, [Yifeng Liu](https://scholar.google.com/citations?user=mFvOVkMAAAAJ&hl=zh-CN)\*, Shuang Wu, Xun Zhou, [Quanquan Gu](https://web.cs.ucla.edu/~qgu/) 6 | 7 | ## 🔔 NEWS 8 | - **[05/01/2025]** Our paper is accepted by **ICML 2025** 🎉🎉. 9 | - **[02/10/2025]** Our paper is updated on ArXiv: https://arxiv.org/pdf/2411.10438v2. 10 | - **[01/12/2025]** Update scripts for reproducing GPT-2 XL results and FineWeb-Edu results. 11 | - **[01/12/2025]** Our pretraining results on FineWeb-Edu are available. GPT-2 XL reaches a Hellaswag accuracy of 56.52 in 50B tokens. 12 | - **[11/26/2024]** Vision tasks added. 13 | - **[11/18/2024]** Our code is open-sourced! 14 | - **[11/15/2024]** Our paper is released on arXiv: https://arxiv.org/abs/2411.10438. 15 | 16 | ## About MARS 17 | 18 | **MARS** (**M**ake v**A**riance **R**eduction **S**hine) is a unified optimization framework designed to address the inherent challenges of training large models. Traditional adaptive gradient methods like Adam and AdamW often suffer from high stochastic gradient variance, while variance reduction techniques have struggled to gain practical impact in deep learning. At its core, **MARS** comprises two major components: (1) a scaled stochastic recursive momentum, which provides a variance-reduced estimator of the full gradient for better gradient complexity; and (2) the preconditioned update, which approximates the second-order Newton's method for better per-iteration complexity. By combining preconditioned gradient methods with variance reduction, **MARS** achieves the best of both worlds, accelerating the search for critical points in optimization. 19 | 20 | The **MARS** framework is built on the following preconditioned variance-reduced updates 21 | 22 | $$ 23 | \mathbf{c}\_t = \nabla f(\mathbf{x}\_t, \mathbf{\xi}\_t)+\underbrace{{\color{red}\gamma_t} \frac{\beta_{1}}{1-\beta_{1}} \left(\nabla f(\mathbf{x}\_t, \mathbf{\xi}\_t)-\nabla f(\mathbf{x}\_{t-1}, \mathbf{\xi}\_t)\right)}_{\text{scaled gradient correction}} 24 | $$ 25 | 26 | $$ 27 | \tilde{\mathbf{c}}_t = \text{Clip}(\mathbf{c}_t,1) = \begin{cases} 28 | \frac{\mathbf{c}_t}{\\|\mathbf{c}_t\\|_2} & \text{if } \\|\mathbf{c}_t\\|_2 > 1,\\ 29 | \mathbf{c}_t & \text{otherwise}. 30 | \end{cases} 31 | $$ 32 | 33 | $$ 34 | \mathbf{m}\_t = \beta_1 \mathbf{m}\_{t-1} + (1-\beta_{1})\tilde{\mathbf{c}}\_t 35 | $$ 36 | 37 | $$ 38 | \mathbf{x}\_{t+1} = \arg\min_{\mathbf{x} \in \mathbb{R}^d} \left\\{\eta_t \left\langle \mathbf{m}_t, \mathbf{x} \right\rangle + \frac{1}{2} \\|\mathbf{x} - \mathbf{x}\_t 39 | \\|\_{\mathbf{H}_t}^2\right\\} 40 | $$ 41 | 42 | Here ${\color{red}\gamma_t}$ is a scaling parameter that controls the strength of gradient correction. 43 | 44 | ### Instantiations of **MARS** 45 | 46 | Under the **MARS** framework, we provide three instantiations based on different Hessian matrix approximations: **MARS-AdamW**, **MARS-Lion**, and **MARS-Shampoo**. Please note that the hyperparameters in this framework are tuned on **MARS-AdamW**. When using other instantiations, it is essential to tune the hyperparameters—particularly the learning rates—for optimal performance. 47 | 48 | #### MARS-AdamW 49 | 50 | (Enable with `mars_type="mars-adamw"` in `mars.py`) 51 | 52 | The Hessian matrix approximation is defined as: 53 | 54 | $$ 55 | \mathbf{v}\_t =\beta_2 \mathbf{v}\_{t-1}+(1-\beta_2) \big(\nabla f(\mathbf{x}\_t, \mathbf{\xi}\_t)\big)^2 56 | $$ 57 | 58 | $$ 59 | \mathbf{H}_t := \sqrt{\text{diag}\Big(\mathbf{v}_t\Big)}\cdot \frac{1 - \beta_1^t}{\sqrt{1 - \beta_2^t}}. 60 | $$ 61 | 62 | #### MARS-Lion 63 | 64 | (Enable with `mars_type="mars-lion"` in `mars.py`) 65 | 66 | The Hessian matrix approximation is defined as: 67 | 68 | $$ 69 | \mathbf{H}_t := \sqrt{\text{diag}(\mathbf{m}_t^2)}. 70 | $$ 71 | 72 | #### MARS-Shampoo 73 | 74 | (Enable with `mars_type="mars-shampoo"` in `mars.py`) 75 | 76 | The preconditioner can be seen as an [orthogonal mapping](https://arxiv.org/abs/2409.20325) operator: 77 | 78 | $$ 79 | \mathbf{U}\_t, \mathbf{\Sigma}\_t, \mathbf{V}\_t = \text{SVD}(\mathbf{G}\_t),\qquad 80 | \mathbf{x}\_{t+1} =\mathbf{x}\_t-\eta_t\mathbf{U}_t\mathbf{V}\_t^\top. 81 | $$ 82 | 83 | In practice, we use the [Newton-Schulz iteration](https://github.com/KellerJordan/modded-nanogpt) to accelerate and approximate the solution of SVD problem. 84 | 85 | ### **Performance of MARS Compared to Baselines** 86 | 87 | #### Experiments on OpenWebText 88 | 89 | Experimental results for **MARS** are based on the **MARS-AdamW** instantiation, unless otherwise stated. In our experiments, gradients are calculated once per sample and per update (**MARS**-approx in our [paper](https://arxiv.org/abs/2411.10438)). Performing exact gradient computation with two evaluations per update, as in the exact form of **MARS**, can slightly enhance performance but at the cost of doubling the computational expense. For more details, refer to our [paper](https://arxiv.org/abs/2411.10438). 90 | 91 | **MARS** consistently outperforms AdamW and the [Muon]([https://github.com/KellerJordan/modded-nanogpt](https://github.com/KellerJordan/modded-nanogpt/tree/e01b457c7c52e1cd0c592920499a016f5289a69e)) optimizers across GPT-2 models: 92 | 93 | | **GPT-2 small** | **GPT-2 medium** | **GPT-2 large** | 94 | | ------------------------------------------------ | ------------------------------------------------- | ------------------------------------------------ | 95 | | | | | 96 | 97 | | Best Val Loss | GPT-2 Small (5B tokens) | GPT-2 Medium (5B tokens) | GPT-2 Large (5B tokens) | GPT-2 Small (20B tokens) | GPT-2 Medium (20B tokens) | GPT-2 Large (20B tokens) | GPT-2 Small (50B tokens) | GPT-2 Medium (50B tokens) | GPT-2 Large (50B tokens) | 98 | | --------------------- | ----------------------- | ------------------------ | ----------------------- | ------------------------ | ------------------------- | ------------------------ | ------------------------ | ------------------------- | ------------------------ | 99 | | AdamW | 3.193 | 3.084 | 3.013 | 3.024 | 2.821 | 2.741 | 2.885 | 2.691 | 2.561 | 100 | | Muon | 3.165 | 3.009 | 2.915 | 3.006 | 2.813 | 2.691 | 2.901 | 2.688 | 2.573 | 101 | | **MARS**-exact | **3.107** | - | - | 2.980 | - | - | **2.847** | - | - | 102 | | **MARS**-approx | 3.108 | **2.969** | **2.876** | **2.981** | **2.763** | **2.647** | **2.849** | **2.636** | **2.518** | 103 | 104 | 105 | #### Efficiency of MARS 106 | 107 | The **MARS** algorithm can achieve better performance not only within the same number of training steps, but also within the same training time: 108 | 109 | | **GPT-2 small** | **GPT-2 medium** | **GPT-2 large** | 110 | | ------------------------------------------------- | -------------------------------------------------- | ------------------------------------------------- | 111 | | | | | 112 | 113 | --- 114 | 115 | #### Experiments on FineWeb-Edu 116 | 117 | Below are the training and validation loss curves for both GPT‑2 Small and GPT‑2 XL when using our MARS approach versus AdamW. As you can see, MARS often yields faster convergence and consistently lower losses across different training steps. 118 | 119 | | Model | **GPT-2 small** | **GPT-2 XL** | 120 | | ----------------------- | -------------------------------------------------------- | --------------------------------------------------------- | 121 | | **Train Loss** | | | 122 | | **Validation Loss** | | | 123 | 124 | ##### Evaluation Metrics 125 | Below, we present the evaluation metrics on the FineWeb-Edu dataset for both GPT‑2 Small and GPT‑2 XL, comparing OpenAI GPT2 baseline, AdamW, and our MARS-AdamW optimizer. 126 | 127 | 128 | 129 | **Results on GPT-2 small** 130 | 131 | MARS-AdamW shows a clear improvement over AdamW and the OpenAI baseline across multiple tasks, with the **highest average score** of 45.93 on GPT‑2 Small. 132 | | Method/Task | ARC-E | ARC-C | BoolQ | HellaSwag | OBQA | PIQA | WG | MMLU | SciQ | Avg. | 133 | |--------------|-------|-------|-------|-----------|-------|-------|-------|-------|-------|-------| 134 | | OpenAI-Comm. | 39.48 | 22.70 | 48.72 | 31.14 | 27.20 | 62.51 | **51.62** | 22.92 | 64.40 | 41.19 | 135 | | AdamW | 51.43 | 26.54 | 55.78 | 36.26 | 30.60 | 64.53 | 50.36 | **24.49** | **71.50** | 45.72 | 136 | | MARS-AdamW | **52.23** | **27.39** | **55.84** | **36.91** | **32.20** | **64.80** | 49.96 | 22.95 | 71.10 | **45.93** | 137 | 138 | **Results on GPT-2 XL** 139 | 140 | On GPT‑2 XL, MARS-AdamW continues to outperform AdamW across most tasks, delivering an impressive **HellaSwag accuracy of 56.52**. 141 | 142 | | Method/Task | ARC-E | ARC-C | BoolQ | HellaSwag | OBQA | PIQA | WG | MMLU | SciQ | Avg. | 143 | |--------------|-------|-------|-------|-----------|-------|-------|-------|-------|-------|-------| 144 | | OpenAI-Comm. | 51.05 | 28.50 | 61.77 | 50.89 | 32.00 | 70.51 | **58.33** | 25.24 | 76.00 | 50.48 | 145 | | AdamW | **68.22** | 38.40 | 61.13 | 53.93 | 39.00 | 72.69 | 54.78 | **25.47** | 85.30 | 55.43 | 146 | | MARS-AdamW | 66.54 | **39.85** | **63.82** | **56.52** | **41.20** | **73.34** | 56.59 | 23.86 | **86.00** | **56.41** | 147 | 148 | --- 149 | 150 | #### Experiments on Vision Tasks 151 | 152 | **MARS** can achieve better test loss and accuracy than AdamW and the [Muon]([https://github.com/KellerJordan/modded-nanogpt](https://github.com/KellerJordan/modded-nanogpt/tree/e01b457c7c52e1cd0c592920499a016f5289a69e)) optimizers on CIFAR-10 and CIFAR-100 datasets with ResNet-18 and MultiStepLR(optimizer, milestones=[100, 150], gamma=0.1) scheduler (We display the best results for each optimizer with grid search of base learning rate within [1e-5, ..., 1e-1]): 153 | 154 | | Dataset | **CIFAR-10** | **CIFAR-100** | 155 | | ----------------------- | -------------------------------------------------------- | --------------------------------------------------------- | 156 | | **Test loss** | | | 157 | | **Test Accuracy** | | | 158 | 159 | | Best Test loss | CIFAR-10 | CIFAR-100 | 160 | | --------------------- | ---------- | ---------- | 161 | | AdamW | 0.306 | 2.608 | 162 | | Muon | 0.230 | 1.726 | 163 | | **MARS**-approx | **0.199** | **0.971** | 164 | 165 | | Best Test Accuracy (%) | CIFAR-10 | CIFAR-100 | 166 | | ---------------------- | --------------- | --------------- | 167 | | AdamW | 94.81 | 73.7 | 168 | | Muon | 95.08 | 74.64 | 169 | | **MARS**-approx | **95.29** | **76.97** | 170 | 171 | 172 | ## Training GPT-2 from Scratch: 173 | 174 | ### Install Dependencies 175 | 176 | ``` 177 | $ pip install torch==2.1.2 transformers==4.33.0 datasets tiktoken numpy==1.26.4 wandb 178 | ``` 179 | 180 | ### Data Preparation 181 | 182 | Prepare the [OpenWebText](https://huggingface.co/datasets/openwebtext) data following [nanoGPT](https://github.com/karpathy/nanoGPT/): 183 | 184 | ``` 185 | $ python data/openwebtext/prepare.py 186 | ``` 187 | 188 | ### **Start Training** 189 | 190 | To train a model using the **MARS** optimizer, run the following command: 191 | 192 | ```bash 193 | $ torchrun --standalone --nproc_per_node=8 MARS/train_mars.py config/${your_config_file} 194 | ``` 195 | 196 | This command initiates the training of a GPT-2 model on the OpenWebText dataset using the **MARS** optimizer. All relevant hyperparameters—training, model, and optimizer—are specified in the configuration file (`${your_config_file}`). These parameters can be adjusted directly in the configuration file or through the bash script. 197 | 198 | ### **Hyperparameter Details** 199 | 200 | #### **Model Hyperparameters**: 201 | 202 | - **n_layer**: Layers of networks, 12 for GPT2 Small, 24 for GPT2 Medium, 36 for GPT2 Large 203 | - **n_head**: Number of heads, 12 for GPT2 small, 16 for GPT2 Medium, 20 for GPT2 Large 204 | - **n_embd**: Embedding dimension, 768 for GPT2 small, 1024 for GPT2 Medium, 1280 for GPT2 Large 205 | 206 | #### **Optimizer Hyperparameters**: 207 | 208 | - **`learning_rate`**: Learning rate for the **MARS** optimizer. 209 | - **`weight_decay`**: Weight decay for the **MARS** optimizer. 210 | - **`beta1, beta2`**: Weights for exponential moving average. 211 | - Default: `beta1=0.95, beta2=0.99` 212 | - **`mars_type`**: Type of optimizer to use: 213 | - Options: `mars-adamw`, `mars-lion`, `mars-shampoo` 214 | - Default: `mars-adamw` 215 | - **`optimize_1d`**: Whether **MARS** should optimize 1D parameters (e.g., layer norm parameters in GPT-2). 216 | - If `False`, AdamW will be used for optimizing 1D parameters. 217 | - Default: `False` 218 | - **`lr_1d`**: Learning rate for AdamW when **`optimize_1d`** is set to `False`. 219 | - **`betas_1d`**: Weights for exponential moving average in AdamW optimizer. 220 | - Default: `(0.9, 0.95)` 221 | - **`is_approx`**: Whether to use approximate gradient calculation (**MARS**-approx). 222 | - Default: `True` 223 | - **`gamma`**: The scaling parameter that controls the strength of gradient correction. 224 | - Default: 0.025 225 | 226 | #### **Training Hyperparameters**: 227 | 228 | - **`batch_size`**: Mini-batch size per device. (for example GPT-2 Small on an A100 GPU typically uses a batch size of 15.) 229 | - **`gradient_accumulation_steps`**: Gradient accumulation steps to ensure the total effective batch size matches the desired scale. (for example, for a total batch size of 480: $15 \times 4 \times 8 \, \text{GPUs}$.) 230 | - **`schedule`**: learning rate schedule. 231 | - Default: `cosine` 232 | 233 | For more detailed hyperparameter examples, refer to: 234 | 235 | - `config/train_gpt2_small_mars.py` 236 | - `scripts/run_mars_small.sh` 237 | 238 | --- 239 | 240 | ### Reproducing Our Results 241 | 242 | #### **Reproducing GPT-2 Small (125M) Results** 243 | 244 | Training with MARS using 245 | 246 | ``` 247 | $ bash scripts/run_mars_small.sh 248 | ``` 249 | 250 | or 251 | 252 | ``` 253 | $ torchrun --standalone --nproc_per_node=8 \ 254 | MARS/train_mars.py \ 255 | config/train_gpt2_small_mars.py \ 256 | --batch_size=15 \ 257 | --gradient_accumulation_steps=4 258 | ``` 259 | 260 | #### Reproducing GPT2 Medium (355M) Results 261 | 262 | Training with MARS using 263 | 264 | ``` 265 | $ bash scripts/run_mars_medium.sh 266 | ``` 267 | 268 | or 269 | 270 | ``` 271 | $ torchrun --standalone --nproc_per_node=8 \ 272 | MARS/train_mars.py \ 273 | config/train_gpt2_medium_mars.py \ 274 | --batch_size=15 \ 275 | --gradient_accumulation_steps=4 276 | ``` 277 | 278 | #### Reproducing GPT2 Large (770M) Results 279 | 280 | Training with MARS using 281 | 282 | ``` 283 | $ bash scripts/run_mars_large.sh 284 | ``` 285 | 286 | or 287 | 288 | ``` 289 | $ torchrun --standalone --nproc_per_node=8 \ 290 | MARS/train_mars.py \ 291 | config/train_gpt2_large_mars.py \ 292 | --batch_size=5 \ 293 | --gradient_accumulation_steps=12 294 | ``` 295 | 296 | #### **Reproducing GPT-2 XL (1.5B) Results on FineWeb-Edu** 297 | ``` 298 | $ bash scripts/run_mars_xl_fw.sh 299 | ``` 300 | 301 | or 302 | 303 | ``` 304 | $ torchrun --standalone --nproc_per_node=8 \ 305 | MARS/train_mars_fw.py \ 306 | config/train_gpt2_xl_mars.py \ 307 | --batch_size=5 \ 308 | --gradient_accumulation_steps=12 309 | ``` 310 | 311 | #### Reproducing Baseline Results 312 | 313 | To reproduce the AdamW baseline: 314 | 315 | ``` 316 | bash scripts/run_adamw_{small/medium/large}.sh 317 | ``` 318 | To reproduce the AdamW baseline on FineWeb-Edu: 319 | ``` 320 | bash scripts/run_adamw_{small/xl}_fw.sh 321 | ``` 322 | 323 | To reproduce the Muon baseline following [modded-nanogpt](https://github.com/KellerJordan/modded-nanogpt/tree/e01b457c7c52e1cd0c592920499a016f5289a69e): 324 | 325 | ``` 326 | bash scripts/run_muon_{small/medium/large}.sh 327 | ``` 328 | 329 | Please adjust ``nproc_per_node``, ``batch_size``, and ``gradient_accumulation_steps`` accordingly if you use other hardware setup. Make sure their product equals 480. 330 | 331 | #### Hyperparameters for GPT-2 models 332 | 333 | | Model Name | Model Size | lr for AdamW | lr for Muon | lr for MARS | lr_1d for MARS | wd for AdamW | wd for Muon | wd for MARS | 334 | | :----------: | :--------: | :----------: | :---------: | :---------: | :------------: | :----------: | :---------: | :---------: | 335 | | GPT-2 small | 125M | 6e-4 | 2e-2 | 6e-3 | 3e-3 | 1e-1 | 0.0 | 1e-2 | 336 | | GPT-2 medium | 355M | 3e-4 | 1e-2 | 3e-3 | 1.5e-3 | 1e-1 | 0.0 | 1e-2 | 337 | | GPT-2 large | 770M | 2e-4 | 6.67e-3 | 2e-3 | 1e-3 | 1e-1 | 0.0 | 1e-2 | 338 | | GPT-2 xl | 1.5B | 2e-4 | - | 2e-3 | 1e-3 | 1e-1 | - | 1e-2 | 339 | 340 | 341 | 342 | ### Customized Training 343 | 344 | To build your own training pipeline on other architectures and datasets, use the following template as an example: 345 | 346 | ```python 347 | import torch 348 | import torch.nn.functional as F 349 | from mars import MARS 350 | 351 | # init model loss function and input data 352 | model = Model() 353 | data_loader = ... 354 | 355 | # init the optimizer 356 | optimizer = MARS(model.parameters(), lr=1e-3, betas=(0.9, 0.95), gamma=0.025) 357 | 358 | total_bs = len(data_loader) 359 | bs = total_bs * block_size 360 | k = 10 361 | iter_num = -1 362 | 363 | # training loop 364 | for epoch in range(epochs): 365 | for X, Y in data_loader: 366 | # standard training code 367 | logits, loss = model(X, Y) 368 | loss.backward() 369 | optimizer.step(bs=bs) 370 | optimizer.zero_grad(set_to_none=True) 371 | optimizer.update_last_grad() 372 | iter_num += 1 373 | 374 | ``` 375 | 376 | ## Star History 377 | 378 | [![Star History Chart](https://api.star-history.com/svg?repos=AGI-Arena/MARS&type=Date)](https://www.star-history.com/#AGI-Arena/MARS&Date) 379 | 380 | ## Citation 381 | 382 | If you find this repo useful for your research, please consider citing the paper 383 | 384 | ```tex 385 | @article{yuan2024mars, 386 | title={MARS: Unleashing the Power of Variance Reduction for Training Large Models}, 387 | author={Yuan, Huizhuo and Liu, Yifeng and Wu, Shuang and Zhou, Xun and Gu, Quanquan}, 388 | journal={arXiv preprint arXiv:2411.10438}, 389 | year={2024} 390 | } 391 | ``` 392 | 393 | ## Acknowledgements 394 | 395 | This repo is built upon [nanoGPT](https://github.com/karpathy/nanoGPT/), [levanter](https://github.com/stanford-crfm/levanter/) and [Sophia](https://github.com/Liuhong99/Sophia), we thank the authors for their great work! 396 | -------------------------------------------------------------------------------- /assets/MARS-AdamW.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/2bec90173df91810d3683de47157e95033209d33/assets/MARS-AdamW.png -------------------------------------------------------------------------------- /assets/MARS-Lion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/2bec90173df91810d3683de47157e95033209d33/assets/MARS-Lion.png -------------------------------------------------------------------------------- /assets/MARS-Shampoo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/2bec90173df91810d3683de47157e95033209d33/assets/MARS-Shampoo.png -------------------------------------------------------------------------------- /assets/MARS.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/2bec90173df91810d3683de47157e95033209d33/assets/MARS.png -------------------------------------------------------------------------------- /assets/ShampooH.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/2bec90173df91810d3683de47157e95033209d33/assets/ShampooH.png -------------------------------------------------------------------------------- /assets/cifar100_test_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/2bec90173df91810d3683de47157e95033209d33/assets/cifar100_test_acc.png -------------------------------------------------------------------------------- /assets/cifar100_test_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/2bec90173df91810d3683de47157e95033209d33/assets/cifar100_test_loss.png -------------------------------------------------------------------------------- /assets/cifar10_test_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/2bec90173df91810d3683de47157e95033209d33/assets/cifar10_test_acc.png -------------------------------------------------------------------------------- /assets/cifar10_test_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/2bec90173df91810d3683de47157e95033209d33/assets/cifar10_test_loss.png -------------------------------------------------------------------------------- /assets/fineweb_hella.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/2bec90173df91810d3683de47157e95033209d33/assets/fineweb_hella.png -------------------------------------------------------------------------------- /assets/small_train.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/2bec90173df91810d3683de47157e95033209d33/assets/small_train.png -------------------------------------------------------------------------------- /assets/small_val.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/2bec90173df91810d3683de47157e95033209d33/assets/small_val.png -------------------------------------------------------------------------------- /assets/time_large.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/2bec90173df91810d3683de47157e95033209d33/assets/time_large.png -------------------------------------------------------------------------------- /assets/time_medium.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/2bec90173df91810d3683de47157e95033209d33/assets/time_medium.png -------------------------------------------------------------------------------- /assets/time_small.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/2bec90173df91810d3683de47157e95033209d33/assets/time_small.png -------------------------------------------------------------------------------- /assets/val_large.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/2bec90173df91810d3683de47157e95033209d33/assets/val_large.png -------------------------------------------------------------------------------- /assets/val_medium.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/2bec90173df91810d3683de47157e95033209d33/assets/val_medium.png -------------------------------------------------------------------------------- /assets/val_small.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/2bec90173df91810d3683de47157e95033209d33/assets/val_small.jpg -------------------------------------------------------------------------------- /assets/val_small.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/2bec90173df91810d3683de47157e95033209d33/assets/val_small.png -------------------------------------------------------------------------------- /assets/xl_train.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/2bec90173df91810d3683de47157e95033209d33/assets/xl_train.png -------------------------------------------------------------------------------- /assets/xl_val.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/2bec90173df91810d3683de47157e95033209d33/assets/xl_val.png -------------------------------------------------------------------------------- /config/train_gpt2_large_adamw.py: -------------------------------------------------------------------------------- 1 | wandb_log = True 2 | wandb_project = 'mars' 3 | wandb_run_name='gpt2-large-adamw-100k' 4 | 5 | batch_size = 5 6 | block_size = 1024 7 | gradient_accumulation_steps = 12 8 | 9 | n_layer = 36 10 | n_head = 20 11 | n_embd = 1280 12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 13 | bias = False 14 | scale_attn_by_inverse_layer_idx = True 15 | 16 | # this makes total number of tokens be 300B 17 | max_iters = 100000 18 | lr_decay_iters = 100000 19 | 20 | # eval stuff 21 | eval_interval = 1000 22 | eval_iters = 200 23 | log_interval = 10 24 | 25 | # optimizer 26 | optimizer_name = 'adamw' 27 | learning_rate = 2e-4 # max learning rate 28 | weight_decay = 1e-1 29 | beta1 = 0.9 30 | beta2 = 0.95 31 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 32 | # learning rate decay settings 33 | decay_lr = True # whether to decay the learning rate 34 | warmup_iters = 2000 # how many steps to warm up for 35 | min_lr = 1e-5 36 | 37 | compile = True 38 | 39 | out_dir = 'out_large_adamw_100k' 40 | -------------------------------------------------------------------------------- /config/train_gpt2_large_mars.py: -------------------------------------------------------------------------------- 1 | wandb_log = True 2 | wandb_project = 'mars' 3 | wandb_run_name='gpt2-large-mars-100k' 4 | 5 | batch_size = 5 6 | block_size = 1024 7 | gradient_accumulation_steps = 12 8 | 9 | n_layer = 36 10 | n_head = 20 11 | n_embd = 1280 12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 13 | bias = False 14 | scale_attn_by_inverse_layer_idx = True 15 | 16 | # this makes total number of tokens be 300B 17 | max_iters = 100000 18 | lr_decay_iters = 100000 19 | 20 | # eval stuff 21 | eval_interval = 1000 22 | eval_iters = 200 23 | log_interval = 10 24 | 25 | # optimizer 26 | optimizer_name = 'mars' 27 | learning_rate = 2e-3 # max learning rate 28 | weight_decay = 1e-2 29 | beta1 = 0.95 30 | beta2 = 0.99 31 | lr_1d=1e-3 32 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 33 | # learning rate decay settings 34 | decay_lr = True # whether to decay the learning rate 35 | warmup_iters = 2000 # how many steps to warm up for 36 | min_lr = 1e-5 37 | 38 | compile = True 39 | 40 | out_dir = 'out_large_mars_100k' 41 | gamma=0.025 42 | -------------------------------------------------------------------------------- /config/train_gpt2_large_muon.py: -------------------------------------------------------------------------------- 1 | wandb_log = True 2 | wandb_project = 'mars' 3 | wandb_run_name='gpt2-large-muon-100k' 4 | 5 | batch_size = 5 6 | block_size = 1024 7 | gradient_accumulation_steps = 12 8 | 9 | n_layer = 36 10 | n_head = 20 11 | n_embd = 1280 12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 13 | bias = False 14 | scale_attn_by_inverse_layer_idx = True 15 | 16 | # this makes total number of tokens be 300B 17 | max_iters = 100000 18 | lr_decay_iters = 100000 19 | 20 | # eval stuff 21 | eval_interval = 1000 22 | eval_iters = 200 23 | log_interval = 10 24 | 25 | # optimizer 26 | optimizer_name = 'muon' 27 | learning_rate = 1e-3 # max learning rate 28 | weight_decay = 1e-1 29 | muon_learning_rate = 6.67e-3 30 | muon_weight_decay = 0. 31 | beta1 = 0.9 32 | beta2 = 0.95 33 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 34 | # learning rate decay settings 35 | decay_lr = True # whether to decay the learning rate 36 | warmup_iters = 2000 # how many steps to warm up for 37 | min_lr = 1e-5 38 | 39 | compile = True 40 | 41 | out_dir = 'out_large_muon_100k' 42 | -------------------------------------------------------------------------------- /config/train_gpt2_medium_adamw.py: -------------------------------------------------------------------------------- 1 | wandb_log = True 2 | wandb_project = 'mars' 3 | wandb_run_name='gpt2-medium-adamw-100k' 4 | 5 | batch_size = 15 6 | block_size = 1024 7 | gradient_accumulation_steps = 4 8 | 9 | n_layer = 24 10 | n_head = 16 11 | n_embd = 1024 12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 13 | bias = False 14 | scale_attn_by_inverse_layer_idx = True 15 | 16 | # this makes total number of tokens be 300B 17 | max_iters = 100000 18 | lr_decay_iters = 100000 19 | 20 | # eval stuff 21 | eval_interval = 1000 22 | eval_iters = 200 23 | log_interval = 10 24 | 25 | # optimizer 26 | optimizer_name = 'adamw' 27 | learning_rate = 3e-4 # max learning rate 28 | weight_decay = 1e-1 29 | beta1 = 0.9 30 | beta2 = 0.95 31 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 32 | # learning rate decay settings 33 | decay_lr = True # whether to decay the learning rate 34 | warmup_iters = 2000 # how many steps to warm up for 35 | min_lr = 6e-5 36 | 37 | compile = True 38 | 39 | out_dir = 'out_medium_adamw_100k' 40 | -------------------------------------------------------------------------------- /config/train_gpt2_medium_mars.py: -------------------------------------------------------------------------------- 1 | wandb_log = True 2 | wandb_project = 'mars' 3 | wandb_run_name='gpt2-medium-mars-100k' 4 | 5 | batch_size = 15 6 | block_size = 1024 7 | gradient_accumulation_steps = 4 8 | 9 | n_layer = 24 10 | n_head = 16 11 | n_embd = 1024 12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 13 | bias = False 14 | scale_attn_by_inverse_layer_idx = True 15 | 16 | # this makes total number of tokens be 300B 17 | max_iters = 100000 18 | lr_decay_iters = 100000 19 | 20 | # eval stuff 21 | eval_interval = 1000 22 | eval_iters = 200 23 | log_interval = 10 24 | 25 | # optimizer 26 | optimizer_name = 'mars' 27 | learning_rate = 3e-3 # max learning rate 28 | weight_decay = 1e-2 29 | beta1 = 0.95 30 | beta2 = 0.99 31 | lr_1d=1.5e-3 32 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 33 | # learning rate decay settings 34 | decay_lr = True # whether to decay the learning rate 35 | warmup_iters = 2000 # how many steps to warm up for 36 | min_lr = 6e-5 37 | 38 | compile = True 39 | 40 | out_dir = 'out_medium_mars_100k' 41 | gamma=0.025 42 | -------------------------------------------------------------------------------- /config/train_gpt2_medium_muon.py: -------------------------------------------------------------------------------- 1 | wandb_log = True 2 | wandb_project = 'mars' 3 | wandb_run_name='gpt2-medium-muon-100k' 4 | 5 | batch_size = 15 6 | block_size = 1024 7 | gradient_accumulation_steps = 4 8 | 9 | n_layer = 24 10 | n_head = 16 11 | n_embd = 1024 12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 13 | bias = False 14 | scale_attn_by_inverse_layer_idx = True 15 | 16 | # this makes total number of tokens be 300B 17 | max_iters = 100000 18 | lr_decay_iters = 100000 19 | 20 | # eval stuff 21 | eval_interval = 1000 22 | eval_iters = 200 23 | log_interval = 10 24 | 25 | # optimizer 26 | optimizer_name = 'muon' 27 | learning_rate = 1.5e-3 # max learning rate 28 | weight_decay = 1e-1 29 | muon_learning_rate = 1e-2 30 | muon_weight_decay = 0. 31 | beta1 = 0.9 32 | beta2 = 0.95 33 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 34 | # learning rate decay settings 35 | decay_lr = True # whether to decay the learning rate 36 | warmup_iters = 2000 # how many steps to warm up for 37 | min_lr = 6e-5 38 | 39 | compile = True 40 | 41 | out_dir = 'out_medium_muon_100k' 42 | -------------------------------------------------------------------------------- /config/train_gpt2_small_adamw.py: -------------------------------------------------------------------------------- 1 | wandb_log = True 2 | wandb_project = 'mars' 3 | wandb_run_name='gpt2-small-adamw-100k' 4 | 5 | batch_size = 15 6 | block_size = 1024 7 | gradient_accumulation_steps = 4 8 | 9 | n_layer = 12 10 | n_head = 12 11 | n_embd = 768 12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 13 | bias = False 14 | 15 | # this makes total number of tokens be 300B 16 | max_iters = 100000 17 | lr_decay_iters = 100000 18 | 19 | # eval stuff 20 | eval_interval = 1000 21 | eval_iters = 200 22 | log_interval = 10 23 | 24 | # optimizer 25 | optimizer_name = 'adamw' 26 | learning_rate = 6e-4 # max learning rate 27 | weight_decay = 1e-1 28 | beta1 = 0.9 29 | beta2 = 0.95 30 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 31 | # learning rate decay settings 32 | decay_lr = True # whether to decay the learning rate 33 | warmup_iters = 2000 # how many steps to warm up for 34 | min_lr = 3e-5 35 | 36 | compile = True 37 | 38 | out_dir = 'out_small_adamw_100k' 39 | -------------------------------------------------------------------------------- /config/train_gpt2_small_mars.py: -------------------------------------------------------------------------------- 1 | wandb_log = True 2 | wandb_project = 'mars' 3 | wandb_run_name='gpt2-small-mars-100k' 4 | 5 | batch_size = 15 6 | block_size = 1024 7 | gradient_accumulation_steps = 4 8 | 9 | n_layer = 12 10 | n_head = 12 11 | n_embd = 768 12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 13 | bias = False 14 | 15 | # this makes total number of tokens be 300B 16 | max_iters = 100000 17 | lr_decay_iters = 100000 18 | 19 | # eval stuff 20 | eval_interval = 1000 21 | eval_iters = 200 22 | log_interval = 10 23 | 24 | # optimizer 25 | optimizer_name = 'mars' 26 | learning_rate = 6e-3 # max learning rate 27 | weight_decay = 1e-2 28 | beta1 = 0.95 29 | beta2 = 0.99 30 | lr_1d=3e-3 31 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 32 | # learning rate decay settings 33 | decay_lr = True # whether to decay the learning rate 34 | warmup_iters = 2000 # how many steps to warm up for 35 | min_lr = 3e-5 36 | 37 | compile = True 38 | 39 | out_dir = 'out_small_mars_100k' 40 | gamma=0.025 41 | -------------------------------------------------------------------------------- /config/train_gpt2_small_muon.py: -------------------------------------------------------------------------------- 1 | wandb_log = True 2 | wandb_project = 'mars' 3 | wandb_run_name='gpt2-small-muon-100k' 4 | 5 | batch_size = 15 6 | block_size = 1024 7 | gradient_accumulation_steps = 4 8 | 9 | n_layer = 12 10 | n_head = 12 11 | n_embd = 768 12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 13 | bias = False 14 | 15 | # this makes total number of tokens be 300B 16 | max_iters = 100000 17 | lr_decay_iters = 100000 18 | 19 | # eval stuff 20 | eval_interval = 1000 21 | eval_iters = 200 22 | log_interval = 10 23 | 24 | # optimizer 25 | optimizer_name = 'muon' 26 | learning_rate = 3e-3 # max learning rate, original=6e-4 27 | weight_decay = 1e-1 28 | muon_learning_rate = 2e-2 29 | muon_weight_decay = 0. 30 | beta1 = 0.9 31 | beta2 = 0.95 32 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 33 | # learning rate decay settings 34 | decay_lr = True # whether to decay the learning rate 35 | warmup_iters = 2000 # how many steps to warm up for 36 | min_lr = 3e-5 37 | schedule = 'cosine' 38 | compile = True 39 | 40 | out_dir = 'out_small_muon_100k' 41 | -------------------------------------------------------------------------------- /config/train_gpt2_xl_adamw.py: -------------------------------------------------------------------------------- 1 | wandb_log = True 2 | wandb_project = 'mars' 3 | wandb_run_name='gpt2-xl-adamw-100k' 4 | 5 | batch_size = 5 6 | block_size = 1024 7 | gradient_accumulation_steps = 12 8 | 9 | n_layer = 48 10 | n_head = 25 11 | n_embd = 1600 12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 13 | bias = False 14 | scale_attn_by_inverse_layer_idx = True 15 | 16 | # this makes total number of tokens be 300B 17 | max_iters = 100000 18 | lr_decay_iters = 100000 19 | 20 | # eval stuff 21 | eval_interval = 1000 22 | eval_iters = 200 23 | log_interval = 10 24 | 25 | # optimizer 26 | optimizer_name = 'adamw' 27 | learning_rate = 2e-4 # max learning rate 28 | weight_decay = 1e-1 29 | beta1 = 0.9 30 | beta2 = 0.95 31 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 32 | # learning rate decay settings 33 | decay_lr = True # whether to decay the learning rate 34 | warmup_iters = 2000 # how many steps to warm up for 35 | min_lr = 1e-5 36 | 37 | compile = True 38 | 39 | out_dir = 'out_large_adamw_100k' 40 | -------------------------------------------------------------------------------- /config/train_gpt2_xl_mars.py: -------------------------------------------------------------------------------- 1 | wandb_log = True 2 | wandb_project = 'mars' 3 | wandb_run_name='gpt2-xl-mars-100k' 4 | 5 | batch_size = 5 6 | block_size = 1024 7 | gradient_accumulation_steps = 12 8 | 9 | n_layer = 48 10 | n_head = 25 11 | n_embd = 1600 12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 13 | bias = False 14 | scale_attn_by_inverse_layer_idx = True 15 | 16 | # this makes total number of tokens be 300B 17 | max_iters = 100000 18 | lr_decay_iters = 100000 19 | 20 | # eval stuff 21 | eval_interval = 1000 22 | eval_iters = 200 23 | log_interval = 10 24 | 25 | # optimizer 26 | optimizer_name = 'mars' 27 | learning_rate = 2e-3 # max learning rate 28 | weight_decay = 1e-2 29 | beta1 = 0.95 30 | beta2 = 0.99 31 | lr_1d=1e-3 32 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 33 | # learning rate decay settings 34 | decay_lr = True # whether to decay the learning rate 35 | warmup_iters = 2000 # how many steps to warm up for 36 | min_lr = 1e-5 37 | 38 | compile = True 39 | 40 | out_dir = 'out_large_mars_100k' 41 | gamma=0.025 42 | -------------------------------------------------------------------------------- /data/openwebtext/prepare.py: -------------------------------------------------------------------------------- 1 | # saves the openwebtext dataset to a binary file for training. following was helpful: 2 | # https://github.com/HazyResearch/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py 3 | 4 | import os 5 | from tqdm import tqdm 6 | import numpy as np 7 | import tiktoken 8 | from datasets import load_dataset # huggingface datasets 9 | 10 | # number of workers in .map() call 11 | # good number to use is ~order number of cpu cores // 2 12 | num_proc = 52 13 | 14 | # takes 54GB in huggingface .cache dir, about 8M documents (8,013,769) 15 | dataset = load_dataset("openwebtext", cache_dir="nanoGPT/cache") 16 | 17 | # owt by default only contains the 'train' split, so create a test split 18 | split_dataset = dataset["train"].train_test_split(test_size=0.0005, seed=2357, shuffle=True) 19 | split_dataset['val'] = split_dataset.pop('test') # rename the test split to val 20 | 21 | # this results in: 22 | # >>> split_dataset 23 | # DatasetDict({ 24 | # train: Dataset({ 25 | # features: ['text'], 26 | # num_rows: 8009762 27 | # }) 28 | # val: Dataset({ 29 | # features: ['text'], 30 | # num_rows: 4007 31 | # }) 32 | # }) 33 | 34 | # we now want to tokenize the dataset. first define the encoding function (gpt2 bpe) 35 | enc = tiktoken.get_encoding("gpt2") 36 | def process(example): 37 | ids = enc.encode_ordinary(example['text']) # encode_ordinary ignores any special tokens 38 | ids.append(enc.eot_token) # add the end of text token, e.g. 50256 for gpt2 bpe 39 | # note: I think eot should be prepended not appended... hmm. it's called "eot" though... 40 | out = {'ids': ids, 'len': len(ids)} 41 | return out 42 | 43 | # tokenize the dataset 44 | tokenized = split_dataset.map( 45 | process, 46 | remove_columns=['text'], 47 | desc="tokenizing the splits", 48 | num_proc=num_proc, 49 | ) 50 | print('tokenization finished') 51 | # concatenate all the ids in each dataset into one large file we can use for training 52 | for split, dset in tokenized.items(): 53 | arr_len = np.sum(dset['len']) 54 | filename = os.path.join(os.path.dirname(__file__), f'{split}.bin') 55 | dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16) 56 | arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,)) 57 | 58 | print(f"writing {filename}...") 59 | idx = 0 60 | for example in tqdm(dset): 61 | arr[idx : idx + example['len']] = example['ids'] 62 | idx += example['len'] 63 | arr.flush() 64 | 65 | # train.bin is ~17GB, val.bin ~8.5MB 66 | # train has ~9B tokens (9,035,582,198) 67 | # val has ~4M tokens (4,434,897) 68 | 69 | # to read the bin files later, e.g. with numpy: 70 | # m = np.memmap('train.bin', dtype=np.uint16, mode='r') 71 | -------------------------------------------------------------------------------- /scripts/run_CNN.sh: -------------------------------------------------------------------------------- 1 | python MARS/train_CNN.py -------------------------------------------------------------------------------- /scripts/run_CV.sh: -------------------------------------------------------------------------------- 1 | python MARS/train_CV.py -------------------------------------------------------------------------------- /scripts/run_adamw_large.sh: -------------------------------------------------------------------------------- 1 | torchrun --standalone --nproc_per_node=8 \ 2 | MARS/train_adamw.py \ 3 | config/train_gpt2_large_adamw.py \ 4 | --batch_size=5 \ 5 | --gradient_accumulation_steps=12 -------------------------------------------------------------------------------- /scripts/run_adamw_medium.sh: -------------------------------------------------------------------------------- 1 | torchrun --standalone --nproc_per_node=8 \ 2 | MARS/train_adamw.py \ 3 | config/train_gpt2_medium_adamw.py \ 4 | --batch_size=15 \ 5 | --gradient_accumulation_steps=4 -------------------------------------------------------------------------------- /scripts/run_adamw_small.sh: -------------------------------------------------------------------------------- 1 | torchrun --standalone --nproc_per_node=8 \ 2 | MARS/train_adamw.py \ 3 | config/train_gpt2_small_adamw.py \ 4 | --batch_size=15 \ 5 | --gradient_accumulation_steps=4 -------------------------------------------------------------------------------- /scripts/run_adamw_small_fw.sh: -------------------------------------------------------------------------------- 1 | torchrun --standalone --nproc_per_node=8 \ 2 | MARS/train_adamw_fw.py \ 3 | config/train_gpt2_small_adamw.py \ 4 | --batch_size=15 \ 5 | --gradient_accumulation_steps=4 6 | -------------------------------------------------------------------------------- /scripts/run_adamw_xl_fw.sh: -------------------------------------------------------------------------------- 1 | torchrun --standalone --nproc_per_node=8 \ 2 | MARS/train_adanw_fw.py \ 3 | config/train_gpt2_large_adamw.py \ 4 | --batch_size=5 \ 5 | --gradient_accumulation_steps=12 6 | -------------------------------------------------------------------------------- /scripts/run_mars_large.sh: -------------------------------------------------------------------------------- 1 | torchrun --standalone --nproc_per_node=8 \ 2 | MARS/train_mars.py \ 3 | config/train_gpt2_large_mars.py \ 4 | --batch_size=5 \ 5 | --gradient_accumulation_steps=12 -------------------------------------------------------------------------------- /scripts/run_mars_medium.sh: -------------------------------------------------------------------------------- 1 | torchrun --standalone --nproc_per_node=8 \ 2 | MARS/train_mars.py \ 3 | config/train_gpt2_medium_mars.py \ 4 | --batch_size=15 \ 5 | --gradient_accumulation_steps=4 -------------------------------------------------------------------------------- /scripts/run_mars_small.sh: -------------------------------------------------------------------------------- 1 | torchrun --standalone --nproc_per_node=8 \ 2 | MARS/train_mars.py \ 3 | config/train_gpt2_small_mars.py \ 4 | --batch_size=15 \ 5 | --gradient_accumulation_steps=4 -------------------------------------------------------------------------------- /scripts/run_mars_small_fw.sh: -------------------------------------------------------------------------------- 1 | torchrun --standalone --nproc_per_node=8 \ 2 | MARS/train_mars_fw.py \ 3 | config/train_gpt2_small_mars.py \ 4 | --batch_size=15 \ 5 | --gradient_accumulation_steps=4 6 | -------------------------------------------------------------------------------- /scripts/run_mars_xl_fw.sh: -------------------------------------------------------------------------------- 1 | torchrun --standalone --nproc_per_node=8 \ 2 | MARS/train_mars_fw.py \ 3 | config/train_gpt2_xl_mars.py \ 4 | --batch_size=5 \ 5 | --gradient_accumulation_steps=12 6 | -------------------------------------------------------------------------------- /scripts/run_muon_large.sh: -------------------------------------------------------------------------------- 1 | torchrun --standalone --nproc_per_node=8 \ 2 | MARS/train_muon.py \ 3 | config/train_gpt2_large_muon.py \ 4 | --batch_size=5 \ 5 | --gradient_accumulation_steps=12 -------------------------------------------------------------------------------- /scripts/run_muon_medium.sh: -------------------------------------------------------------------------------- 1 | torchrun --standalone --nproc_per_node=8 \ 2 | MARS/train_muon.py \ 3 | config/train_gpt2_medium_muon.py \ 4 | --batch_size=15 \ 5 | --gradient_accumulation_steps=4 -------------------------------------------------------------------------------- /scripts/run_muon_small.sh: -------------------------------------------------------------------------------- 1 | torchrun --standalone --nproc_per_node=8 \ 2 | MARS/train_muon.py \ 3 | config/train_gpt2_small_muon.py \ 4 | --batch_size=15 \ 5 | --gradient_accumulation_steps=4 --------------------------------------------------------------------------------