├── MARS_M ├── scripts │ ├── run_CV.sh │ ├── run_mars_m_large.sh │ ├── run_mars_m_medium.sh │ ├── run_mars_m_small.sh │ ├── run_mars_m_xl_fw.sh │ ├── run_mars_m_small_fw.sh │ ├── run_moonlight_large.sh │ ├── run_moonlight_medium.sh │ ├── run_moonlight_small.sh │ ├── run_moonlight_xl_fw.sh │ └── run_moonlight_small_fw.sh ├── assets │ ├── xl_val.png │ ├── xl_train.png │ ├── small_train.png │ ├── small_val.png │ ├── val_large.png │ ├── val_medium.png │ ├── val_small.png │ ├── xl_val_global.png │ ├── small_val_global.png │ ├── val_large_global.png │ ├── val_small_global.png │ ├── xl_train_global.png │ ├── small_train_global.png │ └── val_medium_global.png ├── config │ ├── train_gpt2_small_moonlight.py │ ├── train_gpt2_small_mars_m.py │ ├── train_gpt2_xl_moonlight.py │ ├── train_gpt2_large_moonlight.py │ ├── train_gpt2_medium_moonlight.py │ ├── train_gpt2_xl_mars_m.py │ ├── train_gpt2_large_mars_m.py │ └── train_gpt2_medium_mars_m.py ├── utils │ ├── configurator.py │ ├── cv_utils.py │ └── model_CNN.py ├── openwebtext │ └── prepare.py ├── optimizers │ ├── moonlight.py │ └── mars_m.py ├── train_CV.py └── README.md ├── scripts ├── run_CNN.sh ├── run_CV.sh ├── run_mars_large.sh ├── run_mars_small.sh ├── run_muon_large.sh ├── run_muon_small.sh ├── run_adamw_large.sh ├── run_adamw_medium.sh ├── run_adamw_small.sh ├── run_mars_medium.sh ├── run_mars_xl_fw.sh ├── run_muon_medium.sh ├── run_adamw_xl_fw.sh ├── run_mars_small_fw.sh └── run_adamw_small_fw.sh ├── assets ├── MARS.png ├── ShampooH.png ├── xl_train.png ├── xl_val.png ├── MARS-AdamW.png ├── MARS-Lion.png ├── small_val.png ├── time_large.png ├── time_small.png ├── val_large.png ├── val_medium.png ├── val_small.jpg ├── val_small.png ├── MARS-Shampoo.png ├── fineweb_hella.png ├── small_train.png ├── time_medium.png ├── cifar100_test_acc.png ├── cifar100_test_loss.png ├── cifar10_test_acc.png └── cifar10_test_loss.png ├── config ├── train_gpt2_small_adamw.py ├── train_gpt2_small_mars.py ├── train_gpt2_xl_adamw.py ├── train_gpt2_large_adamw.py ├── train_gpt2_medium_adamw.py ├── train_gpt2_xl_mars.py ├── train_gpt2_large_mars.py ├── train_gpt2_medium_mars.py ├── train_gpt2_small_muon.py ├── train_gpt2_large_muon.py └── train_gpt2_medium_muon.py ├── MARS ├── utils │ ├── configurator.py │ ├── cv_utils.py │ └── model_CNN.py ├── opt.py ├── optimizers │ ├── muon.py │ ├── adamw.py │ └── mars.py ├── train_CV.py ├── train_CNN.py ├── train_adamw.py └── train_adamw_fw.py ├── data └── openwebtext │ └── prepare.py └── LICENSE /MARS_M/scripts/run_CV.sh: -------------------------------------------------------------------------------- 1 | python train_CV.py -------------------------------------------------------------------------------- /scripts/run_CNN.sh: -------------------------------------------------------------------------------- 1 | python MARS/train_CNN.py -------------------------------------------------------------------------------- /scripts/run_CV.sh: -------------------------------------------------------------------------------- 1 | python MARS/train_CV.py -------------------------------------------------------------------------------- /assets/MARS.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/assets/MARS.png -------------------------------------------------------------------------------- /assets/ShampooH.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/assets/ShampooH.png -------------------------------------------------------------------------------- /assets/xl_train.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/assets/xl_train.png -------------------------------------------------------------------------------- /assets/xl_val.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/assets/xl_val.png -------------------------------------------------------------------------------- /assets/MARS-AdamW.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/assets/MARS-AdamW.png -------------------------------------------------------------------------------- /assets/MARS-Lion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/assets/MARS-Lion.png -------------------------------------------------------------------------------- /assets/small_val.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/assets/small_val.png -------------------------------------------------------------------------------- /assets/time_large.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/assets/time_large.png -------------------------------------------------------------------------------- /assets/time_small.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/assets/time_small.png -------------------------------------------------------------------------------- /assets/val_large.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/assets/val_large.png -------------------------------------------------------------------------------- /assets/val_medium.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/assets/val_medium.png -------------------------------------------------------------------------------- /assets/val_small.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/assets/val_small.jpg -------------------------------------------------------------------------------- /assets/val_small.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/assets/val_small.png -------------------------------------------------------------------------------- /MARS_M/assets/xl_val.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/MARS_M/assets/xl_val.png -------------------------------------------------------------------------------- /assets/MARS-Shampoo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/assets/MARS-Shampoo.png -------------------------------------------------------------------------------- /assets/fineweb_hella.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/assets/fineweb_hella.png -------------------------------------------------------------------------------- /assets/small_train.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/assets/small_train.png -------------------------------------------------------------------------------- /assets/time_medium.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/assets/time_medium.png -------------------------------------------------------------------------------- /MARS_M/assets/xl_train.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/MARS_M/assets/xl_train.png -------------------------------------------------------------------------------- /MARS_M/assets/small_train.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/MARS_M/assets/small_train.png -------------------------------------------------------------------------------- /MARS_M/assets/small_val.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/MARS_M/assets/small_val.png -------------------------------------------------------------------------------- /MARS_M/assets/val_large.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/MARS_M/assets/val_large.png -------------------------------------------------------------------------------- /MARS_M/assets/val_medium.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/MARS_M/assets/val_medium.png -------------------------------------------------------------------------------- /MARS_M/assets/val_small.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/MARS_M/assets/val_small.png -------------------------------------------------------------------------------- /assets/cifar100_test_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/assets/cifar100_test_acc.png -------------------------------------------------------------------------------- /assets/cifar100_test_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/assets/cifar100_test_loss.png -------------------------------------------------------------------------------- /assets/cifar10_test_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/assets/cifar10_test_acc.png -------------------------------------------------------------------------------- /assets/cifar10_test_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/assets/cifar10_test_loss.png -------------------------------------------------------------------------------- /MARS_M/assets/xl_val_global.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/MARS_M/assets/xl_val_global.png -------------------------------------------------------------------------------- /MARS_M/assets/small_val_global.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/MARS_M/assets/small_val_global.png -------------------------------------------------------------------------------- /MARS_M/assets/val_large_global.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/MARS_M/assets/val_large_global.png -------------------------------------------------------------------------------- /MARS_M/assets/val_small_global.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/MARS_M/assets/val_small_global.png -------------------------------------------------------------------------------- /MARS_M/assets/xl_train_global.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/MARS_M/assets/xl_train_global.png -------------------------------------------------------------------------------- /MARS_M/assets/small_train_global.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/MARS_M/assets/small_train_global.png -------------------------------------------------------------------------------- /MARS_M/assets/val_medium_global.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/MARS_M/assets/val_medium_global.png -------------------------------------------------------------------------------- /scripts/run_mars_large.sh: -------------------------------------------------------------------------------- 1 | torchrun --standalone --nproc_per_node=8 \ 2 | MARS/train_mars.py \ 3 | config/train_gpt2_large_mars.py \ 4 | --batch_size=5 \ 5 | --gradient_accumulation_steps=12 -------------------------------------------------------------------------------- /scripts/run_mars_small.sh: -------------------------------------------------------------------------------- 1 | torchrun --standalone --nproc_per_node=8 \ 2 | MARS/train_mars.py \ 3 | config/train_gpt2_small_mars.py \ 4 | --batch_size=15 \ 5 | --gradient_accumulation_steps=4 -------------------------------------------------------------------------------- /scripts/run_muon_large.sh: -------------------------------------------------------------------------------- 1 | torchrun --standalone --nproc_per_node=8 \ 2 | MARS/train_muon.py \ 3 | config/train_gpt2_large_muon.py \ 4 | --batch_size=5 \ 5 | --gradient_accumulation_steps=12 -------------------------------------------------------------------------------- /scripts/run_muon_small.sh: -------------------------------------------------------------------------------- 1 | torchrun --standalone --nproc_per_node=8 \ 2 | MARS/train_muon.py \ 3 | config/train_gpt2_small_muon.py \ 4 | --batch_size=15 \ 5 | --gradient_accumulation_steps=4 -------------------------------------------------------------------------------- /scripts/run_adamw_large.sh: -------------------------------------------------------------------------------- 1 | torchrun --standalone --nproc_per_node=8 \ 2 | MARS/train_adamw.py \ 3 | config/train_gpt2_large_adamw.py \ 4 | --batch_size=5 \ 5 | --gradient_accumulation_steps=12 -------------------------------------------------------------------------------- /scripts/run_adamw_medium.sh: -------------------------------------------------------------------------------- 1 | torchrun --standalone --nproc_per_node=8 \ 2 | MARS/train_adamw.py \ 3 | config/train_gpt2_medium_adamw.py \ 4 | --batch_size=15 \ 5 | --gradient_accumulation_steps=4 -------------------------------------------------------------------------------- /scripts/run_adamw_small.sh: -------------------------------------------------------------------------------- 1 | torchrun --standalone --nproc_per_node=8 \ 2 | MARS/train_adamw.py \ 3 | config/train_gpt2_small_adamw.py \ 4 | --batch_size=15 \ 5 | --gradient_accumulation_steps=4 -------------------------------------------------------------------------------- /scripts/run_mars_medium.sh: -------------------------------------------------------------------------------- 1 | torchrun --standalone --nproc_per_node=8 \ 2 | MARS/train_mars.py \ 3 | config/train_gpt2_medium_mars.py \ 4 | --batch_size=15 \ 5 | --gradient_accumulation_steps=4 -------------------------------------------------------------------------------- /scripts/run_mars_xl_fw.sh: -------------------------------------------------------------------------------- 1 | torchrun --standalone --nproc_per_node=8 \ 2 | MARS/train_mars_fw.py \ 3 | config/train_gpt2_xl_mars.py \ 4 | --batch_size=5 \ 5 | --gradient_accumulation_steps=12 6 | -------------------------------------------------------------------------------- /scripts/run_muon_medium.sh: -------------------------------------------------------------------------------- 1 | torchrun --standalone --nproc_per_node=8 \ 2 | MARS/train_muon.py \ 3 | config/train_gpt2_medium_muon.py \ 4 | --batch_size=15 \ 5 | --gradient_accumulation_steps=4 -------------------------------------------------------------------------------- /MARS_M/scripts/run_mars_m_large.sh: -------------------------------------------------------------------------------- 1 | torchrun --standalone --nproc_per_node=8 \ 2 | train_mars_m.py \ 3 | config/train_gpt2_large_mars_m.py \ 4 | --batch_size=5 \ 5 | --gradient_accumulation_steps=12 -------------------------------------------------------------------------------- /MARS_M/scripts/run_mars_m_medium.sh: -------------------------------------------------------------------------------- 1 | torchrun --standalone --nproc_per_node=8 \ 2 | train_mars_m.py \ 3 | config/train_gpt2_medium_mars_m.py \ 4 | --batch_size=15 \ 5 | --gradient_accumulation_steps=4 -------------------------------------------------------------------------------- /MARS_M/scripts/run_mars_m_small.sh: -------------------------------------------------------------------------------- 1 | torchrun --standalone --nproc_per_node=8 \ 2 | train_mars_m.py \ 3 | config/train_gpt2_small_mars_m.py \ 4 | --batch_size=15 \ 5 | --gradient_accumulation_steps=4 -------------------------------------------------------------------------------- /scripts/run_adamw_xl_fw.sh: -------------------------------------------------------------------------------- 1 | torchrun --standalone --nproc_per_node=8 \ 2 | MARS/train_adanw_fw.py \ 3 | config/train_gpt2_large_adamw.py \ 4 | --batch_size=5 \ 5 | --gradient_accumulation_steps=12 6 | -------------------------------------------------------------------------------- /scripts/run_mars_small_fw.sh: -------------------------------------------------------------------------------- 1 | torchrun --standalone --nproc_per_node=8 \ 2 | MARS/train_mars_fw.py \ 3 | config/train_gpt2_small_mars.py \ 4 | --batch_size=15 \ 5 | --gradient_accumulation_steps=4 6 | -------------------------------------------------------------------------------- /MARS_M/scripts/run_mars_m_xl_fw.sh: -------------------------------------------------------------------------------- 1 | torchrun --standalone --nproc_per_node=8 \ 2 | train_mars_m_fw.py \ 3 | config/train_gpt2_xl_mars_m.py \ 4 | --batch_size=5 \ 5 | --gradient_accumulation_steps=12 6 | -------------------------------------------------------------------------------- /scripts/run_adamw_small_fw.sh: -------------------------------------------------------------------------------- 1 | torchrun --standalone --nproc_per_node=8 \ 2 | MARS/train_adamw_fw.py \ 3 | config/train_gpt2_small_adamw.py \ 4 | --batch_size=15 \ 5 | --gradient_accumulation_steps=4 6 | -------------------------------------------------------------------------------- /MARS_M/scripts/run_mars_m_small_fw.sh: -------------------------------------------------------------------------------- 1 | torchrun --standalone --nproc_per_node=8 \ 2 | train_mars_m_fw.py \ 3 | config/train_gpt2_small_mars_m.py \ 4 | --batch_size=15 \ 5 | --gradient_accumulation_steps=4 6 | -------------------------------------------------------------------------------- /MARS_M/scripts/run_moonlight_large.sh: -------------------------------------------------------------------------------- 1 | torchrun --standalone --nproc_per_node=8 \ 2 | train_moonlight.py \ 3 | config/train_gpt2_large_moonlight.py \ 4 | --batch_size=5 \ 5 | --gradient_accumulation_steps=12 -------------------------------------------------------------------------------- /MARS_M/scripts/run_moonlight_medium.sh: -------------------------------------------------------------------------------- 1 | torchrun --standalone --nproc_per_node=8 \ 2 | train_moonlight.py \ 3 | config/train_gpt2_medium_moonlight.py \ 4 | --batch_size=15 \ 5 | --gradient_accumulation_steps=4 -------------------------------------------------------------------------------- /MARS_M/scripts/run_moonlight_small.sh: -------------------------------------------------------------------------------- 1 | torchrun --standalone --nproc_per_node=8 \ 2 | train_moonlight.py \ 3 | config/train_gpt2_small_moonlight.py \ 4 | --batch_size=15 \ 5 | --gradient_accumulation_steps=4 -------------------------------------------------------------------------------- /MARS_M/scripts/run_moonlight_xl_fw.sh: -------------------------------------------------------------------------------- 1 | torchrun --standalone --nproc_per_node=8 \ 2 | train_adanw_fw.py \ 3 | config/train_gpt2_large_moonlight.py \ 4 | --batch_size=5 \ 5 | --gradient_accumulation_steps=12 6 | -------------------------------------------------------------------------------- /MARS_M/scripts/run_moonlight_small_fw.sh: -------------------------------------------------------------------------------- 1 | torchrun --standalone --nproc_per_node=8 \ 2 | train_moonlight_fw.py \ 3 | config/train_gpt2_small_moonlight.py \ 4 | --batch_size=15 \ 5 | --gradient_accumulation_steps=4 6 | -------------------------------------------------------------------------------- /config/train_gpt2_small_adamw.py: -------------------------------------------------------------------------------- 1 | wandb_log = True 2 | wandb_project = 'mars' 3 | wandb_run_name='gpt2-small-adamw-100k' 4 | 5 | batch_size = 15 6 | block_size = 1024 7 | gradient_accumulation_steps = 4 8 | 9 | n_layer = 12 10 | n_head = 12 11 | n_embd = 768 12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 13 | bias = False 14 | 15 | # this makes total number of tokens be ~50B 16 | max_iters = 100000 17 | lr_decay_iters = 100000 18 | 19 | # eval stuff 20 | eval_interval = 1000 21 | eval_iters = 200 22 | log_interval = 10 23 | 24 | # optimizer 25 | optimizer_name = 'adamw' 26 | learning_rate = 6e-4 # max learning rate 27 | weight_decay = 1e-1 28 | beta1 = 0.9 29 | beta2 = 0.95 30 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 31 | # learning rate decay settings 32 | decay_lr = True # whether to decay the learning rate 33 | warmup_iters = 2000 # how many steps to warm up for 34 | min_lr = 3e-5 35 | 36 | compile = True 37 | 38 | out_dir = 'out_small_adamw_100k' 39 | -------------------------------------------------------------------------------- /MARS_M/config/train_gpt2_small_moonlight.py: -------------------------------------------------------------------------------- 1 | wandb_log = True 2 | wandb_project = 'mars-m' 3 | wandb_run_name='gpt2-small-moonlight-100k' 4 | 5 | batch_size = 15 6 | block_size = 1024 7 | gradient_accumulation_steps = 4 8 | 9 | n_layer = 12 10 | n_head = 12 11 | n_embd = 768 12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 13 | bias = False 14 | 15 | # this makes total number of tokens be ~50B 16 | max_iters = 100000 17 | lr_decay_iters = 100000 18 | 19 | # eval stuff 20 | eval_interval = 1000 21 | eval_iters = 200 22 | log_interval = 10 23 | 24 | # optimizer 25 | optimizer_name = 'moonlight' 26 | learning_rate = 6e-3 # max learning rate 27 | weight_decay = 1e-1 28 | beta1 = 0.95 29 | beta2 = 0.99 30 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 31 | # learning rate decay settings 32 | decay_lr = True # whether to decay the learning rate 33 | warmup_iters = 2000 # how many steps to warm up for 34 | min_lr = 3e-5 35 | 36 | compile = True 37 | 38 | out_dir = 'out_small_moonlight_100k' 39 | -------------------------------------------------------------------------------- /config/train_gpt2_small_mars.py: -------------------------------------------------------------------------------- 1 | wandb_log = True 2 | wandb_project = 'mars' 3 | wandb_run_name='gpt2-small-mars-100k' 4 | 5 | batch_size = 15 6 | block_size = 1024 7 | gradient_accumulation_steps = 4 8 | 9 | n_layer = 12 10 | n_head = 12 11 | n_embd = 768 12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 13 | bias = False 14 | 15 | # this makes total number of tokens be ~50B 16 | max_iters = 100000 17 | lr_decay_iters = 100000 18 | 19 | # eval stuff 20 | eval_interval = 1000 21 | eval_iters = 200 22 | log_interval = 10 23 | 24 | # optimizer 25 | optimizer_name = 'mars' 26 | learning_rate = 6e-3 # max learning rate 27 | weight_decay = 1e-2 28 | beta1 = 0.95 29 | beta2 = 0.99 30 | lr_1d=3e-3 31 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 32 | # learning rate decay settings 33 | decay_lr = True # whether to decay the learning rate 34 | warmup_iters = 2000 # how many steps to warm up for 35 | min_lr = 3e-5 36 | 37 | compile = True 38 | 39 | out_dir = 'out_small_mars_100k' 40 | gamma=0.025 41 | -------------------------------------------------------------------------------- /MARS_M/config/train_gpt2_small_mars_m.py: -------------------------------------------------------------------------------- 1 | wandb_log = True 2 | wandb_project = 'mars-m' 3 | wandb_run_name='gpt2-small-mars-m-100k' 4 | 5 | batch_size = 15 6 | block_size = 1024 7 | gradient_accumulation_steps = 4 8 | 9 | n_layer = 12 10 | n_head = 12 11 | n_embd = 768 12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 13 | bias = False 14 | 15 | # this makes total number of tokens be ~50B 16 | max_iters = 100000 17 | lr_decay_iters = 100000 18 | 19 | # eval stuff 20 | eval_interval = 1000 21 | eval_iters = 200 22 | log_interval = 10 23 | 24 | # optimizer 25 | optimizer_name = 'mars-m' 26 | learning_rate = 6e-3 # max learning rate 27 | weight_decay = 1e-2 28 | beta1 = 0.95 29 | beta2 = 0.99 30 | 31 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 32 | # learning rate decay settings 33 | decay_lr = True # whether to decay the learning rate 34 | warmup_iters = 2000 # how many steps to warm up for 35 | min_lr = 3e-5 36 | 37 | compile = True 38 | 39 | out_dir = 'out_small_mars_m_100k' 40 | gamma=0.025 41 | -------------------------------------------------------------------------------- /config/train_gpt2_xl_adamw.py: -------------------------------------------------------------------------------- 1 | wandb_log = True 2 | wandb_project = 'mars' 3 | wandb_run_name='gpt2-xl-adamw-100k' 4 | 5 | batch_size = 5 6 | block_size = 1024 7 | gradient_accumulation_steps = 12 8 | 9 | n_layer = 48 10 | n_head = 25 11 | n_embd = 1600 12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 13 | bias = False 14 | scale_attn_by_inverse_layer_idx = True 15 | 16 | # this makes total number of tokens be ~50B 17 | max_iters = 100000 18 | lr_decay_iters = 100000 19 | 20 | # eval stuff 21 | eval_interval = 1000 22 | eval_iters = 200 23 | log_interval = 10 24 | 25 | # optimizer 26 | optimizer_name = 'adamw' 27 | learning_rate = 2e-4 # max learning rate 28 | weight_decay = 1e-1 29 | beta1 = 0.9 30 | beta2 = 0.95 31 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 32 | # learning rate decay settings 33 | decay_lr = True # whether to decay the learning rate 34 | warmup_iters = 2000 # how many steps to warm up for 35 | min_lr = 1e-5 36 | 37 | compile = True 38 | 39 | out_dir = 'out_large_adamw_100k' 40 | -------------------------------------------------------------------------------- /config/train_gpt2_large_adamw.py: -------------------------------------------------------------------------------- 1 | wandb_log = True 2 | wandb_project = 'mars' 3 | wandb_run_name='gpt2-large-adamw-100k' 4 | 5 | batch_size = 5 6 | block_size = 1024 7 | gradient_accumulation_steps = 12 8 | 9 | n_layer = 36 10 | n_head = 20 11 | n_embd = 1280 12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 13 | bias = False 14 | scale_attn_by_inverse_layer_idx = True 15 | 16 | # this makes total number of tokens be ~50B 17 | max_iters = 100000 18 | lr_decay_iters = 100000 19 | 20 | # eval stuff 21 | eval_interval = 1000 22 | eval_iters = 200 23 | log_interval = 10 24 | 25 | # optimizer 26 | optimizer_name = 'adamw' 27 | learning_rate = 2e-4 # max learning rate 28 | weight_decay = 1e-1 29 | beta1 = 0.9 30 | beta2 = 0.95 31 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 32 | # learning rate decay settings 33 | decay_lr = True # whether to decay the learning rate 34 | warmup_iters = 2000 # how many steps to warm up for 35 | min_lr = 1e-5 36 | 37 | compile = True 38 | 39 | out_dir = 'out_large_adamw_100k' 40 | -------------------------------------------------------------------------------- /config/train_gpt2_medium_adamw.py: -------------------------------------------------------------------------------- 1 | wandb_log = True 2 | wandb_project = 'mars' 3 | wandb_run_name='gpt2-medium-adamw-100k' 4 | 5 | batch_size = 15 6 | block_size = 1024 7 | gradient_accumulation_steps = 4 8 | 9 | n_layer = 24 10 | n_head = 16 11 | n_embd = 1024 12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 13 | bias = False 14 | scale_attn_by_inverse_layer_idx = True 15 | 16 | # this makes total number of tokens be ~50B 17 | max_iters = 100000 18 | lr_decay_iters = 100000 19 | 20 | # eval stuff 21 | eval_interval = 1000 22 | eval_iters = 200 23 | log_interval = 10 24 | 25 | # optimizer 26 | optimizer_name = 'adamw' 27 | learning_rate = 3e-4 # max learning rate 28 | weight_decay = 1e-1 29 | beta1 = 0.9 30 | beta2 = 0.95 31 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 32 | # learning rate decay settings 33 | decay_lr = True # whether to decay the learning rate 34 | warmup_iters = 2000 # how many steps to warm up for 35 | min_lr = 6e-5 36 | 37 | compile = True 38 | 39 | out_dir = 'out_medium_adamw_100k' 40 | -------------------------------------------------------------------------------- /MARS_M/config/train_gpt2_xl_moonlight.py: -------------------------------------------------------------------------------- 1 | wandb_log = True 2 | wandb_project = 'mars-m' 3 | wandb_run_name='gpt2-xl-moonlight-100k' 4 | 5 | batch_size = 5 6 | block_size = 1024 7 | gradient_accumulation_steps = 12 8 | 9 | n_layer = 48 10 | n_head = 25 11 | n_embd = 1600 12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 13 | bias = False 14 | scale_attn_by_inverse_layer_idx = True 15 | 16 | # this makes total number of tokens be ~50B 17 | max_iters = 100000 18 | lr_decay_iters = 100000 19 | 20 | # eval stuff 21 | eval_interval = 1000 22 | eval_iters = 200 23 | log_interval = 10 24 | 25 | # optimizer 26 | optimizer_name = 'moonlight' 27 | learning_rate = 3e-3 # max learning rate 28 | weight_decay = 1e-1 29 | beta1 = 0.95 30 | beta2 = 0.99 31 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 32 | # learning rate decay settings 33 | decay_lr = True # whether to decay the learning rate 34 | warmup_iters = 2000 # how many steps to warm up for 35 | min_lr = 1e-5 36 | 37 | compile = True 38 | 39 | out_dir = 'out_large_moonlight_100k' 40 | -------------------------------------------------------------------------------- /MARS_M/config/train_gpt2_large_moonlight.py: -------------------------------------------------------------------------------- 1 | wandb_log = True 2 | wandb_project = 'mars-m' 3 | wandb_run_name='gpt2-large-moonlight-100k' 4 | 5 | batch_size = 5 6 | block_size = 1024 7 | gradient_accumulation_steps = 12 8 | 9 | n_layer = 36 10 | n_head = 20 11 | n_embd = 1280 12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 13 | bias = False 14 | scale_attn_by_inverse_layer_idx = True 15 | 16 | # this makes total number of tokens be ~50B 17 | max_iters = 100000 18 | lr_decay_iters = 100000 19 | 20 | # eval stuff 21 | eval_interval = 1000 22 | eval_iters = 200 23 | log_interval = 10 24 | 25 | # optimizer 26 | optimizer_name = 'moonlight' 27 | learning_rate = 5e-3 # max learning rate 28 | weight_decay = 1e-1 29 | beta1 = 0.95 30 | beta2 = 0.99 31 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 32 | # learning rate decay settings 33 | decay_lr = True # whether to decay the learning rate 34 | warmup_iters = 2000 # how many steps to warm up for 35 | min_lr = 1e-5 36 | 37 | compile = True 38 | 39 | out_dir = 'out_large_moonlight_100k' 40 | -------------------------------------------------------------------------------- /config/train_gpt2_xl_mars.py: -------------------------------------------------------------------------------- 1 | wandb_log = True 2 | wandb_project = 'mars' 3 | wandb_run_name='gpt2-xl-mars-100k' 4 | 5 | batch_size = 5 6 | block_size = 1024 7 | gradient_accumulation_steps = 12 8 | 9 | n_layer = 48 10 | n_head = 25 11 | n_embd = 1600 12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 13 | bias = False 14 | scale_attn_by_inverse_layer_idx = True 15 | 16 | # this makes total number of tokens be ~50B 17 | max_iters = 100000 18 | lr_decay_iters = 100000 19 | 20 | # eval stuff 21 | eval_interval = 1000 22 | eval_iters = 200 23 | log_interval = 10 24 | 25 | # optimizer 26 | optimizer_name = 'mars' 27 | learning_rate = 2e-3 # max learning rate 28 | weight_decay = 1e-2 29 | beta1 = 0.95 30 | beta2 = 0.99 31 | lr_1d=1e-3 32 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 33 | # learning rate decay settings 34 | decay_lr = True # whether to decay the learning rate 35 | warmup_iters = 2000 # how many steps to warm up for 36 | min_lr = 1e-5 37 | 38 | compile = True 39 | 40 | out_dir = 'out_large_mars_100k' 41 | gamma=0.025 42 | -------------------------------------------------------------------------------- /MARS_M/config/train_gpt2_medium_moonlight.py: -------------------------------------------------------------------------------- 1 | wandb_log = True 2 | wandb_project = 'mars-m' 3 | wandb_run_name='gpt2-medium-moonlight-100k' 4 | 5 | batch_size = 15 6 | block_size = 1024 7 | gradient_accumulation_steps = 4 8 | 9 | n_layer = 24 10 | n_head = 16 11 | n_embd = 1024 12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 13 | bias = False 14 | scale_attn_by_inverse_layer_idx = True 15 | 16 | # this makes total number of tokens be ~50B 17 | max_iters = 100000 18 | lr_decay_iters = 100000 19 | 20 | # eval stuff 21 | eval_interval = 1000 22 | eval_iters = 200 23 | log_interval = 10 24 | 25 | # optimizer 26 | optimizer_name = 'moonlight' 27 | learning_rate = 5e-3 # max learning rate 28 | weight_decay = 1e-1 29 | beta1 = 0.95 30 | beta2 = 0.99 31 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 32 | # learning rate decay settings 33 | decay_lr = True # whether to decay the learning rate 34 | warmup_iters = 2000 # how many steps to warm up for 35 | min_lr = 6e-5 36 | 37 | compile = True 38 | 39 | out_dir = 'out_medium_moonlight_100k' 40 | -------------------------------------------------------------------------------- /MARS_M/config/train_gpt2_xl_mars_m.py: -------------------------------------------------------------------------------- 1 | wandb_log = True 2 | wandb_project = 'mars-m' 3 | wandb_run_name='gpt2-xl-mars-m-100k' 4 | 5 | batch_size = 5 6 | block_size = 1024 7 | gradient_accumulation_steps = 12 8 | 9 | n_layer = 48 10 | n_head = 25 11 | n_embd = 1600 12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 13 | bias = False 14 | scale_attn_by_inverse_layer_idx = True 15 | 16 | # this makes total number of tokens be ~50B 17 | max_iters = 100000 18 | lr_decay_iters = 100000 19 | 20 | # eval stuff 21 | eval_interval = 1000 22 | eval_iters = 200 23 | log_interval = 10 24 | 25 | # optimizer 26 | optimizer_name = 'mars-m' 27 | learning_rate = 3e-3 # max learning rate 28 | weight_decay = 1e-2 29 | beta1 = 0.95 30 | beta2 = 0.99 31 | 32 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 33 | # learning rate decay settings 34 | decay_lr = True # whether to decay the learning rate 35 | warmup_iters = 2000 # how many steps to warm up for 36 | min_lr = 1e-5 37 | 38 | compile = True 39 | 40 | out_dir = 'out_large_mars_m_100k' 41 | gamma=0.025 42 | -------------------------------------------------------------------------------- /config/train_gpt2_large_mars.py: -------------------------------------------------------------------------------- 1 | wandb_log = True 2 | wandb_project = 'mars' 3 | wandb_run_name='gpt2-large-mars-100k' 4 | 5 | batch_size = 5 6 | block_size = 1024 7 | gradient_accumulation_steps = 12 8 | 9 | n_layer = 36 10 | n_head = 20 11 | n_embd = 1280 12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 13 | bias = False 14 | scale_attn_by_inverse_layer_idx = True 15 | 16 | # this makes total number of tokens be ~50B 17 | max_iters = 100000 18 | lr_decay_iters = 100000 19 | 20 | # eval stuff 21 | eval_interval = 1000 22 | eval_iters = 200 23 | log_interval = 10 24 | 25 | # optimizer 26 | optimizer_name = 'mars' 27 | learning_rate = 2e-3 # max learning rate 28 | weight_decay = 1e-2 29 | beta1 = 0.95 30 | beta2 = 0.99 31 | lr_1d=1e-3 32 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 33 | # learning rate decay settings 34 | decay_lr = True # whether to decay the learning rate 35 | warmup_iters = 2000 # how many steps to warm up for 36 | min_lr = 1e-5 37 | 38 | compile = True 39 | 40 | out_dir = 'out_large_mars_100k' 41 | gamma=0.025 42 | -------------------------------------------------------------------------------- /MARS_M/config/train_gpt2_large_mars_m.py: -------------------------------------------------------------------------------- 1 | wandb_log = True 2 | wandb_project = 'mars-m' 3 | wandb_run_name='gpt2-large-mars-m-100k' 4 | 5 | batch_size = 5 6 | block_size = 1024 7 | gradient_accumulation_steps = 12 8 | 9 | n_layer = 36 10 | n_head = 20 11 | n_embd = 1280 12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 13 | bias = False 14 | scale_attn_by_inverse_layer_idx = True 15 | 16 | # this makes total number of tokens be ~50B 17 | max_iters = 100000 18 | lr_decay_iters = 100000 19 | 20 | # eval stuff 21 | eval_interval = 1000 22 | eval_iters = 200 23 | log_interval = 10 24 | 25 | # optimizer 26 | optimizer_name = 'mars-m' 27 | learning_rate = 5e-3 # max learning rate 28 | weight_decay = 1e-2 29 | beta1 = 0.95 30 | beta2 = 0.99 31 | 32 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 33 | # learning rate decay settings 34 | decay_lr = True # whether to decay the learning rate 35 | warmup_iters = 2000 # how many steps to warm up for 36 | min_lr = 1e-5 37 | 38 | compile = True 39 | 40 | out_dir = 'out_large_mars_m_100k' 41 | gamma=0.025 42 | -------------------------------------------------------------------------------- /config/train_gpt2_medium_mars.py: -------------------------------------------------------------------------------- 1 | wandb_log = True 2 | wandb_project = 'mars' 3 | wandb_run_name='gpt2-medium-mars-100k' 4 | 5 | batch_size = 15 6 | block_size = 1024 7 | gradient_accumulation_steps = 4 8 | 9 | n_layer = 24 10 | n_head = 16 11 | n_embd = 1024 12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 13 | bias = False 14 | scale_attn_by_inverse_layer_idx = True 15 | 16 | # this makes total number of tokens be ~50B 17 | max_iters = 100000 18 | lr_decay_iters = 100000 19 | 20 | # eval stuff 21 | eval_interval = 1000 22 | eval_iters = 200 23 | log_interval = 10 24 | 25 | # optimizer 26 | optimizer_name = 'mars' 27 | learning_rate = 3e-3 # max learning rate 28 | weight_decay = 1e-2 29 | beta1 = 0.95 30 | beta2 = 0.99 31 | lr_1d=1.5e-3 32 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 33 | # learning rate decay settings 34 | decay_lr = True # whether to decay the learning rate 35 | warmup_iters = 2000 # how many steps to warm up for 36 | min_lr = 6e-5 37 | 38 | compile = True 39 | 40 | out_dir = 'out_medium_mars_100k' 41 | gamma=0.025 42 | -------------------------------------------------------------------------------- /MARS_M/config/train_gpt2_medium_mars_m.py: -------------------------------------------------------------------------------- 1 | wandb_log = True 2 | wandb_project = 'mars-m' 3 | wandb_run_name='gpt2-medium-mars-m-100k' 4 | 5 | batch_size = 15 6 | block_size = 1024 7 | gradient_accumulation_steps = 4 8 | 9 | n_layer = 24 10 | n_head = 16 11 | n_embd = 1024 12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 13 | bias = False 14 | scale_attn_by_inverse_layer_idx = True 15 | 16 | # this makes total number of tokens be ~50B 17 | max_iters = 100000 18 | lr_decay_iters = 100000 19 | 20 | # eval stuff 21 | eval_interval = 1000 22 | eval_iters = 200 23 | log_interval = 10 24 | 25 | # optimizer 26 | optimizer_name = 'mars-m' 27 | learning_rate = 5e-3 # max learning rate 28 | weight_decay = 1e-2 29 | beta1 = 0.95 30 | beta2 = 0.99 31 | 32 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 33 | # learning rate decay settings 34 | decay_lr = True # whether to decay the learning rate 35 | warmup_iters = 2000 # how many steps to warm up for 36 | min_lr = 6e-5 37 | 38 | compile = True 39 | 40 | out_dir = 'out_medium_mars_m_100k' 41 | gamma=0.025 42 | -------------------------------------------------------------------------------- /config/train_gpt2_small_muon.py: -------------------------------------------------------------------------------- 1 | wandb_log = True 2 | wandb_project = 'mars' 3 | wandb_run_name='gpt2-small-muon-100k' 4 | 5 | batch_size = 15 6 | block_size = 1024 7 | gradient_accumulation_steps = 4 8 | 9 | n_layer = 12 10 | n_head = 12 11 | n_embd = 768 12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 13 | bias = False 14 | 15 | # this makes total number of tokens be ~50B 16 | max_iters = 100000 17 | lr_decay_iters = 100000 18 | 19 | # eval stuff 20 | eval_interval = 1000 21 | eval_iters = 200 22 | log_interval = 10 23 | 24 | # optimizer 25 | optimizer_name = 'muon' 26 | learning_rate = 3e-3 # max learning rate, original=6e-4 27 | weight_decay = 1e-1 28 | muon_learning_rate = 2e-2 29 | muon_weight_decay = 0. 30 | beta1 = 0.9 31 | beta2 = 0.95 32 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 33 | # learning rate decay settings 34 | decay_lr = True # whether to decay the learning rate 35 | warmup_iters = 2000 # how many steps to warm up for 36 | min_lr = 3e-5 37 | schedule = 'cosine' 38 | compile = True 39 | 40 | out_dir = 'out_small_muon_100k' 41 | -------------------------------------------------------------------------------- /config/train_gpt2_large_muon.py: -------------------------------------------------------------------------------- 1 | wandb_log = True 2 | wandb_project = 'mars' 3 | wandb_run_name='gpt2-large-muon-100k' 4 | 5 | batch_size = 5 6 | block_size = 1024 7 | gradient_accumulation_steps = 12 8 | 9 | n_layer = 36 10 | n_head = 20 11 | n_embd = 1280 12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 13 | bias = False 14 | scale_attn_by_inverse_layer_idx = True 15 | 16 | # this makes total number of tokens be ~50B 17 | max_iters = 100000 18 | lr_decay_iters = 100000 19 | 20 | # eval stuff 21 | eval_interval = 1000 22 | eval_iters = 200 23 | log_interval = 10 24 | 25 | # optimizer 26 | optimizer_name = 'muon' 27 | learning_rate = 1e-3 # max learning rate 28 | weight_decay = 1e-1 29 | muon_learning_rate = 6.67e-3 30 | muon_weight_decay = 0. 31 | beta1 = 0.9 32 | beta2 = 0.95 33 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 34 | # learning rate decay settings 35 | decay_lr = True # whether to decay the learning rate 36 | warmup_iters = 2000 # how many steps to warm up for 37 | min_lr = 1e-5 38 | 39 | compile = True 40 | 41 | out_dir = 'out_large_muon_100k' 42 | -------------------------------------------------------------------------------- /config/train_gpt2_medium_muon.py: -------------------------------------------------------------------------------- 1 | wandb_log = True 2 | wandb_project = 'mars' 3 | wandb_run_name='gpt2-medium-muon-100k' 4 | 5 | batch_size = 15 6 | block_size = 1024 7 | gradient_accumulation_steps = 4 8 | 9 | n_layer = 24 10 | n_head = 16 11 | n_embd = 1024 12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 13 | bias = False 14 | scale_attn_by_inverse_layer_idx = True 15 | 16 | # this makes total number of tokens be ~50B 17 | max_iters = 100000 18 | lr_decay_iters = 100000 19 | 20 | # eval stuff 21 | eval_interval = 1000 22 | eval_iters = 200 23 | log_interval = 10 24 | 25 | # optimizer 26 | optimizer_name = 'muon' 27 | learning_rate = 1.5e-3 # max learning rate 28 | weight_decay = 1e-1 29 | muon_learning_rate = 1e-2 30 | muon_weight_decay = 0. 31 | beta1 = 0.9 32 | beta2 = 0.95 33 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 34 | # learning rate decay settings 35 | decay_lr = True # whether to decay the learning rate 36 | warmup_iters = 2000 # how many steps to warm up for 37 | min_lr = 6e-5 38 | 39 | compile = True 40 | 41 | out_dir = 'out_medium_muon_100k' 42 | -------------------------------------------------------------------------------- /MARS/utils/configurator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Poor Man's Configurator. Probably a terrible idea. Example usage: 3 | $ python train.py config/override_file.py --batch_size=32 4 | this will first run config/override_file.py, then override batch_size to 32 5 | 6 | The code in this file will be run as follows from e.g. train.py: 7 | >>> exec(open('configurator.py').read()) 8 | 9 | So it's not a Python module, it's just shuttling this code away from train.py 10 | The code in this script then overrides the globals() 11 | 12 | I know people are not going to love this, I just really dislike configuration 13 | complexity and having to prepend config. to every single variable. If someone 14 | comes up with a better simple Python solution I am all ears. 15 | """ 16 | 17 | import sys 18 | from ast import literal_eval 19 | 20 | for arg in sys.argv[1:]: 21 | if '=' not in arg: 22 | # assume it's the name of a config file 23 | assert not arg.startswith('--') 24 | config_file = arg 25 | print(f"Overriding config with {config_file}:") 26 | with open(config_file) as f: 27 | print(f.read()) 28 | exec(open(config_file).read()) 29 | else: 30 | # assume it's a --key=value argument 31 | assert arg.startswith('--') 32 | key, val = arg.split('=') 33 | key = key[2:] 34 | if key in globals(): 35 | try: 36 | # attempt to eval it it (e.g. if bool, number, or etc) 37 | attempt = literal_eval(val) 38 | except (SyntaxError, ValueError): 39 | # if that goes wrong, just use the string 40 | attempt = val 41 | # ensure the types match ok 42 | assert type(attempt) == type(globals()[key]) 43 | # cross fingers 44 | print(f"Overriding: {key} = {attempt}") 45 | globals()[key] = attempt 46 | else: 47 | raise ValueError(f"Unknown config key: {key}") 48 | -------------------------------------------------------------------------------- /MARS_M/utils/configurator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Poor Man's Configurator. Probably a terrible idea. Example usage: 3 | $ python train.py config/override_file.py --batch_size=32 4 | this will first run config/override_file.py, then override batch_size to 32 5 | 6 | The code in this file will be run as follows from e.g. train.py: 7 | >>> exec(open('configurator.py').read()) 8 | 9 | So it's not a Python module, it's just shuttling this code away from train.py 10 | The code in this script then overrides the globals() 11 | 12 | I know people are not going to love this, I just really dislike configuration 13 | complexity and having to prepend config. to every single variable. If someone 14 | comes up with a better simple Python solution I am all ears. 15 | """ 16 | 17 | import sys 18 | from ast import literal_eval 19 | 20 | for arg in sys.argv[1:]: 21 | if '=' not in arg: 22 | # assume it's the name of a config file 23 | assert not arg.startswith('--') 24 | config_file = arg 25 | print(f"Overriding config with {config_file}:") 26 | with open(config_file) as f: 27 | print(f.read()) 28 | exec(open(config_file).read()) 29 | else: 30 | # assume it's a --key=value argument 31 | assert arg.startswith('--') 32 | key, val = arg.split('=') 33 | key = key[2:] 34 | if key in globals(): 35 | try: 36 | # attempt to eval it it (e.g. if bool, number, or etc) 37 | attempt = literal_eval(val) 38 | except (SyntaxError, ValueError): 39 | # if that goes wrong, just use the string 40 | attempt = val 41 | # ensure the types match ok 42 | assert type(attempt) == type(globals()[key]) 43 | # cross fingers 44 | print(f"Overriding: {key} = {attempt}") 45 | globals()[key] = attempt 46 | else: 47 | raise ValueError(f"Unknown config key: {key}") 48 | -------------------------------------------------------------------------------- /MARS_M/openwebtext/prepare.py: -------------------------------------------------------------------------------- 1 | # saves the openwebtext dataset to a binary file for training. following was helpful: 2 | # https://github.com/HazyResearch/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py 3 | 4 | import os 5 | from tqdm import tqdm 6 | import numpy as np 7 | import tiktoken 8 | from datasets import load_dataset # huggingface datasets 9 | 10 | # number of workers in .map() call 11 | # good number to use is ~order number of cpu cores // 2 12 | num_proc = 52 13 | 14 | # takes 54GB in huggingface .cache dir, about 8M documents (8,013,769) 15 | dataset = load_dataset("openwebtext", cache_dir="nanoGPT/cache") 16 | 17 | # owt by default only contains the 'train' split, so create a test split 18 | split_dataset = dataset["train"].train_test_split(test_size=0.0005, seed=2357, shuffle=True) 19 | split_dataset['val'] = split_dataset.pop('test') # rename the test split to val 20 | 21 | # this results in: 22 | # >>> split_dataset 23 | # DatasetDict({ 24 | # train: Dataset({ 25 | # features: ['text'], 26 | # num_rows: 8009762 27 | # }) 28 | # val: Dataset({ 29 | # features: ['text'], 30 | # num_rows: 4007 31 | # }) 32 | # }) 33 | 34 | # we now want to tokenize the dataset. first define the encoding function (gpt2 bpe) 35 | enc = tiktoken.get_encoding("gpt2") 36 | def process(example): 37 | ids = enc.encode_ordinary(example['text']) # encode_ordinary ignores any special tokens 38 | ids.append(enc.eot_token) # add the end of text token, e.g. 50256 for gpt2 bpe 39 | # note: I think eot should be prepended not appended... hmm. it's called "eot" though... 40 | out = {'ids': ids, 'len': len(ids)} 41 | return out 42 | 43 | # tokenize the dataset 44 | tokenized = split_dataset.map( 45 | process, 46 | remove_columns=['text'], 47 | desc="tokenizing the splits", 48 | num_proc=num_proc, 49 | ) 50 | print('tokenization finished') 51 | # concatenate all the ids in each dataset into one large file we can use for training 52 | for split, dset in tokenized.items(): 53 | arr_len = np.sum(dset['len']) 54 | filename = os.path.join(os.path.dirname(__file__), f'{split}.bin') 55 | dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16) 56 | arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,)) 57 | 58 | print(f"writing {filename}...") 59 | idx = 0 60 | for example in tqdm(dset): 61 | arr[idx : idx + example['len']] = example['ids'] 62 | idx += example['len'] 63 | arr.flush() 64 | 65 | # train.bin is ~17GB, val.bin ~8.5MB 66 | # train has ~9B tokens (9,035,582,198) 67 | # val has ~4M tokens (4,434,897) 68 | 69 | # to read the bin files later, e.g. with numpy: 70 | # m = np.memmap('train.bin', dtype=np.uint16, mode='r') 71 | -------------------------------------------------------------------------------- /data/openwebtext/prepare.py: -------------------------------------------------------------------------------- 1 | # saves the openwebtext dataset to a binary file for training. following was helpful: 2 | # https://github.com/HazyResearch/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py 3 | 4 | import os 5 | from tqdm import tqdm 6 | import numpy as np 7 | import tiktoken 8 | from datasets import load_dataset # huggingface datasets 9 | 10 | # number of workers in .map() call 11 | # good number to use is ~order number of cpu cores // 2 12 | num_proc = 52 13 | 14 | # takes 54GB in huggingface .cache dir, about 8M documents (8,013,769) 15 | dataset = load_dataset("openwebtext", cache_dir="nanoGPT/cache") 16 | 17 | # owt by default only contains the 'train' split, so create a test split 18 | split_dataset = dataset["train"].train_test_split(test_size=0.0005, seed=2357, shuffle=True) 19 | split_dataset['val'] = split_dataset.pop('test') # rename the test split to val 20 | 21 | # this results in: 22 | # >>> split_dataset 23 | # DatasetDict({ 24 | # train: Dataset({ 25 | # features: ['text'], 26 | # num_rows: 8009762 27 | # }) 28 | # val: Dataset({ 29 | # features: ['text'], 30 | # num_rows: 4007 31 | # }) 32 | # }) 33 | 34 | # we now want to tokenize the dataset. first define the encoding function (gpt2 bpe) 35 | enc = tiktoken.get_encoding("gpt2") 36 | def process(example): 37 | ids = enc.encode_ordinary(example['text']) # encode_ordinary ignores any special tokens 38 | ids.append(enc.eot_token) # add the end of text token, e.g. 50256 for gpt2 bpe 39 | # note: I think eot should be prepended not appended... hmm. it's called "eot" though... 40 | out = {'ids': ids, 'len': len(ids)} 41 | return out 42 | 43 | # tokenize the dataset 44 | tokenized = split_dataset.map( 45 | process, 46 | remove_columns=['text'], 47 | desc="tokenizing the splits", 48 | num_proc=num_proc, 49 | ) 50 | print('tokenization finished') 51 | # concatenate all the ids in each dataset into one large file we can use for training 52 | for split, dset in tokenized.items(): 53 | arr_len = np.sum(dset['len']) 54 | filename = os.path.join(os.path.dirname(__file__), f'{split}.bin') 55 | dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16) 56 | arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,)) 57 | 58 | print(f"writing {filename}...") 59 | idx = 0 60 | for example in tqdm(dset): 61 | arr[idx : idx + example['len']] = example['ids'] 62 | idx += example['len'] 63 | arr.flush() 64 | 65 | # train.bin is ~17GB, val.bin ~8.5MB 66 | # train has ~9B tokens (9,035,582,198) 67 | # val has ~4M tokens (4,434,897) 68 | 69 | # to read the bin files later, e.g. with numpy: 70 | # m = np.memmap('train.bin', dtype=np.uint16, mode='r') 71 | -------------------------------------------------------------------------------- /MARS/opt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import collections 3 | """ 4 | Adapted from askerlee@github: https://github.com/KellerJordan/modded-nanogpt/issues/9 5 | """ 6 | def separate_params(param_groups): 7 | param_groups_2d = [] 8 | param_groups_non2d = [] 9 | total_param_2d_count = 0 10 | total_param_non2d_count = 0 11 | 12 | 13 | # Convert iterators to lists 14 | if isinstance(param_groups, collections.abc.Iterable): 15 | param_groups = list(param_groups) 16 | 17 | # Check if param_groups is a list of dicts or list of params 18 | if (isinstance(param_groups, list) and isinstance(param_groups[0], dict)) \ 19 | or isinstance(param_groups, dict): 20 | if isinstance(param_groups, dict): 21 | param_groups = [param_groups] 22 | # param_groups is a list of dicts 23 | for group in param_groups: 24 | params_2d, params_non2d, param_2d_count, param_non2d_count = separate_params(group['params']) 25 | param_group_2d = {'params': params_2d} 26 | param_group_non2d = {'params': params_non2d} 27 | # Copy the group dict and replace the 'params' key with the separated params 28 | for k in group.keys(): 29 | if k != 'params': 30 | param_group_2d[k] = group[k] 31 | param_group_non2d[k] = group[k] 32 | 33 | param_groups_2d.append(param_group_2d) 34 | param_groups_non2d.append(param_group_non2d) 35 | total_param_2d_count += param_2d_count 36 | total_param_non2d_count += param_non2d_count 37 | 38 | return param_groups_2d, param_groups_non2d, total_param_2d_count, total_param_non2d_count 39 | 40 | elif isinstance(param_groups, list) and isinstance(param_groups[0], torch.Tensor): 41 | params_2d = [] 42 | params_non2d = [] 43 | param_group = param_groups 44 | # param_group is a list of param tensors 45 | for param in param_group: 46 | if param.ndim >= 2: 47 | params_2d.append(param) 48 | else: 49 | params_non2d.append(param) 50 | return params_2d, params_non2d, len(params_2d), len(params_non2d) 51 | else: 52 | breakpoint() 53 | 54 | ''' 55 | # CombinedOptimizer is now a torch.optim.Optimizer, compatible with pytorch lightning. 56 | # Original Example: 57 | optimizer = CombinedOptimizer([ 58 | torch.optim.AdamW(self.lm_head.parameters(), lr=learning_rate, betas=betas, weight_decay=0, fused=True), 59 | OrthogonalNesterov(self.transformer.h.parameters(), lr=0.1*learning_rate, momentum=0.95) 60 | ]) 61 | # Refactored Example: 62 | optimizer = CombinedOptimizer(\ 63 | self.parameters(), 64 | [OrthogonalNesterov, torch.optim.AdamW], 65 | [{'lr': 0.1*learning_rate, 'momentum': 0.95}, 66 | {'lr': learning_rate, 'betas': betas, 'weight_decay': 0, 'fused': True} 67 | ]) 68 | ''' 69 | 70 | class CombinedOptimizer(torch.optim.Optimizer): 71 | def __init__(self, params, optimizer_types, configs, raw_model = False): 72 | # Separate 2D and non-2D parameters. 73 | # If params is a list of tensors, then each of param_groups_2d and param_groups_non2d 74 | # will be a list of tensors. 75 | # If params is a list of dicts, then each of param_groups_2d and param_groups_non2d 76 | # will be a list of dicts. 77 | # If params is a dict, then each of param_groups_2d and param_groups_non2d will 78 | # be a list of dicts containing only one dict. 79 | if raw_model: 80 | params_others = list(params.transformer.h.parameters()) 81 | param_groups_2d, param_groups_non2d, total_param_2d_count, total_param_non2d_count \ 82 | = separate_params(params_others) 83 | param_groups_non2d.extend(list(params.lm_head.parameters())) 84 | total_param_non2d_count += 2 85 | else: 86 | param_groups_2d, param_groups_non2d, total_param_2d_count, total_param_non2d_count \ 87 | = separate_params(params) 88 | param_groups_2d_non2d = (param_groups_non2d, param_groups_2d) 89 | print(f"Total 2D params: {total_param_2d_count}, Total non-2D params: {total_param_non2d_count}") 90 | 91 | assert len(optimizer_types) == len(configs) == 2 92 | self.optimizers = [ optimizer_types[i](param_groups_2d_non2d[i], **configs[i]) for i in range(2) ] 93 | self.param_groups = [pg for opt in self.optimizers for pg in opt.param_groups] 94 | self.base_lrs = [opt.param_groups[0]['lr'] for opt in self.optimizers] 95 | # Combine the state dicts of all opt in self.optimizers into a single dict 96 | self.state = {k: v for opt in self.optimizers for k, v in opt.state.items()} 97 | # Initially all states are empty. So no point to print their counts. 98 | # Only use the defaults of the OrthogonalNesterov optimizer 99 | self.defaults = self.optimizers[0].defaults 100 | 101 | def step(self, *args, **kwargs): 102 | for opt in self.optimizers: 103 | opt.step(*args, **kwargs) 104 | 105 | def zero_grad(self, **kwargs): 106 | for opt in self.optimizers: 107 | opt.zero_grad(**kwargs) 108 | 109 | def scale_lrs(self, lr_scale): 110 | for base_lr, opt in zip(self.base_lrs, self.optimizers): 111 | opt.param_groups[0]['lr'] = base_lr * lr_scale 112 | 113 | def state_dict(self): 114 | return [opt.state_dict() for opt in self.optimizers] -------------------------------------------------------------------------------- /MARS/optimizers/muon.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from KellerJordan/modded-nanogpt: https://github.com/KellerJordan/modded-nanogpt/blob/master/train_gpt2.py 3 | """ 4 | 5 | import torch 6 | import torch.distributed as dist 7 | import os 8 | 9 | def zeropower_via_svd(G, steps=None): 10 | U, S, V = G.svd() 11 | return U @ V.T 12 | 13 | @torch.compile 14 | def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7): 15 | """ 16 | Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a 17 | quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose 18 | of minimizing steps, it turns out to be empirically effective to keep increasing the slope at 19 | zero even beyond the point where the iteration no longer converges all the way to one everywhere 20 | on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T 21 | where S' is diagonal with S_{ii}' \sim Uniform(0.5, 1.5), which turns out not to hurt model 22 | performance at all relative to UV^T, where USV^T = G is the SVD. 23 | """ 24 | assert len(G.shape) == 2 25 | a, b, c = (3.4445, -4.7750, 2.0315) 26 | X = G.bfloat16() 27 | X /= (X.norm() + eps) # ensure top singular value <= 1 28 | if G.size(0) > G.size(1): 29 | X = X.T 30 | for _ in range(steps): 31 | A = X @ X.T 32 | B = A @ X 33 | X = a * X + b * B + c * A @ B 34 | if G.size(0) > G.size(1): 35 | X = X.T 36 | return X 37 | 38 | zeropower_backends = dict(svd=zeropower_via_svd, newtonschulz5=zeropower_via_newtonschulz5) 39 | 40 | class Muon(torch.optim.Optimizer): 41 | """ 42 | Muon - MomentUm Orthogonalized by Newton-schulz 43 | 44 | Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- 45 | processing step, in which each 2D parameter's update is replaced with the nearest orthogonal 46 | matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has 47 | the advantage that it can be stably run in bfloat16 on the GPU. 48 | 49 | Some warnings: 50 | - This optimizer assumes that all parameters passed in are 2D. 51 | - It should not be used for the embedding layer, the final fully connected layer, or any {0,1}-D 52 | parameters; those should all be optimized by a standard method (e.g., AdamW). 53 | - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions. 54 | - We believe it is unlikely to work well for training with small batch size. 55 | - We believe it may not work well for finetuning pretrained models, but we haven't tested this. 56 | - We have not yet tried this optimizer for training scenarios larger than NanoGPT (124M). 57 | 58 | Arguments: 59 | lr: The learning rate used by the internal SGD. 60 | momentum: The momentum used by the internal SGD. 61 | nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) 62 | backend: The chosen backend for the orthogonalization step. (recommended: 'newtonschulz5') 63 | backend_steps: The number of iteration steps to use in the backend, if it is iterative. 64 | """ 65 | def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, 66 | backend='newtonschulz5', backend_steps=5, weight_decay=0.): 67 | defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, backend=backend, backend_steps=backend_steps, weight_decay=weight_decay) 68 | super().__init__(params, defaults) 69 | if 'WORLD_SIZE' in os.environ: 70 | self.world_size = int(os.environ['WORLD_SIZE']) 71 | self.rank = int(os.environ['RANK']) 72 | else: 73 | self.world_size = 1 74 | self.rank = 0 75 | 76 | def step(self): 77 | 78 | for group in self.param_groups: 79 | 80 | lr = group['lr'] 81 | weight_decay = group['weight_decay'] 82 | momentum = group['momentum'] 83 | zeropower_backend = zeropower_backends[group['backend']] 84 | 85 | # generate weight updates in distributed fashion 86 | total_params = sum(p.numel() for p in group['params']) 87 | updates_flat = torch.zeros(total_params, device='cuda', dtype=torch.bfloat16) 88 | curr_idx = 0 89 | for i, p in enumerate(group['params']): 90 | # luckily this will perfectly distribute a transformer with multiple of 4 layers to 8 GPUs 91 | if i % int(self.world_size) == int(self.rank): 92 | g = p.grad 93 | assert g is not None 94 | if g.ndim > 2: 95 | g = g.view(g.size(0), -1) 96 | state = self.state[p] 97 | if 'momentum_buffer' not in state: 98 | state['momentum_buffer'] = torch.zeros_like(g) 99 | buf = state['momentum_buffer'] 100 | buf.mul_(momentum).add_(g) 101 | if group['nesterov']: 102 | g = g.add(buf, alpha=momentum) 103 | g = zeropower_backend(g, steps=group['backend_steps']) 104 | g *= max(1, g.size(0)/g.size(1))**0.5 105 | updates_flat[curr_idx:curr_idx+p.numel()] = g.flatten() 106 | curr_idx += p.numel() 107 | 108 | # sync updates across devices. we are not memory-constrained so can do this simple deserialization 109 | if self.world_size > 1: 110 | dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) 111 | 112 | # deserialize and apply updates 113 | curr_idx = 0 114 | for p in group['params']: 115 | g = updates_flat[curr_idx:curr_idx+p.numel()].view_as(p.data).type_as(p.data) 116 | p.data.mul_(1.-lr*weight_decay).add_(g, alpha=-lr) 117 | curr_idx += p.numel() -------------------------------------------------------------------------------- /MARS/optimizers/adamw.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim.optimizer import Optimizer 4 | # from megatron.optimizer.l2_norm import l2_norm 5 | 6 | def exists(val): 7 | return val is not None 8 | 9 | 10 | class AdamW(Optimizer): 11 | """Implements Adam algorithm. 12 | 13 | It has been proposed in `Adam: A Method for Stochastic Optimization`_. 14 | 15 | Arguments: 16 | params (iterable): iterable of parameters to optimize or dicts defining 17 | parameter groups 18 | lr (float, optional): learning rate (default: 1e-3) 19 | betas (Tuple[float, float], optional): coefficients used for computing 20 | running averages of gradient and its square (default: (0.9, 0.999)) 21 | eps (float, optional): term added to the denominator to improve 22 | numerical stability (default: 1e-8) 23 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 24 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 25 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 26 | 27 | .. _Adam\: A Method for Stochastic Optimization: 28 | https://arxiv.org/abs/1412.6980 29 | .. _On the Convergence of Adam and Beyond: 30 | https://openreview.net/forum?id=ryQu7f-RZ 31 | """ 32 | 33 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 34 | weight_decay=0, amsgrad=False): 35 | if not 0.0 <= lr: 36 | raise ValueError("Invalid learning rate: {}".format(lr)) 37 | if not 0.0 <= eps: 38 | raise ValueError("Invalid epsilon value: {}".format(eps)) 39 | if not 0.0 <= betas[0] < 1.0: 40 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 41 | if not 0.0 <= betas[1] < 1.0: 42 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 43 | defaults = dict(lr=lr, betas=betas, eps=eps, 44 | weight_decay=weight_decay, amsgrad=amsgrad) 45 | super(AdamW, self).__init__(params, defaults) 46 | self.eps = eps 47 | 48 | def __setstate__(self, state): 49 | super(AdamW, self).__setstate__(state) 50 | for group in self.param_groups: 51 | group.setdefault('amsgrad', False) 52 | 53 | @torch.no_grad() 54 | def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None, grad_scaler=None): 55 | """Performs a single optimization step. 56 | 57 | Arguments: 58 | closure (callable, optional): A closure that reevaluates the model 59 | and returns the loss. 60 | """ 61 | if any(p is not None for p in [grads, output_params, scale, grad_norms]): 62 | raise RuntimeError('FusedAdam has been updated. Simply initialize it identically to torch.optim.Adam, and call step() with no arguments.') 63 | 64 | loss = None 65 | if exists(closure): 66 | with torch.enable_grad(): 67 | loss = closure() 68 | real_update = 0 69 | real_update_wo_lr = 0 70 | 71 | for group in self.param_groups: 72 | for p in filter(lambda p: exists(p.grad), group['params']): 73 | if p.grad is None: 74 | continue 75 | grad = p.grad.data 76 | if grad.is_sparse: 77 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 78 | amsgrad = group['amsgrad'] 79 | 80 | state = self.state[p] 81 | #print('----- starting a parameter state', state.keys(), 'Length of state', len(state)) 82 | # State initialization 83 | if len(state) == 0: 84 | state['step'] = 0 85 | # Exponential moving average of gradient values 86 | state['exp_avg'] = torch.zeros_like(p.data) 87 | # Exponential moving average of squared gradient values 88 | state['exp_avg_sq'] = torch.zeros_like(p.data) 89 | if amsgrad: 90 | # Maintains max of all exp. moving avg. of sq. grad. values 91 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 92 | 93 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 94 | if amsgrad: 95 | max_exp_avg_sq = state['max_exp_avg_sq'] 96 | beta1, beta2 = group['betas'] 97 | 98 | if 'step' in state: 99 | state['step'] += 1 100 | else: 101 | state['step'] = 1 102 | 103 | # Decay the first and second moment running average coefficient 104 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 105 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 106 | if amsgrad: 107 | # Maintains the maximum of all 2nd moment running avg. till now 108 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 109 | # Use the max. for normalizing running avg. of gradient 110 | denom = max_exp_avg_sq.sqrt().add_(self.eps) 111 | else: 112 | denom = exp_avg_sq.sqrt().add_(self.eps) 113 | 114 | bias_correction1 = 1 - beta1 ** state['step'] 115 | bias_correction2 = 1 - beta2 ** state['step'] 116 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 117 | 118 | # p.data.addcdiv_(-step_size, exp_avg, denom) 119 | real_update_tmp = -step_size * torch.mul(p.data, group['weight_decay']).addcdiv_(1, exp_avg, denom) 120 | real_update_wo_lr_tmp = torch.mul(p.data, group['weight_decay']).addcdiv_(1, exp_avg, denom) 121 | 122 | p.data.add_(real_update_tmp) 123 | return loss 124 | 125 | 126 | 127 | -------------------------------------------------------------------------------- /MARS/utils/cv_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | import torchvision 5 | from torchvision import datasets, transforms 6 | from torch.utils.data import DataLoader 7 | 8 | def get_model(args): 9 | """ 10 | models including: 11 | - VGG16 12 | - resnet18 13 | from https://github.com/iShohei220/adopt/blob/main/adopt.py and https://github.com/uclaml/Padam/blob/master/models/resnet.py 14 | """ 15 | if args.dataset in ['mnist', 'cifar10']: 16 | num_classes = 10 17 | elif args.dataset in ['cifar100']: 18 | num_classes = 100 19 | else: 20 | raise NotImplementedError(f"{args.dataset} is not implemented.") 21 | if args.net == 'simple_cnn': 22 | from .model_CNN import Network 23 | model_config = { 24 | "n_inputs": (3, 32, 32) if args.dataset == "cifar10" else (1, 28, 28), 25 | "conv_layers_list": [ 26 | {"filters": 32, "kernel_size": 3, "repeat": 2, "batch_norm": True}, 27 | {"filters": 64, "kernel_size": 3, "repeat": 2, "batch_norm": True}, 28 | {"filters": 128, "kernel_size": 3, "repeat": 2, "batch_norm": True}, 29 | ], 30 | "n_hiddens_list": [512], 31 | "n_outputs": 10, 32 | "dropout": 0.2, 33 | } 34 | model = Network(**model_config) 35 | elif args.net == 'resnet18': 36 | from .model_CNN import ResNet18 37 | model = ResNet18(num_classes = num_classes) 38 | else: 39 | try: 40 | model = torchvision.models.get_model(args.net, num_classes=num_classes) 41 | except: 42 | print('Model not found') 43 | raise NotImplementedError 44 | return model 45 | 46 | def get_datasets(dataset_name: str, train_batch_size: int, eval_batch_size: int): 47 | """Get train and test dataloaders.""" 48 | print('==> Preparing data..') 49 | if dataset_name == "mnist": 50 | transform = transforms.Compose([ 51 | transforms.ToTensor(), 52 | transforms.Normalize((0.1307,), (0.3081,)) 53 | ]) 54 | train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform) 55 | test_dataset = datasets.MNIST('./data', train=False, transform=transform) 56 | elif dataset_name == "cifar10": 57 | transform_train = transforms.Compose([ 58 | transforms.RandomCrop(32, padding=4), 59 | transforms.RandomHorizontalFlip(), 60 | transforms.ToTensor(), 61 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 62 | ]) 63 | transform_test = transforms.Compose([ 64 | transforms.ToTensor(), 65 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 66 | ]) 67 | train_dataset = datasets.CIFAR10('./data', train=True, download=True, transform=transform_train) 68 | test_dataset = datasets.CIFAR10('./data', train=False, transform=transform_test) 69 | elif dataset_name == "cifar100": 70 | transform_train = transforms.Compose([ 71 | transforms.RandomCrop(32, padding=4), 72 | transforms.RandomHorizontalFlip(), 73 | transforms.ToTensor(), 74 | transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276]), 75 | ]) 76 | transform_test = transforms.Compose([ 77 | transforms.ToTensor(), 78 | transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276]), 79 | ]) 80 | train_dataset = datasets.CIFAR100('./data', train=True, download=True, transform=transform_train) 81 | test_dataset = datasets.CIFAR100('./data', train=False, transform=transform_test) 82 | else: 83 | raise NotImplementedError(f"{dataset_name=} is not implemented.") 84 | 85 | train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=4) 86 | test_loader = DataLoader(test_dataset, batch_size=eval_batch_size, shuffle=False, num_workers=4) 87 | 88 | return train_loader, test_loader 89 | 90 | 91 | class WarmupCosineScheduler: 92 | """Custom learning rate scheduler with linear warmup and cosine decay.""" 93 | def __init__(self, optimizer, warmup_iters: int, total_iters: int, min_lr=0.): 94 | self.optimizer = optimizer 95 | self.warmup_iters = warmup_iters 96 | self.total_iters = total_iters 97 | self.min_lr = min_lr 98 | self.max_lr_list = [] 99 | for param_group in self.optimizer.param_groups: 100 | self.max_lr_list.append(param_group['lr']) 101 | self.current_iter = 0 102 | self.lr_list = [] 103 | for param_group in self.optimizer.param_groups: 104 | self.lr_list.append(param_group['lr']) 105 | 106 | def step(self): 107 | self.current_iter += 1 108 | lr_list = [] 109 | cnt = 0 110 | for param_group in self.optimizer.param_groups: 111 | max_lr = self.max_lr_list[cnt] 112 | if self.current_iter <= self.warmup_iters: 113 | lr = self.current_iter / self.warmup_iters * max_lr 114 | else: 115 | lr = self.min_lr + 0.5 * (max_lr - self.min_lr) * ( 116 | np.cos((self.current_iter - self.warmup_iters) / (self.total_iters - self.warmup_iters) * 3.14159265 / 2) 117 | ).item() 118 | param_group['lr'] = lr 119 | cnt += 1 120 | lr_list.append(lr) 121 | self.lr_list = lr_list 122 | def get_lr(self): 123 | lr_list = [] 124 | for param_group in self.optimizer.param_groups: 125 | lr_list.append(param_group['lr']) 126 | return lr_list 127 | 128 | class ConstantScheduler: 129 | """Constant learning rate scheduler.""" 130 | def __init__(self, optimizer, lr: float): 131 | self.optimizer = optimizer 132 | lr_list = [] 133 | for param_group in self.optimizer.param_groups: 134 | lr_list.append(lr) 135 | 136 | def step(self): 137 | pass 138 | 139 | def get_lr(self): 140 | lr_list = [] 141 | for param_group in self.optimizer.param_groups: 142 | lr_list.append(param_group['lr']) 143 | return lr_list 144 | 145 | def get_scheduler(optimizer, args): 146 | if args.scheduler == 'multistep': 147 | from torch.optim.lr_scheduler import MultiStepLR 148 | scheduler = MultiStepLR(optimizer, milestones=[args.Nepoch // 2, (args.Nepoch * 3) // 4], gamma=0.1) 149 | elif args.scheduler == 'cosine': 150 | scheduler = WarmupCosineScheduler(optimizer, warmup_iters = args.Nepoch // 10, total_iters = args.Nepoch, 151 | min_lr = 0.) 152 | elif args.scheduler == 'constant': 153 | scheduler = ConstantScheduler(optimizer, lr = args.lr) 154 | return scheduler -------------------------------------------------------------------------------- /MARS_M/utils/cv_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | import torchvision 5 | from torchvision import datasets, transforms 6 | from torch.utils.data import DataLoader 7 | 8 | def get_model(args): 9 | """ 10 | models including: 11 | - VGG16 12 | - resnet18 13 | from https://github.com/iShohei220/adopt/blob/main/adopt.py and https://github.com/uclaml/Padam/blob/master/models/resnet.py 14 | """ 15 | if args.dataset in ['mnist', 'cifar10']: 16 | num_classes = 10 17 | elif args.dataset in ['cifar100']: 18 | num_classes = 100 19 | else: 20 | raise NotImplementedError(f"{args.dataset} is not implemented.") 21 | if args.net == 'simple_cnn': 22 | from .model_CNN import Network 23 | model_config = { 24 | "n_inputs": (3, 32, 32) if args.dataset == "cifar10" else (1, 28, 28), 25 | "conv_layers_list": [ 26 | {"filters": 32, "kernel_size": 3, "repeat": 2, "batch_norm": True}, 27 | {"filters": 64, "kernel_size": 3, "repeat": 2, "batch_norm": True}, 28 | {"filters": 128, "kernel_size": 3, "repeat": 2, "batch_norm": True}, 29 | ], 30 | "n_hiddens_list": [512], 31 | "n_outputs": 10, 32 | "dropout": 0.2, 33 | } 34 | model = Network(**model_config) 35 | elif args.net == 'resnet18': 36 | from .model_CNN import ResNet18 37 | model = ResNet18(num_classes = num_classes) 38 | else: 39 | try: 40 | model = torchvision.models.get_model(args.net, num_classes=num_classes) 41 | except: 42 | print('Model not found') 43 | raise NotImplementedError 44 | return model 45 | 46 | def get_datasets(dataset_name: str, train_batch_size: int, eval_batch_size: int): 47 | """Get train and test dataloaders.""" 48 | print('==> Preparing data..') 49 | if dataset_name == "mnist": 50 | transform = transforms.Compose([ 51 | transforms.ToTensor(), 52 | transforms.Normalize((0.1307,), (0.3081,)) 53 | ]) 54 | train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform) 55 | test_dataset = datasets.MNIST('./data', train=False, transform=transform) 56 | elif dataset_name == "cifar10": 57 | transform_train = transforms.Compose([ 58 | transforms.RandomCrop(32, padding=4), 59 | transforms.RandomHorizontalFlip(), 60 | transforms.ToTensor(), 61 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 62 | ]) 63 | transform_test = transforms.Compose([ 64 | transforms.ToTensor(), 65 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 66 | ]) 67 | train_dataset = datasets.CIFAR10('./data', train=True, download=True, transform=transform_train) 68 | test_dataset = datasets.CIFAR10('./data', train=False, transform=transform_test) 69 | elif dataset_name == "cifar100": 70 | transform_train = transforms.Compose([ 71 | transforms.RandomCrop(32, padding=4), 72 | transforms.RandomHorizontalFlip(), 73 | transforms.ToTensor(), 74 | transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276]), 75 | ]) 76 | transform_test = transforms.Compose([ 77 | transforms.ToTensor(), 78 | transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276]), 79 | ]) 80 | train_dataset = datasets.CIFAR100('./data', train=True, download=True, transform=transform_train) 81 | test_dataset = datasets.CIFAR100('./data', train=False, transform=transform_test) 82 | else: 83 | raise NotImplementedError(f"{dataset_name=} is not implemented.") 84 | 85 | train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=4) 86 | test_loader = DataLoader(test_dataset, batch_size=eval_batch_size, shuffle=False, num_workers=4) 87 | 88 | return train_loader, test_loader 89 | 90 | 91 | class WarmupCosineScheduler: 92 | """Custom learning rate scheduler with linear warmup and cosine decay.""" 93 | def __init__(self, optimizer, warmup_iters: int, total_iters: int, min_lr=0.): 94 | self.optimizer = optimizer 95 | self.warmup_iters = warmup_iters 96 | self.total_iters = total_iters 97 | self.min_lr = min_lr 98 | self.max_lr_list = [] 99 | for param_group in self.optimizer.param_groups: 100 | self.max_lr_list.append(param_group['lr']) 101 | self.current_iter = 0 102 | self.lr_list = [] 103 | for param_group in self.optimizer.param_groups: 104 | self.lr_list.append(param_group['lr']) 105 | 106 | def step(self): 107 | self.current_iter += 1 108 | lr_list = [] 109 | cnt = 0 110 | for param_group in self.optimizer.param_groups: 111 | max_lr = self.max_lr_list[cnt] 112 | if self.current_iter <= self.warmup_iters: 113 | lr = self.current_iter / self.warmup_iters * max_lr 114 | else: 115 | lr = self.min_lr + 0.5 * (max_lr - self.min_lr) * ( 116 | np.cos((self.current_iter - self.warmup_iters) / (self.total_iters - self.warmup_iters) * 3.14159265 / 2) 117 | ).item() 118 | param_group['lr'] = lr 119 | cnt += 1 120 | lr_list.append(lr) 121 | self.lr_list = lr_list 122 | def get_lr(self): 123 | lr_list = [] 124 | for param_group in self.optimizer.param_groups: 125 | lr_list.append(param_group['lr']) 126 | return lr_list 127 | 128 | class ConstantScheduler: 129 | """Constant learning rate scheduler.""" 130 | def __init__(self, optimizer, lr: float): 131 | self.optimizer = optimizer 132 | lr_list = [] 133 | for param_group in self.optimizer.param_groups: 134 | lr_list.append(lr) 135 | 136 | def step(self): 137 | pass 138 | 139 | def get_lr(self): 140 | lr_list = [] 141 | for param_group in self.optimizer.param_groups: 142 | lr_list.append(param_group['lr']) 143 | return lr_list 144 | 145 | def get_scheduler(optimizer, args): 146 | if args.scheduler == 'multistep': 147 | from torch.optim.lr_scheduler import MultiStepLR 148 | scheduler = MultiStepLR(optimizer, milestones=[args.Nepoch // 2, (args.Nepoch * 3) // 4], gamma=0.1) 149 | elif args.scheduler == 'cosine': 150 | scheduler = WarmupCosineScheduler(optimizer, warmup_iters = args.Nepoch // 10, total_iters = args.Nepoch, 151 | min_lr = 0.) 152 | elif args.scheduler == 'constant': 153 | scheduler = ConstantScheduler(optimizer, lr = args.lr) 154 | return scheduler -------------------------------------------------------------------------------- /MARS/train_CV.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from uclaml/Padam: https://github.com/uclaml/Padam/blob/master/run_cnn_test_cifar10.py 3 | """ 4 | import numpy as np 5 | import os 6 | import argparse 7 | import json 8 | from tqdm import tqdm 9 | 10 | parser = argparse.ArgumentParser(description='PyTorch Training') 11 | parser.add_argument( 12 | "--dataset", type=str, default="cifar10", choices=["mnist", "cifar10", "cifar100"], help="dataset to use" 13 | ) 14 | parser.add_argument( 15 | "--scheduler", type=str, default="multistep", choices=["multistep", "cosine", "constant"], help="scheduler to use" 16 | ) 17 | parser.add_argument("--train_bsz", type=int, default=128, help="training batch size") 18 | parser.add_argument("--eval_bsz", type=int, default=100, help="eval batch size") 19 | parser.add_argument("--seed", type=int, default=0, help="random seed") 20 | parser.add_argument("--cpu", action="store_true", help="use cpu only") 21 | parser.add_argument("--cuda", type=str, default="0", help="device to use") 22 | 23 | parser.add_argument('--lr', default=0.1, type=float, help='learning rate') 24 | parser.add_argument('--adamw_lr', default=0.003, type=float, help='learning rate for adamw') 25 | parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint') 26 | parser.add_argument('--optim', '-m', type=str, choices=["adam", "adamw", "mars", "muon"], default='mars', help='optimization method, default: mars') 27 | parser.add_argument('--net', '-n', type=str, default="resnet18", help='network archtecture, choosing from "simple_cnn" or torchvision models. default: resnet18') 28 | parser.add_argument('--wd', default=0., type=float, help='weight decay') 29 | parser.add_argument('--Nepoch', default=200, type=int, help='number of epoch') 30 | parser.add_argument('--beta1', default=0.9, type=float, help='beta1') 31 | parser.add_argument('--beta2', default=0.999, type=float, help='beta2') 32 | parser.add_argument('--wandb', action='store_true', help='use wandb') 33 | parser.add_argument('--save_dir', type=str, default="./checkpoint", help='save directory') 34 | parser.add_argument('--wandb_name', type=str, default="None", help='log directory') 35 | 36 | 37 | args = parser.parse_args() 38 | os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda 39 | if args.wandb: 40 | import wandb 41 | if args.wandb_name == "None": 42 | wandb.init(project="CV", name=args.dataset+"_"+args.net+"_"+args.optim+"_"+str(args.lr), config=args) 43 | else: 44 | wandb.init(project="CV", name=args.wandb_name, config=args) 45 | 46 | import torch 47 | import torch.nn as nn 48 | import torch.optim as optim 49 | import torch.backends.cudnn as cudnn 50 | from utils.cv_utils import get_datasets, get_scheduler, get_model 51 | use_cuda = torch.cuda.is_available() and not args.cpu 52 | 53 | os.environ['PYTHONHASHSEED'] = str(args.seed) 54 | np.random.seed(args.seed) 55 | torch.manual_seed(args.seed) 56 | torch.cuda.manual_seed(args.seed) 57 | torch.cuda.manual_seed_all(args.seed) 58 | 59 | trainloader, testloader = get_datasets(args.dataset, args.train_bsz, args.eval_bsz) 60 | if args.resume: 61 | # Load checkpoint. 62 | print('==> Resuming from checkpoint..') 63 | assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' 64 | checkpoint = torch.load(f'./checkpoint/{args.net}_{args.dataset}_'+args.optim) 65 | model = checkpoint['model'] 66 | start_epoch = checkpoint['epoch'] 67 | train_losses = checkpoint['train_losses'] 68 | test_losses = checkpoint['test_losses'] 69 | train_errs = checkpoint['train_errs'] 70 | test_errs = checkpoint['test_errs'] 71 | else: 72 | print('==> Building model..') 73 | 74 | model = get_model(args) 75 | 76 | if use_cuda: 77 | model.cuda() 78 | model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) 79 | cudnn.benchmark = True 80 | 81 | 82 | criterion = nn.CrossEntropyLoss() 83 | 84 | betas = (args.beta1, args.beta2) 85 | from optimizers.mars import MARS 86 | from optimizers.muon import Muon 87 | from opt import CombinedOptimizer 88 | from optimizers.adamw import AdamW 89 | if args.optim == 'adam': 90 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay = args.wd, betas = betas) 91 | elif args.optim == 'adamw': 92 | optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay = args.wd, betas = betas) 93 | elif args.optim == 'muon': 94 | optimizer = CombinedOptimizer(model.parameters(), [AdamW, Muon], [{'lr': args.adamw_lr, 'betas': betas, 'weight_decay': args.wd}, 95 | {'lr': args.lr, 'weight_decay': 0.}]) 96 | elif args.optim == 'mars': 97 | optimizer = MARS(model.parameters(), lr=args.lr, weight_decay = args.wd, lr_1d=args.adamw_lr) 98 | 99 | scheduler = get_scheduler(optimizer, args) 100 | best_acc = 0 # best test accuracy 101 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 102 | train_errs = [] 103 | test_errs = [] 104 | train_losses = [] 105 | test_losses = [] 106 | acc_list = [] 107 | t_bar = tqdm(total=len(trainloader)) 108 | t_bar2 = tqdm(total=len(testloader)) 109 | for epoch in range(start_epoch+1, args.Nepoch+1): 110 | 111 | scheduler.step() 112 | # print ('\nEpoch: %d' % epoch, ' Learning rate:', scheduler.get_lr()) 113 | model.train() # Training 114 | 115 | train_loss = 0 116 | correct_train = 0 117 | total_train = 0 118 | print(scheduler.get_lr()) 119 | t_bar.reset() 120 | for batch_idx, (inputs, targets) in enumerate(trainloader): 121 | if use_cuda: 122 | inputs, targets = inputs.cuda(), targets.cuda() 123 | 124 | optimizer.zero_grad() 125 | outputs = model(inputs) 126 | loss = criterion(outputs, targets) 127 | loss.backward() 128 | optimizer.step() 129 | 130 | train_loss += loss.item() 131 | _, predicted = torch.max(outputs.data, 1) 132 | total_train += targets.size(0) 133 | correct_train += predicted.eq(targets.data).cpu().sum().item() 134 | 135 | t_bar.update(1) 136 | t_bar.set_description('Epoch: %d | Loss: %.3f | Acc: %.3f%% ' % (epoch, train_loss/(batch_idx+1), 100.0/total_train*(correct_train))) 137 | t_bar.refresh() 138 | train_losses.append(train_loss/(batch_idx+1)) 139 | train_errs.append(1 - correct_train/total_train) 140 | 141 | model.eval() # Testing 142 | 143 | test_loss = 0 144 | correct = 0 145 | total = 0 146 | t_bar2.reset() 147 | for batch_idx, (inputs, targets) in enumerate(testloader): 148 | if use_cuda: 149 | inputs, targets = inputs.cuda(), targets.cuda() 150 | outputs = model(inputs) 151 | loss = criterion(outputs, targets) 152 | 153 | test_loss += loss.item() 154 | _, predicted = torch.max(outputs.data, 1) 155 | total += targets.size(0) 156 | correct += predicted.eq(targets.data).cpu().sum().item() 157 | 158 | t_bar2.update(1) 159 | t_bar2.set_description('Loss: %.3f | Acc: %.3f%% (Best: %.3f%%)' % (test_loss/(batch_idx+1), 100.0/total*(correct), best_acc)) 160 | t_bar2.refresh() 161 | test_errs.append(1 - correct/total) 162 | test_losses.append(test_loss/(batch_idx+1)) 163 | if args.wandb: 164 | wandb.log({"epoch": epoch, 165 | "train_loss": train_loss/(batch_idx+1), 166 | "train_acc": 100.0/total_train*(correct_train), 167 | "test_loss": test_loss/(batch_idx+1), 168 | "test_acc": 100.0/total*(correct), 169 | "lr": scheduler.get_lr()[0]}, step=epoch) 170 | # Save checkpoint 171 | acc = 100.0/total*(correct) 172 | if acc > best_acc: 173 | if not os.path.isdir('checkpoint'): 174 | os.mkdir('checkpoint') 175 | state = { 176 | 'model': model, 177 | 'epoch': epoch, 178 | } 179 | # torch.save(state, './checkpoint/cnn_cifar10_' + args.optim) 180 | torch.save(state, os.path.join(args.save_dir, "-".join([args.net, args.dataset, args.optim, str(args.lr).replace(".", "_")])+".pth")) 181 | best_acc = acc 182 | t_bar2.set_description('Model Saved! | Loss: %.3f | Acc: %.3f%% (Best: %.3f%%)' % (test_loss/(batch_idx+1), 100.0/total*(correct), best_acc)) 183 | t_bar2.refresh() 184 | acc_list.append(acc) 185 | -------------------------------------------------------------------------------- /MARS_M/optimizers/moonlight.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | # This code snippet is a modified version adapted from the following GitHub repository: 5 | # https://github.com/KellerJordan/Muon/blob/master/muon.py 6 | @torch.compile 7 | def zeropower_via_newtonschulz5(G, steps): 8 | """ 9 | Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a 10 | quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose 11 | of minimizing steps, it turns out to be empirically effective to keep increasing the slope at 12 | zero even beyond the point where the iteration no longer converges all the way to one everywhere 13 | on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T 14 | where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model 15 | performance at all relative to UV^T, where USV^T = G is the SVD. 16 | """ 17 | assert len(G.shape) == 2 18 | a, b, c = (3.4445, -4.7750, 2.0315) 19 | X = G.bfloat16() 20 | if G.size(0) > G.size(1): 21 | X = X.T 22 | # Ensure spectral norm is at most 1 23 | X = X / (X.norm() + 1e-7) 24 | # Perform the NS iterations 25 | for _ in range(steps): 26 | A = X @ X.T 27 | B = ( 28 | b * A + c * A @ A 29 | ) # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng 30 | X = a * X + B @ X 31 | 32 | if G.size(0) > G.size(1): 33 | X = X.T 34 | return X 35 | 36 | 37 | class Moonlight(torch.optim.Optimizer): 38 | """ 39 | Muon - MomentUm Orthogonalized by Newton-schulz 40 | 41 | Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- 42 | processing step, in which each 2D parameter's update is replaced with the nearest orthogonal 43 | matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has 44 | the advantage that it can be stably run in bfloat16 on the GPU. 45 | 46 | Some warnings: 47 | - We believe this optimizer is unlikely to work well for training with small batch size. 48 | - We believe it may not work well for finetuning pretrained models, but we haven't tested this. 49 | 50 | Arguments: 51 | muon_params: The parameters to be optimized by Muon. 52 | lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) 53 | momentum: The momentum used by the internal SGD. (0.95 is a good default) 54 | nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) 55 | ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) 56 | adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are 57 | {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. 58 | adamw_lr: The learning rate for the internal AdamW. 59 | adamw_betas: The betas for the internal AdamW. 60 | adamw_eps: The epsilon for the internal AdamW. 61 | adamw_wd: The weight decay for the internal AdamW. 62 | """ 63 | 64 | def __init__( 65 | self, 66 | lr=1e-3, 67 | wd=0.1, 68 | muon_params=None, 69 | momentum=0.95, 70 | nesterov=True, 71 | ns_steps=5, 72 | adamw_params=None, 73 | adamw_betas=(0.9, 0.95), # MARS implement: 0.95, 0.95 74 | adamw_eps=1e-8, 75 | ): 76 | 77 | defaults = dict( 78 | lr=lr, 79 | wd=wd, 80 | momentum=momentum, 81 | nesterov=nesterov, 82 | ns_steps=ns_steps, 83 | adamw_betas=adamw_betas, 84 | adamw_eps=adamw_eps, 85 | ) 86 | 87 | params = list(muon_params) 88 | adamw_params = list(adamw_params) if adamw_params is not None else [] 89 | params.extend(adamw_params) 90 | super().__init__(params, defaults) 91 | # Sort parameters into those for which we will use Muon, and those for which we will not 92 | for p in muon_params: 93 | # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer 94 | assert p.ndim == 2, p.ndim 95 | self.state[p]["use_muon"] = True 96 | for p in adamw_params: 97 | # Do not use Muon for parameters in adamw_params 98 | self.state[p]["use_muon"] = False 99 | 100 | def adjust_lr_for_muon(self, lr, param_shape): 101 | A, B = param_shape[:2] 102 | # We adjust the learning rate and weight decay based on the size of the parameter matrix 103 | # as describted in the paper 104 | adjusted_ratio = 0.2 * math.sqrt(max(A, B)) 105 | adjusted_lr = lr * adjusted_ratio 106 | return adjusted_lr 107 | 108 | def step(self, closure=None): 109 | """Perform a single optimization step. 110 | 111 | Args: 112 | closure (Callable, optional): A closure that reevaluates the model 113 | and returns the loss. 114 | """ 115 | loss = None 116 | if closure is not None: 117 | with torch.enable_grad(): 118 | loss = closure() 119 | 120 | for group in self.param_groups: 121 | 122 | ############################ 123 | # Muon # 124 | ############################ 125 | 126 | params = [p for p in group["params"] if self.state[p]["use_muon"]] 127 | # import pdb; pdb.set_trace() 128 | lr = group["lr"] 129 | wd = group["wd"] 130 | momentum = group["momentum"] 131 | 132 | # generate weight updates in distributed fashion 133 | for p in params: 134 | # sanity check 135 | g = p.grad 136 | if g is None: 137 | continue 138 | if g.ndim > 2: 139 | g = g.view(g.size(0), -1) 140 | assert g is not None 141 | 142 | # calc update 143 | state = self.state[p] 144 | if "momentum_buffer" not in state: 145 | state["momentum_buffer"] = torch.zeros_like(g) 146 | buf = state["momentum_buffer"] 147 | buf.mul_(momentum).add_(g) 148 | if group["nesterov"]: 149 | g = g.add(buf, alpha=momentum) 150 | else: 151 | g = buf 152 | u = zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) 153 | 154 | # scale update 155 | adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) 156 | 157 | # apply weight decay 158 | p.data.mul_(1 - lr * wd) 159 | 160 | # apply update 161 | p.data.add_(u, alpha=-adjusted_lr) 162 | 163 | ############################ 164 | # AdamW backup # 165 | ############################ 166 | 167 | params = [p for p in group["params"] if not self.state[p]["use_muon"]] 168 | lr = group['lr'] 169 | beta1, beta2 = group["adamw_betas"] 170 | eps = group["adamw_eps"] 171 | weight_decay = group["wd"] 172 | 173 | for p in params: 174 | g = p.grad 175 | if g is None: 176 | continue 177 | state = self.state[p] 178 | if "step" not in state: 179 | state["step"] = 0 180 | state["moment1"] = torch.zeros_like(g) 181 | state["moment2"] = torch.zeros_like(g) 182 | state["step"] += 1 183 | step = state["step"] 184 | buf1 = state["moment1"] 185 | buf2 = state["moment2"] 186 | buf1.lerp_(g, 1 - beta1) 187 | buf2.lerp_(g.square(), 1 - beta2) 188 | 189 | g = buf1 / (eps + buf2.sqrt()) 190 | 191 | bias_correction1 = 1 - beta1**step 192 | bias_correction2 = 1 - beta2**step 193 | scale = bias_correction1 / bias_correction2**0.5 194 | p.data.mul_(1 - lr * weight_decay) 195 | p.data.add_(g, alpha=-lr / scale) 196 | 197 | return loss -------------------------------------------------------------------------------- /MARS_M/train_CV.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from uclaml/Padam: https://github.com/uclaml/Padam/blob/master/run_cnn_test_cifar10.py 3 | """ 4 | import numpy as np 5 | import os 6 | import argparse 7 | import json 8 | from tqdm import tqdm 9 | 10 | parser = argparse.ArgumentParser(description='PyTorch Training') 11 | parser.add_argument( 12 | "--dataset", type=str, default="cifar10", choices=["mnist", "cifar10", "cifar100"], help="dataset to use" 13 | ) 14 | parser.add_argument( 15 | "--scheduler", type=str, default="multistep", choices=["multistep", "cosine", "constant"], help="scheduler to use" 16 | ) 17 | parser.add_argument("--train_bsz", type=int, default=128, help="training batch size") 18 | parser.add_argument("--eval_bsz", type=int, default=100, help="eval batch size") 19 | parser.add_argument("--seed", type=int, default=0, help="random seed") 20 | parser.add_argument("--cpu", action="store_true", help="use cpu only") 21 | parser.add_argument("--cuda", type=str, default="0", help="device to use") 22 | 23 | parser.add_argument('--lr', default=0.1, type=float, help='learning rate') 24 | parser.add_argument('--adamw_lr', default=0.003, type=float, help='learning rate for adamw') 25 | parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint') 26 | parser.add_argument('--optim', '-m', type=str, choices=["adam", "adamw", "mars-m", "moonlight"], default='mars', help='optimization method, default: mars') 27 | parser.add_argument('--net', '-n', type=str, default="resnet18", help='network archtecture, choosing from "simple_cnn" or torchvision models. default: resnet18') 28 | parser.add_argument('--wd', default=0., type=float, help='weight decay') 29 | parser.add_argument('--Nepoch', default=200, type=int, help='number of epoch') 30 | parser.add_argument('--beta1', default=0.9, type=float, help='beta1') 31 | parser.add_argument('--beta2', default=0.999, type=float, help='beta2') 32 | parser.add_argument('--gamma', default=0.025, type=float, help='gamma in MARS-M') 33 | parser.add_argument('--clip_c', action='store_true', help='whether to clip c_t in MARS-M') 34 | parser.add_argument('--mars_exact', action='store_true', help='whether to use the approximate version of MARS-M') 35 | parser.add_argument('--wandb', action='store_true', help='use wandb') 36 | parser.add_argument('--save_dir', type=str, default="./checkpoint", help='save directory') 37 | parser.add_argument('--wandb_name', type=str, default="None", help='log directory') 38 | 39 | 40 | args = parser.parse_args() 41 | os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda 42 | if args.wandb: 43 | import wandb 44 | if args.wandb_name == "None": 45 | wandb.init(project="CV", name=args.dataset+"_"+args.net+"_"+args.optim+"_"+str(args.lr), config=args) 46 | else: 47 | wandb.init(project="CV", name=args.wandb_name, config=args) 48 | 49 | import torch 50 | import torch.nn as nn 51 | import torch.optim as optim 52 | import torch.backends.cudnn as cudnn 53 | from utils.cv_utils import get_datasets, get_scheduler, get_model 54 | use_cuda = torch.cuda.is_available() and not args.cpu 55 | 56 | os.environ['PYTHONHASHSEED'] = str(args.seed) 57 | np.random.seed(args.seed) 58 | torch.manual_seed(args.seed) 59 | torch.cuda.manual_seed(args.seed) 60 | torch.cuda.manual_seed_all(args.seed) 61 | 62 | trainloader, testloader = get_datasets(args.dataset, args.train_bsz, args.eval_bsz) 63 | if args.resume: 64 | # Load checkpoint. 65 | print('==> Resuming from checkpoint..') 66 | assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' 67 | checkpoint = torch.load(f'./checkpoint/{args.net}_{args.dataset}_'+args.optim) 68 | model = checkpoint['model'] 69 | start_epoch = checkpoint['epoch'] 70 | train_losses = checkpoint['train_losses'] 71 | test_losses = checkpoint['test_losses'] 72 | train_errs = checkpoint['train_errs'] 73 | test_errs = checkpoint['test_errs'] 74 | else: 75 | print('==> Building model..') 76 | 77 | model = get_model(args) 78 | 79 | if use_cuda: 80 | model.cuda() 81 | model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) 82 | cudnn.benchmark = True 83 | 84 | 85 | criterion = nn.CrossEntropyLoss() 86 | 87 | betas = (args.beta1, args.beta2) 88 | from optimizers.mars_m import MARS_M 89 | from optimizers.moonlight import Moonlight 90 | if args.optim == 'adam': 91 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay = args.wd, betas = betas) 92 | elif args.optim == 'adamw': 93 | optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay = args.wd, betas = betas) 94 | elif args.optim == 'muon': 95 | muon_params = [p for p in model.parameters() if p.ndim == 2] 96 | adamw_params = [p for p in model.parameters() if p.ndim != 2] 97 | optimizer = Moonlight(muon_params=muon_params, adamw_params=adamw_params, 98 | lr=args.lr, wd=args.wd, adamw_betas=betas) 99 | elif args.optim == 'mars-m': 100 | muon_params = [p for p in model.parameters() if p.ndim >= 2] 101 | adamw_params = [p for p in model.parameters() if p.ndim < 2] 102 | if args.mars_exact: 103 | is_approx = False 104 | else: 105 | is_approx = True 106 | optimizer = MARS_M(muon_params=muon_params, adamw_params=adamw_params, 107 | lr=args.lr, wd=args.wd, adamw_betas=betas, gamma=args.gamma, 108 | clip_c=args.clip_c, is_approx=is_approx) 109 | 110 | scheduler = get_scheduler(optimizer, args) 111 | best_acc = 0 # best test accuracy 112 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 113 | train_errs = [] 114 | test_errs = [] 115 | train_losses = [] 116 | test_losses = [] 117 | acc_list = [] 118 | t_bar = tqdm(total=len(trainloader)) 119 | t_bar2 = tqdm(total=len(testloader)) 120 | previous_input = None 121 | previous_target = None 122 | for epoch in range(start_epoch+1, args.Nepoch+1): 123 | 124 | scheduler.step() 125 | # print ('\nEpoch: %d' % epoch, ' Learning rate:', scheduler.get_lr()) 126 | model.train() # Training 127 | 128 | train_loss = 0 129 | correct_train = 0 130 | total_train = 0 131 | t_bar.reset() 132 | for batch_idx, (inputs, targets) in enumerate(trainloader): 133 | if args.mars_exact: 134 | if previous_input is not None: 135 | if use_cuda: 136 | p_inputs, p_targets = previous_input.cuda(), previous_target.cuda() 137 | optimizer.zero_grad() 138 | p_outputs = model(p_inputs) 139 | p_loss = criterion(p_outputs, p_targets) 140 | p_loss.backward() 141 | optimizer.update_previous_grad() 142 | else: 143 | # copy inputs to previous_input 144 | previous_input = inputs.clone() 145 | previous_target = targets.clone() 146 | if use_cuda: 147 | inputs, targets = inputs.cuda(), targets.cuda() 148 | 149 | optimizer.zero_grad() 150 | outputs = model(inputs) 151 | loss = criterion(outputs, targets) 152 | loss.backward() 153 | optimizer.step() 154 | if args.mars_exact: 155 | optimizer.update_last_grad() 156 | 157 | train_loss += loss.item() 158 | _, predicted = torch.max(outputs.data, 1) 159 | total_train += targets.size(0) 160 | correct_train += predicted.eq(targets.data).cpu().sum().item() 161 | 162 | t_bar.update(1) 163 | t_bar.set_description('Epoch: %d | Loss: %.3f | Acc: %.3f%% ' % (epoch, train_loss/(batch_idx+1), 100.0/total_train*(correct_train))) 164 | t_bar.refresh() 165 | train_losses.append(train_loss/(batch_idx+1)) 166 | train_errs.append(1 - correct_train/total_train) 167 | 168 | model.eval() # Testing 169 | 170 | test_loss = 0 171 | correct = 0 172 | total = 0 173 | t_bar2.reset() 174 | for batch_idx, (inputs, targets) in enumerate(testloader): 175 | if use_cuda: 176 | inputs, targets = inputs.cuda(), targets.cuda() 177 | outputs = model(inputs) 178 | loss = criterion(outputs, targets) 179 | 180 | test_loss += loss.item() 181 | _, predicted = torch.max(outputs.data, 1) 182 | total += targets.size(0) 183 | correct += predicted.eq(targets.data).cpu().sum().item() 184 | 185 | t_bar2.update(1) 186 | t_bar2.set_description('Loss: %.3f | Acc: %.3f%% (Best: %.3f%%)' % (test_loss/(batch_idx+1), 100.0/total*(correct), best_acc)) 187 | t_bar2.refresh() 188 | test_errs.append(1 - correct/total) 189 | test_losses.append(test_loss/(batch_idx+1)) 190 | if args.wandb: 191 | wandb.log({"epoch": epoch, 192 | "train_loss": train_loss/(batch_idx+1), 193 | "train_acc": 100.0/total_train*(correct_train), 194 | "test_loss": test_loss/(batch_idx+1), 195 | "test_acc": 100.0/total*(correct), 196 | "lr": scheduler.get_lr()[0]}, step=epoch) 197 | # Save checkpoint 198 | acc = 100.0/total*(correct) 199 | if acc > best_acc: 200 | if not os.path.isdir('checkpoint'): 201 | os.mkdir('checkpoint') 202 | state = { 203 | 'model': model, 204 | 'epoch': epoch, 205 | } 206 | # torch.save(state, './checkpoint/cnn_cifar10_' + args.optim) 207 | torch.save(state, os.path.join(args.save_dir, "-".join([args.net, args.dataset, args.optim, str(args.lr).replace(".", "_")])+".pth")) 208 | best_acc = acc 209 | t_bar2.set_description('Model Saved! | Loss: %.3f | Acc: %.3f%% (Best: %.3f%%)' % (test_loss/(batch_idx+1), 100.0/total*(correct), best_acc)) 210 | t_bar2.refresh() 211 | acc_list.append(acc) 212 | -------------------------------------------------------------------------------- /MARS/train_CNN.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates 2 | # SPDX-License-Identifier: Apache-2.0 3 | import argparse 4 | from typing import List, Tuple, Type 5 | 6 | import matplotlib.pyplot as plt 7 | import torch 8 | import torch.nn as nn 9 | from torch.optim import Adam, AdamW 10 | from torch.optim.lr_scheduler import CosineAnnealingLR 11 | from torch.utils.data import DataLoader 12 | from torchvision import datasets, transforms 13 | import numpy as np 14 | from utils.model_CNN import Network 15 | from optimizers.adopt import ADOPT 16 | from optimizers.mars import MARS 17 | import random 18 | parser = argparse.ArgumentParser(add_help=True) 19 | parser.add_argument( 20 | "--dataset", type=str, default="cifar10", choices=["mnist", "cifar10"], help="dataset to use" 21 | ) 22 | parser.add_argument("-b", "--batch_size", type=int, default=128, help="batch size") 23 | parser.add_argument("-e", "--epochs", type=int, default=50, help="number of epochs") 24 | parser.add_argument("--seed", type=int, default=0, help="random seed") 25 | parser.add_argument("--cpu", action="store_true", help="use cpu only") 26 | 27 | 28 | def get_datasets(dataset_name: str, batch_size: int) -> Tuple[DataLoader, DataLoader]: 29 | """Get train and test dataloaders.""" 30 | if dataset_name == "mnist": 31 | transform = transforms.Compose([ 32 | transforms.ToTensor(), 33 | transforms.Normalize((0.1307,), (0.3081,)) 34 | ]) 35 | train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform) 36 | test_dataset = datasets.MNIST('./data', train=False, transform=transform) 37 | elif dataset_name == "cifar10": 38 | transform_train = transforms.Compose([ 39 | transforms.RandomHorizontalFlip(), 40 | transforms.ToTensor(), 41 | # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 42 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)) 43 | ]) 44 | transform_test = transforms.Compose([ 45 | transforms.ToTensor(), 46 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)) 47 | ]) 48 | train_dataset = datasets.CIFAR10('./data', train=True, download=True, transform=transform_train) 49 | test_dataset = datasets.CIFAR10('./data', train=False, transform=transform_test) 50 | else: 51 | raise NotImplementedError(f"{dataset_name=} is not implemented.") 52 | 53 | train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4) 54 | test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4) 55 | 56 | return train_loader, test_loader 57 | 58 | 59 | class WarmupCosineScheduler: 60 | """Custom learning rate scheduler with linear warmup and cosine decay.""" 61 | def __init__(self, optimizer, warmup_iters: int, total_iters: int, min_lr: float, max_lr: float): 62 | self.optimizer = optimizer 63 | self.warmup_iters = warmup_iters 64 | self.total_iters = total_iters 65 | self.min_lr = min_lr 66 | self.max_lr = max_lr 67 | self.current_iter = 0 68 | self.lr = 0 69 | 70 | def step(self): 71 | self.current_iter += 1 72 | if self.current_iter <= self.warmup_iters: 73 | lr = self.current_iter / self.warmup_iters * self.max_lr 74 | else: 75 | lr = self.min_lr + 0.5 * (self.max_lr - self.min_lr) * ( 76 | np.cos((self.current_iter - self.warmup_iters) / (self.total_iters - self.warmup_iters) * 3.14159265 / 2) 77 | ).item() 78 | 79 | for param_group in self.optimizer.param_groups: 80 | param_group['lr'] = lr 81 | self.lr = lr 82 | 83 | class Trainer: 84 | """Training manager for PyTorch models.""" 85 | def __init__(self, model: nn.Module, optimizer: torch.optim.Optimizer, scheduler, device: torch.device): 86 | self.model = model 87 | self.optimizer = optimizer 88 | self.scheduler = scheduler 89 | self.device = device 90 | self.criterion = nn.CrossEntropyLoss() 91 | self.train_acc_trace = [] 92 | self.val_acc_trace = [] 93 | 94 | def train_epoch(self, train_loader: DataLoader) -> float: 95 | self.model.train() 96 | correct = 0 97 | total = 0 98 | 99 | for batch in train_loader: 100 | images, targets = batch[0].to(self.device), batch[1].to(self.device) 101 | 102 | self.optimizer.zero_grad() 103 | outputs = self.model(images) 104 | loss = self.criterion(outputs, targets) 105 | loss.backward() 106 | self.optimizer.step() 107 | 108 | _, predicted = outputs.max(1) 109 | total += targets.size(0) 110 | correct += predicted.eq(targets).sum().item() 111 | if self.scheduler is not None: 112 | self.scheduler.step() 113 | return 100. * correct / total 114 | 115 | def evaluate(self, test_loader: DataLoader) -> float: 116 | self.model.eval() 117 | correct = 0 118 | total = 0 119 | 120 | with torch.no_grad(): 121 | for batch in test_loader: 122 | images, targets = batch[0].to(self.device), batch[1].to(self.device) 123 | outputs = self.model(images) 124 | 125 | _, predicted = outputs.max(1) 126 | total += targets.size(0) 127 | correct += predicted.eq(targets).sum().item() 128 | 129 | return 100. * correct / total 130 | 131 | def train(self, train_loader: DataLoader, test_loader: DataLoader, epochs: int): 132 | for epoch in range(epochs): 133 | train_acc = self.train_epoch(train_loader) 134 | val_acc = self.evaluate(test_loader) 135 | 136 | self.train_acc_trace.append(train_acc) 137 | self.val_acc_trace.append(val_acc) 138 | 139 | # if self.scheduler is not None: 140 | # self.scheduler.step() 141 | 142 | print(f"Epoch {epoch+1}/{epochs} - Train Acc: {train_acc:.2f}% - Val Acc: {val_acc:.2f}%") 143 | 144 | 145 | def get_optimizers(model: nn.Module, opt_name, args): 146 | """Configure optimizers and schedulers.""" 147 | total_steps = 50_000 // args.batch_size * args.epochs 148 | n_warmup = int(total_steps * 0.10) # % of total steps 149 | weight_decay = 1e-4 150 | max_lr = 6e-4 151 | min_lr = 1e-6 152 | 153 | if opt_name == "Adam": 154 | # Adam 155 | adam = Adam(model.parameters(), lr=max_lr) 156 | adam_scheduler = WarmupCosineScheduler( 157 | adam, n_warmup, total_steps, min_lr, max_lr 158 | ) 159 | optimizer = (adam, adam_scheduler, "Adam") 160 | 161 | elif opt_name == "AdamW": 162 | # AdamW 163 | adamw = AdamW(model.parameters(), lr=max_lr, weight_decay=weight_decay) 164 | adamw_scheduler = WarmupCosineScheduler( 165 | adamw, n_warmup, total_steps, min_lr, max_lr 166 | ) 167 | optimizer = (adamw, adamw_scheduler, "AdamW") 168 | elif opt_name == "ADOPT": 169 | # ADOPT 170 | adopt = ADOPT(model.parameters(), lr=max_lr, weight_decay=weight_decay) 171 | adopt_scheduler = WarmupCosineScheduler( 172 | adopt, n_warmup, total_steps, min_lr, max_lr 173 | ) 174 | optimizer = (adopt, adopt_scheduler, "ADOPT") 175 | elif opt_name == "MARS": 176 | # MARS 177 | mars = MARS(model.parameters(), lr=3e-3, weight_decay=weight_decay, optimize_1d=False) 178 | mars_scheduler = WarmupCosineScheduler( 179 | mars, n_warmup, total_steps, min_lr, 3e-3 180 | ) 181 | optimizer = (mars, mars_scheduler, "MARS") 182 | return optimizer 183 | 184 | 185 | def plot_results(results: List[List[float]], optimizer_names: List[str], args): 186 | """Plot training results.""" 187 | fig, ax = plt.subplots(figsize=(5.5, 3.5)) 188 | colors = ["#74add1", "#1730bd", "#1a9850", "#001c01"] 189 | 190 | for i, acc in enumerate(results): 191 | ax.plot(range(1, len(acc) + 1), acc, label=optimizer_names[i], lw=2, color=colors[i]) 192 | 193 | ax.set_title(f"{args.dataset.upper()} (val)", loc="left") 194 | ax.set_xlabel("Epoch", fontsize="medium") 195 | ax.set_ylabel("Accuracy (%)", fontsize="medium") 196 | 197 | ax.legend(ncols=2, columnspacing=0.8, fontsize="medium") 198 | ax.grid(alpha=0.2) 199 | 200 | ax.set_ylim(90 if args.dataset == "mnist" else 70) 201 | acc_min, acc_max = ax.get_ylim() 202 | ax.set_yticks(torch.linspace(acc_min, acc_max, 5).int().tolist()) 203 | ax.xaxis.set_major_locator(plt.MaxNLocator(integer=True)) 204 | 205 | fig.tight_layout() 206 | fig.savefig( 207 | f"./compare-{args.dataset}-blank.png", 208 | dpi=300, 209 | bbox_inches="tight", 210 | ) 211 | plt.show() 212 | 213 | 214 | def main(args): 215 | # Set random seed and device 216 | torch.manual_seed(args.seed) 217 | device = torch.device("cuda" if torch.cuda.is_available() and not args.cpu else "cpu") 218 | 219 | # Get dataloaders 220 | train_loader, test_loader = get_datasets(args.dataset, args.batch_size) 221 | # Model configuration 222 | model_config = { 223 | "n_inputs": (3, 32, 32) if args.dataset == "cifar10" else (1, 28, 28), 224 | "conv_layers_list": [ 225 | {"filters": 32, "kernel_size": 3, "repeat": 2, "batch_norm": True}, 226 | {"filters": 64, "kernel_size": 3, "repeat": 2, "batch_norm": True}, 227 | {"filters": 128, "kernel_size": 3, "repeat": 2, "batch_norm": True}, 228 | ], 229 | "n_hiddens_list": [512], 230 | "n_outputs": 10, 231 | "dropout": 0.2, 232 | } 233 | 234 | results = [] 235 | optimizer_names = [] 236 | # Train with different optimizers 237 | opt_names = ["Adam", "AdamW", "ADOPT", "MARS"] 238 | for opt_name in opt_names: 239 | print(opt_name) 240 | torch.manual_seed(args.seed) 241 | model = Network(**model_config).to(device) 242 | optimizer, scheduler, name = get_optimizers(model, opt_name, args) 243 | trainer = Trainer(model, optimizer, scheduler, device) 244 | trainer.train(train_loader, test_loader, args.epochs) 245 | results.append(trainer.val_acc_trace) 246 | optimizer_names.append(name) 247 | 248 | plot_results(results, optimizer_names, args) 249 | 250 | 251 | if __name__ == "__main__": 252 | args = parser.parse_args() 253 | main(args) -------------------------------------------------------------------------------- /MARS_M/optimizers/mars_m.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | @torch.compile 5 | def zeropower_via_newtonschulz5(G, steps): 6 | """ 7 | Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a 8 | quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose 9 | of minimizing steps, it turns out to be empirically effective to keep increasing the slope at 10 | zero even beyond the point where the iteration no longer converges all the way to one everywhere 11 | on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T 12 | where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model 13 | performance at all relative to UV^T, where USV^T = G is the SVD. 14 | 15 | This version allows for G to be more than 2D. 16 | """ 17 | # assert len(G.shape) == 2 18 | a, b, c = (3.4445, -4.7750, 2.0315) 19 | X = G.bfloat16() 20 | if G.size(-2) > G.size(-1): 21 | X = X.transpose(-2, -1) 22 | # Ensure spectral norm is at most 1 23 | X = X / (X.norm() + 1e-7) 24 | # Perform the NS iterations 25 | for _ in range(steps): 26 | A = X @ X.transpose(-2, -1) 27 | B = ( 28 | b * A + c * A @ A 29 | ) # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng 30 | X = a * X + B @ X 31 | 32 | if G.size(-2) > G.size(-1): 33 | X = X.transpose(-2, -1) 34 | return X 35 | 36 | class MARS_M(torch.optim.Optimizer): 37 | """ 38 | MARS_M: MARS on Matrix-Level 39 | 40 | Arguments: 41 | lr: The learning rate. The updates will have spectral norm of `lr`. 42 | wd: The weight decay. 43 | muon_params: The parameters to be optimized by Muon. 44 | momentum: The momentum used by the internal SGD. 45 | ns_steps: The number of Newton-Schulz iterations to run. 46 | adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are 47 | {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. 48 | adamw_betas: The betas used by AdamW. 49 | adamw_eps: The epsilon used by AdamW. 50 | gamma: The gamma parameter in MARS. 51 | clip_c: Whether to clip the c_t vector to have norm at most 1. 52 | is_approx: Whether to use the approximate version of MARS-M (True) or the exact version (False). 53 | The approximate version does not require any extra gradient computations, while the exact version 54 | requires one extra gradient computation per parameter per step. 55 | However, the exact version may yield slightly better performance. 56 | See the MARS-M paper for details. 57 | 58 | """ 59 | 60 | def __init__( 61 | self, 62 | lr=1e-3, 63 | wd=0.1, 64 | muon_params=None, 65 | momentum=0.95, 66 | ns_steps=5, 67 | adamw_params=None, 68 | adamw_betas=(0.9, 0.95), 69 | adamw_eps=1e-8, 70 | gamma=0.025, 71 | clip_c=False, 72 | is_approx=True, 73 | ): 74 | mars_factor = gamma * momentum / (1-momentum) 75 | defaults = dict( 76 | lr=lr, 77 | wd=wd, 78 | momentum=momentum, 79 | ns_steps=ns_steps, 80 | adamw_betas=adamw_betas, 81 | adamw_eps=adamw_eps, 82 | gamma=gamma, 83 | mars_factor=mars_factor, 84 | clip_c=clip_c, 85 | is_approx=is_approx 86 | ) 87 | 88 | params = list(muon_params) 89 | adamw_params = list(adamw_params) if adamw_params is not None else [] 90 | params.extend(adamw_params) 91 | super().__init__(params, defaults) 92 | self.is_approx = is_approx 93 | # Sort parameters into those for which we will use Muon, and those for which we will not 94 | for p in muon_params: 95 | # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer 96 | # assert p.ndim >= 2, p.ndim 97 | self.state[p]["use_muon"] = True 98 | for p in adamw_params: 99 | # Do not use Muon for parameters in adamw_params 100 | self.state[p]["use_muon"] = False 101 | 102 | def adjust_lr_for_muon(self, lr, param_shape): 103 | A, B = param_shape[:2] 104 | # We adjust the learning rate and weight decay based on the size of the parameter matrix 105 | # as describted in the paper 106 | adjusted_ratio = 0.2 * math.sqrt(max(A, B)) 107 | adjusted_lr = lr * adjusted_ratio 108 | return adjusted_lr 109 | 110 | @torch.no_grad() 111 | def update_last_grad(self): 112 | if not self.is_approx: 113 | for group in self.param_groups: 114 | for p in group['params']: 115 | state = self.state[p] 116 | ## only update previous grad for muon params 117 | if not state["use_muon"]: 118 | continue 119 | ## end skip 120 | if "last_grad" not in state: 121 | state["last_grad"] = torch.zeros_like(p) 122 | state["last_grad"].zero_().add_(state["previous_grad"], alpha=1.0) 123 | @torch.no_grad() 124 | def update_previous_grad(self): 125 | if not self.is_approx: 126 | for group in self.param_groups: 127 | for p in group['params']: 128 | if p.grad is None: 129 | print (p, "grad is none") 130 | continue 131 | state = self.state[p] 132 | ## only update previous grad for muon params 133 | if not state["use_muon"]: 134 | continue 135 | ## end skip 136 | if "previous_grad" not in state: 137 | state['previous_grad'] = torch.zeros_like(p) 138 | state['previous_grad'].zero_().add_(p.grad, alpha=1.0) 139 | 140 | def step(self, closure=None): 141 | """Performs a single optimization step. 142 | 143 | Arguments: 144 | closure (callable, optional): A closure that reevaluates the model 145 | and returns the loss. 146 | 147 | If using exact version, the example usage is as follows: 148 | previous_X, previous_Y = None, None 149 | for epoch in range(epochs): 150 | for X, Y in data_loader: 151 | if previous_X: 152 | logits, loss = model(X, Y) 153 | loss.backward() 154 | optimizer.update_previous_grad() 155 | optimizer.zero_grad(set_to_none=True) 156 | logits, loss = model(X, Y) 157 | loss.backward() 158 | optimizer.step(bs=bs) 159 | optimizer.zero_grad(set_to_none=True) 160 | optimizer.update_last_grad() 161 | iter_num += 1 162 | previous_X, previous_Y = X.clone(), Y.clone() 163 | """ 164 | loss = None 165 | if closure is not None: 166 | with torch.enable_grad(): 167 | loss = closure() 168 | 169 | for group in self.param_groups: 170 | 171 | ############################ 172 | # Muon # 173 | ############################ 174 | 175 | params = [p for p in group["params"] if self.state[p]["use_muon"]] 176 | # import pdb; pdb.set_trace() 177 | lr = group["lr"] 178 | wd = group["wd"] 179 | momentum = group["momentum"] 180 | gamma = group["gamma"] 181 | mars_factor = group["mars_factor"] 182 | for p in params: 183 | # sanity check 184 | g = p.grad 185 | if g is None: 186 | continue 187 | assert g is not None 188 | state = self.state[p] 189 | 190 | # default: MARS, nesterov 191 | if "last_grad" not in state: 192 | state["last_grad"] = torch.zeros_like(g) 193 | # calc update 194 | if "momentum_buffer" not in state: 195 | state["momentum_buffer"] = torch.zeros_like(g) 196 | # mars_factor = gamma * momentum / (1-momentum) 197 | c_t = (g - state["last_grad"]).mul(mars_factor).add(g) 198 | c_t_norm = c_t.norm() 199 | if c_t_norm > 1: 200 | c_t.div_(c_t_norm) 201 | buf = state["momentum_buffer"] 202 | buf.mul_(momentum).add_(c_t, alpha=(1 - momentum)) 203 | 204 | u = zeropower_via_newtonschulz5(buf, steps=group["ns_steps"]) 205 | 206 | 207 | # scale update 208 | adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) 209 | 210 | # apply weight decay 211 | p.data.mul_(1 - lr * wd) 212 | 213 | # apply update 214 | p.data.add_(u, alpha=-adjusted_lr) 215 | if self.is_approx: 216 | state["last_grad"] = g 217 | 218 | ############################ 219 | # AdamW backup # 220 | ############################ 221 | 222 | params = [p for p in group["params"] if not self.state[p]["use_muon"]] 223 | lr = group['lr'] 224 | beta1, beta2 = group["adamw_betas"] 225 | eps = group["adamw_eps"] 226 | weight_decay = group["wd"] 227 | 228 | for p in params: 229 | g = p.grad 230 | if g is None: 231 | continue 232 | state = self.state[p] 233 | if "step" not in state: 234 | state["step"] = 0 235 | state["moment1"] = torch.zeros_like(g) 236 | state["moment2"] = torch.zeros_like(g) 237 | state["step"] += 1 238 | step = state["step"] 239 | buf1 = state["moment1"] 240 | buf2 = state["moment2"] 241 | buf1.lerp_(g, 1 - beta1) 242 | buf2.lerp_(g.square(), 1 - beta2) 243 | 244 | g = buf1 / (eps + buf2.sqrt()) 245 | 246 | bias_correction1 = 1 - beta1**step 247 | bias_correction2 = 1 - beta2**step 248 | scale = bias_correction1 / bias_correction2**0.5 249 | p.data.mul_(1 - lr * weight_decay) 250 | p.data.add_(g, alpha=-lr / scale) 251 | 252 | return loss -------------------------------------------------------------------------------- /MARS/optimizers/mars.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates 2 | # SPDX-License-Identifier: Apache-2.0 3 | import math 4 | import torch 5 | from torch.optim.optimizer import Optimizer 6 | import os 7 | import numpy as np 8 | import math 9 | # from megatron.optimizer.l2_norm import l2_norm 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | 15 | def update_fn(p, grad, exp_avg, exp_avg_sq, lr, wd, beta1, beta2, last_grad, eps, amsgrad, max_exp_avg_sq, step, gamma, 16 | mars_type, is_grad_2d, optimize_1d, lr_1d_factor, betas_1d, weight_decay_1d): 17 | # optimize_1d: use MARS for 1d para, not: use AdamW for 1d para 18 | if optimize_1d or is_grad_2d: 19 | c_t = (grad - last_grad).mul(gamma * (beta1 / (1. - beta1))).add(grad) 20 | c_t_norm = torch.norm(c_t) 21 | if c_t_norm > 1.: 22 | c_t = c_t / c_t_norm 23 | exp_avg.mul_(beta1).add_(c_t, alpha=1. - beta1) 24 | if (mars_type == "mars-adamw") or (mars_type == "mars-shampoo" and not is_grad_2d): 25 | exp_avg_sq.mul_(beta2).addcmul_(c_t, c_t, value=1. - beta2) 26 | bias_correction1 = 1.0 - beta1 ** step 27 | bias_correction2 = 1.0 - beta2 ** step 28 | if amsgrad: 29 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 30 | denom = max_exp_avg_sq.sqrt().mul(1/math.sqrt(bias_correction2)).add(eps).mul(bias_correction1) 31 | else: 32 | denom = exp_avg_sq.sqrt().mul(1/math.sqrt(bias_correction2)).add(eps).mul(bias_correction1) 33 | real_update_tmp = -lr * torch.mul(p.data, wd).add(exp_avg.div(denom)) 34 | elif mars_type == "mars-lion": 35 | real_update_tmp = -lr * torch.mul(p.data, wd).add(exp_avg.sign()) 36 | elif mars_type == "mars-shampoo" and is_grad_2d: 37 | factor = max(1, grad.size(0)/grad.size(1))**0.5 38 | real_update_tmp = NewtonSchulz(exp_avg.mul(1./(1.-beta1)), eps=eps).mul(factor).add(wd, p.data).mul(-lr) 39 | p.data.add_(real_update_tmp) 40 | else: 41 | beta1_1d, beta2_1d = betas_1d 42 | exp_avg.mul_(beta1_1d).add_(grad, alpha=1. - beta1_1d) 43 | exp_avg_sq.mul_(beta2_1d).addcmul_(grad, grad, value=1. - beta2_1d) 44 | bias_correction1 = 1.0 - beta1_1d ** step 45 | bias_correction2 = 1.0 - beta2_1d ** step 46 | if amsgrad: 47 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 48 | denom = max_exp_avg_sq.sqrt().mul(1/math.sqrt(bias_correction2)).add(eps).mul(bias_correction1) 49 | else: 50 | denom = exp_avg_sq.sqrt().mul(1/math.sqrt(bias_correction2)).add(eps).mul(bias_correction1) 51 | real_update_tmp = -lr * lr_1d_factor * torch.mul(p.data, weight_decay_1d).add(exp_avg.div(denom)) 52 | p.data.add_(real_update_tmp) 53 | return exp_avg, exp_avg_sq 54 | 55 | class MARS(Optimizer): 56 | def __init__(self, params, lr=3e-3, betas=(0.95, 0.99), eps=1e-8, weight_decay=0., amsgrad=False, gamma=0.025, 57 | is_approx=True, mars_type="mars-adamw", optimize_1d=False, lr_1d=3e-3, betas_1d=(0.9, 0.95), weight_decay_1d=0.1): 58 | if not 0.0 <= lr: 59 | raise ValueError("Invalid learning rate: {}".format(lr)) 60 | if not 0.0 <= eps: 61 | raise ValueError("Invalid epsilon value: {}".format(eps)) 62 | if not 0.0 <= betas[0] < 1.0: 63 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 64 | if not 0.0 <= betas[1] < 1.0: 65 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 66 | assert mars_type in ["mars-adamw", "mars-lion", "mars-shampoo"], "MARS type not supported" 67 | defaults = dict(lr=lr, betas=betas, eps=eps, 68 | weight_decay=weight_decay, amsgrad=amsgrad, 69 | mars_type=mars_type, gamma=gamma, 70 | optimize_1d=optimize_1d, weight_decay_1d=weight_decay_1d) 71 | super(MARS, self).__init__(params, defaults) 72 | self.eps = eps 73 | self.update_fn = update_fn 74 | self.lr = lr 75 | self.weight_decay=weight_decay 76 | self.amsgrad = amsgrad 77 | self.step_num = 0 78 | self.is_approx = is_approx 79 | self.gamma = gamma 80 | self.mars_type = mars_type 81 | self.optimize_1d = optimize_1d 82 | self.lr_1d_factor = lr_1d / lr 83 | self.weight_decay_1d = weight_decay_1d 84 | self.betas_1d = betas_1d 85 | 86 | @torch.no_grad() 87 | def update_last_grad(self): 88 | if not self.is_approx: 89 | for group in self.param_groups: 90 | for p in group['params']: 91 | state = self.state[p] 92 | if "last_grad" not in state: 93 | state["last_grad"] = torch.zeros_like(p) 94 | state["last_grad"].zero_().add_(state["previous_grad"], alpha=1.0) 95 | @torch.no_grad() 96 | def update_previous_grad(self): 97 | if not self.is_approx: 98 | for group in self.param_groups: 99 | #print ("para name", len(group['params']), len(group['names']), group['names']) 100 | for p in group['params']: 101 | # import pdb 102 | # pdb.set_trace() 103 | if p.grad is None: 104 | print (p, "grad is none") 105 | continue 106 | state = self.state[p] 107 | if "previous_grad" not in state: 108 | state['previous_grad'] = torch.zeros_like(p) 109 | state['previous_grad'].zero_().add_(p.grad, alpha=1.0) 110 | 111 | def __setstate__(self, state): 112 | super(MARS, self).__setstate__(state) 113 | for group in self.param_groups: 114 | group.setdefault('amsgrad', False) 115 | 116 | @torch.no_grad() 117 | def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None, grad_scaler=None): 118 | """Performs a single optimization step. 119 | 120 | Arguments: 121 | closure (callable, optional): A closure that reevaluates the model 122 | and returns the loss. 123 | 124 | If using exact version, the example usage is as follows: 125 | previous_X, previous_Y = None, None 126 | for epoch in range(epochs): 127 | for X, Y in data_loader: 128 | if previous_X: 129 | logits, loss = model(X, Y) 130 | loss.backward() 131 | optimizer.update_previous_grad() 132 | optimizer.zero_grad(set_to_none=True) 133 | logits, loss = model(X, Y) 134 | loss.backward() 135 | optimizer.step(bs=bs) 136 | optimizer.zero_grad(set_to_none=True) 137 | optimizer.update_last_grad() 138 | iter_num += 1 139 | previous_X, previous_Y = X.clone(), Y.clone() 140 | """ 141 | if any(p is not None for p in [grads, output_params, scale, grad_norms]): 142 | raise RuntimeError('FusedAdam has been updated. Simply initialize it identically to torch.optim.Adam, and call step() with no arguments.') 143 | 144 | loss = None 145 | if exists(closure): 146 | with torch.enable_grad(): 147 | loss = closure() 148 | real_update = 0 149 | real_update_wo_lr = 0 150 | gamma = self.gamma 151 | # import pdb 152 | # pdb.set_trace() 153 | for group in self.param_groups: 154 | for p in filter(lambda p: exists(p.grad), group['params']): 155 | if p.grad is None: 156 | continue 157 | grad = p.grad.data 158 | if grad.is_sparse: 159 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 160 | amsgrad = group['amsgrad'] 161 | 162 | state = self.state[p] 163 | #('----- starting a parameter state', state.keys(), 'Length of state', len(state)) 164 | # State initialization 165 | if len(state) <= 1: 166 | state['step'] = 0 167 | # Exponential moving average of gradient values 168 | state['exp_avg'] = torch.zeros_like(p.data) 169 | # Last Gradient 170 | state['last_grad'] = torch.zeros_like(p) 171 | #state['previous_grad'] = torch.zeros_like(p) 172 | # Exponential moving average of squared gradient values 173 | state['exp_avg_sq'] = torch.zeros_like(p.data) 174 | if amsgrad: 175 | # Maintains max of all exp. moving avg. of sq. grad. values 176 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 177 | # import pdb 178 | # pdb.set_trace() 179 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 180 | last_grad = state['last_grad'] 181 | lr, wd, beta1, beta2 = group['lr'], group['weight_decay'], *group['betas'] 182 | if amsgrad: 183 | max_exp_avg_sq = state['max_exp_avg_sq'] 184 | else: 185 | max_exp_avg_sq = 0 186 | 187 | if 'step' in state: 188 | state['step'] += 1 189 | else: 190 | state['step'] = 1 191 | step = state['step'] 192 | is_grad_2d = (len(grad.shape) == 2) 193 | exp_avg, exp_avg_sq = self.update_fn( 194 | p, 195 | grad, 196 | exp_avg, 197 | exp_avg_sq, 198 | lr, 199 | wd, 200 | beta1, 201 | beta2, 202 | last_grad, 203 | self.eps, 204 | amsgrad, 205 | max_exp_avg_sq, 206 | step, 207 | gamma, 208 | mars_type=self.mars_type, 209 | is_grad_2d=is_grad_2d, 210 | optimize_1d=self.optimize_1d, 211 | lr_1d_factor=self.lr_1d_factor, 212 | betas_1d=self.betas_1d, 213 | weight_decay_1d=self.weight_decay if self.optimize_1d else self.weight_decay_1d 214 | ) 215 | if self.is_approx: 216 | state['last_grad'] = grad 217 | self.step_num = step 218 | 219 | return loss 220 | 221 | @torch.compile 222 | def NewtonSchulz(M, steps=5, eps=1e-7): 223 | a, b, c = (3.4445, -4.7750, 2.0315) 224 | X = M.bfloat16() / (M.norm() + eps) 225 | if M.size(0) > M.size(1): 226 | X = X.T 227 | for _ in range(steps): 228 | A = X @ X.T 229 | B = A @ X 230 | X = a * X + b * B + c * A @ B 231 | if M.size(0) > M.size(1): 232 | X = X.T 233 | return X.to(M.dtype) 234 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /MARS/utils/model_CNN.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Type, Union 2 | import importlib 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | def pair(t): 9 | return t if isinstance(t, tuple) else (t, t) 10 | 11 | 12 | def get_activation(activation_f: str) -> Type: 13 | """Get PyTorch activation function by name.""" 14 | package_name = "torch.nn" 15 | module = importlib.import_module(package_name) 16 | 17 | activations = [getattr(module, attr) for attr in dir(module)] 18 | activations = [ 19 | cls for cls in activations if isinstance(cls, type) and issubclass(cls, nn.Module) 20 | ] 21 | names = [cls.__name__.lower() for cls in activations] 22 | 23 | try: 24 | index = names.index(activation_f.lower()) 25 | return activations[index] 26 | except ValueError: 27 | raise NotImplementedError(f"get_activation: {activation_f=} is not yet implemented.") 28 | 29 | 30 | def compute_padding( 31 | input_size: tuple, kernel_size: int | tuple, stride: int | tuple = 1, dilation: int | tuple = 1 32 | ) -> Tuple[int, int]: 33 | """Compute padding for 'same' convolution.""" 34 | if len(input_size) == 2: 35 | input_size = (*input_size, 1) 36 | if isinstance(kernel_size, int): 37 | kernel_size = (kernel_size, kernel_size) 38 | if isinstance(stride, int): 39 | stride = (stride, stride) 40 | if isinstance(dilation, int): 41 | dilation = (dilation, dilation) 42 | 43 | input_h, input_w, _ = input_size 44 | kernel_h, kernel_w = kernel_size 45 | stride_h, stride_w = stride 46 | dilation_h, dilation_w = dilation 47 | 48 | # Compute the effective kernel size after dilation 49 | effective_kernel_h = (kernel_h - 1) * dilation_h + 1 50 | effective_kernel_w = (kernel_w - 1) * dilation_w + 1 51 | 52 | # Compute the padding needed for same convolution 53 | pad_h = int(max((input_h - 1) * stride_h + effective_kernel_h - input_h, 0)) 54 | pad_w = int(max((input_w - 1) * stride_w + effective_kernel_w - input_w, 0)) 55 | 56 | # Compute the padding for each side 57 | pad_top = pad_h // 2 58 | pad_left = pad_w // 2 59 | 60 | return (pad_top, pad_left) 61 | 62 | 63 | class Base(nn.Module): 64 | """Base class for neural network models.""" 65 | def __init__(self, **kwargs): 66 | super().__init__() 67 | self.__dict__.update(kwargs) 68 | 69 | @property 70 | def num_params(self): 71 | return sum(p.numel() for p in self.parameters()) 72 | 73 | @property 74 | def shapes(self): 75 | return {name: p.shape for name, p in self.named_parameters()} 76 | 77 | def summary(self): 78 | print(self) 79 | print(f"Number of parameters: {self.num_params}") 80 | 81 | 82 | class Network(Base): 83 | """Fully Connected / Convolutional Neural Network 84 | 85 | Args: 86 | n_inputs (Union[List[int], Tuple[int], torch.Size]): Input shape 87 | n_outputs (int): Number of output classes 88 | conv_layers_list (List[dict], optional): List of convolutional layers. Defaults to []. 89 | n_hiddens_list (Union[List, int], optional): List of hidden units. Defaults to 0. 90 | activation_f (str, optional): Activation function. Defaults to "ReLU". 91 | dropout (float, optional): Dropout rate. Defaults to 0.0. 92 | 93 | conv_layers_list dict keys: 94 | filters: int 95 | kernel_size: int 96 | stride: int 97 | dilation: int 98 | padding: int 99 | bias: bool 100 | batch_norm: bool 101 | repeat: int 102 | """ 103 | def __init__( 104 | self, 105 | n_inputs: Union[List[int], Tuple[int], torch.Size], 106 | n_outputs: int, 107 | conv_layers_list: List[dict] = [], 108 | n_hiddens_list: Union[List, int] = 0, 109 | activation_f: str = "ReLU", 110 | dropout: float = 0.0, 111 | ): 112 | super().__init__() 113 | 114 | if isinstance(n_hiddens_list, int): 115 | n_hiddens_list = [n_hiddens_list] 116 | 117 | if n_hiddens_list == [] or n_hiddens_list == [0]: 118 | self.n_hidden_layers = 0 119 | else: 120 | self.n_hidden_layers = len(n_hiddens_list) 121 | 122 | activation = get_activation(activation_f) 123 | 124 | # Convert n_inputs to tensor for shape calculations 125 | ni = torch.tensor(n_inputs) 126 | 127 | conv_layers = [] 128 | if conv_layers_list: 129 | for conv_layer in conv_layers_list: 130 | n_channels = int(ni[0]) 131 | 132 | padding = conv_layer.get( 133 | "padding", 134 | compute_padding( # same padding 135 | tuple(ni.tolist()), 136 | conv_layer["kernel_size"], 137 | conv_layer.get("stride", 1), 138 | conv_layer.get("dilation", 1), 139 | ), 140 | ) 141 | 142 | # Add repeated conv blocks 143 | for i in range(conv_layer.get("repeat", 1)): 144 | # Convolutional layer 145 | conv_layers.append( 146 | nn.Conv2d( 147 | n_channels if i == 0 else conv_layer["filters"], 148 | conv_layer["filters"], 149 | conv_layer["kernel_size"], 150 | stride=conv_layer.get("stride", 1), 151 | padding=padding, 152 | dilation=conv_layer.get("dilation", 1), 153 | bias=conv_layer.get("bias", True), 154 | ) 155 | ) 156 | 157 | # Activation 158 | conv_layers.append(activation()) 159 | 160 | # Optional batch norm 161 | if conv_layer.get("batch_norm"): 162 | conv_layers.append(nn.BatchNorm2d(conv_layer["filters"])) 163 | 164 | # Max pooling after each conv block 165 | conv_layers.append(nn.MaxPool2d(2, stride=2)) 166 | 167 | # Optional dropout 168 | if dropout > 0: 169 | conv_layers.append(nn.Dropout(dropout)) 170 | 171 | # Update input shape for next layer 172 | ni = torch.cat([torch.tensor([conv_layer["filters"]]), ni[1:] // 2]) 173 | 174 | self.conv = nn.Sequential(*conv_layers) 175 | 176 | # Fully connected layers 177 | ni = int(torch.prod(ni)) 178 | fcn_layers = [] 179 | if self.n_hidden_layers > 0: 180 | for _, n_units in enumerate(n_hiddens_list): 181 | fcn_layers.extend([ 182 | nn.Linear(ni, n_units), 183 | activation() 184 | ]) 185 | if dropout > 0: 186 | fcn_layers.append(nn.Dropout(dropout)) 187 | ni = n_units 188 | 189 | self.fcn = nn.Sequential(*fcn_layers) 190 | self.output = nn.Linear(ni, n_outputs) 191 | 192 | def forward(self, x: torch.Tensor) -> torch.Tensor: 193 | x = self.conv(x) 194 | x = x.view(x.size(0), -1) 195 | x = self.fcn(x) 196 | return self.output(x) 197 | 198 | '''ResNet in PyTorch. 199 | 200 | For Pre-activation ResNet, see 'preact_resnet.py'. 201 | 202 | Reference: 203 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 204 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 205 | ''' 206 | 207 | 208 | 209 | class BasicBlock(nn.Module): 210 | expansion = 1 211 | 212 | def __init__(self, in_planes, planes, stride=1): 213 | super(BasicBlock, self).__init__() 214 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 215 | self.bn1 = nn.BatchNorm2d(planes) 216 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 217 | self.bn2 = nn.BatchNorm2d(planes) 218 | 219 | self.shortcut = nn.Sequential() 220 | if stride != 1 or in_planes != self.expansion*planes: 221 | self.shortcut = nn.Sequential( 222 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 223 | nn.BatchNorm2d(self.expansion*planes) 224 | ) 225 | 226 | def forward(self, x): 227 | out = F.relu(self.bn1(self.conv1(x))) 228 | out = self.bn2(self.conv2(out)) 229 | out += self.shortcut(x) 230 | out = F.relu(out) 231 | return out 232 | 233 | 234 | class Bottleneck(nn.Module): 235 | expansion = 4 236 | 237 | def __init__(self, in_planes, planes, stride=1): 238 | super(Bottleneck, self).__init__() 239 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 240 | self.bn1 = nn.BatchNorm2d(planes) 241 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 242 | self.bn2 = nn.BatchNorm2d(planes) 243 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 244 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 245 | 246 | self.shortcut = nn.Sequential() 247 | if stride != 1 or in_planes != self.expansion*planes: 248 | self.shortcut = nn.Sequential( 249 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 250 | nn.BatchNorm2d(self.expansion*planes) 251 | ) 252 | 253 | def forward(self, x): 254 | out = F.relu(self.bn1(self.conv1(x))) 255 | out = F.relu(self.bn2(self.conv2(out))) 256 | out = self.bn3(self.conv3(out)) 257 | out += self.shortcut(x) 258 | out = F.relu(out) 259 | return out 260 | 261 | 262 | class ResNet(nn.Module): 263 | def __init__(self, block, num_blocks, num_classes=10): 264 | super(ResNet, self).__init__() 265 | self.in_planes = 64 266 | 267 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 268 | self.bn1 = nn.BatchNorm2d(64) 269 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 270 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 271 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 272 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 273 | self.linear = nn.Linear(512*block.expansion, num_classes) 274 | 275 | def _make_layer(self, block, planes, num_blocks, stride): 276 | strides = [stride] + [1]*(num_blocks-1) 277 | layers = [] 278 | for stride in strides: 279 | layers.append(block(self.in_planes, planes, stride)) 280 | self.in_planes = planes * block.expansion 281 | return nn.Sequential(*layers) 282 | 283 | def forward(self, x): 284 | out = F.relu(self.bn1(self.conv1(x))) 285 | out = self.layer1(out) 286 | out = self.layer2(out) 287 | out = self.layer3(out) 288 | out = self.layer4(out) 289 | out = F.avg_pool2d(out, 4) 290 | out = out.view(out.size(0), -1) 291 | out = self.linear(out) 292 | return out 293 | 294 | 295 | def ResNet18(num_classes = 10): 296 | return ResNet(BasicBlock, [2,2,2,2], num_classes = num_classes) 297 | 298 | def ResNet34(num_classes = 10): 299 | return ResNet(BasicBlock, [3,4,6,3], num_classes = num_classes) 300 | 301 | def ResNet50(num_classes = 10): 302 | return ResNet(Bottleneck, [3,4,6,3], num_classes = num_classes) 303 | 304 | def ResNet101(num_classes = 10): 305 | return ResNet(Bottleneck, [3,4,23,3], num_classes = num_classes) 306 | 307 | def ResNet152(num_classes = 10): 308 | return ResNet(Bottleneck, [3,8,36,3], num_classes = num_classes) -------------------------------------------------------------------------------- /MARS_M/utils/model_CNN.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Type, Union 2 | import importlib 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | def pair(t): 9 | return t if isinstance(t, tuple) else (t, t) 10 | 11 | 12 | def get_activation(activation_f: str) -> Type: 13 | """Get PyTorch activation function by name.""" 14 | package_name = "torch.nn" 15 | module = importlib.import_module(package_name) 16 | 17 | activations = [getattr(module, attr) for attr in dir(module)] 18 | activations = [ 19 | cls for cls in activations if isinstance(cls, type) and issubclass(cls, nn.Module) 20 | ] 21 | names = [cls.__name__.lower() for cls in activations] 22 | 23 | try: 24 | index = names.index(activation_f.lower()) 25 | return activations[index] 26 | except ValueError: 27 | raise NotImplementedError(f"get_activation: {activation_f=} is not yet implemented.") 28 | 29 | 30 | def compute_padding( 31 | input_size: tuple, kernel_size: int | tuple, stride: int | tuple = 1, dilation: int | tuple = 1 32 | ) -> Tuple[int, int]: 33 | """Compute padding for 'same' convolution.""" 34 | if len(input_size) == 2: 35 | input_size = (*input_size, 1) 36 | if isinstance(kernel_size, int): 37 | kernel_size = (kernel_size, kernel_size) 38 | if isinstance(stride, int): 39 | stride = (stride, stride) 40 | if isinstance(dilation, int): 41 | dilation = (dilation, dilation) 42 | 43 | input_h, input_w, _ = input_size 44 | kernel_h, kernel_w = kernel_size 45 | stride_h, stride_w = stride 46 | dilation_h, dilation_w = dilation 47 | 48 | # Compute the effective kernel size after dilation 49 | effective_kernel_h = (kernel_h - 1) * dilation_h + 1 50 | effective_kernel_w = (kernel_w - 1) * dilation_w + 1 51 | 52 | # Compute the padding needed for same convolution 53 | pad_h = int(max((input_h - 1) * stride_h + effective_kernel_h - input_h, 0)) 54 | pad_w = int(max((input_w - 1) * stride_w + effective_kernel_w - input_w, 0)) 55 | 56 | # Compute the padding for each side 57 | pad_top = pad_h // 2 58 | pad_left = pad_w // 2 59 | 60 | return (pad_top, pad_left) 61 | 62 | 63 | class Base(nn.Module): 64 | """Base class for neural network models.""" 65 | def __init__(self, **kwargs): 66 | super().__init__() 67 | self.__dict__.update(kwargs) 68 | 69 | @property 70 | def num_params(self): 71 | return sum(p.numel() for p in self.parameters()) 72 | 73 | @property 74 | def shapes(self): 75 | return {name: p.shape for name, p in self.named_parameters()} 76 | 77 | def summary(self): 78 | print(self) 79 | print(f"Number of parameters: {self.num_params}") 80 | 81 | 82 | class Network(Base): 83 | """Fully Connected / Convolutional Neural Network 84 | 85 | Args: 86 | n_inputs (Union[List[int], Tuple[int], torch.Size]): Input shape 87 | n_outputs (int): Number of output classes 88 | conv_layers_list (List[dict], optional): List of convolutional layers. Defaults to []. 89 | n_hiddens_list (Union[List, int], optional): List of hidden units. Defaults to 0. 90 | activation_f (str, optional): Activation function. Defaults to "ReLU". 91 | dropout (float, optional): Dropout rate. Defaults to 0.0. 92 | 93 | conv_layers_list dict keys: 94 | filters: int 95 | kernel_size: int 96 | stride: int 97 | dilation: int 98 | padding: int 99 | bias: bool 100 | batch_norm: bool 101 | repeat: int 102 | """ 103 | def __init__( 104 | self, 105 | n_inputs: Union[List[int], Tuple[int], torch.Size], 106 | n_outputs: int, 107 | conv_layers_list: List[dict] = [], 108 | n_hiddens_list: Union[List, int] = 0, 109 | activation_f: str = "ReLU", 110 | dropout: float = 0.0, 111 | ): 112 | super().__init__() 113 | 114 | if isinstance(n_hiddens_list, int): 115 | n_hiddens_list = [n_hiddens_list] 116 | 117 | if n_hiddens_list == [] or n_hiddens_list == [0]: 118 | self.n_hidden_layers = 0 119 | else: 120 | self.n_hidden_layers = len(n_hiddens_list) 121 | 122 | activation = get_activation(activation_f) 123 | 124 | # Convert n_inputs to tensor for shape calculations 125 | ni = torch.tensor(n_inputs) 126 | 127 | conv_layers = [] 128 | if conv_layers_list: 129 | for conv_layer in conv_layers_list: 130 | n_channels = int(ni[0]) 131 | 132 | padding = conv_layer.get( 133 | "padding", 134 | compute_padding( # same padding 135 | tuple(ni.tolist()), 136 | conv_layer["kernel_size"], 137 | conv_layer.get("stride", 1), 138 | conv_layer.get("dilation", 1), 139 | ), 140 | ) 141 | 142 | # Add repeated conv blocks 143 | for i in range(conv_layer.get("repeat", 1)): 144 | # Convolutional layer 145 | conv_layers.append( 146 | nn.Conv2d( 147 | n_channels if i == 0 else conv_layer["filters"], 148 | conv_layer["filters"], 149 | conv_layer["kernel_size"], 150 | stride=conv_layer.get("stride", 1), 151 | padding=padding, 152 | dilation=conv_layer.get("dilation", 1), 153 | bias=conv_layer.get("bias", True), 154 | ) 155 | ) 156 | 157 | # Activation 158 | conv_layers.append(activation()) 159 | 160 | # Optional batch norm 161 | if conv_layer.get("batch_norm"): 162 | conv_layers.append(nn.BatchNorm2d(conv_layer["filters"])) 163 | 164 | # Max pooling after each conv block 165 | conv_layers.append(nn.MaxPool2d(2, stride=2)) 166 | 167 | # Optional dropout 168 | if dropout > 0: 169 | conv_layers.append(nn.Dropout(dropout)) 170 | 171 | # Update input shape for next layer 172 | ni = torch.cat([torch.tensor([conv_layer["filters"]]), ni[1:] // 2]) 173 | 174 | self.conv = nn.Sequential(*conv_layers) 175 | 176 | # Fully connected layers 177 | ni = int(torch.prod(ni)) 178 | fcn_layers = [] 179 | if self.n_hidden_layers > 0: 180 | for _, n_units in enumerate(n_hiddens_list): 181 | fcn_layers.extend([ 182 | nn.Linear(ni, n_units), 183 | activation() 184 | ]) 185 | if dropout > 0: 186 | fcn_layers.append(nn.Dropout(dropout)) 187 | ni = n_units 188 | 189 | self.fcn = nn.Sequential(*fcn_layers) 190 | self.output = nn.Linear(ni, n_outputs) 191 | 192 | def forward(self, x: torch.Tensor) -> torch.Tensor: 193 | x = self.conv(x) 194 | x = x.view(x.size(0), -1) 195 | x = self.fcn(x) 196 | return self.output(x) 197 | 198 | '''ResNet in PyTorch. 199 | 200 | For Pre-activation ResNet, see 'preact_resnet.py'. 201 | 202 | Reference: 203 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 204 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 205 | ''' 206 | 207 | 208 | 209 | class BasicBlock(nn.Module): 210 | expansion = 1 211 | 212 | def __init__(self, in_planes, planes, stride=1): 213 | super(BasicBlock, self).__init__() 214 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 215 | self.bn1 = nn.BatchNorm2d(planes) 216 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 217 | self.bn2 = nn.BatchNorm2d(planes) 218 | 219 | self.shortcut = nn.Sequential() 220 | if stride != 1 or in_planes != self.expansion*planes: 221 | self.shortcut = nn.Sequential( 222 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 223 | nn.BatchNorm2d(self.expansion*planes) 224 | ) 225 | 226 | def forward(self, x): 227 | out = F.relu(self.bn1(self.conv1(x))) 228 | out = self.bn2(self.conv2(out)) 229 | out += self.shortcut(x) 230 | out = F.relu(out) 231 | return out 232 | 233 | 234 | class Bottleneck(nn.Module): 235 | expansion = 4 236 | 237 | def __init__(self, in_planes, planes, stride=1): 238 | super(Bottleneck, self).__init__() 239 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 240 | self.bn1 = nn.BatchNorm2d(planes) 241 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 242 | self.bn2 = nn.BatchNorm2d(planes) 243 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 244 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 245 | 246 | self.shortcut = nn.Sequential() 247 | if stride != 1 or in_planes != self.expansion*planes: 248 | self.shortcut = nn.Sequential( 249 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 250 | nn.BatchNorm2d(self.expansion*planes) 251 | ) 252 | 253 | def forward(self, x): 254 | out = F.relu(self.bn1(self.conv1(x))) 255 | out = F.relu(self.bn2(self.conv2(out))) 256 | out = self.bn3(self.conv3(out)) 257 | out += self.shortcut(x) 258 | out = F.relu(out) 259 | return out 260 | 261 | 262 | class ResNet(nn.Module): 263 | def __init__(self, block, num_blocks, num_classes=10): 264 | super(ResNet, self).__init__() 265 | self.in_planes = 64 266 | 267 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 268 | self.bn1 = nn.BatchNorm2d(64) 269 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 270 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 271 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 272 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 273 | self.linear = nn.Linear(512*block.expansion, num_classes) 274 | 275 | def _make_layer(self, block, planes, num_blocks, stride): 276 | strides = [stride] + [1]*(num_blocks-1) 277 | layers = [] 278 | for stride in strides: 279 | layers.append(block(self.in_planes, planes, stride)) 280 | self.in_planes = planes * block.expansion 281 | return nn.Sequential(*layers) 282 | 283 | def forward(self, x): 284 | out = F.relu(self.bn1(self.conv1(x))) 285 | out = self.layer1(out) 286 | out = self.layer2(out) 287 | out = self.layer3(out) 288 | out = self.layer4(out) 289 | out = F.avg_pool2d(out, 4) 290 | out = out.view(out.size(0), -1) 291 | out = self.linear(out) 292 | return out 293 | 294 | 295 | def ResNet18(num_classes = 10): 296 | return ResNet(BasicBlock, [2,2,2,2], num_classes = num_classes) 297 | 298 | def ResNet34(num_classes = 10): 299 | return ResNet(BasicBlock, [3,4,6,3], num_classes = num_classes) 300 | 301 | def ResNet50(num_classes = 10): 302 | return ResNet(Bottleneck, [3,4,6,3], num_classes = num_classes) 303 | 304 | def ResNet101(num_classes = 10): 305 | return ResNet(Bottleneck, [3,4,23,3], num_classes = num_classes) 306 | 307 | def ResNet152(num_classes = 10): 308 | return ResNet(Bottleneck, [3,8,36,3], num_classes = num_classes) -------------------------------------------------------------------------------- /MARS_M/README.md: -------------------------------------------------------------------------------- 1 | # MARS-M: When Variance Reduction Meets Matrices 2 | 3 | This repository contains the official code for the paper [MARS-M: When Variance Reduction Meets Matrices](https://arxiv.org/abs/2510.21800). 4 | 5 | Authors: [Yifeng Liu](https://scholar.google.com/citations?user=mFvOVkMAAAAJ&hl=zh-CN)\*, [Angela Yuan](https://scholar.google.com/citations?user=8foZzX4AAAAJ)\*, [Quanquan Gu](https://web.cs.ucla.edu/~qgu/) 6 | 7 | ## 🔔 NEWS 8 | - **[10/20/2025]** Our paper is released on arXiv: https://arxiv.org/abs/2510.21800. 9 | - **[10/05/2025]** Our code is released. 10 | 11 | ## MARS-M 12 | 13 | **MARS-M** is a brand-new optimizer that integrates matrix-based optimizer (i.e., Muon and Moonlight) with the variance-reduction based optimizer MARS to reduce high stochastic gradient variance in the training process. 14 | 15 | In detail, the **MARS-M** optimizer is built on **MARS** framework: 16 | 17 | --- 18 | 19 | **Algorithm 1** MARS 20 | 21 | --- 22 | 23 | $$ 24 | \begin{align*} 25 | &\pmb{input: }\mathbf{x}_0\in\mathbb{R}^{A\times B}, \lambda, \beta, \{\gamma_t\}, \{\eta_t\}\\ 26 | &\text{Set }\mathbf{m}_0\leftarrow \mathbf{0}\textbf{ and }\mathbf{x}_1\leftarrow\mathbf{x}_0\\ 27 | &\pmb{for }\textbf{ }t=1,\pmb{ to }\textbf{ }n\textbf{ }\pmb{ do}\\ 28 | &\qquad\textbf{sample }\mathbf{\xi}_t\textbf{ and let }\mathbf{c}_t = \nabla f(\mathbf{x}_t, \mathbf{\xi}_t)+\gamma_t\bigg(\frac{\beta}{1-\beta}\bigg)\big(\nabla f(\mathbf{x}_t, \mathbf{\xi}_t)-\nabla f(\mathbf{x}_{t-1}, \mathbf{\xi}_t)\big)\\ 29 | &\qquad\mathbf{m}_t = \beta \mathbf{m}_{t-1} + (1-\beta)\text{Clip}(\mathbf{c}_t, 1)\\ 30 | &\qquad\mathbf{x}\_{t+1} = \arg\min_{\mathbf{x} \in \mathbb{R}^d} \left\\{\eta_t \left\langle \mathbf{m}_t, \mathbf{x} \right\rangle + \frac{1}{2} \\|\mathbf{x} - \mathbf{x}\_t 31 | \\|\_{\mathbf{H}_t}^2\right\\}\\ 32 | &\pmb{end}\textbf{ }\pmb{for} 33 | \end{align*} 34 | $$ 35 | 36 | --- 37 | 38 | where 39 | 40 | $$ 41 | \text{Clip}(\mathbf{c}_t,1) = \begin{cases} 42 | \frac{\mathbf{c}_t}{\\|\mathbf{c}_t\\|_2} & \text{if } \\|\mathbf{c}_t\\|_2 > 1,\\ 43 | \mathbf{c}_t & \text{otherwise}. 44 | \end{cases} 45 | $$ 46 | 47 | 48 | 49 | Here ${\color{red}\gamma_t}$ is a scaling parameter that controls the strength of gradient correction and plays a central role in MARS. 50 | 51 | Under the **MARS** framework, we propose **MARS-M** that applies MARS to matrix-based optimizers (See `optimizers/mars_m.py` for the implementation): 52 | 53 | --- 54 | 55 | **Algorithm 2** MARS-M 56 | 57 | --- 58 | 59 | $$ 60 | \begin{align*} 61 | &\pmb{input: }\mathbf{X}_0\in\mathbb{R}^{A\times B}, \lambda, \beta, \{\gamma_t\}, \{\eta_t\}\\ 62 | &\text{Set }\mathbf{M}_0\leftarrow \mathbf{0}\textbf{ and }\mathbf{X}_1\leftarrow\mathbf{X}_0\\ 63 | &\pmb{for }\textbf{ }t=1,\pmb{ to }\textbf{ }n\textbf{ }\pmb{ do}\\ 64 | &\qquad\textbf{sample }\mathbf{\xi}_t\textbf{ and let }\mathbf{C}_t = \nabla f(\mathbf{X}_t, \mathbf{\xi}_t)+\gamma_t\bigg(\frac{\beta}{1-\beta}\bigg)\big(\nabla f(\mathbf{X}_t, \mathbf{\xi}_t)-\nabla f(\mathbf{X}_{t-1}, \mathbf{\xi}_t)\big)\\ 65 | &\qquad\mathbf{M}_t = \beta \mathbf{M}_{t-1} + (1-\beta)\text{Clip}(\mathbf{C}_t, 1)\\ 66 | &\qquad\mathbf{O}_t = \text{NewtonSchulz}(\mathbf{M}_t)\\ 67 | &\qquad\mathbf{X}_{t+1} = \mathbf{X}_t - \eta_t(0.2\cdot\mathbf{O}_t\cdot\sqrt{\max(A,B)} + \lambda \mathbf{X}_t)\\ 68 | &\pmb{end}\textbf{ }\pmb{for} 69 | \end{align*} 70 | $$ 71 | 72 | --- 73 | 74 | To accelerate training process, we also propose the approximated version of MARS-M by substituting $f(\mathbf{X}\_{t-1}, \mathbf{\xi}\_t)$ with $f(\mathbf{X}\_{t-1}, \mathbf{\xi}\_{t-1})$ as follows: 75 | 76 | --- 77 | 78 | **Algorithm 3** MARS-M-approx 79 | 80 | --- 81 | 82 | $$ 83 | \begin{align*} 84 | &\pmb{input: }\mathbf{X}_0\in\mathbb{R}^{A\times B}, \lambda, \beta, \{\gamma_t\}, \{\eta_t\}\\ 85 | &\text{Set }\mathbf{M}_0\leftarrow \mathbf{0}\textbf{ and }\mathbf{X}_1\leftarrow\mathbf{X}_0\\ 86 | &\pmb{for }\textbf{ }t=1,\pmb{ to }\textbf{ }n\textbf{ }\pmb{ do}\\ 87 | &\qquad\textbf{sample }\mathbf{\xi}_t\textbf{ and let }\mathbf{C}_t = \nabla f(\mathbf{X}_t, \mathbf{\xi}_t)+\gamma_t\bigg(\frac{\beta}{1-\beta}\bigg)\big(\nabla f(\mathbf{X}_t, \mathbf{\xi}_t)-\nabla f(\mathbf{X}_{t-1}, \mathbf{\xi}_{t-1})\big)\\ 88 | &\qquad\mathbf{M}_t = \beta \mathbf{M}_{t-1} + (1-\beta)\text{Clip}(\mathbf{C}_t, 1)\\ 89 | &\qquad\mathbf{O}_t = \text{NewtonSchulz}(\mathbf{M}_t)\\ 90 | &\qquad\mathbf{X}_{t+1} = \mathbf{X}_t - \eta_t(0.2\cdot\mathbf{O}_t\cdot\sqrt{\max(A,B)} + \lambda \mathbf{X}_t)\\ 91 | &\pmb{end}\textbf{ }\pmb{for} 92 | \end{align*} 93 | $$ 94 | 95 | --- 96 | 97 | ### **Performance of MARS-M Compared to Baseline of Muon (Moonlight) and AdamW** 98 | 99 | #### Experiment Settings 100 | 101 | We implement grid search on learning rates for AdamW and Muon (Moonlight) and use the same hyper-parameters of Muon (Moonlight) for experiments with **MARS-M**. 102 | 103 | #### Experiments on OpenWebText 104 | 105 | In our experiments, gradients are calculated once per sample and per update (**MARS-M**-approx). Performing exact gradient computation with two evaluations per update, as in the exact form of **MARS-M**, can slightly enhance performance but at the cost of doubling the computational cost. Moreover, **MARS-M** also outperforms AdamW for the best loss value. 106 | 107 | **MARS-M** consistently outperforms [Muon (Moonlight version)](https://arxiv.org/abs/2502.16982) optimizers across GPT-2 models: 108 | 109 | | **GPT-2 small** | **GPT-2 medium** | **GPT-2 large** | 110 | | ------------------------------------------------ | ------------------------------------------------- | ------------------------------------------------ | 111 | | | | | 112 | 113 | --- 114 | 115 | Zoomed-in loss curves 116 | | **GPT-2 small** | **GPT-2 medium** | **GPT-2 large** | 117 | | ------------------------------------------------ | ------------------------------------------------- | ------------------------------------------------ | 118 | | | | | 119 | 120 | #### Experiments on FineWeb-Edu 121 | 122 | Below are the training and validation loss curves for both GPT‑2 Small and GPT‑2 XL when using our MARS-M approach versus [Muon (Moonlight version)](https://arxiv.org/abs/2502.16982) optimizers. As you can see, MARS-M often yields faster convergence and consistently lower losses across different training steps. Moreover, **MARS-M** also outperforms AdamW for the best loss value. 123 | 124 | | Model | **GPT-2 small** | **GPT-2 XL** | 125 | | ------------------------- | -------------------------------------------------- | ----------------------------------------------- | 126 | | **Training Loss** | | | 127 | | **Validation Loss** | | | 128 | 129 | --- 130 | 131 | Zoomed-in loss curves 132 | | Model | **GPT-2 small** | **GPT-2 XL** | 133 | | ------------------------- | -------------------------------------------------- | ----------------------------------------------- | 134 | | **Training Loss** | | | 135 | | **Validation Loss** | | | 136 | 137 | ## Training GPT-2 from Scratch: 138 | 139 | ### Install Dependencies 140 | 141 | ``` 142 | $ pip install torch==2.1.2 transformers==4.33.0 datasets tiktoken numpy==1.26.4 wandb 143 | ``` 144 | 145 | ### Data Preparation 146 | 147 | Prepare the [OpenWebText](https://huggingface.co/datasets/openwebtext) data following [nanoGPT](https://github.com/karpathy/nanoGPT/): 148 | 149 | ``` 150 | $ python data/openwebtext/prepare.py 151 | ``` 152 | 153 | ### **Start Training** 154 | 155 | To train a model using the **MARS-M** optimizer, run the following command: 156 | 157 | ```bash 158 | $ torchrun --standalone --nproc_per_node=8 train_mars_m.py config/${your_config_file} 159 | ``` 160 | 161 | This command initiates the training of a GPT-2 model on the OpenWebText dataset using the **MARS-M** optimizer. All relevant hyperparameters—training, model, and optimizer—are specified in the configuration file (`${your_config_file}`). These parameters can be adjusted directly in the configuration file or through the bash script. 162 | 163 | ### **Hyperparameter Details** 164 | 165 | #### **Model Hyperparameters**: 166 | 167 | - **n_layer**: Layers of networks, 12 for GPT2 Small, 24 for GPT2 Medium, 36 for GPT2 Large 168 | - **n_head**: Number of heads, 12 for GPT2 small, 16 for GPT2 Medium, 20 for GPT2 Large 169 | - **n_embd**: Embedding dimension, 768 for GPT2 small, 1024 for GPT2 Medium, 1280 for GPT2 Large 170 | 171 | #### **Optimizer Hyperparameters**: 172 | 173 | - **`learning_rate`**: Learning rate for the **MARS-M** optimizer. 174 | - **`weight_decay`**: Weight decay for the **MARS-M** optimizer. 175 | - **`beta1`**: momentum for **MARS-M** optimizer. 176 | 177 | - Default: `beta1=0.95, beta2=0.99` 178 | - **`betas_1d`**: Weights for exponential moving average in AdamW optimizer (for 1d parameters). 179 | 180 | - Default: `(0.9, 0.95)` 181 | - **`is_approx`**: Whether to use approximate gradient calculation (**MARS-M**-approx). 182 | 183 | - Default: `True` 184 | - **`gamma`**: The scaling parameter that controls the strength of gradient correction. 185 | 186 | - Default: 0.025 187 | 188 | #### **Training Hyperparameters**: 189 | 190 | - **`batch_size`**: Mini-batch size per device. (for example GPT-2 Small on an A100 GPU typically uses a batch size of 15.) 191 | - **`gradient_accumulation_steps`**: Gradient accumulation steps to ensure the total effective batch size matches the desired scale. (for example, for a total batch size of 480: $15 \times 4 \times 8 \, \text{GPUs}$.) 192 | - **`schedule`**: learning rate schedule. 193 | - Default: `cosine` 194 | 195 | For more detailed hyperparameter examples, refer to: 196 | 197 | - `config/train_gpt2_small_mars_m.py` 198 | - `scripts/run_mars_m_small.sh` 199 | 200 | --- 201 | 202 | ### Reproducing Our Results 203 | 204 | #### **Reproducing GPT-2 Small (125M) Results** 205 | 206 | Training with MARS-M using 207 | 208 | ``` 209 | $ bash scripts/run_mars_m_small.sh 210 | ``` 211 | 212 | or 213 | 214 | ``` 215 | $ torchrun --standalone --nproc_per_node=8 \ 216 | train_mars_m.py \ 217 | config/train_gpt2_small_mars_m.py \ 218 | --batch_size=15 \ 219 | --gradient_accumulation_steps=4 220 | ``` 221 | 222 | #### Reproducing GPT2 Medium (355M) Results 223 | 224 | Training with MARS-M using 225 | 226 | ``` 227 | $ bash scripts/run_mars_m_medium.sh 228 | ``` 229 | 230 | or 231 | 232 | ``` 233 | $ torchrun --standalone --nproc_per_node=8 \ 234 | train_mars_m.py \ 235 | config/train_gpt2_medium_mars_m.py \ 236 | --batch_size=15 \ 237 | --gradient_accumulation_steps=4 238 | ``` 239 | 240 | #### Reproducing GPT2 Large (770M) Results 241 | 242 | Training with MARS-M using 243 | 244 | ``` 245 | $ bash scripts/run_mars_m_large.sh 246 | ``` 247 | 248 | or 249 | 250 | ``` 251 | $ torchrun --standalone --nproc_per_node=8 \ 252 | train_mars_m.py \ 253 | config/train_gpt2_large_mars_m.py \ 254 | --batch_size=5 \ 255 | --gradient_accumulation_steps=12 256 | ``` 257 | 258 | #### **Reproducing GPT-2 XL (1.5B) Results on FineWeb-Edu** 259 | 260 | ``` 261 | $ bash scripts/run_mars_m_xl_fw.sh 262 | ``` 263 | 264 | or 265 | 266 | ``` 267 | $ torchrun --standalone --nproc_per_node=8 \ 268 | train_mars_m_fw.py \ 269 | config/train_gpt2_xl_mars_m.py \ 270 | --batch_size=5 \ 271 | --gradient_accumulation_steps=12 272 | ``` 273 | 274 | #### Reproducing Baseline Results 275 | 276 | To reproduce the Moonlight baseline: 277 | 278 | ``` 279 | bash scripts/run_moonlight_{small/medium/large}.sh 280 | ``` 281 | 282 | Other baselines can be implemented with codes in `../MARS` folder. 283 | 284 | Please adjust ``nproc_per_node``, ``batch_size``, and ``gradient_accumulation_steps`` accordingly if you use other hardware setup. Make sure their product equals 480. 285 | 286 | #### Hyperparameters for GPT-2 models 287 | 288 | | Model Name | Model Size | OpenWebText LR | FineWeb-Edu LR | weight decay | 289 | | :----------: | :--------: | :------------: | :------------: | :----------: | 290 | | GPT-2 small | 125M | 6e-3 | 1e-2 | 1e-1 | 291 | | GPT-2 medium | 355M | 5e-3 | 5e-3 | 1e-1 | 292 | | GPT-2 large | 770M | 5e-3 | 5e-3 | 1e-1 | 293 | | GPT-2 xl | 1.5B | - | 3e-3 | 1e-1 | 294 | 295 | ### Customized Training 296 | 297 | To build your own training pipeline on other architectures and datasets, use the following template as an example: 298 | 299 | ```python 300 | import torch 301 | import torch.nn.functional as F 302 | from mars_m import MARS_M 303 | 304 | # init model loss function and input data 305 | model = Model() 306 | data_loader = ... 307 | 308 | # init the optimizer 309 | muon_params = [p for name, p in model.named_parameters() if p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name] 310 | adamw_params = [p for name, p in model.named_parameters() if p.ndim < 2 or "embed_tokens" in name or "lm_head" in name] 311 | optimizer = MARS_M(muon_params=muon_params, adamw_params=adamw_params, lr=1e-3, betas=(0.9, 0.95), gamma=0.025) 312 | 313 | total_bs = len(data_loader) 314 | bs = total_bs * block_size 315 | k = 10 316 | iter_num = -1 317 | 318 | # training loop 319 | for epoch in range(epochs): 320 | for X, Y in data_loader: 321 | # standard training code 322 | logits, loss = model(X, Y) 323 | loss.backward() 324 | optimizer.step(bs=bs) 325 | optimizer.zero_grad(set_to_none=True) 326 | optimizer.update_last_grad() 327 | iter_num += 1 328 | 329 | ``` 330 | 331 | ## Citation 332 | 333 | If you find this repo useful for your research, please consider citing our github repository: 334 | 335 | ```tex 336 | @misc{liu2025MARS, 337 | author = {Yifeng Liu and Angela Yuan and Quanquan Gu}, 338 | title = {MARS-M: When Variance Reduction Meets Matrices}, 339 | year = {2025}, 340 | url = {https://github.com/AGI-Arena/MARS/tree/main/MARS_M/} 341 | } 342 | ``` 343 | 344 | ## Acknowledgements 345 | 346 | This repo is built upon [nanoGPT](https://github.com/karpathy/nanoGPT/), [levanter](https://github.com/stanford-crfm/levanter/) and [Sophia](https://github.com/Liuhong99/Sophia), we thank the authors for their great work! 347 | -------------------------------------------------------------------------------- /MARS/train_adamw.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from https://github.com/Liuhong99/Sophia/blob/main/train_adam.py 3 | """ 4 | import os 5 | import time 6 | import math 7 | import pickle 8 | from contextlib import nullcontext 9 | 10 | import numpy as np 11 | import torch 12 | from torch.nn.parallel import DistributedDataParallel as DDP 13 | from torch.distributed import init_process_group, destroy_process_group 14 | 15 | from model import GPTConfig, GPT 16 | import sys 17 | from ast import literal_eval 18 | # ----------------------------------------------------------------------------- 19 | # default config values designed to train a gpt2 (124M) on OpenWebText 20 | # I/O 21 | out_dir = 'out' 22 | eval_interval = 2000 23 | log_interval = 1 24 | eval_iters = 200 25 | eval_only = False # if True, script exits right after the first eval 26 | always_save_checkpoint = True # if True, always save a checkpoint after each eval 27 | init_from = 'scratch' # 'scratch' or 'resume' or 'gpt2*' 28 | # wandb logging 29 | wandb_log = False # disabled by default 30 | wandb_project = 'mars' 31 | wandb_run_name = 'gpt2' # 'run' + str(time.time()) 32 | # data 33 | dataset = 'openwebtext' 34 | gradient_accumulation_steps = 5 # used to simulate larger batch sizes 35 | batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size 36 | block_size = 1024 37 | # model 38 | n_layer = 12 39 | n_head = 12 40 | n_embd = 768 41 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 42 | bias = False # do we use bias inside LayerNorm and Linear layers? 43 | # optimizer 44 | optimizer_name = 'adamw' 45 | learning_rate = 6e-4 # max learning rate 46 | max_iters = 600000 # total number of training iterations 47 | weight_decay = 1e-1 48 | beta1 = 0.9 49 | beta2 = 0.95 50 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 51 | interval = 10 52 | variant = 4 53 | schedule='cosine' 54 | # learning rate decay settings 55 | decay_lr = True # whether to decay the learning rate 56 | warmup_iters = 2000 # how many steps to warm up for 57 | lr_decay_iters = 600000 # should be ~= max_iters per Chinchilla 58 | min_lr = 6e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla 59 | # DDP settings 60 | backend = 'nccl' # 'nccl', 'gloo', etc. 61 | # system 62 | device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks 63 | dtype = 'bfloat16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler 64 | compile = True # use PyTorch 2.0 to compile the model to be faster 65 | scale_attn_by_inverse_layer_idx = True 66 | # ----------------------------------------------------------------------------- 67 | config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] 68 | for arg in sys.argv[1:]: 69 | if '=' not in arg: 70 | # assume it's the name of a config file 71 | assert not arg.startswith('--') 72 | config_file = arg 73 | print(f"Overriding config with {config_file}:") 74 | with open(config_file) as f: 75 | print(f.read()) 76 | exec(open(config_file).read()) 77 | else: 78 | # assume it's a --key=value argument 79 | assert arg.startswith('--') 80 | key, val = arg.split('=') 81 | key = key[2:] 82 | if key in globals(): 83 | try: 84 | # attempt to eval it it (e.g. if bool, number, or etc) 85 | attempt = literal_eval(val) 86 | except (SyntaxError, ValueError): 87 | # if that goes wrong, just use the string 88 | attempt = val 89 | # ensure the types match ok 90 | assert type(attempt) == type(globals()[key]) 91 | # cross fingers 92 | print(f"Overriding: {key} = {attempt}") 93 | globals()[key] = attempt 94 | else: 95 | raise ValueError(f"Unknown config key: {key}") 96 | 97 | config = {k: globals()[k] for k in config_keys} # will be useful for logging 98 | # ----------------------------------------------------------------------------- 99 | 100 | # various inits, derived attributes, I/O setup 101 | ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run? 102 | if ddp: 103 | init_process_group(backend=backend) 104 | ddp_rank = int(os.environ['RANK']) 105 | ddp_local_rank = int(os.environ['LOCAL_RANK']) 106 | device = f'cuda:{ddp_local_rank}' 107 | torch.cuda.set_device(device) 108 | master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. 109 | seed_offset = ddp_rank # each process gets a different seed 110 | else: 111 | # if not ddp, we are running on a single gpu, and one process 112 | master_process = True 113 | seed_offset = 0 114 | gradient_accumulation_steps *= 8 # simulate 8 gpus 115 | 116 | if master_process: 117 | os.makedirs(out_dir, exist_ok=True) 118 | torch.manual_seed(5000 + seed_offset) 119 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul 120 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn 121 | device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast 122 | # note: float16 data type will automatically use a GradScaler 123 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] 124 | ctx = nullcontext() if device_type == 'cpu' else torch.autocast(device_type=device_type, dtype=ptdtype) 125 | 126 | # poor man's data loader 127 | data_dir = os.path.join('data', dataset) 128 | train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r') 129 | val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r') 130 | def get_batch(split): 131 | data = train_data if split == 'train' else val_data 132 | ix = torch.randint(len(data) - block_size, (batch_size,)) 133 | x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix]) 134 | y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix]) 135 | if device_type == 'cuda': 136 | # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True) 137 | x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True) 138 | else: 139 | x, y = x.to(device), y.to(device) 140 | return x, y 141 | 142 | # init these up here, can override if init_from='resume' (i.e. from a checkpoint) 143 | iter_num = 0 144 | best_val_loss = 1e9 145 | 146 | # attempt to derive vocab_size from the dataset 147 | meta_path = os.path.join(data_dir, 'meta.pkl') 148 | meta_vocab_size = None 149 | if os.path.exists(meta_path): 150 | with open(meta_path, 'rb') as f: 151 | meta = pickle.load(f) 152 | meta_vocab_size = meta['vocab_size'] 153 | print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})") 154 | 155 | # model init 156 | model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size, 157 | bias=bias, vocab_size=None, dropout=dropout, scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx) # start with model_args from command line 158 | if init_from == 'scratch': 159 | # init a new model from scratch 160 | print("Initializing a new model from scratch") 161 | # determine the vocab size we'll use for from-scratch training 162 | if meta_vocab_size is None: 163 | print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)") 164 | model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304 165 | gptconf = GPTConfig(**model_args) 166 | model = GPT(gptconf) 167 | elif init_from == 'resume': 168 | print(f"Resuming training from {out_dir}") 169 | # resume training from a checkpoint. 170 | ckpt_path = os.path.join(out_dir, 'ckpt.pt') 171 | checkpoint = torch.load(ckpt_path, map_location=device) 172 | checkpoint_model_args = checkpoint['model_args'] 173 | # force these config attributes to be equal otherwise we can't even resume training 174 | # the rest of the attributes (e.g. dropout) can stay as desired from command line 175 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']: 176 | model_args[k] = checkpoint_model_args[k] 177 | # create the model 178 | gptconf = GPTConfig(**model_args) 179 | model = GPT(gptconf) 180 | state_dict = checkpoint['model'] 181 | # fix the keys of the state dictionary :( 182 | # honestly no idea how checkpoints sometimes get this prefix, have to debug more 183 | unwanted_prefix = '_orig_mod.' 184 | for k,v in list(state_dict.items()): 185 | if k.startswith(unwanted_prefix): 186 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) 187 | model.load_state_dict(state_dict) 188 | iter_num = checkpoint['iter_num'] 189 | best_val_loss = checkpoint['best_val_loss'] 190 | elif init_from.startswith('gpt2'): 191 | print(f"Initializing from OpenAI GPT-2 weights: {init_from}") 192 | # initialize from OpenAI GPT-2 weights 193 | override_args = dict(dropout=dropout) 194 | model = GPT.from_pretrained(init_from, override_args) 195 | # read off the created config params, so we can store them into checkpoint correctly 196 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']: 197 | model_args[k] = getattr(model.config, k) 198 | # crop down the model block size if desired, using model surgery 199 | if block_size < model.config.block_size: 200 | model.crop_block_size(block_size) 201 | model_args['block_size'] = block_size # so that the checkpoint will have the right value 202 | model.to(device) 203 | 204 | # initialize a GradScaler. If enabled=False scaler is a no-op 205 | scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16')) 206 | 207 | # optimizer 208 | optimizer = model.configure_optimizers(optimizer_name, weight_decay, learning_rate, (beta1, beta2), device_type) 209 | if init_from == 'resume': 210 | optimizer.load_state_dict(checkpoint['optimizer']) 211 | del state_dict 212 | del checkpoint 213 | # compile the model 214 | if compile: 215 | print("compiling the model... (takes a ~minute)") 216 | unoptimized_model = model 217 | model = torch.compile(model) # requires PyTorch 2.0 218 | 219 | # wrap model into DDP container 220 | if ddp: 221 | model = DDP(model, device_ids=[ddp_local_rank]) 222 | 223 | # helps estimate an arbitrarily accurate loss over either split using many batches 224 | @torch.no_grad() 225 | def estimate_loss(): 226 | out = {} 227 | model.eval() 228 | for split in ['train', 'val']: 229 | losses = torch.zeros(eval_iters) 230 | for k in range(eval_iters): 231 | X, Y = get_batch(split) 232 | with ctx: 233 | logits, loss = model(X, Y) 234 | losses[k] = loss.item() 235 | out[split] = losses.mean() 236 | model.train() 237 | return out 238 | 239 | # learning rate decay scheduler (cosine with warmup) 240 | def get_lr(it, schedule='cosine'): 241 | #ing rate schedule {schedule}") 242 | # 1) linear warmup for warmup_iters steps 243 | if it < warmup_iters: 244 | return learning_rate * it / warmup_iters 245 | # 2) if it > lr_decay_iters, return min learning rate 246 | if it > lr_decay_iters: 247 | return min_lr 248 | # 3) in between, use cosine decay down to min learning rate 249 | decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) 250 | assert 0 <= decay_ratio <= 1 251 | if schedule=='cosine': 252 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 253 | elif schedule=='exp': 254 | coeff = np.power(0.9, 100 * decay_ratio) 255 | return min_lr + coeff * (learning_rate - min_lr) 256 | 257 | # logging 258 | if wandb_log and master_process: 259 | import wandb 260 | wandb.init(project=wandb_project, name=wandb_run_name, config=config) 261 | 262 | # training loop 263 | X, Y = get_batch('train') # fetch the very first batch 264 | t0 = time.time() 265 | local_iter_num = 0 # number of iterations in the lifetime of this process 266 | raw_model = model.module if ddp else model # unwrap DDP container if needed 267 | running_mfu = -1.0 268 | clip_time = 0 269 | while True: 270 | 271 | # determine and set the learning rate for this iteration 272 | lr = get_lr(iter_num, schedule=schedule) if decay_lr else learning_rate 273 | for param_group in optimizer.param_groups: 274 | param_group['lr'] = lr 275 | 276 | # evaluate the loss on train/val sets and write checkpoints 277 | if iter_num % eval_interval == 0 and master_process: 278 | losses = estimate_loss() 279 | print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") 280 | if wandb_log: 281 | wandb.log({ 282 | "iter": iter_num, 283 | "train/loss": losses['train'], 284 | "val/loss": losses['val'], 285 | "lr": lr, 286 | "mfu": running_mfu*100, # convert to percentage 287 | }, step=iter_num) 288 | if losses['val'] < best_val_loss or always_save_checkpoint: 289 | best_val_loss = losses['val'] 290 | if iter_num > 0: 291 | checkpoint = { 292 | 'model': raw_model.state_dict(), 293 | 'optimizer': optimizer.state_dict(), 294 | 'model_args': model_args, 295 | 'iter_num': iter_num, 296 | 'best_val_loss': best_val_loss, 297 | 'config': config, 298 | } 299 | print(f"saving checkpoint to {out_dir}") 300 | torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt')) 301 | if iter_num % (eval_interval * 5) == 0: 302 | checkpoint = { 303 | 'model': raw_model.state_dict(), 304 | 'optimizer': optimizer.state_dict(), 305 | 'model_args': model_args, 306 | 'iter_num': iter_num, 307 | 'best_val_loss': best_val_loss, 308 | 'config': config, 309 | } 310 | print(f"saving checkpoint to {out_dir}") 311 | torch.save(checkpoint, os.path.join(out_dir, 'ckpt_'+str(iter_num)+'.pt')) 312 | if iter_num == 0 and eval_only: 313 | break 314 | 315 | # forward backward update, with optional gradient accumulation to simulate larger batch size 316 | # and using the GradScaler if data type is float16 317 | for micro_step in range(gradient_accumulation_steps): 318 | if ddp: 319 | # in DDP training we only need to sync gradients at the last micro step. 320 | # the official way to do this is with model.no_sync() context manager, but 321 | # I really dislike that this bloats the code and forces us to repeat code 322 | # looking at the source of that context manager, it just toggles this variable 323 | model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1) 324 | with ctx: 325 | logits, loss = model(X, Y) 326 | # immediately async prefetch next batch while model is doing the forward pass on the GPU 327 | X, Y = get_batch('train') 328 | # backward pass, with gradient scaling if training in fp16 329 | scaler.scale(loss).backward() 330 | # clip the gradient 331 | if grad_clip != 0.0: 332 | scaler.unscale_(optimizer) 333 | total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) 334 | if total_norm.item() > grad_clip: 335 | clip_time += 1 336 | # step the optimizer and scaler if training in fp16 337 | scaler.step(optimizer) 338 | scaler.update() 339 | # flush the gradients as soon as we can, no need for this memory anymore 340 | optimizer.zero_grad(set_to_none=True) 341 | 342 | # timing and logging 343 | t1 = time.time() 344 | dt = t1 - t0 345 | t0 = t1 346 | if iter_num % log_interval == 0 and master_process: 347 | lossf = loss.item() # loss as float. note: this is a CPU-GPU sync point 348 | if local_iter_num >= 5: # let the training loop settle a bit 349 | mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt) 350 | running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu 351 | print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%") 352 | params = [] 353 | for (name, p) in model.named_parameters(): 354 | params.append(p) 355 | total_param_norm = 0 356 | for p in params: 357 | param_norm = p.data.norm(2) 358 | total_param_norm += param_norm.item() ** 2 359 | total_param_norm = total_param_norm ** 0.5 360 | momentum_norm = 0 361 | LL = len(optimizer.state_dict()['state']) 362 | for jj in range(LL): 363 | momentum_norm += (optimizer.state_dict()['state'][jj]['exp_avg'].detach().norm(2)) ** 2 364 | momentum_norm = torch.sqrt(momentum_norm).item() 365 | if wandb_log: 366 | wandb.log({ 367 | "iter": iter_num, 368 | "train/loss": lossf, 369 | "lr": lr, 370 | "param_norm": total_param_norm, 371 | "momentum_norm" : momentum_norm, 372 | "train/clip_rate": clip_time / (iter_num + 1) 373 | }, step=iter_num) 374 | iter_num += 1 375 | local_iter_num += 1 376 | 377 | # termination conditions 378 | if iter_num > max_iters: 379 | break 380 | 381 | if ddp: 382 | destroy_process_group() 383 | -------------------------------------------------------------------------------- /MARS/train_adamw_fw.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import math 4 | import pickle 5 | from contextlib import nullcontext 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn.parallel import DistributedDataParallel as DDP 10 | from torch.distributed import init_process_group, destroy_process_group 11 | 12 | from model import GPTConfig, GPT 13 | import sys 14 | from ast import literal_eval 15 | # ----------------------------------------------------------------------------- 16 | # default config values designed to train a gpt2 (124M) on OpenWebText 17 | # I/O 18 | out_dir = 'out' 19 | eval_interval = 2000 20 | log_interval = 1 21 | eval_iters = 200 22 | eval_only = False # if True, script exits right after the first eval 23 | always_save_checkpoint = True # if True, always save a checkpoint after each eval 24 | init_from = 'scratch' # 'scratch' or 'resume' or 'gpt2*' 25 | # wandb logging 26 | wandb_log = False # disabled by default 27 | wandb_project = 'mars' 28 | wandb_run_name = 'gpt2' # 'run' + str(time.time()) 29 | # data 30 | dataset = 'fineweb-edu100B' 31 | gradient_accumulation_steps = 5 # used to simulate larger batch sizes 32 | batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size 33 | block_size = 1024 34 | # model 35 | n_layer = 12 36 | n_head = 12 37 | n_embd = 768 38 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 39 | bias = False # do we use bias inside LayerNorm and Linear layers? 40 | # optimizer 41 | optimizer_name = 'adamw' 42 | learning_rate = 6e-4 # max learning rate 43 | max_iters = 600000 # total number of training iterations 44 | weight_decay = 1e-1 45 | beta1 = 0.9 46 | beta2 = 0.95 47 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 48 | interval = 10 49 | variant = 4 50 | schedule='cosine' 51 | # learning rate decay settings 52 | decay_lr = True # whether to decay the learning rate 53 | warmup_iters = 2000 # how many steps to warm up for 54 | lr_decay_iters = 600000 # should be ~= max_iters per Chinchilla 55 | min_lr = 6e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla 56 | # DDP settings 57 | backend = 'nccl' # 'nccl', 'gloo', etc. 58 | # system 59 | device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks 60 | dtype = 'bfloat16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler 61 | compile = True # use PyTorch 2.0 to compile the model to be faster 62 | scale_attn_by_inverse_layer_idx = True 63 | # ----------------------------------------------------------------------------- 64 | config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] 65 | for arg in sys.argv[1:]: 66 | if '=' not in arg: 67 | # assume it's the name of a config file 68 | assert not arg.startswith('--') 69 | config_file = arg 70 | print(f"Overriding config with {config_file}:") 71 | with open(config_file) as f: 72 | print(f.read()) 73 | exec(open(config_file).read()) 74 | else: 75 | # assume it's a --key=value argument 76 | assert arg.startswith('--') 77 | key, val = arg.split('=') 78 | key = key[2:] 79 | if key in globals(): 80 | try: 81 | # attempt to eval it it (e.g. if bool, number, or etc) 82 | attempt = literal_eval(val) 83 | except (SyntaxError, ValueError): 84 | # if that goes wrong, just use the string 85 | attempt = val 86 | # ensure the types match ok 87 | assert type(attempt) == type(globals()[key]) 88 | # cross fingers 89 | print(f"Overriding: {key} = {attempt}") 90 | globals()[key] = attempt 91 | else: 92 | raise ValueError(f"Unknown config key: {key}") 93 | 94 | config = {k: globals()[k] for k in config_keys} # will be useful for logging 95 | # ----------------------------------------------------------------------------- 96 | 97 | # various inits, derived attributes, I/O setup 98 | ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run? 99 | if ddp: 100 | init_process_group(backend=backend) 101 | ddp_rank = int(os.environ['RANK']) 102 | ddp_local_rank = int(os.environ['LOCAL_RANK']) 103 | device = f'cuda:{ddp_local_rank}' 104 | torch.cuda.set_device(device) 105 | master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. 106 | seed_offset = ddp_rank # each process gets a different seed 107 | else: 108 | # if not ddp, we are running on a single gpu, and one process 109 | master_process = True 110 | seed_offset = 0 111 | gradient_accumulation_steps *= 8 # simulate 8 gpus 112 | 113 | if master_process: 114 | os.makedirs(out_dir, exist_ok=True) 115 | torch.manual_seed(5000 + seed_offset) 116 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul 117 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn 118 | device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast 119 | # note: float16 data type will automatically use a GradScaler 120 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] 121 | ctx = nullcontext() if device_type == 'cpu' else torch.autocast(device_type=device_type, dtype=ptdtype) 122 | 123 | # poor man's data loader 124 | data_dir = os.path.join('data', dataset) 125 | train_file_list = list(filter(lambda x: x.endswith('.bin') and x.startswith('fineweb_train'), os.listdir(data_dir))) 126 | train_data_list = [np.memmap(os.path.join(data_dir, file), dtype=np.uint16, mode='r') for file in train_file_list] 127 | val_data = np.memmap(os.path.join(data_dir, 'fineweb_val_000000.bin'), dtype=np.uint16, mode='r') 128 | import random 129 | random.seed(5000 + seed_offset) 130 | def get_batch(split): 131 | if split == 'train': 132 | data = random.choice(train_data_list) 133 | else: 134 | data = val_data 135 | offset = 512 136 | ix = torch.randint(len(data) - block_size - offset, (batch_size,)) 137 | x = torch.stack([torch.from_numpy((data[offset+i:offset+i+block_size]).astype(np.int64)) for i in ix]) 138 | y = torch.stack([torch.from_numpy((data[offset+i+1:offset+i+1+block_size]).astype(np.int64)) for i in ix]) 139 | if device_type == 'cuda': 140 | # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True) 141 | x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True) 142 | else: 143 | x, y = x.to(device), y.to(device) 144 | return x, y 145 | 146 | # init these up here, can override if init_from='resume' (i.e. from a checkpoint) 147 | iter_num = 0 148 | best_val_loss = 1e9 149 | 150 | # attempt to derive vocab_size from the dataset 151 | meta_path = os.path.join(data_dir, 'meta.pkl') 152 | meta_vocab_size = None 153 | if os.path.exists(meta_path): 154 | with open(meta_path, 'rb') as f: 155 | meta = pickle.load(f) 156 | meta_vocab_size = meta['vocab_size'] 157 | print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})") 158 | 159 | # model init 160 | model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size, 161 | bias=bias, vocab_size=None, dropout=dropout, scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx) # start with model_args from command line 162 | if init_from == 'scratch': 163 | # init a new model from scratch 164 | print("Initializing a new model from scratch") 165 | # determine the vocab size we'll use for from-scratch training 166 | if meta_vocab_size is None: 167 | print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)") 168 | model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304 169 | gptconf = GPTConfig(**model_args) 170 | model = GPT(gptconf) 171 | elif init_from == 'resume': 172 | print(f"Resuming training from {out_dir}") 173 | # resume training from a checkpoint. 174 | ckpt_path = os.path.join(out_dir, 'ckpt.pt') 175 | checkpoint = torch.load(ckpt_path, map_location=device) 176 | checkpoint_model_args = checkpoint['model_args'] 177 | # force these config attributes to be equal otherwise we can't even resume training 178 | # the rest of the attributes (e.g. dropout) can stay as desired from command line 179 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']: 180 | model_args[k] = checkpoint_model_args[k] 181 | # create the model 182 | gptconf = GPTConfig(**model_args) 183 | model = GPT(gptconf) 184 | state_dict = checkpoint['model'] 185 | # fix the keys of the state dictionary :( 186 | # honestly no idea how checkpoints sometimes get this prefix, have to debug more 187 | unwanted_prefix = '_orig_mod.' 188 | for k,v in list(state_dict.items()): 189 | if k.startswith(unwanted_prefix): 190 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) 191 | model.load_state_dict(state_dict) 192 | iter_num = checkpoint['iter_num'] 193 | best_val_loss = checkpoint['best_val_loss'] 194 | elif init_from.startswith('gpt2'): 195 | print(f"Initializing from OpenAI GPT-2 weights: {init_from}") 196 | # initialize from OpenAI GPT-2 weights 197 | override_args = dict(dropout=dropout) 198 | model = GPT.from_pretrained(init_from, override_args) 199 | # read off the created config params, so we can store them into checkpoint correctly 200 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']: 201 | model_args[k] = getattr(model.config, k) 202 | # crop down the model block size if desired, using model surgery 203 | if block_size < model.config.block_size: 204 | model.crop_block_size(block_size) 205 | model_args['block_size'] = block_size # so that the checkpoint will have the right value 206 | model.to(device) 207 | 208 | # initialize a GradScaler. If enabled=False scaler is a no-op 209 | scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16')) 210 | 211 | # optimizer 212 | optimizer = model.configure_optimizers(optimizer_name, weight_decay, learning_rate, (beta1, beta2), device_type) 213 | if init_from == 'resume': 214 | optimizer.load_state_dict(checkpoint['optimizer']) 215 | del state_dict 216 | del checkpoint 217 | # compile the model 218 | if compile: 219 | print("compiling the model... (takes a ~minute)") 220 | unoptimized_model = model 221 | model = torch.compile(model) # requires PyTorch 2.0 222 | 223 | # wrap model into DDP container 224 | if ddp: 225 | model = DDP(model, device_ids=[ddp_local_rank]) 226 | 227 | # helps estimate an arbitrarily accurate loss over either split using many batches 228 | @torch.no_grad() 229 | def estimate_loss(): 230 | out = {} 231 | model.eval() 232 | for split in ['train', 'val']: 233 | losses = torch.zeros(eval_iters) 234 | for k in range(eval_iters): 235 | X, Y = get_batch(split) 236 | with ctx: 237 | logits, loss = model(X, Y) 238 | losses[k] = loss.item() 239 | out[split] = losses.mean() 240 | model.train() 241 | return out 242 | 243 | # learning rate decay scheduler (cosine with warmup) 244 | def get_lr(it, schedule='cosine'): 245 | #ing rate schedule {schedule}") 246 | # 1) linear warmup for warmup_iters steps 247 | if it < warmup_iters: 248 | return learning_rate * it / warmup_iters 249 | # 2) if it > lr_decay_iters, return min learning rate 250 | if schedule=='wsd': 251 | if it < 0.8 * max_iters: 252 | return learning_rate 253 | else: 254 | return learning_rate * (max_iters - it) / (max_iters * 0.2) 255 | if it > lr_decay_iters: 256 | return min_lr 257 | # 3) in between, use cosine decay down to min learning rate 258 | decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) 259 | assert 0 <= decay_ratio <= 1 260 | if schedule=='cosine': 261 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 262 | elif schedule=='exp': 263 | coeff = np.power(0.9, 100 * decay_ratio) 264 | 265 | return min_lr + coeff * (learning_rate - min_lr) 266 | 267 | # logging 268 | if wandb_log and master_process: 269 | import wandb 270 | wandb.init(project=wandb_project, name=wandb_run_name, config=config) 271 | 272 | # training loop 273 | X, Y = get_batch('train') # fetch the very first batch 274 | t0 = time.time() 275 | local_iter_num = 0 # number of iterations in the lifetime of this process 276 | raw_model = model.module if ddp else model # unwrap DDP container if needed 277 | running_mfu = -1.0 278 | clip_time = 0 279 | while True: 280 | 281 | # determine and set the learning rate for this iteration 282 | lr = get_lr(iter_num, schedule=schedule) if decay_lr else learning_rate 283 | for param_group in optimizer.param_groups: 284 | param_group['lr'] = lr 285 | 286 | # evaluate the loss on train/val sets and write checkpoints 287 | if iter_num % eval_interval == 0 and master_process: 288 | losses = estimate_loss() 289 | print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") 290 | if wandb_log: 291 | wandb.log({ 292 | "iter": iter_num, 293 | "train/loss": losses['train'], 294 | "val/loss": losses['val'], 295 | "lr": lr, 296 | "mfu": running_mfu*100, # convert to percentage 297 | }, step=iter_num) 298 | if losses['val'] < best_val_loss or always_save_checkpoint: 299 | best_val_loss = losses['val'] 300 | if iter_num > 0: 301 | checkpoint = { 302 | 'model': raw_model.state_dict(), 303 | 'optimizer': optimizer.state_dict(), 304 | 'model_args': model_args, 305 | 'iter_num': iter_num, 306 | 'best_val_loss': best_val_loss, 307 | 'config': config, 308 | } 309 | print(f"saving checkpoint to {out_dir}") 310 | torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt')) 311 | if iter_num % (eval_interval * 5) == 0: 312 | checkpoint = { 313 | 'model': raw_model.state_dict(), 314 | 'optimizer': optimizer.state_dict(), 315 | 'model_args': model_args, 316 | 'iter_num': iter_num, 317 | 'best_val_loss': best_val_loss, 318 | 'config': config, 319 | } 320 | print(f"saving checkpoint to {out_dir}") 321 | torch.save(checkpoint, os.path.join(out_dir, 'ckpt_'+str(iter_num)+'.pt')) 322 | if iter_num == 0 and eval_only: 323 | break 324 | 325 | # forward backward update, with optional gradient accumulation to simulate larger batch size 326 | # and using the GradScaler if data type is float16 327 | for micro_step in range(gradient_accumulation_steps): 328 | if ddp: 329 | # in DDP training we only need to sync gradients at the last micro step. 330 | # the official way to do this is with model.no_sync() context manager, but 331 | # I really dislike that this bloats the code and forces us to repeat code 332 | # looking at the source of that context manager, it just toggles this variable 333 | model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1) 334 | with ctx: 335 | logits, loss = model(X, Y) 336 | # immediately async prefetch next batch while model is doing the forward pass on the GPU 337 | X, Y = get_batch('train') 338 | # backward pass, with gradient scaling if training in fp16 339 | scaler.scale(loss).backward() 340 | # clip the gradient 341 | if grad_clip != 0.0: 342 | scaler.unscale_(optimizer) 343 | total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) 344 | if total_norm.item() > grad_clip: 345 | clip_time += 1 346 | # step the optimizer and scaler if training in fp16 347 | scaler.step(optimizer) 348 | scaler.update() 349 | # flush the gradients as soon as we can, no need for this memory anymore 350 | optimizer.zero_grad(set_to_none=True) 351 | 352 | # timing and logging 353 | t1 = time.time() 354 | dt = t1 - t0 355 | t0 = t1 356 | if iter_num % log_interval == 0 and master_process: 357 | lossf = loss.item() # loss as float. note: this is a CPU-GPU sync point 358 | if local_iter_num >= 5: # let the training loop settle a bit 359 | mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt) 360 | running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu 361 | print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%") 362 | params = [] 363 | for (name, p) in model.named_parameters(): 364 | params.append(p) 365 | total_param_norm = 0 366 | for p in params: 367 | param_norm = p.data.norm(2) 368 | total_param_norm += param_norm.item() ** 2 369 | total_param_norm = total_param_norm ** 0.5 370 | momentum_norm = 0 371 | LL = len(optimizer.state_dict()['state']) 372 | for jj in range(LL): 373 | momentum_norm += (optimizer.state_dict()['state'][jj]['exp_avg'].detach().norm(2)) ** 2 374 | momentum_norm = torch.sqrt(momentum_norm).item() 375 | if wandb_log: 376 | wandb.log({ 377 | "iter": iter_num, 378 | "train/loss": lossf, 379 | "lr": lr, 380 | "param_norm": total_param_norm, 381 | "momentum_norm" : momentum_norm, 382 | "train/clip_rate": clip_time / (iter_num + 1) 383 | }, step=iter_num) 384 | iter_num += 1 385 | local_iter_num += 1 386 | 387 | # termination conditions 388 | if iter_num > max_iters: 389 | break 390 | 391 | if ddp: 392 | destroy_process_group() 393 | --------------------------------------------------------------------------------