├── MARS_M
├── scripts
│ ├── run_CV.sh
│ ├── run_mars_m_large.sh
│ ├── run_mars_m_medium.sh
│ ├── run_mars_m_small.sh
│ ├── run_mars_m_xl_fw.sh
│ ├── run_mars_m_small_fw.sh
│ ├── run_moonlight_large.sh
│ ├── run_moonlight_medium.sh
│ ├── run_moonlight_small.sh
│ ├── run_moonlight_xl_fw.sh
│ └── run_moonlight_small_fw.sh
├── assets
│ ├── xl_val.png
│ ├── xl_train.png
│ ├── small_train.png
│ ├── small_val.png
│ ├── val_large.png
│ ├── val_medium.png
│ ├── val_small.png
│ ├── xl_val_global.png
│ ├── small_val_global.png
│ ├── val_large_global.png
│ ├── val_small_global.png
│ ├── xl_train_global.png
│ ├── small_train_global.png
│ └── val_medium_global.png
├── config
│ ├── train_gpt2_small_moonlight.py
│ ├── train_gpt2_small_mars_m.py
│ ├── train_gpt2_xl_moonlight.py
│ ├── train_gpt2_large_moonlight.py
│ ├── train_gpt2_medium_moonlight.py
│ ├── train_gpt2_xl_mars_m.py
│ ├── train_gpt2_large_mars_m.py
│ └── train_gpt2_medium_mars_m.py
├── utils
│ ├── configurator.py
│ ├── cv_utils.py
│ └── model_CNN.py
├── openwebtext
│ └── prepare.py
├── optimizers
│ ├── moonlight.py
│ └── mars_m.py
├── train_CV.py
└── README.md
├── scripts
├── run_CNN.sh
├── run_CV.sh
├── run_mars_large.sh
├── run_mars_small.sh
├── run_muon_large.sh
├── run_muon_small.sh
├── run_adamw_large.sh
├── run_adamw_medium.sh
├── run_adamw_small.sh
├── run_mars_medium.sh
├── run_mars_xl_fw.sh
├── run_muon_medium.sh
├── run_adamw_xl_fw.sh
├── run_mars_small_fw.sh
└── run_adamw_small_fw.sh
├── assets
├── MARS.png
├── ShampooH.png
├── xl_train.png
├── xl_val.png
├── MARS-AdamW.png
├── MARS-Lion.png
├── small_val.png
├── time_large.png
├── time_small.png
├── val_large.png
├── val_medium.png
├── val_small.jpg
├── val_small.png
├── MARS-Shampoo.png
├── fineweb_hella.png
├── small_train.png
├── time_medium.png
├── cifar100_test_acc.png
├── cifar100_test_loss.png
├── cifar10_test_acc.png
└── cifar10_test_loss.png
├── config
├── train_gpt2_small_adamw.py
├── train_gpt2_small_mars.py
├── train_gpt2_xl_adamw.py
├── train_gpt2_large_adamw.py
├── train_gpt2_medium_adamw.py
├── train_gpt2_xl_mars.py
├── train_gpt2_large_mars.py
├── train_gpt2_medium_mars.py
├── train_gpt2_small_muon.py
├── train_gpt2_large_muon.py
└── train_gpt2_medium_muon.py
├── MARS
├── utils
│ ├── configurator.py
│ ├── cv_utils.py
│ └── model_CNN.py
├── opt.py
├── optimizers
│ ├── muon.py
│ ├── adamw.py
│ └── mars.py
├── train_CV.py
├── train_CNN.py
├── train_adamw.py
└── train_adamw_fw.py
├── data
└── openwebtext
│ └── prepare.py
└── LICENSE
/MARS_M/scripts/run_CV.sh:
--------------------------------------------------------------------------------
1 | python train_CV.py
--------------------------------------------------------------------------------
/scripts/run_CNN.sh:
--------------------------------------------------------------------------------
1 | python MARS/train_CNN.py
--------------------------------------------------------------------------------
/scripts/run_CV.sh:
--------------------------------------------------------------------------------
1 | python MARS/train_CV.py
--------------------------------------------------------------------------------
/assets/MARS.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/assets/MARS.png
--------------------------------------------------------------------------------
/assets/ShampooH.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/assets/ShampooH.png
--------------------------------------------------------------------------------
/assets/xl_train.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/assets/xl_train.png
--------------------------------------------------------------------------------
/assets/xl_val.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/assets/xl_val.png
--------------------------------------------------------------------------------
/assets/MARS-AdamW.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/assets/MARS-AdamW.png
--------------------------------------------------------------------------------
/assets/MARS-Lion.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/assets/MARS-Lion.png
--------------------------------------------------------------------------------
/assets/small_val.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/assets/small_val.png
--------------------------------------------------------------------------------
/assets/time_large.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/assets/time_large.png
--------------------------------------------------------------------------------
/assets/time_small.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/assets/time_small.png
--------------------------------------------------------------------------------
/assets/val_large.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/assets/val_large.png
--------------------------------------------------------------------------------
/assets/val_medium.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/assets/val_medium.png
--------------------------------------------------------------------------------
/assets/val_small.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/assets/val_small.jpg
--------------------------------------------------------------------------------
/assets/val_small.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/assets/val_small.png
--------------------------------------------------------------------------------
/MARS_M/assets/xl_val.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/MARS_M/assets/xl_val.png
--------------------------------------------------------------------------------
/assets/MARS-Shampoo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/assets/MARS-Shampoo.png
--------------------------------------------------------------------------------
/assets/fineweb_hella.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/assets/fineweb_hella.png
--------------------------------------------------------------------------------
/assets/small_train.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/assets/small_train.png
--------------------------------------------------------------------------------
/assets/time_medium.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/assets/time_medium.png
--------------------------------------------------------------------------------
/MARS_M/assets/xl_train.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/MARS_M/assets/xl_train.png
--------------------------------------------------------------------------------
/MARS_M/assets/small_train.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/MARS_M/assets/small_train.png
--------------------------------------------------------------------------------
/MARS_M/assets/small_val.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/MARS_M/assets/small_val.png
--------------------------------------------------------------------------------
/MARS_M/assets/val_large.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/MARS_M/assets/val_large.png
--------------------------------------------------------------------------------
/MARS_M/assets/val_medium.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/MARS_M/assets/val_medium.png
--------------------------------------------------------------------------------
/MARS_M/assets/val_small.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/MARS_M/assets/val_small.png
--------------------------------------------------------------------------------
/assets/cifar100_test_acc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/assets/cifar100_test_acc.png
--------------------------------------------------------------------------------
/assets/cifar100_test_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/assets/cifar100_test_loss.png
--------------------------------------------------------------------------------
/assets/cifar10_test_acc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/assets/cifar10_test_acc.png
--------------------------------------------------------------------------------
/assets/cifar10_test_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/assets/cifar10_test_loss.png
--------------------------------------------------------------------------------
/MARS_M/assets/xl_val_global.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/MARS_M/assets/xl_val_global.png
--------------------------------------------------------------------------------
/MARS_M/assets/small_val_global.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/MARS_M/assets/small_val_global.png
--------------------------------------------------------------------------------
/MARS_M/assets/val_large_global.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/MARS_M/assets/val_large_global.png
--------------------------------------------------------------------------------
/MARS_M/assets/val_small_global.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/MARS_M/assets/val_small_global.png
--------------------------------------------------------------------------------
/MARS_M/assets/xl_train_global.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/MARS_M/assets/xl_train_global.png
--------------------------------------------------------------------------------
/MARS_M/assets/small_train_global.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/MARS_M/assets/small_train_global.png
--------------------------------------------------------------------------------
/MARS_M/assets/val_medium_global.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AGI-Arena/MARS/HEAD/MARS_M/assets/val_medium_global.png
--------------------------------------------------------------------------------
/scripts/run_mars_large.sh:
--------------------------------------------------------------------------------
1 | torchrun --standalone --nproc_per_node=8 \
2 | MARS/train_mars.py \
3 | config/train_gpt2_large_mars.py \
4 | --batch_size=5 \
5 | --gradient_accumulation_steps=12
--------------------------------------------------------------------------------
/scripts/run_mars_small.sh:
--------------------------------------------------------------------------------
1 | torchrun --standalone --nproc_per_node=8 \
2 | MARS/train_mars.py \
3 | config/train_gpt2_small_mars.py \
4 | --batch_size=15 \
5 | --gradient_accumulation_steps=4
--------------------------------------------------------------------------------
/scripts/run_muon_large.sh:
--------------------------------------------------------------------------------
1 | torchrun --standalone --nproc_per_node=8 \
2 | MARS/train_muon.py \
3 | config/train_gpt2_large_muon.py \
4 | --batch_size=5 \
5 | --gradient_accumulation_steps=12
--------------------------------------------------------------------------------
/scripts/run_muon_small.sh:
--------------------------------------------------------------------------------
1 | torchrun --standalone --nproc_per_node=8 \
2 | MARS/train_muon.py \
3 | config/train_gpt2_small_muon.py \
4 | --batch_size=15 \
5 | --gradient_accumulation_steps=4
--------------------------------------------------------------------------------
/scripts/run_adamw_large.sh:
--------------------------------------------------------------------------------
1 | torchrun --standalone --nproc_per_node=8 \
2 | MARS/train_adamw.py \
3 | config/train_gpt2_large_adamw.py \
4 | --batch_size=5 \
5 | --gradient_accumulation_steps=12
--------------------------------------------------------------------------------
/scripts/run_adamw_medium.sh:
--------------------------------------------------------------------------------
1 | torchrun --standalone --nproc_per_node=8 \
2 | MARS/train_adamw.py \
3 | config/train_gpt2_medium_adamw.py \
4 | --batch_size=15 \
5 | --gradient_accumulation_steps=4
--------------------------------------------------------------------------------
/scripts/run_adamw_small.sh:
--------------------------------------------------------------------------------
1 | torchrun --standalone --nproc_per_node=8 \
2 | MARS/train_adamw.py \
3 | config/train_gpt2_small_adamw.py \
4 | --batch_size=15 \
5 | --gradient_accumulation_steps=4
--------------------------------------------------------------------------------
/scripts/run_mars_medium.sh:
--------------------------------------------------------------------------------
1 | torchrun --standalone --nproc_per_node=8 \
2 | MARS/train_mars.py \
3 | config/train_gpt2_medium_mars.py \
4 | --batch_size=15 \
5 | --gradient_accumulation_steps=4
--------------------------------------------------------------------------------
/scripts/run_mars_xl_fw.sh:
--------------------------------------------------------------------------------
1 | torchrun --standalone --nproc_per_node=8 \
2 | MARS/train_mars_fw.py \
3 | config/train_gpt2_xl_mars.py \
4 | --batch_size=5 \
5 | --gradient_accumulation_steps=12
6 |
--------------------------------------------------------------------------------
/scripts/run_muon_medium.sh:
--------------------------------------------------------------------------------
1 | torchrun --standalone --nproc_per_node=8 \
2 | MARS/train_muon.py \
3 | config/train_gpt2_medium_muon.py \
4 | --batch_size=15 \
5 | --gradient_accumulation_steps=4
--------------------------------------------------------------------------------
/MARS_M/scripts/run_mars_m_large.sh:
--------------------------------------------------------------------------------
1 | torchrun --standalone --nproc_per_node=8 \
2 | train_mars_m.py \
3 | config/train_gpt2_large_mars_m.py \
4 | --batch_size=5 \
5 | --gradient_accumulation_steps=12
--------------------------------------------------------------------------------
/MARS_M/scripts/run_mars_m_medium.sh:
--------------------------------------------------------------------------------
1 | torchrun --standalone --nproc_per_node=8 \
2 | train_mars_m.py \
3 | config/train_gpt2_medium_mars_m.py \
4 | --batch_size=15 \
5 | --gradient_accumulation_steps=4
--------------------------------------------------------------------------------
/MARS_M/scripts/run_mars_m_small.sh:
--------------------------------------------------------------------------------
1 | torchrun --standalone --nproc_per_node=8 \
2 | train_mars_m.py \
3 | config/train_gpt2_small_mars_m.py \
4 | --batch_size=15 \
5 | --gradient_accumulation_steps=4
--------------------------------------------------------------------------------
/scripts/run_adamw_xl_fw.sh:
--------------------------------------------------------------------------------
1 | torchrun --standalone --nproc_per_node=8 \
2 | MARS/train_adanw_fw.py \
3 | config/train_gpt2_large_adamw.py \
4 | --batch_size=5 \
5 | --gradient_accumulation_steps=12
6 |
--------------------------------------------------------------------------------
/scripts/run_mars_small_fw.sh:
--------------------------------------------------------------------------------
1 | torchrun --standalone --nproc_per_node=8 \
2 | MARS/train_mars_fw.py \
3 | config/train_gpt2_small_mars.py \
4 | --batch_size=15 \
5 | --gradient_accumulation_steps=4
6 |
--------------------------------------------------------------------------------
/MARS_M/scripts/run_mars_m_xl_fw.sh:
--------------------------------------------------------------------------------
1 | torchrun --standalone --nproc_per_node=8 \
2 | train_mars_m_fw.py \
3 | config/train_gpt2_xl_mars_m.py \
4 | --batch_size=5 \
5 | --gradient_accumulation_steps=12
6 |
--------------------------------------------------------------------------------
/scripts/run_adamw_small_fw.sh:
--------------------------------------------------------------------------------
1 | torchrun --standalone --nproc_per_node=8 \
2 | MARS/train_adamw_fw.py \
3 | config/train_gpt2_small_adamw.py \
4 | --batch_size=15 \
5 | --gradient_accumulation_steps=4
6 |
--------------------------------------------------------------------------------
/MARS_M/scripts/run_mars_m_small_fw.sh:
--------------------------------------------------------------------------------
1 | torchrun --standalone --nproc_per_node=8 \
2 | train_mars_m_fw.py \
3 | config/train_gpt2_small_mars_m.py \
4 | --batch_size=15 \
5 | --gradient_accumulation_steps=4
6 |
--------------------------------------------------------------------------------
/MARS_M/scripts/run_moonlight_large.sh:
--------------------------------------------------------------------------------
1 | torchrun --standalone --nproc_per_node=8 \
2 | train_moonlight.py \
3 | config/train_gpt2_large_moonlight.py \
4 | --batch_size=5 \
5 | --gradient_accumulation_steps=12
--------------------------------------------------------------------------------
/MARS_M/scripts/run_moonlight_medium.sh:
--------------------------------------------------------------------------------
1 | torchrun --standalone --nproc_per_node=8 \
2 | train_moonlight.py \
3 | config/train_gpt2_medium_moonlight.py \
4 | --batch_size=15 \
5 | --gradient_accumulation_steps=4
--------------------------------------------------------------------------------
/MARS_M/scripts/run_moonlight_small.sh:
--------------------------------------------------------------------------------
1 | torchrun --standalone --nproc_per_node=8 \
2 | train_moonlight.py \
3 | config/train_gpt2_small_moonlight.py \
4 | --batch_size=15 \
5 | --gradient_accumulation_steps=4
--------------------------------------------------------------------------------
/MARS_M/scripts/run_moonlight_xl_fw.sh:
--------------------------------------------------------------------------------
1 | torchrun --standalone --nproc_per_node=8 \
2 | train_adanw_fw.py \
3 | config/train_gpt2_large_moonlight.py \
4 | --batch_size=5 \
5 | --gradient_accumulation_steps=12
6 |
--------------------------------------------------------------------------------
/MARS_M/scripts/run_moonlight_small_fw.sh:
--------------------------------------------------------------------------------
1 | torchrun --standalone --nproc_per_node=8 \
2 | train_moonlight_fw.py \
3 | config/train_gpt2_small_moonlight.py \
4 | --batch_size=15 \
5 | --gradient_accumulation_steps=4
6 |
--------------------------------------------------------------------------------
/config/train_gpt2_small_adamw.py:
--------------------------------------------------------------------------------
1 | wandb_log = True
2 | wandb_project = 'mars'
3 | wandb_run_name='gpt2-small-adamw-100k'
4 |
5 | batch_size = 15
6 | block_size = 1024
7 | gradient_accumulation_steps = 4
8 |
9 | n_layer = 12
10 | n_head = 12
11 | n_embd = 768
12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
13 | bias = False
14 |
15 | # this makes total number of tokens be ~50B
16 | max_iters = 100000
17 | lr_decay_iters = 100000
18 |
19 | # eval stuff
20 | eval_interval = 1000
21 | eval_iters = 200
22 | log_interval = 10
23 |
24 | # optimizer
25 | optimizer_name = 'adamw'
26 | learning_rate = 6e-4 # max learning rate
27 | weight_decay = 1e-1
28 | beta1 = 0.9
29 | beta2 = 0.95
30 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
31 | # learning rate decay settings
32 | decay_lr = True # whether to decay the learning rate
33 | warmup_iters = 2000 # how many steps to warm up for
34 | min_lr = 3e-5
35 |
36 | compile = True
37 |
38 | out_dir = 'out_small_adamw_100k'
39 |
--------------------------------------------------------------------------------
/MARS_M/config/train_gpt2_small_moonlight.py:
--------------------------------------------------------------------------------
1 | wandb_log = True
2 | wandb_project = 'mars-m'
3 | wandb_run_name='gpt2-small-moonlight-100k'
4 |
5 | batch_size = 15
6 | block_size = 1024
7 | gradient_accumulation_steps = 4
8 |
9 | n_layer = 12
10 | n_head = 12
11 | n_embd = 768
12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
13 | bias = False
14 |
15 | # this makes total number of tokens be ~50B
16 | max_iters = 100000
17 | lr_decay_iters = 100000
18 |
19 | # eval stuff
20 | eval_interval = 1000
21 | eval_iters = 200
22 | log_interval = 10
23 |
24 | # optimizer
25 | optimizer_name = 'moonlight'
26 | learning_rate = 6e-3 # max learning rate
27 | weight_decay = 1e-1
28 | beta1 = 0.95
29 | beta2 = 0.99
30 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
31 | # learning rate decay settings
32 | decay_lr = True # whether to decay the learning rate
33 | warmup_iters = 2000 # how many steps to warm up for
34 | min_lr = 3e-5
35 |
36 | compile = True
37 |
38 | out_dir = 'out_small_moonlight_100k'
39 |
--------------------------------------------------------------------------------
/config/train_gpt2_small_mars.py:
--------------------------------------------------------------------------------
1 | wandb_log = True
2 | wandb_project = 'mars'
3 | wandb_run_name='gpt2-small-mars-100k'
4 |
5 | batch_size = 15
6 | block_size = 1024
7 | gradient_accumulation_steps = 4
8 |
9 | n_layer = 12
10 | n_head = 12
11 | n_embd = 768
12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
13 | bias = False
14 |
15 | # this makes total number of tokens be ~50B
16 | max_iters = 100000
17 | lr_decay_iters = 100000
18 |
19 | # eval stuff
20 | eval_interval = 1000
21 | eval_iters = 200
22 | log_interval = 10
23 |
24 | # optimizer
25 | optimizer_name = 'mars'
26 | learning_rate = 6e-3 # max learning rate
27 | weight_decay = 1e-2
28 | beta1 = 0.95
29 | beta2 = 0.99
30 | lr_1d=3e-3
31 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
32 | # learning rate decay settings
33 | decay_lr = True # whether to decay the learning rate
34 | warmup_iters = 2000 # how many steps to warm up for
35 | min_lr = 3e-5
36 |
37 | compile = True
38 |
39 | out_dir = 'out_small_mars_100k'
40 | gamma=0.025
41 |
--------------------------------------------------------------------------------
/MARS_M/config/train_gpt2_small_mars_m.py:
--------------------------------------------------------------------------------
1 | wandb_log = True
2 | wandb_project = 'mars-m'
3 | wandb_run_name='gpt2-small-mars-m-100k'
4 |
5 | batch_size = 15
6 | block_size = 1024
7 | gradient_accumulation_steps = 4
8 |
9 | n_layer = 12
10 | n_head = 12
11 | n_embd = 768
12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
13 | bias = False
14 |
15 | # this makes total number of tokens be ~50B
16 | max_iters = 100000
17 | lr_decay_iters = 100000
18 |
19 | # eval stuff
20 | eval_interval = 1000
21 | eval_iters = 200
22 | log_interval = 10
23 |
24 | # optimizer
25 | optimizer_name = 'mars-m'
26 | learning_rate = 6e-3 # max learning rate
27 | weight_decay = 1e-2
28 | beta1 = 0.95
29 | beta2 = 0.99
30 |
31 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
32 | # learning rate decay settings
33 | decay_lr = True # whether to decay the learning rate
34 | warmup_iters = 2000 # how many steps to warm up for
35 | min_lr = 3e-5
36 |
37 | compile = True
38 |
39 | out_dir = 'out_small_mars_m_100k'
40 | gamma=0.025
41 |
--------------------------------------------------------------------------------
/config/train_gpt2_xl_adamw.py:
--------------------------------------------------------------------------------
1 | wandb_log = True
2 | wandb_project = 'mars'
3 | wandb_run_name='gpt2-xl-adamw-100k'
4 |
5 | batch_size = 5
6 | block_size = 1024
7 | gradient_accumulation_steps = 12
8 |
9 | n_layer = 48
10 | n_head = 25
11 | n_embd = 1600
12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
13 | bias = False
14 | scale_attn_by_inverse_layer_idx = True
15 |
16 | # this makes total number of tokens be ~50B
17 | max_iters = 100000
18 | lr_decay_iters = 100000
19 |
20 | # eval stuff
21 | eval_interval = 1000
22 | eval_iters = 200
23 | log_interval = 10
24 |
25 | # optimizer
26 | optimizer_name = 'adamw'
27 | learning_rate = 2e-4 # max learning rate
28 | weight_decay = 1e-1
29 | beta1 = 0.9
30 | beta2 = 0.95
31 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
32 | # learning rate decay settings
33 | decay_lr = True # whether to decay the learning rate
34 | warmup_iters = 2000 # how many steps to warm up for
35 | min_lr = 1e-5
36 |
37 | compile = True
38 |
39 | out_dir = 'out_large_adamw_100k'
40 |
--------------------------------------------------------------------------------
/config/train_gpt2_large_adamw.py:
--------------------------------------------------------------------------------
1 | wandb_log = True
2 | wandb_project = 'mars'
3 | wandb_run_name='gpt2-large-adamw-100k'
4 |
5 | batch_size = 5
6 | block_size = 1024
7 | gradient_accumulation_steps = 12
8 |
9 | n_layer = 36
10 | n_head = 20
11 | n_embd = 1280
12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
13 | bias = False
14 | scale_attn_by_inverse_layer_idx = True
15 |
16 | # this makes total number of tokens be ~50B
17 | max_iters = 100000
18 | lr_decay_iters = 100000
19 |
20 | # eval stuff
21 | eval_interval = 1000
22 | eval_iters = 200
23 | log_interval = 10
24 |
25 | # optimizer
26 | optimizer_name = 'adamw'
27 | learning_rate = 2e-4 # max learning rate
28 | weight_decay = 1e-1
29 | beta1 = 0.9
30 | beta2 = 0.95
31 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
32 | # learning rate decay settings
33 | decay_lr = True # whether to decay the learning rate
34 | warmup_iters = 2000 # how many steps to warm up for
35 | min_lr = 1e-5
36 |
37 | compile = True
38 |
39 | out_dir = 'out_large_adamw_100k'
40 |
--------------------------------------------------------------------------------
/config/train_gpt2_medium_adamw.py:
--------------------------------------------------------------------------------
1 | wandb_log = True
2 | wandb_project = 'mars'
3 | wandb_run_name='gpt2-medium-adamw-100k'
4 |
5 | batch_size = 15
6 | block_size = 1024
7 | gradient_accumulation_steps = 4
8 |
9 | n_layer = 24
10 | n_head = 16
11 | n_embd = 1024
12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
13 | bias = False
14 | scale_attn_by_inverse_layer_idx = True
15 |
16 | # this makes total number of tokens be ~50B
17 | max_iters = 100000
18 | lr_decay_iters = 100000
19 |
20 | # eval stuff
21 | eval_interval = 1000
22 | eval_iters = 200
23 | log_interval = 10
24 |
25 | # optimizer
26 | optimizer_name = 'adamw'
27 | learning_rate = 3e-4 # max learning rate
28 | weight_decay = 1e-1
29 | beta1 = 0.9
30 | beta2 = 0.95
31 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
32 | # learning rate decay settings
33 | decay_lr = True # whether to decay the learning rate
34 | warmup_iters = 2000 # how many steps to warm up for
35 | min_lr = 6e-5
36 |
37 | compile = True
38 |
39 | out_dir = 'out_medium_adamw_100k'
40 |
--------------------------------------------------------------------------------
/MARS_M/config/train_gpt2_xl_moonlight.py:
--------------------------------------------------------------------------------
1 | wandb_log = True
2 | wandb_project = 'mars-m'
3 | wandb_run_name='gpt2-xl-moonlight-100k'
4 |
5 | batch_size = 5
6 | block_size = 1024
7 | gradient_accumulation_steps = 12
8 |
9 | n_layer = 48
10 | n_head = 25
11 | n_embd = 1600
12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
13 | bias = False
14 | scale_attn_by_inverse_layer_idx = True
15 |
16 | # this makes total number of tokens be ~50B
17 | max_iters = 100000
18 | lr_decay_iters = 100000
19 |
20 | # eval stuff
21 | eval_interval = 1000
22 | eval_iters = 200
23 | log_interval = 10
24 |
25 | # optimizer
26 | optimizer_name = 'moonlight'
27 | learning_rate = 3e-3 # max learning rate
28 | weight_decay = 1e-1
29 | beta1 = 0.95
30 | beta2 = 0.99
31 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
32 | # learning rate decay settings
33 | decay_lr = True # whether to decay the learning rate
34 | warmup_iters = 2000 # how many steps to warm up for
35 | min_lr = 1e-5
36 |
37 | compile = True
38 |
39 | out_dir = 'out_large_moonlight_100k'
40 |
--------------------------------------------------------------------------------
/MARS_M/config/train_gpt2_large_moonlight.py:
--------------------------------------------------------------------------------
1 | wandb_log = True
2 | wandb_project = 'mars-m'
3 | wandb_run_name='gpt2-large-moonlight-100k'
4 |
5 | batch_size = 5
6 | block_size = 1024
7 | gradient_accumulation_steps = 12
8 |
9 | n_layer = 36
10 | n_head = 20
11 | n_embd = 1280
12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
13 | bias = False
14 | scale_attn_by_inverse_layer_idx = True
15 |
16 | # this makes total number of tokens be ~50B
17 | max_iters = 100000
18 | lr_decay_iters = 100000
19 |
20 | # eval stuff
21 | eval_interval = 1000
22 | eval_iters = 200
23 | log_interval = 10
24 |
25 | # optimizer
26 | optimizer_name = 'moonlight'
27 | learning_rate = 5e-3 # max learning rate
28 | weight_decay = 1e-1
29 | beta1 = 0.95
30 | beta2 = 0.99
31 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
32 | # learning rate decay settings
33 | decay_lr = True # whether to decay the learning rate
34 | warmup_iters = 2000 # how many steps to warm up for
35 | min_lr = 1e-5
36 |
37 | compile = True
38 |
39 | out_dir = 'out_large_moonlight_100k'
40 |
--------------------------------------------------------------------------------
/config/train_gpt2_xl_mars.py:
--------------------------------------------------------------------------------
1 | wandb_log = True
2 | wandb_project = 'mars'
3 | wandb_run_name='gpt2-xl-mars-100k'
4 |
5 | batch_size = 5
6 | block_size = 1024
7 | gradient_accumulation_steps = 12
8 |
9 | n_layer = 48
10 | n_head = 25
11 | n_embd = 1600
12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
13 | bias = False
14 | scale_attn_by_inverse_layer_idx = True
15 |
16 | # this makes total number of tokens be ~50B
17 | max_iters = 100000
18 | lr_decay_iters = 100000
19 |
20 | # eval stuff
21 | eval_interval = 1000
22 | eval_iters = 200
23 | log_interval = 10
24 |
25 | # optimizer
26 | optimizer_name = 'mars'
27 | learning_rate = 2e-3 # max learning rate
28 | weight_decay = 1e-2
29 | beta1 = 0.95
30 | beta2 = 0.99
31 | lr_1d=1e-3
32 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
33 | # learning rate decay settings
34 | decay_lr = True # whether to decay the learning rate
35 | warmup_iters = 2000 # how many steps to warm up for
36 | min_lr = 1e-5
37 |
38 | compile = True
39 |
40 | out_dir = 'out_large_mars_100k'
41 | gamma=0.025
42 |
--------------------------------------------------------------------------------
/MARS_M/config/train_gpt2_medium_moonlight.py:
--------------------------------------------------------------------------------
1 | wandb_log = True
2 | wandb_project = 'mars-m'
3 | wandb_run_name='gpt2-medium-moonlight-100k'
4 |
5 | batch_size = 15
6 | block_size = 1024
7 | gradient_accumulation_steps = 4
8 |
9 | n_layer = 24
10 | n_head = 16
11 | n_embd = 1024
12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
13 | bias = False
14 | scale_attn_by_inverse_layer_idx = True
15 |
16 | # this makes total number of tokens be ~50B
17 | max_iters = 100000
18 | lr_decay_iters = 100000
19 |
20 | # eval stuff
21 | eval_interval = 1000
22 | eval_iters = 200
23 | log_interval = 10
24 |
25 | # optimizer
26 | optimizer_name = 'moonlight'
27 | learning_rate = 5e-3 # max learning rate
28 | weight_decay = 1e-1
29 | beta1 = 0.95
30 | beta2 = 0.99
31 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
32 | # learning rate decay settings
33 | decay_lr = True # whether to decay the learning rate
34 | warmup_iters = 2000 # how many steps to warm up for
35 | min_lr = 6e-5
36 |
37 | compile = True
38 |
39 | out_dir = 'out_medium_moonlight_100k'
40 |
--------------------------------------------------------------------------------
/MARS_M/config/train_gpt2_xl_mars_m.py:
--------------------------------------------------------------------------------
1 | wandb_log = True
2 | wandb_project = 'mars-m'
3 | wandb_run_name='gpt2-xl-mars-m-100k'
4 |
5 | batch_size = 5
6 | block_size = 1024
7 | gradient_accumulation_steps = 12
8 |
9 | n_layer = 48
10 | n_head = 25
11 | n_embd = 1600
12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
13 | bias = False
14 | scale_attn_by_inverse_layer_idx = True
15 |
16 | # this makes total number of tokens be ~50B
17 | max_iters = 100000
18 | lr_decay_iters = 100000
19 |
20 | # eval stuff
21 | eval_interval = 1000
22 | eval_iters = 200
23 | log_interval = 10
24 |
25 | # optimizer
26 | optimizer_name = 'mars-m'
27 | learning_rate = 3e-3 # max learning rate
28 | weight_decay = 1e-2
29 | beta1 = 0.95
30 | beta2 = 0.99
31 |
32 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
33 | # learning rate decay settings
34 | decay_lr = True # whether to decay the learning rate
35 | warmup_iters = 2000 # how many steps to warm up for
36 | min_lr = 1e-5
37 |
38 | compile = True
39 |
40 | out_dir = 'out_large_mars_m_100k'
41 | gamma=0.025
42 |
--------------------------------------------------------------------------------
/config/train_gpt2_large_mars.py:
--------------------------------------------------------------------------------
1 | wandb_log = True
2 | wandb_project = 'mars'
3 | wandb_run_name='gpt2-large-mars-100k'
4 |
5 | batch_size = 5
6 | block_size = 1024
7 | gradient_accumulation_steps = 12
8 |
9 | n_layer = 36
10 | n_head = 20
11 | n_embd = 1280
12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
13 | bias = False
14 | scale_attn_by_inverse_layer_idx = True
15 |
16 | # this makes total number of tokens be ~50B
17 | max_iters = 100000
18 | lr_decay_iters = 100000
19 |
20 | # eval stuff
21 | eval_interval = 1000
22 | eval_iters = 200
23 | log_interval = 10
24 |
25 | # optimizer
26 | optimizer_name = 'mars'
27 | learning_rate = 2e-3 # max learning rate
28 | weight_decay = 1e-2
29 | beta1 = 0.95
30 | beta2 = 0.99
31 | lr_1d=1e-3
32 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
33 | # learning rate decay settings
34 | decay_lr = True # whether to decay the learning rate
35 | warmup_iters = 2000 # how many steps to warm up for
36 | min_lr = 1e-5
37 |
38 | compile = True
39 |
40 | out_dir = 'out_large_mars_100k'
41 | gamma=0.025
42 |
--------------------------------------------------------------------------------
/MARS_M/config/train_gpt2_large_mars_m.py:
--------------------------------------------------------------------------------
1 | wandb_log = True
2 | wandb_project = 'mars-m'
3 | wandb_run_name='gpt2-large-mars-m-100k'
4 |
5 | batch_size = 5
6 | block_size = 1024
7 | gradient_accumulation_steps = 12
8 |
9 | n_layer = 36
10 | n_head = 20
11 | n_embd = 1280
12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
13 | bias = False
14 | scale_attn_by_inverse_layer_idx = True
15 |
16 | # this makes total number of tokens be ~50B
17 | max_iters = 100000
18 | lr_decay_iters = 100000
19 |
20 | # eval stuff
21 | eval_interval = 1000
22 | eval_iters = 200
23 | log_interval = 10
24 |
25 | # optimizer
26 | optimizer_name = 'mars-m'
27 | learning_rate = 5e-3 # max learning rate
28 | weight_decay = 1e-2
29 | beta1 = 0.95
30 | beta2 = 0.99
31 |
32 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
33 | # learning rate decay settings
34 | decay_lr = True # whether to decay the learning rate
35 | warmup_iters = 2000 # how many steps to warm up for
36 | min_lr = 1e-5
37 |
38 | compile = True
39 |
40 | out_dir = 'out_large_mars_m_100k'
41 | gamma=0.025
42 |
--------------------------------------------------------------------------------
/config/train_gpt2_medium_mars.py:
--------------------------------------------------------------------------------
1 | wandb_log = True
2 | wandb_project = 'mars'
3 | wandb_run_name='gpt2-medium-mars-100k'
4 |
5 | batch_size = 15
6 | block_size = 1024
7 | gradient_accumulation_steps = 4
8 |
9 | n_layer = 24
10 | n_head = 16
11 | n_embd = 1024
12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
13 | bias = False
14 | scale_attn_by_inverse_layer_idx = True
15 |
16 | # this makes total number of tokens be ~50B
17 | max_iters = 100000
18 | lr_decay_iters = 100000
19 |
20 | # eval stuff
21 | eval_interval = 1000
22 | eval_iters = 200
23 | log_interval = 10
24 |
25 | # optimizer
26 | optimizer_name = 'mars'
27 | learning_rate = 3e-3 # max learning rate
28 | weight_decay = 1e-2
29 | beta1 = 0.95
30 | beta2 = 0.99
31 | lr_1d=1.5e-3
32 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
33 | # learning rate decay settings
34 | decay_lr = True # whether to decay the learning rate
35 | warmup_iters = 2000 # how many steps to warm up for
36 | min_lr = 6e-5
37 |
38 | compile = True
39 |
40 | out_dir = 'out_medium_mars_100k'
41 | gamma=0.025
42 |
--------------------------------------------------------------------------------
/MARS_M/config/train_gpt2_medium_mars_m.py:
--------------------------------------------------------------------------------
1 | wandb_log = True
2 | wandb_project = 'mars-m'
3 | wandb_run_name='gpt2-medium-mars-m-100k'
4 |
5 | batch_size = 15
6 | block_size = 1024
7 | gradient_accumulation_steps = 4
8 |
9 | n_layer = 24
10 | n_head = 16
11 | n_embd = 1024
12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
13 | bias = False
14 | scale_attn_by_inverse_layer_idx = True
15 |
16 | # this makes total number of tokens be ~50B
17 | max_iters = 100000
18 | lr_decay_iters = 100000
19 |
20 | # eval stuff
21 | eval_interval = 1000
22 | eval_iters = 200
23 | log_interval = 10
24 |
25 | # optimizer
26 | optimizer_name = 'mars-m'
27 | learning_rate = 5e-3 # max learning rate
28 | weight_decay = 1e-2
29 | beta1 = 0.95
30 | beta2 = 0.99
31 |
32 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
33 | # learning rate decay settings
34 | decay_lr = True # whether to decay the learning rate
35 | warmup_iters = 2000 # how many steps to warm up for
36 | min_lr = 6e-5
37 |
38 | compile = True
39 |
40 | out_dir = 'out_medium_mars_m_100k'
41 | gamma=0.025
42 |
--------------------------------------------------------------------------------
/config/train_gpt2_small_muon.py:
--------------------------------------------------------------------------------
1 | wandb_log = True
2 | wandb_project = 'mars'
3 | wandb_run_name='gpt2-small-muon-100k'
4 |
5 | batch_size = 15
6 | block_size = 1024
7 | gradient_accumulation_steps = 4
8 |
9 | n_layer = 12
10 | n_head = 12
11 | n_embd = 768
12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
13 | bias = False
14 |
15 | # this makes total number of tokens be ~50B
16 | max_iters = 100000
17 | lr_decay_iters = 100000
18 |
19 | # eval stuff
20 | eval_interval = 1000
21 | eval_iters = 200
22 | log_interval = 10
23 |
24 | # optimizer
25 | optimizer_name = 'muon'
26 | learning_rate = 3e-3 # max learning rate, original=6e-4
27 | weight_decay = 1e-1
28 | muon_learning_rate = 2e-2
29 | muon_weight_decay = 0.
30 | beta1 = 0.9
31 | beta2 = 0.95
32 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
33 | # learning rate decay settings
34 | decay_lr = True # whether to decay the learning rate
35 | warmup_iters = 2000 # how many steps to warm up for
36 | min_lr = 3e-5
37 | schedule = 'cosine'
38 | compile = True
39 |
40 | out_dir = 'out_small_muon_100k'
41 |
--------------------------------------------------------------------------------
/config/train_gpt2_large_muon.py:
--------------------------------------------------------------------------------
1 | wandb_log = True
2 | wandb_project = 'mars'
3 | wandb_run_name='gpt2-large-muon-100k'
4 |
5 | batch_size = 5
6 | block_size = 1024
7 | gradient_accumulation_steps = 12
8 |
9 | n_layer = 36
10 | n_head = 20
11 | n_embd = 1280
12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
13 | bias = False
14 | scale_attn_by_inverse_layer_idx = True
15 |
16 | # this makes total number of tokens be ~50B
17 | max_iters = 100000
18 | lr_decay_iters = 100000
19 |
20 | # eval stuff
21 | eval_interval = 1000
22 | eval_iters = 200
23 | log_interval = 10
24 |
25 | # optimizer
26 | optimizer_name = 'muon'
27 | learning_rate = 1e-3 # max learning rate
28 | weight_decay = 1e-1
29 | muon_learning_rate = 6.67e-3
30 | muon_weight_decay = 0.
31 | beta1 = 0.9
32 | beta2 = 0.95
33 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
34 | # learning rate decay settings
35 | decay_lr = True # whether to decay the learning rate
36 | warmup_iters = 2000 # how many steps to warm up for
37 | min_lr = 1e-5
38 |
39 | compile = True
40 |
41 | out_dir = 'out_large_muon_100k'
42 |
--------------------------------------------------------------------------------
/config/train_gpt2_medium_muon.py:
--------------------------------------------------------------------------------
1 | wandb_log = True
2 | wandb_project = 'mars'
3 | wandb_run_name='gpt2-medium-muon-100k'
4 |
5 | batch_size = 15
6 | block_size = 1024
7 | gradient_accumulation_steps = 4
8 |
9 | n_layer = 24
10 | n_head = 16
11 | n_embd = 1024
12 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
13 | bias = False
14 | scale_attn_by_inverse_layer_idx = True
15 |
16 | # this makes total number of tokens be ~50B
17 | max_iters = 100000
18 | lr_decay_iters = 100000
19 |
20 | # eval stuff
21 | eval_interval = 1000
22 | eval_iters = 200
23 | log_interval = 10
24 |
25 | # optimizer
26 | optimizer_name = 'muon'
27 | learning_rate = 1.5e-3 # max learning rate
28 | weight_decay = 1e-1
29 | muon_learning_rate = 1e-2
30 | muon_weight_decay = 0.
31 | beta1 = 0.9
32 | beta2 = 0.95
33 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
34 | # learning rate decay settings
35 | decay_lr = True # whether to decay the learning rate
36 | warmup_iters = 2000 # how many steps to warm up for
37 | min_lr = 6e-5
38 |
39 | compile = True
40 |
41 | out_dir = 'out_medium_muon_100k'
42 |
--------------------------------------------------------------------------------
/MARS/utils/configurator.py:
--------------------------------------------------------------------------------
1 | """
2 | Poor Man's Configurator. Probably a terrible idea. Example usage:
3 | $ python train.py config/override_file.py --batch_size=32
4 | this will first run config/override_file.py, then override batch_size to 32
5 |
6 | The code in this file will be run as follows from e.g. train.py:
7 | >>> exec(open('configurator.py').read())
8 |
9 | So it's not a Python module, it's just shuttling this code away from train.py
10 | The code in this script then overrides the globals()
11 |
12 | I know people are not going to love this, I just really dislike configuration
13 | complexity and having to prepend config. to every single variable. If someone
14 | comes up with a better simple Python solution I am all ears.
15 | """
16 |
17 | import sys
18 | from ast import literal_eval
19 |
20 | for arg in sys.argv[1:]:
21 | if '=' not in arg:
22 | # assume it's the name of a config file
23 | assert not arg.startswith('--')
24 | config_file = arg
25 | print(f"Overriding config with {config_file}:")
26 | with open(config_file) as f:
27 | print(f.read())
28 | exec(open(config_file).read())
29 | else:
30 | # assume it's a --key=value argument
31 | assert arg.startswith('--')
32 | key, val = arg.split('=')
33 | key = key[2:]
34 | if key in globals():
35 | try:
36 | # attempt to eval it it (e.g. if bool, number, or etc)
37 | attempt = literal_eval(val)
38 | except (SyntaxError, ValueError):
39 | # if that goes wrong, just use the string
40 | attempt = val
41 | # ensure the types match ok
42 | assert type(attempt) == type(globals()[key])
43 | # cross fingers
44 | print(f"Overriding: {key} = {attempt}")
45 | globals()[key] = attempt
46 | else:
47 | raise ValueError(f"Unknown config key: {key}")
48 |
--------------------------------------------------------------------------------
/MARS_M/utils/configurator.py:
--------------------------------------------------------------------------------
1 | """
2 | Poor Man's Configurator. Probably a terrible idea. Example usage:
3 | $ python train.py config/override_file.py --batch_size=32
4 | this will first run config/override_file.py, then override batch_size to 32
5 |
6 | The code in this file will be run as follows from e.g. train.py:
7 | >>> exec(open('configurator.py').read())
8 |
9 | So it's not a Python module, it's just shuttling this code away from train.py
10 | The code in this script then overrides the globals()
11 |
12 | I know people are not going to love this, I just really dislike configuration
13 | complexity and having to prepend config. to every single variable. If someone
14 | comes up with a better simple Python solution I am all ears.
15 | """
16 |
17 | import sys
18 | from ast import literal_eval
19 |
20 | for arg in sys.argv[1:]:
21 | if '=' not in arg:
22 | # assume it's the name of a config file
23 | assert not arg.startswith('--')
24 | config_file = arg
25 | print(f"Overriding config with {config_file}:")
26 | with open(config_file) as f:
27 | print(f.read())
28 | exec(open(config_file).read())
29 | else:
30 | # assume it's a --key=value argument
31 | assert arg.startswith('--')
32 | key, val = arg.split('=')
33 | key = key[2:]
34 | if key in globals():
35 | try:
36 | # attempt to eval it it (e.g. if bool, number, or etc)
37 | attempt = literal_eval(val)
38 | except (SyntaxError, ValueError):
39 | # if that goes wrong, just use the string
40 | attempt = val
41 | # ensure the types match ok
42 | assert type(attempt) == type(globals()[key])
43 | # cross fingers
44 | print(f"Overriding: {key} = {attempt}")
45 | globals()[key] = attempt
46 | else:
47 | raise ValueError(f"Unknown config key: {key}")
48 |
--------------------------------------------------------------------------------
/MARS_M/openwebtext/prepare.py:
--------------------------------------------------------------------------------
1 | # saves the openwebtext dataset to a binary file for training. following was helpful:
2 | # https://github.com/HazyResearch/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py
3 |
4 | import os
5 | from tqdm import tqdm
6 | import numpy as np
7 | import tiktoken
8 | from datasets import load_dataset # huggingface datasets
9 |
10 | # number of workers in .map() call
11 | # good number to use is ~order number of cpu cores // 2
12 | num_proc = 52
13 |
14 | # takes 54GB in huggingface .cache dir, about 8M documents (8,013,769)
15 | dataset = load_dataset("openwebtext", cache_dir="nanoGPT/cache")
16 |
17 | # owt by default only contains the 'train' split, so create a test split
18 | split_dataset = dataset["train"].train_test_split(test_size=0.0005, seed=2357, shuffle=True)
19 | split_dataset['val'] = split_dataset.pop('test') # rename the test split to val
20 |
21 | # this results in:
22 | # >>> split_dataset
23 | # DatasetDict({
24 | # train: Dataset({
25 | # features: ['text'],
26 | # num_rows: 8009762
27 | # })
28 | # val: Dataset({
29 | # features: ['text'],
30 | # num_rows: 4007
31 | # })
32 | # })
33 |
34 | # we now want to tokenize the dataset. first define the encoding function (gpt2 bpe)
35 | enc = tiktoken.get_encoding("gpt2")
36 | def process(example):
37 | ids = enc.encode_ordinary(example['text']) # encode_ordinary ignores any special tokens
38 | ids.append(enc.eot_token) # add the end of text token, e.g. 50256 for gpt2 bpe
39 | # note: I think eot should be prepended not appended... hmm. it's called "eot" though...
40 | out = {'ids': ids, 'len': len(ids)}
41 | return out
42 |
43 | # tokenize the dataset
44 | tokenized = split_dataset.map(
45 | process,
46 | remove_columns=['text'],
47 | desc="tokenizing the splits",
48 | num_proc=num_proc,
49 | )
50 | print('tokenization finished')
51 | # concatenate all the ids in each dataset into one large file we can use for training
52 | for split, dset in tokenized.items():
53 | arr_len = np.sum(dset['len'])
54 | filename = os.path.join(os.path.dirname(__file__), f'{split}.bin')
55 | dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16)
56 | arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,))
57 |
58 | print(f"writing {filename}...")
59 | idx = 0
60 | for example in tqdm(dset):
61 | arr[idx : idx + example['len']] = example['ids']
62 | idx += example['len']
63 | arr.flush()
64 |
65 | # train.bin is ~17GB, val.bin ~8.5MB
66 | # train has ~9B tokens (9,035,582,198)
67 | # val has ~4M tokens (4,434,897)
68 |
69 | # to read the bin files later, e.g. with numpy:
70 | # m = np.memmap('train.bin', dtype=np.uint16, mode='r')
71 |
--------------------------------------------------------------------------------
/data/openwebtext/prepare.py:
--------------------------------------------------------------------------------
1 | # saves the openwebtext dataset to a binary file for training. following was helpful:
2 | # https://github.com/HazyResearch/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py
3 |
4 | import os
5 | from tqdm import tqdm
6 | import numpy as np
7 | import tiktoken
8 | from datasets import load_dataset # huggingface datasets
9 |
10 | # number of workers in .map() call
11 | # good number to use is ~order number of cpu cores // 2
12 | num_proc = 52
13 |
14 | # takes 54GB in huggingface .cache dir, about 8M documents (8,013,769)
15 | dataset = load_dataset("openwebtext", cache_dir="nanoGPT/cache")
16 |
17 | # owt by default only contains the 'train' split, so create a test split
18 | split_dataset = dataset["train"].train_test_split(test_size=0.0005, seed=2357, shuffle=True)
19 | split_dataset['val'] = split_dataset.pop('test') # rename the test split to val
20 |
21 | # this results in:
22 | # >>> split_dataset
23 | # DatasetDict({
24 | # train: Dataset({
25 | # features: ['text'],
26 | # num_rows: 8009762
27 | # })
28 | # val: Dataset({
29 | # features: ['text'],
30 | # num_rows: 4007
31 | # })
32 | # })
33 |
34 | # we now want to tokenize the dataset. first define the encoding function (gpt2 bpe)
35 | enc = tiktoken.get_encoding("gpt2")
36 | def process(example):
37 | ids = enc.encode_ordinary(example['text']) # encode_ordinary ignores any special tokens
38 | ids.append(enc.eot_token) # add the end of text token, e.g. 50256 for gpt2 bpe
39 | # note: I think eot should be prepended not appended... hmm. it's called "eot" though...
40 | out = {'ids': ids, 'len': len(ids)}
41 | return out
42 |
43 | # tokenize the dataset
44 | tokenized = split_dataset.map(
45 | process,
46 | remove_columns=['text'],
47 | desc="tokenizing the splits",
48 | num_proc=num_proc,
49 | )
50 | print('tokenization finished')
51 | # concatenate all the ids in each dataset into one large file we can use for training
52 | for split, dset in tokenized.items():
53 | arr_len = np.sum(dset['len'])
54 | filename = os.path.join(os.path.dirname(__file__), f'{split}.bin')
55 | dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16)
56 | arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,))
57 |
58 | print(f"writing {filename}...")
59 | idx = 0
60 | for example in tqdm(dset):
61 | arr[idx : idx + example['len']] = example['ids']
62 | idx += example['len']
63 | arr.flush()
64 |
65 | # train.bin is ~17GB, val.bin ~8.5MB
66 | # train has ~9B tokens (9,035,582,198)
67 | # val has ~4M tokens (4,434,897)
68 |
69 | # to read the bin files later, e.g. with numpy:
70 | # m = np.memmap('train.bin', dtype=np.uint16, mode='r')
71 |
--------------------------------------------------------------------------------
/MARS/opt.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import collections
3 | """
4 | Adapted from askerlee@github: https://github.com/KellerJordan/modded-nanogpt/issues/9
5 | """
6 | def separate_params(param_groups):
7 | param_groups_2d = []
8 | param_groups_non2d = []
9 | total_param_2d_count = 0
10 | total_param_non2d_count = 0
11 |
12 |
13 | # Convert iterators to lists
14 | if isinstance(param_groups, collections.abc.Iterable):
15 | param_groups = list(param_groups)
16 |
17 | # Check if param_groups is a list of dicts or list of params
18 | if (isinstance(param_groups, list) and isinstance(param_groups[0], dict)) \
19 | or isinstance(param_groups, dict):
20 | if isinstance(param_groups, dict):
21 | param_groups = [param_groups]
22 | # param_groups is a list of dicts
23 | for group in param_groups:
24 | params_2d, params_non2d, param_2d_count, param_non2d_count = separate_params(group['params'])
25 | param_group_2d = {'params': params_2d}
26 | param_group_non2d = {'params': params_non2d}
27 | # Copy the group dict and replace the 'params' key with the separated params
28 | for k in group.keys():
29 | if k != 'params':
30 | param_group_2d[k] = group[k]
31 | param_group_non2d[k] = group[k]
32 |
33 | param_groups_2d.append(param_group_2d)
34 | param_groups_non2d.append(param_group_non2d)
35 | total_param_2d_count += param_2d_count
36 | total_param_non2d_count += param_non2d_count
37 |
38 | return param_groups_2d, param_groups_non2d, total_param_2d_count, total_param_non2d_count
39 |
40 | elif isinstance(param_groups, list) and isinstance(param_groups[0], torch.Tensor):
41 | params_2d = []
42 | params_non2d = []
43 | param_group = param_groups
44 | # param_group is a list of param tensors
45 | for param in param_group:
46 | if param.ndim >= 2:
47 | params_2d.append(param)
48 | else:
49 | params_non2d.append(param)
50 | return params_2d, params_non2d, len(params_2d), len(params_non2d)
51 | else:
52 | breakpoint()
53 |
54 | '''
55 | # CombinedOptimizer is now a torch.optim.Optimizer, compatible with pytorch lightning.
56 | # Original Example:
57 | optimizer = CombinedOptimizer([
58 | torch.optim.AdamW(self.lm_head.parameters(), lr=learning_rate, betas=betas, weight_decay=0, fused=True),
59 | OrthogonalNesterov(self.transformer.h.parameters(), lr=0.1*learning_rate, momentum=0.95)
60 | ])
61 | # Refactored Example:
62 | optimizer = CombinedOptimizer(\
63 | self.parameters(),
64 | [OrthogonalNesterov, torch.optim.AdamW],
65 | [{'lr': 0.1*learning_rate, 'momentum': 0.95},
66 | {'lr': learning_rate, 'betas': betas, 'weight_decay': 0, 'fused': True}
67 | ])
68 | '''
69 |
70 | class CombinedOptimizer(torch.optim.Optimizer):
71 | def __init__(self, params, optimizer_types, configs, raw_model = False):
72 | # Separate 2D and non-2D parameters.
73 | # If params is a list of tensors, then each of param_groups_2d and param_groups_non2d
74 | # will be a list of tensors.
75 | # If params is a list of dicts, then each of param_groups_2d and param_groups_non2d
76 | # will be a list of dicts.
77 | # If params is a dict, then each of param_groups_2d and param_groups_non2d will
78 | # be a list of dicts containing only one dict.
79 | if raw_model:
80 | params_others = list(params.transformer.h.parameters())
81 | param_groups_2d, param_groups_non2d, total_param_2d_count, total_param_non2d_count \
82 | = separate_params(params_others)
83 | param_groups_non2d.extend(list(params.lm_head.parameters()))
84 | total_param_non2d_count += 2
85 | else:
86 | param_groups_2d, param_groups_non2d, total_param_2d_count, total_param_non2d_count \
87 | = separate_params(params)
88 | param_groups_2d_non2d = (param_groups_non2d, param_groups_2d)
89 | print(f"Total 2D params: {total_param_2d_count}, Total non-2D params: {total_param_non2d_count}")
90 |
91 | assert len(optimizer_types) == len(configs) == 2
92 | self.optimizers = [ optimizer_types[i](param_groups_2d_non2d[i], **configs[i]) for i in range(2) ]
93 | self.param_groups = [pg for opt in self.optimizers for pg in opt.param_groups]
94 | self.base_lrs = [opt.param_groups[0]['lr'] for opt in self.optimizers]
95 | # Combine the state dicts of all opt in self.optimizers into a single dict
96 | self.state = {k: v for opt in self.optimizers for k, v in opt.state.items()}
97 | # Initially all states are empty. So no point to print their counts.
98 | # Only use the defaults of the OrthogonalNesterov optimizer
99 | self.defaults = self.optimizers[0].defaults
100 |
101 | def step(self, *args, **kwargs):
102 | for opt in self.optimizers:
103 | opt.step(*args, **kwargs)
104 |
105 | def zero_grad(self, **kwargs):
106 | for opt in self.optimizers:
107 | opt.zero_grad(**kwargs)
108 |
109 | def scale_lrs(self, lr_scale):
110 | for base_lr, opt in zip(self.base_lrs, self.optimizers):
111 | opt.param_groups[0]['lr'] = base_lr * lr_scale
112 |
113 | def state_dict(self):
114 | return [opt.state_dict() for opt in self.optimizers]
--------------------------------------------------------------------------------
/MARS/optimizers/muon.py:
--------------------------------------------------------------------------------
1 | """
2 | Adapted from KellerJordan/modded-nanogpt: https://github.com/KellerJordan/modded-nanogpt/blob/master/train_gpt2.py
3 | """
4 |
5 | import torch
6 | import torch.distributed as dist
7 | import os
8 |
9 | def zeropower_via_svd(G, steps=None):
10 | U, S, V = G.svd()
11 | return U @ V.T
12 |
13 | @torch.compile
14 | def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7):
15 | """
16 | Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
17 | quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
18 | of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
19 | zero even beyond the point where the iteration no longer converges all the way to one everywhere
20 | on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
21 | where S' is diagonal with S_{ii}' \sim Uniform(0.5, 1.5), which turns out not to hurt model
22 | performance at all relative to UV^T, where USV^T = G is the SVD.
23 | """
24 | assert len(G.shape) == 2
25 | a, b, c = (3.4445, -4.7750, 2.0315)
26 | X = G.bfloat16()
27 | X /= (X.norm() + eps) # ensure top singular value <= 1
28 | if G.size(0) > G.size(1):
29 | X = X.T
30 | for _ in range(steps):
31 | A = X @ X.T
32 | B = A @ X
33 | X = a * X + b * B + c * A @ B
34 | if G.size(0) > G.size(1):
35 | X = X.T
36 | return X
37 |
38 | zeropower_backends = dict(svd=zeropower_via_svd, newtonschulz5=zeropower_via_newtonschulz5)
39 |
40 | class Muon(torch.optim.Optimizer):
41 | """
42 | Muon - MomentUm Orthogonalized by Newton-schulz
43 |
44 | Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
45 | processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
46 | matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
47 | the advantage that it can be stably run in bfloat16 on the GPU.
48 |
49 | Some warnings:
50 | - This optimizer assumes that all parameters passed in are 2D.
51 | - It should not be used for the embedding layer, the final fully connected layer, or any {0,1}-D
52 | parameters; those should all be optimized by a standard method (e.g., AdamW).
53 | - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
54 | - We believe it is unlikely to work well for training with small batch size.
55 | - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
56 | - We have not yet tried this optimizer for training scenarios larger than NanoGPT (124M).
57 |
58 | Arguments:
59 | lr: The learning rate used by the internal SGD.
60 | momentum: The momentum used by the internal SGD.
61 | nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
62 | backend: The chosen backend for the orthogonalization step. (recommended: 'newtonschulz5')
63 | backend_steps: The number of iteration steps to use in the backend, if it is iterative.
64 | """
65 | def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True,
66 | backend='newtonschulz5', backend_steps=5, weight_decay=0.):
67 | defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, backend=backend, backend_steps=backend_steps, weight_decay=weight_decay)
68 | super().__init__(params, defaults)
69 | if 'WORLD_SIZE' in os.environ:
70 | self.world_size = int(os.environ['WORLD_SIZE'])
71 | self.rank = int(os.environ['RANK'])
72 | else:
73 | self.world_size = 1
74 | self.rank = 0
75 |
76 | def step(self):
77 |
78 | for group in self.param_groups:
79 |
80 | lr = group['lr']
81 | weight_decay = group['weight_decay']
82 | momentum = group['momentum']
83 | zeropower_backend = zeropower_backends[group['backend']]
84 |
85 | # generate weight updates in distributed fashion
86 | total_params = sum(p.numel() for p in group['params'])
87 | updates_flat = torch.zeros(total_params, device='cuda', dtype=torch.bfloat16)
88 | curr_idx = 0
89 | for i, p in enumerate(group['params']):
90 | # luckily this will perfectly distribute a transformer with multiple of 4 layers to 8 GPUs
91 | if i % int(self.world_size) == int(self.rank):
92 | g = p.grad
93 | assert g is not None
94 | if g.ndim > 2:
95 | g = g.view(g.size(0), -1)
96 | state = self.state[p]
97 | if 'momentum_buffer' not in state:
98 | state['momentum_buffer'] = torch.zeros_like(g)
99 | buf = state['momentum_buffer']
100 | buf.mul_(momentum).add_(g)
101 | if group['nesterov']:
102 | g = g.add(buf, alpha=momentum)
103 | g = zeropower_backend(g, steps=group['backend_steps'])
104 | g *= max(1, g.size(0)/g.size(1))**0.5
105 | updates_flat[curr_idx:curr_idx+p.numel()] = g.flatten()
106 | curr_idx += p.numel()
107 |
108 | # sync updates across devices. we are not memory-constrained so can do this simple deserialization
109 | if self.world_size > 1:
110 | dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM)
111 |
112 | # deserialize and apply updates
113 | curr_idx = 0
114 | for p in group['params']:
115 | g = updates_flat[curr_idx:curr_idx+p.numel()].view_as(p.data).type_as(p.data)
116 | p.data.mul_(1.-lr*weight_decay).add_(g, alpha=-lr)
117 | curr_idx += p.numel()
--------------------------------------------------------------------------------
/MARS/optimizers/adamw.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch.optim.optimizer import Optimizer
4 | # from megatron.optimizer.l2_norm import l2_norm
5 |
6 | def exists(val):
7 | return val is not None
8 |
9 |
10 | class AdamW(Optimizer):
11 | """Implements Adam algorithm.
12 |
13 | It has been proposed in `Adam: A Method for Stochastic Optimization`_.
14 |
15 | Arguments:
16 | params (iterable): iterable of parameters to optimize or dicts defining
17 | parameter groups
18 | lr (float, optional): learning rate (default: 1e-3)
19 | betas (Tuple[float, float], optional): coefficients used for computing
20 | running averages of gradient and its square (default: (0.9, 0.999))
21 | eps (float, optional): term added to the denominator to improve
22 | numerical stability (default: 1e-8)
23 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
24 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this
25 | algorithm from the paper `On the Convergence of Adam and Beyond`_
26 |
27 | .. _Adam\: A Method for Stochastic Optimization:
28 | https://arxiv.org/abs/1412.6980
29 | .. _On the Convergence of Adam and Beyond:
30 | https://openreview.net/forum?id=ryQu7f-RZ
31 | """
32 |
33 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
34 | weight_decay=0, amsgrad=False):
35 | if not 0.0 <= lr:
36 | raise ValueError("Invalid learning rate: {}".format(lr))
37 | if not 0.0 <= eps:
38 | raise ValueError("Invalid epsilon value: {}".format(eps))
39 | if not 0.0 <= betas[0] < 1.0:
40 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
41 | if not 0.0 <= betas[1] < 1.0:
42 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
43 | defaults = dict(lr=lr, betas=betas, eps=eps,
44 | weight_decay=weight_decay, amsgrad=amsgrad)
45 | super(AdamW, self).__init__(params, defaults)
46 | self.eps = eps
47 |
48 | def __setstate__(self, state):
49 | super(AdamW, self).__setstate__(state)
50 | for group in self.param_groups:
51 | group.setdefault('amsgrad', False)
52 |
53 | @torch.no_grad()
54 | def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None, grad_scaler=None):
55 | """Performs a single optimization step.
56 |
57 | Arguments:
58 | closure (callable, optional): A closure that reevaluates the model
59 | and returns the loss.
60 | """
61 | if any(p is not None for p in [grads, output_params, scale, grad_norms]):
62 | raise RuntimeError('FusedAdam has been updated. Simply initialize it identically to torch.optim.Adam, and call step() with no arguments.')
63 |
64 | loss = None
65 | if exists(closure):
66 | with torch.enable_grad():
67 | loss = closure()
68 | real_update = 0
69 | real_update_wo_lr = 0
70 |
71 | for group in self.param_groups:
72 | for p in filter(lambda p: exists(p.grad), group['params']):
73 | if p.grad is None:
74 | continue
75 | grad = p.grad.data
76 | if grad.is_sparse:
77 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
78 | amsgrad = group['amsgrad']
79 |
80 | state = self.state[p]
81 | #print('----- starting a parameter state', state.keys(), 'Length of state', len(state))
82 | # State initialization
83 | if len(state) == 0:
84 | state['step'] = 0
85 | # Exponential moving average of gradient values
86 | state['exp_avg'] = torch.zeros_like(p.data)
87 | # Exponential moving average of squared gradient values
88 | state['exp_avg_sq'] = torch.zeros_like(p.data)
89 | if amsgrad:
90 | # Maintains max of all exp. moving avg. of sq. grad. values
91 | state['max_exp_avg_sq'] = torch.zeros_like(p.data)
92 |
93 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
94 | if amsgrad:
95 | max_exp_avg_sq = state['max_exp_avg_sq']
96 | beta1, beta2 = group['betas']
97 |
98 | if 'step' in state:
99 | state['step'] += 1
100 | else:
101 | state['step'] = 1
102 |
103 | # Decay the first and second moment running average coefficient
104 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
105 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
106 | if amsgrad:
107 | # Maintains the maximum of all 2nd moment running avg. till now
108 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
109 | # Use the max. for normalizing running avg. of gradient
110 | denom = max_exp_avg_sq.sqrt().add_(self.eps)
111 | else:
112 | denom = exp_avg_sq.sqrt().add_(self.eps)
113 |
114 | bias_correction1 = 1 - beta1 ** state['step']
115 | bias_correction2 = 1 - beta2 ** state['step']
116 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
117 |
118 | # p.data.addcdiv_(-step_size, exp_avg, denom)
119 | real_update_tmp = -step_size * torch.mul(p.data, group['weight_decay']).addcdiv_(1, exp_avg, denom)
120 | real_update_wo_lr_tmp = torch.mul(p.data, group['weight_decay']).addcdiv_(1, exp_avg, denom)
121 |
122 | p.data.add_(real_update_tmp)
123 | return loss
124 |
125 |
126 |
127 |
--------------------------------------------------------------------------------
/MARS/utils/cv_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | import torchvision
5 | from torchvision import datasets, transforms
6 | from torch.utils.data import DataLoader
7 |
8 | def get_model(args):
9 | """
10 | models including:
11 | - VGG16
12 | - resnet18
13 | from https://github.com/iShohei220/adopt/blob/main/adopt.py and https://github.com/uclaml/Padam/blob/master/models/resnet.py
14 | """
15 | if args.dataset in ['mnist', 'cifar10']:
16 | num_classes = 10
17 | elif args.dataset in ['cifar100']:
18 | num_classes = 100
19 | else:
20 | raise NotImplementedError(f"{args.dataset} is not implemented.")
21 | if args.net == 'simple_cnn':
22 | from .model_CNN import Network
23 | model_config = {
24 | "n_inputs": (3, 32, 32) if args.dataset == "cifar10" else (1, 28, 28),
25 | "conv_layers_list": [
26 | {"filters": 32, "kernel_size": 3, "repeat": 2, "batch_norm": True},
27 | {"filters": 64, "kernel_size": 3, "repeat": 2, "batch_norm": True},
28 | {"filters": 128, "kernel_size": 3, "repeat": 2, "batch_norm": True},
29 | ],
30 | "n_hiddens_list": [512],
31 | "n_outputs": 10,
32 | "dropout": 0.2,
33 | }
34 | model = Network(**model_config)
35 | elif args.net == 'resnet18':
36 | from .model_CNN import ResNet18
37 | model = ResNet18(num_classes = num_classes)
38 | else:
39 | try:
40 | model = torchvision.models.get_model(args.net, num_classes=num_classes)
41 | except:
42 | print('Model not found')
43 | raise NotImplementedError
44 | return model
45 |
46 | def get_datasets(dataset_name: str, train_batch_size: int, eval_batch_size: int):
47 | """Get train and test dataloaders."""
48 | print('==> Preparing data..')
49 | if dataset_name == "mnist":
50 | transform = transforms.Compose([
51 | transforms.ToTensor(),
52 | transforms.Normalize((0.1307,), (0.3081,))
53 | ])
54 | train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
55 | test_dataset = datasets.MNIST('./data', train=False, transform=transform)
56 | elif dataset_name == "cifar10":
57 | transform_train = transforms.Compose([
58 | transforms.RandomCrop(32, padding=4),
59 | transforms.RandomHorizontalFlip(),
60 | transforms.ToTensor(),
61 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
62 | ])
63 | transform_test = transforms.Compose([
64 | transforms.ToTensor(),
65 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
66 | ])
67 | train_dataset = datasets.CIFAR10('./data', train=True, download=True, transform=transform_train)
68 | test_dataset = datasets.CIFAR10('./data', train=False, transform=transform_test)
69 | elif dataset_name == "cifar100":
70 | transform_train = transforms.Compose([
71 | transforms.RandomCrop(32, padding=4),
72 | transforms.RandomHorizontalFlip(),
73 | transforms.ToTensor(),
74 | transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276]),
75 | ])
76 | transform_test = transforms.Compose([
77 | transforms.ToTensor(),
78 | transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276]),
79 | ])
80 | train_dataset = datasets.CIFAR100('./data', train=True, download=True, transform=transform_train)
81 | test_dataset = datasets.CIFAR100('./data', train=False, transform=transform_test)
82 | else:
83 | raise NotImplementedError(f"{dataset_name=} is not implemented.")
84 |
85 | train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=4)
86 | test_loader = DataLoader(test_dataset, batch_size=eval_batch_size, shuffle=False, num_workers=4)
87 |
88 | return train_loader, test_loader
89 |
90 |
91 | class WarmupCosineScheduler:
92 | """Custom learning rate scheduler with linear warmup and cosine decay."""
93 | def __init__(self, optimizer, warmup_iters: int, total_iters: int, min_lr=0.):
94 | self.optimizer = optimizer
95 | self.warmup_iters = warmup_iters
96 | self.total_iters = total_iters
97 | self.min_lr = min_lr
98 | self.max_lr_list = []
99 | for param_group in self.optimizer.param_groups:
100 | self.max_lr_list.append(param_group['lr'])
101 | self.current_iter = 0
102 | self.lr_list = []
103 | for param_group in self.optimizer.param_groups:
104 | self.lr_list.append(param_group['lr'])
105 |
106 | def step(self):
107 | self.current_iter += 1
108 | lr_list = []
109 | cnt = 0
110 | for param_group in self.optimizer.param_groups:
111 | max_lr = self.max_lr_list[cnt]
112 | if self.current_iter <= self.warmup_iters:
113 | lr = self.current_iter / self.warmup_iters * max_lr
114 | else:
115 | lr = self.min_lr + 0.5 * (max_lr - self.min_lr) * (
116 | np.cos((self.current_iter - self.warmup_iters) / (self.total_iters - self.warmup_iters) * 3.14159265 / 2)
117 | ).item()
118 | param_group['lr'] = lr
119 | cnt += 1
120 | lr_list.append(lr)
121 | self.lr_list = lr_list
122 | def get_lr(self):
123 | lr_list = []
124 | for param_group in self.optimizer.param_groups:
125 | lr_list.append(param_group['lr'])
126 | return lr_list
127 |
128 | class ConstantScheduler:
129 | """Constant learning rate scheduler."""
130 | def __init__(self, optimizer, lr: float):
131 | self.optimizer = optimizer
132 | lr_list = []
133 | for param_group in self.optimizer.param_groups:
134 | lr_list.append(lr)
135 |
136 | def step(self):
137 | pass
138 |
139 | def get_lr(self):
140 | lr_list = []
141 | for param_group in self.optimizer.param_groups:
142 | lr_list.append(param_group['lr'])
143 | return lr_list
144 |
145 | def get_scheduler(optimizer, args):
146 | if args.scheduler == 'multistep':
147 | from torch.optim.lr_scheduler import MultiStepLR
148 | scheduler = MultiStepLR(optimizer, milestones=[args.Nepoch // 2, (args.Nepoch * 3) // 4], gamma=0.1)
149 | elif args.scheduler == 'cosine':
150 | scheduler = WarmupCosineScheduler(optimizer, warmup_iters = args.Nepoch // 10, total_iters = args.Nepoch,
151 | min_lr = 0.)
152 | elif args.scheduler == 'constant':
153 | scheduler = ConstantScheduler(optimizer, lr = args.lr)
154 | return scheduler
--------------------------------------------------------------------------------
/MARS_M/utils/cv_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | import torchvision
5 | from torchvision import datasets, transforms
6 | from torch.utils.data import DataLoader
7 |
8 | def get_model(args):
9 | """
10 | models including:
11 | - VGG16
12 | - resnet18
13 | from https://github.com/iShohei220/adopt/blob/main/adopt.py and https://github.com/uclaml/Padam/blob/master/models/resnet.py
14 | """
15 | if args.dataset in ['mnist', 'cifar10']:
16 | num_classes = 10
17 | elif args.dataset in ['cifar100']:
18 | num_classes = 100
19 | else:
20 | raise NotImplementedError(f"{args.dataset} is not implemented.")
21 | if args.net == 'simple_cnn':
22 | from .model_CNN import Network
23 | model_config = {
24 | "n_inputs": (3, 32, 32) if args.dataset == "cifar10" else (1, 28, 28),
25 | "conv_layers_list": [
26 | {"filters": 32, "kernel_size": 3, "repeat": 2, "batch_norm": True},
27 | {"filters": 64, "kernel_size": 3, "repeat": 2, "batch_norm": True},
28 | {"filters": 128, "kernel_size": 3, "repeat": 2, "batch_norm": True},
29 | ],
30 | "n_hiddens_list": [512],
31 | "n_outputs": 10,
32 | "dropout": 0.2,
33 | }
34 | model = Network(**model_config)
35 | elif args.net == 'resnet18':
36 | from .model_CNN import ResNet18
37 | model = ResNet18(num_classes = num_classes)
38 | else:
39 | try:
40 | model = torchvision.models.get_model(args.net, num_classes=num_classes)
41 | except:
42 | print('Model not found')
43 | raise NotImplementedError
44 | return model
45 |
46 | def get_datasets(dataset_name: str, train_batch_size: int, eval_batch_size: int):
47 | """Get train and test dataloaders."""
48 | print('==> Preparing data..')
49 | if dataset_name == "mnist":
50 | transform = transforms.Compose([
51 | transforms.ToTensor(),
52 | transforms.Normalize((0.1307,), (0.3081,))
53 | ])
54 | train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
55 | test_dataset = datasets.MNIST('./data', train=False, transform=transform)
56 | elif dataset_name == "cifar10":
57 | transform_train = transforms.Compose([
58 | transforms.RandomCrop(32, padding=4),
59 | transforms.RandomHorizontalFlip(),
60 | transforms.ToTensor(),
61 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
62 | ])
63 | transform_test = transforms.Compose([
64 | transforms.ToTensor(),
65 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
66 | ])
67 | train_dataset = datasets.CIFAR10('./data', train=True, download=True, transform=transform_train)
68 | test_dataset = datasets.CIFAR10('./data', train=False, transform=transform_test)
69 | elif dataset_name == "cifar100":
70 | transform_train = transforms.Compose([
71 | transforms.RandomCrop(32, padding=4),
72 | transforms.RandomHorizontalFlip(),
73 | transforms.ToTensor(),
74 | transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276]),
75 | ])
76 | transform_test = transforms.Compose([
77 | transforms.ToTensor(),
78 | transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276]),
79 | ])
80 | train_dataset = datasets.CIFAR100('./data', train=True, download=True, transform=transform_train)
81 | test_dataset = datasets.CIFAR100('./data', train=False, transform=transform_test)
82 | else:
83 | raise NotImplementedError(f"{dataset_name=} is not implemented.")
84 |
85 | train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=4)
86 | test_loader = DataLoader(test_dataset, batch_size=eval_batch_size, shuffle=False, num_workers=4)
87 |
88 | return train_loader, test_loader
89 |
90 |
91 | class WarmupCosineScheduler:
92 | """Custom learning rate scheduler with linear warmup and cosine decay."""
93 | def __init__(self, optimizer, warmup_iters: int, total_iters: int, min_lr=0.):
94 | self.optimizer = optimizer
95 | self.warmup_iters = warmup_iters
96 | self.total_iters = total_iters
97 | self.min_lr = min_lr
98 | self.max_lr_list = []
99 | for param_group in self.optimizer.param_groups:
100 | self.max_lr_list.append(param_group['lr'])
101 | self.current_iter = 0
102 | self.lr_list = []
103 | for param_group in self.optimizer.param_groups:
104 | self.lr_list.append(param_group['lr'])
105 |
106 | def step(self):
107 | self.current_iter += 1
108 | lr_list = []
109 | cnt = 0
110 | for param_group in self.optimizer.param_groups:
111 | max_lr = self.max_lr_list[cnt]
112 | if self.current_iter <= self.warmup_iters:
113 | lr = self.current_iter / self.warmup_iters * max_lr
114 | else:
115 | lr = self.min_lr + 0.5 * (max_lr - self.min_lr) * (
116 | np.cos((self.current_iter - self.warmup_iters) / (self.total_iters - self.warmup_iters) * 3.14159265 / 2)
117 | ).item()
118 | param_group['lr'] = lr
119 | cnt += 1
120 | lr_list.append(lr)
121 | self.lr_list = lr_list
122 | def get_lr(self):
123 | lr_list = []
124 | for param_group in self.optimizer.param_groups:
125 | lr_list.append(param_group['lr'])
126 | return lr_list
127 |
128 | class ConstantScheduler:
129 | """Constant learning rate scheduler."""
130 | def __init__(self, optimizer, lr: float):
131 | self.optimizer = optimizer
132 | lr_list = []
133 | for param_group in self.optimizer.param_groups:
134 | lr_list.append(lr)
135 |
136 | def step(self):
137 | pass
138 |
139 | def get_lr(self):
140 | lr_list = []
141 | for param_group in self.optimizer.param_groups:
142 | lr_list.append(param_group['lr'])
143 | return lr_list
144 |
145 | def get_scheduler(optimizer, args):
146 | if args.scheduler == 'multistep':
147 | from torch.optim.lr_scheduler import MultiStepLR
148 | scheduler = MultiStepLR(optimizer, milestones=[args.Nepoch // 2, (args.Nepoch * 3) // 4], gamma=0.1)
149 | elif args.scheduler == 'cosine':
150 | scheduler = WarmupCosineScheduler(optimizer, warmup_iters = args.Nepoch // 10, total_iters = args.Nepoch,
151 | min_lr = 0.)
152 | elif args.scheduler == 'constant':
153 | scheduler = ConstantScheduler(optimizer, lr = args.lr)
154 | return scheduler
--------------------------------------------------------------------------------
/MARS/train_CV.py:
--------------------------------------------------------------------------------
1 | """
2 | Adapted from uclaml/Padam: https://github.com/uclaml/Padam/blob/master/run_cnn_test_cifar10.py
3 | """
4 | import numpy as np
5 | import os
6 | import argparse
7 | import json
8 | from tqdm import tqdm
9 |
10 | parser = argparse.ArgumentParser(description='PyTorch Training')
11 | parser.add_argument(
12 | "--dataset", type=str, default="cifar10", choices=["mnist", "cifar10", "cifar100"], help="dataset to use"
13 | )
14 | parser.add_argument(
15 | "--scheduler", type=str, default="multistep", choices=["multistep", "cosine", "constant"], help="scheduler to use"
16 | )
17 | parser.add_argument("--train_bsz", type=int, default=128, help="training batch size")
18 | parser.add_argument("--eval_bsz", type=int, default=100, help="eval batch size")
19 | parser.add_argument("--seed", type=int, default=0, help="random seed")
20 | parser.add_argument("--cpu", action="store_true", help="use cpu only")
21 | parser.add_argument("--cuda", type=str, default="0", help="device to use")
22 |
23 | parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
24 | parser.add_argument('--adamw_lr', default=0.003, type=float, help='learning rate for adamw')
25 | parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
26 | parser.add_argument('--optim', '-m', type=str, choices=["adam", "adamw", "mars", "muon"], default='mars', help='optimization method, default: mars')
27 | parser.add_argument('--net', '-n', type=str, default="resnet18", help='network archtecture, choosing from "simple_cnn" or torchvision models. default: resnet18')
28 | parser.add_argument('--wd', default=0., type=float, help='weight decay')
29 | parser.add_argument('--Nepoch', default=200, type=int, help='number of epoch')
30 | parser.add_argument('--beta1', default=0.9, type=float, help='beta1')
31 | parser.add_argument('--beta2', default=0.999, type=float, help='beta2')
32 | parser.add_argument('--wandb', action='store_true', help='use wandb')
33 | parser.add_argument('--save_dir', type=str, default="./checkpoint", help='save directory')
34 | parser.add_argument('--wandb_name', type=str, default="None", help='log directory')
35 |
36 |
37 | args = parser.parse_args()
38 | os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda
39 | if args.wandb:
40 | import wandb
41 | if args.wandb_name == "None":
42 | wandb.init(project="CV", name=args.dataset+"_"+args.net+"_"+args.optim+"_"+str(args.lr), config=args)
43 | else:
44 | wandb.init(project="CV", name=args.wandb_name, config=args)
45 |
46 | import torch
47 | import torch.nn as nn
48 | import torch.optim as optim
49 | import torch.backends.cudnn as cudnn
50 | from utils.cv_utils import get_datasets, get_scheduler, get_model
51 | use_cuda = torch.cuda.is_available() and not args.cpu
52 |
53 | os.environ['PYTHONHASHSEED'] = str(args.seed)
54 | np.random.seed(args.seed)
55 | torch.manual_seed(args.seed)
56 | torch.cuda.manual_seed(args.seed)
57 | torch.cuda.manual_seed_all(args.seed)
58 |
59 | trainloader, testloader = get_datasets(args.dataset, args.train_bsz, args.eval_bsz)
60 | if args.resume:
61 | # Load checkpoint.
62 | print('==> Resuming from checkpoint..')
63 | assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
64 | checkpoint = torch.load(f'./checkpoint/{args.net}_{args.dataset}_'+args.optim)
65 | model = checkpoint['model']
66 | start_epoch = checkpoint['epoch']
67 | train_losses = checkpoint['train_losses']
68 | test_losses = checkpoint['test_losses']
69 | train_errs = checkpoint['train_errs']
70 | test_errs = checkpoint['test_errs']
71 | else:
72 | print('==> Building model..')
73 |
74 | model = get_model(args)
75 |
76 | if use_cuda:
77 | model.cuda()
78 | model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
79 | cudnn.benchmark = True
80 |
81 |
82 | criterion = nn.CrossEntropyLoss()
83 |
84 | betas = (args.beta1, args.beta2)
85 | from optimizers.mars import MARS
86 | from optimizers.muon import Muon
87 | from opt import CombinedOptimizer
88 | from optimizers.adamw import AdamW
89 | if args.optim == 'adam':
90 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay = args.wd, betas = betas)
91 | elif args.optim == 'adamw':
92 | optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay = args.wd, betas = betas)
93 | elif args.optim == 'muon':
94 | optimizer = CombinedOptimizer(model.parameters(), [AdamW, Muon], [{'lr': args.adamw_lr, 'betas': betas, 'weight_decay': args.wd},
95 | {'lr': args.lr, 'weight_decay': 0.}])
96 | elif args.optim == 'mars':
97 | optimizer = MARS(model.parameters(), lr=args.lr, weight_decay = args.wd, lr_1d=args.adamw_lr)
98 |
99 | scheduler = get_scheduler(optimizer, args)
100 | best_acc = 0 # best test accuracy
101 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch
102 | train_errs = []
103 | test_errs = []
104 | train_losses = []
105 | test_losses = []
106 | acc_list = []
107 | t_bar = tqdm(total=len(trainloader))
108 | t_bar2 = tqdm(total=len(testloader))
109 | for epoch in range(start_epoch+1, args.Nepoch+1):
110 |
111 | scheduler.step()
112 | # print ('\nEpoch: %d' % epoch, ' Learning rate:', scheduler.get_lr())
113 | model.train() # Training
114 |
115 | train_loss = 0
116 | correct_train = 0
117 | total_train = 0
118 | print(scheduler.get_lr())
119 | t_bar.reset()
120 | for batch_idx, (inputs, targets) in enumerate(trainloader):
121 | if use_cuda:
122 | inputs, targets = inputs.cuda(), targets.cuda()
123 |
124 | optimizer.zero_grad()
125 | outputs = model(inputs)
126 | loss = criterion(outputs, targets)
127 | loss.backward()
128 | optimizer.step()
129 |
130 | train_loss += loss.item()
131 | _, predicted = torch.max(outputs.data, 1)
132 | total_train += targets.size(0)
133 | correct_train += predicted.eq(targets.data).cpu().sum().item()
134 |
135 | t_bar.update(1)
136 | t_bar.set_description('Epoch: %d | Loss: %.3f | Acc: %.3f%% ' % (epoch, train_loss/(batch_idx+1), 100.0/total_train*(correct_train)))
137 | t_bar.refresh()
138 | train_losses.append(train_loss/(batch_idx+1))
139 | train_errs.append(1 - correct_train/total_train)
140 |
141 | model.eval() # Testing
142 |
143 | test_loss = 0
144 | correct = 0
145 | total = 0
146 | t_bar2.reset()
147 | for batch_idx, (inputs, targets) in enumerate(testloader):
148 | if use_cuda:
149 | inputs, targets = inputs.cuda(), targets.cuda()
150 | outputs = model(inputs)
151 | loss = criterion(outputs, targets)
152 |
153 | test_loss += loss.item()
154 | _, predicted = torch.max(outputs.data, 1)
155 | total += targets.size(0)
156 | correct += predicted.eq(targets.data).cpu().sum().item()
157 |
158 | t_bar2.update(1)
159 | t_bar2.set_description('Loss: %.3f | Acc: %.3f%% (Best: %.3f%%)' % (test_loss/(batch_idx+1), 100.0/total*(correct), best_acc))
160 | t_bar2.refresh()
161 | test_errs.append(1 - correct/total)
162 | test_losses.append(test_loss/(batch_idx+1))
163 | if args.wandb:
164 | wandb.log({"epoch": epoch,
165 | "train_loss": train_loss/(batch_idx+1),
166 | "train_acc": 100.0/total_train*(correct_train),
167 | "test_loss": test_loss/(batch_idx+1),
168 | "test_acc": 100.0/total*(correct),
169 | "lr": scheduler.get_lr()[0]}, step=epoch)
170 | # Save checkpoint
171 | acc = 100.0/total*(correct)
172 | if acc > best_acc:
173 | if not os.path.isdir('checkpoint'):
174 | os.mkdir('checkpoint')
175 | state = {
176 | 'model': model,
177 | 'epoch': epoch,
178 | }
179 | # torch.save(state, './checkpoint/cnn_cifar10_' + args.optim)
180 | torch.save(state, os.path.join(args.save_dir, "-".join([args.net, args.dataset, args.optim, str(args.lr).replace(".", "_")])+".pth"))
181 | best_acc = acc
182 | t_bar2.set_description('Model Saved! | Loss: %.3f | Acc: %.3f%% (Best: %.3f%%)' % (test_loss/(batch_idx+1), 100.0/total*(correct), best_acc))
183 | t_bar2.refresh()
184 | acc_list.append(acc)
185 |
--------------------------------------------------------------------------------
/MARS_M/optimizers/moonlight.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 |
4 | # This code snippet is a modified version adapted from the following GitHub repository:
5 | # https://github.com/KellerJordan/Muon/blob/master/muon.py
6 | @torch.compile
7 | def zeropower_via_newtonschulz5(G, steps):
8 | """
9 | Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
10 | quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
11 | of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
12 | zero even beyond the point where the iteration no longer converges all the way to one everywhere
13 | on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
14 | where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
15 | performance at all relative to UV^T, where USV^T = G is the SVD.
16 | """
17 | assert len(G.shape) == 2
18 | a, b, c = (3.4445, -4.7750, 2.0315)
19 | X = G.bfloat16()
20 | if G.size(0) > G.size(1):
21 | X = X.T
22 | # Ensure spectral norm is at most 1
23 | X = X / (X.norm() + 1e-7)
24 | # Perform the NS iterations
25 | for _ in range(steps):
26 | A = X @ X.T
27 | B = (
28 | b * A + c * A @ A
29 | ) # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
30 | X = a * X + B @ X
31 |
32 | if G.size(0) > G.size(1):
33 | X = X.T
34 | return X
35 |
36 |
37 | class Moonlight(torch.optim.Optimizer):
38 | """
39 | Muon - MomentUm Orthogonalized by Newton-schulz
40 |
41 | Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
42 | processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
43 | matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
44 | the advantage that it can be stably run in bfloat16 on the GPU.
45 |
46 | Some warnings:
47 | - We believe this optimizer is unlikely to work well for training with small batch size.
48 | - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
49 |
50 | Arguments:
51 | muon_params: The parameters to be optimized by Muon.
52 | lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
53 | momentum: The momentum used by the internal SGD. (0.95 is a good default)
54 | nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
55 | ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
56 | adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are
57 | {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
58 | adamw_lr: The learning rate for the internal AdamW.
59 | adamw_betas: The betas for the internal AdamW.
60 | adamw_eps: The epsilon for the internal AdamW.
61 | adamw_wd: The weight decay for the internal AdamW.
62 | """
63 |
64 | def __init__(
65 | self,
66 | lr=1e-3,
67 | wd=0.1,
68 | muon_params=None,
69 | momentum=0.95,
70 | nesterov=True,
71 | ns_steps=5,
72 | adamw_params=None,
73 | adamw_betas=(0.9, 0.95), # MARS implement: 0.95, 0.95
74 | adamw_eps=1e-8,
75 | ):
76 |
77 | defaults = dict(
78 | lr=lr,
79 | wd=wd,
80 | momentum=momentum,
81 | nesterov=nesterov,
82 | ns_steps=ns_steps,
83 | adamw_betas=adamw_betas,
84 | adamw_eps=adamw_eps,
85 | )
86 |
87 | params = list(muon_params)
88 | adamw_params = list(adamw_params) if adamw_params is not None else []
89 | params.extend(adamw_params)
90 | super().__init__(params, defaults)
91 | # Sort parameters into those for which we will use Muon, and those for which we will not
92 | for p in muon_params:
93 | # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
94 | assert p.ndim == 2, p.ndim
95 | self.state[p]["use_muon"] = True
96 | for p in adamw_params:
97 | # Do not use Muon for parameters in adamw_params
98 | self.state[p]["use_muon"] = False
99 |
100 | def adjust_lr_for_muon(self, lr, param_shape):
101 | A, B = param_shape[:2]
102 | # We adjust the learning rate and weight decay based on the size of the parameter matrix
103 | # as describted in the paper
104 | adjusted_ratio = 0.2 * math.sqrt(max(A, B))
105 | adjusted_lr = lr * adjusted_ratio
106 | return adjusted_lr
107 |
108 | def step(self, closure=None):
109 | """Perform a single optimization step.
110 |
111 | Args:
112 | closure (Callable, optional): A closure that reevaluates the model
113 | and returns the loss.
114 | """
115 | loss = None
116 | if closure is not None:
117 | with torch.enable_grad():
118 | loss = closure()
119 |
120 | for group in self.param_groups:
121 |
122 | ############################
123 | # Muon #
124 | ############################
125 |
126 | params = [p for p in group["params"] if self.state[p]["use_muon"]]
127 | # import pdb; pdb.set_trace()
128 | lr = group["lr"]
129 | wd = group["wd"]
130 | momentum = group["momentum"]
131 |
132 | # generate weight updates in distributed fashion
133 | for p in params:
134 | # sanity check
135 | g = p.grad
136 | if g is None:
137 | continue
138 | if g.ndim > 2:
139 | g = g.view(g.size(0), -1)
140 | assert g is not None
141 |
142 | # calc update
143 | state = self.state[p]
144 | if "momentum_buffer" not in state:
145 | state["momentum_buffer"] = torch.zeros_like(g)
146 | buf = state["momentum_buffer"]
147 | buf.mul_(momentum).add_(g)
148 | if group["nesterov"]:
149 | g = g.add(buf, alpha=momentum)
150 | else:
151 | g = buf
152 | u = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
153 |
154 | # scale update
155 | adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
156 |
157 | # apply weight decay
158 | p.data.mul_(1 - lr * wd)
159 |
160 | # apply update
161 | p.data.add_(u, alpha=-adjusted_lr)
162 |
163 | ############################
164 | # AdamW backup #
165 | ############################
166 |
167 | params = [p for p in group["params"] if not self.state[p]["use_muon"]]
168 | lr = group['lr']
169 | beta1, beta2 = group["adamw_betas"]
170 | eps = group["adamw_eps"]
171 | weight_decay = group["wd"]
172 |
173 | for p in params:
174 | g = p.grad
175 | if g is None:
176 | continue
177 | state = self.state[p]
178 | if "step" not in state:
179 | state["step"] = 0
180 | state["moment1"] = torch.zeros_like(g)
181 | state["moment2"] = torch.zeros_like(g)
182 | state["step"] += 1
183 | step = state["step"]
184 | buf1 = state["moment1"]
185 | buf2 = state["moment2"]
186 | buf1.lerp_(g, 1 - beta1)
187 | buf2.lerp_(g.square(), 1 - beta2)
188 |
189 | g = buf1 / (eps + buf2.sqrt())
190 |
191 | bias_correction1 = 1 - beta1**step
192 | bias_correction2 = 1 - beta2**step
193 | scale = bias_correction1 / bias_correction2**0.5
194 | p.data.mul_(1 - lr * weight_decay)
195 | p.data.add_(g, alpha=-lr / scale)
196 |
197 | return loss
--------------------------------------------------------------------------------
/MARS_M/train_CV.py:
--------------------------------------------------------------------------------
1 | """
2 | Adapted from uclaml/Padam: https://github.com/uclaml/Padam/blob/master/run_cnn_test_cifar10.py
3 | """
4 | import numpy as np
5 | import os
6 | import argparse
7 | import json
8 | from tqdm import tqdm
9 |
10 | parser = argparse.ArgumentParser(description='PyTorch Training')
11 | parser.add_argument(
12 | "--dataset", type=str, default="cifar10", choices=["mnist", "cifar10", "cifar100"], help="dataset to use"
13 | )
14 | parser.add_argument(
15 | "--scheduler", type=str, default="multistep", choices=["multistep", "cosine", "constant"], help="scheduler to use"
16 | )
17 | parser.add_argument("--train_bsz", type=int, default=128, help="training batch size")
18 | parser.add_argument("--eval_bsz", type=int, default=100, help="eval batch size")
19 | parser.add_argument("--seed", type=int, default=0, help="random seed")
20 | parser.add_argument("--cpu", action="store_true", help="use cpu only")
21 | parser.add_argument("--cuda", type=str, default="0", help="device to use")
22 |
23 | parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
24 | parser.add_argument('--adamw_lr', default=0.003, type=float, help='learning rate for adamw')
25 | parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
26 | parser.add_argument('--optim', '-m', type=str, choices=["adam", "adamw", "mars-m", "moonlight"], default='mars', help='optimization method, default: mars')
27 | parser.add_argument('--net', '-n', type=str, default="resnet18", help='network archtecture, choosing from "simple_cnn" or torchvision models. default: resnet18')
28 | parser.add_argument('--wd', default=0., type=float, help='weight decay')
29 | parser.add_argument('--Nepoch', default=200, type=int, help='number of epoch')
30 | parser.add_argument('--beta1', default=0.9, type=float, help='beta1')
31 | parser.add_argument('--beta2', default=0.999, type=float, help='beta2')
32 | parser.add_argument('--gamma', default=0.025, type=float, help='gamma in MARS-M')
33 | parser.add_argument('--clip_c', action='store_true', help='whether to clip c_t in MARS-M')
34 | parser.add_argument('--mars_exact', action='store_true', help='whether to use the approximate version of MARS-M')
35 | parser.add_argument('--wandb', action='store_true', help='use wandb')
36 | parser.add_argument('--save_dir', type=str, default="./checkpoint", help='save directory')
37 | parser.add_argument('--wandb_name', type=str, default="None", help='log directory')
38 |
39 |
40 | args = parser.parse_args()
41 | os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda
42 | if args.wandb:
43 | import wandb
44 | if args.wandb_name == "None":
45 | wandb.init(project="CV", name=args.dataset+"_"+args.net+"_"+args.optim+"_"+str(args.lr), config=args)
46 | else:
47 | wandb.init(project="CV", name=args.wandb_name, config=args)
48 |
49 | import torch
50 | import torch.nn as nn
51 | import torch.optim as optim
52 | import torch.backends.cudnn as cudnn
53 | from utils.cv_utils import get_datasets, get_scheduler, get_model
54 | use_cuda = torch.cuda.is_available() and not args.cpu
55 |
56 | os.environ['PYTHONHASHSEED'] = str(args.seed)
57 | np.random.seed(args.seed)
58 | torch.manual_seed(args.seed)
59 | torch.cuda.manual_seed(args.seed)
60 | torch.cuda.manual_seed_all(args.seed)
61 |
62 | trainloader, testloader = get_datasets(args.dataset, args.train_bsz, args.eval_bsz)
63 | if args.resume:
64 | # Load checkpoint.
65 | print('==> Resuming from checkpoint..')
66 | assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
67 | checkpoint = torch.load(f'./checkpoint/{args.net}_{args.dataset}_'+args.optim)
68 | model = checkpoint['model']
69 | start_epoch = checkpoint['epoch']
70 | train_losses = checkpoint['train_losses']
71 | test_losses = checkpoint['test_losses']
72 | train_errs = checkpoint['train_errs']
73 | test_errs = checkpoint['test_errs']
74 | else:
75 | print('==> Building model..')
76 |
77 | model = get_model(args)
78 |
79 | if use_cuda:
80 | model.cuda()
81 | model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
82 | cudnn.benchmark = True
83 |
84 |
85 | criterion = nn.CrossEntropyLoss()
86 |
87 | betas = (args.beta1, args.beta2)
88 | from optimizers.mars_m import MARS_M
89 | from optimizers.moonlight import Moonlight
90 | if args.optim == 'adam':
91 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay = args.wd, betas = betas)
92 | elif args.optim == 'adamw':
93 | optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay = args.wd, betas = betas)
94 | elif args.optim == 'muon':
95 | muon_params = [p for p in model.parameters() if p.ndim == 2]
96 | adamw_params = [p for p in model.parameters() if p.ndim != 2]
97 | optimizer = Moonlight(muon_params=muon_params, adamw_params=adamw_params,
98 | lr=args.lr, wd=args.wd, adamw_betas=betas)
99 | elif args.optim == 'mars-m':
100 | muon_params = [p for p in model.parameters() if p.ndim >= 2]
101 | adamw_params = [p for p in model.parameters() if p.ndim < 2]
102 | if args.mars_exact:
103 | is_approx = False
104 | else:
105 | is_approx = True
106 | optimizer = MARS_M(muon_params=muon_params, adamw_params=adamw_params,
107 | lr=args.lr, wd=args.wd, adamw_betas=betas, gamma=args.gamma,
108 | clip_c=args.clip_c, is_approx=is_approx)
109 |
110 | scheduler = get_scheduler(optimizer, args)
111 | best_acc = 0 # best test accuracy
112 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch
113 | train_errs = []
114 | test_errs = []
115 | train_losses = []
116 | test_losses = []
117 | acc_list = []
118 | t_bar = tqdm(total=len(trainloader))
119 | t_bar2 = tqdm(total=len(testloader))
120 | previous_input = None
121 | previous_target = None
122 | for epoch in range(start_epoch+1, args.Nepoch+1):
123 |
124 | scheduler.step()
125 | # print ('\nEpoch: %d' % epoch, ' Learning rate:', scheduler.get_lr())
126 | model.train() # Training
127 |
128 | train_loss = 0
129 | correct_train = 0
130 | total_train = 0
131 | t_bar.reset()
132 | for batch_idx, (inputs, targets) in enumerate(trainloader):
133 | if args.mars_exact:
134 | if previous_input is not None:
135 | if use_cuda:
136 | p_inputs, p_targets = previous_input.cuda(), previous_target.cuda()
137 | optimizer.zero_grad()
138 | p_outputs = model(p_inputs)
139 | p_loss = criterion(p_outputs, p_targets)
140 | p_loss.backward()
141 | optimizer.update_previous_grad()
142 | else:
143 | # copy inputs to previous_input
144 | previous_input = inputs.clone()
145 | previous_target = targets.clone()
146 | if use_cuda:
147 | inputs, targets = inputs.cuda(), targets.cuda()
148 |
149 | optimizer.zero_grad()
150 | outputs = model(inputs)
151 | loss = criterion(outputs, targets)
152 | loss.backward()
153 | optimizer.step()
154 | if args.mars_exact:
155 | optimizer.update_last_grad()
156 |
157 | train_loss += loss.item()
158 | _, predicted = torch.max(outputs.data, 1)
159 | total_train += targets.size(0)
160 | correct_train += predicted.eq(targets.data).cpu().sum().item()
161 |
162 | t_bar.update(1)
163 | t_bar.set_description('Epoch: %d | Loss: %.3f | Acc: %.3f%% ' % (epoch, train_loss/(batch_idx+1), 100.0/total_train*(correct_train)))
164 | t_bar.refresh()
165 | train_losses.append(train_loss/(batch_idx+1))
166 | train_errs.append(1 - correct_train/total_train)
167 |
168 | model.eval() # Testing
169 |
170 | test_loss = 0
171 | correct = 0
172 | total = 0
173 | t_bar2.reset()
174 | for batch_idx, (inputs, targets) in enumerate(testloader):
175 | if use_cuda:
176 | inputs, targets = inputs.cuda(), targets.cuda()
177 | outputs = model(inputs)
178 | loss = criterion(outputs, targets)
179 |
180 | test_loss += loss.item()
181 | _, predicted = torch.max(outputs.data, 1)
182 | total += targets.size(0)
183 | correct += predicted.eq(targets.data).cpu().sum().item()
184 |
185 | t_bar2.update(1)
186 | t_bar2.set_description('Loss: %.3f | Acc: %.3f%% (Best: %.3f%%)' % (test_loss/(batch_idx+1), 100.0/total*(correct), best_acc))
187 | t_bar2.refresh()
188 | test_errs.append(1 - correct/total)
189 | test_losses.append(test_loss/(batch_idx+1))
190 | if args.wandb:
191 | wandb.log({"epoch": epoch,
192 | "train_loss": train_loss/(batch_idx+1),
193 | "train_acc": 100.0/total_train*(correct_train),
194 | "test_loss": test_loss/(batch_idx+1),
195 | "test_acc": 100.0/total*(correct),
196 | "lr": scheduler.get_lr()[0]}, step=epoch)
197 | # Save checkpoint
198 | acc = 100.0/total*(correct)
199 | if acc > best_acc:
200 | if not os.path.isdir('checkpoint'):
201 | os.mkdir('checkpoint')
202 | state = {
203 | 'model': model,
204 | 'epoch': epoch,
205 | }
206 | # torch.save(state, './checkpoint/cnn_cifar10_' + args.optim)
207 | torch.save(state, os.path.join(args.save_dir, "-".join([args.net, args.dataset, args.optim, str(args.lr).replace(".", "_")])+".pth"))
208 | best_acc = acc
209 | t_bar2.set_description('Model Saved! | Loss: %.3f | Acc: %.3f%% (Best: %.3f%%)' % (test_loss/(batch_idx+1), 100.0/total*(correct), best_acc))
210 | t_bar2.refresh()
211 | acc_list.append(acc)
212 |
--------------------------------------------------------------------------------
/MARS/train_CNN.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
2 | # SPDX-License-Identifier: Apache-2.0
3 | import argparse
4 | from typing import List, Tuple, Type
5 |
6 | import matplotlib.pyplot as plt
7 | import torch
8 | import torch.nn as nn
9 | from torch.optim import Adam, AdamW
10 | from torch.optim.lr_scheduler import CosineAnnealingLR
11 | from torch.utils.data import DataLoader
12 | from torchvision import datasets, transforms
13 | import numpy as np
14 | from utils.model_CNN import Network
15 | from optimizers.adopt import ADOPT
16 | from optimizers.mars import MARS
17 | import random
18 | parser = argparse.ArgumentParser(add_help=True)
19 | parser.add_argument(
20 | "--dataset", type=str, default="cifar10", choices=["mnist", "cifar10"], help="dataset to use"
21 | )
22 | parser.add_argument("-b", "--batch_size", type=int, default=128, help="batch size")
23 | parser.add_argument("-e", "--epochs", type=int, default=50, help="number of epochs")
24 | parser.add_argument("--seed", type=int, default=0, help="random seed")
25 | parser.add_argument("--cpu", action="store_true", help="use cpu only")
26 |
27 |
28 | def get_datasets(dataset_name: str, batch_size: int) -> Tuple[DataLoader, DataLoader]:
29 | """Get train and test dataloaders."""
30 | if dataset_name == "mnist":
31 | transform = transforms.Compose([
32 | transforms.ToTensor(),
33 | transforms.Normalize((0.1307,), (0.3081,))
34 | ])
35 | train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
36 | test_dataset = datasets.MNIST('./data', train=False, transform=transform)
37 | elif dataset_name == "cifar10":
38 | transform_train = transforms.Compose([
39 | transforms.RandomHorizontalFlip(),
40 | transforms.ToTensor(),
41 | # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
42 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
43 | ])
44 | transform_test = transforms.Compose([
45 | transforms.ToTensor(),
46 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
47 | ])
48 | train_dataset = datasets.CIFAR10('./data', train=True, download=True, transform=transform_train)
49 | test_dataset = datasets.CIFAR10('./data', train=False, transform=transform_test)
50 | else:
51 | raise NotImplementedError(f"{dataset_name=} is not implemented.")
52 |
53 | train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
54 | test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
55 |
56 | return train_loader, test_loader
57 |
58 |
59 | class WarmupCosineScheduler:
60 | """Custom learning rate scheduler with linear warmup and cosine decay."""
61 | def __init__(self, optimizer, warmup_iters: int, total_iters: int, min_lr: float, max_lr: float):
62 | self.optimizer = optimizer
63 | self.warmup_iters = warmup_iters
64 | self.total_iters = total_iters
65 | self.min_lr = min_lr
66 | self.max_lr = max_lr
67 | self.current_iter = 0
68 | self.lr = 0
69 |
70 | def step(self):
71 | self.current_iter += 1
72 | if self.current_iter <= self.warmup_iters:
73 | lr = self.current_iter / self.warmup_iters * self.max_lr
74 | else:
75 | lr = self.min_lr + 0.5 * (self.max_lr - self.min_lr) * (
76 | np.cos((self.current_iter - self.warmup_iters) / (self.total_iters - self.warmup_iters) * 3.14159265 / 2)
77 | ).item()
78 |
79 | for param_group in self.optimizer.param_groups:
80 | param_group['lr'] = lr
81 | self.lr = lr
82 |
83 | class Trainer:
84 | """Training manager for PyTorch models."""
85 | def __init__(self, model: nn.Module, optimizer: torch.optim.Optimizer, scheduler, device: torch.device):
86 | self.model = model
87 | self.optimizer = optimizer
88 | self.scheduler = scheduler
89 | self.device = device
90 | self.criterion = nn.CrossEntropyLoss()
91 | self.train_acc_trace = []
92 | self.val_acc_trace = []
93 |
94 | def train_epoch(self, train_loader: DataLoader) -> float:
95 | self.model.train()
96 | correct = 0
97 | total = 0
98 |
99 | for batch in train_loader:
100 | images, targets = batch[0].to(self.device), batch[1].to(self.device)
101 |
102 | self.optimizer.zero_grad()
103 | outputs = self.model(images)
104 | loss = self.criterion(outputs, targets)
105 | loss.backward()
106 | self.optimizer.step()
107 |
108 | _, predicted = outputs.max(1)
109 | total += targets.size(0)
110 | correct += predicted.eq(targets).sum().item()
111 | if self.scheduler is not None:
112 | self.scheduler.step()
113 | return 100. * correct / total
114 |
115 | def evaluate(self, test_loader: DataLoader) -> float:
116 | self.model.eval()
117 | correct = 0
118 | total = 0
119 |
120 | with torch.no_grad():
121 | for batch in test_loader:
122 | images, targets = batch[0].to(self.device), batch[1].to(self.device)
123 | outputs = self.model(images)
124 |
125 | _, predicted = outputs.max(1)
126 | total += targets.size(0)
127 | correct += predicted.eq(targets).sum().item()
128 |
129 | return 100. * correct / total
130 |
131 | def train(self, train_loader: DataLoader, test_loader: DataLoader, epochs: int):
132 | for epoch in range(epochs):
133 | train_acc = self.train_epoch(train_loader)
134 | val_acc = self.evaluate(test_loader)
135 |
136 | self.train_acc_trace.append(train_acc)
137 | self.val_acc_trace.append(val_acc)
138 |
139 | # if self.scheduler is not None:
140 | # self.scheduler.step()
141 |
142 | print(f"Epoch {epoch+1}/{epochs} - Train Acc: {train_acc:.2f}% - Val Acc: {val_acc:.2f}%")
143 |
144 |
145 | def get_optimizers(model: nn.Module, opt_name, args):
146 | """Configure optimizers and schedulers."""
147 | total_steps = 50_000 // args.batch_size * args.epochs
148 | n_warmup = int(total_steps * 0.10) # % of total steps
149 | weight_decay = 1e-4
150 | max_lr = 6e-4
151 | min_lr = 1e-6
152 |
153 | if opt_name == "Adam":
154 | # Adam
155 | adam = Adam(model.parameters(), lr=max_lr)
156 | adam_scheduler = WarmupCosineScheduler(
157 | adam, n_warmup, total_steps, min_lr, max_lr
158 | )
159 | optimizer = (adam, adam_scheduler, "Adam")
160 |
161 | elif opt_name == "AdamW":
162 | # AdamW
163 | adamw = AdamW(model.parameters(), lr=max_lr, weight_decay=weight_decay)
164 | adamw_scheduler = WarmupCosineScheduler(
165 | adamw, n_warmup, total_steps, min_lr, max_lr
166 | )
167 | optimizer = (adamw, adamw_scheduler, "AdamW")
168 | elif opt_name == "ADOPT":
169 | # ADOPT
170 | adopt = ADOPT(model.parameters(), lr=max_lr, weight_decay=weight_decay)
171 | adopt_scheduler = WarmupCosineScheduler(
172 | adopt, n_warmup, total_steps, min_lr, max_lr
173 | )
174 | optimizer = (adopt, adopt_scheduler, "ADOPT")
175 | elif opt_name == "MARS":
176 | # MARS
177 | mars = MARS(model.parameters(), lr=3e-3, weight_decay=weight_decay, optimize_1d=False)
178 | mars_scheduler = WarmupCosineScheduler(
179 | mars, n_warmup, total_steps, min_lr, 3e-3
180 | )
181 | optimizer = (mars, mars_scheduler, "MARS")
182 | return optimizer
183 |
184 |
185 | def plot_results(results: List[List[float]], optimizer_names: List[str], args):
186 | """Plot training results."""
187 | fig, ax = plt.subplots(figsize=(5.5, 3.5))
188 | colors = ["#74add1", "#1730bd", "#1a9850", "#001c01"]
189 |
190 | for i, acc in enumerate(results):
191 | ax.plot(range(1, len(acc) + 1), acc, label=optimizer_names[i], lw=2, color=colors[i])
192 |
193 | ax.set_title(f"{args.dataset.upper()} (val)", loc="left")
194 | ax.set_xlabel("Epoch", fontsize="medium")
195 | ax.set_ylabel("Accuracy (%)", fontsize="medium")
196 |
197 | ax.legend(ncols=2, columnspacing=0.8, fontsize="medium")
198 | ax.grid(alpha=0.2)
199 |
200 | ax.set_ylim(90 if args.dataset == "mnist" else 70)
201 | acc_min, acc_max = ax.get_ylim()
202 | ax.set_yticks(torch.linspace(acc_min, acc_max, 5).int().tolist())
203 | ax.xaxis.set_major_locator(plt.MaxNLocator(integer=True))
204 |
205 | fig.tight_layout()
206 | fig.savefig(
207 | f"./compare-{args.dataset}-blank.png",
208 | dpi=300,
209 | bbox_inches="tight",
210 | )
211 | plt.show()
212 |
213 |
214 | def main(args):
215 | # Set random seed and device
216 | torch.manual_seed(args.seed)
217 | device = torch.device("cuda" if torch.cuda.is_available() and not args.cpu else "cpu")
218 |
219 | # Get dataloaders
220 | train_loader, test_loader = get_datasets(args.dataset, args.batch_size)
221 | # Model configuration
222 | model_config = {
223 | "n_inputs": (3, 32, 32) if args.dataset == "cifar10" else (1, 28, 28),
224 | "conv_layers_list": [
225 | {"filters": 32, "kernel_size": 3, "repeat": 2, "batch_norm": True},
226 | {"filters": 64, "kernel_size": 3, "repeat": 2, "batch_norm": True},
227 | {"filters": 128, "kernel_size": 3, "repeat": 2, "batch_norm": True},
228 | ],
229 | "n_hiddens_list": [512],
230 | "n_outputs": 10,
231 | "dropout": 0.2,
232 | }
233 |
234 | results = []
235 | optimizer_names = []
236 | # Train with different optimizers
237 | opt_names = ["Adam", "AdamW", "ADOPT", "MARS"]
238 | for opt_name in opt_names:
239 | print(opt_name)
240 | torch.manual_seed(args.seed)
241 | model = Network(**model_config).to(device)
242 | optimizer, scheduler, name = get_optimizers(model, opt_name, args)
243 | trainer = Trainer(model, optimizer, scheduler, device)
244 | trainer.train(train_loader, test_loader, args.epochs)
245 | results.append(trainer.val_acc_trace)
246 | optimizer_names.append(name)
247 |
248 | plot_results(results, optimizer_names, args)
249 |
250 |
251 | if __name__ == "__main__":
252 | args = parser.parse_args()
253 | main(args)
--------------------------------------------------------------------------------
/MARS_M/optimizers/mars_m.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 |
4 | @torch.compile
5 | def zeropower_via_newtonschulz5(G, steps):
6 | """
7 | Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
8 | quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
9 | of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
10 | zero even beyond the point where the iteration no longer converges all the way to one everywhere
11 | on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
12 | where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
13 | performance at all relative to UV^T, where USV^T = G is the SVD.
14 |
15 | This version allows for G to be more than 2D.
16 | """
17 | # assert len(G.shape) == 2
18 | a, b, c = (3.4445, -4.7750, 2.0315)
19 | X = G.bfloat16()
20 | if G.size(-2) > G.size(-1):
21 | X = X.transpose(-2, -1)
22 | # Ensure spectral norm is at most 1
23 | X = X / (X.norm() + 1e-7)
24 | # Perform the NS iterations
25 | for _ in range(steps):
26 | A = X @ X.transpose(-2, -1)
27 | B = (
28 | b * A + c * A @ A
29 | ) # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
30 | X = a * X + B @ X
31 |
32 | if G.size(-2) > G.size(-1):
33 | X = X.transpose(-2, -1)
34 | return X
35 |
36 | class MARS_M(torch.optim.Optimizer):
37 | """
38 | MARS_M: MARS on Matrix-Level
39 |
40 | Arguments:
41 | lr: The learning rate. The updates will have spectral norm of `lr`.
42 | wd: The weight decay.
43 | muon_params: The parameters to be optimized by Muon.
44 | momentum: The momentum used by the internal SGD.
45 | ns_steps: The number of Newton-Schulz iterations to run.
46 | adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are
47 | {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
48 | adamw_betas: The betas used by AdamW.
49 | adamw_eps: The epsilon used by AdamW.
50 | gamma: The gamma parameter in MARS.
51 | clip_c: Whether to clip the c_t vector to have norm at most 1.
52 | is_approx: Whether to use the approximate version of MARS-M (True) or the exact version (False).
53 | The approximate version does not require any extra gradient computations, while the exact version
54 | requires one extra gradient computation per parameter per step.
55 | However, the exact version may yield slightly better performance.
56 | See the MARS-M paper for details.
57 |
58 | """
59 |
60 | def __init__(
61 | self,
62 | lr=1e-3,
63 | wd=0.1,
64 | muon_params=None,
65 | momentum=0.95,
66 | ns_steps=5,
67 | adamw_params=None,
68 | adamw_betas=(0.9, 0.95),
69 | adamw_eps=1e-8,
70 | gamma=0.025,
71 | clip_c=False,
72 | is_approx=True,
73 | ):
74 | mars_factor = gamma * momentum / (1-momentum)
75 | defaults = dict(
76 | lr=lr,
77 | wd=wd,
78 | momentum=momentum,
79 | ns_steps=ns_steps,
80 | adamw_betas=adamw_betas,
81 | adamw_eps=adamw_eps,
82 | gamma=gamma,
83 | mars_factor=mars_factor,
84 | clip_c=clip_c,
85 | is_approx=is_approx
86 | )
87 |
88 | params = list(muon_params)
89 | adamw_params = list(adamw_params) if adamw_params is not None else []
90 | params.extend(adamw_params)
91 | super().__init__(params, defaults)
92 | self.is_approx = is_approx
93 | # Sort parameters into those for which we will use Muon, and those for which we will not
94 | for p in muon_params:
95 | # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
96 | # assert p.ndim >= 2, p.ndim
97 | self.state[p]["use_muon"] = True
98 | for p in adamw_params:
99 | # Do not use Muon for parameters in adamw_params
100 | self.state[p]["use_muon"] = False
101 |
102 | def adjust_lr_for_muon(self, lr, param_shape):
103 | A, B = param_shape[:2]
104 | # We adjust the learning rate and weight decay based on the size of the parameter matrix
105 | # as describted in the paper
106 | adjusted_ratio = 0.2 * math.sqrt(max(A, B))
107 | adjusted_lr = lr * adjusted_ratio
108 | return adjusted_lr
109 |
110 | @torch.no_grad()
111 | def update_last_grad(self):
112 | if not self.is_approx:
113 | for group in self.param_groups:
114 | for p in group['params']:
115 | state = self.state[p]
116 | ## only update previous grad for muon params
117 | if not state["use_muon"]:
118 | continue
119 | ## end skip
120 | if "last_grad" not in state:
121 | state["last_grad"] = torch.zeros_like(p)
122 | state["last_grad"].zero_().add_(state["previous_grad"], alpha=1.0)
123 | @torch.no_grad()
124 | def update_previous_grad(self):
125 | if not self.is_approx:
126 | for group in self.param_groups:
127 | for p in group['params']:
128 | if p.grad is None:
129 | print (p, "grad is none")
130 | continue
131 | state = self.state[p]
132 | ## only update previous grad for muon params
133 | if not state["use_muon"]:
134 | continue
135 | ## end skip
136 | if "previous_grad" not in state:
137 | state['previous_grad'] = torch.zeros_like(p)
138 | state['previous_grad'].zero_().add_(p.grad, alpha=1.0)
139 |
140 | def step(self, closure=None):
141 | """Performs a single optimization step.
142 |
143 | Arguments:
144 | closure (callable, optional): A closure that reevaluates the model
145 | and returns the loss.
146 |
147 | If using exact version, the example usage is as follows:
148 | previous_X, previous_Y = None, None
149 | for epoch in range(epochs):
150 | for X, Y in data_loader:
151 | if previous_X:
152 | logits, loss = model(X, Y)
153 | loss.backward()
154 | optimizer.update_previous_grad()
155 | optimizer.zero_grad(set_to_none=True)
156 | logits, loss = model(X, Y)
157 | loss.backward()
158 | optimizer.step(bs=bs)
159 | optimizer.zero_grad(set_to_none=True)
160 | optimizer.update_last_grad()
161 | iter_num += 1
162 | previous_X, previous_Y = X.clone(), Y.clone()
163 | """
164 | loss = None
165 | if closure is not None:
166 | with torch.enable_grad():
167 | loss = closure()
168 |
169 | for group in self.param_groups:
170 |
171 | ############################
172 | # Muon #
173 | ############################
174 |
175 | params = [p for p in group["params"] if self.state[p]["use_muon"]]
176 | # import pdb; pdb.set_trace()
177 | lr = group["lr"]
178 | wd = group["wd"]
179 | momentum = group["momentum"]
180 | gamma = group["gamma"]
181 | mars_factor = group["mars_factor"]
182 | for p in params:
183 | # sanity check
184 | g = p.grad
185 | if g is None:
186 | continue
187 | assert g is not None
188 | state = self.state[p]
189 |
190 | # default: MARS, nesterov
191 | if "last_grad" not in state:
192 | state["last_grad"] = torch.zeros_like(g)
193 | # calc update
194 | if "momentum_buffer" not in state:
195 | state["momentum_buffer"] = torch.zeros_like(g)
196 | # mars_factor = gamma * momentum / (1-momentum)
197 | c_t = (g - state["last_grad"]).mul(mars_factor).add(g)
198 | c_t_norm = c_t.norm()
199 | if c_t_norm > 1:
200 | c_t.div_(c_t_norm)
201 | buf = state["momentum_buffer"]
202 | buf.mul_(momentum).add_(c_t, alpha=(1 - momentum))
203 |
204 | u = zeropower_via_newtonschulz5(buf, steps=group["ns_steps"])
205 |
206 |
207 | # scale update
208 | adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
209 |
210 | # apply weight decay
211 | p.data.mul_(1 - lr * wd)
212 |
213 | # apply update
214 | p.data.add_(u, alpha=-adjusted_lr)
215 | if self.is_approx:
216 | state["last_grad"] = g
217 |
218 | ############################
219 | # AdamW backup #
220 | ############################
221 |
222 | params = [p for p in group["params"] if not self.state[p]["use_muon"]]
223 | lr = group['lr']
224 | beta1, beta2 = group["adamw_betas"]
225 | eps = group["adamw_eps"]
226 | weight_decay = group["wd"]
227 |
228 | for p in params:
229 | g = p.grad
230 | if g is None:
231 | continue
232 | state = self.state[p]
233 | if "step" not in state:
234 | state["step"] = 0
235 | state["moment1"] = torch.zeros_like(g)
236 | state["moment2"] = torch.zeros_like(g)
237 | state["step"] += 1
238 | step = state["step"]
239 | buf1 = state["moment1"]
240 | buf2 = state["moment2"]
241 | buf1.lerp_(g, 1 - beta1)
242 | buf2.lerp_(g.square(), 1 - beta2)
243 |
244 | g = buf1 / (eps + buf2.sqrt())
245 |
246 | bias_correction1 = 1 - beta1**step
247 | bias_correction2 = 1 - beta2**step
248 | scale = bias_correction1 / bias_correction2**0.5
249 | p.data.mul_(1 - lr * weight_decay)
250 | p.data.add_(g, alpha=-lr / scale)
251 |
252 | return loss
--------------------------------------------------------------------------------
/MARS/optimizers/mars.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
2 | # SPDX-License-Identifier: Apache-2.0
3 | import math
4 | import torch
5 | from torch.optim.optimizer import Optimizer
6 | import os
7 | import numpy as np
8 | import math
9 | # from megatron.optimizer.l2_norm import l2_norm
10 |
11 | def exists(val):
12 | return val is not None
13 |
14 |
15 | def update_fn(p, grad, exp_avg, exp_avg_sq, lr, wd, beta1, beta2, last_grad, eps, amsgrad, max_exp_avg_sq, step, gamma,
16 | mars_type, is_grad_2d, optimize_1d, lr_1d_factor, betas_1d, weight_decay_1d):
17 | # optimize_1d: use MARS for 1d para, not: use AdamW for 1d para
18 | if optimize_1d or is_grad_2d:
19 | c_t = (grad - last_grad).mul(gamma * (beta1 / (1. - beta1))).add(grad)
20 | c_t_norm = torch.norm(c_t)
21 | if c_t_norm > 1.:
22 | c_t = c_t / c_t_norm
23 | exp_avg.mul_(beta1).add_(c_t, alpha=1. - beta1)
24 | if (mars_type == "mars-adamw") or (mars_type == "mars-shampoo" and not is_grad_2d):
25 | exp_avg_sq.mul_(beta2).addcmul_(c_t, c_t, value=1. - beta2)
26 | bias_correction1 = 1.0 - beta1 ** step
27 | bias_correction2 = 1.0 - beta2 ** step
28 | if amsgrad:
29 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
30 | denom = max_exp_avg_sq.sqrt().mul(1/math.sqrt(bias_correction2)).add(eps).mul(bias_correction1)
31 | else:
32 | denom = exp_avg_sq.sqrt().mul(1/math.sqrt(bias_correction2)).add(eps).mul(bias_correction1)
33 | real_update_tmp = -lr * torch.mul(p.data, wd).add(exp_avg.div(denom))
34 | elif mars_type == "mars-lion":
35 | real_update_tmp = -lr * torch.mul(p.data, wd).add(exp_avg.sign())
36 | elif mars_type == "mars-shampoo" and is_grad_2d:
37 | factor = max(1, grad.size(0)/grad.size(1))**0.5
38 | real_update_tmp = NewtonSchulz(exp_avg.mul(1./(1.-beta1)), eps=eps).mul(factor).add(wd, p.data).mul(-lr)
39 | p.data.add_(real_update_tmp)
40 | else:
41 | beta1_1d, beta2_1d = betas_1d
42 | exp_avg.mul_(beta1_1d).add_(grad, alpha=1. - beta1_1d)
43 | exp_avg_sq.mul_(beta2_1d).addcmul_(grad, grad, value=1. - beta2_1d)
44 | bias_correction1 = 1.0 - beta1_1d ** step
45 | bias_correction2 = 1.0 - beta2_1d ** step
46 | if amsgrad:
47 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
48 | denom = max_exp_avg_sq.sqrt().mul(1/math.sqrt(bias_correction2)).add(eps).mul(bias_correction1)
49 | else:
50 | denom = exp_avg_sq.sqrt().mul(1/math.sqrt(bias_correction2)).add(eps).mul(bias_correction1)
51 | real_update_tmp = -lr * lr_1d_factor * torch.mul(p.data, weight_decay_1d).add(exp_avg.div(denom))
52 | p.data.add_(real_update_tmp)
53 | return exp_avg, exp_avg_sq
54 |
55 | class MARS(Optimizer):
56 | def __init__(self, params, lr=3e-3, betas=(0.95, 0.99), eps=1e-8, weight_decay=0., amsgrad=False, gamma=0.025,
57 | is_approx=True, mars_type="mars-adamw", optimize_1d=False, lr_1d=3e-3, betas_1d=(0.9, 0.95), weight_decay_1d=0.1):
58 | if not 0.0 <= lr:
59 | raise ValueError("Invalid learning rate: {}".format(lr))
60 | if not 0.0 <= eps:
61 | raise ValueError("Invalid epsilon value: {}".format(eps))
62 | if not 0.0 <= betas[0] < 1.0:
63 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
64 | if not 0.0 <= betas[1] < 1.0:
65 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
66 | assert mars_type in ["mars-adamw", "mars-lion", "mars-shampoo"], "MARS type not supported"
67 | defaults = dict(lr=lr, betas=betas, eps=eps,
68 | weight_decay=weight_decay, amsgrad=amsgrad,
69 | mars_type=mars_type, gamma=gamma,
70 | optimize_1d=optimize_1d, weight_decay_1d=weight_decay_1d)
71 | super(MARS, self).__init__(params, defaults)
72 | self.eps = eps
73 | self.update_fn = update_fn
74 | self.lr = lr
75 | self.weight_decay=weight_decay
76 | self.amsgrad = amsgrad
77 | self.step_num = 0
78 | self.is_approx = is_approx
79 | self.gamma = gamma
80 | self.mars_type = mars_type
81 | self.optimize_1d = optimize_1d
82 | self.lr_1d_factor = lr_1d / lr
83 | self.weight_decay_1d = weight_decay_1d
84 | self.betas_1d = betas_1d
85 |
86 | @torch.no_grad()
87 | def update_last_grad(self):
88 | if not self.is_approx:
89 | for group in self.param_groups:
90 | for p in group['params']:
91 | state = self.state[p]
92 | if "last_grad" not in state:
93 | state["last_grad"] = torch.zeros_like(p)
94 | state["last_grad"].zero_().add_(state["previous_grad"], alpha=1.0)
95 | @torch.no_grad()
96 | def update_previous_grad(self):
97 | if not self.is_approx:
98 | for group in self.param_groups:
99 | #print ("para name", len(group['params']), len(group['names']), group['names'])
100 | for p in group['params']:
101 | # import pdb
102 | # pdb.set_trace()
103 | if p.grad is None:
104 | print (p, "grad is none")
105 | continue
106 | state = self.state[p]
107 | if "previous_grad" not in state:
108 | state['previous_grad'] = torch.zeros_like(p)
109 | state['previous_grad'].zero_().add_(p.grad, alpha=1.0)
110 |
111 | def __setstate__(self, state):
112 | super(MARS, self).__setstate__(state)
113 | for group in self.param_groups:
114 | group.setdefault('amsgrad', False)
115 |
116 | @torch.no_grad()
117 | def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None, grad_scaler=None):
118 | """Performs a single optimization step.
119 |
120 | Arguments:
121 | closure (callable, optional): A closure that reevaluates the model
122 | and returns the loss.
123 |
124 | If using exact version, the example usage is as follows:
125 | previous_X, previous_Y = None, None
126 | for epoch in range(epochs):
127 | for X, Y in data_loader:
128 | if previous_X:
129 | logits, loss = model(X, Y)
130 | loss.backward()
131 | optimizer.update_previous_grad()
132 | optimizer.zero_grad(set_to_none=True)
133 | logits, loss = model(X, Y)
134 | loss.backward()
135 | optimizer.step(bs=bs)
136 | optimizer.zero_grad(set_to_none=True)
137 | optimizer.update_last_grad()
138 | iter_num += 1
139 | previous_X, previous_Y = X.clone(), Y.clone()
140 | """
141 | if any(p is not None for p in [grads, output_params, scale, grad_norms]):
142 | raise RuntimeError('FusedAdam has been updated. Simply initialize it identically to torch.optim.Adam, and call step() with no arguments.')
143 |
144 | loss = None
145 | if exists(closure):
146 | with torch.enable_grad():
147 | loss = closure()
148 | real_update = 0
149 | real_update_wo_lr = 0
150 | gamma = self.gamma
151 | # import pdb
152 | # pdb.set_trace()
153 | for group in self.param_groups:
154 | for p in filter(lambda p: exists(p.grad), group['params']):
155 | if p.grad is None:
156 | continue
157 | grad = p.grad.data
158 | if grad.is_sparse:
159 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
160 | amsgrad = group['amsgrad']
161 |
162 | state = self.state[p]
163 | #('----- starting a parameter state', state.keys(), 'Length of state', len(state))
164 | # State initialization
165 | if len(state) <= 1:
166 | state['step'] = 0
167 | # Exponential moving average of gradient values
168 | state['exp_avg'] = torch.zeros_like(p.data)
169 | # Last Gradient
170 | state['last_grad'] = torch.zeros_like(p)
171 | #state['previous_grad'] = torch.zeros_like(p)
172 | # Exponential moving average of squared gradient values
173 | state['exp_avg_sq'] = torch.zeros_like(p.data)
174 | if amsgrad:
175 | # Maintains max of all exp. moving avg. of sq. grad. values
176 | state['max_exp_avg_sq'] = torch.zeros_like(p.data)
177 | # import pdb
178 | # pdb.set_trace()
179 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
180 | last_grad = state['last_grad']
181 | lr, wd, beta1, beta2 = group['lr'], group['weight_decay'], *group['betas']
182 | if amsgrad:
183 | max_exp_avg_sq = state['max_exp_avg_sq']
184 | else:
185 | max_exp_avg_sq = 0
186 |
187 | if 'step' in state:
188 | state['step'] += 1
189 | else:
190 | state['step'] = 1
191 | step = state['step']
192 | is_grad_2d = (len(grad.shape) == 2)
193 | exp_avg, exp_avg_sq = self.update_fn(
194 | p,
195 | grad,
196 | exp_avg,
197 | exp_avg_sq,
198 | lr,
199 | wd,
200 | beta1,
201 | beta2,
202 | last_grad,
203 | self.eps,
204 | amsgrad,
205 | max_exp_avg_sq,
206 | step,
207 | gamma,
208 | mars_type=self.mars_type,
209 | is_grad_2d=is_grad_2d,
210 | optimize_1d=self.optimize_1d,
211 | lr_1d_factor=self.lr_1d_factor,
212 | betas_1d=self.betas_1d,
213 | weight_decay_1d=self.weight_decay if self.optimize_1d else self.weight_decay_1d
214 | )
215 | if self.is_approx:
216 | state['last_grad'] = grad
217 | self.step_num = step
218 |
219 | return loss
220 |
221 | @torch.compile
222 | def NewtonSchulz(M, steps=5, eps=1e-7):
223 | a, b, c = (3.4445, -4.7750, 2.0315)
224 | X = M.bfloat16() / (M.norm() + eps)
225 | if M.size(0) > M.size(1):
226 | X = X.T
227 | for _ in range(steps):
228 | A = X @ X.T
229 | B = A @ X
230 | X = a * X + b * B + c * A @ B
231 | if M.size(0) > M.size(1):
232 | X = X.T
233 | return X.to(M.dtype)
234 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/MARS/utils/model_CNN.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple, Type, Union
2 | import importlib
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | def pair(t):
9 | return t if isinstance(t, tuple) else (t, t)
10 |
11 |
12 | def get_activation(activation_f: str) -> Type:
13 | """Get PyTorch activation function by name."""
14 | package_name = "torch.nn"
15 | module = importlib.import_module(package_name)
16 |
17 | activations = [getattr(module, attr) for attr in dir(module)]
18 | activations = [
19 | cls for cls in activations if isinstance(cls, type) and issubclass(cls, nn.Module)
20 | ]
21 | names = [cls.__name__.lower() for cls in activations]
22 |
23 | try:
24 | index = names.index(activation_f.lower())
25 | return activations[index]
26 | except ValueError:
27 | raise NotImplementedError(f"get_activation: {activation_f=} is not yet implemented.")
28 |
29 |
30 | def compute_padding(
31 | input_size: tuple, kernel_size: int | tuple, stride: int | tuple = 1, dilation: int | tuple = 1
32 | ) -> Tuple[int, int]:
33 | """Compute padding for 'same' convolution."""
34 | if len(input_size) == 2:
35 | input_size = (*input_size, 1)
36 | if isinstance(kernel_size, int):
37 | kernel_size = (kernel_size, kernel_size)
38 | if isinstance(stride, int):
39 | stride = (stride, stride)
40 | if isinstance(dilation, int):
41 | dilation = (dilation, dilation)
42 |
43 | input_h, input_w, _ = input_size
44 | kernel_h, kernel_w = kernel_size
45 | stride_h, stride_w = stride
46 | dilation_h, dilation_w = dilation
47 |
48 | # Compute the effective kernel size after dilation
49 | effective_kernel_h = (kernel_h - 1) * dilation_h + 1
50 | effective_kernel_w = (kernel_w - 1) * dilation_w + 1
51 |
52 | # Compute the padding needed for same convolution
53 | pad_h = int(max((input_h - 1) * stride_h + effective_kernel_h - input_h, 0))
54 | pad_w = int(max((input_w - 1) * stride_w + effective_kernel_w - input_w, 0))
55 |
56 | # Compute the padding for each side
57 | pad_top = pad_h // 2
58 | pad_left = pad_w // 2
59 |
60 | return (pad_top, pad_left)
61 |
62 |
63 | class Base(nn.Module):
64 | """Base class for neural network models."""
65 | def __init__(self, **kwargs):
66 | super().__init__()
67 | self.__dict__.update(kwargs)
68 |
69 | @property
70 | def num_params(self):
71 | return sum(p.numel() for p in self.parameters())
72 |
73 | @property
74 | def shapes(self):
75 | return {name: p.shape for name, p in self.named_parameters()}
76 |
77 | def summary(self):
78 | print(self)
79 | print(f"Number of parameters: {self.num_params}")
80 |
81 |
82 | class Network(Base):
83 | """Fully Connected / Convolutional Neural Network
84 |
85 | Args:
86 | n_inputs (Union[List[int], Tuple[int], torch.Size]): Input shape
87 | n_outputs (int): Number of output classes
88 | conv_layers_list (List[dict], optional): List of convolutional layers. Defaults to [].
89 | n_hiddens_list (Union[List, int], optional): List of hidden units. Defaults to 0.
90 | activation_f (str, optional): Activation function. Defaults to "ReLU".
91 | dropout (float, optional): Dropout rate. Defaults to 0.0.
92 |
93 | conv_layers_list dict keys:
94 | filters: int
95 | kernel_size: int
96 | stride: int
97 | dilation: int
98 | padding: int
99 | bias: bool
100 | batch_norm: bool
101 | repeat: int
102 | """
103 | def __init__(
104 | self,
105 | n_inputs: Union[List[int], Tuple[int], torch.Size],
106 | n_outputs: int,
107 | conv_layers_list: List[dict] = [],
108 | n_hiddens_list: Union[List, int] = 0,
109 | activation_f: str = "ReLU",
110 | dropout: float = 0.0,
111 | ):
112 | super().__init__()
113 |
114 | if isinstance(n_hiddens_list, int):
115 | n_hiddens_list = [n_hiddens_list]
116 |
117 | if n_hiddens_list == [] or n_hiddens_list == [0]:
118 | self.n_hidden_layers = 0
119 | else:
120 | self.n_hidden_layers = len(n_hiddens_list)
121 |
122 | activation = get_activation(activation_f)
123 |
124 | # Convert n_inputs to tensor for shape calculations
125 | ni = torch.tensor(n_inputs)
126 |
127 | conv_layers = []
128 | if conv_layers_list:
129 | for conv_layer in conv_layers_list:
130 | n_channels = int(ni[0])
131 |
132 | padding = conv_layer.get(
133 | "padding",
134 | compute_padding( # same padding
135 | tuple(ni.tolist()),
136 | conv_layer["kernel_size"],
137 | conv_layer.get("stride", 1),
138 | conv_layer.get("dilation", 1),
139 | ),
140 | )
141 |
142 | # Add repeated conv blocks
143 | for i in range(conv_layer.get("repeat", 1)):
144 | # Convolutional layer
145 | conv_layers.append(
146 | nn.Conv2d(
147 | n_channels if i == 0 else conv_layer["filters"],
148 | conv_layer["filters"],
149 | conv_layer["kernel_size"],
150 | stride=conv_layer.get("stride", 1),
151 | padding=padding,
152 | dilation=conv_layer.get("dilation", 1),
153 | bias=conv_layer.get("bias", True),
154 | )
155 | )
156 |
157 | # Activation
158 | conv_layers.append(activation())
159 |
160 | # Optional batch norm
161 | if conv_layer.get("batch_norm"):
162 | conv_layers.append(nn.BatchNorm2d(conv_layer["filters"]))
163 |
164 | # Max pooling after each conv block
165 | conv_layers.append(nn.MaxPool2d(2, stride=2))
166 |
167 | # Optional dropout
168 | if dropout > 0:
169 | conv_layers.append(nn.Dropout(dropout))
170 |
171 | # Update input shape for next layer
172 | ni = torch.cat([torch.tensor([conv_layer["filters"]]), ni[1:] // 2])
173 |
174 | self.conv = nn.Sequential(*conv_layers)
175 |
176 | # Fully connected layers
177 | ni = int(torch.prod(ni))
178 | fcn_layers = []
179 | if self.n_hidden_layers > 0:
180 | for _, n_units in enumerate(n_hiddens_list):
181 | fcn_layers.extend([
182 | nn.Linear(ni, n_units),
183 | activation()
184 | ])
185 | if dropout > 0:
186 | fcn_layers.append(nn.Dropout(dropout))
187 | ni = n_units
188 |
189 | self.fcn = nn.Sequential(*fcn_layers)
190 | self.output = nn.Linear(ni, n_outputs)
191 |
192 | def forward(self, x: torch.Tensor) -> torch.Tensor:
193 | x = self.conv(x)
194 | x = x.view(x.size(0), -1)
195 | x = self.fcn(x)
196 | return self.output(x)
197 |
198 | '''ResNet in PyTorch.
199 |
200 | For Pre-activation ResNet, see 'preact_resnet.py'.
201 |
202 | Reference:
203 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
204 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
205 | '''
206 |
207 |
208 |
209 | class BasicBlock(nn.Module):
210 | expansion = 1
211 |
212 | def __init__(self, in_planes, planes, stride=1):
213 | super(BasicBlock, self).__init__()
214 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
215 | self.bn1 = nn.BatchNorm2d(planes)
216 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
217 | self.bn2 = nn.BatchNorm2d(planes)
218 |
219 | self.shortcut = nn.Sequential()
220 | if stride != 1 or in_planes != self.expansion*planes:
221 | self.shortcut = nn.Sequential(
222 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
223 | nn.BatchNorm2d(self.expansion*planes)
224 | )
225 |
226 | def forward(self, x):
227 | out = F.relu(self.bn1(self.conv1(x)))
228 | out = self.bn2(self.conv2(out))
229 | out += self.shortcut(x)
230 | out = F.relu(out)
231 | return out
232 |
233 |
234 | class Bottleneck(nn.Module):
235 | expansion = 4
236 |
237 | def __init__(self, in_planes, planes, stride=1):
238 | super(Bottleneck, self).__init__()
239 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
240 | self.bn1 = nn.BatchNorm2d(planes)
241 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
242 | self.bn2 = nn.BatchNorm2d(planes)
243 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
244 | self.bn3 = nn.BatchNorm2d(self.expansion*planes)
245 |
246 | self.shortcut = nn.Sequential()
247 | if stride != 1 or in_planes != self.expansion*planes:
248 | self.shortcut = nn.Sequential(
249 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
250 | nn.BatchNorm2d(self.expansion*planes)
251 | )
252 |
253 | def forward(self, x):
254 | out = F.relu(self.bn1(self.conv1(x)))
255 | out = F.relu(self.bn2(self.conv2(out)))
256 | out = self.bn3(self.conv3(out))
257 | out += self.shortcut(x)
258 | out = F.relu(out)
259 | return out
260 |
261 |
262 | class ResNet(nn.Module):
263 | def __init__(self, block, num_blocks, num_classes=10):
264 | super(ResNet, self).__init__()
265 | self.in_planes = 64
266 |
267 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
268 | self.bn1 = nn.BatchNorm2d(64)
269 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
270 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
271 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
272 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
273 | self.linear = nn.Linear(512*block.expansion, num_classes)
274 |
275 | def _make_layer(self, block, planes, num_blocks, stride):
276 | strides = [stride] + [1]*(num_blocks-1)
277 | layers = []
278 | for stride in strides:
279 | layers.append(block(self.in_planes, planes, stride))
280 | self.in_planes = planes * block.expansion
281 | return nn.Sequential(*layers)
282 |
283 | def forward(self, x):
284 | out = F.relu(self.bn1(self.conv1(x)))
285 | out = self.layer1(out)
286 | out = self.layer2(out)
287 | out = self.layer3(out)
288 | out = self.layer4(out)
289 | out = F.avg_pool2d(out, 4)
290 | out = out.view(out.size(0), -1)
291 | out = self.linear(out)
292 | return out
293 |
294 |
295 | def ResNet18(num_classes = 10):
296 | return ResNet(BasicBlock, [2,2,2,2], num_classes = num_classes)
297 |
298 | def ResNet34(num_classes = 10):
299 | return ResNet(BasicBlock, [3,4,6,3], num_classes = num_classes)
300 |
301 | def ResNet50(num_classes = 10):
302 | return ResNet(Bottleneck, [3,4,6,3], num_classes = num_classes)
303 |
304 | def ResNet101(num_classes = 10):
305 | return ResNet(Bottleneck, [3,4,23,3], num_classes = num_classes)
306 |
307 | def ResNet152(num_classes = 10):
308 | return ResNet(Bottleneck, [3,8,36,3], num_classes = num_classes)
--------------------------------------------------------------------------------
/MARS_M/utils/model_CNN.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple, Type, Union
2 | import importlib
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | def pair(t):
9 | return t if isinstance(t, tuple) else (t, t)
10 |
11 |
12 | def get_activation(activation_f: str) -> Type:
13 | """Get PyTorch activation function by name."""
14 | package_name = "torch.nn"
15 | module = importlib.import_module(package_name)
16 |
17 | activations = [getattr(module, attr) for attr in dir(module)]
18 | activations = [
19 | cls for cls in activations if isinstance(cls, type) and issubclass(cls, nn.Module)
20 | ]
21 | names = [cls.__name__.lower() for cls in activations]
22 |
23 | try:
24 | index = names.index(activation_f.lower())
25 | return activations[index]
26 | except ValueError:
27 | raise NotImplementedError(f"get_activation: {activation_f=} is not yet implemented.")
28 |
29 |
30 | def compute_padding(
31 | input_size: tuple, kernel_size: int | tuple, stride: int | tuple = 1, dilation: int | tuple = 1
32 | ) -> Tuple[int, int]:
33 | """Compute padding for 'same' convolution."""
34 | if len(input_size) == 2:
35 | input_size = (*input_size, 1)
36 | if isinstance(kernel_size, int):
37 | kernel_size = (kernel_size, kernel_size)
38 | if isinstance(stride, int):
39 | stride = (stride, stride)
40 | if isinstance(dilation, int):
41 | dilation = (dilation, dilation)
42 |
43 | input_h, input_w, _ = input_size
44 | kernel_h, kernel_w = kernel_size
45 | stride_h, stride_w = stride
46 | dilation_h, dilation_w = dilation
47 |
48 | # Compute the effective kernel size after dilation
49 | effective_kernel_h = (kernel_h - 1) * dilation_h + 1
50 | effective_kernel_w = (kernel_w - 1) * dilation_w + 1
51 |
52 | # Compute the padding needed for same convolution
53 | pad_h = int(max((input_h - 1) * stride_h + effective_kernel_h - input_h, 0))
54 | pad_w = int(max((input_w - 1) * stride_w + effective_kernel_w - input_w, 0))
55 |
56 | # Compute the padding for each side
57 | pad_top = pad_h // 2
58 | pad_left = pad_w // 2
59 |
60 | return (pad_top, pad_left)
61 |
62 |
63 | class Base(nn.Module):
64 | """Base class for neural network models."""
65 | def __init__(self, **kwargs):
66 | super().__init__()
67 | self.__dict__.update(kwargs)
68 |
69 | @property
70 | def num_params(self):
71 | return sum(p.numel() for p in self.parameters())
72 |
73 | @property
74 | def shapes(self):
75 | return {name: p.shape for name, p in self.named_parameters()}
76 |
77 | def summary(self):
78 | print(self)
79 | print(f"Number of parameters: {self.num_params}")
80 |
81 |
82 | class Network(Base):
83 | """Fully Connected / Convolutional Neural Network
84 |
85 | Args:
86 | n_inputs (Union[List[int], Tuple[int], torch.Size]): Input shape
87 | n_outputs (int): Number of output classes
88 | conv_layers_list (List[dict], optional): List of convolutional layers. Defaults to [].
89 | n_hiddens_list (Union[List, int], optional): List of hidden units. Defaults to 0.
90 | activation_f (str, optional): Activation function. Defaults to "ReLU".
91 | dropout (float, optional): Dropout rate. Defaults to 0.0.
92 |
93 | conv_layers_list dict keys:
94 | filters: int
95 | kernel_size: int
96 | stride: int
97 | dilation: int
98 | padding: int
99 | bias: bool
100 | batch_norm: bool
101 | repeat: int
102 | """
103 | def __init__(
104 | self,
105 | n_inputs: Union[List[int], Tuple[int], torch.Size],
106 | n_outputs: int,
107 | conv_layers_list: List[dict] = [],
108 | n_hiddens_list: Union[List, int] = 0,
109 | activation_f: str = "ReLU",
110 | dropout: float = 0.0,
111 | ):
112 | super().__init__()
113 |
114 | if isinstance(n_hiddens_list, int):
115 | n_hiddens_list = [n_hiddens_list]
116 |
117 | if n_hiddens_list == [] or n_hiddens_list == [0]:
118 | self.n_hidden_layers = 0
119 | else:
120 | self.n_hidden_layers = len(n_hiddens_list)
121 |
122 | activation = get_activation(activation_f)
123 |
124 | # Convert n_inputs to tensor for shape calculations
125 | ni = torch.tensor(n_inputs)
126 |
127 | conv_layers = []
128 | if conv_layers_list:
129 | for conv_layer in conv_layers_list:
130 | n_channels = int(ni[0])
131 |
132 | padding = conv_layer.get(
133 | "padding",
134 | compute_padding( # same padding
135 | tuple(ni.tolist()),
136 | conv_layer["kernel_size"],
137 | conv_layer.get("stride", 1),
138 | conv_layer.get("dilation", 1),
139 | ),
140 | )
141 |
142 | # Add repeated conv blocks
143 | for i in range(conv_layer.get("repeat", 1)):
144 | # Convolutional layer
145 | conv_layers.append(
146 | nn.Conv2d(
147 | n_channels if i == 0 else conv_layer["filters"],
148 | conv_layer["filters"],
149 | conv_layer["kernel_size"],
150 | stride=conv_layer.get("stride", 1),
151 | padding=padding,
152 | dilation=conv_layer.get("dilation", 1),
153 | bias=conv_layer.get("bias", True),
154 | )
155 | )
156 |
157 | # Activation
158 | conv_layers.append(activation())
159 |
160 | # Optional batch norm
161 | if conv_layer.get("batch_norm"):
162 | conv_layers.append(nn.BatchNorm2d(conv_layer["filters"]))
163 |
164 | # Max pooling after each conv block
165 | conv_layers.append(nn.MaxPool2d(2, stride=2))
166 |
167 | # Optional dropout
168 | if dropout > 0:
169 | conv_layers.append(nn.Dropout(dropout))
170 |
171 | # Update input shape for next layer
172 | ni = torch.cat([torch.tensor([conv_layer["filters"]]), ni[1:] // 2])
173 |
174 | self.conv = nn.Sequential(*conv_layers)
175 |
176 | # Fully connected layers
177 | ni = int(torch.prod(ni))
178 | fcn_layers = []
179 | if self.n_hidden_layers > 0:
180 | for _, n_units in enumerate(n_hiddens_list):
181 | fcn_layers.extend([
182 | nn.Linear(ni, n_units),
183 | activation()
184 | ])
185 | if dropout > 0:
186 | fcn_layers.append(nn.Dropout(dropout))
187 | ni = n_units
188 |
189 | self.fcn = nn.Sequential(*fcn_layers)
190 | self.output = nn.Linear(ni, n_outputs)
191 |
192 | def forward(self, x: torch.Tensor) -> torch.Tensor:
193 | x = self.conv(x)
194 | x = x.view(x.size(0), -1)
195 | x = self.fcn(x)
196 | return self.output(x)
197 |
198 | '''ResNet in PyTorch.
199 |
200 | For Pre-activation ResNet, see 'preact_resnet.py'.
201 |
202 | Reference:
203 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
204 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
205 | '''
206 |
207 |
208 |
209 | class BasicBlock(nn.Module):
210 | expansion = 1
211 |
212 | def __init__(self, in_planes, planes, stride=1):
213 | super(BasicBlock, self).__init__()
214 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
215 | self.bn1 = nn.BatchNorm2d(planes)
216 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
217 | self.bn2 = nn.BatchNorm2d(planes)
218 |
219 | self.shortcut = nn.Sequential()
220 | if stride != 1 or in_planes != self.expansion*planes:
221 | self.shortcut = nn.Sequential(
222 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
223 | nn.BatchNorm2d(self.expansion*planes)
224 | )
225 |
226 | def forward(self, x):
227 | out = F.relu(self.bn1(self.conv1(x)))
228 | out = self.bn2(self.conv2(out))
229 | out += self.shortcut(x)
230 | out = F.relu(out)
231 | return out
232 |
233 |
234 | class Bottleneck(nn.Module):
235 | expansion = 4
236 |
237 | def __init__(self, in_planes, planes, stride=1):
238 | super(Bottleneck, self).__init__()
239 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
240 | self.bn1 = nn.BatchNorm2d(planes)
241 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
242 | self.bn2 = nn.BatchNorm2d(planes)
243 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
244 | self.bn3 = nn.BatchNorm2d(self.expansion*planes)
245 |
246 | self.shortcut = nn.Sequential()
247 | if stride != 1 or in_planes != self.expansion*planes:
248 | self.shortcut = nn.Sequential(
249 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
250 | nn.BatchNorm2d(self.expansion*planes)
251 | )
252 |
253 | def forward(self, x):
254 | out = F.relu(self.bn1(self.conv1(x)))
255 | out = F.relu(self.bn2(self.conv2(out)))
256 | out = self.bn3(self.conv3(out))
257 | out += self.shortcut(x)
258 | out = F.relu(out)
259 | return out
260 |
261 |
262 | class ResNet(nn.Module):
263 | def __init__(self, block, num_blocks, num_classes=10):
264 | super(ResNet, self).__init__()
265 | self.in_planes = 64
266 |
267 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
268 | self.bn1 = nn.BatchNorm2d(64)
269 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
270 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
271 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
272 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
273 | self.linear = nn.Linear(512*block.expansion, num_classes)
274 |
275 | def _make_layer(self, block, planes, num_blocks, stride):
276 | strides = [stride] + [1]*(num_blocks-1)
277 | layers = []
278 | for stride in strides:
279 | layers.append(block(self.in_planes, planes, stride))
280 | self.in_planes = planes * block.expansion
281 | return nn.Sequential(*layers)
282 |
283 | def forward(self, x):
284 | out = F.relu(self.bn1(self.conv1(x)))
285 | out = self.layer1(out)
286 | out = self.layer2(out)
287 | out = self.layer3(out)
288 | out = self.layer4(out)
289 | out = F.avg_pool2d(out, 4)
290 | out = out.view(out.size(0), -1)
291 | out = self.linear(out)
292 | return out
293 |
294 |
295 | def ResNet18(num_classes = 10):
296 | return ResNet(BasicBlock, [2,2,2,2], num_classes = num_classes)
297 |
298 | def ResNet34(num_classes = 10):
299 | return ResNet(BasicBlock, [3,4,6,3], num_classes = num_classes)
300 |
301 | def ResNet50(num_classes = 10):
302 | return ResNet(Bottleneck, [3,4,6,3], num_classes = num_classes)
303 |
304 | def ResNet101(num_classes = 10):
305 | return ResNet(Bottleneck, [3,4,23,3], num_classes = num_classes)
306 |
307 | def ResNet152(num_classes = 10):
308 | return ResNet(Bottleneck, [3,8,36,3], num_classes = num_classes)
--------------------------------------------------------------------------------
/MARS_M/README.md:
--------------------------------------------------------------------------------
1 | # MARS-M: When Variance Reduction Meets Matrices
2 |
3 | This repository contains the official code for the paper [MARS-M: When Variance Reduction Meets Matrices](https://arxiv.org/abs/2510.21800).
4 |
5 | Authors: [Yifeng Liu](https://scholar.google.com/citations?user=mFvOVkMAAAAJ&hl=zh-CN)\*, [Angela Yuan](https://scholar.google.com/citations?user=8foZzX4AAAAJ)\*, [Quanquan Gu](https://web.cs.ucla.edu/~qgu/)
6 |
7 | ## 🔔 NEWS
8 | - **[10/20/2025]** Our paper is released on arXiv: https://arxiv.org/abs/2510.21800.
9 | - **[10/05/2025]** Our code is released.
10 |
11 | ## MARS-M
12 |
13 | **MARS-M** is a brand-new optimizer that integrates matrix-based optimizer (i.e., Muon and Moonlight) with the variance-reduction based optimizer MARS to reduce high stochastic gradient variance in the training process.
14 |
15 | In detail, the **MARS-M** optimizer is built on **MARS** framework:
16 |
17 | ---
18 |
19 | **Algorithm 1** MARS
20 |
21 | ---
22 |
23 | $$
24 | \begin{align*}
25 | &\pmb{input: }\mathbf{x}_0\in\mathbb{R}^{A\times B}, \lambda, \beta, \{\gamma_t\}, \{\eta_t\}\\
26 | &\text{Set }\mathbf{m}_0\leftarrow \mathbf{0}\textbf{ and }\mathbf{x}_1\leftarrow\mathbf{x}_0\\
27 | &\pmb{for }\textbf{ }t=1,\pmb{ to }\textbf{ }n\textbf{ }\pmb{ do}\\
28 | &\qquad\textbf{sample }\mathbf{\xi}_t\textbf{ and let }\mathbf{c}_t = \nabla f(\mathbf{x}_t, \mathbf{\xi}_t)+\gamma_t\bigg(\frac{\beta}{1-\beta}\bigg)\big(\nabla f(\mathbf{x}_t, \mathbf{\xi}_t)-\nabla f(\mathbf{x}_{t-1}, \mathbf{\xi}_t)\big)\\
29 | &\qquad\mathbf{m}_t = \beta \mathbf{m}_{t-1} + (1-\beta)\text{Clip}(\mathbf{c}_t, 1)\\
30 | &\qquad\mathbf{x}\_{t+1} = \arg\min_{\mathbf{x} \in \mathbb{R}^d} \left\\{\eta_t \left\langle \mathbf{m}_t, \mathbf{x} \right\rangle + \frac{1}{2} \\|\mathbf{x} - \mathbf{x}\_t
31 | \\|\_{\mathbf{H}_t}^2\right\\}\\
32 | &\pmb{end}\textbf{ }\pmb{for}
33 | \end{align*}
34 | $$
35 |
36 | ---
37 |
38 | where
39 |
40 | $$
41 | \text{Clip}(\mathbf{c}_t,1) = \begin{cases}
42 | \frac{\mathbf{c}_t}{\\|\mathbf{c}_t\\|_2} & \text{if } \\|\mathbf{c}_t\\|_2 > 1,\\
43 | \mathbf{c}_t & \text{otherwise}.
44 | \end{cases}
45 | $$
46 |
47 |
48 |
49 | Here ${\color{red}\gamma_t}$ is a scaling parameter that controls the strength of gradient correction and plays a central role in MARS.
50 |
51 | Under the **MARS** framework, we propose **MARS-M** that applies MARS to matrix-based optimizers (See `optimizers/mars_m.py` for the implementation):
52 |
53 | ---
54 |
55 | **Algorithm 2** MARS-M
56 |
57 | ---
58 |
59 | $$
60 | \begin{align*}
61 | &\pmb{input: }\mathbf{X}_0\in\mathbb{R}^{A\times B}, \lambda, \beta, \{\gamma_t\}, \{\eta_t\}\\
62 | &\text{Set }\mathbf{M}_0\leftarrow \mathbf{0}\textbf{ and }\mathbf{X}_1\leftarrow\mathbf{X}_0\\
63 | &\pmb{for }\textbf{ }t=1,\pmb{ to }\textbf{ }n\textbf{ }\pmb{ do}\\
64 | &\qquad\textbf{sample }\mathbf{\xi}_t\textbf{ and let }\mathbf{C}_t = \nabla f(\mathbf{X}_t, \mathbf{\xi}_t)+\gamma_t\bigg(\frac{\beta}{1-\beta}\bigg)\big(\nabla f(\mathbf{X}_t, \mathbf{\xi}_t)-\nabla f(\mathbf{X}_{t-1}, \mathbf{\xi}_t)\big)\\
65 | &\qquad\mathbf{M}_t = \beta \mathbf{M}_{t-1} + (1-\beta)\text{Clip}(\mathbf{C}_t, 1)\\
66 | &\qquad\mathbf{O}_t = \text{NewtonSchulz}(\mathbf{M}_t)\\
67 | &\qquad\mathbf{X}_{t+1} = \mathbf{X}_t - \eta_t(0.2\cdot\mathbf{O}_t\cdot\sqrt{\max(A,B)} + \lambda \mathbf{X}_t)\\
68 | &\pmb{end}\textbf{ }\pmb{for}
69 | \end{align*}
70 | $$
71 |
72 | ---
73 |
74 | To accelerate training process, we also propose the approximated version of MARS-M by substituting $f(\mathbf{X}\_{t-1}, \mathbf{\xi}\_t)$ with $f(\mathbf{X}\_{t-1}, \mathbf{\xi}\_{t-1})$ as follows:
75 |
76 | ---
77 |
78 | **Algorithm 3** MARS-M-approx
79 |
80 | ---
81 |
82 | $$
83 | \begin{align*}
84 | &\pmb{input: }\mathbf{X}_0\in\mathbb{R}^{A\times B}, \lambda, \beta, \{\gamma_t\}, \{\eta_t\}\\
85 | &\text{Set }\mathbf{M}_0\leftarrow \mathbf{0}\textbf{ and }\mathbf{X}_1\leftarrow\mathbf{X}_0\\
86 | &\pmb{for }\textbf{ }t=1,\pmb{ to }\textbf{ }n\textbf{ }\pmb{ do}\\
87 | &\qquad\textbf{sample }\mathbf{\xi}_t\textbf{ and let }\mathbf{C}_t = \nabla f(\mathbf{X}_t, \mathbf{\xi}_t)+\gamma_t\bigg(\frac{\beta}{1-\beta}\bigg)\big(\nabla f(\mathbf{X}_t, \mathbf{\xi}_t)-\nabla f(\mathbf{X}_{t-1}, \mathbf{\xi}_{t-1})\big)\\
88 | &\qquad\mathbf{M}_t = \beta \mathbf{M}_{t-1} + (1-\beta)\text{Clip}(\mathbf{C}_t, 1)\\
89 | &\qquad\mathbf{O}_t = \text{NewtonSchulz}(\mathbf{M}_t)\\
90 | &\qquad\mathbf{X}_{t+1} = \mathbf{X}_t - \eta_t(0.2\cdot\mathbf{O}_t\cdot\sqrt{\max(A,B)} + \lambda \mathbf{X}_t)\\
91 | &\pmb{end}\textbf{ }\pmb{for}
92 | \end{align*}
93 | $$
94 |
95 | ---
96 |
97 | ### **Performance of MARS-M Compared to Baseline of Muon (Moonlight) and AdamW**
98 |
99 | #### Experiment Settings
100 |
101 | We implement grid search on learning rates for AdamW and Muon (Moonlight) and use the same hyper-parameters of Muon (Moonlight) for experiments with **MARS-M**.
102 |
103 | #### Experiments on OpenWebText
104 |
105 | In our experiments, gradients are calculated once per sample and per update (**MARS-M**-approx). Performing exact gradient computation with two evaluations per update, as in the exact form of **MARS-M**, can slightly enhance performance but at the cost of doubling the computational cost. Moreover, **MARS-M** also outperforms AdamW for the best loss value.
106 |
107 | **MARS-M** consistently outperforms [Muon (Moonlight version)](https://arxiv.org/abs/2502.16982) optimizers across GPT-2 models:
108 |
109 | | **GPT-2 small** | **GPT-2 medium** | **GPT-2 large** |
110 | | ------------------------------------------------ | ------------------------------------------------- | ------------------------------------------------ |
111 | |
|
|
|
112 |
113 | ---
114 |
115 | Zoomed-in loss curves
116 | | **GPT-2 small** | **GPT-2 medium** | **GPT-2 large** |
117 | | ------------------------------------------------ | ------------------------------------------------- | ------------------------------------------------ |
118 | |
|
|
|
119 |
120 | #### Experiments on FineWeb-Edu
121 |
122 | Below are the training and validation loss curves for both GPT‑2 Small and GPT‑2 XL when using our MARS-M approach versus [Muon (Moonlight version)](https://arxiv.org/abs/2502.16982) optimizers. As you can see, MARS-M often yields faster convergence and consistently lower losses across different training steps. Moreover, **MARS-M** also outperforms AdamW for the best loss value.
123 |
124 | | Model | **GPT-2 small** | **GPT-2 XL** |
125 | | ------------------------- | -------------------------------------------------- | ----------------------------------------------- |
126 | | **Training Loss** |
|
|
127 | | **Validation Loss** |
|
|
128 |
129 | ---
130 |
131 | Zoomed-in loss curves
132 | | Model | **GPT-2 small** | **GPT-2 XL** |
133 | | ------------------------- | -------------------------------------------------- | ----------------------------------------------- |
134 | | **Training Loss** |
|
|
135 | | **Validation Loss** |
|
|
136 |
137 | ## Training GPT-2 from Scratch:
138 |
139 | ### Install Dependencies
140 |
141 | ```
142 | $ pip install torch==2.1.2 transformers==4.33.0 datasets tiktoken numpy==1.26.4 wandb
143 | ```
144 |
145 | ### Data Preparation
146 |
147 | Prepare the [OpenWebText](https://huggingface.co/datasets/openwebtext) data following [nanoGPT](https://github.com/karpathy/nanoGPT/):
148 |
149 | ```
150 | $ python data/openwebtext/prepare.py
151 | ```
152 |
153 | ### **Start Training**
154 |
155 | To train a model using the **MARS-M** optimizer, run the following command:
156 |
157 | ```bash
158 | $ torchrun --standalone --nproc_per_node=8 train_mars_m.py config/${your_config_file}
159 | ```
160 |
161 | This command initiates the training of a GPT-2 model on the OpenWebText dataset using the **MARS-M** optimizer. All relevant hyperparameters—training, model, and optimizer—are specified in the configuration file (`${your_config_file}`). These parameters can be adjusted directly in the configuration file or through the bash script.
162 |
163 | ### **Hyperparameter Details**
164 |
165 | #### **Model Hyperparameters**:
166 |
167 | - **n_layer**: Layers of networks, 12 for GPT2 Small, 24 for GPT2 Medium, 36 for GPT2 Large
168 | - **n_head**: Number of heads, 12 for GPT2 small, 16 for GPT2 Medium, 20 for GPT2 Large
169 | - **n_embd**: Embedding dimension, 768 for GPT2 small, 1024 for GPT2 Medium, 1280 for GPT2 Large
170 |
171 | #### **Optimizer Hyperparameters**:
172 |
173 | - **`learning_rate`**: Learning rate for the **MARS-M** optimizer.
174 | - **`weight_decay`**: Weight decay for the **MARS-M** optimizer.
175 | - **`beta1`**: momentum for **MARS-M** optimizer.
176 |
177 | - Default: `beta1=0.95, beta2=0.99`
178 | - **`betas_1d`**: Weights for exponential moving average in AdamW optimizer (for 1d parameters).
179 |
180 | - Default: `(0.9, 0.95)`
181 | - **`is_approx`**: Whether to use approximate gradient calculation (**MARS-M**-approx).
182 |
183 | - Default: `True`
184 | - **`gamma`**: The scaling parameter that controls the strength of gradient correction.
185 |
186 | - Default: 0.025
187 |
188 | #### **Training Hyperparameters**:
189 |
190 | - **`batch_size`**: Mini-batch size per device. (for example GPT-2 Small on an A100 GPU typically uses a batch size of 15.)
191 | - **`gradient_accumulation_steps`**: Gradient accumulation steps to ensure the total effective batch size matches the desired scale. (for example, for a total batch size of 480: $15 \times 4 \times 8 \, \text{GPUs}$.)
192 | - **`schedule`**: learning rate schedule.
193 | - Default: `cosine`
194 |
195 | For more detailed hyperparameter examples, refer to:
196 |
197 | - `config/train_gpt2_small_mars_m.py`
198 | - `scripts/run_mars_m_small.sh`
199 |
200 | ---
201 |
202 | ### Reproducing Our Results
203 |
204 | #### **Reproducing GPT-2 Small (125M) Results**
205 |
206 | Training with MARS-M using
207 |
208 | ```
209 | $ bash scripts/run_mars_m_small.sh
210 | ```
211 |
212 | or
213 |
214 | ```
215 | $ torchrun --standalone --nproc_per_node=8 \
216 | train_mars_m.py \
217 | config/train_gpt2_small_mars_m.py \
218 | --batch_size=15 \
219 | --gradient_accumulation_steps=4
220 | ```
221 |
222 | #### Reproducing GPT2 Medium (355M) Results
223 |
224 | Training with MARS-M using
225 |
226 | ```
227 | $ bash scripts/run_mars_m_medium.sh
228 | ```
229 |
230 | or
231 |
232 | ```
233 | $ torchrun --standalone --nproc_per_node=8 \
234 | train_mars_m.py \
235 | config/train_gpt2_medium_mars_m.py \
236 | --batch_size=15 \
237 | --gradient_accumulation_steps=4
238 | ```
239 |
240 | #### Reproducing GPT2 Large (770M) Results
241 |
242 | Training with MARS-M using
243 |
244 | ```
245 | $ bash scripts/run_mars_m_large.sh
246 | ```
247 |
248 | or
249 |
250 | ```
251 | $ torchrun --standalone --nproc_per_node=8 \
252 | train_mars_m.py \
253 | config/train_gpt2_large_mars_m.py \
254 | --batch_size=5 \
255 | --gradient_accumulation_steps=12
256 | ```
257 |
258 | #### **Reproducing GPT-2 XL (1.5B) Results on FineWeb-Edu**
259 |
260 | ```
261 | $ bash scripts/run_mars_m_xl_fw.sh
262 | ```
263 |
264 | or
265 |
266 | ```
267 | $ torchrun --standalone --nproc_per_node=8 \
268 | train_mars_m_fw.py \
269 | config/train_gpt2_xl_mars_m.py \
270 | --batch_size=5 \
271 | --gradient_accumulation_steps=12
272 | ```
273 |
274 | #### Reproducing Baseline Results
275 |
276 | To reproduce the Moonlight baseline:
277 |
278 | ```
279 | bash scripts/run_moonlight_{small/medium/large}.sh
280 | ```
281 |
282 | Other baselines can be implemented with codes in `../MARS` folder.
283 |
284 | Please adjust ``nproc_per_node``, ``batch_size``, and ``gradient_accumulation_steps`` accordingly if you use other hardware setup. Make sure their product equals 480.
285 |
286 | #### Hyperparameters for GPT-2 models
287 |
288 | | Model Name | Model Size | OpenWebText LR | FineWeb-Edu LR | weight decay |
289 | | :----------: | :--------: | :------------: | :------------: | :----------: |
290 | | GPT-2 small | 125M | 6e-3 | 1e-2 | 1e-1 |
291 | | GPT-2 medium | 355M | 5e-3 | 5e-3 | 1e-1 |
292 | | GPT-2 large | 770M | 5e-3 | 5e-3 | 1e-1 |
293 | | GPT-2 xl | 1.5B | - | 3e-3 | 1e-1 |
294 |
295 | ### Customized Training
296 |
297 | To build your own training pipeline on other architectures and datasets, use the following template as an example:
298 |
299 | ```python
300 | import torch
301 | import torch.nn.functional as F
302 | from mars_m import MARS_M
303 |
304 | # init model loss function and input data
305 | model = Model()
306 | data_loader = ...
307 |
308 | # init the optimizer
309 | muon_params = [p for name, p in model.named_parameters() if p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name]
310 | adamw_params = [p for name, p in model.named_parameters() if p.ndim < 2 or "embed_tokens" in name or "lm_head" in name]
311 | optimizer = MARS_M(muon_params=muon_params, adamw_params=adamw_params, lr=1e-3, betas=(0.9, 0.95), gamma=0.025)
312 |
313 | total_bs = len(data_loader)
314 | bs = total_bs * block_size
315 | k = 10
316 | iter_num = -1
317 |
318 | # training loop
319 | for epoch in range(epochs):
320 | for X, Y in data_loader:
321 | # standard training code
322 | logits, loss = model(X, Y)
323 | loss.backward()
324 | optimizer.step(bs=bs)
325 | optimizer.zero_grad(set_to_none=True)
326 | optimizer.update_last_grad()
327 | iter_num += 1
328 |
329 | ```
330 |
331 | ## Citation
332 |
333 | If you find this repo useful for your research, please consider citing our github repository:
334 |
335 | ```tex
336 | @misc{liu2025MARS,
337 | author = {Yifeng Liu and Angela Yuan and Quanquan Gu},
338 | title = {MARS-M: When Variance Reduction Meets Matrices},
339 | year = {2025},
340 | url = {https://github.com/AGI-Arena/MARS/tree/main/MARS_M/}
341 | }
342 | ```
343 |
344 | ## Acknowledgements
345 |
346 | This repo is built upon [nanoGPT](https://github.com/karpathy/nanoGPT/), [levanter](https://github.com/stanford-crfm/levanter/) and [Sophia](https://github.com/Liuhong99/Sophia), we thank the authors for their great work!
347 |
--------------------------------------------------------------------------------
/MARS/train_adamw.py:
--------------------------------------------------------------------------------
1 | """
2 | Adapted from https://github.com/Liuhong99/Sophia/blob/main/train_adam.py
3 | """
4 | import os
5 | import time
6 | import math
7 | import pickle
8 | from contextlib import nullcontext
9 |
10 | import numpy as np
11 | import torch
12 | from torch.nn.parallel import DistributedDataParallel as DDP
13 | from torch.distributed import init_process_group, destroy_process_group
14 |
15 | from model import GPTConfig, GPT
16 | import sys
17 | from ast import literal_eval
18 | # -----------------------------------------------------------------------------
19 | # default config values designed to train a gpt2 (124M) on OpenWebText
20 | # I/O
21 | out_dir = 'out'
22 | eval_interval = 2000
23 | log_interval = 1
24 | eval_iters = 200
25 | eval_only = False # if True, script exits right after the first eval
26 | always_save_checkpoint = True # if True, always save a checkpoint after each eval
27 | init_from = 'scratch' # 'scratch' or 'resume' or 'gpt2*'
28 | # wandb logging
29 | wandb_log = False # disabled by default
30 | wandb_project = 'mars'
31 | wandb_run_name = 'gpt2' # 'run' + str(time.time())
32 | # data
33 | dataset = 'openwebtext'
34 | gradient_accumulation_steps = 5 # used to simulate larger batch sizes
35 | batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size
36 | block_size = 1024
37 | # model
38 | n_layer = 12
39 | n_head = 12
40 | n_embd = 768
41 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
42 | bias = False # do we use bias inside LayerNorm and Linear layers?
43 | # optimizer
44 | optimizer_name = 'adamw'
45 | learning_rate = 6e-4 # max learning rate
46 | max_iters = 600000 # total number of training iterations
47 | weight_decay = 1e-1
48 | beta1 = 0.9
49 | beta2 = 0.95
50 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
51 | interval = 10
52 | variant = 4
53 | schedule='cosine'
54 | # learning rate decay settings
55 | decay_lr = True # whether to decay the learning rate
56 | warmup_iters = 2000 # how many steps to warm up for
57 | lr_decay_iters = 600000 # should be ~= max_iters per Chinchilla
58 | min_lr = 6e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
59 | # DDP settings
60 | backend = 'nccl' # 'nccl', 'gloo', etc.
61 | # system
62 | device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
63 | dtype = 'bfloat16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
64 | compile = True # use PyTorch 2.0 to compile the model to be faster
65 | scale_attn_by_inverse_layer_idx = True
66 | # -----------------------------------------------------------------------------
67 | config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
68 | for arg in sys.argv[1:]:
69 | if '=' not in arg:
70 | # assume it's the name of a config file
71 | assert not arg.startswith('--')
72 | config_file = arg
73 | print(f"Overriding config with {config_file}:")
74 | with open(config_file) as f:
75 | print(f.read())
76 | exec(open(config_file).read())
77 | else:
78 | # assume it's a --key=value argument
79 | assert arg.startswith('--')
80 | key, val = arg.split('=')
81 | key = key[2:]
82 | if key in globals():
83 | try:
84 | # attempt to eval it it (e.g. if bool, number, or etc)
85 | attempt = literal_eval(val)
86 | except (SyntaxError, ValueError):
87 | # if that goes wrong, just use the string
88 | attempt = val
89 | # ensure the types match ok
90 | assert type(attempt) == type(globals()[key])
91 | # cross fingers
92 | print(f"Overriding: {key} = {attempt}")
93 | globals()[key] = attempt
94 | else:
95 | raise ValueError(f"Unknown config key: {key}")
96 |
97 | config = {k: globals()[k] for k in config_keys} # will be useful for logging
98 | # -----------------------------------------------------------------------------
99 |
100 | # various inits, derived attributes, I/O setup
101 | ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
102 | if ddp:
103 | init_process_group(backend=backend)
104 | ddp_rank = int(os.environ['RANK'])
105 | ddp_local_rank = int(os.environ['LOCAL_RANK'])
106 | device = f'cuda:{ddp_local_rank}'
107 | torch.cuda.set_device(device)
108 | master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
109 | seed_offset = ddp_rank # each process gets a different seed
110 | else:
111 | # if not ddp, we are running on a single gpu, and one process
112 | master_process = True
113 | seed_offset = 0
114 | gradient_accumulation_steps *= 8 # simulate 8 gpus
115 |
116 | if master_process:
117 | os.makedirs(out_dir, exist_ok=True)
118 | torch.manual_seed(5000 + seed_offset)
119 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
120 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
121 | device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
122 | # note: float16 data type will automatically use a GradScaler
123 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
124 | ctx = nullcontext() if device_type == 'cpu' else torch.autocast(device_type=device_type, dtype=ptdtype)
125 |
126 | # poor man's data loader
127 | data_dir = os.path.join('data', dataset)
128 | train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
129 | val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
130 | def get_batch(split):
131 | data = train_data if split == 'train' else val_data
132 | ix = torch.randint(len(data) - block_size, (batch_size,))
133 | x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
134 | y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
135 | if device_type == 'cuda':
136 | # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
137 | x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
138 | else:
139 | x, y = x.to(device), y.to(device)
140 | return x, y
141 |
142 | # init these up here, can override if init_from='resume' (i.e. from a checkpoint)
143 | iter_num = 0
144 | best_val_loss = 1e9
145 |
146 | # attempt to derive vocab_size from the dataset
147 | meta_path = os.path.join(data_dir, 'meta.pkl')
148 | meta_vocab_size = None
149 | if os.path.exists(meta_path):
150 | with open(meta_path, 'rb') as f:
151 | meta = pickle.load(f)
152 | meta_vocab_size = meta['vocab_size']
153 | print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})")
154 |
155 | # model init
156 | model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size,
157 | bias=bias, vocab_size=None, dropout=dropout, scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx) # start with model_args from command line
158 | if init_from == 'scratch':
159 | # init a new model from scratch
160 | print("Initializing a new model from scratch")
161 | # determine the vocab size we'll use for from-scratch training
162 | if meta_vocab_size is None:
163 | print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)")
164 | model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304
165 | gptconf = GPTConfig(**model_args)
166 | model = GPT(gptconf)
167 | elif init_from == 'resume':
168 | print(f"Resuming training from {out_dir}")
169 | # resume training from a checkpoint.
170 | ckpt_path = os.path.join(out_dir, 'ckpt.pt')
171 | checkpoint = torch.load(ckpt_path, map_location=device)
172 | checkpoint_model_args = checkpoint['model_args']
173 | # force these config attributes to be equal otherwise we can't even resume training
174 | # the rest of the attributes (e.g. dropout) can stay as desired from command line
175 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
176 | model_args[k] = checkpoint_model_args[k]
177 | # create the model
178 | gptconf = GPTConfig(**model_args)
179 | model = GPT(gptconf)
180 | state_dict = checkpoint['model']
181 | # fix the keys of the state dictionary :(
182 | # honestly no idea how checkpoints sometimes get this prefix, have to debug more
183 | unwanted_prefix = '_orig_mod.'
184 | for k,v in list(state_dict.items()):
185 | if k.startswith(unwanted_prefix):
186 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
187 | model.load_state_dict(state_dict)
188 | iter_num = checkpoint['iter_num']
189 | best_val_loss = checkpoint['best_val_loss']
190 | elif init_from.startswith('gpt2'):
191 | print(f"Initializing from OpenAI GPT-2 weights: {init_from}")
192 | # initialize from OpenAI GPT-2 weights
193 | override_args = dict(dropout=dropout)
194 | model = GPT.from_pretrained(init_from, override_args)
195 | # read off the created config params, so we can store them into checkpoint correctly
196 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
197 | model_args[k] = getattr(model.config, k)
198 | # crop down the model block size if desired, using model surgery
199 | if block_size < model.config.block_size:
200 | model.crop_block_size(block_size)
201 | model_args['block_size'] = block_size # so that the checkpoint will have the right value
202 | model.to(device)
203 |
204 | # initialize a GradScaler. If enabled=False scaler is a no-op
205 | scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))
206 |
207 | # optimizer
208 | optimizer = model.configure_optimizers(optimizer_name, weight_decay, learning_rate, (beta1, beta2), device_type)
209 | if init_from == 'resume':
210 | optimizer.load_state_dict(checkpoint['optimizer'])
211 | del state_dict
212 | del checkpoint
213 | # compile the model
214 | if compile:
215 | print("compiling the model... (takes a ~minute)")
216 | unoptimized_model = model
217 | model = torch.compile(model) # requires PyTorch 2.0
218 |
219 | # wrap model into DDP container
220 | if ddp:
221 | model = DDP(model, device_ids=[ddp_local_rank])
222 |
223 | # helps estimate an arbitrarily accurate loss over either split using many batches
224 | @torch.no_grad()
225 | def estimate_loss():
226 | out = {}
227 | model.eval()
228 | for split in ['train', 'val']:
229 | losses = torch.zeros(eval_iters)
230 | for k in range(eval_iters):
231 | X, Y = get_batch(split)
232 | with ctx:
233 | logits, loss = model(X, Y)
234 | losses[k] = loss.item()
235 | out[split] = losses.mean()
236 | model.train()
237 | return out
238 |
239 | # learning rate decay scheduler (cosine with warmup)
240 | def get_lr(it, schedule='cosine'):
241 | #ing rate schedule {schedule}")
242 | # 1) linear warmup for warmup_iters steps
243 | if it < warmup_iters:
244 | return learning_rate * it / warmup_iters
245 | # 2) if it > lr_decay_iters, return min learning rate
246 | if it > lr_decay_iters:
247 | return min_lr
248 | # 3) in between, use cosine decay down to min learning rate
249 | decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
250 | assert 0 <= decay_ratio <= 1
251 | if schedule=='cosine':
252 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
253 | elif schedule=='exp':
254 | coeff = np.power(0.9, 100 * decay_ratio)
255 | return min_lr + coeff * (learning_rate - min_lr)
256 |
257 | # logging
258 | if wandb_log and master_process:
259 | import wandb
260 | wandb.init(project=wandb_project, name=wandb_run_name, config=config)
261 |
262 | # training loop
263 | X, Y = get_batch('train') # fetch the very first batch
264 | t0 = time.time()
265 | local_iter_num = 0 # number of iterations in the lifetime of this process
266 | raw_model = model.module if ddp else model # unwrap DDP container if needed
267 | running_mfu = -1.0
268 | clip_time = 0
269 | while True:
270 |
271 | # determine and set the learning rate for this iteration
272 | lr = get_lr(iter_num, schedule=schedule) if decay_lr else learning_rate
273 | for param_group in optimizer.param_groups:
274 | param_group['lr'] = lr
275 |
276 | # evaluate the loss on train/val sets and write checkpoints
277 | if iter_num % eval_interval == 0 and master_process:
278 | losses = estimate_loss()
279 | print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
280 | if wandb_log:
281 | wandb.log({
282 | "iter": iter_num,
283 | "train/loss": losses['train'],
284 | "val/loss": losses['val'],
285 | "lr": lr,
286 | "mfu": running_mfu*100, # convert to percentage
287 | }, step=iter_num)
288 | if losses['val'] < best_val_loss or always_save_checkpoint:
289 | best_val_loss = losses['val']
290 | if iter_num > 0:
291 | checkpoint = {
292 | 'model': raw_model.state_dict(),
293 | 'optimizer': optimizer.state_dict(),
294 | 'model_args': model_args,
295 | 'iter_num': iter_num,
296 | 'best_val_loss': best_val_loss,
297 | 'config': config,
298 | }
299 | print(f"saving checkpoint to {out_dir}")
300 | torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
301 | if iter_num % (eval_interval * 5) == 0:
302 | checkpoint = {
303 | 'model': raw_model.state_dict(),
304 | 'optimizer': optimizer.state_dict(),
305 | 'model_args': model_args,
306 | 'iter_num': iter_num,
307 | 'best_val_loss': best_val_loss,
308 | 'config': config,
309 | }
310 | print(f"saving checkpoint to {out_dir}")
311 | torch.save(checkpoint, os.path.join(out_dir, 'ckpt_'+str(iter_num)+'.pt'))
312 | if iter_num == 0 and eval_only:
313 | break
314 |
315 | # forward backward update, with optional gradient accumulation to simulate larger batch size
316 | # and using the GradScaler if data type is float16
317 | for micro_step in range(gradient_accumulation_steps):
318 | if ddp:
319 | # in DDP training we only need to sync gradients at the last micro step.
320 | # the official way to do this is with model.no_sync() context manager, but
321 | # I really dislike that this bloats the code and forces us to repeat code
322 | # looking at the source of that context manager, it just toggles this variable
323 | model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
324 | with ctx:
325 | logits, loss = model(X, Y)
326 | # immediately async prefetch next batch while model is doing the forward pass on the GPU
327 | X, Y = get_batch('train')
328 | # backward pass, with gradient scaling if training in fp16
329 | scaler.scale(loss).backward()
330 | # clip the gradient
331 | if grad_clip != 0.0:
332 | scaler.unscale_(optimizer)
333 | total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
334 | if total_norm.item() > grad_clip:
335 | clip_time += 1
336 | # step the optimizer and scaler if training in fp16
337 | scaler.step(optimizer)
338 | scaler.update()
339 | # flush the gradients as soon as we can, no need for this memory anymore
340 | optimizer.zero_grad(set_to_none=True)
341 |
342 | # timing and logging
343 | t1 = time.time()
344 | dt = t1 - t0
345 | t0 = t1
346 | if iter_num % log_interval == 0 and master_process:
347 | lossf = loss.item() # loss as float. note: this is a CPU-GPU sync point
348 | if local_iter_num >= 5: # let the training loop settle a bit
349 | mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt)
350 | running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu
351 | print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%")
352 | params = []
353 | for (name, p) in model.named_parameters():
354 | params.append(p)
355 | total_param_norm = 0
356 | for p in params:
357 | param_norm = p.data.norm(2)
358 | total_param_norm += param_norm.item() ** 2
359 | total_param_norm = total_param_norm ** 0.5
360 | momentum_norm = 0
361 | LL = len(optimizer.state_dict()['state'])
362 | for jj in range(LL):
363 | momentum_norm += (optimizer.state_dict()['state'][jj]['exp_avg'].detach().norm(2)) ** 2
364 | momentum_norm = torch.sqrt(momentum_norm).item()
365 | if wandb_log:
366 | wandb.log({
367 | "iter": iter_num,
368 | "train/loss": lossf,
369 | "lr": lr,
370 | "param_norm": total_param_norm,
371 | "momentum_norm" : momentum_norm,
372 | "train/clip_rate": clip_time / (iter_num + 1)
373 | }, step=iter_num)
374 | iter_num += 1
375 | local_iter_num += 1
376 |
377 | # termination conditions
378 | if iter_num > max_iters:
379 | break
380 |
381 | if ddp:
382 | destroy_process_group()
383 |
--------------------------------------------------------------------------------
/MARS/train_adamw_fw.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import math
4 | import pickle
5 | from contextlib import nullcontext
6 |
7 | import numpy as np
8 | import torch
9 | from torch.nn.parallel import DistributedDataParallel as DDP
10 | from torch.distributed import init_process_group, destroy_process_group
11 |
12 | from model import GPTConfig, GPT
13 | import sys
14 | from ast import literal_eval
15 | # -----------------------------------------------------------------------------
16 | # default config values designed to train a gpt2 (124M) on OpenWebText
17 | # I/O
18 | out_dir = 'out'
19 | eval_interval = 2000
20 | log_interval = 1
21 | eval_iters = 200
22 | eval_only = False # if True, script exits right after the first eval
23 | always_save_checkpoint = True # if True, always save a checkpoint after each eval
24 | init_from = 'scratch' # 'scratch' or 'resume' or 'gpt2*'
25 | # wandb logging
26 | wandb_log = False # disabled by default
27 | wandb_project = 'mars'
28 | wandb_run_name = 'gpt2' # 'run' + str(time.time())
29 | # data
30 | dataset = 'fineweb-edu100B'
31 | gradient_accumulation_steps = 5 # used to simulate larger batch sizes
32 | batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size
33 | block_size = 1024
34 | # model
35 | n_layer = 12
36 | n_head = 12
37 | n_embd = 768
38 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
39 | bias = False # do we use bias inside LayerNorm and Linear layers?
40 | # optimizer
41 | optimizer_name = 'adamw'
42 | learning_rate = 6e-4 # max learning rate
43 | max_iters = 600000 # total number of training iterations
44 | weight_decay = 1e-1
45 | beta1 = 0.9
46 | beta2 = 0.95
47 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
48 | interval = 10
49 | variant = 4
50 | schedule='cosine'
51 | # learning rate decay settings
52 | decay_lr = True # whether to decay the learning rate
53 | warmup_iters = 2000 # how many steps to warm up for
54 | lr_decay_iters = 600000 # should be ~= max_iters per Chinchilla
55 | min_lr = 6e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
56 | # DDP settings
57 | backend = 'nccl' # 'nccl', 'gloo', etc.
58 | # system
59 | device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
60 | dtype = 'bfloat16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
61 | compile = True # use PyTorch 2.0 to compile the model to be faster
62 | scale_attn_by_inverse_layer_idx = True
63 | # -----------------------------------------------------------------------------
64 | config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
65 | for arg in sys.argv[1:]:
66 | if '=' not in arg:
67 | # assume it's the name of a config file
68 | assert not arg.startswith('--')
69 | config_file = arg
70 | print(f"Overriding config with {config_file}:")
71 | with open(config_file) as f:
72 | print(f.read())
73 | exec(open(config_file).read())
74 | else:
75 | # assume it's a --key=value argument
76 | assert arg.startswith('--')
77 | key, val = arg.split('=')
78 | key = key[2:]
79 | if key in globals():
80 | try:
81 | # attempt to eval it it (e.g. if bool, number, or etc)
82 | attempt = literal_eval(val)
83 | except (SyntaxError, ValueError):
84 | # if that goes wrong, just use the string
85 | attempt = val
86 | # ensure the types match ok
87 | assert type(attempt) == type(globals()[key])
88 | # cross fingers
89 | print(f"Overriding: {key} = {attempt}")
90 | globals()[key] = attempt
91 | else:
92 | raise ValueError(f"Unknown config key: {key}")
93 |
94 | config = {k: globals()[k] for k in config_keys} # will be useful for logging
95 | # -----------------------------------------------------------------------------
96 |
97 | # various inits, derived attributes, I/O setup
98 | ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
99 | if ddp:
100 | init_process_group(backend=backend)
101 | ddp_rank = int(os.environ['RANK'])
102 | ddp_local_rank = int(os.environ['LOCAL_RANK'])
103 | device = f'cuda:{ddp_local_rank}'
104 | torch.cuda.set_device(device)
105 | master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
106 | seed_offset = ddp_rank # each process gets a different seed
107 | else:
108 | # if not ddp, we are running on a single gpu, and one process
109 | master_process = True
110 | seed_offset = 0
111 | gradient_accumulation_steps *= 8 # simulate 8 gpus
112 |
113 | if master_process:
114 | os.makedirs(out_dir, exist_ok=True)
115 | torch.manual_seed(5000 + seed_offset)
116 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
117 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
118 | device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
119 | # note: float16 data type will automatically use a GradScaler
120 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
121 | ctx = nullcontext() if device_type == 'cpu' else torch.autocast(device_type=device_type, dtype=ptdtype)
122 |
123 | # poor man's data loader
124 | data_dir = os.path.join('data', dataset)
125 | train_file_list = list(filter(lambda x: x.endswith('.bin') and x.startswith('fineweb_train'), os.listdir(data_dir)))
126 | train_data_list = [np.memmap(os.path.join(data_dir, file), dtype=np.uint16, mode='r') for file in train_file_list]
127 | val_data = np.memmap(os.path.join(data_dir, 'fineweb_val_000000.bin'), dtype=np.uint16, mode='r')
128 | import random
129 | random.seed(5000 + seed_offset)
130 | def get_batch(split):
131 | if split == 'train':
132 | data = random.choice(train_data_list)
133 | else:
134 | data = val_data
135 | offset = 512
136 | ix = torch.randint(len(data) - block_size - offset, (batch_size,))
137 | x = torch.stack([torch.from_numpy((data[offset+i:offset+i+block_size]).astype(np.int64)) for i in ix])
138 | y = torch.stack([torch.from_numpy((data[offset+i+1:offset+i+1+block_size]).astype(np.int64)) for i in ix])
139 | if device_type == 'cuda':
140 | # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
141 | x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
142 | else:
143 | x, y = x.to(device), y.to(device)
144 | return x, y
145 |
146 | # init these up here, can override if init_from='resume' (i.e. from a checkpoint)
147 | iter_num = 0
148 | best_val_loss = 1e9
149 |
150 | # attempt to derive vocab_size from the dataset
151 | meta_path = os.path.join(data_dir, 'meta.pkl')
152 | meta_vocab_size = None
153 | if os.path.exists(meta_path):
154 | with open(meta_path, 'rb') as f:
155 | meta = pickle.load(f)
156 | meta_vocab_size = meta['vocab_size']
157 | print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})")
158 |
159 | # model init
160 | model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size,
161 | bias=bias, vocab_size=None, dropout=dropout, scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx) # start with model_args from command line
162 | if init_from == 'scratch':
163 | # init a new model from scratch
164 | print("Initializing a new model from scratch")
165 | # determine the vocab size we'll use for from-scratch training
166 | if meta_vocab_size is None:
167 | print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)")
168 | model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304
169 | gptconf = GPTConfig(**model_args)
170 | model = GPT(gptconf)
171 | elif init_from == 'resume':
172 | print(f"Resuming training from {out_dir}")
173 | # resume training from a checkpoint.
174 | ckpt_path = os.path.join(out_dir, 'ckpt.pt')
175 | checkpoint = torch.load(ckpt_path, map_location=device)
176 | checkpoint_model_args = checkpoint['model_args']
177 | # force these config attributes to be equal otherwise we can't even resume training
178 | # the rest of the attributes (e.g. dropout) can stay as desired from command line
179 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
180 | model_args[k] = checkpoint_model_args[k]
181 | # create the model
182 | gptconf = GPTConfig(**model_args)
183 | model = GPT(gptconf)
184 | state_dict = checkpoint['model']
185 | # fix the keys of the state dictionary :(
186 | # honestly no idea how checkpoints sometimes get this prefix, have to debug more
187 | unwanted_prefix = '_orig_mod.'
188 | for k,v in list(state_dict.items()):
189 | if k.startswith(unwanted_prefix):
190 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
191 | model.load_state_dict(state_dict)
192 | iter_num = checkpoint['iter_num']
193 | best_val_loss = checkpoint['best_val_loss']
194 | elif init_from.startswith('gpt2'):
195 | print(f"Initializing from OpenAI GPT-2 weights: {init_from}")
196 | # initialize from OpenAI GPT-2 weights
197 | override_args = dict(dropout=dropout)
198 | model = GPT.from_pretrained(init_from, override_args)
199 | # read off the created config params, so we can store them into checkpoint correctly
200 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
201 | model_args[k] = getattr(model.config, k)
202 | # crop down the model block size if desired, using model surgery
203 | if block_size < model.config.block_size:
204 | model.crop_block_size(block_size)
205 | model_args['block_size'] = block_size # so that the checkpoint will have the right value
206 | model.to(device)
207 |
208 | # initialize a GradScaler. If enabled=False scaler is a no-op
209 | scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))
210 |
211 | # optimizer
212 | optimizer = model.configure_optimizers(optimizer_name, weight_decay, learning_rate, (beta1, beta2), device_type)
213 | if init_from == 'resume':
214 | optimizer.load_state_dict(checkpoint['optimizer'])
215 | del state_dict
216 | del checkpoint
217 | # compile the model
218 | if compile:
219 | print("compiling the model... (takes a ~minute)")
220 | unoptimized_model = model
221 | model = torch.compile(model) # requires PyTorch 2.0
222 |
223 | # wrap model into DDP container
224 | if ddp:
225 | model = DDP(model, device_ids=[ddp_local_rank])
226 |
227 | # helps estimate an arbitrarily accurate loss over either split using many batches
228 | @torch.no_grad()
229 | def estimate_loss():
230 | out = {}
231 | model.eval()
232 | for split in ['train', 'val']:
233 | losses = torch.zeros(eval_iters)
234 | for k in range(eval_iters):
235 | X, Y = get_batch(split)
236 | with ctx:
237 | logits, loss = model(X, Y)
238 | losses[k] = loss.item()
239 | out[split] = losses.mean()
240 | model.train()
241 | return out
242 |
243 | # learning rate decay scheduler (cosine with warmup)
244 | def get_lr(it, schedule='cosine'):
245 | #ing rate schedule {schedule}")
246 | # 1) linear warmup for warmup_iters steps
247 | if it < warmup_iters:
248 | return learning_rate * it / warmup_iters
249 | # 2) if it > lr_decay_iters, return min learning rate
250 | if schedule=='wsd':
251 | if it < 0.8 * max_iters:
252 | return learning_rate
253 | else:
254 | return learning_rate * (max_iters - it) / (max_iters * 0.2)
255 | if it > lr_decay_iters:
256 | return min_lr
257 | # 3) in between, use cosine decay down to min learning rate
258 | decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
259 | assert 0 <= decay_ratio <= 1
260 | if schedule=='cosine':
261 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
262 | elif schedule=='exp':
263 | coeff = np.power(0.9, 100 * decay_ratio)
264 |
265 | return min_lr + coeff * (learning_rate - min_lr)
266 |
267 | # logging
268 | if wandb_log and master_process:
269 | import wandb
270 | wandb.init(project=wandb_project, name=wandb_run_name, config=config)
271 |
272 | # training loop
273 | X, Y = get_batch('train') # fetch the very first batch
274 | t0 = time.time()
275 | local_iter_num = 0 # number of iterations in the lifetime of this process
276 | raw_model = model.module if ddp else model # unwrap DDP container if needed
277 | running_mfu = -1.0
278 | clip_time = 0
279 | while True:
280 |
281 | # determine and set the learning rate for this iteration
282 | lr = get_lr(iter_num, schedule=schedule) if decay_lr else learning_rate
283 | for param_group in optimizer.param_groups:
284 | param_group['lr'] = lr
285 |
286 | # evaluate the loss on train/val sets and write checkpoints
287 | if iter_num % eval_interval == 0 and master_process:
288 | losses = estimate_loss()
289 | print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
290 | if wandb_log:
291 | wandb.log({
292 | "iter": iter_num,
293 | "train/loss": losses['train'],
294 | "val/loss": losses['val'],
295 | "lr": lr,
296 | "mfu": running_mfu*100, # convert to percentage
297 | }, step=iter_num)
298 | if losses['val'] < best_val_loss or always_save_checkpoint:
299 | best_val_loss = losses['val']
300 | if iter_num > 0:
301 | checkpoint = {
302 | 'model': raw_model.state_dict(),
303 | 'optimizer': optimizer.state_dict(),
304 | 'model_args': model_args,
305 | 'iter_num': iter_num,
306 | 'best_val_loss': best_val_loss,
307 | 'config': config,
308 | }
309 | print(f"saving checkpoint to {out_dir}")
310 | torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
311 | if iter_num % (eval_interval * 5) == 0:
312 | checkpoint = {
313 | 'model': raw_model.state_dict(),
314 | 'optimizer': optimizer.state_dict(),
315 | 'model_args': model_args,
316 | 'iter_num': iter_num,
317 | 'best_val_loss': best_val_loss,
318 | 'config': config,
319 | }
320 | print(f"saving checkpoint to {out_dir}")
321 | torch.save(checkpoint, os.path.join(out_dir, 'ckpt_'+str(iter_num)+'.pt'))
322 | if iter_num == 0 and eval_only:
323 | break
324 |
325 | # forward backward update, with optional gradient accumulation to simulate larger batch size
326 | # and using the GradScaler if data type is float16
327 | for micro_step in range(gradient_accumulation_steps):
328 | if ddp:
329 | # in DDP training we only need to sync gradients at the last micro step.
330 | # the official way to do this is with model.no_sync() context manager, but
331 | # I really dislike that this bloats the code and forces us to repeat code
332 | # looking at the source of that context manager, it just toggles this variable
333 | model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
334 | with ctx:
335 | logits, loss = model(X, Y)
336 | # immediately async prefetch next batch while model is doing the forward pass on the GPU
337 | X, Y = get_batch('train')
338 | # backward pass, with gradient scaling if training in fp16
339 | scaler.scale(loss).backward()
340 | # clip the gradient
341 | if grad_clip != 0.0:
342 | scaler.unscale_(optimizer)
343 | total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
344 | if total_norm.item() > grad_clip:
345 | clip_time += 1
346 | # step the optimizer and scaler if training in fp16
347 | scaler.step(optimizer)
348 | scaler.update()
349 | # flush the gradients as soon as we can, no need for this memory anymore
350 | optimizer.zero_grad(set_to_none=True)
351 |
352 | # timing and logging
353 | t1 = time.time()
354 | dt = t1 - t0
355 | t0 = t1
356 | if iter_num % log_interval == 0 and master_process:
357 | lossf = loss.item() # loss as float. note: this is a CPU-GPU sync point
358 | if local_iter_num >= 5: # let the training loop settle a bit
359 | mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt)
360 | running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu
361 | print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%")
362 | params = []
363 | for (name, p) in model.named_parameters():
364 | params.append(p)
365 | total_param_norm = 0
366 | for p in params:
367 | param_norm = p.data.norm(2)
368 | total_param_norm += param_norm.item() ** 2
369 | total_param_norm = total_param_norm ** 0.5
370 | momentum_norm = 0
371 | LL = len(optimizer.state_dict()['state'])
372 | for jj in range(LL):
373 | momentum_norm += (optimizer.state_dict()['state'][jj]['exp_avg'].detach().norm(2)) ** 2
374 | momentum_norm = torch.sqrt(momentum_norm).item()
375 | if wandb_log:
376 | wandb.log({
377 | "iter": iter_num,
378 | "train/loss": lossf,
379 | "lr": lr,
380 | "param_norm": total_param_norm,
381 | "momentum_norm" : momentum_norm,
382 | "train/clip_rate": clip_time / (iter_num + 1)
383 | }, step=iter_num)
384 | iter_num += 1
385 | local_iter_num += 1
386 |
387 | # termination conditions
388 | if iter_num > max_iters:
389 | break
390 |
391 | if ddp:
392 | destroy_process_group()
393 |
--------------------------------------------------------------------------------