├── LICENSE
├── MARS
├── model.py
├── opt.py
├── optimizers
│ ├── adamw.py
│ ├── adopt.py
│ ├── mars.py
│ └── muon.py
├── train_CNN.py
├── train_CV.py
├── train_adamw.py
├── train_adamw_fw.py
├── train_mars.py
├── train_mars_fw.py
├── train_muon.py
└── utils
│ ├── configurator.py
│ ├── cv_utils.py
│ └── model_CNN.py
├── README.md
├── assets
├── MARS-AdamW.png
├── MARS-Lion.png
├── MARS-Shampoo.png
├── MARS.png
├── ShampooH.png
├── cifar100_test_acc.png
├── cifar100_test_loss.png
├── cifar10_test_acc.png
├── cifar10_test_loss.png
├── fineweb_hella.png
├── small_train.png
├── small_val.png
├── time_large.png
├── time_medium.png
├── time_small.png
├── val_large.png
├── val_medium.png
├── val_small.jpg
├── val_small.png
├── xl_train.png
└── xl_val.png
├── config
├── train_gpt2_large_adamw.py
├── train_gpt2_large_mars.py
├── train_gpt2_large_muon.py
├── train_gpt2_medium_adamw.py
├── train_gpt2_medium_mars.py
├── train_gpt2_medium_muon.py
├── train_gpt2_small_adamw.py
├── train_gpt2_small_mars.py
├── train_gpt2_small_muon.py
├── train_gpt2_xl_adamw.py
└── train_gpt2_xl_mars.py
├── data
└── openwebtext
│ └── prepare.py
└── scripts
├── run_CNN.sh
├── run_CV.sh
├── run_adamw_large.sh
├── run_adamw_medium.sh
├── run_adamw_small.sh
├── run_adamw_small_fw.sh
├── run_adamw_xl_fw.sh
├── run_mars_large.sh
├── run_mars_medium.sh
├── run_mars_small.sh
├── run_mars_small_fw.sh
├── run_mars_xl_fw.sh
├── run_muon_large.sh
├── run_muon_medium.sh
└── run_muon_small.sh
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/MARS/model.py:
--------------------------------------------------------------------------------
1 | """
2 | Adapted from https://github.com/Liuhong99/Sophia/blob/main/model.py
3 | """
4 |
5 | import math
6 | import inspect
7 | from dataclasses import dataclass
8 | from optimizers.adamw import AdamW
9 | import torch
10 | import torch.nn as nn
11 | from torch.nn import functional as F
12 | from optimizers.mars import MARS
13 |
14 | optimizer_dict = {'adamw': torch.optim.AdamW,
15 | 'adamw_ours': AdamW,
16 | 'mars': MARS,
17 | }
18 |
19 | # @torch.jit.script # good to enable when not using torch.compile, disable when using (our default)
20 | def new_gelu(x):
21 | """
22 | Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT).
23 | Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415
24 | """
25 | return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
26 |
27 | class LayerNorm(nn.Module):
28 | """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
29 |
30 | def __init__(self, ndim, bias):
31 | super().__init__()
32 | self.weight = nn.Parameter(torch.ones(ndim))
33 | self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
34 |
35 | def forward(self, input):
36 | return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
37 |
38 | class CausalSelfAttention(nn.Module):
39 |
40 | def __init__(self, config, idx_layer):
41 | super().__init__()
42 | assert config.n_embd % config.n_head == 0
43 | # key, query, value projections for all heads, but in a batch
44 | self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
45 | # output projection
46 | self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
47 | # regularization
48 | self.attn_dropout = nn.Dropout(config.dropout)
49 | self.resid_dropout = nn.Dropout(config.dropout)
50 | self.n_head = config.n_head
51 | self.n_embd = config.n_embd
52 | self.dropout = config.dropout
53 | self.idx_layer = idx_layer
54 | self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
55 |
56 | # causal mask to ensure that attention is only applied to the left in the input sequence
57 | self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
58 | .view(1, 1, config.block_size, config.block_size))
59 |
60 | def forward(self, x):
61 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
62 |
63 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim
64 | q, k ,v = self.c_attn(x).split(self.n_embd, dim=2)
65 | k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
66 | q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
67 | v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
68 |
69 | if self.scale_attn_by_inverse_layer_idx:
70 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)) / float(self.idx_layer + 1))
71 | else:
72 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
73 | att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
74 | att = F.softmax(att, dim=-1)
75 | att = self.attn_dropout(att)
76 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
77 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
78 |
79 | # output projection
80 | y = self.resid_dropout(self.c_proj(y))
81 | return y
82 |
83 | class MLP(nn.Module):
84 |
85 | def __init__(self, config):
86 | super().__init__()
87 | self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
88 | self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
89 | self.dropout = nn.Dropout(config.dropout)
90 |
91 | def forward(self, x):
92 | x = self.c_fc(x)
93 | x = new_gelu(x)
94 | x = self.c_proj(x)
95 | x = self.dropout(x)
96 | return x
97 |
98 | class Block(nn.Module):
99 |
100 | def __init__(self, config, idx_layer):
101 | super().__init__()
102 | self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
103 | self.attn = CausalSelfAttention(config, idx_layer)
104 | self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
105 | self.mlp = MLP(config)
106 |
107 | def forward(self, x):
108 | x = x + self.attn(self.ln_1(x))
109 | x = x + self.mlp(self.ln_2(x))
110 | return x
111 |
112 | @dataclass
113 | class GPTConfig:
114 | block_size: int = 1024
115 | vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency, 50304
116 | n_layer: int = 12
117 | n_head: int = 12
118 | n_embd: int = 768
119 | dropout: float = 0.0
120 | bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
121 | scale_attn_by_inverse_layer_idx: bool = False
122 |
123 |
124 | class GPT(nn.Module):
125 |
126 | def __init__(self, config):
127 | super().__init__()
128 | assert config.vocab_size is not None
129 | assert config.block_size is not None
130 | self.config = config
131 |
132 | self.transformer = nn.ModuleDict(dict(
133 | wte = nn.Embedding(config.vocab_size, config.n_embd),
134 | wpe = nn.Embedding(config.block_size, config.n_embd),
135 | drop = nn.Dropout(config.dropout),
136 | h = nn.ModuleList([Block(config, idx_layer) for idx_layer in range(config.n_layer)]),
137 | ln_f = LayerNorm(config.n_embd, bias=config.bias),
138 | ))
139 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
140 | # with weight tying when using torch.compile() some warnings get generated:
141 | # "UserWarning: functional_call was passed multiple values for tied weights.
142 | # This behavior is deprecated and will be an error in future versions"
143 | # not 100% sure what this is, so far seems to be harmless. TODO investigate
144 | self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
145 |
146 | # init all weights
147 | self.apply(self._init_weights)
148 | # apply special scaled init to the residual projections, per GPT-2 paper
149 | for pn, p in self.named_parameters():
150 | if pn.endswith('c_proj.weight'):
151 | torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
152 |
153 | # report number of parameters
154 | print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
155 |
156 | def get_num_params(self, non_embedding=True):
157 | """
158 | Return the number of parameters in the model.
159 | For non-embedding count (default), the position embeddings get subtracted.
160 | The token embeddings would too, except due to the parameter sharing these
161 | params are actually used as weights in the final layer, so we include them.
162 | """
163 | n_params = sum(p.numel() for p in self.parameters())
164 | if non_embedding:
165 | n_params -= self.transformer.wpe.weight.numel()
166 | return n_params
167 |
168 | def _init_weights(self, module):
169 | if isinstance(module, nn.Linear):
170 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
171 | if module.bias is not None:
172 | torch.nn.init.zeros_(module.bias)
173 | elif isinstance(module, nn.Embedding):
174 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
175 |
176 | def forward(self, idx, targets=None):
177 | device = idx.device
178 | b, t = idx.size()
179 | assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
180 | pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
181 |
182 | # forward the GPT model itself
183 | tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
184 | pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd)
185 | x = self.transformer.drop(tok_emb + pos_emb)
186 | for block in self.transformer.h:
187 | x = block(x)
188 | x = self.transformer.ln_f(x)
189 |
190 | if targets is not None:
191 | # if we are given some desired targets also calculate the loss
192 | if not isinstance(targets, int):
193 | logits = self.lm_head(x)
194 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
195 | else:
196 | logits = self.lm_head(x)
197 | loss = None
198 | else:
199 | # inference-time mini-optimization: only forward the lm_head on the very last position
200 | logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
201 | loss = None
202 |
203 | return logits, loss
204 |
205 | def crop_block_size(self, block_size):
206 | # model surgery to decrease the block size if necessary
207 | # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024)
208 | # but want to use a smaller block size for some smaller, simpler model
209 | assert block_size <= self.config.block_size
210 | self.config.block_size = block_size
211 | self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size])
212 | for block in self.transformer.h:
213 | block.attn.bias = block.attn.bias[:,:,:block_size,:block_size]
214 |
215 | @classmethod
216 | def from_pretrained(cls, model_type, override_args=None):
217 | assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
218 | override_args = override_args or {} # default to empty dict
219 | # only dropout can be overridden see more notes below
220 | assert all(k == 'dropout' for k in override_args)
221 | from transformers import GPT2LMHeadModel
222 | print("loading weights from pretrained gpt: %s" % model_type)
223 |
224 | # n_layer, n_head and n_embd are determined from model_type
225 | config_args = {
226 | 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params
227 | 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
228 | 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
229 | 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
230 | }[model_type]
231 | print("forcing vocab_size=50257, block_size=1024, bias=True")
232 | config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints
233 | config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints
234 | config_args['bias'] = True # always True for GPT model checkpoints
235 | # we can override the dropout rate, if desired
236 | if 'dropout' in override_args:
237 | print(f"overriding dropout rate to {override_args['dropout']}")
238 | config_args['dropout'] = override_args['dropout']
239 | # create a from-scratch initialized minGPT model
240 | config = GPTConfig(**config_args)
241 | model = GPT(config)
242 | sd = model.state_dict()
243 | sd_keys = sd.keys()
244 | sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param
245 |
246 | # init a huggingface/transformers model
247 | model_hf = GPT2LMHeadModel.from_pretrained(model_type)
248 | sd_hf = model_hf.state_dict()
249 |
250 | # copy while ensuring all of the parameters are aligned and match in names and shapes
251 | sd_keys_hf = sd_hf.keys()
252 | sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
253 | sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
254 | transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
255 | # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
256 | # this means that we have to transpose these weights when we import them
257 | assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
258 | for k in sd_keys_hf:
259 | if any(k.endswith(w) for w in transposed):
260 | # special treatment for the Conv1D weights we need to transpose
261 | assert sd_hf[k].shape[::-1] == sd[k].shape
262 | with torch.no_grad():
263 | sd[k].copy_(sd_hf[k].t())
264 | else:
265 | # vanilla copy over the other parameters
266 | assert sd_hf[k].shape == sd[k].shape
267 | with torch.no_grad():
268 | sd[k].copy_(sd_hf[k])
269 |
270 | return model
271 |
272 | def configure_optimizers(self, optimizer_name, weight_decay, learning_rate, betas, device_type,
273 | other_para_config=None):
274 | """
275 | This long function is unfortunately doing something very simple and is being very defensive:
276 | We are separating out all parameters of the model into two buckets: those that will experience
277 | weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
278 | We are then returning the PyTorch optimizer object.
279 | """
280 |
281 | # separate out all parameters to those that will and won't experience regularizing weight decay
282 | decay = set()
283 | no_decay = set()
284 | whitelist_weight_modules = (torch.nn.Linear, )
285 | blacklist_weight_modules = (torch.nn.LayerNorm, LayerNorm, torch.nn.Embedding)
286 | for mn, m in self.named_modules():
287 | for pn, p in m.named_parameters():
288 | fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
289 | # random note: because named_modules and named_parameters are recursive
290 | # we will see the same tensors p many many times. but doing it this way
291 | # allows us to know which parent module any tensor p belongs to...
292 | if pn.endswith('bias'):
293 | # all biases will not be decayed
294 | no_decay.add(fpn)
295 | elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
296 | # weights of whitelist modules will be weight decayed
297 | decay.add(fpn)
298 | elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
299 | # weights of blacklist modules will NOT be weight decayed
300 | no_decay.add(fpn)
301 |
302 | # subtle: 'transformer.wte.weight' and 'lm_head.weight' are tied, so they
303 | # will appear in the no_decay and decay sets respectively after the above.
304 | # In addition, because named_parameters() doesn't return duplicates, it
305 | # will only return the first occurrence, key'd by 'transformer.wte.weight', below.
306 | # so let's manually remove 'lm_head.weight' from decay set. This will include
307 | # this tensor into optimization via transformer.wte.weight only, and not decayed.
308 | decay.remove('lm_head.weight')
309 |
310 | # validate that we considered every parameter
311 | param_dict = {pn: p for pn, p in self.named_parameters()}
312 | inter_params = decay & no_decay
313 | union_params = decay | no_decay
314 | assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
315 | assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
316 | % (str(param_dict.keys() - union_params), )
317 |
318 | # create the pytorch optimizer object
319 | optim_groups = [
320 | {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay},
321 | {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
322 | ]
323 |
324 | opt_func = optimizer_dict[optimizer_name]
325 | if optimizer_name == 'adamw':
326 | # new PyTorch nightly has a new 'fused' option for AdamW that is much faster
327 | use_fused = (device_type == 'cuda') and ('fused' in inspect.signature(torch.optim.AdamW).parameters)
328 | print(f"using fused AdamW: {use_fused}")
329 | extra_args = dict(fused=True) if use_fused else dict()
330 | optimizer = opt_func(optim_groups, lr=learning_rate, betas=betas, **extra_args)
331 | elif optimizer_name == 'adamw_ours':
332 | optimizer = opt_func(optim_groups, lr=learning_rate, betas=betas)
333 | elif optimizer_name == 'mars':
334 | if other_para_config is None:
335 | optimizer = opt_func(optim_groups, lr=learning_rate, betas=betas)
336 | else:
337 | optimizer = opt_func(optim_groups, lr=learning_rate, betas=betas, **other_para_config)
338 | else:
339 | raise ValueError('Invalid optimizer.')
340 | return optimizer
341 |
342 | def estimate_mfu(self, fwdbwd_per_iter, dt):
343 | """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
344 | # first estimate the number of flops we do per iteration.
345 | # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
346 | N = self.get_num_params()
347 | cfg = self.config
348 | L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size
349 | flops_per_token = 6*N + 12*L*H*Q*T
350 | flops_per_fwdbwd = flops_per_token * T
351 | flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
352 | # express our flops throughput as ratio of A100 bfloat16 peak flops
353 | flops_achieved = flops_per_iter * (1.0/dt) # per second
354 | flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
355 | mfu = flops_achieved / flops_promised
356 | return mfu
357 |
358 | @torch.no_grad()
359 | def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
360 | """
361 | Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
362 | the sequence max_new_tokens times, feeding the predictions back into the model each time.
363 | Most likely you'll want to make sure to be in model.eval() mode of operation for this.
364 | """
365 | for _ in range(max_new_tokens):
366 | # if the sequence context is growing too long we must crop it at block_size
367 | idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
368 | # forward the model to get the logits for the index in the sequence
369 | logits, _ = self(idx_cond)
370 | # pluck the logits at the final step and scale by desired temperature
371 | logits = logits[:, -1, :] / temperature
372 | # optionally crop the logits to only the top k options
373 | if top_k is not None:
374 | v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
375 | logits[logits < v[:, [-1]]] = -float('Inf')
376 | # apply softmax to convert logits to (normalized) probabilities
377 | probs = F.softmax(logits, dim=-1)
378 | # sample from the distribution
379 | idx_next = torch.multinomial(probs, num_samples=1)
380 | # append sampled index to the running sequence and continue
381 | idx = torch.cat((idx, idx_next), dim=1)
382 |
383 | return idx
384 |
--------------------------------------------------------------------------------
/MARS/opt.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import collections
3 | """
4 | Adapted from askerlee@github: https://github.com/KellerJordan/modded-nanogpt/issues/9
5 | """
6 | def separate_params(param_groups):
7 | param_groups_2d = []
8 | param_groups_non2d = []
9 | total_param_2d_count = 0
10 | total_param_non2d_count = 0
11 |
12 |
13 | # Convert iterators to lists
14 | if isinstance(param_groups, collections.abc.Iterable):
15 | param_groups = list(param_groups)
16 |
17 | # Check if param_groups is a list of dicts or list of params
18 | if (isinstance(param_groups, list) and isinstance(param_groups[0], dict)) \
19 | or isinstance(param_groups, dict):
20 | if isinstance(param_groups, dict):
21 | param_groups = [param_groups]
22 | # param_groups is a list of dicts
23 | for group in param_groups:
24 | params_2d, params_non2d, param_2d_count, param_non2d_count = separate_params(group['params'])
25 | param_group_2d = {'params': params_2d}
26 | param_group_non2d = {'params': params_non2d}
27 | # Copy the group dict and replace the 'params' key with the separated params
28 | for k in group.keys():
29 | if k != 'params':
30 | param_group_2d[k] = group[k]
31 | param_group_non2d[k] = group[k]
32 |
33 | param_groups_2d.append(param_group_2d)
34 | param_groups_non2d.append(param_group_non2d)
35 | total_param_2d_count += param_2d_count
36 | total_param_non2d_count += param_non2d_count
37 |
38 | return param_groups_2d, param_groups_non2d, total_param_2d_count, total_param_non2d_count
39 |
40 | elif isinstance(param_groups, list) and isinstance(param_groups[0], torch.Tensor):
41 | params_2d = []
42 | params_non2d = []
43 | param_group = param_groups
44 | # param_group is a list of param tensors
45 | for param in param_group:
46 | if param.ndim >= 2:
47 | params_2d.append(param)
48 | else:
49 | params_non2d.append(param)
50 | return params_2d, params_non2d, len(params_2d), len(params_non2d)
51 | else:
52 | breakpoint()
53 |
54 | '''
55 | # CombinedOptimizer is now a torch.optim.Optimizer, compatible with pytorch lightning.
56 | # Original Example:
57 | optimizer = CombinedOptimizer([
58 | torch.optim.AdamW(self.lm_head.parameters(), lr=learning_rate, betas=betas, weight_decay=0, fused=True),
59 | OrthogonalNesterov(self.transformer.h.parameters(), lr=0.1*learning_rate, momentum=0.95)
60 | ])
61 | # Refactored Example:
62 | optimizer = CombinedOptimizer(\
63 | self.parameters(),
64 | [OrthogonalNesterov, torch.optim.AdamW],
65 | [{'lr': 0.1*learning_rate, 'momentum': 0.95},
66 | {'lr': learning_rate, 'betas': betas, 'weight_decay': 0, 'fused': True}
67 | ])
68 | '''
69 |
70 | class CombinedOptimizer(torch.optim.Optimizer):
71 | def __init__(self, params, optimizer_types, configs, raw_model = False):
72 | # Separate 2D and non-2D parameters.
73 | # If params is a list of tensors, then each of param_groups_2d and param_groups_non2d
74 | # will be a list of tensors.
75 | # If params is a list of dicts, then each of param_groups_2d and param_groups_non2d
76 | # will be a list of dicts.
77 | # If params is a dict, then each of param_groups_2d and param_groups_non2d will
78 | # be a list of dicts containing only one dict.
79 | if raw_model:
80 | params_others = list(params.transformer.h.parameters())
81 | param_groups_2d, param_groups_non2d, total_param_2d_count, total_param_non2d_count \
82 | = separate_params(params_others)
83 | param_groups_non2d.extend(list(params.lm_head.parameters()))
84 | total_param_non2d_count += 2
85 | else:
86 | param_groups_2d, param_groups_non2d, total_param_2d_count, total_param_non2d_count \
87 | = separate_params(params)
88 | param_groups_2d_non2d = (param_groups_non2d, param_groups_2d)
89 | print(f"Total 2D params: {total_param_2d_count}, Total non-2D params: {total_param_non2d_count}")
90 |
91 | assert len(optimizer_types) == len(configs) == 2
92 | self.optimizers = [ optimizer_types[i](param_groups_2d_non2d[i], **configs[i]) for i in range(2) ]
93 | self.param_groups = [pg for opt in self.optimizers for pg in opt.param_groups]
94 | self.base_lrs = [opt.param_groups[0]['lr'] for opt in self.optimizers]
95 | # Combine the state dicts of all opt in self.optimizers into a single dict
96 | self.state = {k: v for opt in self.optimizers for k, v in opt.state.items()}
97 | # Initially all states are empty. So no point to print their counts.
98 | # Only use the defaults of the OrthogonalNesterov optimizer
99 | self.defaults = self.optimizers[0].defaults
100 |
101 | def step(self, *args, **kwargs):
102 | for opt in self.optimizers:
103 | opt.step(*args, **kwargs)
104 |
105 | def zero_grad(self, **kwargs):
106 | for opt in self.optimizers:
107 | opt.zero_grad(**kwargs)
108 |
109 | def scale_lrs(self, lr_scale):
110 | for base_lr, opt in zip(self.base_lrs, self.optimizers):
111 | opt.param_groups[0]['lr'] = base_lr * lr_scale
112 |
113 | def state_dict(self):
114 | return [opt.state_dict() for opt in self.optimizers]
--------------------------------------------------------------------------------
/MARS/optimizers/adamw.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch.optim.optimizer import Optimizer
4 | # from megatron.optimizer.l2_norm import l2_norm
5 |
6 | def exists(val):
7 | return val is not None
8 |
9 |
10 | class AdamW(Optimizer):
11 | """Implements Adam algorithm.
12 |
13 | It has been proposed in `Adam: A Method for Stochastic Optimization`_.
14 |
15 | Arguments:
16 | params (iterable): iterable of parameters to optimize or dicts defining
17 | parameter groups
18 | lr (float, optional): learning rate (default: 1e-3)
19 | betas (Tuple[float, float], optional): coefficients used for computing
20 | running averages of gradient and its square (default: (0.9, 0.999))
21 | eps (float, optional): term added to the denominator to improve
22 | numerical stability (default: 1e-8)
23 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
24 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this
25 | algorithm from the paper `On the Convergence of Adam and Beyond`_
26 |
27 | .. _Adam\: A Method for Stochastic Optimization:
28 | https://arxiv.org/abs/1412.6980
29 | .. _On the Convergence of Adam and Beyond:
30 | https://openreview.net/forum?id=ryQu7f-RZ
31 | """
32 |
33 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
34 | weight_decay=0, amsgrad=False):
35 | if not 0.0 <= lr:
36 | raise ValueError("Invalid learning rate: {}".format(lr))
37 | if not 0.0 <= eps:
38 | raise ValueError("Invalid epsilon value: {}".format(eps))
39 | if not 0.0 <= betas[0] < 1.0:
40 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
41 | if not 0.0 <= betas[1] < 1.0:
42 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
43 | defaults = dict(lr=lr, betas=betas, eps=eps,
44 | weight_decay=weight_decay, amsgrad=amsgrad)
45 | super(AdamW, self).__init__(params, defaults)
46 | self.eps = eps
47 |
48 | def __setstate__(self, state):
49 | super(AdamW, self).__setstate__(state)
50 | for group in self.param_groups:
51 | group.setdefault('amsgrad', False)
52 |
53 | @torch.no_grad()
54 | def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None, grad_scaler=None):
55 | """Performs a single optimization step.
56 |
57 | Arguments:
58 | closure (callable, optional): A closure that reevaluates the model
59 | and returns the loss.
60 | """
61 | if any(p is not None for p in [grads, output_params, scale, grad_norms]):
62 | raise RuntimeError('FusedAdam has been updated. Simply initialize it identically to torch.optim.Adam, and call step() with no arguments.')
63 |
64 | loss = None
65 | if exists(closure):
66 | with torch.enable_grad():
67 | loss = closure()
68 | real_update = 0
69 | real_update_wo_lr = 0
70 |
71 | for group in self.param_groups:
72 | for p in filter(lambda p: exists(p.grad), group['params']):
73 | if p.grad is None:
74 | continue
75 | grad = p.grad.data
76 | if grad.is_sparse:
77 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
78 | amsgrad = group['amsgrad']
79 |
80 | state = self.state[p]
81 | #print('----- starting a parameter state', state.keys(), 'Length of state', len(state))
82 | # State initialization
83 | if len(state) == 0:
84 | state['step'] = 0
85 | # Exponential moving average of gradient values
86 | state['exp_avg'] = torch.zeros_like(p.data)
87 | # Exponential moving average of squared gradient values
88 | state['exp_avg_sq'] = torch.zeros_like(p.data)
89 | if amsgrad:
90 | # Maintains max of all exp. moving avg. of sq. grad. values
91 | state['max_exp_avg_sq'] = torch.zeros_like(p.data)
92 |
93 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
94 | if amsgrad:
95 | max_exp_avg_sq = state['max_exp_avg_sq']
96 | beta1, beta2 = group['betas']
97 |
98 | if 'step' in state:
99 | state['step'] += 1
100 | else:
101 | state['step'] = 1
102 |
103 | # Decay the first and second moment running average coefficient
104 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
105 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
106 | if amsgrad:
107 | # Maintains the maximum of all 2nd moment running avg. till now
108 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
109 | # Use the max. for normalizing running avg. of gradient
110 | denom = max_exp_avg_sq.sqrt().add_(self.eps)
111 | else:
112 | denom = exp_avg_sq.sqrt().add_(self.eps)
113 |
114 | bias_correction1 = 1 - beta1 ** state['step']
115 | bias_correction2 = 1 - beta2 ** state['step']
116 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
117 |
118 | # p.data.addcdiv_(-step_size, exp_avg, denom)
119 | real_update_tmp = -step_size * torch.mul(p.data, group['weight_decay']).addcdiv_(1, exp_avg, denom)
120 | real_update_wo_lr_tmp = torch.mul(p.data, group['weight_decay']).addcdiv_(1, exp_avg, denom)
121 |
122 | p.data.add_(real_update_tmp)
123 | return loss
124 |
125 |
126 |
127 |
--------------------------------------------------------------------------------
/MARS/optimizers/mars.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
2 | # SPDX-License-Identifier: Apache-2.0
3 | import math
4 | import torch
5 | from torch.optim.optimizer import Optimizer
6 | import os
7 | import numpy as np
8 | import math
9 | # from megatron.optimizer.l2_norm import l2_norm
10 |
11 | def exists(val):
12 | return val is not None
13 |
14 |
15 | def update_fn(p, grad, exp_avg, exp_avg_sq, lr, wd, beta1, beta2, last_grad, eps, amsgrad, max_exp_avg_sq, step, gamma,
16 | mars_type, is_grad_2d, optimize_1d, lr_1d_factor, betas_1d, weight_decay_1d):
17 | # optimize_1d: use MARS for 1d para, not: use AdamW for 1d para
18 | if optimize_1d or is_grad_2d:
19 | c_t = (grad - last_grad).mul(gamma * (beta1 / (1. - beta1))).add(grad)
20 | c_t_norm = torch.norm(c_t)
21 | if c_t_norm > 1.:
22 | c_t = c_t / c_t_norm
23 | exp_avg.mul_(beta1).add_(c_t, alpha=1. - beta1)
24 | if (mars_type == "mars-adamw") or (mars_type == "mars-shampoo" and not is_grad_2d):
25 | exp_avg_sq.mul_(beta2).addcmul_(c_t, c_t, value=1. - beta2)
26 | bias_correction1 = 1.0 - beta1 ** step
27 | bias_correction2 = 1.0 - beta2 ** step
28 | if amsgrad:
29 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
30 | denom = max_exp_avg_sq.sqrt().mul(1/math.sqrt(bias_correction2)).add(eps).mul(bias_correction1)
31 | else:
32 | denom = exp_avg_sq.sqrt().mul(1/math.sqrt(bias_correction2)).add(eps).mul(bias_correction1)
33 | real_update_tmp = -lr * torch.mul(p.data, wd).add(exp_avg.div(denom))
34 | elif mars_type == "mars-lion":
35 | real_update_tmp = -lr * torch.mul(p.data, wd).add(exp_avg.sign())
36 | elif mars_type == "mars-shampoo" and is_grad_2d:
37 | factor = max(1, grad.size(0)/grad.size(1))**0.5
38 | real_update_tmp = NewtonSchulz(exp_avg.mul(1./(1.-beta1)), eps=eps).mul(factor).add(wd, p.data).mul(-lr)
39 | p.data.add_(real_update_tmp)
40 | else:
41 | beta1_1d, beta2_1d = betas_1d
42 | exp_avg.mul_(beta1_1d).add_(grad, alpha=1. - beta1_1d)
43 | exp_avg_sq.mul_(beta2_1d).addcmul_(grad, grad, value=1. - beta2_1d)
44 | bias_correction1 = 1.0 - beta1_1d ** step
45 | bias_correction2 = 1.0 - beta2_1d ** step
46 | if amsgrad:
47 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
48 | denom = max_exp_avg_sq.sqrt().mul(1/math.sqrt(bias_correction2)).add(eps).mul(bias_correction1)
49 | else:
50 | denom = exp_avg_sq.sqrt().mul(1/math.sqrt(bias_correction2)).add(eps).mul(bias_correction1)
51 | real_update_tmp = -lr * lr_1d_factor * torch.mul(p.data, weight_decay_1d).add(exp_avg.div(denom))
52 | p.data.add_(real_update_tmp)
53 | return exp_avg, exp_avg_sq
54 |
55 | class MARS(Optimizer):
56 | def __init__(self, params, lr=3e-3, betas=(0.95, 0.99), eps=1e-8, weight_decay=0., amsgrad=False, gamma=0.025,
57 | is_approx=True, mars_type="mars-adamw", optimize_1d=False, lr_1d=3e-3, betas_1d=(0.9, 0.95), weight_decay_1d=0.1):
58 | if not 0.0 <= lr:
59 | raise ValueError("Invalid learning rate: {}".format(lr))
60 | if not 0.0 <= eps:
61 | raise ValueError("Invalid epsilon value: {}".format(eps))
62 | if not 0.0 <= betas[0] < 1.0:
63 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
64 | if not 0.0 <= betas[1] < 1.0:
65 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
66 | assert mars_type in ["mars-adamw", "mars-lion", "mars-shampoo"], "MARS type not supported"
67 | defaults = dict(lr=lr, betas=betas, eps=eps,
68 | weight_decay=weight_decay, amsgrad=amsgrad,
69 | mars_type=mars_type, gamma=gamma,
70 | optimize_1d=optimize_1d, weight_decay_1d=weight_decay_1d)
71 | super(MARS, self).__init__(params, defaults)
72 | self.eps = eps
73 | self.update_fn = update_fn
74 | self.lr = lr
75 | self.weight_decay=weight_decay
76 | self.amsgrad = amsgrad
77 | self.step_num = 0
78 | self.is_approx = is_approx
79 | self.gamma = gamma
80 | self.mars_type = mars_type
81 | self.optimize_1d = optimize_1d
82 | self.lr_1d_factor = lr_1d / lr
83 | self.weight_decay_1d = weight_decay_1d
84 | self.betas_1d = betas_1d
85 |
86 | @torch.no_grad()
87 | def update_last_grad(self):
88 | if not self.is_approx:
89 | for group in self.param_groups:
90 | for p in group['params']:
91 | state = self.state[p]
92 | if "last_grad" not in state:
93 | state["last_grad"] = torch.zeros_like(p)
94 | state["last_grad"].zero_().add_(state["previous_grad"], alpha=1.0)
95 | @torch.no_grad()
96 | def update_previous_grad(self):
97 | if not self.is_approx:
98 | for group in self.param_groups:
99 | #print ("para name", len(group['params']), len(group['names']), group['names'])
100 | for p in group['params']:
101 | # import pdb
102 | # pdb.set_trace()
103 | if p.grad is None:
104 | print (p, "grad is none")
105 | continue
106 | state = self.state[p]
107 | if "previous_grad" not in state:
108 | state['previous_grad'] = torch.zeros_like(p)
109 | state['previous_grad'].zero_().add_(p.grad, alpha=1.0)
110 |
111 | def __setstate__(self, state):
112 | super(MARS, self).__setstate__(state)
113 | for group in self.param_groups:
114 | group.setdefault('amsgrad', False)
115 |
116 | @torch.no_grad()
117 | def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None, grad_scaler=None):
118 | """Performs a single optimization step.
119 |
120 | Arguments:
121 | closure (callable, optional): A closure that reevaluates the model
122 | and returns the loss.
123 |
124 | If using exact version, the example usage is as follows:
125 | previous_X, previous_Y = None, None
126 | for epoch in range(epochs):
127 | for X, Y in data_loader:
128 | if previous_X:
129 | logits, loss = model(X, Y)
130 | loss.backward()
131 | optimizer.update_previous_grad()
132 | optimizer.zero_grad(set_to_none=True)
133 | logits, loss = model(X, Y)
134 | loss.backward()
135 | optimizer.step(bs=bs)
136 | optimizer.zero_grad(set_to_none=True)
137 | optimizer.update_last_grad()
138 | iter_num += 1
139 | previous_X, previous_Y = X.clone(), Y.clone()
140 | """
141 | if any(p is not None for p in [grads, output_params, scale, grad_norms]):
142 | raise RuntimeError('FusedAdam has been updated. Simply initialize it identically to torch.optim.Adam, and call step() with no arguments.')
143 |
144 | loss = None
145 | if exists(closure):
146 | with torch.enable_grad():
147 | loss = closure()
148 | real_update = 0
149 | real_update_wo_lr = 0
150 | gamma = self.gamma
151 | # import pdb
152 | # pdb.set_trace()
153 | for group in self.param_groups:
154 | for p in filter(lambda p: exists(p.grad), group['params']):
155 | if p.grad is None:
156 | continue
157 | grad = p.grad.data
158 | if grad.is_sparse:
159 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
160 | amsgrad = group['amsgrad']
161 |
162 | state = self.state[p]
163 | #('----- starting a parameter state', state.keys(), 'Length of state', len(state))
164 | # State initialization
165 | if len(state) <= 1:
166 | state['step'] = 0
167 | # Exponential moving average of gradient values
168 | state['exp_avg'] = torch.zeros_like(p.data)
169 | # Last Gradient
170 | state['last_grad'] = torch.zeros_like(p)
171 | #state['previous_grad'] = torch.zeros_like(p)
172 | # Exponential moving average of squared gradient values
173 | state['exp_avg_sq'] = torch.zeros_like(p.data)
174 | if amsgrad:
175 | # Maintains max of all exp. moving avg. of sq. grad. values
176 | state['max_exp_avg_sq'] = torch.zeros_like(p.data)
177 | # import pdb
178 | # pdb.set_trace()
179 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
180 | last_grad = state['last_grad']
181 | lr, wd, beta1, beta2 = group['lr'], group['weight_decay'], *group['betas']
182 | if amsgrad:
183 | max_exp_avg_sq = state['max_exp_avg_sq']
184 | else:
185 | max_exp_avg_sq = 0
186 |
187 | if 'step' in state:
188 | state['step'] += 1
189 | else:
190 | state['step'] = 1
191 | step = state['step']
192 | is_grad_2d = (len(grad.shape) == 2)
193 | exp_avg, exp_avg_sq = self.update_fn(
194 | p,
195 | grad,
196 | exp_avg,
197 | exp_avg_sq,
198 | lr,
199 | wd,
200 | beta1,
201 | beta2,
202 | last_grad,
203 | self.eps,
204 | amsgrad,
205 | max_exp_avg_sq,
206 | step,
207 | gamma,
208 | mars_type=self.mars_type,
209 | is_grad_2d=is_grad_2d,
210 | optimize_1d=self.optimize_1d,
211 | lr_1d_factor=self.lr_1d_factor,
212 | betas_1d=self.betas_1d,
213 | weight_decay_1d=self.weight_decay if self.optimize_1d else self.weight_decay_1d
214 | )
215 | if self.is_approx:
216 | state['last_grad'] = grad
217 | self.step_num = step
218 |
219 | return loss
220 |
221 | @torch.compile
222 | def NewtonSchulz(M, steps=5, eps=1e-7):
223 | a, b, c = (3.4445, -4.7750, 2.0315)
224 | X = M.bfloat16() / (M.norm() + eps)
225 | if M.size(0) > M.size(1):
226 | X = X.T
227 | for _ in range(steps):
228 | A = X @ X.T
229 | B = A @ X
230 | X = a * X + b * B + c * A @ B
231 | if M.size(0) > M.size(1):
232 | X = X.T
233 | return X.to(M.dtype)
234 |
--------------------------------------------------------------------------------
/MARS/optimizers/muon.py:
--------------------------------------------------------------------------------
1 | """
2 | Adapted from KellerJordan/modded-nanogpt: https://github.com/KellerJordan/modded-nanogpt/blob/master/train_gpt2.py
3 | """
4 |
5 | import torch
6 | import torch.distributed as dist
7 | import os
8 |
9 | def zeropower_via_svd(G, steps=None):
10 | U, S, V = G.svd()
11 | return U @ V.T
12 |
13 | @torch.compile
14 | def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7):
15 | """
16 | Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
17 | quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
18 | of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
19 | zero even beyond the point where the iteration no longer converges all the way to one everywhere
20 | on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
21 | where S' is diagonal with S_{ii}' \sim Uniform(0.5, 1.5), which turns out not to hurt model
22 | performance at all relative to UV^T, where USV^T = G is the SVD.
23 | """
24 | assert len(G.shape) == 2
25 | a, b, c = (3.4445, -4.7750, 2.0315)
26 | X = G.bfloat16()
27 | X /= (X.norm() + eps) # ensure top singular value <= 1
28 | if G.size(0) > G.size(1):
29 | X = X.T
30 | for _ in range(steps):
31 | A = X @ X.T
32 | B = A @ X
33 | X = a * X + b * B + c * A @ B
34 | if G.size(0) > G.size(1):
35 | X = X.T
36 | return X
37 |
38 | zeropower_backends = dict(svd=zeropower_via_svd, newtonschulz5=zeropower_via_newtonschulz5)
39 |
40 | class Muon(torch.optim.Optimizer):
41 | """
42 | Muon - MomentUm Orthogonalized by Newton-schulz
43 |
44 | Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
45 | processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
46 | matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
47 | the advantage that it can be stably run in bfloat16 on the GPU.
48 |
49 | Some warnings:
50 | - This optimizer assumes that all parameters passed in are 2D.
51 | - It should not be used for the embedding layer, the final fully connected layer, or any {0,1}-D
52 | parameters; those should all be optimized by a standard method (e.g., AdamW).
53 | - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
54 | - We believe it is unlikely to work well for training with small batch size.
55 | - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
56 | - We have not yet tried this optimizer for training scenarios larger than NanoGPT (124M).
57 |
58 | Arguments:
59 | lr: The learning rate used by the internal SGD.
60 | momentum: The momentum used by the internal SGD.
61 | nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
62 | backend: The chosen backend for the orthogonalization step. (recommended: 'newtonschulz5')
63 | backend_steps: The number of iteration steps to use in the backend, if it is iterative.
64 | """
65 | def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True,
66 | backend='newtonschulz5', backend_steps=5, weight_decay=0.):
67 | defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, backend=backend, backend_steps=backend_steps, weight_decay=weight_decay)
68 | super().__init__(params, defaults)
69 | if 'WORLD_SIZE' in os.environ:
70 | self.world_size = int(os.environ['WORLD_SIZE'])
71 | self.rank = int(os.environ['RANK'])
72 | else:
73 | self.world_size = 1
74 | self.rank = 0
75 |
76 | def step(self):
77 |
78 | for group in self.param_groups:
79 |
80 | lr = group['lr']
81 | weight_decay = group['weight_decay']
82 | momentum = group['momentum']
83 | zeropower_backend = zeropower_backends[group['backend']]
84 |
85 | # generate weight updates in distributed fashion
86 | total_params = sum(p.numel() for p in group['params'])
87 | updates_flat = torch.zeros(total_params, device='cuda', dtype=torch.bfloat16)
88 | curr_idx = 0
89 | for i, p in enumerate(group['params']):
90 | # luckily this will perfectly distribute a transformer with multiple of 4 layers to 8 GPUs
91 | if i % int(self.world_size) == int(self.rank):
92 | g = p.grad
93 | assert g is not None
94 | if g.ndim > 2:
95 | g = g.view(g.size(0), -1)
96 | state = self.state[p]
97 | if 'momentum_buffer' not in state:
98 | state['momentum_buffer'] = torch.zeros_like(g)
99 | buf = state['momentum_buffer']
100 | buf.mul_(momentum).add_(g)
101 | if group['nesterov']:
102 | g = g.add(buf, alpha=momentum)
103 | g = zeropower_backend(g, steps=group['backend_steps'])
104 | g *= max(1, g.size(0)/g.size(1))**0.5
105 | updates_flat[curr_idx:curr_idx+p.numel()] = g.flatten()
106 | curr_idx += p.numel()
107 |
108 | # sync updates across devices. we are not memory-constrained so can do this simple deserialization
109 | if self.world_size > 1:
110 | dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM)
111 |
112 | # deserialize and apply updates
113 | curr_idx = 0
114 | for p in group['params']:
115 | g = updates_flat[curr_idx:curr_idx+p.numel()].view_as(p.data).type_as(p.data)
116 | p.data.mul_(1.-lr*weight_decay).add_(g, alpha=-lr)
117 | curr_idx += p.numel()
--------------------------------------------------------------------------------
/MARS/train_CNN.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
2 | # SPDX-License-Identifier: Apache-2.0
3 | import argparse
4 | from typing import List, Tuple, Type
5 |
6 | import matplotlib.pyplot as plt
7 | import torch
8 | import torch.nn as nn
9 | from torch.optim import Adam, AdamW
10 | from torch.optim.lr_scheduler import CosineAnnealingLR
11 | from torch.utils.data import DataLoader
12 | from torchvision import datasets, transforms
13 | import numpy as np
14 | from utils.model_CNN import Network
15 | from optimizers.adopt import ADOPT
16 | from optimizers.mars import MARS
17 | import random
18 | parser = argparse.ArgumentParser(add_help=True)
19 | parser.add_argument(
20 | "--dataset", type=str, default="cifar10", choices=["mnist", "cifar10"], help="dataset to use"
21 | )
22 | parser.add_argument("-b", "--batch_size", type=int, default=128, help="batch size")
23 | parser.add_argument("-e", "--epochs", type=int, default=50, help="number of epochs")
24 | parser.add_argument("--seed", type=int, default=0, help="random seed")
25 | parser.add_argument("--cpu", action="store_true", help="use cpu only")
26 |
27 |
28 | def get_datasets(dataset_name: str, batch_size: int) -> Tuple[DataLoader, DataLoader]:
29 | """Get train and test dataloaders."""
30 | if dataset_name == "mnist":
31 | transform = transforms.Compose([
32 | transforms.ToTensor(),
33 | transforms.Normalize((0.1307,), (0.3081,))
34 | ])
35 | train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
36 | test_dataset = datasets.MNIST('./data', train=False, transform=transform)
37 | elif dataset_name == "cifar10":
38 | transform_train = transforms.Compose([
39 | transforms.RandomHorizontalFlip(),
40 | transforms.ToTensor(),
41 | # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
42 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
43 | ])
44 | transform_test = transforms.Compose([
45 | transforms.ToTensor(),
46 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
47 | ])
48 | train_dataset = datasets.CIFAR10('./data', train=True, download=True, transform=transform_train)
49 | test_dataset = datasets.CIFAR10('./data', train=False, transform=transform_test)
50 | else:
51 | raise NotImplementedError(f"{dataset_name=} is not implemented.")
52 |
53 | train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
54 | test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
55 |
56 | return train_loader, test_loader
57 |
58 |
59 | class WarmupCosineScheduler:
60 | """Custom learning rate scheduler with linear warmup and cosine decay."""
61 | def __init__(self, optimizer, warmup_iters: int, total_iters: int, min_lr: float, max_lr: float):
62 | self.optimizer = optimizer
63 | self.warmup_iters = warmup_iters
64 | self.total_iters = total_iters
65 | self.min_lr = min_lr
66 | self.max_lr = max_lr
67 | self.current_iter = 0
68 | self.lr = 0
69 |
70 | def step(self):
71 | self.current_iter += 1
72 | if self.current_iter <= self.warmup_iters:
73 | lr = self.current_iter / self.warmup_iters * self.max_lr
74 | else:
75 | lr = self.min_lr + 0.5 * (self.max_lr - self.min_lr) * (
76 | np.cos((self.current_iter - self.warmup_iters) / (self.total_iters - self.warmup_iters) * 3.14159265 / 2)
77 | ).item()
78 |
79 | for param_group in self.optimizer.param_groups:
80 | param_group['lr'] = lr
81 | self.lr = lr
82 |
83 | class Trainer:
84 | """Training manager for PyTorch models."""
85 | def __init__(self, model: nn.Module, optimizer: torch.optim.Optimizer, scheduler, device: torch.device):
86 | self.model = model
87 | self.optimizer = optimizer
88 | self.scheduler = scheduler
89 | self.device = device
90 | self.criterion = nn.CrossEntropyLoss()
91 | self.train_acc_trace = []
92 | self.val_acc_trace = []
93 |
94 | def train_epoch(self, train_loader: DataLoader) -> float:
95 | self.model.train()
96 | correct = 0
97 | total = 0
98 |
99 | for batch in train_loader:
100 | images, targets = batch[0].to(self.device), batch[1].to(self.device)
101 |
102 | self.optimizer.zero_grad()
103 | outputs = self.model(images)
104 | loss = self.criterion(outputs, targets)
105 | loss.backward()
106 | self.optimizer.step()
107 |
108 | _, predicted = outputs.max(1)
109 | total += targets.size(0)
110 | correct += predicted.eq(targets).sum().item()
111 | if self.scheduler is not None:
112 | self.scheduler.step()
113 | return 100. * correct / total
114 |
115 | def evaluate(self, test_loader: DataLoader) -> float:
116 | self.model.eval()
117 | correct = 0
118 | total = 0
119 |
120 | with torch.no_grad():
121 | for batch in test_loader:
122 | images, targets = batch[0].to(self.device), batch[1].to(self.device)
123 | outputs = self.model(images)
124 |
125 | _, predicted = outputs.max(1)
126 | total += targets.size(0)
127 | correct += predicted.eq(targets).sum().item()
128 |
129 | return 100. * correct / total
130 |
131 | def train(self, train_loader: DataLoader, test_loader: DataLoader, epochs: int):
132 | for epoch in range(epochs):
133 | train_acc = self.train_epoch(train_loader)
134 | val_acc = self.evaluate(test_loader)
135 |
136 | self.train_acc_trace.append(train_acc)
137 | self.val_acc_trace.append(val_acc)
138 |
139 | # if self.scheduler is not None:
140 | # self.scheduler.step()
141 |
142 | print(f"Epoch {epoch+1}/{epochs} - Train Acc: {train_acc:.2f}% - Val Acc: {val_acc:.2f}%")
143 |
144 |
145 | def get_optimizers(model: nn.Module, opt_name, args):
146 | """Configure optimizers and schedulers."""
147 | total_steps = 50_000 // args.batch_size * args.epochs
148 | n_warmup = int(total_steps * 0.10) # % of total steps
149 | weight_decay = 1e-4
150 | max_lr = 6e-4
151 | min_lr = 1e-6
152 |
153 | if opt_name == "Adam":
154 | # Adam
155 | adam = Adam(model.parameters(), lr=max_lr)
156 | adam_scheduler = WarmupCosineScheduler(
157 | adam, n_warmup, total_steps, min_lr, max_lr
158 | )
159 | optimizer = (adam, adam_scheduler, "Adam")
160 |
161 | elif opt_name == "AdamW":
162 | # AdamW
163 | adamw = AdamW(model.parameters(), lr=max_lr, weight_decay=weight_decay)
164 | adamw_scheduler = WarmupCosineScheduler(
165 | adamw, n_warmup, total_steps, min_lr, max_lr
166 | )
167 | optimizer = (adamw, adamw_scheduler, "AdamW")
168 | elif opt_name == "ADOPT":
169 | # ADOPT
170 | adopt = ADOPT(model.parameters(), lr=max_lr, weight_decay=weight_decay)
171 | adopt_scheduler = WarmupCosineScheduler(
172 | adopt, n_warmup, total_steps, min_lr, max_lr
173 | )
174 | optimizer = (adopt, adopt_scheduler, "ADOPT")
175 | elif opt_name == "MARS":
176 | # MARS
177 | mars = MARS(model.parameters(), lr=3e-3, weight_decay=weight_decay, optimize_1d=False)
178 | mars_scheduler = WarmupCosineScheduler(
179 | mars, n_warmup, total_steps, min_lr, 3e-3
180 | )
181 | optimizer = (mars, mars_scheduler, "MARS")
182 | return optimizer
183 |
184 |
185 | def plot_results(results: List[List[float]], optimizer_names: List[str], args):
186 | """Plot training results."""
187 | fig, ax = plt.subplots(figsize=(5.5, 3.5))
188 | colors = ["#74add1", "#1730bd", "#1a9850", "#001c01"]
189 |
190 | for i, acc in enumerate(results):
191 | ax.plot(range(1, len(acc) + 1), acc, label=optimizer_names[i], lw=2, color=colors[i])
192 |
193 | ax.set_title(f"{args.dataset.upper()} (val)", loc="left")
194 | ax.set_xlabel("Epoch", fontsize="medium")
195 | ax.set_ylabel("Accuracy (%)", fontsize="medium")
196 |
197 | ax.legend(ncols=2, columnspacing=0.8, fontsize="medium")
198 | ax.grid(alpha=0.2)
199 |
200 | ax.set_ylim(90 if args.dataset == "mnist" else 70)
201 | acc_min, acc_max = ax.get_ylim()
202 | ax.set_yticks(torch.linspace(acc_min, acc_max, 5).int().tolist())
203 | ax.xaxis.set_major_locator(plt.MaxNLocator(integer=True))
204 |
205 | fig.tight_layout()
206 | fig.savefig(
207 | f"./compare-{args.dataset}-blank.png",
208 | dpi=300,
209 | bbox_inches="tight",
210 | )
211 | plt.show()
212 |
213 |
214 | def main(args):
215 | # Set random seed and device
216 | torch.manual_seed(args.seed)
217 | device = torch.device("cuda" if torch.cuda.is_available() and not args.cpu else "cpu")
218 |
219 | # Get dataloaders
220 | train_loader, test_loader = get_datasets(args.dataset, args.batch_size)
221 | # Model configuration
222 | model_config = {
223 | "n_inputs": (3, 32, 32) if args.dataset == "cifar10" else (1, 28, 28),
224 | "conv_layers_list": [
225 | {"filters": 32, "kernel_size": 3, "repeat": 2, "batch_norm": True},
226 | {"filters": 64, "kernel_size": 3, "repeat": 2, "batch_norm": True},
227 | {"filters": 128, "kernel_size": 3, "repeat": 2, "batch_norm": True},
228 | ],
229 | "n_hiddens_list": [512],
230 | "n_outputs": 10,
231 | "dropout": 0.2,
232 | }
233 |
234 | results = []
235 | optimizer_names = []
236 | # Train with different optimizers
237 | opt_names = ["Adam", "AdamW", "ADOPT", "MARS"]
238 | for opt_name in opt_names:
239 | print(opt_name)
240 | torch.manual_seed(args.seed)
241 | model = Network(**model_config).to(device)
242 | optimizer, scheduler, name = get_optimizers(model, opt_name, args)
243 | trainer = Trainer(model, optimizer, scheduler, device)
244 | trainer.train(train_loader, test_loader, args.epochs)
245 | results.append(trainer.val_acc_trace)
246 | optimizer_names.append(name)
247 |
248 | plot_results(results, optimizer_names, args)
249 |
250 |
251 | if __name__ == "__main__":
252 | args = parser.parse_args()
253 | main(args)
--------------------------------------------------------------------------------
/MARS/train_CV.py:
--------------------------------------------------------------------------------
1 | """
2 | Adapted from uclaml/Padam: https://github.com/uclaml/Padam/blob/master/run_cnn_test_cifar10.py
3 | """
4 | import numpy as np
5 | import os
6 | import argparse
7 | import json
8 | from tqdm import tqdm
9 |
10 | parser = argparse.ArgumentParser(description='PyTorch Training')
11 | parser.add_argument(
12 | "--dataset", type=str, default="cifar10", choices=["mnist", "cifar10", "cifar100"], help="dataset to use"
13 | )
14 | parser.add_argument(
15 | "--scheduler", type=str, default="multistep", choices=["multistep", "cosine", "constant"], help="scheduler to use"
16 | )
17 | parser.add_argument("--train_bsz", type=int, default=128, help="training batch size")
18 | parser.add_argument("--eval_bsz", type=int, default=100, help="eval batch size")
19 | parser.add_argument("--seed", type=int, default=0, help="random seed")
20 | parser.add_argument("--cpu", action="store_true", help="use cpu only")
21 | parser.add_argument("--cuda", type=str, default="0", help="device to use")
22 |
23 | parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
24 | parser.add_argument('--adamw_lr', default=0.003, type=float, help='learning rate for adamw')
25 | parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
26 | parser.add_argument('--optim', '-m', type=str, choices=["adam", "adamw", "mars", "muon"], default='mars', help='optimization method, default: mars')
27 | parser.add_argument('--net', '-n', type=str, default="resnet18", help='network archtecture, choosing from "simple_cnn" or torchvision models. default: resnet18')
28 | parser.add_argument('--wd', default=0., type=float, help='weight decay')
29 | parser.add_argument('--Nepoch', default=200, type=int, help='number of epoch')
30 | parser.add_argument('--beta1', default=0.9, type=float, help='beta1')
31 | parser.add_argument('--beta2', default=0.999, type=float, help='beta2')
32 | parser.add_argument('--wandb', action='store_true', help='use wandb')
33 | parser.add_argument('--save_dir', type=str, default="./checkpoint", help='save directory')
34 | parser.add_argument('--wandb_name', type=str, default="None", help='log directory')
35 |
36 |
37 | args = parser.parse_args()
38 | os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda
39 | if args.wandb:
40 | import wandb
41 | if args.wandb_name == "None":
42 | wandb.init(project="CV", name=args.dataset+"_"+args.net+"_"+args.optim+"_"+str(args.lr), config=args)
43 | else:
44 | wandb.init(project="CV", name=args.wandb_name, config=args)
45 |
46 | import torch
47 | import torch.nn as nn
48 | import torch.optim as optim
49 | import torch.backends.cudnn as cudnn
50 | from utils.cv_utils import get_datasets, get_scheduler, get_model
51 | use_cuda = torch.cuda.is_available() and not args.cpu
52 |
53 | os.environ['PYTHONHASHSEED'] = str(args.seed)
54 | np.random.seed(args.seed)
55 | torch.manual_seed(args.seed)
56 | torch.cuda.manual_seed(args.seed)
57 | torch.cuda.manual_seed_all(args.seed)
58 |
59 | trainloader, testloader = get_datasets(args.dataset, args.train_bsz, args.eval_bsz)
60 | if args.resume:
61 | # Load checkpoint.
62 | print('==> Resuming from checkpoint..')
63 | assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
64 | checkpoint = torch.load(f'./checkpoint/{args.net}_{args.dataset}_'+args.optim)
65 | model = checkpoint['model']
66 | start_epoch = checkpoint['epoch']
67 | train_losses = checkpoint['train_losses']
68 | test_losses = checkpoint['test_losses']
69 | train_errs = checkpoint['train_errs']
70 | test_errs = checkpoint['test_errs']
71 | else:
72 | print('==> Building model..')
73 |
74 | model = get_model(args)
75 |
76 | if use_cuda:
77 | model.cuda()
78 | model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
79 | cudnn.benchmark = True
80 |
81 |
82 | criterion = nn.CrossEntropyLoss()
83 |
84 | betas = (args.beta1, args.beta2)
85 | from optimizers.mars import MARS
86 | from optimizers.muon import Muon
87 | from opt import CombinedOptimizer
88 | from optimizers.adamw import AdamW
89 | if args.optim == 'adam':
90 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay = args.wd, betas = betas)
91 | elif args.optim == 'adamw':
92 | optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay = args.wd, betas = betas)
93 | elif args.optim == 'muon':
94 | optimizer = CombinedOptimizer(model.parameters(), [AdamW, Muon], [{'lr': args.adamw_lr, 'betas': betas, 'weight_decay': args.wd},
95 | {'lr': args.lr, 'weight_decay': 0.}])
96 | elif args.optim == 'mars':
97 | optimizer = MARS(model.parameters(), lr=args.lr, weight_decay = args.wd, lr_1d=args.adamw_lr)
98 |
99 | scheduler = get_scheduler(optimizer, args)
100 | best_acc = 0 # best test accuracy
101 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch
102 | train_errs = []
103 | test_errs = []
104 | train_losses = []
105 | test_losses = []
106 | acc_list = []
107 | t_bar = tqdm(total=len(trainloader))
108 | t_bar2 = tqdm(total=len(testloader))
109 | for epoch in range(start_epoch+1, args.Nepoch+1):
110 |
111 | scheduler.step()
112 | # print ('\nEpoch: %d' % epoch, ' Learning rate:', scheduler.get_lr())
113 | model.train() # Training
114 |
115 | train_loss = 0
116 | correct_train = 0
117 | total_train = 0
118 | print(scheduler.get_lr())
119 | t_bar.reset()
120 | for batch_idx, (inputs, targets) in enumerate(trainloader):
121 | if use_cuda:
122 | inputs, targets = inputs.cuda(), targets.cuda()
123 |
124 | optimizer.zero_grad()
125 | outputs = model(inputs)
126 | loss = criterion(outputs, targets)
127 | loss.backward()
128 | optimizer.step()
129 |
130 | train_loss += loss.item()
131 | _, predicted = torch.max(outputs.data, 1)
132 | total_train += targets.size(0)
133 | correct_train += predicted.eq(targets.data).cpu().sum().item()
134 |
135 | t_bar.update(1)
136 | t_bar.set_description('Epoch: %d | Loss: %.3f | Acc: %.3f%% ' % (epoch, train_loss/(batch_idx+1), 100.0/total_train*(correct_train)))
137 | t_bar.refresh()
138 | train_losses.append(train_loss/(batch_idx+1))
139 | train_errs.append(1 - correct_train/total_train)
140 |
141 | model.eval() # Testing
142 |
143 | test_loss = 0
144 | correct = 0
145 | total = 0
146 | t_bar2.reset()
147 | for batch_idx, (inputs, targets) in enumerate(testloader):
148 | if use_cuda:
149 | inputs, targets = inputs.cuda(), targets.cuda()
150 | outputs = model(inputs)
151 | loss = criterion(outputs, targets)
152 |
153 | test_loss += loss.item()
154 | _, predicted = torch.max(outputs.data, 1)
155 | total += targets.size(0)
156 | correct += predicted.eq(targets.data).cpu().sum().item()
157 |
158 | t_bar2.update(1)
159 | t_bar2.set_description('Loss: %.3f | Acc: %.3f%% (Best: %.3f%%)' % (test_loss/(batch_idx+1), 100.0/total*(correct), best_acc))
160 | t_bar2.refresh()
161 | test_errs.append(1 - correct/total)
162 | test_losses.append(test_loss/(batch_idx+1))
163 | if args.wandb:
164 | wandb.log({"epoch": epoch,
165 | "train_loss": train_loss/(batch_idx+1),
166 | "train_acc": 100.0/total_train*(correct_train),
167 | "test_loss": test_loss/(batch_idx+1),
168 | "test_acc": 100.0/total*(correct),
169 | "lr": scheduler.get_lr()[0]}, step=epoch)
170 | # Save checkpoint
171 | acc = 100.0/total*(correct)
172 | if acc > best_acc:
173 | if not os.path.isdir('checkpoint'):
174 | os.mkdir('checkpoint')
175 | state = {
176 | 'model': model,
177 | 'epoch': epoch,
178 | }
179 | # torch.save(state, './checkpoint/cnn_cifar10_' + args.optim)
180 | torch.save(state, os.path.join(args.save_dir, "-".join([args.net, args.dataset, args.optim, str(args.lr).replace(".", "_")])+".pth"))
181 | best_acc = acc
182 | t_bar2.set_description('Model Saved! | Loss: %.3f | Acc: %.3f%% (Best: %.3f%%)' % (test_loss/(batch_idx+1), 100.0/total*(correct), best_acc))
183 | t_bar2.refresh()
184 | acc_list.append(acc)
185 |
--------------------------------------------------------------------------------
/MARS/train_adamw.py:
--------------------------------------------------------------------------------
1 | """
2 | Adapted from https://github.com/Liuhong99/Sophia/blob/main/train_adam.py
3 | """
4 | import os
5 | import time
6 | import math
7 | import pickle
8 | from contextlib import nullcontext
9 |
10 | import numpy as np
11 | import torch
12 | from torch.nn.parallel import DistributedDataParallel as DDP
13 | from torch.distributed import init_process_group, destroy_process_group
14 |
15 | from model import GPTConfig, GPT
16 | import sys
17 | from ast import literal_eval
18 | # -----------------------------------------------------------------------------
19 | # default config values designed to train a gpt2 (124M) on OpenWebText
20 | # I/O
21 | out_dir = 'out'
22 | eval_interval = 2000
23 | log_interval = 1
24 | eval_iters = 200
25 | eval_only = False # if True, script exits right after the first eval
26 | always_save_checkpoint = True # if True, always save a checkpoint after each eval
27 | init_from = 'scratch' # 'scratch' or 'resume' or 'gpt2*'
28 | # wandb logging
29 | wandb_log = False # disabled by default
30 | wandb_project = 'mars'
31 | wandb_run_name = 'gpt2' # 'run' + str(time.time())
32 | # data
33 | dataset = 'openwebtext'
34 | gradient_accumulation_steps = 5 # used to simulate larger batch sizes
35 | batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size
36 | block_size = 1024
37 | # model
38 | n_layer = 12
39 | n_head = 12
40 | n_embd = 768
41 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
42 | bias = False # do we use bias inside LayerNorm and Linear layers?
43 | # optimizer
44 | optimizer_name = 'adamw'
45 | learning_rate = 6e-4 # max learning rate
46 | max_iters = 600000 # total number of training iterations
47 | weight_decay = 1e-1
48 | beta1 = 0.9
49 | beta2 = 0.95
50 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
51 | interval = 10
52 | variant = 4
53 | schedule='cosine'
54 | # learning rate decay settings
55 | decay_lr = True # whether to decay the learning rate
56 | warmup_iters = 2000 # how many steps to warm up for
57 | lr_decay_iters = 600000 # should be ~= max_iters per Chinchilla
58 | min_lr = 6e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
59 | # DDP settings
60 | backend = 'nccl' # 'nccl', 'gloo', etc.
61 | # system
62 | device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
63 | dtype = 'bfloat16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
64 | compile = True # use PyTorch 2.0 to compile the model to be faster
65 | scale_attn_by_inverse_layer_idx = True
66 | # -----------------------------------------------------------------------------
67 | config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
68 | for arg in sys.argv[1:]:
69 | if '=' not in arg:
70 | # assume it's the name of a config file
71 | assert not arg.startswith('--')
72 | config_file = arg
73 | print(f"Overriding config with {config_file}:")
74 | with open(config_file) as f:
75 | print(f.read())
76 | exec(open(config_file).read())
77 | else:
78 | # assume it's a --key=value argument
79 | assert arg.startswith('--')
80 | key, val = arg.split('=')
81 | key = key[2:]
82 | if key in globals():
83 | try:
84 | # attempt to eval it it (e.g. if bool, number, or etc)
85 | attempt = literal_eval(val)
86 | except (SyntaxError, ValueError):
87 | # if that goes wrong, just use the string
88 | attempt = val
89 | # ensure the types match ok
90 | assert type(attempt) == type(globals()[key])
91 | # cross fingers
92 | print(f"Overriding: {key} = {attempt}")
93 | globals()[key] = attempt
94 | else:
95 | raise ValueError(f"Unknown config key: {key}")
96 |
97 | config = {k: globals()[k] for k in config_keys} # will be useful for logging
98 | # -----------------------------------------------------------------------------
99 |
100 | # various inits, derived attributes, I/O setup
101 | ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
102 | if ddp:
103 | init_process_group(backend=backend)
104 | ddp_rank = int(os.environ['RANK'])
105 | ddp_local_rank = int(os.environ['LOCAL_RANK'])
106 | device = f'cuda:{ddp_local_rank}'
107 | torch.cuda.set_device(device)
108 | master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
109 | seed_offset = ddp_rank # each process gets a different seed
110 | else:
111 | # if not ddp, we are running on a single gpu, and one process
112 | master_process = True
113 | seed_offset = 0
114 | gradient_accumulation_steps *= 8 # simulate 8 gpus
115 |
116 | if master_process:
117 | os.makedirs(out_dir, exist_ok=True)
118 | torch.manual_seed(5000 + seed_offset)
119 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
120 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
121 | device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
122 | # note: float16 data type will automatically use a GradScaler
123 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
124 | ctx = nullcontext() if device_type == 'cpu' else torch.autocast(device_type=device_type, dtype=ptdtype)
125 |
126 | # poor man's data loader
127 | data_dir = os.path.join('data', dataset)
128 | train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
129 | val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
130 | def get_batch(split):
131 | data = train_data if split == 'train' else val_data
132 | ix = torch.randint(len(data) - block_size, (batch_size,))
133 | x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
134 | y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
135 | if device_type == 'cuda':
136 | # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
137 | x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
138 | else:
139 | x, y = x.to(device), y.to(device)
140 | return x, y
141 |
142 | # init these up here, can override if init_from='resume' (i.e. from a checkpoint)
143 | iter_num = 0
144 | best_val_loss = 1e9
145 |
146 | # attempt to derive vocab_size from the dataset
147 | meta_path = os.path.join(data_dir, 'meta.pkl')
148 | meta_vocab_size = None
149 | if os.path.exists(meta_path):
150 | with open(meta_path, 'rb') as f:
151 | meta = pickle.load(f)
152 | meta_vocab_size = meta['vocab_size']
153 | print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})")
154 |
155 | # model init
156 | model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size,
157 | bias=bias, vocab_size=None, dropout=dropout, scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx) # start with model_args from command line
158 | if init_from == 'scratch':
159 | # init a new model from scratch
160 | print("Initializing a new model from scratch")
161 | # determine the vocab size we'll use for from-scratch training
162 | if meta_vocab_size is None:
163 | print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)")
164 | model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304
165 | gptconf = GPTConfig(**model_args)
166 | model = GPT(gptconf)
167 | elif init_from == 'resume':
168 | print(f"Resuming training from {out_dir}")
169 | # resume training from a checkpoint.
170 | ckpt_path = os.path.join(out_dir, 'ckpt.pt')
171 | checkpoint = torch.load(ckpt_path, map_location=device)
172 | checkpoint_model_args = checkpoint['model_args']
173 | # force these config attributes to be equal otherwise we can't even resume training
174 | # the rest of the attributes (e.g. dropout) can stay as desired from command line
175 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
176 | model_args[k] = checkpoint_model_args[k]
177 | # create the model
178 | gptconf = GPTConfig(**model_args)
179 | model = GPT(gptconf)
180 | state_dict = checkpoint['model']
181 | # fix the keys of the state dictionary :(
182 | # honestly no idea how checkpoints sometimes get this prefix, have to debug more
183 | unwanted_prefix = '_orig_mod.'
184 | for k,v in list(state_dict.items()):
185 | if k.startswith(unwanted_prefix):
186 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
187 | model.load_state_dict(state_dict)
188 | iter_num = checkpoint['iter_num']
189 | best_val_loss = checkpoint['best_val_loss']
190 | elif init_from.startswith('gpt2'):
191 | print(f"Initializing from OpenAI GPT-2 weights: {init_from}")
192 | # initialize from OpenAI GPT-2 weights
193 | override_args = dict(dropout=dropout)
194 | model = GPT.from_pretrained(init_from, override_args)
195 | # read off the created config params, so we can store them into checkpoint correctly
196 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
197 | model_args[k] = getattr(model.config, k)
198 | # crop down the model block size if desired, using model surgery
199 | if block_size < model.config.block_size:
200 | model.crop_block_size(block_size)
201 | model_args['block_size'] = block_size # so that the checkpoint will have the right value
202 | model.to(device)
203 |
204 | # initialize a GradScaler. If enabled=False scaler is a no-op
205 | scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))
206 |
207 | # optimizer
208 | optimizer = model.configure_optimizers(optimizer_name, weight_decay, learning_rate, (beta1, beta2), device_type)
209 | if init_from == 'resume':
210 | optimizer.load_state_dict(checkpoint['optimizer'])
211 | del state_dict
212 | del checkpoint
213 | # compile the model
214 | if compile:
215 | print("compiling the model... (takes a ~minute)")
216 | unoptimized_model = model
217 | model = torch.compile(model) # requires PyTorch 2.0
218 |
219 | # wrap model into DDP container
220 | if ddp:
221 | model = DDP(model, device_ids=[ddp_local_rank])
222 |
223 | # helps estimate an arbitrarily accurate loss over either split using many batches
224 | @torch.no_grad()
225 | def estimate_loss():
226 | out = {}
227 | model.eval()
228 | for split in ['train', 'val']:
229 | losses = torch.zeros(eval_iters)
230 | for k in range(eval_iters):
231 | X, Y = get_batch(split)
232 | with ctx:
233 | logits, loss = model(X, Y)
234 | losses[k] = loss.item()
235 | out[split] = losses.mean()
236 | model.train()
237 | return out
238 |
239 | # learning rate decay scheduler (cosine with warmup)
240 | def get_lr(it, schedule='cosine'):
241 | #ing rate schedule {schedule}")
242 | # 1) linear warmup for warmup_iters steps
243 | if it < warmup_iters:
244 | return learning_rate * it / warmup_iters
245 | # 2) if it > lr_decay_iters, return min learning rate
246 | if it > lr_decay_iters:
247 | return min_lr
248 | # 3) in between, use cosine decay down to min learning rate
249 | decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
250 | assert 0 <= decay_ratio <= 1
251 | if schedule=='cosine':
252 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
253 | elif schedule=='exp':
254 | coeff = np.power(0.9, 100 * decay_ratio)
255 | return min_lr + coeff * (learning_rate - min_lr)
256 |
257 | # logging
258 | if wandb_log and master_process:
259 | import wandb
260 | wandb.init(project=wandb_project, name=wandb_run_name, config=config)
261 |
262 | # training loop
263 | X, Y = get_batch('train') # fetch the very first batch
264 | t0 = time.time()
265 | local_iter_num = 0 # number of iterations in the lifetime of this process
266 | raw_model = model.module if ddp else model # unwrap DDP container if needed
267 | running_mfu = -1.0
268 | clip_time = 0
269 | while True:
270 |
271 | # determine and set the learning rate for this iteration
272 | lr = get_lr(iter_num, schedule=schedule) if decay_lr else learning_rate
273 | for param_group in optimizer.param_groups:
274 | param_group['lr'] = lr
275 |
276 | # evaluate the loss on train/val sets and write checkpoints
277 | if iter_num % eval_interval == 0 and master_process:
278 | losses = estimate_loss()
279 | print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
280 | if wandb_log:
281 | wandb.log({
282 | "iter": iter_num,
283 | "train/loss": losses['train'],
284 | "val/loss": losses['val'],
285 | "lr": lr,
286 | "mfu": running_mfu*100, # convert to percentage
287 | }, step=iter_num)
288 | if losses['val'] < best_val_loss or always_save_checkpoint:
289 | best_val_loss = losses['val']
290 | if iter_num > 0:
291 | checkpoint = {
292 | 'model': raw_model.state_dict(),
293 | 'optimizer': optimizer.state_dict(),
294 | 'model_args': model_args,
295 | 'iter_num': iter_num,
296 | 'best_val_loss': best_val_loss,
297 | 'config': config,
298 | }
299 | print(f"saving checkpoint to {out_dir}")
300 | torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
301 | if iter_num % (eval_interval * 5) == 0:
302 | checkpoint = {
303 | 'model': raw_model.state_dict(),
304 | 'optimizer': optimizer.state_dict(),
305 | 'model_args': model_args,
306 | 'iter_num': iter_num,
307 | 'best_val_loss': best_val_loss,
308 | 'config': config,
309 | }
310 | print(f"saving checkpoint to {out_dir}")
311 | torch.save(checkpoint, os.path.join(out_dir, 'ckpt_'+str(iter_num)+'.pt'))
312 | if iter_num == 0 and eval_only:
313 | break
314 |
315 | # forward backward update, with optional gradient accumulation to simulate larger batch size
316 | # and using the GradScaler if data type is float16
317 | for micro_step in range(gradient_accumulation_steps):
318 | if ddp:
319 | # in DDP training we only need to sync gradients at the last micro step.
320 | # the official way to do this is with model.no_sync() context manager, but
321 | # I really dislike that this bloats the code and forces us to repeat code
322 | # looking at the source of that context manager, it just toggles this variable
323 | model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
324 | with ctx:
325 | logits, loss = model(X, Y)
326 | # immediately async prefetch next batch while model is doing the forward pass on the GPU
327 | X, Y = get_batch('train')
328 | # backward pass, with gradient scaling if training in fp16
329 | scaler.scale(loss).backward()
330 | # clip the gradient
331 | if grad_clip != 0.0:
332 | scaler.unscale_(optimizer)
333 | total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
334 | if total_norm.item() > grad_clip:
335 | clip_time += 1
336 | # step the optimizer and scaler if training in fp16
337 | scaler.step(optimizer)
338 | scaler.update()
339 | # flush the gradients as soon as we can, no need for this memory anymore
340 | optimizer.zero_grad(set_to_none=True)
341 |
342 | # timing and logging
343 | t1 = time.time()
344 | dt = t1 - t0
345 | t0 = t1
346 | if iter_num % log_interval == 0 and master_process:
347 | lossf = loss.item() # loss as float. note: this is a CPU-GPU sync point
348 | if local_iter_num >= 5: # let the training loop settle a bit
349 | mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt)
350 | running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu
351 | print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%")
352 | params = []
353 | for (name, p) in model.named_parameters():
354 | params.append(p)
355 | total_param_norm = 0
356 | for p in params:
357 | param_norm = p.data.norm(2)
358 | total_param_norm += param_norm.item() ** 2
359 | total_param_norm = total_param_norm ** 0.5
360 | momentum_norm = 0
361 | LL = len(optimizer.state_dict()['state'])
362 | for jj in range(LL):
363 | momentum_norm += (optimizer.state_dict()['state'][jj]['exp_avg'].detach().norm(2)) ** 2
364 | momentum_norm = torch.sqrt(momentum_norm).item()
365 | if wandb_log:
366 | wandb.log({
367 | "iter": iter_num,
368 | "train/loss": lossf,
369 | "lr": lr,
370 | "param_norm": total_param_norm,
371 | "momentum_norm" : momentum_norm,
372 | "train/clip_rate": clip_time / (iter_num + 1)
373 | }, step=iter_num)
374 | iter_num += 1
375 | local_iter_num += 1
376 |
377 | # termination conditions
378 | if iter_num > max_iters:
379 | break
380 |
381 | if ddp:
382 | destroy_process_group()
383 |
--------------------------------------------------------------------------------
/MARS/train_adamw_fw.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import math
4 | import pickle
5 | from contextlib import nullcontext
6 |
7 | import numpy as np
8 | import torch
9 | from torch.nn.parallel import DistributedDataParallel as DDP
10 | from torch.distributed import init_process_group, destroy_process_group
11 |
12 | from model import GPTConfig, GPT
13 | import sys
14 | from ast import literal_eval
15 | # -----------------------------------------------------------------------------
16 | # default config values designed to train a gpt2 (124M) on OpenWebText
17 | # I/O
18 | out_dir = 'out'
19 | eval_interval = 2000
20 | log_interval = 1
21 | eval_iters = 200
22 | eval_only = False # if True, script exits right after the first eval
23 | always_save_checkpoint = True # if True, always save a checkpoint after each eval
24 | init_from = 'scratch' # 'scratch' or 'resume' or 'gpt2*'
25 | # wandb logging
26 | wandb_log = False # disabled by default
27 | wandb_project = 'mars'
28 | wandb_run_name = 'gpt2' # 'run' + str(time.time())
29 | # data
30 | dataset = 'fineweb-edu100B'
31 | gradient_accumulation_steps = 5 # used to simulate larger batch sizes
32 | batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size
33 | block_size = 1024
34 | # model
35 | n_layer = 12
36 | n_head = 12
37 | n_embd = 768
38 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
39 | bias = False # do we use bias inside LayerNorm and Linear layers?
40 | # optimizer
41 | optimizer_name = 'adamw'
42 | learning_rate = 6e-4 # max learning rate
43 | max_iters = 600000 # total number of training iterations
44 | weight_decay = 1e-1
45 | beta1 = 0.9
46 | beta2 = 0.95
47 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
48 | interval = 10
49 | variant = 4
50 | schedule='cosine'
51 | # learning rate decay settings
52 | decay_lr = True # whether to decay the learning rate
53 | warmup_iters = 2000 # how many steps to warm up for
54 | lr_decay_iters = 600000 # should be ~= max_iters per Chinchilla
55 | min_lr = 6e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
56 | # DDP settings
57 | backend = 'nccl' # 'nccl', 'gloo', etc.
58 | # system
59 | device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
60 | dtype = 'bfloat16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
61 | compile = True # use PyTorch 2.0 to compile the model to be faster
62 | scale_attn_by_inverse_layer_idx = True
63 | # -----------------------------------------------------------------------------
64 | config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
65 | for arg in sys.argv[1:]:
66 | if '=' not in arg:
67 | # assume it's the name of a config file
68 | assert not arg.startswith('--')
69 | config_file = arg
70 | print(f"Overriding config with {config_file}:")
71 | with open(config_file) as f:
72 | print(f.read())
73 | exec(open(config_file).read())
74 | else:
75 | # assume it's a --key=value argument
76 | assert arg.startswith('--')
77 | key, val = arg.split('=')
78 | key = key[2:]
79 | if key in globals():
80 | try:
81 | # attempt to eval it it (e.g. if bool, number, or etc)
82 | attempt = literal_eval(val)
83 | except (SyntaxError, ValueError):
84 | # if that goes wrong, just use the string
85 | attempt = val
86 | # ensure the types match ok
87 | assert type(attempt) == type(globals()[key])
88 | # cross fingers
89 | print(f"Overriding: {key} = {attempt}")
90 | globals()[key] = attempt
91 | else:
92 | raise ValueError(f"Unknown config key: {key}")
93 |
94 | config = {k: globals()[k] for k in config_keys} # will be useful for logging
95 | # -----------------------------------------------------------------------------
96 |
97 | # various inits, derived attributes, I/O setup
98 | ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
99 | if ddp:
100 | init_process_group(backend=backend)
101 | ddp_rank = int(os.environ['RANK'])
102 | ddp_local_rank = int(os.environ['LOCAL_RANK'])
103 | device = f'cuda:{ddp_local_rank}'
104 | torch.cuda.set_device(device)
105 | master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
106 | seed_offset = ddp_rank # each process gets a different seed
107 | else:
108 | # if not ddp, we are running on a single gpu, and one process
109 | master_process = True
110 | seed_offset = 0
111 | gradient_accumulation_steps *= 8 # simulate 8 gpus
112 |
113 | if master_process:
114 | os.makedirs(out_dir, exist_ok=True)
115 | torch.manual_seed(5000 + seed_offset)
116 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
117 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
118 | device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
119 | # note: float16 data type will automatically use a GradScaler
120 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
121 | ctx = nullcontext() if device_type == 'cpu' else torch.autocast(device_type=device_type, dtype=ptdtype)
122 |
123 | # poor man's data loader
124 | data_dir = os.path.join('data', dataset)
125 | train_file_list = list(filter(lambda x: x.endswith('.bin') and x.startswith('fineweb_train'), os.listdir(data_dir)))
126 | train_data_list = [np.memmap(os.path.join(data_dir, file), dtype=np.uint16, mode='r') for file in train_file_list]
127 | val_data = np.memmap(os.path.join(data_dir, 'fineweb_val_000000.bin'), dtype=np.uint16, mode='r')
128 | import random
129 | random.seed(5000 + seed_offset)
130 | def get_batch(split):
131 | if split == 'train':
132 | data = random.choice(train_data_list)
133 | else:
134 | data = val_data
135 | offset = 512
136 | ix = torch.randint(len(data) - block_size - offset, (batch_size,))
137 | x = torch.stack([torch.from_numpy((data[offset+i:offset+i+block_size]).astype(np.int64)) for i in ix])
138 | y = torch.stack([torch.from_numpy((data[offset+i+1:offset+i+1+block_size]).astype(np.int64)) for i in ix])
139 | if device_type == 'cuda':
140 | # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
141 | x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
142 | else:
143 | x, y = x.to(device), y.to(device)
144 | return x, y
145 |
146 | # init these up here, can override if init_from='resume' (i.e. from a checkpoint)
147 | iter_num = 0
148 | best_val_loss = 1e9
149 |
150 | # attempt to derive vocab_size from the dataset
151 | meta_path = os.path.join(data_dir, 'meta.pkl')
152 | meta_vocab_size = None
153 | if os.path.exists(meta_path):
154 | with open(meta_path, 'rb') as f:
155 | meta = pickle.load(f)
156 | meta_vocab_size = meta['vocab_size']
157 | print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})")
158 |
159 | # model init
160 | model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size,
161 | bias=bias, vocab_size=None, dropout=dropout, scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx) # start with model_args from command line
162 | if init_from == 'scratch':
163 | # init a new model from scratch
164 | print("Initializing a new model from scratch")
165 | # determine the vocab size we'll use for from-scratch training
166 | if meta_vocab_size is None:
167 | print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)")
168 | model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304
169 | gptconf = GPTConfig(**model_args)
170 | model = GPT(gptconf)
171 | elif init_from == 'resume':
172 | print(f"Resuming training from {out_dir}")
173 | # resume training from a checkpoint.
174 | ckpt_path = os.path.join(out_dir, 'ckpt.pt')
175 | checkpoint = torch.load(ckpt_path, map_location=device)
176 | checkpoint_model_args = checkpoint['model_args']
177 | # force these config attributes to be equal otherwise we can't even resume training
178 | # the rest of the attributes (e.g. dropout) can stay as desired from command line
179 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
180 | model_args[k] = checkpoint_model_args[k]
181 | # create the model
182 | gptconf = GPTConfig(**model_args)
183 | model = GPT(gptconf)
184 | state_dict = checkpoint['model']
185 | # fix the keys of the state dictionary :(
186 | # honestly no idea how checkpoints sometimes get this prefix, have to debug more
187 | unwanted_prefix = '_orig_mod.'
188 | for k,v in list(state_dict.items()):
189 | if k.startswith(unwanted_prefix):
190 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
191 | model.load_state_dict(state_dict)
192 | iter_num = checkpoint['iter_num']
193 | best_val_loss = checkpoint['best_val_loss']
194 | elif init_from.startswith('gpt2'):
195 | print(f"Initializing from OpenAI GPT-2 weights: {init_from}")
196 | # initialize from OpenAI GPT-2 weights
197 | override_args = dict(dropout=dropout)
198 | model = GPT.from_pretrained(init_from, override_args)
199 | # read off the created config params, so we can store them into checkpoint correctly
200 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
201 | model_args[k] = getattr(model.config, k)
202 | # crop down the model block size if desired, using model surgery
203 | if block_size < model.config.block_size:
204 | model.crop_block_size(block_size)
205 | model_args['block_size'] = block_size # so that the checkpoint will have the right value
206 | model.to(device)
207 |
208 | # initialize a GradScaler. If enabled=False scaler is a no-op
209 | scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))
210 |
211 | # optimizer
212 | optimizer = model.configure_optimizers(optimizer_name, weight_decay, learning_rate, (beta1, beta2), device_type)
213 | if init_from == 'resume':
214 | optimizer.load_state_dict(checkpoint['optimizer'])
215 | del state_dict
216 | del checkpoint
217 | # compile the model
218 | if compile:
219 | print("compiling the model... (takes a ~minute)")
220 | unoptimized_model = model
221 | model = torch.compile(model) # requires PyTorch 2.0
222 |
223 | # wrap model into DDP container
224 | if ddp:
225 | model = DDP(model, device_ids=[ddp_local_rank])
226 |
227 | # helps estimate an arbitrarily accurate loss over either split using many batches
228 | @torch.no_grad()
229 | def estimate_loss():
230 | out = {}
231 | model.eval()
232 | for split in ['train', 'val']:
233 | losses = torch.zeros(eval_iters)
234 | for k in range(eval_iters):
235 | X, Y = get_batch(split)
236 | with ctx:
237 | logits, loss = model(X, Y)
238 | losses[k] = loss.item()
239 | out[split] = losses.mean()
240 | model.train()
241 | return out
242 |
243 | # learning rate decay scheduler (cosine with warmup)
244 | def get_lr(it, schedule='cosine'):
245 | #ing rate schedule {schedule}")
246 | # 1) linear warmup for warmup_iters steps
247 | if it < warmup_iters:
248 | return learning_rate * it / warmup_iters
249 | # 2) if it > lr_decay_iters, return min learning rate
250 | if schedule=='wsd':
251 | if it < 0.8 * max_iters:
252 | return learning_rate
253 | else:
254 | return learning_rate * (max_iters - it) / (max_iters * 0.2)
255 | if it > lr_decay_iters:
256 | return min_lr
257 | # 3) in between, use cosine decay down to min learning rate
258 | decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
259 | assert 0 <= decay_ratio <= 1
260 | if schedule=='cosine':
261 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
262 | elif schedule=='exp':
263 | coeff = np.power(0.9, 100 * decay_ratio)
264 |
265 | return min_lr + coeff * (learning_rate - min_lr)
266 |
267 | # logging
268 | if wandb_log and master_process:
269 | import wandb
270 | wandb.init(project=wandb_project, name=wandb_run_name, config=config)
271 |
272 | # training loop
273 | X, Y = get_batch('train') # fetch the very first batch
274 | t0 = time.time()
275 | local_iter_num = 0 # number of iterations in the lifetime of this process
276 | raw_model = model.module if ddp else model # unwrap DDP container if needed
277 | running_mfu = -1.0
278 | clip_time = 0
279 | while True:
280 |
281 | # determine and set the learning rate for this iteration
282 | lr = get_lr(iter_num, schedule=schedule) if decay_lr else learning_rate
283 | for param_group in optimizer.param_groups:
284 | param_group['lr'] = lr
285 |
286 | # evaluate the loss on train/val sets and write checkpoints
287 | if iter_num % eval_interval == 0 and master_process:
288 | losses = estimate_loss()
289 | print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
290 | if wandb_log:
291 | wandb.log({
292 | "iter": iter_num,
293 | "train/loss": losses['train'],
294 | "val/loss": losses['val'],
295 | "lr": lr,
296 | "mfu": running_mfu*100, # convert to percentage
297 | }, step=iter_num)
298 | if losses['val'] < best_val_loss or always_save_checkpoint:
299 | best_val_loss = losses['val']
300 | if iter_num > 0:
301 | checkpoint = {
302 | 'model': raw_model.state_dict(),
303 | 'optimizer': optimizer.state_dict(),
304 | 'model_args': model_args,
305 | 'iter_num': iter_num,
306 | 'best_val_loss': best_val_loss,
307 | 'config': config,
308 | }
309 | print(f"saving checkpoint to {out_dir}")
310 | torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
311 | if iter_num % (eval_interval * 5) == 0:
312 | checkpoint = {
313 | 'model': raw_model.state_dict(),
314 | 'optimizer': optimizer.state_dict(),
315 | 'model_args': model_args,
316 | 'iter_num': iter_num,
317 | 'best_val_loss': best_val_loss,
318 | 'config': config,
319 | }
320 | print(f"saving checkpoint to {out_dir}")
321 | torch.save(checkpoint, os.path.join(out_dir, 'ckpt_'+str(iter_num)+'.pt'))
322 | if iter_num == 0 and eval_only:
323 | break
324 |
325 | # forward backward update, with optional gradient accumulation to simulate larger batch size
326 | # and using the GradScaler if data type is float16
327 | for micro_step in range(gradient_accumulation_steps):
328 | if ddp:
329 | # in DDP training we only need to sync gradients at the last micro step.
330 | # the official way to do this is with model.no_sync() context manager, but
331 | # I really dislike that this bloats the code and forces us to repeat code
332 | # looking at the source of that context manager, it just toggles this variable
333 | model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
334 | with ctx:
335 | logits, loss = model(X, Y)
336 | # immediately async prefetch next batch while model is doing the forward pass on the GPU
337 | X, Y = get_batch('train')
338 | # backward pass, with gradient scaling if training in fp16
339 | scaler.scale(loss).backward()
340 | # clip the gradient
341 | if grad_clip != 0.0:
342 | scaler.unscale_(optimizer)
343 | total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
344 | if total_norm.item() > grad_clip:
345 | clip_time += 1
346 | # step the optimizer and scaler if training in fp16
347 | scaler.step(optimizer)
348 | scaler.update()
349 | # flush the gradients as soon as we can, no need for this memory anymore
350 | optimizer.zero_grad(set_to_none=True)
351 |
352 | # timing and logging
353 | t1 = time.time()
354 | dt = t1 - t0
355 | t0 = t1
356 | if iter_num % log_interval == 0 and master_process:
357 | lossf = loss.item() # loss as float. note: this is a CPU-GPU sync point
358 | if local_iter_num >= 5: # let the training loop settle a bit
359 | mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt)
360 | running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu
361 | print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%")
362 | params = []
363 | for (name, p) in model.named_parameters():
364 | params.append(p)
365 | total_param_norm = 0
366 | for p in params:
367 | param_norm = p.data.norm(2)
368 | total_param_norm += param_norm.item() ** 2
369 | total_param_norm = total_param_norm ** 0.5
370 | momentum_norm = 0
371 | LL = len(optimizer.state_dict()['state'])
372 | for jj in range(LL):
373 | momentum_norm += (optimizer.state_dict()['state'][jj]['exp_avg'].detach().norm(2)) ** 2
374 | momentum_norm = torch.sqrt(momentum_norm).item()
375 | if wandb_log:
376 | wandb.log({
377 | "iter": iter_num,
378 | "train/loss": lossf,
379 | "lr": lr,
380 | "param_norm": total_param_norm,
381 | "momentum_norm" : momentum_norm,
382 | "train/clip_rate": clip_time / (iter_num + 1)
383 | }, step=iter_num)
384 | iter_num += 1
385 | local_iter_num += 1
386 |
387 | # termination conditions
388 | if iter_num > max_iters:
389 | break
390 |
391 | if ddp:
392 | destroy_process_group()
393 |
--------------------------------------------------------------------------------
/MARS/train_mars.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import math
4 | import pickle
5 | from contextlib import nullcontext
6 | from collections import deque
7 |
8 | import numpy as np
9 | import torch
10 | from torch.nn.parallel import DistributedDataParallel as DDP
11 | from torch.distributed import init_process_group, destroy_process_group
12 |
13 | from model import GPTConfig, GPT
14 | import sys
15 | from ast import literal_eval
16 | # -----------------------------------------------------------------------------
17 | # default config values designed to train a gpt2 (124M) on OpenWebText
18 | # I/O
19 | data_path = "./data"
20 | out_dir = 'out'
21 | eval_interval = 2000
22 | log_interval = 1
23 | eval_iters = 200
24 | eval_only = False # if True, script exits right after the first eval
25 | always_save_checkpoint = False # if True, always save a checkpoint after each eval
26 | init_from = 'scratch' # 'scratch' or 'resume' or 'gpt2*'
27 | # wandb logging
28 | wandb_log = False # disabled by default
29 | wandb_project = 'owt'
30 | wandb_run_name = 'gpt2' # 'run' + str(time.time())
31 | # data
32 | dataset = 'openwebtext'
33 | gradient_accumulation_steps = 5 # used to simulate larger batch sizes
34 | batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size
35 | initial_steps = 100
36 | block_size = 1024
37 | # model
38 | n_layer = 12
39 | n_head = 12
40 | n_embd = 768
41 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
42 | bias = False # do we use bias inside LayerNorm and Linear layers?
43 | # optimizer
44 | optimizer_name = 'mars'
45 | learning_rate = 6e-4 # max learning rate
46 | max_iters = 600000 # total number of training iterations
47 | weight_decay = 1e-1
48 | beta1 = 0.95
49 | beta2 = 0.99
50 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
51 | interval = 10
52 | variant = 4
53 | # learning rate decay settings
54 | decay_lr = True # whether to decay the learning rate
55 | warmup_iters = 2000 # how many steps to warm up for
56 | lr_decay_iters = 600000 # should be ~= max_iters per Chinchilla
57 | min_lr = 6e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
58 | # DDP settings
59 | backend = 'nccl' # 'nccl', 'gloo', etc.
60 | # system
61 | device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
62 | dtype = 'bfloat16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
63 | compile = True # use PyTorch 2.0 to compile the model to be faster
64 | scale_attn_by_inverse_layer_idx = True
65 | # learning rate schedule
66 | schedule='cosine'
67 | scheme='exact'
68 | gamma=0.025
69 | lr_1d=3e-3
70 | is_approx=True
71 | mars_type="mars-adamw"
72 | optimize_1d=False
73 | weight_decay_1d=0.1
74 | betas_1d=(0.9, 0.95)
75 | # -----------------------------------------------------------------------------
76 | config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
77 | for arg in sys.argv[1:]:
78 | if '=' not in arg:
79 | # assume it's the name of a config file
80 | assert not arg.startswith('--')
81 | config_file = arg
82 | print(f"Overriding config with {config_file}:")
83 | with open(config_file) as f:
84 | print(f.read())
85 | exec(open(config_file).read())
86 | else:
87 | # assume it's a --key=value argument
88 | assert arg.startswith('--')
89 | key, val = arg.split('=')
90 | key = key[2:]
91 | if key in globals():
92 | try:
93 | # attempt to eval it it (e.g. if bool, number, or etc)
94 | attempt = literal_eval(val)
95 | except (SyntaxError, ValueError):
96 | # if that goes wrong, just use the string
97 | attempt = val
98 | # ensure the types match ok
99 | assert type(attempt) == type(globals()[key])
100 | # cross fingers
101 | print(f"Overriding: {key} = {attempt}")
102 | globals()[key] = attempt
103 | else:
104 | raise ValueError(f"Unknown config key: {key}")
105 | config = {k: globals()[k] for k in config_keys} # will be useful for logging
106 | # -----------------------------------------------------------------------------
107 |
108 | # various inits, derived attributes, I/O setup
109 | ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
110 | if ddp:
111 | init_process_group(backend=backend)
112 | ddp_rank = int(os.environ['RANK'])
113 | ddp_local_rank = int(os.environ['LOCAL_RANK'])
114 | device = f'cuda:{ddp_local_rank}'
115 | torch.cuda.set_device(device)
116 | master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
117 | seed_offset = ddp_rank # each process gets a different seed
118 | else:
119 | # if not ddp, we are running on a single gpu, and one process
120 | master_process = True
121 | seed_offset = 0
122 | gradient_accumulation_steps *= 8 # simulate 8 gpus
123 |
124 | if master_process:
125 | os.makedirs(out_dir, exist_ok=True)
126 | torch.manual_seed(5000 + seed_offset)
127 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
128 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
129 | device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
130 | # note: float16 data type will automatically use a GradScaler
131 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
132 | ctx = nullcontext() if device_type == 'cpu' else torch.autocast(device_type=device_type, dtype=ptdtype)
133 |
134 | # poor man's data loader
135 | data_dir = os.path.join(data_path, dataset)
136 | train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
137 | val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
138 | def get_batch(split):
139 | data = train_data if split == 'train' else val_data
140 | ix = torch.randint(len(data) - block_size, (batch_size,))
141 | x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
142 | y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
143 | if device_type == 'cuda':
144 | # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
145 | x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
146 | else:
147 | x, y = x.to(device), y.to(device)
148 | return x, y
149 |
150 | # init these up here, can override if init_from='resume' (i.e. from a checkpoint)
151 | iter_num = 0
152 | best_val_loss = 1e9
153 |
154 | # attempt to derive vocab_size from the dataset
155 | meta_path = os.path.join(data_dir, 'meta.pkl')
156 | meta_vocab_size = None
157 | if os.path.exists(meta_path):
158 | with open(meta_path, 'rb') as f:
159 | meta = pickle.load(f)
160 | meta_vocab_size = meta['vocab_size']
161 | print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})")
162 |
163 | # model init
164 | model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size,
165 | bias=bias, vocab_size=None, dropout=dropout, scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx) # start with model_args from command line
166 | if init_from == 'scratch':
167 | # init a new model from scratch
168 | print("Initializing a new model from scratch")
169 | # determine the vocab size we'll use for from-scratch training
170 | if meta_vocab_size is None:
171 | print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)")
172 | model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304
173 | gptconf = GPTConfig(**model_args)
174 | model = GPT(gptconf)
175 | elif init_from == 'resume':
176 | print(f"Resuming training from {out_dir}")
177 | # resume training from a checkpoint.
178 | ckpt_path = os.path.join(out_dir, 'ckpt.pt')
179 | checkpoint = torch.load(ckpt_path, map_location=device)
180 | checkpoint_model_args = checkpoint['model_args']
181 | # force these config attributes to be equal otherwise we can't even resume training
182 | # the rest of the attributes (e.g. dropout) can stay as desired from command line
183 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
184 | model_args[k] = checkpoint_model_args[k]
185 | # create the model
186 | gptconf = GPTConfig(**model_args)
187 | model = GPT(gptconf)
188 | state_dict = checkpoint['model']
189 | # fix the keys of the state dictionary :(
190 | # honestly no idea how checkpoints sometimes get this prefix, have to debug more
191 | unwanted_prefix = '_orig_mod.'
192 | for k,v in list(state_dict.items()):
193 | if k.startswith(unwanted_prefix):
194 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
195 | model.load_state_dict(state_dict)
196 | iter_num = checkpoint['iter_num']
197 | best_val_loss = checkpoint['best_val_loss']
198 | elif init_from.startswith('gpt2'):
199 | print(f"Initializing from OpenAI GPT-2 weights: {init_from}")
200 | # initialize from OpenAI GPT-2 weights
201 | override_args = dict(dropout=dropout)
202 | model = GPT.from_pretrained(init_from, override_args)
203 | # read off the created config params, so we can store them into checkpoint correctly
204 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
205 | model_args[k] = getattr(model.config, k)
206 | # crop down the model block size if desired, using model surgery
207 | if block_size < model.config.block_size:
208 | model.crop_block_size(block_size)
209 | model_args['block_size'] = block_size # so that the checkpoint will have the right value
210 | model.to(device)
211 |
212 | # initialize a GradScaler. If enabled=False scaler is a no-op
213 | scaler = torch.amp.GradScaler('cuda', enabled=(dtype == 'float16'))
214 | other_params = {'gamma': gamma, 'is_approx': is_approx, 'mars_type': mars_type, 'optimize_1d': optimize_1d,
215 | 'lr_1d': lr_1d, 'betas_1d': betas_1d, 'weight_decay_1d': weight_decay_1d}
216 | # optimizer
217 | optimizer = model.configure_optimizers(optimizer_name, weight_decay, learning_rate, (beta1, beta2), device_type,
218 | other_params)
219 | if init_from == 'resume':
220 | optimizer.load_state_dict(checkpoint['optimizer'])
221 | del state_dict
222 | del checkpoint
223 | # compile the model
224 | if compile:
225 | print("compiling the model... (takes a ~minute)")
226 | unoptimized_model = model
227 | model = torch.compile(model) # requires PyTorch 2.0
228 |
229 | # wrap model into DDP container
230 | if ddp:
231 | print('DDP_used')
232 | model = DDP(model, device_ids=[ddp_local_rank])
233 |
234 | # helps estimate an arbitrarily accurate loss over either split using many batches
235 | @torch.no_grad()
236 | def estimate_loss():
237 | out = {}
238 | model.eval()
239 | for split in ['train', 'val']:
240 | losses = torch.zeros(eval_iters)
241 | for k in range(eval_iters):
242 | X, Y = get_batch(split)
243 | with ctx:
244 | logits, loss = model(X, Y)
245 | losses[k] = loss.item()
246 | out[split] = losses.mean()
247 | model.train()
248 | return out
249 |
250 | # learning rate decay scheduler (cosine with warmup)
251 | def get_lr(it, schedule='cosine'):
252 | #ing rate schedule {schedule}")
253 | # 1) linear warmup for warmup_iters steps
254 | if it < warmup_iters:
255 | return learning_rate * it / warmup_iters
256 | # 2) if it > lr_decay_iters, return min learning rate
257 | if it > lr_decay_iters:
258 | return min_lr
259 | # 3) in between, use cosine decay down to min learning rate
260 | decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
261 | assert 0 <= decay_ratio <= 1
262 | if schedule=='cosine':
263 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
264 | elif schedule=='exp':
265 | coeff = np.power(0.9, 100 * decay_ratio)
266 | return min_lr + coeff * (learning_rate - min_lr)
267 |
268 | # logging
269 | if wandb_log and master_process:
270 | import wandb
271 | wandb.init(project=wandb_project, name=wandb_run_name, config=config)
272 |
273 | # training loop
274 | #X, Y = get_batch('train') # fetch the very first batch
275 | Xs=deque([])
276 | Ys=deque([])
277 | for micro_step in range(gradient_accumulation_steps):
278 | X, Y = get_batch('train')
279 | Xs.append(X)
280 | Ys.append(Y)
281 | t0 = time.time()
282 | local_iter_num = 0 # number of iterations in the lifetime of this process
283 | raw_model = model.module if ddp else model # unwrap DDP container if needed
284 | running_mfu = -1.0
285 | clip_time = 0
286 | while True:
287 |
288 | # determine and set the learning rate for this iteration
289 | lr = get_lr(iter_num, schedule=schedule) if decay_lr else learning_rate
290 | for param_group in optimizer.param_groups:
291 | param_group['lr'] = lr
292 |
293 | # evaluate the loss on train/val sets and write checkpoints
294 | if iter_num % eval_interval == 0 and master_process:
295 | losses = estimate_loss()
296 | print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
297 | if wandb_log:
298 | wandb.log({
299 | "iter": iter_num,
300 | "train/loss": losses['train'],
301 | "val/loss": losses['val'],
302 | "lr": lr,
303 | "mfu": running_mfu*100, # convert to percentage
304 | }, step=iter_num)
305 | if losses['val'] < best_val_loss or always_save_checkpoint:
306 | best_val_loss = losses['val']
307 | if iter_num > 0:
308 | checkpoint = {
309 | 'model': raw_model.state_dict(),
310 | 'optimizer': optimizer.state_dict(),
311 | 'model_args': model_args,
312 | 'iter_num': iter_num,
313 | 'best_val_loss': best_val_loss,
314 | 'config': config,
315 | }
316 | print(f"saving checkpoint to {out_dir}")
317 | torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
318 | if iter_num % (eval_interval * 5) == 0:
319 | checkpoint = {
320 | 'model': raw_model.state_dict(),
321 | 'optimizer': optimizer.state_dict(),
322 | 'model_args': model_args,
323 | 'iter_num': iter_num,
324 | 'best_val_loss': best_val_loss,
325 | 'config': config,
326 | }
327 | print(f"saving checkpoint to {out_dir}")
328 | torch.save(checkpoint, os.path.join(out_dir, 'ckpt_'+str(iter_num)+'.pt'))
329 | if iter_num == 0 and eval_only:
330 | break
331 |
332 | # forward backward update, with optional gradient accumulation to simulate larger batch size
333 | # and using the GradScaler if data type is float16
334 | minibatch_size = gradient_accumulation_steps
335 | X_cur = []
336 | Y_cur = []
337 | ## Update datasets
338 | for micro_step in range(minibatch_size):
339 | X_cur.append(Xs.popleft())
340 | Y_cur.append(Ys.popleft())
341 | X, Y = get_batch('train')
342 | Xs.append(X)
343 | Ys.append(Y)
344 | ## Calculate previous gradient with future batch data first, this information should be used at the next iteration.
345 | if scheme == 'exact' and not is_approx:
346 | ### Calculate the gradient again using the new batch
347 | for micro_step in range(gradient_accumulation_steps):
348 | if ddp:
349 | # in DDP training we only need to sync gradients at the last micro step.
350 | # the official way to do this is with model.no_sync() context manager, but
351 | # I really dislike that this bloats the code and forces us to repeat code
352 | # looking at the source of that context manager, it just toggles this variable
353 | model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
354 | with ctx:
355 | X = Xs[micro_step]
356 | Y = Ys[micro_step]
357 | logits, loss = model(X, Y)
358 | # immediately async prefetch next batch while model is doing the forward pass on the GPU
359 | # backward pass, with gradient scaling if training in fp16
360 | scaler.scale(loss).backward()
361 | if grad_clip != 0.0:
362 | scaler.unscale_(optimizer)
363 | total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
364 | if total_norm.item() > grad_clip:
365 | clip_time += 1
366 | elif (grad_clip == 0.0) and (optimizer.gamma == 0.0):
367 | scaler.unscale_(optimizer)
368 | total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
369 | if total_norm.item() > 1.0:
370 | clip_time += 1
371 | ### Update the previous grad of the next iteration
372 | optimizer.update_previous_grad()
373 |
374 | # flush the gradients as soon as we can, no need for this memory anymore
375 | optimizer.zero_grad(set_to_none=True)
376 |
377 | ## Calculate the gradient of the current batch
378 | for micro_step in range(minibatch_size):
379 | if ddp:
380 | # in DDP training we only need to sync gradients at the last micro step.
381 | # the official way to do this is with model.no_sync() context manager, but
382 | # I really dislike that this bloats the code and forces us to repeat code
383 | # looking at the source of that context manager, it just toggles this variable
384 | model.require_backward_grad_sync = (micro_step == minibatch_size - 1)
385 | with ctx:
386 | X = X_cur[micro_step]
387 | Y = Y_cur[micro_step]
388 | logits, loss = model(X, Y)
389 | # immediately async prefetch next batch while model is doing the forward pass on the GPU
390 | # backward pass, with gradient scaling if training in fp16
391 | scaler.scale(loss).backward()
392 | # clip the gradient
393 | if grad_clip != 0.0:
394 | scaler.unscale_(optimizer)
395 | total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
396 | if total_norm.item() > grad_clip:
397 | clip_time += 1
398 | ### First update the current value of gradient
399 | #optimizer.update_current_grad()
400 | # step the optimizer and scaler if training in fp16
401 | scaler.step(optimizer)
402 | scaler.update()
403 | ### TODO: Clean the grad
404 | optimizer.zero_grad(set_to_none=True)
405 | optimizer.update_last_grad()
406 |
407 | # timing and logging
408 | t1 = time.time()
409 | dt = t1 - t0
410 | t0 = t1
411 | if iter_num % log_interval == 0 and master_process:
412 | lossf = loss.item() # loss as float. note: this is a CPU-GPU sync point
413 | if local_iter_num >= 5: # let the training loop settle a bit
414 | mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt)
415 | running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu
416 | print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%")
417 | params = []
418 | for (name, p) in model.named_parameters():
419 | params.append(p)
420 | total_param_norm = 0
421 | for p in params:
422 | param_norm = p.data.norm(2)
423 | total_param_norm += param_norm.item() ** 2
424 | total_param_norm = total_param_norm ** 0.5
425 | momentum_norm = 0
426 | momentum_norm_sq = 0
427 | momentum_div = 0
428 | LL = len(optimizer.state_dict()['state'])
429 | for jj in range(LL):
430 | momentum_norm += (optimizer.state_dict()['state'][jj]['exp_avg'].detach().norm(2)) ** 2
431 | momentum_norm_sq += (optimizer.state_dict()['state'][jj]['exp_avg_sq'].detach().norm(2)) ** 2
432 | momentum_norm = torch.sqrt(momentum_norm).item()
433 | momentum_norm_sq = torch.sqrt(momentum_norm_sq).item()
434 | momentum_div = momentum_norm/(np.sqrt(momentum_norm_sq)+1e-8)
435 | if wandb_log:
436 | wandb.log({
437 | "iter": iter_num,
438 | "train/loss": lossf,
439 | "lr": lr,
440 | "param_norm": total_param_norm,
441 | "momentum_norm" : momentum_norm,
442 | "momentum_norm_sq": momentum_norm_sq,
443 | "momentum_div": momentum_div,
444 | "train/clip_rate": clip_time / (iter_num + 1)
445 | }, step=iter_num)
446 | iter_num += 1
447 | local_iter_num += 1
448 |
449 | # termination conditions
450 | if iter_num > max_iters:
451 | break
452 |
453 | if ddp:
454 | destroy_process_group()
455 |
--------------------------------------------------------------------------------
/MARS/train_muon.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import math
4 | import pickle
5 | from contextlib import nullcontext
6 |
7 | import numpy as np
8 | import torch
9 | from torch.nn.parallel import DistributedDataParallel as DDP
10 | from torch.distributed import init_process_group, destroy_process_group
11 |
12 | from model import GPTConfig, GPT
13 | import sys
14 | from ast import literal_eval
15 | # -----------------------------------------------------------------------------
16 | # default config values designed to train a gpt2 (124M) on OpenWebText
17 | # I/O
18 | data_path = "./data"
19 | out_dir = 'out'
20 | eval_interval = 2000
21 | log_interval = 1
22 | eval_iters = 200
23 | eval_only = False # if True, script exits right after the first eval
24 | always_save_checkpoint = True # if True, always save a checkpoint after each eval
25 | init_from = 'scratch' # 'scratch' or 'resume' or 'gpt2*'
26 | # wandb logging
27 | wandb_log = False # disabled by default
28 | wandb_project = 'owt'
29 | wandb_run_name = 'gpt2' # 'run' + str(time.time())
30 | # data
31 | dataset = 'openwebtext'
32 | gradient_accumulation_steps = 5 # used to simulate larger batch sizes
33 | batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size
34 | block_size = 1024
35 | # model
36 | n_layer = 12
37 | n_head = 12
38 | n_embd = 768
39 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
40 | bias = False # do we use bias inside LayerNorm and Linear layers?
41 | # optimizer
42 | optimizer_name = 'muon'
43 | learning_rate = 6e-4 # max learning rate
44 | muon_learning_rate = 2e-2
45 | max_iters = 600000 # total number of training iterations
46 | weight_decay = 1e-1
47 | muon_weight_decay = 0.
48 | beta1 = 0.95
49 | beta2 = 0.99
50 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
51 | interval = 10
52 | variant = 4
53 | # learning rate decay settings
54 | decay_lr = True # whether to decay the learning rate
55 | warmup_iters = 2000 # how many steps to warm up for‘
56 | warmdown_iters = 2000
57 | lr_decay_iters = 600000 # should be ~= max_iters per Chinchilla
58 | min_lr = 6e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
59 | # DDP settings
60 | backend = 'nccl' # 'nccl', 'gloo', etc.
61 | schedule = 'cosine'
62 | # system
63 | device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
64 | dtype = 'bfloat16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
65 | compile = True # use PyTorch 2.0 to compile the model to be faster
66 | scale_attn_by_inverse_layer_idx = True
67 | # -----------------------------------------------------------------------------
68 | config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
69 | for arg in sys.argv[1:]:
70 | if '=' not in arg:
71 | # assume it's the name of a config file
72 | assert not arg.startswith('--')
73 | config_file = arg
74 | print(f"Overriding config with {config_file}:")
75 | with open(config_file) as f:
76 | print(f.read())
77 | exec(open(config_file).read())
78 | else:
79 | # assume it's a --key=value argument
80 | assert arg.startswith('--')
81 | key, val = arg.split('=')
82 | key = key[2:]
83 | if key in globals():
84 | try:
85 | # attempt to eval it it (e.g. if bool, number, or etc)
86 | attempt = literal_eval(val)
87 | except (SyntaxError, ValueError):
88 | # if that goes wrong, just use the string
89 | attempt = val
90 | # ensure the types match ok
91 | assert type(attempt) == type(globals()[key])
92 | # cross fingers
93 | print(f"Overriding: {key} = {attempt}")
94 | globals()[key] = attempt
95 | else:
96 | raise ValueError(f"Unknown config key: {key}")
97 | config = {k: globals()[k] for k in config_keys} # will be useful for logging
98 | # -----------------------------------------------------------------------------
99 |
100 | # various inits, derived attributes, I/O setup
101 | ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
102 | ddp_world_size = int(os.environ['WORLD_SIZE'])
103 | if ddp:
104 | init_process_group(backend=backend)
105 | ddp_rank = int(os.environ['RANK'])
106 | ddp_local_rank = int(os.environ['LOCAL_RANK'])
107 | device = f'cuda:{ddp_local_rank}'
108 | torch.cuda.set_device(device)
109 | master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
110 | seed_offset = ddp_rank # each process gets a different seed
111 | else:
112 | # if not ddp, we are running on a single gpu, and one process
113 | master_process = True
114 | seed_offset = 0
115 | gradient_accumulation_steps *= 8 # simulate 8 gpus
116 |
117 | if master_process:
118 | os.makedirs(out_dir, exist_ok=True)
119 | torch.manual_seed(5000 + seed_offset)
120 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
121 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
122 | device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
123 | # note: float16 data type will automatically use a GradScaler
124 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
125 | ctx = nullcontext() if device_type == 'cpu' else torch.autocast(device_type=device_type, dtype=ptdtype)
126 |
127 | # poor man's data loader
128 | data_dir = os.path.join(data_path, dataset)
129 | train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
130 | val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
131 | def get_batch(split):
132 | data = train_data if split == 'train' else val_data
133 | ix = torch.randint(len(data) - block_size, (batch_size,))
134 | x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
135 | y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
136 | if device_type == 'cuda':
137 | # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
138 | x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
139 | else:
140 | x, y = x.to(device), y.to(device)
141 | return x, y
142 |
143 | # init these up here, can override if init_from='resume' (i.e. from a checkpoint)
144 | iter_num = 0
145 | best_val_loss = 1e9
146 |
147 | # attempt to derive vocab_size from the dataset
148 | meta_path = os.path.join(data_dir, 'meta.pkl')
149 | meta_vocab_size = None
150 | if os.path.exists(meta_path):
151 | with open(meta_path, 'rb') as f:
152 | meta = pickle.load(f)
153 | meta_vocab_size = meta['vocab_size']
154 | print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})")
155 |
156 | # model init
157 | model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size,
158 | bias=bias, vocab_size=None, dropout=dropout, scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx) # start with model_args from command line
159 | if init_from == 'scratch':
160 | # init a new model from scratch
161 | print("Initializing a new model from scratch")
162 | # determine the vocab size we'll use for from-scratch training
163 | if meta_vocab_size is None:
164 | print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)")
165 | model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304
166 | gptconf = GPTConfig(**model_args)
167 | model = GPT(gptconf)
168 | elif init_from == 'resume':
169 | print(f"Resuming training from {out_dir}")
170 | # resume training from a checkpoint.
171 | ckpt_path = os.path.join(out_dir, 'ckpt.pt')
172 | checkpoint = torch.load(ckpt_path, map_location=device)
173 | checkpoint_model_args = checkpoint['model_args']
174 | # force these config attributes to be equal otherwise we can't even resume training
175 | # the rest of the attributes (e.g. dropout) can stay as desired from command line
176 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
177 | model_args[k] = checkpoint_model_args[k]
178 | # create the model
179 | gptconf = GPTConfig(**model_args)
180 | model = GPT(gptconf)
181 | state_dict = checkpoint['model']
182 | # fix the keys of the state dictionary :(
183 | # honestly no idea how checkpoints sometimes get this prefix, have to debug more
184 | unwanted_prefix = '_orig_mod.'
185 | for k,v in list(state_dict.items()):
186 | if k.startswith(unwanted_prefix):
187 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
188 | model.load_state_dict(state_dict)
189 | iter_num = checkpoint['iter_num']
190 | best_val_loss = checkpoint['best_val_loss']
191 | elif init_from.startswith('gpt2'):
192 | print(f"Initializing from OpenAI GPT-2 weights: {init_from}")
193 | # initialize from OpenAI GPT-2 weights
194 | override_args = dict(dropout=dropout)
195 | model = GPT.from_pretrained(init_from, override_args)
196 | # read off the created config params, so we can store them into checkpoint correctly
197 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
198 | model_args[k] = getattr(model.config, k)
199 | # crop down the model block size if desired, using model surgery
200 | if block_size < model.config.block_size:
201 | model.crop_block_size(block_size)
202 | model_args['block_size'] = block_size # so that the checkpoint will have the right value
203 | model.to(device)
204 |
205 | # initialize a GradScaler. If enabled=False scaler is a no-op
206 | scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))
207 |
208 | # optimizer
209 | from optimizers.muon import Muon
210 | from optimizers.adamw import AdamW
211 | params = list(model.parameters())
212 | from opt import CombinedOptimizer
213 | # optimizer1 = AdamW([p for p in params if p.ndim == 1], weight_decay=weight_decay, lr=learning_rate, betas=(beta1, beta2))
214 | # optimizer2 = Muon([p for p in params if p.ndim == 2], lr=muon_learning_rate, rank=ddp_rank, world_size=ddp_world_size)
215 | # optimizers = [optimizer1, optimizer2]
216 | optimizer = CombinedOptimizer(params, [AdamW, Muon], [{'lr': learning_rate, 'betas': (beta1, beta2), 'weight_decay': weight_decay},
217 | {'lr': muon_learning_rate, 'weight_decay': muon_weight_decay}])
218 | if init_from == 'resume':
219 | # for optimizer in optimizers:
220 | optimizer.load_state_dict(checkpoint['optimizer'])
221 | del state_dict
222 | del checkpoint
223 | # compile the model
224 | if compile:
225 | print("compiling the model... (takes a ~minute)")
226 | unoptimized_model = model
227 | model = torch.compile(model) # requires PyTorch 2.0
228 |
229 | # wrap model into DDP container
230 | if ddp:
231 | model = DDP(model, device_ids=[ddp_local_rank])
232 |
233 | # helps estimate an arbitrarily accurate loss over either split using many batches
234 | @torch.no_grad()
235 | def estimate_loss():
236 | out = {}
237 | model.eval()
238 | for split in ['train', 'val']:
239 | losses = torch.zeros(eval_iters)
240 | for k in range(eval_iters):
241 | X, Y = get_batch(split)
242 | with ctx:
243 | logits, loss = model(X, Y)
244 | losses[k] = loss.item()
245 | out[split] = losses.mean()
246 | model.train()
247 | return out
248 |
249 | # learning rate decay scheduler (cosine with warmup)
250 | def get_lr(it, schedule='cosine', base_lr=learning_rate):
251 | # 1) linear warmup for warmup_iters steps
252 | if it < warmup_iters:
253 | return base_lr * it / warmup_iters
254 | elif it < max_iters - warmdown_iters:
255 | return base_lr
256 | else:
257 | decay_ratio = (max_iters - it) / warmdown_iters
258 | return base_lr * decay_ratio
259 |
260 | # logging
261 | if wandb_log and master_process:
262 | import wandb
263 | wandb.init(project=wandb_project, name=wandb_run_name, config=config)
264 |
265 | # training loop
266 | X, Y = get_batch('train') # fetch the very first batch
267 | t0 = time.time()
268 | local_iter_num = 0 # number of iterations in the lifetime of this process
269 | raw_model = model.module if ddp else model # unwrap DDP container if needed
270 | running_mfu = -1.0
271 | clip_time = 0
272 | while True:
273 |
274 | # determine and set the learning rate for this iteration
275 |
276 | # for optimizer in optimizers:
277 | for i in range(len(optimizer.optimizers)):
278 | lr = get_lr(iter_num, schedule=schedule, base_lr=optimizer.base_lrs[i])
279 | for param_group in optimizer.optimizers[i].param_groups:
280 | param_group['lr'] = lr
281 |
282 | # evaluate the loss on train/val sets and write checkpoints
283 | if iter_num % eval_interval == 0 and master_process:
284 | losses = estimate_loss()
285 | print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
286 | if wandb_log:
287 | wandb.log({
288 | "iter": iter_num,
289 | "train/loss": losses['train'],
290 | "val/loss": losses['val'],
291 | "lr": lr,
292 | "mfu": running_mfu*100, # convert to percentage
293 | }, step=iter_num)
294 | if losses['val'] < best_val_loss or always_save_checkpoint:
295 | best_val_loss = losses['val']
296 | if iter_num > 0:
297 | checkpoint = {
298 | 'model': raw_model.state_dict(),
299 | 'optimizer': optimizer.state_dict(),
300 | 'model_args': model_args,
301 | 'iter_num': iter_num,
302 | 'best_val_loss': best_val_loss,
303 | 'config': config,
304 | }
305 | print(f"saving checkpoint to {out_dir}")
306 | torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
307 | # model.save_pretrained(os.path.join(out_dir, 'ckpt.pt'))
308 | if iter_num % (eval_interval * 5) == 0:
309 | checkpoint = {
310 | 'model': raw_model.state_dict(),
311 | 'optimizer': optimizer.state_dict(),
312 | 'model_args': model_args,
313 | 'iter_num': iter_num,
314 | 'best_val_loss': best_val_loss,
315 | 'config': config,
316 | }
317 | print(f"saving checkpoint to {out_dir}")
318 | torch.save(checkpoint, os.path.join(out_dir, 'ckpt_'+str(iter_num)+'.pt'))
319 | # model.save_pretrained(os.path.join(out_dir, 'ckpt.pt'))
320 | if iter_num == 0 and eval_only:
321 | break
322 |
323 | # forward backward update, with optional gradient accumulation to simulate larger batch size
324 | # and using the GradScaler if data type is float16
325 | for micro_step in range(gradient_accumulation_steps):
326 | if ddp:
327 | # in DDP training we only need to sync gradients at the last micro step.
328 | # the official way to do this is with model.no_sync() context manager, but
329 | # I really dislike that this bloats the code and forces us to repeat code
330 | # looking at the source of that context manager, it just toggles this variable
331 | model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
332 | with ctx:
333 | logits, loss = model(X, Y)
334 | # immediately async prefetch next batch while model is doing the forward pass on the GPU
335 | X, Y = get_batch('train')
336 | # backward pass, with gradient scaling if training in fp16
337 | scaler.scale(loss).backward()
338 | # clip the gradient
339 | if grad_clip != 0.0:
340 | scaler.unscale_(optimizer.optimizers[0])
341 | scaler.unscale_(optimizer.optimizers[1])
342 | total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
343 | if total_norm.item() > grad_clip:
344 | clip_time += 1
345 | # step the optimizer and scaler if training in fp16
346 | scaler.step(optimizer)
347 | scaler.update()
348 | # flush the gradients as soon as we can, no need for this memory anymore
349 | optimizer.zero_grad(set_to_none=True)
350 |
351 | # timing and logging
352 | t1 = time.time()
353 | dt = t1 - t0
354 | t0 = t1
355 | if iter_num % log_interval == 0 and master_process:
356 | lossf = loss.item() # loss as float. note: this is a CPU-GPU sync point
357 | if local_iter_num >= 5: # let the training loop settle a bit
358 | mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt)
359 | running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu
360 | print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%")
361 | params = []
362 | for (name, p) in model.named_parameters():
363 | params.append(p)
364 | total_param_norm = 0
365 | for p in params:
366 | param_norm = p.data.norm(2)
367 | total_param_norm += param_norm.item() ** 2
368 | total_param_norm = total_param_norm ** 0.5
369 | momentum_norm = 0
370 | momentum_norm_sq = 0
371 | momentum_div = 0
372 | LL = len(optimizer.optimizers[0].state_dict()['state'])
373 | for jj in range(LL):
374 | momentum_norm += (optimizer.optimizers[0].state_dict()['state'][jj]['exp_avg'].detach().norm(2)) ** 2
375 | momentum_norm_sq += (optimizer.optimizers[0].state_dict()['state'][jj]['exp_avg_sq'].detach().norm(2)) ** 2
376 | momentum_norm = torch.sqrt(momentum_norm).item()
377 | momentum_norm_sq = torch.sqrt(momentum_norm_sq).item()
378 | momentum_div = momentum_norm/(np.sqrt(momentum_norm_sq)+1e-8)
379 | if wandb_log:
380 | wandb.log({
381 | "iter": iter_num,
382 | "train/loss": lossf,
383 | "lr": lr,
384 | "param_norm": total_param_norm,
385 | "momentum_norm" : momentum_norm,
386 | "momentum_norm_sq": momentum_norm_sq,
387 | "momentum_div": momentum_div,
388 | "train/clip_rate": clip_time / (iter_num + 1)
389 | }, step=iter_num)
390 | iter_num += 1
391 | local_iter_num += 1
392 |
393 | # termination conditions
394 | if iter_num > max_iters:
395 | break
396 |
397 | if ddp:
398 | destroy_process_group()
399 |
--------------------------------------------------------------------------------
/MARS/utils/configurator.py:
--------------------------------------------------------------------------------
1 | """
2 | Poor Man's Configurator. Probably a terrible idea. Example usage:
3 | $ python train.py config/override_file.py --batch_size=32
4 | this will first run config/override_file.py, then override batch_size to 32
5 |
6 | The code in this file will be run as follows from e.g. train.py:
7 | >>> exec(open('configurator.py').read())
8 |
9 | So it's not a Python module, it's just shuttling this code away from train.py
10 | The code in this script then overrides the globals()
11 |
12 | I know people are not going to love this, I just really dislike configuration
13 | complexity and having to prepend config. to every single variable. If someone
14 | comes up with a better simple Python solution I am all ears.
15 | """
16 |
17 | import sys
18 | from ast import literal_eval
19 |
20 | for arg in sys.argv[1:]:
21 | if '=' not in arg:
22 | # assume it's the name of a config file
23 | assert not arg.startswith('--')
24 | config_file = arg
25 | print(f"Overriding config with {config_file}:")
26 | with open(config_file) as f:
27 | print(f.read())
28 | exec(open(config_file).read())
29 | else:
30 | # assume it's a --key=value argument
31 | assert arg.startswith('--')
32 | key, val = arg.split('=')
33 | key = key[2:]
34 | if key in globals():
35 | try:
36 | # attempt to eval it it (e.g. if bool, number, or etc)
37 | attempt = literal_eval(val)
38 | except (SyntaxError, ValueError):
39 | # if that goes wrong, just use the string
40 | attempt = val
41 | # ensure the types match ok
42 | assert type(attempt) == type(globals()[key])
43 | # cross fingers
44 | print(f"Overriding: {key} = {attempt}")
45 | globals()[key] = attempt
46 | else:
47 | raise ValueError(f"Unknown config key: {key}")
48 |
--------------------------------------------------------------------------------
/MARS/utils/cv_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | import torchvision
5 | from torchvision import datasets, transforms
6 | from torch.utils.data import DataLoader
7 |
8 | def get_model(args):
9 | """
10 | models including:
11 | - VGG16
12 | - resnet18
13 | from https://github.com/iShohei220/adopt/blob/main/adopt.py and https://github.com/uclaml/Padam/blob/master/models/resnet.py
14 | """
15 | if args.dataset in ['mnist', 'cifar10']:
16 | num_classes = 10
17 | elif args.dataset in ['cifar100']:
18 | num_classes = 100
19 | else:
20 | raise NotImplementedError(f"{args.dataset} is not implemented.")
21 | if args.net == 'simple_cnn':
22 | from .model_CNN import Network
23 | model_config = {
24 | "n_inputs": (3, 32, 32) if args.dataset == "cifar10" else (1, 28, 28),
25 | "conv_layers_list": [
26 | {"filters": 32, "kernel_size": 3, "repeat": 2, "batch_norm": True},
27 | {"filters": 64, "kernel_size": 3, "repeat": 2, "batch_norm": True},
28 | {"filters": 128, "kernel_size": 3, "repeat": 2, "batch_norm": True},
29 | ],
30 | "n_hiddens_list": [512],
31 | "n_outputs": 10,
32 | "dropout": 0.2,
33 | }
34 | model = Network(**model_config)
35 | elif args.net == 'resnet18':
36 | from .model_CNN import ResNet18
37 | model = ResNet18(num_classes = num_classes)
38 | else:
39 | try:
40 | model = torchvision.models.get_model(args.net, num_classes=num_classes)
41 | except:
42 | print('Model not found')
43 | raise NotImplementedError
44 | return model
45 |
46 | def get_datasets(dataset_name: str, train_batch_size: int, eval_batch_size: int):
47 | """Get train and test dataloaders."""
48 | print('==> Preparing data..')
49 | if dataset_name == "mnist":
50 | transform = transforms.Compose([
51 | transforms.ToTensor(),
52 | transforms.Normalize((0.1307,), (0.3081,))
53 | ])
54 | train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
55 | test_dataset = datasets.MNIST('./data', train=False, transform=transform)
56 | elif dataset_name == "cifar10":
57 | transform_train = transforms.Compose([
58 | transforms.RandomCrop(32, padding=4),
59 | transforms.RandomHorizontalFlip(),
60 | transforms.ToTensor(),
61 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
62 | ])
63 | transform_test = transforms.Compose([
64 | transforms.ToTensor(),
65 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
66 | ])
67 | train_dataset = datasets.CIFAR10('./data', train=True, download=True, transform=transform_train)
68 | test_dataset = datasets.CIFAR10('./data', train=False, transform=transform_test)
69 | elif dataset_name == "cifar100":
70 | transform_train = transforms.Compose([
71 | transforms.RandomCrop(32, padding=4),
72 | transforms.RandomHorizontalFlip(),
73 | transforms.ToTensor(),
74 | transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276]),
75 | ])
76 | transform_test = transforms.Compose([
77 | transforms.ToTensor(),
78 | transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276]),
79 | ])
80 | train_dataset = datasets.CIFAR100('./data', train=True, download=True, transform=transform_train)
81 | test_dataset = datasets.CIFAR100('./data', train=False, transform=transform_test)
82 | else:
83 | raise NotImplementedError(f"{dataset_name=} is not implemented.")
84 |
85 | train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=4)
86 | test_loader = DataLoader(test_dataset, batch_size=eval_batch_size, shuffle=False, num_workers=4)
87 |
88 | return train_loader, test_loader
89 |
90 |
91 | class WarmupCosineScheduler:
92 | """Custom learning rate scheduler with linear warmup and cosine decay."""
93 | def __init__(self, optimizer, warmup_iters: int, total_iters: int, min_lr=0.):
94 | self.optimizer = optimizer
95 | self.warmup_iters = warmup_iters
96 | self.total_iters = total_iters
97 | self.min_lr = min_lr
98 | self.max_lr_list = []
99 | for param_group in self.optimizer.param_groups:
100 | self.max_lr_list.append(param_group['lr'])
101 | self.current_iter = 0
102 | self.lr_list = []
103 | for param_group in self.optimizer.param_groups:
104 | self.lr_list.append(param_group['lr'])
105 |
106 | def step(self):
107 | self.current_iter += 1
108 | lr_list = []
109 | cnt = 0
110 | for param_group in self.optimizer.param_groups:
111 | max_lr = self.max_lr_list[cnt]
112 | if self.current_iter <= self.warmup_iters:
113 | lr = self.current_iter / self.warmup_iters * max_lr
114 | else:
115 | lr = self.min_lr + 0.5 * (max_lr - self.min_lr) * (
116 | np.cos((self.current_iter - self.warmup_iters) / (self.total_iters - self.warmup_iters) * 3.14159265 / 2)
117 | ).item()
118 | param_group['lr'] = lr
119 | cnt += 1
120 | lr_list.append(lr)
121 | self.lr_list = lr_list
122 | def get_lr(self):
123 | lr_list = []
124 | for param_group in self.optimizer.param_groups:
125 | lr_list.append(param_group['lr'])
126 | return lr_list
127 |
128 | class ConstantScheduler:
129 | """Constant learning rate scheduler."""
130 | def __init__(self, optimizer, lr: float):
131 | self.optimizer = optimizer
132 | lr_list = []
133 | for param_group in self.optimizer.param_groups:
134 | lr_list.append(lr)
135 |
136 | def step(self):
137 | pass
138 |
139 | def get_lr(self):
140 | lr_list = []
141 | for param_group in self.optimizer.param_groups:
142 | lr_list.append(param_group['lr'])
143 | return lr_list
144 |
145 | def get_scheduler(optimizer, args):
146 | if args.scheduler == 'multistep':
147 | from torch.optim.lr_scheduler import MultiStepLR
148 | scheduler = MultiStepLR(optimizer, milestones=[args.Nepoch // 2, (args.Nepoch * 3) // 4], gamma=0.1)
149 | elif args.scheduler == 'cosine':
150 | scheduler = WarmupCosineScheduler(optimizer, warmup_iters = args.Nepoch // 10, total_iters = args.Nepoch,
151 | min_lr = 0.)
152 | elif args.scheduler == 'constant':
153 | scheduler = ConstantScheduler(optimizer, lr = args.lr)
154 | return scheduler
--------------------------------------------------------------------------------
/MARS/utils/model_CNN.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple, Type, Union
2 | import importlib
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | def pair(t):
9 | return t if isinstance(t, tuple) else (t, t)
10 |
11 |
12 | def get_activation(activation_f: str) -> Type:
13 | """Get PyTorch activation function by name."""
14 | package_name = "torch.nn"
15 | module = importlib.import_module(package_name)
16 |
17 | activations = [getattr(module, attr) for attr in dir(module)]
18 | activations = [
19 | cls for cls in activations if isinstance(cls, type) and issubclass(cls, nn.Module)
20 | ]
21 | names = [cls.__name__.lower() for cls in activations]
22 |
23 | try:
24 | index = names.index(activation_f.lower())
25 | return activations[index]
26 | except ValueError:
27 | raise NotImplementedError(f"get_activation: {activation_f=} is not yet implemented.")
28 |
29 |
30 | def compute_padding(
31 | input_size: tuple, kernel_size: int | tuple, stride: int | tuple = 1, dilation: int | tuple = 1
32 | ) -> Tuple[int, int]:
33 | """Compute padding for 'same' convolution."""
34 | if len(input_size) == 2:
35 | input_size = (*input_size, 1)
36 | if isinstance(kernel_size, int):
37 | kernel_size = (kernel_size, kernel_size)
38 | if isinstance(stride, int):
39 | stride = (stride, stride)
40 | if isinstance(dilation, int):
41 | dilation = (dilation, dilation)
42 |
43 | input_h, input_w, _ = input_size
44 | kernel_h, kernel_w = kernel_size
45 | stride_h, stride_w = stride
46 | dilation_h, dilation_w = dilation
47 |
48 | # Compute the effective kernel size after dilation
49 | effective_kernel_h = (kernel_h - 1) * dilation_h + 1
50 | effective_kernel_w = (kernel_w - 1) * dilation_w + 1
51 |
52 | # Compute the padding needed for same convolution
53 | pad_h = int(max((input_h - 1) * stride_h + effective_kernel_h - input_h, 0))
54 | pad_w = int(max((input_w - 1) * stride_w + effective_kernel_w - input_w, 0))
55 |
56 | # Compute the padding for each side
57 | pad_top = pad_h // 2
58 | pad_left = pad_w // 2
59 |
60 | return (pad_top, pad_left)
61 |
62 |
63 | class Base(nn.Module):
64 | """Base class for neural network models."""
65 | def __init__(self, **kwargs):
66 | super().__init__()
67 | self.__dict__.update(kwargs)
68 |
69 | @property
70 | def num_params(self):
71 | return sum(p.numel() for p in self.parameters())
72 |
73 | @property
74 | def shapes(self):
75 | return {name: p.shape for name, p in self.named_parameters()}
76 |
77 | def summary(self):
78 | print(self)
79 | print(f"Number of parameters: {self.num_params}")
80 |
81 |
82 | class Network(Base):
83 | """Fully Connected / Convolutional Neural Network
84 |
85 | Args:
86 | n_inputs (Union[List[int], Tuple[int], torch.Size]): Input shape
87 | n_outputs (int): Number of output classes
88 | conv_layers_list (List[dict], optional): List of convolutional layers. Defaults to [].
89 | n_hiddens_list (Union[List, int], optional): List of hidden units. Defaults to 0.
90 | activation_f (str, optional): Activation function. Defaults to "ReLU".
91 | dropout (float, optional): Dropout rate. Defaults to 0.0.
92 |
93 | conv_layers_list dict keys:
94 | filters: int
95 | kernel_size: int
96 | stride: int
97 | dilation: int
98 | padding: int
99 | bias: bool
100 | batch_norm: bool
101 | repeat: int
102 | """
103 | def __init__(
104 | self,
105 | n_inputs: Union[List[int], Tuple[int], torch.Size],
106 | n_outputs: int,
107 | conv_layers_list: List[dict] = [],
108 | n_hiddens_list: Union[List, int] = 0,
109 | activation_f: str = "ReLU",
110 | dropout: float = 0.0,
111 | ):
112 | super().__init__()
113 |
114 | if isinstance(n_hiddens_list, int):
115 | n_hiddens_list = [n_hiddens_list]
116 |
117 | if n_hiddens_list == [] or n_hiddens_list == [0]:
118 | self.n_hidden_layers = 0
119 | else:
120 | self.n_hidden_layers = len(n_hiddens_list)
121 |
122 | activation = get_activation(activation_f)
123 |
124 | # Convert n_inputs to tensor for shape calculations
125 | ni = torch.tensor(n_inputs)
126 |
127 | conv_layers = []
128 | if conv_layers_list:
129 | for conv_layer in conv_layers_list:
130 | n_channels = int(ni[0])
131 |
132 | padding = conv_layer.get(
133 | "padding",
134 | compute_padding( # same padding
135 | tuple(ni.tolist()),
136 | conv_layer["kernel_size"],
137 | conv_layer.get("stride", 1),
138 | conv_layer.get("dilation", 1),
139 | ),
140 | )
141 |
142 | # Add repeated conv blocks
143 | for i in range(conv_layer.get("repeat", 1)):
144 | # Convolutional layer
145 | conv_layers.append(
146 | nn.Conv2d(
147 | n_channels if i == 0 else conv_layer["filters"],
148 | conv_layer["filters"],
149 | conv_layer["kernel_size"],
150 | stride=conv_layer.get("stride", 1),
151 | padding=padding,
152 | dilation=conv_layer.get("dilation", 1),
153 | bias=conv_layer.get("bias", True),
154 | )
155 | )
156 |
157 | # Activation
158 | conv_layers.append(activation())
159 |
160 | # Optional batch norm
161 | if conv_layer.get("batch_norm"):
162 | conv_layers.append(nn.BatchNorm2d(conv_layer["filters"]))
163 |
164 | # Max pooling after each conv block
165 | conv_layers.append(nn.MaxPool2d(2, stride=2))
166 |
167 | # Optional dropout
168 | if dropout > 0:
169 | conv_layers.append(nn.Dropout(dropout))
170 |
171 | # Update input shape for next layer
172 | ni = torch.cat([torch.tensor([conv_layer["filters"]]), ni[1:] // 2])
173 |
174 | self.conv = nn.Sequential(*conv_layers)
175 |
176 | # Fully connected layers
177 | ni = int(torch.prod(ni))
178 | fcn_layers = []
179 | if self.n_hidden_layers > 0:
180 | for _, n_units in enumerate(n_hiddens_list):
181 | fcn_layers.extend([
182 | nn.Linear(ni, n_units),
183 | activation()
184 | ])
185 | if dropout > 0:
186 | fcn_layers.append(nn.Dropout(dropout))
187 | ni = n_units
188 |
189 | self.fcn = nn.Sequential(*fcn_layers)
190 | self.output = nn.Linear(ni, n_outputs)
191 |
192 | def forward(self, x: torch.Tensor) -> torch.Tensor:
193 | x = self.conv(x)
194 | x = x.view(x.size(0), -1)
195 | x = self.fcn(x)
196 | return self.output(x)
197 |
198 | '''ResNet in PyTorch.
199 |
200 | For Pre-activation ResNet, see 'preact_resnet.py'.
201 |
202 | Reference:
203 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
204 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
205 | '''
206 |
207 |
208 |
209 | class BasicBlock(nn.Module):
210 | expansion = 1
211 |
212 | def __init__(self, in_planes, planes, stride=1):
213 | super(BasicBlock, self).__init__()
214 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
215 | self.bn1 = nn.BatchNorm2d(planes)
216 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
217 | self.bn2 = nn.BatchNorm2d(planes)
218 |
219 | self.shortcut = nn.Sequential()
220 | if stride != 1 or in_planes != self.expansion*planes:
221 | self.shortcut = nn.Sequential(
222 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
223 | nn.BatchNorm2d(self.expansion*planes)
224 | )
225 |
226 | def forward(self, x):
227 | out = F.relu(self.bn1(self.conv1(x)))
228 | out = self.bn2(self.conv2(out))
229 | out += self.shortcut(x)
230 | out = F.relu(out)
231 | return out
232 |
233 |
234 | class Bottleneck(nn.Module):
235 | expansion = 4
236 |
237 | def __init__(self, in_planes, planes, stride=1):
238 | super(Bottleneck, self).__init__()
239 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
240 | self.bn1 = nn.BatchNorm2d(planes)
241 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
242 | self.bn2 = nn.BatchNorm2d(planes)
243 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
244 | self.bn3 = nn.BatchNorm2d(self.expansion*planes)
245 |
246 | self.shortcut = nn.Sequential()
247 | if stride != 1 or in_planes != self.expansion*planes:
248 | self.shortcut = nn.Sequential(
249 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
250 | nn.BatchNorm2d(self.expansion*planes)
251 | )
252 |
253 | def forward(self, x):
254 | out = F.relu(self.bn1(self.conv1(x)))
255 | out = F.relu(self.bn2(self.conv2(out)))
256 | out = self.bn3(self.conv3(out))
257 | out += self.shortcut(x)
258 | out = F.relu(out)
259 | return out
260 |
261 |
262 | class ResNet(nn.Module):
263 | def __init__(self, block, num_blocks, num_classes=10):
264 | super(ResNet, self).__init__()
265 | self.in_planes = 64
266 |
267 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
268 | self.bn1 = nn.BatchNorm2d(64)
269 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
270 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
271 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
272 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
273 | self.linear = nn.Linear(512*block.expansion, num_classes)
274 |
275 | def _make_layer(self, block, planes, num_blocks, stride):
276 | strides = [stride] + [1]*(num_blocks-1)
277 | layers = []
278 | for stride in strides:
279 | layers.append(block(self.in_planes, planes, stride))
280 | self.in_planes = planes * block.expansion
281 | return nn.Sequential(*layers)
282 |
283 | def forward(self, x):
284 | out = F.relu(self.bn1(self.conv1(x)))
285 | out = self.layer1(out)
286 | out = self.layer2(out)
287 | out = self.layer3(out)
288 | out = self.layer4(out)
289 | out = F.avg_pool2d(out, 4)
290 | out = out.view(out.size(0), -1)
291 | out = self.linear(out)
292 | return out
293 |
294 |
295 | def ResNet18(num_classes = 10):
296 | return ResNet(BasicBlock, [2,2,2,2], num_classes = num_classes)
297 |
298 | def ResNet34(num_classes = 10):
299 | return ResNet(BasicBlock, [3,4,6,3], num_classes = num_classes)
300 |
301 | def ResNet50(num_classes = 10):
302 | return ResNet(Bottleneck, [3,4,6,3], num_classes = num_classes)
303 |
304 | def ResNet101(num_classes = 10):
305 | return ResNet(Bottleneck, [3,4,23,3], num_classes = num_classes)
306 |
307 | def ResNet152(num_classes = 10):
308 | return ResNet(Bottleneck, [3,8,36,3], num_classes = num_classes)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MARS: Unleashing the Power of Variance Reduction for Training Large Models
2 |
3 | This repository contains the official code for the paper [MARS: Unleashing the Power of Variance Reduction for Training Large Models](https://arxiv.org/abs/2411.10438).
4 |
5 | Authors: [Huizhuo Yuan](https://scholar.google.com/citations?user=8foZzX4AAAAJ)\*, [Yifeng Liu](https://scholar.google.com/citations?user=mFvOVkMAAAAJ&hl=zh-CN)\*, Shuang Wu, Xun Zhou, [Quanquan Gu](https://web.cs.ucla.edu/~qgu/)
6 |
7 | ## 🔔 NEWS
8 | - **[05/01/2025]** Our paper is accepted by **ICML 2025** 🎉🎉.
9 | - **[02/10/2025]** Our paper is updated on ArXiv: https://arxiv.org/pdf/2411.10438v2.
10 | - **[01/12/2025]** Update scripts for reproducing GPT-2 XL results and FineWeb-Edu results.
11 | - **[01/12/2025]** Our pretraining results on FineWeb-Edu are available. GPT-2 XL reaches a Hellaswag accuracy of 56.52 in 50B tokens.
12 | - **[11/26/2024]** Vision tasks added.
13 | - **[11/18/2024]** Our code is open-sourced!
14 | - **[11/15/2024]** Our paper is released on arXiv: https://arxiv.org/abs/2411.10438.
15 |
16 | ## About MARS
17 |
18 | **MARS** (**M**ake v**A**riance **R**eduction **S**hine) is a unified optimization framework designed to address the inherent challenges of training large models. Traditional adaptive gradient methods like Adam and AdamW often suffer from high stochastic gradient variance, while variance reduction techniques have struggled to gain practical impact in deep learning. At its core, **MARS** comprises two major components: (1) a scaled stochastic recursive momentum, which provides a variance-reduced estimator of the full gradient for better gradient complexity; and (2) the preconditioned update, which approximates the second-order Newton's method for better per-iteration complexity. By combining preconditioned gradient methods with variance reduction, **MARS** achieves the best of both worlds, accelerating the search for critical points in optimization.
19 |
20 | The **MARS** framework is built on the following preconditioned variance-reduced updates
21 |
22 | $$
23 | \mathbf{c}\_t = \nabla f(\mathbf{x}\_t, \mathbf{\xi}\_t)+\underbrace{{\color{red}\gamma_t} \frac{\beta_{1}}{1-\beta_{1}} \left(\nabla f(\mathbf{x}\_t, \mathbf{\xi}\_t)-\nabla f(\mathbf{x}\_{t-1}, \mathbf{\xi}\_t)\right)}_{\text{scaled gradient correction}}
24 | $$
25 |
26 | $$
27 | \tilde{\mathbf{c}}_t = \text{Clip}(\mathbf{c}_t,1) = \begin{cases}
28 | \frac{\mathbf{c}_t}{\\|\mathbf{c}_t\\|_2} & \text{if } \\|\mathbf{c}_t\\|_2 > 1,\\
29 | \mathbf{c}_t & \text{otherwise}.
30 | \end{cases}
31 | $$
32 |
33 | $$
34 | \mathbf{m}\_t = \beta_1 \mathbf{m}\_{t-1} + (1-\beta_{1})\tilde{\mathbf{c}}\_t
35 | $$
36 |
37 | $$
38 | \mathbf{x}\_{t+1} = \arg\min_{\mathbf{x} \in \mathbb{R}^d} \left\\{\eta_t \left\langle \mathbf{m}_t, \mathbf{x} \right\rangle + \frac{1}{2} \\|\mathbf{x} - \mathbf{x}\_t
39 | \\|\_{\mathbf{H}_t}^2\right\\}
40 | $$
41 |
42 | Here ${\color{red}\gamma_t}$ is a scaling parameter that controls the strength of gradient correction.
43 |
44 | ### Instantiations of **MARS**
45 |
46 | Under the **MARS** framework, we provide three instantiations based on different Hessian matrix approximations: **MARS-AdamW**, **MARS-Lion**, and **MARS-Shampoo**. Please note that the hyperparameters in this framework are tuned on **MARS-AdamW**. When using other instantiations, it is essential to tune the hyperparameters—particularly the learning rates—for optimal performance.
47 |
48 | #### MARS-AdamW
49 |
50 | (Enable with `mars_type="mars-adamw"` in `mars.py`)
51 |
52 | The Hessian matrix approximation is defined as:
53 |
54 | $$
55 | \mathbf{v}\_t =\beta_2 \mathbf{v}\_{t-1}+(1-\beta_2) \big(\nabla f(\mathbf{x}\_t, \mathbf{\xi}\_t)\big)^2
56 | $$
57 |
58 | $$
59 | \mathbf{H}_t := \sqrt{\text{diag}\Big(\mathbf{v}_t\Big)}\cdot \frac{1 - \beta_1^t}{\sqrt{1 - \beta_2^t}}.
60 | $$
61 |
62 | #### MARS-Lion
63 |
64 | (Enable with `mars_type="mars-lion"` in `mars.py`)
65 |
66 | The Hessian matrix approximation is defined as:
67 |
68 | $$
69 | \mathbf{H}_t := \sqrt{\text{diag}(\mathbf{m}_t^2)}.
70 | $$
71 |
72 | #### MARS-Shampoo
73 |
74 | (Enable with `mars_type="mars-shampoo"` in `mars.py`)
75 |
76 | The preconditioner can be seen as an [orthogonal mapping](https://arxiv.org/abs/2409.20325) operator:
77 |
78 | $$
79 | \mathbf{U}\_t, \mathbf{\Sigma}\_t, \mathbf{V}\_t = \text{SVD}(\mathbf{G}\_t),\qquad
80 | \mathbf{x}\_{t+1} =\mathbf{x}\_t-\eta_t\mathbf{U}_t\mathbf{V}\_t^\top.
81 | $$
82 |
83 | In practice, we use the [Newton-Schulz iteration](https://github.com/KellerJordan/modded-nanogpt) to accelerate and approximate the solution of SVD problem.
84 |
85 | ### **Performance of MARS Compared to Baselines**
86 |
87 | #### Experiments on OpenWebText
88 |
89 | Experimental results for **MARS** are based on the **MARS-AdamW** instantiation, unless otherwise stated. In our experiments, gradients are calculated once per sample and per update (**MARS**-approx in our [paper](https://arxiv.org/abs/2411.10438)). Performing exact gradient computation with two evaluations per update, as in the exact form of **MARS**, can slightly enhance performance but at the cost of doubling the computational expense. For more details, refer to our [paper](https://arxiv.org/abs/2411.10438).
90 |
91 | **MARS** consistently outperforms AdamW and the [Muon]([https://github.com/KellerJordan/modded-nanogpt](https://github.com/KellerJordan/modded-nanogpt/tree/e01b457c7c52e1cd0c592920499a016f5289a69e)) optimizers across GPT-2 models:
92 |
93 | | **GPT-2 small** | **GPT-2 medium** | **GPT-2 large** |
94 | | ------------------------------------------------ | ------------------------------------------------- | ------------------------------------------------ |
95 | |
|
|
|
96 |
97 | | Best Val Loss | GPT-2 Small (5B tokens) | GPT-2 Medium (5B tokens) | GPT-2 Large (5B tokens) | GPT-2 Small (20B tokens) | GPT-2 Medium (20B tokens) | GPT-2 Large (20B tokens) | GPT-2 Small (50B tokens) | GPT-2 Medium (50B tokens) | GPT-2 Large (50B tokens) |
98 | | --------------------- | ----------------------- | ------------------------ | ----------------------- | ------------------------ | ------------------------- | ------------------------ | ------------------------ | ------------------------- | ------------------------ |
99 | | AdamW | 3.193 | 3.084 | 3.013 | 3.024 | 2.821 | 2.741 | 2.885 | 2.691 | 2.561 |
100 | | Muon | 3.165 | 3.009 | 2.915 | 3.006 | 2.813 | 2.691 | 2.901 | 2.688 | 2.573 |
101 | | **MARS**-exact | **3.107** | - | - | 2.980 | - | - | **2.847** | - | - |
102 | | **MARS**-approx | 3.108 | **2.969** | **2.876** | **2.981** | **2.763** | **2.647** | **2.849** | **2.636** | **2.518** |
103 |
104 |
105 | #### Efficiency of MARS
106 |
107 | The **MARS** algorithm can achieve better performance not only within the same number of training steps, but also within the same training time:
108 |
109 | | **GPT-2 small** | **GPT-2 medium** | **GPT-2 large** |
110 | | ------------------------------------------------- | -------------------------------------------------- | ------------------------------------------------- |
111 | |
|
|
|
112 |
113 | ---
114 |
115 | #### Experiments on FineWeb-Edu
116 |
117 | Below are the training and validation loss curves for both GPT‑2 Small and GPT‑2 XL when using our MARS approach versus AdamW. As you can see, MARS often yields faster convergence and consistently lower losses across different training steps.
118 |
119 | | Model | **GPT-2 small** | **GPT-2 XL** |
120 | | ----------------------- | -------------------------------------------------------- | --------------------------------------------------------- |
121 | | **Train Loss** |
|
|
122 | | **Validation Loss** |
|
|
123 |
124 | ##### Evaluation Metrics
125 | Below, we present the evaluation metrics on the FineWeb-Edu dataset for both GPT‑2 Small and GPT‑2 XL, comparing OpenAI GPT2 baseline, AdamW, and our MARS-AdamW optimizer.
126 |
127 |
128 |
129 | **Results on GPT-2 small**
130 |
131 | MARS-AdamW shows a clear improvement over AdamW and the OpenAI baseline across multiple tasks, with the **highest average score** of 45.93 on GPT‑2 Small.
132 | | Method/Task | ARC-E | ARC-C | BoolQ | HellaSwag | OBQA | PIQA | WG | MMLU | SciQ | Avg. |
133 | |--------------|-------|-------|-------|-----------|-------|-------|-------|-------|-------|-------|
134 | | OpenAI-Comm. | 39.48 | 22.70 | 48.72 | 31.14 | 27.20 | 62.51 | **51.62** | 22.92 | 64.40 | 41.19 |
135 | | AdamW | 51.43 | 26.54 | 55.78 | 36.26 | 30.60 | 64.53 | 50.36 | **24.49** | **71.50** | 45.72 |
136 | | MARS-AdamW | **52.23** | **27.39** | **55.84** | **36.91** | **32.20** | **64.80** | 49.96 | 22.95 | 71.10 | **45.93** |
137 |
138 | **Results on GPT-2 XL**
139 |
140 | On GPT‑2 XL, MARS-AdamW continues to outperform AdamW across most tasks, delivering an impressive **HellaSwag accuracy of 56.52**.
141 |
142 | | Method/Task | ARC-E | ARC-C | BoolQ | HellaSwag | OBQA | PIQA | WG | MMLU | SciQ | Avg. |
143 | |--------------|-------|-------|-------|-----------|-------|-------|-------|-------|-------|-------|
144 | | OpenAI-Comm. | 51.05 | 28.50 | 61.77 | 50.89 | 32.00 | 70.51 | **58.33** | 25.24 | 76.00 | 50.48 |
145 | | AdamW | **68.22** | 38.40 | 61.13 | 53.93 | 39.00 | 72.69 | 54.78 | **25.47** | 85.30 | 55.43 |
146 | | MARS-AdamW | 66.54 | **39.85** | **63.82** | **56.52** | **41.20** | **73.34** | 56.59 | 23.86 | **86.00** | **56.41** |
147 |
148 | ---
149 |
150 | #### Experiments on Vision Tasks
151 |
152 | **MARS** can achieve better test loss and accuracy than AdamW and the [Muon]([https://github.com/KellerJordan/modded-nanogpt](https://github.com/KellerJordan/modded-nanogpt/tree/e01b457c7c52e1cd0c592920499a016f5289a69e)) optimizers on CIFAR-10 and CIFAR-100 datasets with ResNet-18 and MultiStepLR(optimizer, milestones=[100, 150], gamma=0.1) scheduler (We display the best results for each optimizer with grid search of base learning rate within [1e-5, ..., 1e-1]):
153 |
154 | | Dataset | **CIFAR-10** | **CIFAR-100** |
155 | | ----------------------- | -------------------------------------------------------- | --------------------------------------------------------- |
156 | | **Test loss** |
|
|
157 | | **Test Accuracy** |
|
|
158 |
159 | | Best Test loss | CIFAR-10 | CIFAR-100 |
160 | | --------------------- | ---------- | ---------- |
161 | | AdamW | 0.306 | 2.608 |
162 | | Muon | 0.230 | 1.726 |
163 | | **MARS**-approx | **0.199** | **0.971** |
164 |
165 | | Best Test Accuracy (%) | CIFAR-10 | CIFAR-100 |
166 | | ---------------------- | --------------- | --------------- |
167 | | AdamW | 94.81 | 73.7 |
168 | | Muon | 95.08 | 74.64 |
169 | | **MARS**-approx | **95.29** | **76.97** |
170 |
171 |
172 | ## Training GPT-2 from Scratch:
173 |
174 | ### Install Dependencies
175 |
176 | ```
177 | $ pip install torch==2.1.2 transformers==4.33.0 datasets tiktoken numpy==1.26.4 wandb
178 | ```
179 |
180 | ### Data Preparation
181 |
182 | Prepare the [OpenWebText](https://huggingface.co/datasets/openwebtext) data following [nanoGPT](https://github.com/karpathy/nanoGPT/):
183 |
184 | ```
185 | $ python data/openwebtext/prepare.py
186 | ```
187 |
188 | ### **Start Training**
189 |
190 | To train a model using the **MARS** optimizer, run the following command:
191 |
192 | ```bash
193 | $ torchrun --standalone --nproc_per_node=8 MARS/train_mars.py config/${your_config_file}
194 | ```
195 |
196 | This command initiates the training of a GPT-2 model on the OpenWebText dataset using the **MARS** optimizer. All relevant hyperparameters—training, model, and optimizer—are specified in the configuration file (`${your_config_file}`). These parameters can be adjusted directly in the configuration file or through the bash script.
197 |
198 | ### **Hyperparameter Details**
199 |
200 | #### **Model Hyperparameters**:
201 |
202 | - **n_layer**: Layers of networks, 12 for GPT2 Small, 24 for GPT2 Medium, 36 for GPT2 Large
203 | - **n_head**: Number of heads, 12 for GPT2 small, 16 for GPT2 Medium, 20 for GPT2 Large
204 | - **n_embd**: Embedding dimension, 768 for GPT2 small, 1024 for GPT2 Medium, 1280 for GPT2 Large
205 |
206 | #### **Optimizer Hyperparameters**:
207 |
208 | - **`learning_rate`**: Learning rate for the **MARS** optimizer.
209 | - **`weight_decay`**: Weight decay for the **MARS** optimizer.
210 | - **`beta1, beta2`**: Weights for exponential moving average.
211 | - Default: `beta1=0.95, beta2=0.99`
212 | - **`mars_type`**: Type of optimizer to use:
213 | - Options: `mars-adamw`, `mars-lion`, `mars-shampoo`
214 | - Default: `mars-adamw`
215 | - **`optimize_1d`**: Whether **MARS** should optimize 1D parameters (e.g., layer norm parameters in GPT-2).
216 | - If `False`, AdamW will be used for optimizing 1D parameters.
217 | - Default: `False`
218 | - **`lr_1d`**: Learning rate for AdamW when **`optimize_1d`** is set to `False`.
219 | - **`betas_1d`**: Weights for exponential moving average in AdamW optimizer.
220 | - Default: `(0.9, 0.95)`
221 | - **`is_approx`**: Whether to use approximate gradient calculation (**MARS**-approx).
222 | - Default: `True`
223 | - **`gamma`**: The scaling parameter that controls the strength of gradient correction.
224 | - Default: 0.025
225 |
226 | #### **Training Hyperparameters**:
227 |
228 | - **`batch_size`**: Mini-batch size per device. (for example GPT-2 Small on an A100 GPU typically uses a batch size of 15.)
229 | - **`gradient_accumulation_steps`**: Gradient accumulation steps to ensure the total effective batch size matches the desired scale. (for example, for a total batch size of 480: $15 \times 4 \times 8 \, \text{GPUs}$.)
230 | - **`schedule`**: learning rate schedule.
231 | - Default: `cosine`
232 |
233 | For more detailed hyperparameter examples, refer to:
234 |
235 | - `config/train_gpt2_small_mars.py`
236 | - `scripts/run_mars_small.sh`
237 |
238 | ---
239 |
240 | ### Reproducing Our Results
241 |
242 | #### **Reproducing GPT-2 Small (125M) Results**
243 |
244 | Training with MARS using
245 |
246 | ```
247 | $ bash scripts/run_mars_small.sh
248 | ```
249 |
250 | or
251 |
252 | ```
253 | $ torchrun --standalone --nproc_per_node=8 \
254 | MARS/train_mars.py \
255 | config/train_gpt2_small_mars.py \
256 | --batch_size=15 \
257 | --gradient_accumulation_steps=4
258 | ```
259 |
260 | #### Reproducing GPT2 Medium (355M) Results
261 |
262 | Training with MARS using
263 |
264 | ```
265 | $ bash scripts/run_mars_medium.sh
266 | ```
267 |
268 | or
269 |
270 | ```
271 | $ torchrun --standalone --nproc_per_node=8 \
272 | MARS/train_mars.py \
273 | config/train_gpt2_medium_mars.py \
274 | --batch_size=15 \
275 | --gradient_accumulation_steps=4
276 | ```
277 |
278 | #### Reproducing GPT2 Large (770M) Results
279 |
280 | Training with MARS using
281 |
282 | ```
283 | $ bash scripts/run_mars_large.sh
284 | ```
285 |
286 | or
287 |
288 | ```
289 | $ torchrun --standalone --nproc_per_node=8 \
290 | MARS/train_mars.py \
291 | config/train_gpt2_large_mars.py \
292 | --batch_size=5 \
293 | --gradient_accumulation_steps=12
294 | ```
295 |
296 | #### **Reproducing GPT-2 XL (1.5B) Results on FineWeb-Edu**
297 | ```
298 | $ bash scripts/run_mars_xl_fw.sh
299 | ```
300 |
301 | or
302 |
303 | ```
304 | $ torchrun --standalone --nproc_per_node=8 \
305 | MARS/train_mars_fw.py \
306 | config/train_gpt2_xl_mars.py \
307 | --batch_size=5 \
308 | --gradient_accumulation_steps=12
309 | ```
310 |
311 | #### Reproducing Baseline Results
312 |
313 | To reproduce the AdamW baseline:
314 |
315 | ```
316 | bash scripts/run_adamw_{small/medium/large}.sh
317 | ```
318 | To reproduce the AdamW baseline on FineWeb-Edu:
319 | ```
320 | bash scripts/run_adamw_{small/xl}_fw.sh
321 | ```
322 |
323 | To reproduce the Muon baseline following [modded-nanogpt](https://github.com/KellerJordan/modded-nanogpt/tree/e01b457c7c52e1cd0c592920499a016f5289a69e):
324 |
325 | ```
326 | bash scripts/run_muon_{small/medium/large}.sh
327 | ```
328 |
329 | Please adjust ``nproc_per_node``, ``batch_size``, and ``gradient_accumulation_steps`` accordingly if you use other hardware setup. Make sure their product equals 480.
330 |
331 | #### Hyperparameters for GPT-2 models
332 |
333 | | Model Name | Model Size | lr for AdamW | lr for Muon | lr for MARS | lr_1d for MARS | wd for AdamW | wd for Muon | wd for MARS |
334 | | :----------: | :--------: | :----------: | :---------: | :---------: | :------------: | :----------: | :---------: | :---------: |
335 | | GPT-2 small | 125M | 6e-4 | 2e-2 | 6e-3 | 3e-3 | 1e-1 | 0.0 | 1e-2 |
336 | | GPT-2 medium | 355M | 3e-4 | 1e-2 | 3e-3 | 1.5e-3 | 1e-1 | 0.0 | 1e-2 |
337 | | GPT-2 large | 770M | 2e-4 | 6.67e-3 | 2e-3 | 1e-3 | 1e-1 | 0.0 | 1e-2 |
338 | | GPT-2 xl | 1.5B | 2e-4 | - | 2e-3 | 1e-3 | 1e-1 | - | 1e-2 |
339 |
340 |
341 |
342 | ### Customized Training
343 |
344 | To build your own training pipeline on other architectures and datasets, use the following template as an example:
345 |
346 | ```python
347 | import torch
348 | import torch.nn.functional as F
349 | from mars import MARS
350 |
351 | # init model loss function and input data
352 | model = Model()
353 | data_loader = ...
354 |
355 | # init the optimizer
356 | optimizer = MARS(model.parameters(), lr=1e-3, betas=(0.9, 0.95), gamma=0.025)
357 |
358 | total_bs = len(data_loader)
359 | bs = total_bs * block_size
360 | k = 10
361 | iter_num = -1
362 |
363 | # training loop
364 | for epoch in range(epochs):
365 | for X, Y in data_loader:
366 | # standard training code
367 | logits, loss = model(X, Y)
368 | loss.backward()
369 | optimizer.step(bs=bs)
370 | optimizer.zero_grad(set_to_none=True)
371 | optimizer.update_last_grad()
372 | iter_num += 1
373 |
374 | ```
375 |
376 | ## Star History
377 |
378 | [](https://www.star-history.com/#AGI-Arena/MARS&Date)
379 |
380 | ## Citation
381 |
382 | If you find this repo useful for your research, please consider citing the paper
383 |
384 | ```tex
385 | @article{yuan2024mars,
386 | title={MARS: Unleashing the Power of Variance Reduction for Training Large Models},
387 | author={Yuan, Huizhuo and Liu, Yifeng and Wu, Shuang and Zhou, Xun and Gu, Quanquan},
388 | journal={arXiv preprint arXiv:2411.10438},
389 | year={2024}
390 | }
391 | ```
392 |
393 | ## Acknowledgements
394 |
395 | This repo is built upon [nanoGPT](https://github.com/karpathy/nanoGPT/), [levanter](https://github.com/stanford-crfm/levanter/) and [Sophia](https://github.com/Liuhong99/Sophia), we thank the authors for their great work!
396 |
--------------------------------------------------------------------------------
/assets/MARS-AdamW.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/2bec90173df91810d3683de47157e95033209d33/assets/MARS-AdamW.png
--------------------------------------------------------------------------------
/assets/MARS-Lion.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/2bec90173df91810d3683de47157e95033209d33/assets/MARS-Lion.png
--------------------------------------------------------------------------------
/assets/MARS-Shampoo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/2bec90173df91810d3683de47157e95033209d33/assets/MARS-Shampoo.png
--------------------------------------------------------------------------------
/assets/MARS.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/2bec90173df91810d3683de47157e95033209d33/assets/MARS.png
--------------------------------------------------------------------------------
/assets/ShampooH.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/2bec90173df91810d3683de47157e95033209d33/assets/ShampooH.png
--------------------------------------------------------------------------------
/assets/cifar100_test_acc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/2bec90173df91810d3683de47157e95033209d33/assets/cifar100_test_acc.png
--------------------------------------------------------------------------------
/assets/cifar100_test_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/2bec90173df91810d3683de47157e95033209d33/assets/cifar100_test_loss.png
--------------------------------------------------------------------------------
/assets/cifar10_test_acc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/2bec90173df91810d3683de47157e95033209d33/assets/cifar10_test_acc.png
--------------------------------------------------------------------------------
/assets/cifar10_test_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/2bec90173df91810d3683de47157e95033209d33/assets/cifar10_test_loss.png
--------------------------------------------------------------------------------
/assets/fineweb_hella.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/2bec90173df91810d3683de47157e95033209d33/assets/fineweb_hella.png
--------------------------------------------------------------------------------
/assets/small_train.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/2bec90173df91810d3683de47157e95033209d33/assets/small_train.png
--------------------------------------------------------------------------------
/assets/small_val.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/2bec90173df91810d3683de47157e95033209d33/assets/small_val.png
--------------------------------------------------------------------------------
/assets/time_large.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/2bec90173df91810d3683de47157e95033209d33/assets/time_large.png
--------------------------------------------------------------------------------
/assets/time_medium.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/2bec90173df91810d3683de47157e95033209d33/assets/time_medium.png
--------------------------------------------------------------------------------
/assets/time_small.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/2bec90173df91810d3683de47157e95033209d33/assets/time_small.png
--------------------------------------------------------------------------------
/assets/val_large.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/2bec90173df91810d3683de47157e95033209d33/assets/val_large.png
--------------------------------------------------------------------------------
/assets/val_medium.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/2bec90173df91810d3683de47157e95033209d33/assets/val_medium.png
--------------------------------------------------------------------------------
/assets/val_small.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/2bec90173df91810d3683de47157e95033209d33/assets/val_small.jpg
--------------------------------------------------------------------------------
/assets/val_small.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/2bec90173df91810d3683de47157e95033209d33/assets/val_small.png
--------------------------------------------------------------------------------
/assets/xl_train.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/2bec90173df91810d3683de47157e95033209d33/assets/xl_train.png
--------------------------------------------------------------------------------
/assets/xl_val.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/2bec90173df91810d3683de47157e95033209d33/assets/xl_val.png
--------------------------------------------------------------------------------
/config/train_gpt2_large_adamw.py:
--------------------------------------------------------------------------------
1 | wandb_log = True
2 | wandb_project = 'mars'
3 | wandb_run_name='gpt2-large-adamw-100k'
4 |
5 | batch_size = 5
6 | block_size = 1024
7 | gradient_accumulation_steps = 12
8 |
9 | n_layer = 36
10 | n_head = 20
11 | n_embd = 1280
12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
13 | bias = False
14 | scale_attn_by_inverse_layer_idx = True
15 |
16 | # this makes total number of tokens be 300B
17 | max_iters = 100000
18 | lr_decay_iters = 100000
19 |
20 | # eval stuff
21 | eval_interval = 1000
22 | eval_iters = 200
23 | log_interval = 10
24 |
25 | # optimizer
26 | optimizer_name = 'adamw'
27 | learning_rate = 2e-4 # max learning rate
28 | weight_decay = 1e-1
29 | beta1 = 0.9
30 | beta2 = 0.95
31 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
32 | # learning rate decay settings
33 | decay_lr = True # whether to decay the learning rate
34 | warmup_iters = 2000 # how many steps to warm up for
35 | min_lr = 1e-5
36 |
37 | compile = True
38 |
39 | out_dir = 'out_large_adamw_100k'
40 |
--------------------------------------------------------------------------------
/config/train_gpt2_large_mars.py:
--------------------------------------------------------------------------------
1 | wandb_log = True
2 | wandb_project = 'mars'
3 | wandb_run_name='gpt2-large-mars-100k'
4 |
5 | batch_size = 5
6 | block_size = 1024
7 | gradient_accumulation_steps = 12
8 |
9 | n_layer = 36
10 | n_head = 20
11 | n_embd = 1280
12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
13 | bias = False
14 | scale_attn_by_inverse_layer_idx = True
15 |
16 | # this makes total number of tokens be 300B
17 | max_iters = 100000
18 | lr_decay_iters = 100000
19 |
20 | # eval stuff
21 | eval_interval = 1000
22 | eval_iters = 200
23 | log_interval = 10
24 |
25 | # optimizer
26 | optimizer_name = 'mars'
27 | learning_rate = 2e-3 # max learning rate
28 | weight_decay = 1e-2
29 | beta1 = 0.95
30 | beta2 = 0.99
31 | lr_1d=1e-3
32 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
33 | # learning rate decay settings
34 | decay_lr = True # whether to decay the learning rate
35 | warmup_iters = 2000 # how many steps to warm up for
36 | min_lr = 1e-5
37 |
38 | compile = True
39 |
40 | out_dir = 'out_large_mars_100k'
41 | gamma=0.025
42 |
--------------------------------------------------------------------------------
/config/train_gpt2_large_muon.py:
--------------------------------------------------------------------------------
1 | wandb_log = True
2 | wandb_project = 'mars'
3 | wandb_run_name='gpt2-large-muon-100k'
4 |
5 | batch_size = 5
6 | block_size = 1024
7 | gradient_accumulation_steps = 12
8 |
9 | n_layer = 36
10 | n_head = 20
11 | n_embd = 1280
12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
13 | bias = False
14 | scale_attn_by_inverse_layer_idx = True
15 |
16 | # this makes total number of tokens be 300B
17 | max_iters = 100000
18 | lr_decay_iters = 100000
19 |
20 | # eval stuff
21 | eval_interval = 1000
22 | eval_iters = 200
23 | log_interval = 10
24 |
25 | # optimizer
26 | optimizer_name = 'muon'
27 | learning_rate = 1e-3 # max learning rate
28 | weight_decay = 1e-1
29 | muon_learning_rate = 6.67e-3
30 | muon_weight_decay = 0.
31 | beta1 = 0.9
32 | beta2 = 0.95
33 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
34 | # learning rate decay settings
35 | decay_lr = True # whether to decay the learning rate
36 | warmup_iters = 2000 # how many steps to warm up for
37 | min_lr = 1e-5
38 |
39 | compile = True
40 |
41 | out_dir = 'out_large_muon_100k'
42 |
--------------------------------------------------------------------------------
/config/train_gpt2_medium_adamw.py:
--------------------------------------------------------------------------------
1 | wandb_log = True
2 | wandb_project = 'mars'
3 | wandb_run_name='gpt2-medium-adamw-100k'
4 |
5 | batch_size = 15
6 | block_size = 1024
7 | gradient_accumulation_steps = 4
8 |
9 | n_layer = 24
10 | n_head = 16
11 | n_embd = 1024
12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
13 | bias = False
14 | scale_attn_by_inverse_layer_idx = True
15 |
16 | # this makes total number of tokens be 300B
17 | max_iters = 100000
18 | lr_decay_iters = 100000
19 |
20 | # eval stuff
21 | eval_interval = 1000
22 | eval_iters = 200
23 | log_interval = 10
24 |
25 | # optimizer
26 | optimizer_name = 'adamw'
27 | learning_rate = 3e-4 # max learning rate
28 | weight_decay = 1e-1
29 | beta1 = 0.9
30 | beta2 = 0.95
31 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
32 | # learning rate decay settings
33 | decay_lr = True # whether to decay the learning rate
34 | warmup_iters = 2000 # how many steps to warm up for
35 | min_lr = 6e-5
36 |
37 | compile = True
38 |
39 | out_dir = 'out_medium_adamw_100k'
40 |
--------------------------------------------------------------------------------
/config/train_gpt2_medium_mars.py:
--------------------------------------------------------------------------------
1 | wandb_log = True
2 | wandb_project = 'mars'
3 | wandb_run_name='gpt2-medium-mars-100k'
4 |
5 | batch_size = 15
6 | block_size = 1024
7 | gradient_accumulation_steps = 4
8 |
9 | n_layer = 24
10 | n_head = 16
11 | n_embd = 1024
12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
13 | bias = False
14 | scale_attn_by_inverse_layer_idx = True
15 |
16 | # this makes total number of tokens be 300B
17 | max_iters = 100000
18 | lr_decay_iters = 100000
19 |
20 | # eval stuff
21 | eval_interval = 1000
22 | eval_iters = 200
23 | log_interval = 10
24 |
25 | # optimizer
26 | optimizer_name = 'mars'
27 | learning_rate = 3e-3 # max learning rate
28 | weight_decay = 1e-2
29 | beta1 = 0.95
30 | beta2 = 0.99
31 | lr_1d=1.5e-3
32 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
33 | # learning rate decay settings
34 | decay_lr = True # whether to decay the learning rate
35 | warmup_iters = 2000 # how many steps to warm up for
36 | min_lr = 6e-5
37 |
38 | compile = True
39 |
40 | out_dir = 'out_medium_mars_100k'
41 | gamma=0.025
42 |
--------------------------------------------------------------------------------
/config/train_gpt2_medium_muon.py:
--------------------------------------------------------------------------------
1 | wandb_log = True
2 | wandb_project = 'mars'
3 | wandb_run_name='gpt2-medium-muon-100k'
4 |
5 | batch_size = 15
6 | block_size = 1024
7 | gradient_accumulation_steps = 4
8 |
9 | n_layer = 24
10 | n_head = 16
11 | n_embd = 1024
12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
13 | bias = False
14 | scale_attn_by_inverse_layer_idx = True
15 |
16 | # this makes total number of tokens be 300B
17 | max_iters = 100000
18 | lr_decay_iters = 100000
19 |
20 | # eval stuff
21 | eval_interval = 1000
22 | eval_iters = 200
23 | log_interval = 10
24 |
25 | # optimizer
26 | optimizer_name = 'muon'
27 | learning_rate = 1.5e-3 # max learning rate
28 | weight_decay = 1e-1
29 | muon_learning_rate = 1e-2
30 | muon_weight_decay = 0.
31 | beta1 = 0.9
32 | beta2 = 0.95
33 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
34 | # learning rate decay settings
35 | decay_lr = True # whether to decay the learning rate
36 | warmup_iters = 2000 # how many steps to warm up for
37 | min_lr = 6e-5
38 |
39 | compile = True
40 |
41 | out_dir = 'out_medium_muon_100k'
42 |
--------------------------------------------------------------------------------
/config/train_gpt2_small_adamw.py:
--------------------------------------------------------------------------------
1 | wandb_log = True
2 | wandb_project = 'mars'
3 | wandb_run_name='gpt2-small-adamw-100k'
4 |
5 | batch_size = 15
6 | block_size = 1024
7 | gradient_accumulation_steps = 4
8 |
9 | n_layer = 12
10 | n_head = 12
11 | n_embd = 768
12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
13 | bias = False
14 |
15 | # this makes total number of tokens be 300B
16 | max_iters = 100000
17 | lr_decay_iters = 100000
18 |
19 | # eval stuff
20 | eval_interval = 1000
21 | eval_iters = 200
22 | log_interval = 10
23 |
24 | # optimizer
25 | optimizer_name = 'adamw'
26 | learning_rate = 6e-4 # max learning rate
27 | weight_decay = 1e-1
28 | beta1 = 0.9
29 | beta2 = 0.95
30 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
31 | # learning rate decay settings
32 | decay_lr = True # whether to decay the learning rate
33 | warmup_iters = 2000 # how many steps to warm up for
34 | min_lr = 3e-5
35 |
36 | compile = True
37 |
38 | out_dir = 'out_small_adamw_100k'
39 |
--------------------------------------------------------------------------------
/config/train_gpt2_small_mars.py:
--------------------------------------------------------------------------------
1 | wandb_log = True
2 | wandb_project = 'mars'
3 | wandb_run_name='gpt2-small-mars-100k'
4 |
5 | batch_size = 15
6 | block_size = 1024
7 | gradient_accumulation_steps = 4
8 |
9 | n_layer = 12
10 | n_head = 12
11 | n_embd = 768
12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
13 | bias = False
14 |
15 | # this makes total number of tokens be 300B
16 | max_iters = 100000
17 | lr_decay_iters = 100000
18 |
19 | # eval stuff
20 | eval_interval = 1000
21 | eval_iters = 200
22 | log_interval = 10
23 |
24 | # optimizer
25 | optimizer_name = 'mars'
26 | learning_rate = 6e-3 # max learning rate
27 | weight_decay = 1e-2
28 | beta1 = 0.95
29 | beta2 = 0.99
30 | lr_1d=3e-3
31 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
32 | # learning rate decay settings
33 | decay_lr = True # whether to decay the learning rate
34 | warmup_iters = 2000 # how many steps to warm up for
35 | min_lr = 3e-5
36 |
37 | compile = True
38 |
39 | out_dir = 'out_small_mars_100k'
40 | gamma=0.025
41 |
--------------------------------------------------------------------------------
/config/train_gpt2_small_muon.py:
--------------------------------------------------------------------------------
1 | wandb_log = True
2 | wandb_project = 'mars'
3 | wandb_run_name='gpt2-small-muon-100k'
4 |
5 | batch_size = 15
6 | block_size = 1024
7 | gradient_accumulation_steps = 4
8 |
9 | n_layer = 12
10 | n_head = 12
11 | n_embd = 768
12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
13 | bias = False
14 |
15 | # this makes total number of tokens be 300B
16 | max_iters = 100000
17 | lr_decay_iters = 100000
18 |
19 | # eval stuff
20 | eval_interval = 1000
21 | eval_iters = 200
22 | log_interval = 10
23 |
24 | # optimizer
25 | optimizer_name = 'muon'
26 | learning_rate = 3e-3 # max learning rate, original=6e-4
27 | weight_decay = 1e-1
28 | muon_learning_rate = 2e-2
29 | muon_weight_decay = 0.
30 | beta1 = 0.9
31 | beta2 = 0.95
32 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
33 | # learning rate decay settings
34 | decay_lr = True # whether to decay the learning rate
35 | warmup_iters = 2000 # how many steps to warm up for
36 | min_lr = 3e-5
37 | schedule = 'cosine'
38 | compile = True
39 |
40 | out_dir = 'out_small_muon_100k'
41 |
--------------------------------------------------------------------------------
/config/train_gpt2_xl_adamw.py:
--------------------------------------------------------------------------------
1 | wandb_log = True
2 | wandb_project = 'mars'
3 | wandb_run_name='gpt2-xl-adamw-100k'
4 |
5 | batch_size = 5
6 | block_size = 1024
7 | gradient_accumulation_steps = 12
8 |
9 | n_layer = 48
10 | n_head = 25
11 | n_embd = 1600
12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
13 | bias = False
14 | scale_attn_by_inverse_layer_idx = True
15 |
16 | # this makes total number of tokens be 300B
17 | max_iters = 100000
18 | lr_decay_iters = 100000
19 |
20 | # eval stuff
21 | eval_interval = 1000
22 | eval_iters = 200
23 | log_interval = 10
24 |
25 | # optimizer
26 | optimizer_name = 'adamw'
27 | learning_rate = 2e-4 # max learning rate
28 | weight_decay = 1e-1
29 | beta1 = 0.9
30 | beta2 = 0.95
31 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
32 | # learning rate decay settings
33 | decay_lr = True # whether to decay the learning rate
34 | warmup_iters = 2000 # how many steps to warm up for
35 | min_lr = 1e-5
36 |
37 | compile = True
38 |
39 | out_dir = 'out_large_adamw_100k'
40 |
--------------------------------------------------------------------------------
/config/train_gpt2_xl_mars.py:
--------------------------------------------------------------------------------
1 | wandb_log = True
2 | wandb_project = 'mars'
3 | wandb_run_name='gpt2-xl-mars-100k'
4 |
5 | batch_size = 5
6 | block_size = 1024
7 | gradient_accumulation_steps = 12
8 |
9 | n_layer = 48
10 | n_head = 25
11 | n_embd = 1600
12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
13 | bias = False
14 | scale_attn_by_inverse_layer_idx = True
15 |
16 | # this makes total number of tokens be 300B
17 | max_iters = 100000
18 | lr_decay_iters = 100000
19 |
20 | # eval stuff
21 | eval_interval = 1000
22 | eval_iters = 200
23 | log_interval = 10
24 |
25 | # optimizer
26 | optimizer_name = 'mars'
27 | learning_rate = 2e-3 # max learning rate
28 | weight_decay = 1e-2
29 | beta1 = 0.95
30 | beta2 = 0.99
31 | lr_1d=1e-3
32 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
33 | # learning rate decay settings
34 | decay_lr = True # whether to decay the learning rate
35 | warmup_iters = 2000 # how many steps to warm up for
36 | min_lr = 1e-5
37 |
38 | compile = True
39 |
40 | out_dir = 'out_large_mars_100k'
41 | gamma=0.025
42 |
--------------------------------------------------------------------------------
/data/openwebtext/prepare.py:
--------------------------------------------------------------------------------
1 | # saves the openwebtext dataset to a binary file for training. following was helpful:
2 | # https://github.com/HazyResearch/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py
3 |
4 | import os
5 | from tqdm import tqdm
6 | import numpy as np
7 | import tiktoken
8 | from datasets import load_dataset # huggingface datasets
9 |
10 | # number of workers in .map() call
11 | # good number to use is ~order number of cpu cores // 2
12 | num_proc = 52
13 |
14 | # takes 54GB in huggingface .cache dir, about 8M documents (8,013,769)
15 | dataset = load_dataset("openwebtext", cache_dir="nanoGPT/cache")
16 |
17 | # owt by default only contains the 'train' split, so create a test split
18 | split_dataset = dataset["train"].train_test_split(test_size=0.0005, seed=2357, shuffle=True)
19 | split_dataset['val'] = split_dataset.pop('test') # rename the test split to val
20 |
21 | # this results in:
22 | # >>> split_dataset
23 | # DatasetDict({
24 | # train: Dataset({
25 | # features: ['text'],
26 | # num_rows: 8009762
27 | # })
28 | # val: Dataset({
29 | # features: ['text'],
30 | # num_rows: 4007
31 | # })
32 | # })
33 |
34 | # we now want to tokenize the dataset. first define the encoding function (gpt2 bpe)
35 | enc = tiktoken.get_encoding("gpt2")
36 | def process(example):
37 | ids = enc.encode_ordinary(example['text']) # encode_ordinary ignores any special tokens
38 | ids.append(enc.eot_token) # add the end of text token, e.g. 50256 for gpt2 bpe
39 | # note: I think eot should be prepended not appended... hmm. it's called "eot" though...
40 | out = {'ids': ids, 'len': len(ids)}
41 | return out
42 |
43 | # tokenize the dataset
44 | tokenized = split_dataset.map(
45 | process,
46 | remove_columns=['text'],
47 | desc="tokenizing the splits",
48 | num_proc=num_proc,
49 | )
50 | print('tokenization finished')
51 | # concatenate all the ids in each dataset into one large file we can use for training
52 | for split, dset in tokenized.items():
53 | arr_len = np.sum(dset['len'])
54 | filename = os.path.join(os.path.dirname(__file__), f'{split}.bin')
55 | dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16)
56 | arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,))
57 |
58 | print(f"writing {filename}...")
59 | idx = 0
60 | for example in tqdm(dset):
61 | arr[idx : idx + example['len']] = example['ids']
62 | idx += example['len']
63 | arr.flush()
64 |
65 | # train.bin is ~17GB, val.bin ~8.5MB
66 | # train has ~9B tokens (9,035,582,198)
67 | # val has ~4M tokens (4,434,897)
68 |
69 | # to read the bin files later, e.g. with numpy:
70 | # m = np.memmap('train.bin', dtype=np.uint16, mode='r')
71 |
--------------------------------------------------------------------------------
/scripts/run_CNN.sh:
--------------------------------------------------------------------------------
1 | python MARS/train_CNN.py
--------------------------------------------------------------------------------
/scripts/run_CV.sh:
--------------------------------------------------------------------------------
1 | python MARS/train_CV.py
--------------------------------------------------------------------------------
/scripts/run_adamw_large.sh:
--------------------------------------------------------------------------------
1 | torchrun --standalone --nproc_per_node=8 \
2 | MARS/train_adamw.py \
3 | config/train_gpt2_large_adamw.py \
4 | --batch_size=5 \
5 | --gradient_accumulation_steps=12
--------------------------------------------------------------------------------
/scripts/run_adamw_medium.sh:
--------------------------------------------------------------------------------
1 | torchrun --standalone --nproc_per_node=8 \
2 | MARS/train_adamw.py \
3 | config/train_gpt2_medium_adamw.py \
4 | --batch_size=15 \
5 | --gradient_accumulation_steps=4
--------------------------------------------------------------------------------
/scripts/run_adamw_small.sh:
--------------------------------------------------------------------------------
1 | torchrun --standalone --nproc_per_node=8 \
2 | MARS/train_adamw.py \
3 | config/train_gpt2_small_adamw.py \
4 | --batch_size=15 \
5 | --gradient_accumulation_steps=4
--------------------------------------------------------------------------------
/scripts/run_adamw_small_fw.sh:
--------------------------------------------------------------------------------
1 | torchrun --standalone --nproc_per_node=8 \
2 | MARS/train_adamw_fw.py \
3 | config/train_gpt2_small_adamw.py \
4 | --batch_size=15 \
5 | --gradient_accumulation_steps=4
6 |
--------------------------------------------------------------------------------
/scripts/run_adamw_xl_fw.sh:
--------------------------------------------------------------------------------
1 | torchrun --standalone --nproc_per_node=8 \
2 | MARS/train_adanw_fw.py \
3 | config/train_gpt2_large_adamw.py \
4 | --batch_size=5 \
5 | --gradient_accumulation_steps=12
6 |
--------------------------------------------------------------------------------
/scripts/run_mars_large.sh:
--------------------------------------------------------------------------------
1 | torchrun --standalone --nproc_per_node=8 \
2 | MARS/train_mars.py \
3 | config/train_gpt2_large_mars.py \
4 | --batch_size=5 \
5 | --gradient_accumulation_steps=12
--------------------------------------------------------------------------------
/scripts/run_mars_medium.sh:
--------------------------------------------------------------------------------
1 | torchrun --standalone --nproc_per_node=8 \
2 | MARS/train_mars.py \
3 | config/train_gpt2_medium_mars.py \
4 | --batch_size=15 \
5 | --gradient_accumulation_steps=4
--------------------------------------------------------------------------------
/scripts/run_mars_small.sh:
--------------------------------------------------------------------------------
1 | torchrun --standalone --nproc_per_node=8 \
2 | MARS/train_mars.py \
3 | config/train_gpt2_small_mars.py \
4 | --batch_size=15 \
5 | --gradient_accumulation_steps=4
--------------------------------------------------------------------------------
/scripts/run_mars_small_fw.sh:
--------------------------------------------------------------------------------
1 | torchrun --standalone --nproc_per_node=8 \
2 | MARS/train_mars_fw.py \
3 | config/train_gpt2_small_mars.py \
4 | --batch_size=15 \
5 | --gradient_accumulation_steps=4
6 |
--------------------------------------------------------------------------------
/scripts/run_mars_xl_fw.sh:
--------------------------------------------------------------------------------
1 | torchrun --standalone --nproc_per_node=8 \
2 | MARS/train_mars_fw.py \
3 | config/train_gpt2_xl_mars.py \
4 | --batch_size=5 \
5 | --gradient_accumulation_steps=12
6 |
--------------------------------------------------------------------------------
/scripts/run_muon_large.sh:
--------------------------------------------------------------------------------
1 | torchrun --standalone --nproc_per_node=8 \
2 | MARS/train_muon.py \
3 | config/train_gpt2_large_muon.py \
4 | --batch_size=5 \
5 | --gradient_accumulation_steps=12
--------------------------------------------------------------------------------
/scripts/run_muon_medium.sh:
--------------------------------------------------------------------------------
1 | torchrun --standalone --nproc_per_node=8 \
2 | MARS/train_muon.py \
3 | config/train_gpt2_medium_muon.py \
4 | --batch_size=15 \
5 | --gradient_accumulation_steps=4
--------------------------------------------------------------------------------
/scripts/run_muon_small.sh:
--------------------------------------------------------------------------------
1 | torchrun --standalone --nproc_per_node=8 \
2 | MARS/train_muon.py \
3 | config/train_gpt2_small_muon.py \
4 | --batch_size=15 \
5 | --gradient_accumulation_steps=4
--------------------------------------------------------------------------------