├── .gitignore ├── Moonlight.pdf ├── figures ├── logo.png ├── banner.png ├── megatron.png ├── scaling.png ├── banner_short.png ├── fig_MMLU_performance.png └── chinlaw_8k_flops_ratio.png ├── Moonlight_intermediate_checkpoints.pdf ├── requirements.txt ├── LICENSE ├── README.md └── examples └── toy_train.py /.gitignore: -------------------------------------------------------------------------------- 1 | logs 2 | *.bin -------------------------------------------------------------------------------- /Moonlight.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MoonshotAI/Moonlight/HEAD/Moonlight.pdf -------------------------------------------------------------------------------- /figures/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MoonshotAI/Moonlight/HEAD/figures/logo.png -------------------------------------------------------------------------------- /figures/banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MoonshotAI/Moonlight/HEAD/figures/banner.png -------------------------------------------------------------------------------- /figures/megatron.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MoonshotAI/Moonlight/HEAD/figures/megatron.png -------------------------------------------------------------------------------- /figures/scaling.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MoonshotAI/Moonlight/HEAD/figures/scaling.png -------------------------------------------------------------------------------- /figures/banner_short.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MoonshotAI/Moonlight/HEAD/figures/banner_short.png -------------------------------------------------------------------------------- /figures/fig_MMLU_performance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MoonshotAI/Moonlight/HEAD/figures/fig_MMLU_performance.png -------------------------------------------------------------------------------- /figures/chinlaw_8k_flops_ratio.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MoonshotAI/Moonlight/HEAD/figures/chinlaw_8k_flops_ratio.png -------------------------------------------------------------------------------- /Moonlight_intermediate_checkpoints.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MoonshotAI/Moonlight/HEAD/Moonlight_intermediate_checkpoints.pdf -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | datasets==3.3.2 2 | loguru==0.7.3 3 | numpy==2.2.3 4 | torch==2.6.0 5 | tqdm==4.67.1 6 | transformers==4.49.0 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | Copyright © 2025 Moonshot AI 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 5 | 6 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 7 | 8 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |
4 | 5 | 6 | 7 |
8 | Tech Report | 9 | HuggingFace | 10 | Megatron-LM 11 |
12 | 13 | 14 | ## Abstract 15 | Recently, the [Muon optimizer](https://github.com/KellerJordan/Muon) based on matrix orthogonalization has demonstrated strong results in training small-scale language models, but the scalability to larger models has not been proven. We identify two crucial techniques for scaling up Muon: (1) adding weight decay and (2) carefully adjusting the per-parameter update scale. These techniques allow Muon to work out-of-the-box on large-scale training without the need of hyper-parameter tuning. Scaling law experiments indicate that Muon achieves ∼ 2× computational efficiency compared to AdamW with compute optimal training. 16 | 17 | Based on these improvements, we introduce **Moonlight**, a 3B/16B-parameter Mixture-of-Expert (MoE) model trained with 5.7T tokens using Muon. Our model improves the current Pareto frontier, achieving better performance with much fewer training FLOPs compared to prior models. 18 | 19 | We open-source our distributed Muon implementation that is memory optimal and communication efficient. We also release the pretrained, instruction-tuned, and intermediate checkpoints to support future research. 20 | 21 | Our code is available at [MoonshotAI/Moonlight](https://github.com/MoonshotAI/Moonlight). 22 | 23 | ## Key Ingredients 24 | 25 | Our work builds upon Muon while systematically identifying and resolving its limitations in large-scale training scenarios. Our technical contributions include: 26 | 27 | - **Analysis for Effective Scaling of Muon**: Through extensive analysis, we identify that weight decay plays a crucial roles in Muon's scalability. Besides, we proposed to keep a consistent update root mean square (RMS) across different matrix and non-matrix parameters through parameter-wise update scale adjustments. Such adjustments significantly enhanced training stability. 28 | 29 | - **Efficient Distributed Implementation**: We develop a distributed version of Muon with ZeRO-1 style optimization, achieving optimal memory efficiency and reduced communication overhead while preserving the mathematical properties of the algorithm. 30 | 31 | - **Scaling Law Validation**: We performed scaling law research that compares Muon with strong AdamW baselines, and showed the superior performance of Muon (see Figure 1). Based on the scaling law results, Muon achieves comparable performance to AdamW trained counterparts while requiring only approximately 52% of the training FLOPs. 32 | 33 |
34 | 35 |

Scaling up with Muon. (a) Scaling law experiments comparing Muon and Adam. Muon is 2 times more sample efficient than Adam. (b) The MMLU performance of our Moonlight model optimized with Muon and other comparable models. Moonlight advances the Pareto frontier of performance vs training FLOPs.

36 |
37 | 38 | 39 | ## Performance 40 | 41 | We named our lightweight model trained with Muon "Moonlight". We compared Moonlight with SOTA public models at similar scale: 42 | 43 | - **LLAMA3-3B** is a 3B-parameter dense model trained with 9T tokens 44 | - **Qwen2.5-3B** is a 3B-parameter dense model trained with 18T tokens 45 | - **Deepseek-v2-Lite** is a 2.4B/16B-parameter MOE model trained with 5.7T tokens 46 | 47 | | | **Benchmark (Metric)** | **Llama3.2-3B** | **Qwen2.5-3B** | **DSV2-Lite** | **Moonlight** | 48 | |---|---|---|---|---|---| 49 | | | Activated Param† | 2.81B | 2.77B | 2.24B | 2.24B | 50 | | | Total Params† | 2.81B | 2.77B | 15.29B | 15.29B | 51 | | | Training Tokens | 9T | 18T | 5.7T | 5.7T | 52 | | | Optimizer | AdamW | * | AdamW | Muon | 53 | | **English** | MMLU | 54.75 | 65.6 | 58.3 | **70.0** | 54 | | | MMLU-pro | 25.0 | 34.6 | 25.5 | **42.4** | 55 | | | BBH | 46.8 | 56.3 | 44.1 | **65.2** | 56 | | | TriviaQA‡ | 59.6 | 51.1 | 65.1 | **66.3** | 57 | | **Code** | HumanEval | 28.0 | 42.1 | 29.9 | **48.1** | 58 | | | MBPP | 48.7 | 57.1 | 43.2 | **63.8** | 59 | | **Math** | GSM8K | 34.0 | **79.1** | 41.1 | 77.4 | 60 | | | MATH | 8.5 | 42.6 | 17.1 | **45.3** | 61 | | | CMath | - | 80.0 | 58.4 | **81.1** | 62 | | **Chinese** | C-Eval | - | 75.0 | 60.3 | **77.2** | 63 | | | CMMLU | - | 75.0 | 64.3 | **78.2** | 64 | 65 | *Qwen 2 & 2.5 reports didn't disclose their optimizer information. †The reported parameter counts exclude the embedding parameters. ‡We test all listed models with the full set of TriviaQA.* 66 | 67 | 68 | ## Example usage 69 | ### Model Download 70 | 71 |
72 | 73 | | **Model** | **#Total Params** | **#Activated Params** | **Context Length** | **Download Link** | 74 | | :------------: | :------------: | :------------: | :------------: | :------------: | 75 | | Moonlight | 16B | 3B | 8K | [🤗 Hugging Face](https://huggingface.co/moonshotai/Moonlight-16B-A3B) | 76 | | Moonlight-Instruct | 16B | 3B | 8K | [🤗 Hugging Face](https://huggingface.co/moonshotai/Moonlight-16B-A3B-Instruct) | 77 | 78 |
79 | 80 | ### Inference with Hugging Face Transformers 81 | 82 | We introduce how to use our model at inference stage using transformers library. It is recommended to use python=3.10, torch>=2.1.0, and transformers=4.48.2 as the development environment. 83 | 84 | For our pretrained model (Moonlight): 85 | ```python 86 | from transformers import AutoModelForCausalLM, AutoTokenizer 87 | 88 | model_path = "moonshotai/Moonlight-16B-A3B" 89 | model = AutoModelForCausalLM.from_pretrained( 90 | model_path, 91 | torch_dtype="auto", 92 | device_map="auto", 93 | trust_remote_code=True, 94 | ) 95 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) 96 | 97 | prompt = "1+1=2, 1+2=" 98 | inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(model.device) 99 | generated_ids = model.generate(**inputs, max_new_tokens=100) 100 | response = tokenizer.batch_decode(generated_ids)[0] 101 | print(response) 102 | ``` 103 | 104 | For our instruct model (Moonlight-Instruct): 105 | 106 | ```python 107 | from transformers import AutoModelForCausalLM, AutoTokenizer 108 | 109 | model_path = "moonshotai/Moonlight-16B-A3B-Instruct" 110 | model = AutoModelForCausalLM.from_pretrained( 111 | model_path, 112 | torch_dtype="auto", 113 | device_map="auto", 114 | trust_remote_code=True 115 | ) 116 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) 117 | 118 | messages = [ 119 | {"role": "system", "content": "You are a helpful assistant provided by Moonshot-AI."}, 120 | {"role": "user", "content": "Is 123 a prime?"} 121 | ] 122 | input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device) 123 | generated_ids = model.generate(inputs=input_ids, max_new_tokens=500) 124 | response = tokenizer.batch_decode(generated_ids)[0] 125 | print(response) 126 | ``` 127 | 128 | Moonlight has the same architecture as DeepSeek-V3, which is supported by many popular inference engines, such as VLLM and SGLang. As a result, our model can also be easily deployed using these tools. 129 | 130 | ### Training 131 | ``` 132 | # train qwen-like dense model with muon 133 | python3 examples/toy_train.py --model qwen --optimizer muon --dataset openwebtext-100k --hidden_size 896 --lr 1e-3 134 | 135 | # train qwen-like dense model with adamw 136 | python3 examples/toy_train.py --model qwen --optimizer adamw --dataset openwebtext-100k --hidden_size 896 --lr 1e-3 137 | ``` 138 | 139 | ## Intermediate Checkpoints 140 | To support ongoing research efforts, we will soon release our intermediate checkpoints. Coming soon... 141 | 142 | ## Citation 143 | If you find Moonlight is useful or want to use in your projects, please kindly cite our paper: 144 | ``` 145 | @misc{liu2025muonscalablellmtraining, 146 | title={Muon is Scalable for LLM Training}, 147 | author={Jingyuan Liu and Jianlin Su and Xingcheng Yao and Zhejun Jiang and Guokun Lai and Yulun Du and Yidao Qin and Weixin Xu and Enzhe Lu and Junjie Yan and Yanru Chen and Huabin Zheng and Yibo Liu and Shaowei Liu and Bohong Yin and Weiran He and Han Zhu and Yuzhi Wang and Jianzhou Wang and Mengnan Dong and Zheng Zhang and Yongsheng Kang and Hao Zhang and Xinran Xu and Yutao Zhang and Yuxin Wu and Xinyu Zhou and Zhilin Yang}, 148 | year={2025}, 149 | eprint={2502.16982}, 150 | archivePrefix={arXiv}, 151 | primaryClass={cs.LG}, 152 | url={https://arxiv.org/abs/2502.16982}, 153 | } 154 | ``` 155 | -------------------------------------------------------------------------------- /examples/toy_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | from loguru import logger 5 | from datasets import load_dataset 6 | from torch.utils.data import DataLoader, Dataset 7 | from transformers import ( 8 | Qwen2Config, 9 | Qwen2ForCausalLM, 10 | Qwen2Tokenizer, 11 | get_cosine_schedule_with_warmup, 12 | ) 13 | from tqdm import tqdm 14 | 15 | 16 | class MoonDataset(Dataset): 17 | def __init__(self, dataset_name, dataset, tokenizer, max_length=512): 18 | self.dataset_name = dataset_name 19 | self.dataset = dataset 20 | self.tokenizer = tokenizer 21 | self.texts = dataset["train"]["text"] 22 | self.max_length = max_length 23 | self.tokens = [] 24 | self._tokenize_texts() 25 | 26 | def _tokenize_texts(self): 27 | if os.path.exists(f"{self.dataset_name}.bin"): 28 | self.tokens = torch.load(f"{self.dataset_name}.bin") 29 | else: 30 | for text in tqdm(self.texts, desc="Tokenizing texts"): 31 | encoded = self.tokenizer.encode(text, add_special_tokens=True) 32 | self.tokens.extend(encoded) 33 | torch.save(self.tokens, f"{self.dataset_name}.bin") 34 | 35 | def __len__(self): 36 | return len(self.tokens) // self.max_length 37 | 38 | def __getitem__(self, idx): 39 | start_idx = idx * (self.max_length) 40 | end_idx = start_idx + (self.max_length) 41 | token_slice = self.tokens[start_idx:end_idx] 42 | data = torch.tensor(token_slice, dtype=torch.long) 43 | return data 44 | 45 | 46 | # This code snippet is a modified version adapted from the following GitHub repository: 47 | # https://github.com/KellerJordan/Muon/blob/master/muon.py 48 | @torch.compile 49 | def zeropower_via_newtonschulz5(G, steps): 50 | """ 51 | Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a 52 | quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose 53 | of minimizing steps, it turns out to be empirically effective to keep increasing the slope at 54 | zero even beyond the point where the iteration no longer converges all the way to one everywhere 55 | on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T 56 | where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model 57 | performance at all relative to UV^T, where USV^T = G is the SVD. 58 | """ 59 | assert len(G.shape) == 2 60 | a, b, c = (3.4445, -4.7750, 2.0315) 61 | X = G.bfloat16() 62 | if G.size(0) > G.size(1): 63 | X = X.T 64 | # Ensure spectral norm is at most 1 65 | X = X / (X.norm() + 1e-7) 66 | # Perform the NS iterations 67 | for _ in range(steps): 68 | A = X @ X.T 69 | B = ( 70 | b * A + c * A @ A 71 | ) # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng 72 | X = a * X + B @ X 73 | 74 | if G.size(0) > G.size(1): 75 | X = X.T 76 | return X 77 | 78 | 79 | class Muon(torch.optim.Optimizer): 80 | """ 81 | Muon - MomentUm Orthogonalized by Newton-schulz 82 | 83 | Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- 84 | processing step, in which each 2D parameter's update is replaced with the nearest orthogonal 85 | matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has 86 | the advantage that it can be stably run in bfloat16 on the GPU. 87 | 88 | Some warnings: 89 | - We believe this optimizer is unlikely to work well for training with small batch size. 90 | - We believe it may not work well for finetuning pretrained models, but we haven't tested this. 91 | 92 | Arguments: 93 | muon_params: The parameters to be optimized by Muon. 94 | lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) 95 | momentum: The momentum used by the internal SGD. (0.95 is a good default) 96 | nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) 97 | ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) 98 | adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are 99 | {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. 100 | adamw_lr: The learning rate for the internal AdamW. 101 | adamw_betas: The betas for the internal AdamW. 102 | adamw_eps: The epsilon for the internal AdamW. 103 | adamw_wd: The weight decay for the internal AdamW. 104 | """ 105 | 106 | def __init__( 107 | self, 108 | lr=1e-3, 109 | wd=0.1, 110 | muon_params=None, 111 | momentum=0.95, 112 | nesterov=True, 113 | ns_steps=5, 114 | adamw_params=None, 115 | adamw_betas=(0.9, 0.95), 116 | adamw_eps=1e-8, 117 | ): 118 | 119 | defaults = dict( 120 | lr=lr, 121 | wd=wd, 122 | momentum=momentum, 123 | nesterov=nesterov, 124 | ns_steps=ns_steps, 125 | adamw_betas=adamw_betas, 126 | adamw_eps=adamw_eps, 127 | ) 128 | 129 | params = list(muon_params) 130 | adamw_params = list(adamw_params) if adamw_params is not None else [] 131 | params.extend(adamw_params) 132 | super().__init__(params, defaults) 133 | # Sort parameters into those for which we will use Muon, and those for which we will not 134 | for p in muon_params: 135 | # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer 136 | assert p.ndim == 2, p.ndim 137 | self.state[p]["use_muon"] = True 138 | for p in adamw_params: 139 | # Do not use Muon for parameters in adamw_params 140 | self.state[p]["use_muon"] = False 141 | 142 | def adjust_lr_for_muon(self, lr, param_shape): 143 | A, B = param_shape[:2] 144 | # We adjust the learning rate and weight decay based on the size of the parameter matrix 145 | # as describted in the paper 146 | adjusted_ratio = 0.2 * math.sqrt(max(A, B)) 147 | adjusted_lr = lr * adjusted_ratio 148 | return adjusted_lr 149 | 150 | def step(self, closure=None): 151 | """Perform a single optimization step. 152 | 153 | Args: 154 | closure (Callable, optional): A closure that reevaluates the model 155 | and returns the loss. 156 | """ 157 | loss = None 158 | if closure is not None: 159 | with torch.enable_grad(): 160 | loss = closure() 161 | 162 | for group in self.param_groups: 163 | 164 | ############################ 165 | # Muon # 166 | ############################ 167 | 168 | params = [p for p in group["params"] if self.state[p]["use_muon"]] 169 | # import pdb; pdb.set_trace() 170 | lr = group["lr"] 171 | wd = group["wd"] 172 | momentum = group["momentum"] 173 | 174 | # generate weight updates 175 | for p in params: 176 | # sanity check 177 | g = p.grad 178 | if g is None: 179 | continue 180 | if g.ndim > 2: 181 | g = g.view(g.size(0), -1) 182 | assert g is not None 183 | 184 | # calc update 185 | state = self.state[p] 186 | if "momentum_buffer" not in state: 187 | state["momentum_buffer"] = torch.zeros_like(g) 188 | buf = state["momentum_buffer"] 189 | buf.mul_(momentum).add_(g) 190 | if group["nesterov"]: 191 | g = g.add(buf, alpha=momentum) 192 | else: 193 | g = buf 194 | u = zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) 195 | 196 | # scale update 197 | adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) 198 | 199 | # apply weight decay 200 | p.data.mul_(1 - lr * wd) 201 | 202 | # apply update 203 | p.data.add_(u, alpha=-adjusted_lr) 204 | 205 | ############################ 206 | # AdamW backup # 207 | ############################ 208 | 209 | params = [p for p in group["params"] if not self.state[p]["use_muon"]] 210 | lr = group['lr'] 211 | beta1, beta2 = group["adamw_betas"] 212 | eps = group["adamw_eps"] 213 | weight_decay = group["wd"] 214 | 215 | for p in params: 216 | g = p.grad 217 | if g is None: 218 | continue 219 | state = self.state[p] 220 | if "step" not in state: 221 | state["step"] = 0 222 | state["moment1"] = torch.zeros_like(g) 223 | state["moment2"] = torch.zeros_like(g) 224 | state["step"] += 1 225 | step = state["step"] 226 | buf1 = state["moment1"] 227 | buf2 = state["moment2"] 228 | buf1.lerp_(g, 1 - beta1) 229 | buf2.lerp_(g.square(), 1 - beta2) 230 | 231 | g = buf1 / (eps + buf2.sqrt()) 232 | 233 | bias_correction1 = 1 - beta1**step 234 | bias_correction2 = 1 - beta2**step 235 | scale = bias_correction1 / bias_correction2**0.5 236 | p.data.mul_(1 - lr * weight_decay) 237 | p.data.add_(g, alpha=-lr / scale) 238 | 239 | return loss 240 | 241 | 242 | def get_model_and_dataloader(model_name, dataset_name, hidden_size): 243 | name2path = { 244 | "openwebtext-100k": "Elriggs/openwebtext-100k", 245 | } 246 | train_dataset = load_dataset(name2path[dataset_name], trust_remote_code=True) 247 | if model_name == "qwen": 248 | tokenizer = Qwen2Tokenizer.from_pretrained( 249 | "Qwen/Qwen2.5-0.5B", trust_remote_code=True 250 | ) 251 | else: 252 | assert 0, f"model {model_name} not supported" 253 | train_dataset = MoonDataset(dataset_name, train_dataset, tokenizer) 254 | train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True) 255 | 256 | if model_name == "qwen": 257 | config = Qwen2Config( 258 | attention_dropout=0.0, 259 | bos_token_id=151643, 260 | eos_token_id=151643, 261 | hidden_act="silu", 262 | hidden_size=hidden_size, 263 | initializer_range=0.02, 264 | intermediate_size=4864, 265 | max_position_embeddings=513, 266 | max_window_layers=12, 267 | model_type="qwen2", 268 | num_attention_heads=16, 269 | num_hidden_layers=12, 270 | num_key_value_heads=16, 271 | rms_norm_eps=1e-06, 272 | rope_theta=1000000.0, 273 | sliding_window=1024, 274 | tie_word_embeddings=True, 275 | torch_dtype="bfloat16", 276 | use_cache=True, 277 | use_mrope=False, 278 | use_sliding_window=False, 279 | vocab_size=151936, 280 | ) 281 | model = Qwen2ForCausalLM(config) 282 | else: 283 | assert 0, f"model {model_name} not supported" 284 | return model, train_loader 285 | 286 | 287 | def get_optimizer(optimizer_name, model, lr=1e-3, wd=0.1): 288 | if optimizer_name == "adamw": 289 | return torch.optim.AdamW( 290 | model.parameters(), lr=lr, weight_decay=wd, betas=(0.9, 0.95) 291 | ) 292 | elif optimizer_name == "muon": 293 | muon_params = [ 294 | p 295 | for name, p in model.named_parameters() 296 | if p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name 297 | ] 298 | adamw_params = [ 299 | p 300 | for name, p in model.named_parameters() 301 | if not ( 302 | p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name 303 | ) 304 | ] 305 | 306 | return Muon( 307 | lr=lr, 308 | wd=wd, 309 | muon_params=muon_params, 310 | adamw_params=adamw_params, 311 | ) 312 | else: 313 | assert 0, "optimizer not supported" 314 | 315 | 316 | if __name__ == "__main__": 317 | import argparse 318 | 319 | parser = argparse.ArgumentParser() 320 | parser.add_argument("--model", type=str, default="qwen") 321 | parser.add_argument("--optimizer", type=str, default="adamw") 322 | parser.add_argument("--lr", type=float, default=1e-3) 323 | parser.add_argument("--wd", type=float, default=0.1) 324 | parser.add_argument("--dataset", type=str, default="openwebtext-100k") 325 | parser.add_argument("--hidden_size", type=int, default=1024) 326 | args = parser.parse_args() 327 | logger.add(f"logs/train_{args.model}_{args.optimizer}_lr{args.lr}.log") 328 | 329 | model, train_loader = get_model_and_dataloader( 330 | args.model, args.dataset, args.hidden_size 331 | ) 332 | optimizer = get_optimizer( 333 | args.optimizer, model, lr=args.lr 334 | ) 335 | 336 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 337 | model.to(device) 338 | 339 | model.train() 340 | epoch = 1 341 | lr_scheduler = get_cosine_schedule_with_warmup( 342 | optimizer=optimizer, 343 | num_warmup_steps=100, 344 | num_training_steps=len(train_loader) * epoch, 345 | num_cycles=0.5, 346 | ) 347 | for epoch in range(epoch): 348 | for step, batch in enumerate(train_loader): 349 | batch = batch.to(device) 350 | input_ids = batch 351 | outputs = model(input_ids=input_ids, labels=input_ids) 352 | loss = outputs.loss 353 | loss.backward() 354 | optimizer.step() 355 | lr_scheduler.step() 356 | optimizer.zero_grad() 357 | logger.info( 358 | f"Epoch: {epoch} Step: {step} LR: {optimizer.param_groups[0]['lr']} Training loss: {loss.item()}" 359 | ) 360 | --------------------------------------------------------------------------------