├── .gitignore
├── Moonlight.pdf
├── figures
├── logo.png
├── banner.png
├── megatron.png
├── scaling.png
├── banner_short.png
├── fig_MMLU_performance.png
└── chinlaw_8k_flops_ratio.png
├── Moonlight_intermediate_checkpoints.pdf
├── requirements.txt
├── LICENSE
├── README.md
└── examples
└── toy_train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | logs
2 | *.bin
--------------------------------------------------------------------------------
/Moonlight.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MoonshotAI/Moonlight/HEAD/Moonlight.pdf
--------------------------------------------------------------------------------
/figures/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MoonshotAI/Moonlight/HEAD/figures/logo.png
--------------------------------------------------------------------------------
/figures/banner.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MoonshotAI/Moonlight/HEAD/figures/banner.png
--------------------------------------------------------------------------------
/figures/megatron.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MoonshotAI/Moonlight/HEAD/figures/megatron.png
--------------------------------------------------------------------------------
/figures/scaling.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MoonshotAI/Moonlight/HEAD/figures/scaling.png
--------------------------------------------------------------------------------
/figures/banner_short.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MoonshotAI/Moonlight/HEAD/figures/banner_short.png
--------------------------------------------------------------------------------
/figures/fig_MMLU_performance.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MoonshotAI/Moonlight/HEAD/figures/fig_MMLU_performance.png
--------------------------------------------------------------------------------
/figures/chinlaw_8k_flops_ratio.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MoonshotAI/Moonlight/HEAD/figures/chinlaw_8k_flops_ratio.png
--------------------------------------------------------------------------------
/Moonlight_intermediate_checkpoints.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MoonshotAI/Moonlight/HEAD/Moonlight_intermediate_checkpoints.pdf
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | datasets==3.3.2
2 | loguru==0.7.3
3 | numpy==2.2.3
4 | torch==2.6.0
5 | tqdm==4.67.1
6 | transformers==4.49.0
7 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 | Copyright © 2025 Moonshot AI
3 |
4 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
5 |
6 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
7 |
8 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |

3 |
4 |
5 |
6 |
7 |
12 |
13 |
14 | ## Abstract
15 | Recently, the [Muon optimizer](https://github.com/KellerJordan/Muon) based on matrix orthogonalization has demonstrated strong results in training small-scale language models, but the scalability to larger models has not been proven. We identify two crucial techniques for scaling up Muon: (1) adding weight decay and (2) carefully adjusting the per-parameter update scale. These techniques allow Muon to work out-of-the-box on large-scale training without the need of hyper-parameter tuning. Scaling law experiments indicate that Muon achieves ∼ 2× computational efficiency compared to AdamW with compute optimal training.
16 |
17 | Based on these improvements, we introduce **Moonlight**, a 3B/16B-parameter Mixture-of-Expert (MoE) model trained with 5.7T tokens using Muon. Our model improves the current Pareto frontier, achieving better performance with much fewer training FLOPs compared to prior models.
18 |
19 | We open-source our distributed Muon implementation that is memory optimal and communication efficient. We also release the pretrained, instruction-tuned, and intermediate checkpoints to support future research.
20 |
21 | Our code is available at [MoonshotAI/Moonlight](https://github.com/MoonshotAI/Moonlight).
22 |
23 | ## Key Ingredients
24 |
25 | Our work builds upon Muon while systematically identifying and resolving its limitations in large-scale training scenarios. Our technical contributions include:
26 |
27 | - **Analysis for Effective Scaling of Muon**: Through extensive analysis, we identify that weight decay plays a crucial roles in Muon's scalability. Besides, we proposed to keep a consistent update root mean square (RMS) across different matrix and non-matrix parameters through parameter-wise update scale adjustments. Such adjustments significantly enhanced training stability.
28 |
29 | - **Efficient Distributed Implementation**: We develop a distributed version of Muon with ZeRO-1 style optimization, achieving optimal memory efficiency and reduced communication overhead while preserving the mathematical properties of the algorithm.
30 |
31 | - **Scaling Law Validation**: We performed scaling law research that compares Muon with strong AdamW baselines, and showed the superior performance of Muon (see Figure 1). Based on the scaling law results, Muon achieves comparable performance to AdamW trained counterparts while requiring only approximately 52% of the training FLOPs.
32 |
33 |
34 |

35 |
Scaling up with Muon. (a) Scaling law experiments comparing Muon and Adam. Muon is 2 times more sample efficient than Adam. (b) The MMLU performance of our Moonlight model optimized with Muon and other comparable models. Moonlight advances the Pareto frontier of performance vs training FLOPs.
36 |
37 |
38 |
39 | ## Performance
40 |
41 | We named our lightweight model trained with Muon "Moonlight". We compared Moonlight with SOTA public models at similar scale:
42 |
43 | - **LLAMA3-3B** is a 3B-parameter dense model trained with 9T tokens
44 | - **Qwen2.5-3B** is a 3B-parameter dense model trained with 18T tokens
45 | - **Deepseek-v2-Lite** is a 2.4B/16B-parameter MOE model trained with 5.7T tokens
46 |
47 | | | **Benchmark (Metric)** | **Llama3.2-3B** | **Qwen2.5-3B** | **DSV2-Lite** | **Moonlight** |
48 | |---|---|---|---|---|---|
49 | | | Activated Param† | 2.81B | 2.77B | 2.24B | 2.24B |
50 | | | Total Params† | 2.81B | 2.77B | 15.29B | 15.29B |
51 | | | Training Tokens | 9T | 18T | 5.7T | 5.7T |
52 | | | Optimizer | AdamW | * | AdamW | Muon |
53 | | **English** | MMLU | 54.75 | 65.6 | 58.3 | **70.0** |
54 | | | MMLU-pro | 25.0 | 34.6 | 25.5 | **42.4** |
55 | | | BBH | 46.8 | 56.3 | 44.1 | **65.2** |
56 | | | TriviaQA‡ | 59.6 | 51.1 | 65.1 | **66.3** |
57 | | **Code** | HumanEval | 28.0 | 42.1 | 29.9 | **48.1** |
58 | | | MBPP | 48.7 | 57.1 | 43.2 | **63.8** |
59 | | **Math** | GSM8K | 34.0 | **79.1** | 41.1 | 77.4 |
60 | | | MATH | 8.5 | 42.6 | 17.1 | **45.3** |
61 | | | CMath | - | 80.0 | 58.4 | **81.1** |
62 | | **Chinese** | C-Eval | - | 75.0 | 60.3 | **77.2** |
63 | | | CMMLU | - | 75.0 | 64.3 | **78.2** |
64 |
65 | *Qwen 2 & 2.5 reports didn't disclose their optimizer information. †The reported parameter counts exclude the embedding parameters. ‡We test all listed models with the full set of TriviaQA.*
66 |
67 |
68 | ## Example usage
69 | ### Model Download
70 |
71 |
72 |
73 | | **Model** | **#Total Params** | **#Activated Params** | **Context Length** | **Download Link** |
74 | | :------------: | :------------: | :------------: | :------------: | :------------: |
75 | | Moonlight | 16B | 3B | 8K | [🤗 Hugging Face](https://huggingface.co/moonshotai/Moonlight-16B-A3B) |
76 | | Moonlight-Instruct | 16B | 3B | 8K | [🤗 Hugging Face](https://huggingface.co/moonshotai/Moonlight-16B-A3B-Instruct) |
77 |
78 |
79 |
80 | ### Inference with Hugging Face Transformers
81 |
82 | We introduce how to use our model at inference stage using transformers library. It is recommended to use python=3.10, torch>=2.1.0, and transformers=4.48.2 as the development environment.
83 |
84 | For our pretrained model (Moonlight):
85 | ```python
86 | from transformers import AutoModelForCausalLM, AutoTokenizer
87 |
88 | model_path = "moonshotai/Moonlight-16B-A3B"
89 | model = AutoModelForCausalLM.from_pretrained(
90 | model_path,
91 | torch_dtype="auto",
92 | device_map="auto",
93 | trust_remote_code=True,
94 | )
95 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
96 |
97 | prompt = "1+1=2, 1+2="
98 | inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(model.device)
99 | generated_ids = model.generate(**inputs, max_new_tokens=100)
100 | response = tokenizer.batch_decode(generated_ids)[0]
101 | print(response)
102 | ```
103 |
104 | For our instruct model (Moonlight-Instruct):
105 |
106 | ```python
107 | from transformers import AutoModelForCausalLM, AutoTokenizer
108 |
109 | model_path = "moonshotai/Moonlight-16B-A3B-Instruct"
110 | model = AutoModelForCausalLM.from_pretrained(
111 | model_path,
112 | torch_dtype="auto",
113 | device_map="auto",
114 | trust_remote_code=True
115 | )
116 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
117 |
118 | messages = [
119 | {"role": "system", "content": "You are a helpful assistant provided by Moonshot-AI."},
120 | {"role": "user", "content": "Is 123 a prime?"}
121 | ]
122 | input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
123 | generated_ids = model.generate(inputs=input_ids, max_new_tokens=500)
124 | response = tokenizer.batch_decode(generated_ids)[0]
125 | print(response)
126 | ```
127 |
128 | Moonlight has the same architecture as DeepSeek-V3, which is supported by many popular inference engines, such as VLLM and SGLang. As a result, our model can also be easily deployed using these tools.
129 |
130 | ### Training
131 | ```
132 | # train qwen-like dense model with muon
133 | python3 examples/toy_train.py --model qwen --optimizer muon --dataset openwebtext-100k --hidden_size 896 --lr 1e-3
134 |
135 | # train qwen-like dense model with adamw
136 | python3 examples/toy_train.py --model qwen --optimizer adamw --dataset openwebtext-100k --hidden_size 896 --lr 1e-3
137 | ```
138 |
139 | ## Intermediate Checkpoints
140 | To support ongoing research efforts, we will soon release our intermediate checkpoints. Coming soon...
141 |
142 | ## Citation
143 | If you find Moonlight is useful or want to use in your projects, please kindly cite our paper:
144 | ```
145 | @misc{liu2025muonscalablellmtraining,
146 | title={Muon is Scalable for LLM Training},
147 | author={Jingyuan Liu and Jianlin Su and Xingcheng Yao and Zhejun Jiang and Guokun Lai and Yulun Du and Yidao Qin and Weixin Xu and Enzhe Lu and Junjie Yan and Yanru Chen and Huabin Zheng and Yibo Liu and Shaowei Liu and Bohong Yin and Weiran He and Han Zhu and Yuzhi Wang and Jianzhou Wang and Mengnan Dong and Zheng Zhang and Yongsheng Kang and Hao Zhang and Xinran Xu and Yutao Zhang and Yuxin Wu and Xinyu Zhou and Zhilin Yang},
148 | year={2025},
149 | eprint={2502.16982},
150 | archivePrefix={arXiv},
151 | primaryClass={cs.LG},
152 | url={https://arxiv.org/abs/2502.16982},
153 | }
154 | ```
155 |
--------------------------------------------------------------------------------
/examples/toy_train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import torch
4 | from loguru import logger
5 | from datasets import load_dataset
6 | from torch.utils.data import DataLoader, Dataset
7 | from transformers import (
8 | Qwen2Config,
9 | Qwen2ForCausalLM,
10 | Qwen2Tokenizer,
11 | get_cosine_schedule_with_warmup,
12 | )
13 | from tqdm import tqdm
14 |
15 |
16 | class MoonDataset(Dataset):
17 | def __init__(self, dataset_name, dataset, tokenizer, max_length=512):
18 | self.dataset_name = dataset_name
19 | self.dataset = dataset
20 | self.tokenizer = tokenizer
21 | self.texts = dataset["train"]["text"]
22 | self.max_length = max_length
23 | self.tokens = []
24 | self._tokenize_texts()
25 |
26 | def _tokenize_texts(self):
27 | if os.path.exists(f"{self.dataset_name}.bin"):
28 | self.tokens = torch.load(f"{self.dataset_name}.bin")
29 | else:
30 | for text in tqdm(self.texts, desc="Tokenizing texts"):
31 | encoded = self.tokenizer.encode(text, add_special_tokens=True)
32 | self.tokens.extend(encoded)
33 | torch.save(self.tokens, f"{self.dataset_name}.bin")
34 |
35 | def __len__(self):
36 | return len(self.tokens) // self.max_length
37 |
38 | def __getitem__(self, idx):
39 | start_idx = idx * (self.max_length)
40 | end_idx = start_idx + (self.max_length)
41 | token_slice = self.tokens[start_idx:end_idx]
42 | data = torch.tensor(token_slice, dtype=torch.long)
43 | return data
44 |
45 |
46 | # This code snippet is a modified version adapted from the following GitHub repository:
47 | # https://github.com/KellerJordan/Muon/blob/master/muon.py
48 | @torch.compile
49 | def zeropower_via_newtonschulz5(G, steps):
50 | """
51 | Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
52 | quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
53 | of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
54 | zero even beyond the point where the iteration no longer converges all the way to one everywhere
55 | on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
56 | where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
57 | performance at all relative to UV^T, where USV^T = G is the SVD.
58 | """
59 | assert len(G.shape) == 2
60 | a, b, c = (3.4445, -4.7750, 2.0315)
61 | X = G.bfloat16()
62 | if G.size(0) > G.size(1):
63 | X = X.T
64 | # Ensure spectral norm is at most 1
65 | X = X / (X.norm() + 1e-7)
66 | # Perform the NS iterations
67 | for _ in range(steps):
68 | A = X @ X.T
69 | B = (
70 | b * A + c * A @ A
71 | ) # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
72 | X = a * X + B @ X
73 |
74 | if G.size(0) > G.size(1):
75 | X = X.T
76 | return X
77 |
78 |
79 | class Muon(torch.optim.Optimizer):
80 | """
81 | Muon - MomentUm Orthogonalized by Newton-schulz
82 |
83 | Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
84 | processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
85 | matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
86 | the advantage that it can be stably run in bfloat16 on the GPU.
87 |
88 | Some warnings:
89 | - We believe this optimizer is unlikely to work well for training with small batch size.
90 | - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
91 |
92 | Arguments:
93 | muon_params: The parameters to be optimized by Muon.
94 | lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
95 | momentum: The momentum used by the internal SGD. (0.95 is a good default)
96 | nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
97 | ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
98 | adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are
99 | {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
100 | adamw_lr: The learning rate for the internal AdamW.
101 | adamw_betas: The betas for the internal AdamW.
102 | adamw_eps: The epsilon for the internal AdamW.
103 | adamw_wd: The weight decay for the internal AdamW.
104 | """
105 |
106 | def __init__(
107 | self,
108 | lr=1e-3,
109 | wd=0.1,
110 | muon_params=None,
111 | momentum=0.95,
112 | nesterov=True,
113 | ns_steps=5,
114 | adamw_params=None,
115 | adamw_betas=(0.9, 0.95),
116 | adamw_eps=1e-8,
117 | ):
118 |
119 | defaults = dict(
120 | lr=lr,
121 | wd=wd,
122 | momentum=momentum,
123 | nesterov=nesterov,
124 | ns_steps=ns_steps,
125 | adamw_betas=adamw_betas,
126 | adamw_eps=adamw_eps,
127 | )
128 |
129 | params = list(muon_params)
130 | adamw_params = list(adamw_params) if adamw_params is not None else []
131 | params.extend(adamw_params)
132 | super().__init__(params, defaults)
133 | # Sort parameters into those for which we will use Muon, and those for which we will not
134 | for p in muon_params:
135 | # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
136 | assert p.ndim == 2, p.ndim
137 | self.state[p]["use_muon"] = True
138 | for p in adamw_params:
139 | # Do not use Muon for parameters in adamw_params
140 | self.state[p]["use_muon"] = False
141 |
142 | def adjust_lr_for_muon(self, lr, param_shape):
143 | A, B = param_shape[:2]
144 | # We adjust the learning rate and weight decay based on the size of the parameter matrix
145 | # as describted in the paper
146 | adjusted_ratio = 0.2 * math.sqrt(max(A, B))
147 | adjusted_lr = lr * adjusted_ratio
148 | return adjusted_lr
149 |
150 | def step(self, closure=None):
151 | """Perform a single optimization step.
152 |
153 | Args:
154 | closure (Callable, optional): A closure that reevaluates the model
155 | and returns the loss.
156 | """
157 | loss = None
158 | if closure is not None:
159 | with torch.enable_grad():
160 | loss = closure()
161 |
162 | for group in self.param_groups:
163 |
164 | ############################
165 | # Muon #
166 | ############################
167 |
168 | params = [p for p in group["params"] if self.state[p]["use_muon"]]
169 | # import pdb; pdb.set_trace()
170 | lr = group["lr"]
171 | wd = group["wd"]
172 | momentum = group["momentum"]
173 |
174 | # generate weight updates
175 | for p in params:
176 | # sanity check
177 | g = p.grad
178 | if g is None:
179 | continue
180 | if g.ndim > 2:
181 | g = g.view(g.size(0), -1)
182 | assert g is not None
183 |
184 | # calc update
185 | state = self.state[p]
186 | if "momentum_buffer" not in state:
187 | state["momentum_buffer"] = torch.zeros_like(g)
188 | buf = state["momentum_buffer"]
189 | buf.mul_(momentum).add_(g)
190 | if group["nesterov"]:
191 | g = g.add(buf, alpha=momentum)
192 | else:
193 | g = buf
194 | u = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
195 |
196 | # scale update
197 | adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
198 |
199 | # apply weight decay
200 | p.data.mul_(1 - lr * wd)
201 |
202 | # apply update
203 | p.data.add_(u, alpha=-adjusted_lr)
204 |
205 | ############################
206 | # AdamW backup #
207 | ############################
208 |
209 | params = [p for p in group["params"] if not self.state[p]["use_muon"]]
210 | lr = group['lr']
211 | beta1, beta2 = group["adamw_betas"]
212 | eps = group["adamw_eps"]
213 | weight_decay = group["wd"]
214 |
215 | for p in params:
216 | g = p.grad
217 | if g is None:
218 | continue
219 | state = self.state[p]
220 | if "step" not in state:
221 | state["step"] = 0
222 | state["moment1"] = torch.zeros_like(g)
223 | state["moment2"] = torch.zeros_like(g)
224 | state["step"] += 1
225 | step = state["step"]
226 | buf1 = state["moment1"]
227 | buf2 = state["moment2"]
228 | buf1.lerp_(g, 1 - beta1)
229 | buf2.lerp_(g.square(), 1 - beta2)
230 |
231 | g = buf1 / (eps + buf2.sqrt())
232 |
233 | bias_correction1 = 1 - beta1**step
234 | bias_correction2 = 1 - beta2**step
235 | scale = bias_correction1 / bias_correction2**0.5
236 | p.data.mul_(1 - lr * weight_decay)
237 | p.data.add_(g, alpha=-lr / scale)
238 |
239 | return loss
240 |
241 |
242 | def get_model_and_dataloader(model_name, dataset_name, hidden_size):
243 | name2path = {
244 | "openwebtext-100k": "Elriggs/openwebtext-100k",
245 | }
246 | train_dataset = load_dataset(name2path[dataset_name], trust_remote_code=True)
247 | if model_name == "qwen":
248 | tokenizer = Qwen2Tokenizer.from_pretrained(
249 | "Qwen/Qwen2.5-0.5B", trust_remote_code=True
250 | )
251 | else:
252 | assert 0, f"model {model_name} not supported"
253 | train_dataset = MoonDataset(dataset_name, train_dataset, tokenizer)
254 | train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
255 |
256 | if model_name == "qwen":
257 | config = Qwen2Config(
258 | attention_dropout=0.0,
259 | bos_token_id=151643,
260 | eos_token_id=151643,
261 | hidden_act="silu",
262 | hidden_size=hidden_size,
263 | initializer_range=0.02,
264 | intermediate_size=4864,
265 | max_position_embeddings=513,
266 | max_window_layers=12,
267 | model_type="qwen2",
268 | num_attention_heads=16,
269 | num_hidden_layers=12,
270 | num_key_value_heads=16,
271 | rms_norm_eps=1e-06,
272 | rope_theta=1000000.0,
273 | sliding_window=1024,
274 | tie_word_embeddings=True,
275 | torch_dtype="bfloat16",
276 | use_cache=True,
277 | use_mrope=False,
278 | use_sliding_window=False,
279 | vocab_size=151936,
280 | )
281 | model = Qwen2ForCausalLM(config)
282 | else:
283 | assert 0, f"model {model_name} not supported"
284 | return model, train_loader
285 |
286 |
287 | def get_optimizer(optimizer_name, model, lr=1e-3, wd=0.1):
288 | if optimizer_name == "adamw":
289 | return torch.optim.AdamW(
290 | model.parameters(), lr=lr, weight_decay=wd, betas=(0.9, 0.95)
291 | )
292 | elif optimizer_name == "muon":
293 | muon_params = [
294 | p
295 | for name, p in model.named_parameters()
296 | if p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
297 | ]
298 | adamw_params = [
299 | p
300 | for name, p in model.named_parameters()
301 | if not (
302 | p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
303 | )
304 | ]
305 |
306 | return Muon(
307 | lr=lr,
308 | wd=wd,
309 | muon_params=muon_params,
310 | adamw_params=adamw_params,
311 | )
312 | else:
313 | assert 0, "optimizer not supported"
314 |
315 |
316 | if __name__ == "__main__":
317 | import argparse
318 |
319 | parser = argparse.ArgumentParser()
320 | parser.add_argument("--model", type=str, default="qwen")
321 | parser.add_argument("--optimizer", type=str, default="adamw")
322 | parser.add_argument("--lr", type=float, default=1e-3)
323 | parser.add_argument("--wd", type=float, default=0.1)
324 | parser.add_argument("--dataset", type=str, default="openwebtext-100k")
325 | parser.add_argument("--hidden_size", type=int, default=1024)
326 | args = parser.parse_args()
327 | logger.add(f"logs/train_{args.model}_{args.optimizer}_lr{args.lr}.log")
328 |
329 | model, train_loader = get_model_and_dataloader(
330 | args.model, args.dataset, args.hidden_size
331 | )
332 | optimizer = get_optimizer(
333 | args.optimizer, model, lr=args.lr
334 | )
335 |
336 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
337 | model.to(device)
338 |
339 | model.train()
340 | epoch = 1
341 | lr_scheduler = get_cosine_schedule_with_warmup(
342 | optimizer=optimizer,
343 | num_warmup_steps=100,
344 | num_training_steps=len(train_loader) * epoch,
345 | num_cycles=0.5,
346 | )
347 | for epoch in range(epoch):
348 | for step, batch in enumerate(train_loader):
349 | batch = batch.to(device)
350 | input_ids = batch
351 | outputs = model(input_ids=input_ids, labels=input_ids)
352 | loss = outputs.loss
353 | loss.backward()
354 | optimizer.step()
355 | lr_scheduler.step()
356 | optimizer.zero_grad()
357 | logger.info(
358 | f"Epoch: {epoch} Step: {step} LR: {optimizer.param_groups[0]['lr']} Training loss: {loss.item()}"
359 | )
360 |
--------------------------------------------------------------------------------