├── .gitignore
├── Dockerfile
├── LICENSE
├── README.md
├── data
├── cached_fineweb100B.py
├── cached_fineweb10B.py
├── cached_finewebedu10B.py
├── fineweb.py
└── requirements.txt
├── img
├── algo_optimizer.png
├── dofa.jpg
├── fig_optimizer.png
├── fig_tuned_nanogpt.png
├── nanogpt_speedrun51.png
├── nanogpt_speedrun52.png
├── nanogpt_speedrun53.png
└── nanogpt_speedrun54.png
├── records
├── 010425_SoftCap
│ ├── 31d6c427-f1f7-4d8a-91be-a67b5dcd13fd.txt
│ ├── README.md
│ └── curves_010425.png
├── 011325_Fp8LmHead
│ ├── README.md
│ └── c51969c2-d04c-40a7-bcea-c092c3c2d11a.txt
├── 011625_Sub3Min
│ ├── 1d3bd93b-a69e-4118-aeb8-8184239d7566.txt
│ ├── README.md
│ ├── attn-entropy.png
│ ├── attn-scales-pattern.gif
│ ├── learned-attn-scales.png
│ └── long-short-swa.png
├── 011825_GPT2Medium
│ ├── 241dd7a7-3d76-4dce-85a4-7df60387f32a.txt
│ └── main.log
├── 012625_BatchSize
│ ├── 0bdd5ee9-ac28-4202-bdf1-c906b102b0ec.txt
│ ├── README.md
│ ├── ablations.png
│ ├── c44090cc-1b99-4c95-8624-38fb4b5834f9.txt
│ ├── val_losses.png
│ └── wallclock.png
├── 020125_RuleTweak
│ └── eff63a8c-2f7e-4fc5-97ce-7f600dae0bc7.txt
├── 020825_GPT2MediumWeightDecay
│ └── b01743db-605c-4326-b5b1-d388ee5bebc5.txt
├── 021425_GPT2MediumOptCoeffs
│ └── 1baa66b2-bff7-4850-aced-d63885ffb4b6.txt
├── 030625_GPT2MediumLongerCooldown
│ ├── 779c041a-2a37-45d2-a18b-ec0f223c2bb7.txt
│ └── README.md
├── 032525_GPT2MediumArchOptTweaks
│ └── train_gpt-20250329.txt
├── 041625_GPT2Medium_Record7
│ ├── 223_3310d0b1-b24d-48ee-899f-d5c2a254a195.txt
│ └── README.md
├── 042225_GPT2Medium_Record8
│ └── 075_640429f2-e726-4e83-aa27-684626239ffc.txt
├── 052425_FasterReduce
│ └── 23f40b75-06fb-4c3f-87a8-743524769a35.txt
├── 052425_StableTorch
│ └── 89d9f224-3b01-4581-966e-358d692335e0.txt
├── 052525_EvenFasterReduce
│ └── 6ae86d05-5cb2-4e40-a512-63246fd08e45.txt
├── 052525_MuonWithAuxAdamExample
│ └── b01550f9-03d8-4a9c-86fe-4ab434f1c5e0.txt
├── 060624_AdamW
│ ├── README.md
│ └── f66d43d7-e449-4029-8adf-e8537bab49ea.log
├── 100924_SOAP
│ ├── 5bdc3988-496c-4232-b4ef-53764cb81c92.txt
│ ├── README.md
│ └── train_gpt2.py
├── 101024_Muon
│ ├── eb5659d0-fb6a-49e5-a311-f1f89412f726.txt
│ └── train_gpt2.py
├── 101324_llmc
│ ├── README.md
│ └── main.log
├── 101424_ModernArch
│ ├── dabaaddd-237c-4ec9-939d-6608a9ed5e27.txt
│ └── train_gpt2.py
├── 101724_DistributedMuon
│ └── 22d24867-eb5a-4fcc-ae2c-263d0277dfd1.txt
├── 101824_PyTorch25
│ └── d4bfb25f-688d-4da5-8743-33926fad4842.txt
├── 102024_ScaleUp1B
│ ├── 87bd51fd-6203-4c88-b3aa-8a849a6a83ca.txt
│ ├── ad8d7ae5-7b2d-4ee9-bc52-f912e9174d7a.txt
│ └── c0078066-c8c9-49c8-868a-ff4d4f32e615.txt
├── 102924_Optimizers
│ ├── 8bfe4e35-c3fc-4b70-a984-3be937b71ff3.txt
│ ├── 8d6193f4-27fc-4e68-899f-af70019a4d54.txt
│ ├── 95a9fd44-7c13-49c7-b324-3e7d9e23a499.txt
│ ├── README.md
│ ├── e21a2838-a0f2-46f2-a247-db0021165682.txt
│ ├── nanogpt_speedrun81w.png
│ └── nanogpt_speedrun82w.png
├── 110324_UntieEmbed
│ ├── README.md
│ └── d6b50d71-f419-4d26-bb39-a60d55ae7a04.txt
├── 110424_50Bruns
│ ├── 3d715d41-453a-40d6-9506-421ba69766b2.txt
│ ├── 4fbe61ec-f79a-4c19-836d-46d599deecce.txt
│ ├── 530f3ee1-8862-4d21-be2b-da10eb05e6a9.txt
│ ├── 69c33fc9-eabb-4a38-aa08-6922914eb405.txt
│ └── README.md
├── 110624_ShortcutsTweaks
│ ├── 042f9e87-07e6-4504-bb04-4ec59a380211.txt
│ ├── 05b29e54-0be0-4a0f-a1e2-7d5317daedd3.txt
│ ├── 10119f53-7001-4248-bfd9-33d32427a912.txt
│ ├── 43f60c4f-0448-4de7-83d9-643ca26f61e7.txt
│ ├── 4a71cc92-0f43-4058-a033-23e85c1e98f1.txt
│ ├── README.md
│ ├── d6b50d71-f419-4d26-bb39-a60d55ae7a04.txt
│ ├── dd7304a6-cc43-4d5e-adb8-c070111464a1.txt
│ ├── nanogpt_speedrun110.png
│ └── nanogpt_speedrun111.png
├── 110824_CastBf16
│ └── a833bed8-2fa8-4cfe-af05-58c1cc48bc30.txt
├── 110924_Replicateleloykun
│ ├── 1621af10-aa0c-42af-bf54-8a773c63a2af.txt
│ └── README.md
├── 111024_ScaleShortcuts
│ ├── 3e55eb2e-6261-466a-b1e9-2b31f56fb16a.txt
│ ├── 4897c987-9d09-435c-a23f-20585912936a.txt
│ ├── 70a0ada6-8dee-4fef-8980-135379479c21.txt
│ ├── a833bed8-2fa8-4cfe-af05-58c1cc48bc30.txt
│ └── d6b50d71-f419-4d26-bb39-a60d55ae7a04.txt
├── 111024_UNetDoubleLr
│ ├── README.md
│ └── c87bb826-797b-4f37-98c7-d3a5dad2de74.txt
├── 111424_QuantizedFP4
│ ├── 433c1732-0c3d-4099-a4a8-ec31eae49b16.txt
│ ├── 70a0ada6-8dee-4fef-8980-135379479c21.txt
│ ├── 932bbe0e-41c3-4a5b-94bd-4ea3350909bd.txt
│ └── a833bed8-2fa8-4cfe-af05-58c1cc48bc30.txt
├── 111924_FlexAttention
│ ├── 8384493d-dba9-4991-b16b-8696953f5e6d.txt
│ └── README.md
├── 112424_WindowWarmup
│ ├── 3151f69a-f89e-452c-ac80-b85118007583.txt
│ ├── 4428858e-7cb8-4a25-a936-818d8f28de51.txt
│ ├── ae732e01-04b2-4665-b570-a77210e73e28.txt
│ ├── ba299b7e-a36a-4fd8-a268-25bb772010dd.txt
│ ├── cf9e4571-c5fc-4323-abf3-a98d862ec6c8.txt
│ ├── d1cf11aa-7b8e-4d28-a94d-1aab632e0f38.txt
│ └── dca62101-15d4-4c76-842e-99213fa2508b.txt
├── 120424_ValueEmbed
│ ├── 00008ea0-21dd-442a-82ee-d12799249d0f.txt
│ ├── 123ac41a-8f7b-42a1-853c-5a791dcaa7ae.txt
│ ├── 14511f40-47db-4c94-b35b-70616770fd2d.txt
│ ├── 19bb65fb-f903-4a41-803b-fbd57562f653.txt
│ ├── 1b4db08c-46d9-4e1b-a8e6-872a685061c3.txt
│ ├── 2358dd3a-8ce4-4b8a-a367-fca6dcd38343.txt
│ ├── 2577721d-9ce9-400c-8902-ce95d6fbcf64.txt
│ ├── 2d140d43-3323-408e-b559-3639c0880323.txt
│ ├── 2e527b0b-3540-4bcd-a15d-955f86cb8bd2.txt
│ ├── 2f4ce5fe-b625-41b4-acbb-d4c20b591ead.txt
│ ├── 385d2312-0cf9-48c3-af3f-35c12c12a38d.txt
│ ├── 3e39d446-d7e8-4c49-b92c-62c3848a3c01.txt
│ ├── 4f8bcdc3-18cf-4743-895e-1deb08696fe7.txt
│ ├── 51b3baf0-69d6-43ee-a88b-2c5c28e3dd5b.txt
│ ├── 51faed93-0804-418c-9057-0d94c3f94a9c.txt
│ ├── 66173c47-b15b-4a24-a835-60c82f6b8283.txt
│ ├── 67716aee-6747-4997-a37c-b96932fab4dd.txt
│ ├── 6b244191-77a3-41ea-a314-82c6a9184b31.txt
│ ├── 74cba1d4-da56-4334-9622-e0aa960dfe3f.txt
│ ├── 75a3af7b-f1a6-47dc-a989-d95e4419ff31.txt
│ ├── 87f81569-fa04-4eb3-8b75-42c116e96ba0.txt
│ ├── 8bd08106-6eb1-4cd1-9779-ccf1192bda1b.txt
│ ├── 92b1541a-a6d6-4ddf-8932-ea4bcd31ba3b.txt
│ ├── 949e5cfd-cb9c-48e7-a888-551981582a9b.txt
│ ├── 968f73e2-b588-4102-80b0-996bae126be1.txt
│ ├── README.md
│ ├── a5c54ede-2647-4df1-872d-de033898a9e2.txt
│ ├── b19b341a-bf8d-46f9-8076-fef1fcb7445e.txt
│ ├── bb60727f-25b8-4e01-bf08-293aa3860e7c.txt
│ ├── c249f3c7-b947-4b6f-8a42-56e97ef7712e.txt
│ ├── c8e1a7d3-a37e-4a88-b28a-3afb2d8089ca.txt
│ ├── d12cb409-0f5d-4624-951c-60119a482bca.txt
│ ├── d6520673-0f5f-4c28-898b-f52d056b257d.txt
│ ├── d884d4cb-0656-460d-9454-897eb9789f2a.txt
│ ├── e0248b14-212b-436b-9a12-ba142d720ab4.txt
│ ├── e3ffb000-a388-4bdc-b88f-2f85794c13c8.txt
│ ├── ed14c8b2-2ac1-41e0-acea-3cc55cd94f83.txt
│ ├── eecabe35-2224-4988-9910-8e8434a0c281.txt
│ ├── f0f173e3-69ee-4970-ba0b-f7f3e5d92e33.txt
│ └── train_gpt2.py
├── 120824_UNetValueEmbedsTweaks
│ ├── 0069607b-aa90-49fd-9766-4368bcd168c4.txt
│ ├── 08d38857-18d8-406f-907b-c80142fec515.txt
│ ├── 0ad22fc7-dd0c-4d79-a216-c7667db09e15.txt
│ ├── 13daf1c4-986b-4938-b269-1ac935891505.txt
│ ├── 14dfd7b3-d65a-4813-8327-03a9f1bf9f6e.txt
│ ├── 16cacfc3-de94-429e-be3a-03c3c5293d08.txt
│ ├── 19c57414-499c-4892-b25c-a7083583fa59.txt
│ ├── 1a1fde38-fd1f-4cc9-a648-74b8d6dab734.txt
│ ├── 1b363e84-a47e-48c9-ae21-7fb825512c9a.txt
│ ├── 23ed0b57-544b-4fc5-b780-062dcf3f6f2c.txt
│ ├── 257dc15f-5bfd-4c0b-9349-a46b3c13d5ef.txt
│ ├── 26fa5797-44d0-4a63-9e57-f435f2f59aad.txt
│ ├── 29354340-7020-42e4-8b06-b9a7c85e3d51.txt
│ ├── 43ed2395-9260-48d6-9ab2-dd6f7f4684b5.txt
│ ├── 472b8553-de0d-4f91-9285-10781b4cd07f.txt
│ ├── 4732cc5e-a214-47d6-bfce-cb4ae2f663c8.txt
│ ├── 48359737-5c81-4653-8d5c-5c2da8405f16.txt
│ ├── 4a0f4d79-dc64-4731-b777-33f1e3ec6d3d.txt
│ ├── 4bcd20c5-9ffd-40e9-8d9a-6d9d647b16ba.txt
│ ├── 4de826f5-a244-490f-8025-5f0d370bd68a.txt
│ ├── 501b69a6-bcbc-4403-9bc0-c257bdc3f8b6.txt
│ ├── 52275a2c-80fd-4743-ad16-1d5098d97821.txt
│ ├── 554a738b-571d-4da1-8556-393e7592eda7.txt
│ ├── 563ca9ab-8f99-4a56-a8be-5ab3b21ac197.txt
│ ├── 569ca2ab-12e9-4035-a0c7-f18c7050d234.txt
│ ├── 59ba1f2d-a3b7-4fa8-b099-f13b838470ee.txt
│ ├── 5bd46aa9-aaa1-4390-b355-41b730732b10.txt
│ ├── 5dc85466-81b7-4770-ba8d-6937f4400758.txt
│ ├── 5e28cee7-fc40-4593-bc83-a22495bfaacc.txt
│ ├── 5f2dc8e6-dfc3-4078-b756-e7521561ebce.txt
│ ├── 5fca3d5f-6290-47af-9130-b78a777b24c4.txt
│ ├── 5fcbae90-4380-412e-8209-51fccc183040.txt
│ ├── 60adb82c-ba8d-41b0-81fc-565439384b72.txt
│ ├── 6175f2f0-8526-4f20-86d9-d7c2a6dcda19.txt
│ ├── 625a6fcc-203c-4545-b697-0b8daa2b6d07.txt
│ ├── 677ef1aa-30f7-4c4d-96c8-b85c7d70c4d6.txt
│ ├── 68b508e5-a986-40d7-a307-21c282389be5.txt
│ ├── 6a4a3bd3-c3b5-4bef-9267-03724ba49759.txt
│ ├── 6bf9356b-7ff8-4253-97b0-c2747ad5a52e.txt
│ ├── 71e11d12-c639-4966-bd9a-be49de53bc9c.txt
│ ├── 7639d7d4-a962-4193-8e57-c1d76ea22f54.txt
│ ├── 7de3e01c-69f5-4cfe-9e99-97413d15488c.txt
│ ├── 81074c42-378a-4231-a37d-10f5c477a78c.txt
│ ├── 83d3d075-ede1-4b85-93ab-24dc38ae1bf5.txt
│ ├── 83fd4151-f2d1-4d8d-ad31-5546543072f3.txt
│ ├── 88baa340-752f-4874-9dc8-f2f7831d78e4.txt
│ ├── 89ac99c6-838d-431b-967a-4bde4d81d6c4.txt
│ ├── 8a0afde4-391d-4e2d-8df1-e44fe0b80feb.txt
│ ├── 8b074f6f-a4ad-42ae-9460-679426524ddb.txt
│ ├── 8b82b55d-6955-4100-a1a9-5630cb85683c.txt
│ ├── 93406931-1cc4-419c-b7a7-adef928d4c54.txt
│ ├── 938ecd23-aaa8-4975-b6da-9407ab45e0f9.txt
│ ├── 94df9d58-590d-4171-a015-e59e144d7206.txt
│ ├── 95dc5712-2fd2-45bf-89c9-a86d7704e3d0.txt
│ ├── 9e4f1bfe-0f1a-4bd8-af1b-072dcaa61c4c.txt
│ ├── abf007e8-6370-4987-bd1a-85bff108fe5e.txt
│ ├── ad6d9498-e76d-4ce4-acb5-fd2116eb77d8.txt
│ ├── b7197dc5-b590-4e32-8590-a8c0076b64ab.txt
│ ├── bbc9dd90-96c5-480b-984c-594145555821.txt
│ ├── bc895608-3be1-4365-89ff-abbf6c316b37.txt
│ ├── bdcab079-8761-4bb8-b2bc-adc2a45fbc9a.txt
│ ├── be800e9a-90a5-457d-8f86-a943f4dea20c.txt
│ ├── c2e4c338-f228-4ece-a8cb-1ccf5e949817.txt
│ ├── c8cecdbc-504d-4df2-b844-fadc9bddabfa.txt
│ ├── d13cc58e-654c-432b-b1b6-1c8b8ae8a232.txt
│ ├── d30450c5-ecc2-426d-9375-bcf48e6123d8.txt
│ ├── d69df4f4-5e17-45a3-94ba-5c2ec73a82e5.txt
│ ├── e276c0fa-854b-44ad-830b-cdf887eaf6c3.txt
│ ├── e2c6f3ca-140e-43fa-862e-4ee48c04422e.txt
│ ├── e2f54f7a-f6c5-4fcf-bea6-e0442b62adca.txt
│ ├── e3319113-ec3a-457a-975c-e89dcddd15f2.txt
│ ├── e5589f56-4312-44f2-9254-95304cae10cf.txt
│ ├── e66b0dd9-9680-4e7a-85c7-ac77f1d89aa2.txt
│ ├── f2fda86f-4418-4ff8-a3a3-8112731cbcf9.txt
│ └── f9a93608-ed4e-46ab-9c06-07f90f1328a6.txt
├── 121024_MFUTweaks
│ ├── 04df8c70-2462-4b6c-9e49-c7589b321a81.txt
│ ├── 052953d9-0495-43db-9d15-ae98fff06f18.txt
│ ├── 0aa83756-53f0-4268-9721-db6d5985bc42.txt
│ ├── 217c8375-d34f-45ee-9bee-ca8f9a6608da.txt
│ ├── 2d25e67e-6ab4-45ff-92e0-d223bcf1784e.txt
│ ├── 2d564428-0822-4d01-a61b-f53e4cd646c6.txt
│ ├── 30beb830-0f15-4a71-8c98-187ea2015606.txt
│ ├── 35c28373-6cf8-463c-8007-1146801f8068.txt
│ ├── 38757b64-ade6-45fa-b628-11cc83b6ed0c.txt
│ ├── 3a714803-f1e1-4b81-b984-e2457ccb664f.txt
│ ├── 409647d3-2887-4760-9cb4-327d3bb219a4.txt
│ ├── 47029c45-997b-4498-a5c0-2fc76d365eb8.txt
│ ├── 5175d854-1dcb-41e1-a690-b223fa69fd7f.txt
│ ├── 571574de-d894-49b1-8f02-281ddd607823.txt
│ ├── 591496b6-fdce-4902-b279-4360b378865d.txt
│ ├── 59b474de-a528-4040-b6f4-19c89c2041cf.txt
│ ├── 7442abb4-a571-4340-9844-de16209c3762.txt
│ ├── 76ee9a40-0f61-42bb-849b-bf46c3a3e7c9.txt
│ ├── 7787c6ea-33f2-4c63-a345-a8e8a830e595.txt
│ ├── 8dff461f-9696-4d45-b717-bad7b1cb2060.txt
│ ├── 91b48522-0c7b-4fee-83ea-c01df3e3d5c0.txt
│ ├── 9323a20a-8f7f-4549-8997-dde75dfc6ec6.txt
│ ├── a3e860f0-98ed-4845-9aee-cccd01fd9b5b.txt
│ ├── a55c6778-e76a-44bc-a0e1-68c0e7580096.txt
│ ├── a880e114-d133-469c-85db-ab58bc507d04.txt
│ ├── ac154191-d2fb-478a-b6ea-55276b4f6655.txt
│ ├── afcec83b-9286-455e-81a6-5eb3710d0ead.txt
│ ├── c01ffedd-443f-421b-b7ca-492e8187557b.txt
│ ├── c80b525d-de02-4fb0-b4ae-44f3308474ea.txt
│ ├── d00194c9-e6d0-4767-b703-7e3c3ff24881.txt
│ ├── d6ab1d3a-92f5-4847-8f3b-69faabc89c67.txt
│ ├── d7c9acce-8b10-4f22-a505-4741319decc3.txt
│ ├── e0f4f91a-4b79-4a39-a316-9ce769921840.txt
│ ├── e1e718d6-89dc-4794-a969-17b2977b004c.txt
│ ├── e221be41-2c83-4158-98dd-2df74ad2b9ba.txt
│ ├── ea08b5fb-abc8-46b0-a643-4561d621d78d.txt
│ ├── eca39347-6616-4903-ad6b-bab828aa2f78.txt
│ ├── ed5a4791-fc3b-4fbf-b871-5c1c34f19b5a.txt
│ ├── f2546c91-e2b3-4906-9f5b-bee25b47216f.txt
│ └── feed17f1-d377-484b-b04b-0124144dfc62.txt
├── 121724_SparsifyEmbeds
│ ├── 165384d5-5a41-4fba-9b79-8ecc372616ea.txt
│ ├── README.md
│ └── loss_hist_121724.png
└── 123124_Target350M
│ ├── README.md
│ └── train_gpt.py
├── requirements.txt
├── run.sh
├── train_gpt.py
└── train_gpt_medium.py
/.gitignore:
--------------------------------------------------------------------------------
1 | fineweb10B/
2 | pylog124M/
3 | __pycache__/
4 | logs/
5 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvidia/cuda:12.6.2-cudnn-devel-ubuntu24.04
2 |
3 | ENV DEBIAN_FRONTEND=noninteractive
4 | ENV PYTHON_VERSION=3.12.7
5 | ENV PATH=/usr/local/bin:$PATH
6 |
7 | RUN apt update && apt install -y --no-install-recommends build-essential libssl-dev zlib1g-dev \
8 | libbz2-dev libreadline-dev libsqlite3-dev curl git libncursesw5-dev xz-utils tk-dev libxml2-dev \
9 | libxmlsec1-dev libffi-dev liblzma-dev \
10 | && apt clean && rm -rf /var/lib/apt/lists/*
11 |
12 | RUN curl -O https://www.python.org/ftp/python/${PYTHON_VERSION}/Python-${PYTHON_VERSION}.tgz && \
13 | tar -xzf Python-${PYTHON_VERSION}.tgz && \
14 | cd Python-${PYTHON_VERSION} && \
15 | ./configure --enable-optimizations && \
16 | make -j$(nproc) && \
17 | make altinstall && \
18 | cd .. && \
19 | rm -rf Python-${PYTHON_VERSION} Python-${PYTHON_VERSION}.tgz
20 |
21 | RUN ln -s /usr/local/bin/python3.12 /usr/local/bin/python && \
22 | ln -s /usr/local/bin/pip3.12 /usr/local/bin/pip
23 |
24 | COPY requirements.txt /modded-nanogpt/requirements.txt
25 | WORKDIR /modded-nanogpt
26 |
27 | RUN python -m pip install --upgrade pip && \
28 | pip install -r requirements.txt
29 |
30 | RUN pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 --upgrade
31 |
32 | CMD ["bash"]
33 | ENTRYPOINT []
34 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Keller Jordan
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Modded-NanoGPT
2 |
3 | This repository hosts the *NanoGPT speedrun*, in which we (collaboratively|competitively) search for the fastest algorithm to use 8 NVIDIA H100 GPUs to train a language model that attains 3.28 cross-entropy loss on the [FineWeb](https://huggingface.co/datasets/HuggingFaceFW/fineweb) validation set.
4 |
5 | The target (3.28 validation loss on FineWeb) follows Andrej Karpathy's [GPT-2 replication in llm.c, which attains that loss after running for 45 minutes](https://github.com/karpathy/llm.c/discussions/481#:~:text=By%20the%20end%20of%20the%20optimization%20we%27ll%20get%20to%20about%203.29).
6 | The speedrun code also descends from llm.c's [PyTorch trainer](https://github.com/karpathy/llm.c/blob/master/train_gpt2.py), which itself descends from NanoGPT, hence the name of the repo.
7 | Thanks to the efforts of many contributors, this repo now contains a training algorithm which attains the target performance in:
8 | * 3 minutes on 8xH100 (the llm.c GPT-2 replication needed 45)
9 | * 0.73B tokens (the llm.c GPT-2 replication needed 10B)
10 |
11 | This improvement in training speed has been brought about by the following techniques:
12 | * Modernized architecture: Rotary embeddings, QK-Norm, and ReLU²
13 | * The Muon optimizer [[writeup](https://kellerjordan.github.io/posts/muon/)] [[repo](https://github.com/KellerJordan/Muon)]
14 | * Untie head from embedding, use FP8 matmul for head, and softcap logits (the latter following Gemma 2)
15 | * Initialization of projection and classification layers to zero (muP-like)
16 | * Skip connections from embedding to every block as well as between blocks in U-net pattern
17 | * Extra embeddings which are mixed into the values in attention layers (inspired by Zhou et al. 2024)
18 | * FlexAttention with long-short sliding window attention pattern (inspired by Gemma 2) and window size warmup
19 |
20 | As well as many systems optimizations.
21 |
22 | Contributors list (growing with each new record): [@bozavlado](https://x.com/bozavlado); [@brendanh0gan](https://x.com/brendanh0gan);
23 | [@fernbear.bsky.social](https://bsky.app/profile/fernbear.bsky.social); [@Grad62304977](https://x.com/Grad62304977);
24 | [@jxbz](https://x.com/jxbz); [@kellerjordan0](https://x.com/kellerjordan0);
25 | [@KoszarskyB](https://x.com/KoszarskyB); [@leloykun](https://x.com/@leloykun);
26 | [@YouJiacheng](https://x.com/YouJiacheng); [@jadenj3o](https://x.com/jadenj3o);
27 | [@KonstantinWilleke](https://github.com/KonstantinWilleke), [@alexrgilbert](https://github.com/alexrgilbert), [@adricarda](https://github.com/adricarda),
28 | [@tuttyfrutyee](https://github.com/tuttyfrutyee), [@vdlad](https://github.com/vdlad);
29 | [@ryanyang0](https://x.com/ryanyang0)
30 |
31 |
32 | ---
33 |
34 | ## Running the current record
35 |
36 | To run the current record, run the following commands.
37 | ```bash
38 | git clone https://github.com/KellerJordan/modded-nanogpt.git && cd modded-nanogpt
39 | pip install -r requirements.txt
40 | pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 --upgrade
41 | # downloads only the first 800M training tokens to save time
42 | python data/cached_fineweb10B.py 8
43 | ./run.sh
44 | ```
45 |
46 | **Note: torch.compile will add around 5 minutes of latency the first time you run the code.**
47 |
48 | ## Alternative: Running with Docker (recommended for precise timing)
49 |
50 | For cases where CUDA or NCCL versions aren't compatible with your current system setup, Docker can be a helpful alternative.
51 | This approach standardizes versions for CUDA, NCCL, CUDNN, and Python, reducing dependency issues and simplifying setup.
52 | Note: an NVIDIA driver must already be installed on the system (useful if only the NVIDIA driver and Docker are available).
53 |
54 | ```bash
55 | git clone https://github.com/KellerJordan/modded-nanogpt.git && cd modded-nanogpt
56 | sudo docker build -t modded-nanogpt .
57 | sudo docker run -it --rm --gpus all -v $(pwd):/modded-nanogpt modded-nanogpt python data/cached_fineweb10B.py 8
58 | sudo docker run -it --rm --gpus all -v $(pwd):/modded-nanogpt modded-nanogpt sh run.sh
59 | ```
60 |
61 | To get an interactive docker, you can use
62 | ```bash
63 | sudo docker run -it --rm --gpus all -v $(pwd):/modded-nanogpt modded-nanogpt bash
64 | ```
65 |
66 | ---
67 |
68 | ## World record history
69 |
70 | The following is the historical progression of world speed records for the following competitive task:
71 |
72 | > *Train a neural network to ≤3.28 validation loss on FineWeb using 8x NVIDIA H100s.*
73 |
74 | Note: The 3.28 target was selected to match [Andrej Karpathy's GPT-2 (small) reproduction](https://github.com/karpathy/llm.c/discussions/481).
75 |
76 | | # | Record time | Description | Date | Log | Contributors |
77 | | - | - | - | - | - | - |
78 | 1 | 45 minutes | [llm.c baseline](https://github.com/karpathy/llm.c/discussions/481) | 05/28/24 | [log](records/101324_llmc/main.log) | @karpathy, llm.c contributors
79 | 2 | 31.4 minutes | [Tuned learning rate & rotary embeddings](https://x.com/kellerjordan0/status/1798863559243513937) | 06/06/24 | [log](records/060624_AdamW/f66d43d7-e449-4029-8adf-e8537bab49ea.log) | @kellerjordan0
80 | 3 | 24.9 minutes | [Introduced the Muon optimizer](https://x.com/kellerjordan0/status/1842300916864844014) | 10/04/24 | none | @kellerjordan0, @jxbz
81 | 4 | 22.3 minutes | [Muon improvements](https://x.com/kellerjordan0/status/1844820919061287009) | 10/11/24 | [log](records/101024_Muon/eb5659d0-fb6a-49e5-a311-f1f89412f726.txt) | @kellerjordan0, @bozavlado
82 | 5 | 15.2 minutes | [Pad embeddings, ReLU², zero-init projections, QK-norm](https://x.com/kellerjordan0/status/1845865698532450646) | 10/14/24 | [log](records/101424_ModernArch/dabaaddd-237c-4ec9-939d-6608a9ed5e27.txt) | @Grad62304977, @kellerjordan0
83 | 6 | 13.1 minutes | [Distributed the overhead of Muon](https://x.com/kellerjordan0/status/1847291684016783746) | 10/18/24 | [log](records/101724_DistributedMuon/22d24867-eb5a-4fcc-ae2c-263d0277dfd1.txt) | @kellerjordan0
84 | 7 | 12.0 minutes | [Upgraded PyTorch 2.5.0](https://x.com/kellerjordan0/status/1847358578686152764) | 10/18/24 | [log](records/101824_PyTorch25/d4bfb25f-688d-4da5-8743-33926fad4842.txt) | @kellerjordan0
85 | 8 | 10.8 minutes | [Untied embedding and head](https://x.com/kellerjordan0/status/1853188916704387239) | 11/03/24 | [log](records/110324_UntieEmbed/d6b50d71-f419-4d26-bb39-a60d55ae7a04.txt) | @Grad62304977, @kellerjordan0
86 | 9 | 8.2 minutes | [Value and embedding skip connections, momentum warmup, logit softcap](https://x.com/kellerjordan0/status/1854296101303800108) | 11/06/24 | [log](records/110624_ShortcutsTweaks/dd7304a6-cc43-4d5e-adb8-c070111464a1.txt) | @Grad62304977, @kellerjordan0
87 | 10 | 7.8 minutes | [Bfloat16 activations](https://x.com/kellerjordan0/status/1855267054774865980) | 11/08/24 | [log](records/110824_CastBf16/a833bed8-2fa8-4cfe-af05-58c1cc48bc30.txt) | @kellerjordan0
88 | 11 | 7.2 minutes | [U-net pattern skip connections & double lr](https://x.com/kellerjordan0/status/1856053121103093922) | 11/10/24 | [log](records/111024_UNetDoubleLr/c87bb826-797b-4f37-98c7-d3a5dad2de74.txt) | @brendanh0gan
89 | 12 | 5.03 minutes | [1024-ctx dense causal attention → 64K-ctx FlexAttention](https://x.com/kellerjordan0/status/1859331370268623321) | 11/19/24 | [log](records/111924_FlexAttention/8384493d-dba9-4991-b16b-8696953f5e6d.txt) | @KoszarskyB
90 | 13 | 4.66 minutes | [Attention window warmup](https://x.com/hi_tysam/status/1860851011797053450) | 11/24/24 | [log](records/112424_WindowWarmup/cf9e4571-c5fc-4323-abf3-a98d862ec6c8.txt) | @fernbear.bsky.social
91 | 14 | 4.41 minutes | [Value Embeddings](https://x.com/KoszarskyB/status/1864746625572257852) | 12/04/24 | [log](records/120424_ValueEmbed) | @KoszarskyB
92 | 15 | 3.95 minutes | [U-net pattern value embeddings, assorted code optimizations](https://x.com/YouJiacheng/status/1865761473886347747) | 12/08/24 | [log](records/120824_UNetValueEmbedsTweaks) | @leloykun, @YouJiacheng
93 | 16 | 3.80 minutes | [Split value embeddings, block sliding window, separate block mask](https://x.com/YouJiacheng/status/1866734331559071981) | 12/10/24 | [log](records/121024_MFUTweaks) | @YouJiacheng
94 | 17 | 3.57 minutes | [Sparsify value embeddings, improve rotary embeddings, drop an attn layer](https://x.com/YouJiacheng/status/1868938024731787640) | 12/17/24 | [log](records/121724_SparsifyEmbeds) | @YouJiacheng
95 | 18 | 3.4 minutes | [Lower logit softcap from 30 to 15](https://x.com/kellerjordan0/status/1876048851158880624) | 01/04/25 | [log](records/010425_SoftCap/31d6c427-f1f7-4d8a-91be-a67b5dcd13fd.txt) | @KoszarskyB
96 | 19 | 3.142 minutes | [FP8 head, offset logits, lr decay to 0.1 instead of 0.0](https://x.com/YouJiacheng/status/1878827972519772241) | 01/13/25 | [log](records/011325_Fp8LmHead/c51969c2-d04c-40a7-bcea-c092c3c2d11a.txt) | @YouJiacheng
97 | 20 | 2.992 minutes | [Merged QKV weights, long-short attention, attention scale, lower Adam epsilon, batched Muon](https://x.com/leloykun/status/1880301753213809016) | 01/16/25 | [log](records/011625_Sub3Min/1d3bd93b-a69e-4118-aeb8-8184239d7566.txt) | @leloykun, @fernbear.bsky.social, @YouJiacheng, @brendanh0gan, @scottjmaddox, @Grad62304977
98 | 21 | 2.933 minutes | [Reduced batch size](https://x.com/leloykun/status/1885640350368420160) | 01/26/25 | [log](records/012625_BatchSize/c44090cc-1b99-4c95-8624-38fb4b5834f9.txt) | @leloykun
99 | 21 | 2.997 minutes | 21st record with new timing | 02/01/25 | [log](records/020125_RuleTweak/eff63a8c-2f7e-4fc5-97ce-7f600dae0bc7.txt) | not a new record, just re-timing #21 with the [updated rules](#timing-change-after-record-21)
100 | 21 | 3.014 minutes | 21st record with latest torch | 05/24/25 | [log](records/052425_StableTorch/89d9f224-3b01-4581-966e-358d692335e0.txt) | not a new record, just re-timing #21 with latest torch
101 | 22 | 2.990 minutes | [Faster gradient all-reduce](https://x.com/KonstantinWille/status/1927137223238909969) | 05/24/25 | [log](records/052425_FasterReduce/23f40b75-06fb-4c3f-87a8-743524769a35.txt) | @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad; The Enigma project
102 | 23 | 2.979 minutes | Overlap computation and gradient communication | 05/25/25 | [log](records/052525_EvenFasterReduce/6ae86d05-5cb2-4e40-a512-63246fd08e45.txt) | @ryanyang0
103 |
104 | ## Rules
105 |
106 | The only rules are that new records must:
107 |
108 | 1. Not modify the train or validation data pipelines. (You can change the batch size, sequence length, attention structure etc.; just don't change the underlying streams of tokens.)
109 | 2. Attain ≤3.28 mean val loss. (Due to inter-run variance, submissions must provide enough run logs to attain a statistical significance level of p<0.01 that their mean val loss is ≤3.28. Example code to compute p-value can be found [here](records/010425_SoftCap#softer-softcap). For submissions which improve speed by optimizing the systems performance, without touching the ML, this requirement is waived.)
110 | 3. Not use any extra `torch._inductor.config` or `torch.compile` flags. (These can save a few seconds, but they can also make compilation take >30min. This rule was introduced after the 21st record.)
111 |
112 | > Note: `torch._inductor.config.coordinate_descent_tuning` is allowed for GPT-2 Medium track (a.k.a. 2.92 track).
113 |
114 | Other than that, anything and everything is fair game!
115 |
116 | [further clarifications](https://github.com/KellerJordan/modded-nanogpt/discussions/23?sort=new#discussioncomment-12109560)
117 |
118 | ---
119 |
120 | ### Comment on the target metric
121 |
122 | The target metric is *cross-entropy loss on the FineWeb val set*. To speak mathematically, the goal of the speedrun is *to obtain a probability model of language which assigns a probability of at least `math.exp(-3.28 * 10485760)` to the first 10,485,760 tokens of the FineWeb valset. Hence, e.g., we allow evaluation at any sequence length, so long as we still have a valid probability model of language.
123 |
124 | ---
125 |
126 | ### Timing change after record 21
127 |
128 | After the 21st record, we made two changes to the timing. First, there used to be an initial "grace period" of 10 untimed steps to allow kernel warmup. We replaced this with an explicit kernel-warmup section which is untimed and uses dummy data. This results in an extra runtime of 850ms from the 10 extra timed steps.
129 | Second, we banned the use of `torch._inductor.config.coordinate_descent_tuning`. This saves ~25min of untimed pre-run compilation, but results in an extra runtime of ~3s.
130 |
131 |
135 |
138 |
139 |
145 |
146 | ---
147 |
148 | ### Notable attempts & forks
149 |
150 | **Notable runs:**
151 |
152 | * [@alexjc's 01/20/2025 2.77-minute TokenMonster-based record](https://x.com/alexjc/status/1881410039639863622).
153 | This record is technically outside the rules of the speedrun, since we specified that the train/val tokens must be kept fixed.
154 | However, it's very interesting, and worth including. The run is not more data-efficient; rather, the speedup comes from the improved tokenizer allowing
155 | the vocabulary size to be reduced (nearly halved!) while preserving the same bytes-per-token, which saves lots of parameters and FLOPs in the head and embeddings.
156 |
157 | **Notable forks:**
158 | * [https://github.com/BlinkDL/modded-nanogpt-rwkv](https://github.com/BlinkDL/modded-nanogpt-rwkv)
159 | * [https://github.com/nikhilvyas/modded-nanogpt-SOAP](https://github.com/nikhilvyas/modded-nanogpt-SOAP)
160 |
161 | ---
162 |
163 | ## Speedrun track 2: GPT-2 Medium
164 |
165 | The target loss for this track is lowered from 3.28 to 2.92, as per Andrej Karpathy's 350M-parameter llm.c baseline.
166 | This baseline generates a model with performance similar to the original GPT-2 Medium, whereas the first track's baseline generates a model on par with GPT-2 Small.
167 | All other rules remain the same.
168 |
169 | > Note: `torch._inductor.config.coordinate_descent_tuning` is turned on after the record 6 (*).
170 |
171 | | # | Record time | Description | Date | Log | Contributors |
172 | | - | - | - | - | - | - |
173 | 1 | 5.8 hours | [llm.c baseline (350M parameters)](https://github.com/karpathy/llm.c/discussions/481) | 05/28/24 | [log](records/011825_GPT2Medium/main.log) | @karpathy, llm.c contributors
174 | 2 | 29.3 minutes | [Initial record based on scaling up the GPT-2 small track speedrun](https://x.com/kellerjordan0/status/1881959719012847703) | 01/18/25 | [log](records/011825_GPT2Medium/241dd7a7-3d76-4dce-85a4-7df60387f32a.txt) | @kellerjordan0
175 | 3 | 28.1 minutes | [Added standard weight decay](https://x.com/kellerjordan0/status/1888320690543284449) | 02/08/25 | [log](records/020825_GPT2MediumWeightDecay/b01743db-605c-4326-b5b1-d388ee5bebc5.txt) | @kellerjordan0
176 | 4 | 27.7 minutes | [Tuned Muon Newton-Schulz coefficients](https://x.com/leloykun/status/1892793848163946799) | 02/14/25 | [log](records/021425_GPT2MediumOptCoeffs/1baa66b2-bff7-4850-aced-d63885ffb4b6.txt) | @leloykun
177 | 5 | 27.2 minutes | [Increased learning rate cooldown phase duration](records/030625_GPT2MediumLongerCooldown/779c041a-2a37-45d2-a18b-ec0f223c2bb7.txt) | 03/06/25 | [log](records/030625_GPT2MediumLongerCooldown/779c041a-2a37-45d2-a18b-ec0f223c2bb7.txt) | @YouJiacheng
178 | 6 | 25.95 minutes* | [2x MLP wd, qkv norm, all_reduce/opt.step() overlap, optimized skip pattern](https://x.com/YouJiacheng/status/1905861218138804534) | 03/25/25 | [log](records/032525_GPT2MediumArchOptTweaks/train_gpt-20250329.txt) | @YouJiacheng
179 | 7 | 25.29 minutes | [Remove FP8 head; ISRU logits softcap; New sharded mixed precision Muon; merge weights](https://x.com/YouJiacheng/status/1912570883878842527) | 04/16/25 | [log](records/041625_GPT2Medium_Record7/223_3310d0b1-b24d-48ee-899f-d5c2a254a195.txt) | @YouJiacheng
180 | 8 | 24.50 minutes | [Cubic sliding window size schedule, 2× max window size (24.84 minutes)](https://x.com/jadenj3o/status/1914893086276169754) [24.5min repro](https://x.com/YouJiacheng/status/1915667616913645985) | 04/22/25 | [log](records/042225_GPT2Medium_Record8/075_640429f2-e726-4e83-aa27-684626239ffc.txt) | @jadenj3o
181 |
182 | ---
183 |
184 | ### Q: What is the point of NanoGPT speedrunning?
185 |
186 | A: The officially stated goal of NanoGPT speedrunning is as follows: `gotta go fast`. But for something a little more verbose involving an argument for good benchmarking, here's some kind of manifesto, adorned with a blessing from the master. [https://x.com/karpathy/status/1846790537262571739](https://x.com/karpathy/status/1846790537262571739)
187 |
188 | ### Q: What makes "NanoGPT speedrunning" not just another idiosyncratic benchmark?
189 |
190 | A: Because it is a *competitive* benchmark. In particular, if you attain a new speed record (using whatever method you want), there is an open invitation for you
191 | to post that record (on arXiv or X) and thereby vacuum up all the clout for yourself. I will even help you do it by reposting you as much as I can.
192 |
193 |
200 |
201 | ["Artificial intelligence advances by inventing games and gloating to goad others to play" - Professor Ben Recht](https://www.argmin.net/p/too-much-information)
202 |
203 | ### Q: NanoGPT speedrunning is cool and all, but meh it probably won't scale and is just overfitting to val loss
204 |
205 | A: This is hard to refute, since "at scale" is an infinite category (what if the methods stop working only for >100T models?), making it impossible to fully prove.
206 | Also, I would agree that some of the methods used in the speedrun are unlikely to scale, particularly those which *impose additional structure* on the network, such as logit softcapping.
207 | But if the reader cares about 1.5B models, they might be convinced by this result:
208 |
209 | *Straightforwardly scaling up the speedrun (10/18/24 version) to 1.5B parameters yields a model with GPT-2 (1.5B)-level HellaSwag performance 2.5x more cheaply than [@karpathy's baseline](https://github.com/karpathy/llm.c/discussions/677) ($233 instead of $576):*
210 |
211 | 
212 | [[reproducible log](https://github.com/KellerJordan/modded-nanogpt/blob/master/records/102024_ScaleUp1B/ad8d7ae5-7b2d-4ee9-bc52-f912e9174d7a.txt)]
213 | 
214 |
215 | ---
216 |
217 | ## [Muon optimizer](https://github.com/KellerJordan/Muon)
218 |
219 | Muon is defined as follows:
220 |
221 | 
222 |
223 | Where NewtonSchulz5 is the following Newton-Schulz iteration [2, 3], which approximately replaces `G` with `U @ V.T` where `U, S, V = G.svd()`.
224 | ```python
225 | @torch.compile
226 | def zeroth_power_via_newtonschulz5(G, steps=5, eps=1e-7):
227 | assert len(G.shape) == 2
228 | a, b, c = (3.4445, -4.7750, 2.0315)
229 | X = G.bfloat16() / (G.norm() + eps)
230 | if G.size(0) > G.size(1):
231 | X = X.T
232 | for _ in range(steps):
233 | A = X @ X.T
234 | B = b * A + c * A @ A
235 | X = a * X + B @ X
236 | if G.size(0) > G.size(1):
237 | X = X.T
238 | return X.to(G.dtype)
239 | ```
240 |
241 | For this training scenario, Muon has the following favorable properties:
242 | * Lower memory usage than Adam
243 | * ~1.5x better sample-efficiency
244 | * <2% wallclock overhead
245 |
246 |
247 | ### Provenance
248 |
249 | Many of the choices made to generate this optimizer were obtained experimentally by our pursuit of [CIFAR-10 speedrunning](https://github.com/KellerJordan/cifar10-airbench).
250 | In particular, we experimentally obtained the following practices:
251 | * Using Nesterov momentum inside the update, with orthogonalization applied after momentum.
252 | * Using a specifically quintic Newton-Schulz iteration as the method of orthogonalization.
253 | * Using non-convergent coefficients for the quintic polynomial in order to maximize slope at zero, and thereby minimize the number of necessary Newton-Schulz iterations.
254 | It turns out that the variance doesn't actually matter that much, so we end up with a quintic that rapidly converges to the range 0.68, 1.13 upon repeated application, rather than converging more slowly to 1.
255 | * Running the Newton-Schulz iteration in bfloat16 (whereas Shampoo implementations often depend on inverse-pth-roots run in fp32 or fp64).
256 |
257 | Our use of a Newton-Schulz iteration for orthogonalization traces to [Bernstein & Newhouse (2024)](https://arxiv.org/abs/2409.20325),
258 | who suggested it as a way to compute Shampoo [5, 6] preconditioners, and theoretically explored Shampoo without preconditioner accumulation.
259 | In particular, Jeremy Bernstein @jxbz sent us the draft, which caused us to experiment with various Newton-Schulz iterations as the
260 | orthogonalization method for this optimizer.
261 | If we had used SVD instead of a Newton-Schulz iteration, this optimizer would have been too slow to be useful.
262 | Bernstein & Newhouse also pointed out that Shampoo without preconditioner accumulation is equivalent to steepest descent in the spectral norm,
263 | and therefore Shampoo can be thought of as a way to smooth out spectral steepest descent.
264 | The proposed optimizer can be thought of as a second way of smoothing spectral steepest descent, with a different set of memory and runtime tradeoffs
265 | compared to Shampoo.
266 |
267 | ---
268 |
269 | ## Running on fewer GPUs
270 |
271 | * To run experiments on fewer GPUs, simply modify `run.sh` to have a different `--nproc_per_node`. This should not change the behavior of the training.
272 | * If you're running out of memory, you may need to reduce the sequence length for FlexAttention (which does change the training. see [here](https://github.com/KellerJordan/modded-nanogpt/pull/38) for a guide)
273 |
274 | ---
275 |
276 | ## References
277 |
278 | 1. [Guilherme Penedo et al. "The fineweb datasets: Decanting the web for the finest text data at scale." arXiv preprint arXiv:2406.17557 (2024).](https://arxiv.org/abs/2406.17557)
279 | 2. Nicholas J. Higham. Functions of Matrices. Society for Industrial and Applied Mathematics (2008). Equation 5.22.
280 | 3. Günther Schulz. Iterative Berechnung der reziproken Matrix. Z. Angew. Math. Mech., 13:57â59 (1933).
281 | 4. [Jeremy Bernstein and Laker Newhouse. "Old Optimizer, New Norm: An Anthology." arxiv preprint arXiv:2409.20325 (2024).](https://arxiv.org/abs/2409.20325)
282 | 5. [Vineet Gupta, Tomer Koren, and Yoram Singer. "Shampoo: Preconditioned stochastic tensor optimization." International Conference on Machine Learning. PMLR, 2018.](https://arxiv.org/abs/1802.09568)
283 | 6. [Rohan Anil et al. "Scalable second order optimization for deep learning." arXiv preprint arXiv:2002.09018 (2020).](https://arxiv.org/abs/2002.09018)
284 | 7. [Alexander Hägele et al. "Scaling Laws and Compute-Optimal Training Beyond Fixed Training Durations." arXiv preprint arXiv:2405.18392 (2024).](https://arxiv.org/abs/2405.18392)
285 | 8. [Zhanchao Zhou et al. "Value Residual Learning For Alleviating Attention Concentration In Transformers." arXiv preprint arXiv:2410.17897 (2024).](https://arxiv.org/abs/2410.17897)
286 | 9. [Team, Gemma, et al. "Gemma 2: Improving open language models at a practical size." arXiv preprint arXiv:2408.00118 (2024).](https://arxiv.org/abs/2408.00118)
287 | 10. [Alec Radford et al. "Language models are unsupervised multitask learners." OpenAI blog 1.8 (2019).](https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf)
288 |
289 | ## Citation
290 |
291 | ```
292 | @misc{modded_nanogpt_2024,
293 | author = {Keller Jordan and Jeremy Bernstein and Brendan Rappazzo and
294 | @fernbear.bsky.social and Boza Vlado and You Jiacheng and
295 | Franz Cesista and Braden Koszarsky and @Grad62304977},
296 | title = {modded-nanogpt: Speedrunning the NanoGPT baseline},
297 | year = {2024},
298 | url = {https://github.com/KellerJordan/modded-nanogpt}
299 | }
300 | ```
301 |
302 |
303 |
304 |
--------------------------------------------------------------------------------
/data/cached_fineweb100B.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from huggingface_hub import hf_hub_download
4 | # Download the GPT-2 tokens of Fineweb100B from huggingface. This
5 | # saves about an hour of startup time compared to regenerating them.
6 | def get(fname):
7 | local_dir = os.path.join(os.path.dirname(__file__), 'fineweb100B')
8 | if not os.path.exists(os.path.join(local_dir, fname)):
9 | hf_hub_download(repo_id="kjj0/fineweb100B-gpt2", filename=fname,
10 | repo_type="dataset", local_dir=local_dir)
11 | get("fineweb_val_%06d.bin" % 0)
12 | num_chunks = 1030 # full fineweb100B. Each chunk is 100M tokens
13 | if len(sys.argv) >= 2: # we can pass an argument to download less
14 | num_chunks = int(sys.argv[1])
15 | for i in range(1, num_chunks+1):
16 | get("fineweb_train_%06d.bin" % i)
17 |
--------------------------------------------------------------------------------
/data/cached_fineweb10B.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from huggingface_hub import hf_hub_download
4 | # Download the GPT-2 tokens of Fineweb10B from huggingface. This
5 | # saves about an hour of startup time compared to regenerating them.
6 | def get(fname):
7 | local_dir = os.path.join(os.path.dirname(__file__), 'fineweb10B')
8 | if not os.path.exists(os.path.join(local_dir, fname)):
9 | hf_hub_download(repo_id="kjj0/fineweb10B-gpt2", filename=fname,
10 | repo_type="dataset", local_dir=local_dir)
11 | get("fineweb_val_%06d.bin" % 0)
12 | num_chunks = 103 # full fineweb10B. Each chunk is 100M tokens
13 | if len(sys.argv) >= 2: # we can pass an argument to download less
14 | num_chunks = int(sys.argv[1])
15 | for i in range(1, num_chunks+1):
16 | get("fineweb_train_%06d.bin" % i)
17 |
--------------------------------------------------------------------------------
/data/cached_finewebedu10B.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from huggingface_hub import hf_hub_download
4 | # Download the GPT-2 tokens of FinewebEDU10B from huggingface. This
5 | # saves about an hour of startup time compared to regenerating them.
6 | def get(fname):
7 | local_dir = os.path.join(os.path.dirname(__file__), 'finewebedu10B')
8 | if not os.path.exists(os.path.join(local_dir, fname)):
9 | hf_hub_download(repo_id="kjj0/finewebedu10B-gpt2", filename=fname,
10 | repo_type="dataset", local_dir=local_dir)
11 | get("finewebedu_val_%06d.bin" % 0)
12 | num_chunks = 99 # full FinewebEDU10B. Each chunk is 100M tokens
13 | if len(sys.argv) >= 2: # we can pass an argument to download less
14 | num_chunks = int(sys.argv[1])
15 | for i in range(1, num_chunks+1):
16 | get("finewebedu_train_%06d.bin" % i)
17 |
--------------------------------------------------------------------------------
/data/fineweb.py:
--------------------------------------------------------------------------------
1 | """
2 | FineWeb dataset (for srs pretraining)
3 | https://huggingface.co/datasets/HuggingFaceFW/fineweb
4 |
5 | example doc to highlight the structure of the dataset:
6 | {
7 | "text": "Posted by mattsmith on 20th April 2012\nStraight from...",
8 | "id": "",
9 | "dump": "CC-MAIN-2013-20",
10 | "url": "http://nleastchatter.com/philliesphandom/tag/freddy-galvis/",
11 | "date": "2013-05-18T07:24:47Z",
12 | "file_path": "s3://commoncrawl/long.../path.../file.gz",
13 | "language": "en",
14 | "language_score": 0.9185474514961243,
15 | "token_count": 594
16 | }
17 | """
18 | import os
19 | import argparse
20 | import multiprocessing as mp
21 | import numpy as np
22 | import tiktoken
23 | # from huggingface_hub import snapshot_download
24 | from datasets import load_dataset
25 | from tqdm import tqdm
26 | import argparse
27 | import numpy as np
28 | def write_datafile(filename, toks):
29 | """
30 | Saves token data as a .bin file, for reading in C.
31 | - First comes a header with 256 int32s
32 | - The tokens follow, each as a uint16
33 | """
34 | assert len(toks) < 2**31, "token count too large" # ~2.1B tokens
35 | # construct the header
36 | header = np.zeros(256, dtype=np.int32)
37 | header[0] = 20240520 # magic
38 | header[1] = 1 # version
39 | header[2] = len(toks) # number of tokens after the 256*4 bytes of header (each 2 bytes as uint16)
40 | # construct the tokens numpy array, if not already
41 | if not isinstance(toks, np.ndarray) or not toks.dtype == np.uint16:
42 | # validate that no token exceeds a uint16
43 | maxtok = 2**16
44 | assert all(0 <= t < maxtok for t in toks), "token dictionary too large for uint16"
45 | toks_np = np.array(toks, dtype=np.uint16)
46 | else:
47 | toks_np = toks
48 | # write to file
49 | print(f"writing {len(toks):,} tokens to {filename}")
50 | with open(filename, "wb") as f:
51 | f.write(header.tobytes())
52 | f.write(toks_np.tobytes())
53 | # ------------------------------------------
54 |
55 | parser = argparse.ArgumentParser(description="FineWeb dataset preprocessing")
56 | parser.add_argument("-v", "--version", type=str, default="10B", help="Which version of fineweb to use 10B|100B")
57 | parser.add_argument("-s", "--shard_size", type=int, default=10**8, help="Size of each shard in tokens")
58 | args = parser.parse_args()
59 |
60 | # FineWeb has a few possible subsamples available
61 | assert args.version in ["10B", "100B"], "version must be one of 10B, 100B"
62 | if args.version == "10B":
63 | local_dir = "fineweb10B"
64 | remote_name = "sample-10BT"
65 | elif args.version == "100B":
66 | local_dir = "fineweb100B"
67 | remote_name = "sample-100BT"
68 |
69 | # create the cache the local directory if it doesn't exist yet
70 | DATA_CACHE_DIR = os.path.join(os.path.dirname(__file__), local_dir)
71 | os.makedirs(DATA_CACHE_DIR, exist_ok=True)
72 |
73 | # download the dataset
74 | fw = load_dataset("HuggingFaceFW/fineweb", name=remote_name, split="train")
75 |
76 | # init the tokenizer
77 | enc = tiktoken.get_encoding("gpt2")
78 | eot = enc._special_tokens['<|endoftext|>'] # end of text token
79 | def tokenize(doc):
80 | # tokenizes a single document and returns a numpy array of uint16 tokens
81 | tokens = [eot] # the special <|endoftext|> token delimits all documents
82 | tokens.extend(enc.encode_ordinary(doc["text"]))
83 | tokens_np = np.array(tokens)
84 | assert (0 <= tokens_np).all() and (tokens_np < 2**16).all(), "token dictionary too large for uint16"
85 | tokens_np_uint16 = tokens_np.astype(np.uint16)
86 | return tokens_np_uint16
87 |
88 | # tokenize all documents and write output shards, each of shard_size tokens (last shard has remainder)
89 | nprocs = max(1, os.cpu_count() - 2) # don't hog the entire system
90 | with mp.Pool(nprocs) as pool:
91 | shard_index = 0
92 | # preallocate buffer to hold current shard
93 | all_tokens_np = np.empty((args.shard_size,), dtype=np.uint16)
94 | token_count = 0
95 | progress_bar = None
96 | for tokens in pool.imap(tokenize, fw, chunksize=16):
97 |
98 | # is there enough space in the current shard for the new tokens?
99 | if token_count + len(tokens) < args.shard_size:
100 | # simply append tokens to current shard
101 | all_tokens_np[token_count:token_count+len(tokens)] = tokens
102 | token_count += len(tokens)
103 | # update progress bar
104 | if progress_bar is None:
105 | progress_bar = tqdm(total=args.shard_size, unit="tokens", desc=f"Shard {shard_index}")
106 | progress_bar.update(len(tokens))
107 | else:
108 | # write the current shard and start a new one
109 | split = "val" if shard_index == 0 else "train"
110 | filename = os.path.join(DATA_CACHE_DIR, f"fineweb_{split}_{shard_index:06d}.bin")
111 | # split the document into whatever fits in this shard; the remainder goes to next one
112 | remainder = args.shard_size - token_count
113 | progress_bar.update(remainder)
114 | all_tokens_np[token_count:token_count+remainder] = tokens[:remainder]
115 | write_datafile(filename, all_tokens_np)
116 | shard_index += 1
117 | progress_bar = None
118 | # populate the next shard with the leftovers of the current doc
119 | all_tokens_np[0:len(tokens)-remainder] = tokens[remainder:]
120 | token_count = len(tokens)-remainder
121 |
122 | # write any remaining tokens as the last shard
123 | if token_count != 0:
124 | split = "val" if shard_index == 0 else "train"
125 | filename = os.path.join(DATA_CACHE_DIR, f"fineweb_{split}_{shard_index:06d}.bin")
126 | write_datafile(filename, all_tokens_np[:token_count])
127 |
--------------------------------------------------------------------------------
/data/requirements.txt:
--------------------------------------------------------------------------------
1 | datasets
2 | tiktoken
3 |
--------------------------------------------------------------------------------
/img/algo_optimizer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KellerJordan/modded-nanogpt/ca964e982191830eebbd155e185937077511a8aa/img/algo_optimizer.png
--------------------------------------------------------------------------------
/img/dofa.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KellerJordan/modded-nanogpt/ca964e982191830eebbd155e185937077511a8aa/img/dofa.jpg
--------------------------------------------------------------------------------
/img/fig_optimizer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KellerJordan/modded-nanogpt/ca964e982191830eebbd155e185937077511a8aa/img/fig_optimizer.png
--------------------------------------------------------------------------------
/img/fig_tuned_nanogpt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KellerJordan/modded-nanogpt/ca964e982191830eebbd155e185937077511a8aa/img/fig_tuned_nanogpt.png
--------------------------------------------------------------------------------
/img/nanogpt_speedrun51.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KellerJordan/modded-nanogpt/ca964e982191830eebbd155e185937077511a8aa/img/nanogpt_speedrun51.png
--------------------------------------------------------------------------------
/img/nanogpt_speedrun52.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KellerJordan/modded-nanogpt/ca964e982191830eebbd155e185937077511a8aa/img/nanogpt_speedrun52.png
--------------------------------------------------------------------------------
/img/nanogpt_speedrun53.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KellerJordan/modded-nanogpt/ca964e982191830eebbd155e185937077511a8aa/img/nanogpt_speedrun53.png
--------------------------------------------------------------------------------
/img/nanogpt_speedrun54.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KellerJordan/modded-nanogpt/ca964e982191830eebbd155e185937077511a8aa/img/nanogpt_speedrun54.png
--------------------------------------------------------------------------------
/records/010425_SoftCap/README.md:
--------------------------------------------------------------------------------
1 | # Softer softcap
2 |
3 | This record, by Braden Koszarsky, increases the degree of logit softcapping, yielding a 7% speedup.
4 | [reproducible log](31d6c427-f1f7-4d8a-91be-a67b5dcd13fd.txt)
5 |
6 | Previously, logits were softcapped (via tanh) to be at most 30. The new record lowers that to 15,
7 | which boosts performance such that the step count can be reduced from 1490 to 1390.
8 |
9 | Lowering the tanh softcap can be understood as a form of extra structure which we are imposing on the network, which improves
10 | performance in the small-scale regime.
11 |
12 | Running this new record 80 times yielded the following series of val losses:
13 | ```
14 | accs = [3.2798, 3.2804, 3.2837, 3.2808, 3.2782, 3.2801, 3.283, 3.2825, 3.2777, 3.2769, 3.2834, 3.2832, 3.2753,
15 | 3.2809, 3.2778, 3.2801, 3.2799, 3.2804, 3.2765, 3.2792, 3.2786, 3.2792, 3.2801, 3.2762, 3.2803, 3.2784,
16 | 3.2792, 3.2791, 3.2769, 3.279, 3.2784, 3.2775, 3.283, 3.2785, 3.2753, 3.2805, 3.2766, 3.2766, 3.2781,
17 | 3.2819, 3.2754, 3.2827, 3.2803, 3.2784, 3.2802, 3.2794, 3.2765, 3.278, 3.2782, 3.278, 3.2816, 3.279,
18 | 3.2771, 3.2791, 3.2768, 3.2781, 3.2794, 3.2798, 3.2785, 3.2804, 3.2777, 3.2765, 3.2796, 3.278, 3.2803,
19 | 3.2793, 3.2793, 3.2788, 3.2797, 3.278, 3.2799, 3.2813, 3.2803, 3.2768, 3.2803, 3.2796, 3.28, 3.2796,
20 | 3.2783, 3.278]
21 |
22 | import scipy.stats
23 | print('p=%.4f' % scipy.stats.ttest_1samp(accs, 3.28, alternative='less').pvalue)
24 | # p=0.0001
25 |
26 | import torch
27 | print(torch.std_mean(torch.tensor(accs)))
28 | # (tensor(0.0019), tensor(3.2791))
29 | ```
30 |
31 | 
32 |
33 |
--------------------------------------------------------------------------------
/records/010425_SoftCap/curves_010425.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KellerJordan/modded-nanogpt/ca964e982191830eebbd155e185937077511a8aa/records/010425_SoftCap/curves_010425.png
--------------------------------------------------------------------------------
/records/011325_Fp8LmHead/README.md:
--------------------------------------------------------------------------------
1 | Note: statistical significance was obtained by @YouJiacheng [here](https://x.com/YouJiacheng/status/1878827972519772241).
2 |
--------------------------------------------------------------------------------
/records/011625_Sub3Min/README.md:
--------------------------------------------------------------------------------
1 | # Sub-3 minute record
2 |
3 | ## Evidence for <=3.28 mean loss
4 |
5 | ```bash
6 | $ grep "1393/1393 val" * | python -c "import sys; ss = list(sys.stdin); accs = [float(s.split()[1].split(':')[1]) for s in ss]; print(accs); import scipy.stats; mvs = scipy.stats.bayes_mvs(accs); print(mvs[0]); print(mvs[2]); print(f'p={scipy.stats.ttest_1samp(accs, 3.28, alternative='less').pvalue:.4f}')"
7 | [3.276, 3.2785, 3.2796, 3.2788, 3.2789, 3.2768, 3.2775, 3.2784, 3.2767, 3.2792, 3.2807, 3.2801, 3.2805, 3.2777, 3.2789, 3.2799, 3.2786, 3.2776, 3.2791, 3.2808, 3.2776, 3.2786, 3.2774, 3.2832, 3.277, 3.2789, 3.2784, 3.2766, 3.2755, 3.2784, 3.2798, 3.2825]
8 | Mean(statistic=np.float64(3.27869375), minmax=(np.float64(3.2781784751445135), np.float64(3.2792090248554864)))
9 | Std_dev(statistic=np.float64(0.0017621789337662857), minmax=(np.float64(0.0014271074116428265), np.float64(0.002179878373699496)))
10 | p=0.0001
11 | ```
12 |
13 | ```
14 | Mean runtime: 179.8 seconds
15 | Stddev: 101ms
16 | ```
17 |
18 | ## Details on the changes made
19 |
20 | ### Long-Short Sliding Window Attention
21 |
22 | 
23 |
24 | This attention mechanism is inspired by the Local-Global Attention introduced by the [Gemma 2](https://arxiv.org/abs/2408.00118) paper (and more recent "hybrid" architectures). But there are two key differences:
25 |
26 | 1. We use [Sliding Window Attention](https://arxiv.org/abs/2004.05150) for both the "global attention" (i.e. "long SWA") and the "local attention" (i.e. "short SWA") parts. The difference between the two is that the "long SWA" has double the context length of the "short SWA".
27 | 2. We also **warmup the context length** of both the sliding window attention mechanisms, but **at different rates**. The "long SWA" context length is warmed up at a double the rate compared to the "short SWA".
28 |
29 | We also made a speedrun-specific decision to only use "long SWA" in the first, fifth, and last layers. The first, because we do not want to compress information too early in the network. The last, because the model architecture we use for the speedrun follows a UNet-like structure, and we want the first and the last layers to be symmetric. And finally, the fifth layer, mainly because it is empirically the best choice for the speedrun.
30 |
31 | This would have been very difficult to implement without PyTorch's [FlexAttention](https://pytorch.org/blog/flexattention/).
32 |
33 | ```diff
34 | # In GPT.forward...
35 | def dense_to_ordered(dense_mask: torch.Tensor):
36 | num_blocks = dense_mask.sum(dim=-1, dtype=torch.int32)
37 | - indices = dense_mask.argsort(dim=-1, descending=True, stable=True).to(torch.int32)
38 | + indices = dense_mask.argsort(dim=-1, descending=False, stable=True).flip(-1).to(torch.int32)
39 | return num_blocks[None, None].contiguous(), indices[None, None].contiguous()
40 |
41 | - def create_doc_swc_block_mask(sliding_window_num_blocks):
42 | + def create_doc_swc_block_masks(sliding_window_num_blocks: int):
43 | kv_idx = block_idx = torch.arange(total_num_blocks, dtype=torch.int32, device='cuda')
44 | q_idx = block_idx[:, None]
45 | causal_bm = q_idx >= kv_idx
46 | causal_full_bm = q_idx > kv_idx
47 | - window_bm = q_idx - kv_idx < sliding_window_num_blocks
48 | - window_full_bm = window_bm # block-wise sliding window by @YouJiacheng
49 | document_bm = (docs_low[:, None] <= docs_high) & (docs_low <= docs_high[:, None])
50 | document_full_bm = (docs_low[:, None] == docs_high) & (docs_low == docs_high[:, None])
51 | - nonzero_bm = causal_bm & window_bm & document_bm
52 | - full_bm = causal_full_bm & window_full_bm & document_full_bm
53 | + nonzero_bm = causal_bm & document_bm
54 | + full_bm = causal_full_bm & document_full_bm
55 | kv_num_blocks, kv_indices = dense_to_ordered(nonzero_bm & ~full_bm)
56 | full_kv_num_blocks, full_kv_indices = dense_to_ordered(full_bm)
57 | - return BlockMask.from_kv_blocks(
58 | - kv_num_blocks,
59 | - kv_indices,
60 | - full_kv_num_blocks,
61 | - full_kv_indices,
62 | - BLOCK_SIZE=BLOCK_SIZE,
63 | - mask_mod=document_causal,
64 | - )
65 | + def build_bm(sw_num_blocks: Tensor) -> BlockMask:
66 | + return BlockMask.from_kv_blocks(
67 | + torch.clamp_max(kv_num_blocks, torch.clamp_min(sw_num_blocks - full_kv_num_blocks, 1)),
68 | + kv_indices,
69 | + torch.clamp_max(full_kv_num_blocks, sw_num_blocks - 1),
70 | + full_kv_indices,
71 | + BLOCK_SIZE=BLOCK_SIZE,
72 | + mask_mod=document_causal,
73 | + )
74 | + return build_bm(sliding_window_num_blocks), build_bm(sliding_window_num_blocks // 2)
75 |
76 | - block_mask = create_doc_swc_block_mask(sliding_window_num_blocks)
77 | + long_bm, short_bm = create_doc_swc_block_masks(sliding_window_num_blocks)
78 | ...
79 | skip_connections = []
80 | # Encoder pass - process only the first half of the blocks
81 | + block_masks = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm]
82 | for i in range(self.num_encoder_layers):
83 | - x = self.blocks[i](x, ve_enc[i], x0, block_mask)
84 | + x = self.blocks[i](x, ve_enc[i], x0, block_masks[i])
85 | skip_connections.append(x)
86 | # Decoder pass - process the remaining blocks with weighted skip connections
87 | + block_masks.reverse()
88 | for i in range(self.num_decoder_layers):
89 | x = x + self.skip_weights[i] * skip_connections.pop()
90 | - x = self.blocks[self.num_encoder_layers + i](x, ve_dec[i], x0, block_mask)
91 | + x = self.blocks[self.num_encoder_layers + i](x, ve_dec[i], x0, block_masks[i])
92 | ```
93 |
94 | ### Attention Scale Modification
95 |
96 | We currently use QK-Normalization to stabilize the attention coefficients. This helps [reduce the wallclock time of the speedrun](https://x.com/kellerjordan0/status/1845865698532450646). However, unlike in larger-scale models such as [ViT-22B](https://arxiv.org/pdf/2302.05442) and [Chameleon](https://arxiv.org/pdf/2405.09818v1), we use a parameter-free RMSNorm instead of the usual LayerNorm with learnable parameters.
97 |
98 | But while the parameter-free RMSNorm is faster and leads to more stable training runs, it also constrains the logit sharpness and consequently the entropy of the attention coefficients to be in the same range across different layers. And in out setup, this leads to higher attention entropies which means the model is less "certain" which tokens to "attend to" during training. While not problematic early in training as we also don't want the model to overfit early on, it can be problematic later on when we want the model to "focus" on the most important tokens. And the current record is now tight-enough for this to be a problem.
99 |
100 | 
101 |
102 | To fix this issue, we first tried out (1) RMSNorm with learned channel-wise parameters and (2) a learned scalar "attention scale" parameter, one for each Attenion layer. Both approaches allowed us to reduce training steps by 20, with a ~0.5-0.7 ms/step overhead. Overall, the wallclock time reduction was ~2-3 secs.
103 |
104 | Strangely, the models seemed to consistently learn a UNet-like attention scales pattern. And hardcoding this pattern lead to roughly the same results (e.g. `attn_scale(layer_idx) := 0.12 + 0.01 * min(layer_idx, 11 - layer_idx)`). We find this interesting and could be a potential area for future research. But fow now, we offer now explanation why this pattern emerges and why it works well aside from divine intervention.
105 |
106 | 
107 |
108 | We eventually settled with simply setting the attention scale to `0.12` (vs. the default `1.0 / sqrt(d_model)`) for all layers. This leads to the same 20 step reduction, but with no per-step overhead; an overall speed gain for ~3 secs.
109 |
110 | ```diff
111 | # In CausalSelfAttention.__init__
112 | + self.attn_scale = 0.12
113 | ...
114 | # In CausalSelfAttention.forward
115 | - y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask)
116 | + y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask, scale=self.attn_scale)
117 | ```
118 |
119 | For logs on learnable attention scales, see: [README for 01/12/25 record attempt](https://github.com/leloykun/modded-nanogpt/blob/fc--learnable-attn-scale/records/011225_LearnableAttnScale/README.md)
120 |
121 | ### Stacked QKV Weights & Batched Muon Implementation
122 |
123 | This is an implementation/compiler-level optimization that leads to a 1-2 secs speed improvement. The crux is that, with a big enough GPU, doing one massive matmul for the QKV weights is faster than doing three smaller matmuls, one for each of the weights.
124 |
125 | The problem, however, is that Muon performs better on the unmerged QKV weights primarily due to the massive matmuls in its Newton-Schulz iterations. Our previous implementation involved storing these weights separately as before but concatenating them in the forward pass. But this concatenation operation introduced a ~1 sec regression. Finally, we got rid of this overhead by stacking the QKV weights instead and using a batched implementation of Muon.
126 |
127 | ### Adam `eps=1e-10` fix
128 |
129 | The speedrun is so tight now that even Adam's default epsilon parameter is already causing problems.
130 |
131 | For context, we initialize our LM head as a zero matrix. This leads to small gradients early on in training which could sometimes be even smaller than Adam's default epsilon--causing training instability and increased validation loss.
132 |
133 | To address this issue, we simply reduced Adam's `eps` from `1e-8` down to `1e-10`. This lead to a 0.0014 validation loss improvement with no per-step overhead; thereby allowing us to reduce training steps by 10.
134 |
135 | ```diff
136 | - optimizer1 = torch.optim.Adam(adam_params, betas=(0.8, 0.95), fused=True)
137 | + optimizer1 = torch.optim.Adam(adam_params, betas=(0.8, 0.95), fused=True, eps=1e-10)
138 | ```
139 |
--------------------------------------------------------------------------------
/records/011625_Sub3Min/attn-entropy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KellerJordan/modded-nanogpt/ca964e982191830eebbd155e185937077511a8aa/records/011625_Sub3Min/attn-entropy.png
--------------------------------------------------------------------------------
/records/011625_Sub3Min/attn-scales-pattern.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KellerJordan/modded-nanogpt/ca964e982191830eebbd155e185937077511a8aa/records/011625_Sub3Min/attn-scales-pattern.gif
--------------------------------------------------------------------------------
/records/011625_Sub3Min/learned-attn-scales.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KellerJordan/modded-nanogpt/ca964e982191830eebbd155e185937077511a8aa/records/011625_Sub3Min/learned-attn-scales.png
--------------------------------------------------------------------------------
/records/011625_Sub3Min/long-short-swa.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KellerJordan/modded-nanogpt/ca964e982191830eebbd155e185937077511a8aa/records/011625_Sub3Min/long-short-swa.png
--------------------------------------------------------------------------------
/records/011825_GPT2Medium/main.log:
--------------------------------------------------------------------------------
1 | s:0 tel:10.9830
2 | s:250 tel:6.3097
3 | s:500 tel:5.5887
4 | s:750 tel:4.9040
5 | s:1000 tel:4.3734
6 | s:1250 tel:4.1307
7 | s:1500 tel:3.9782
8 | s:1750 tel:3.8681
9 | s:2000 tel:3.7884
10 | s:2250 tel:3.7278
11 | s:2500 tel:3.6737
12 | s:2750 tel:3.6324
13 | s:3000 tel:3.6014
14 | s:3250 tel:3.5695
15 | s:3500 tel:3.5400
16 | s:3750 tel:3.5161
17 | s:4000 tel:3.4925
18 | s:4250 tel:3.4754
19 | s:4500 tel:3.4540
20 | s:4750 tel:3.4349
21 | s:5000 tel:3.4190
22 | s:5250 tel:3.4076
23 | s:5500 tel:3.3982
24 | s:5750 tel:3.3842
25 | s:6000 tel:3.3674
26 | s:6250 tel:3.3589
27 | s:6500 tel:3.3493
28 | s:6750 tel:3.3442
29 | s:7000 tel:3.3309
30 | s:7250 tel:3.3207
31 | s:7500 tel:3.3110
32 | s:7750 tel:3.3055
33 | s:8000 tel:3.2969
34 | s:8250 tel:3.2885
35 | s:8500 tel:3.2813
36 | s:8750 tel:3.2780
37 | s:9000 tel:3.2689
38 | s:9250 tel:3.2654
39 | s:9500 tel:3.2574
40 | s:9750 tel:3.2507
41 | s:10000 tel:3.2461
42 | s:10250 tel:3.2401
43 | s:10500 tel:3.2369
44 | s:10750 tel:3.2297
45 | s:11000 tel:3.2247
46 | s:11250 tel:3.2212
47 | s:11500 tel:3.2165
48 | s:11750 tel:3.2135
49 | s:12000 tel:3.2063
50 | s:12250 tel:3.2048
51 | s:12500 tel:3.1993
52 | s:12750 tel:3.1969
53 | s:13000 tel:3.1934
54 | s:13250 tel:3.1887
55 | s:13500 tel:3.1873
56 | s:13750 tel:3.1823
57 | s:14000 tel:3.1819
58 | s:14250 tel:3.1758
59 | s:14500 tel:3.1715
60 | s:14750 tel:3.1668
61 | s:15000 tel:3.1648
62 | s:15250 tel:3.1618
63 | s:15500 tel:3.1589
64 | s:15750 tel:3.1562
65 | s:16000 tel:3.1533
66 | s:16250 tel:3.1508
67 | s:16500 tel:3.1466
68 | s:16750 tel:3.1450
69 | s:17000 tel:3.1448
70 | s:17250 tel:3.1398
71 | s:17500 tel:3.1355
72 | s:17750 tel:3.1342
73 | s:18000 tel:3.1311
74 | s:18250 tel:3.1281
75 | s:18500 tel:3.1258
76 | s:18750 tel:3.1229
77 | s:19000 tel:3.1240
78 | s:19250 tel:3.1195
79 | s:19500 tel:3.1167
80 | s:19750 tel:3.1144
81 | s:20000 tel:3.1129
82 | s:20250 tel:3.1115
83 | s:20500 tel:3.1091
84 | s:20750 tel:3.1074
85 | s:21000 tel:3.1037
86 | s:21250 tel:3.1012
87 | s:21500 tel:3.1006
88 | s:21750 tel:3.0981
89 | s:22000 tel:3.0951
90 | s:22250 tel:3.0938
91 | s:22500 tel:3.0920
92 | s:22750 tel:3.0897
93 | s:23000 tel:3.0888
94 | s:23250 tel:3.0845
95 | s:23500 tel:3.0850
96 | s:23750 tel:3.0812
97 | s:24000 tel:3.0794
98 | s:24250 tel:3.0773
99 | s:24500 tel:3.0755
100 | s:24750 tel:3.0758
101 | s:25000 tel:3.0728
102 | s:25250 tel:3.0708
103 | s:25500 tel:3.0677
104 | s:25750 tel:3.0676
105 | s:26000 tel:3.0654
106 | s:26250 tel:3.0631
107 | s:26500 tel:3.0604
108 | s:26750 tel:3.0589
109 | s:27000 tel:3.0587
110 | s:27250 tel:3.0572
111 | s:27500 tel:3.0553
112 | s:27750 tel:3.0534
113 | s:28000 tel:3.0525
114 | s:28250 tel:3.0501
115 | s:28500 tel:3.0486
116 | s:28750 tel:3.0462
117 | s:29000 tel:3.0456
118 | s:29250 tel:3.0437
119 | s:29500 tel:3.0406
120 | s:29750 tel:3.0409
121 | s:30000 tel:3.0387
122 | s:30250 tel:3.0370
123 | s:30500 tel:3.0369
124 | s:30750 tel:3.0334
125 | s:31000 tel:3.0320
126 | s:31250 tel:3.0306
127 | s:31500 tel:3.0289
128 | s:31750 tel:3.0280
129 | s:32000 tel:3.0252
130 | s:32250 tel:3.0259
131 | s:32500 tel:3.0239
132 | s:32750 tel:3.0227
133 | s:33000 tel:3.0194
134 | s:33250 tel:3.0189
135 | s:33500 tel:3.0168
136 | s:33750 tel:3.0168
137 | s:34000 tel:3.0138
138 | s:34250 tel:3.0125
139 | s:34500 tel:3.0116
140 | s:34750 tel:3.0100
141 | s:35000 tel:3.0082
142 | s:35250 tel:3.0075
143 | s:35500 tel:3.0051
144 | s:35750 tel:3.0037
145 | s:36000 tel:3.0026
146 | s:36250 tel:3.0015
147 | s:36500 tel:3.0000
148 | s:36750 tel:2.9987
149 | s:37000 tel:2.9974
150 | s:37250 tel:2.9954
151 | s:37500 tel:2.9938
152 | s:37750 tel:2.9927
153 | s:38000 tel:2.9911
154 | s:38250 tel:2.9901
155 | s:38500 tel:2.9890
156 | s:38750 tel:2.9871
157 | s:39000 tel:2.9865
158 | s:39250 tel:2.9847
159 | s:39500 tel:2.9833
160 | s:39750 tel:2.9818
161 | s:40000 tel:2.9812
162 | s:40250 tel:2.9798
163 | s:40500 tel:2.9781
164 | s:40750 tel:2.9772
165 | s:41000 tel:2.9762
166 | s:41250 tel:2.9749
167 | s:41500 tel:2.9734
168 | s:41750 tel:2.9724
169 | s:42000 tel:2.9717
170 | s:42250 tel:2.9702
171 | s:42500 tel:2.9685
172 | s:42750 tel:2.9681
173 | s:43000 tel:2.9667
174 | s:43250 tel:2.9651
175 | s:43500 tel:2.9641
176 | s:43750 tel:2.9633
177 | s:44000 tel:2.9638
178 | s:44250 tel:2.9612
179 | s:44500 tel:2.9599
180 | s:44750 tel:2.9592
181 | s:45000 tel:2.9581
182 | s:45250 tel:2.9569
183 | s:45500 tel:2.9563
184 | s:45750 tel:2.9549
185 | s:46000 tel:2.9541
186 | s:46250 tel:2.9530
187 | s:46500 tel:2.9520
188 | s:46750 tel:2.9515
189 | s:47000 tel:2.9504
190 | s:47250 tel:2.9494
191 | s:47500 tel:2.9485
192 | s:47750 tel:2.9475
193 | s:48000 tel:2.9467
194 | s:48250 tel:2.9459
195 | s:48500 tel:2.9451
196 | s:48750 tel:2.9440
197 | s:49000 tel:2.9433
198 | s:49250 tel:2.9428
199 | s:49500 tel:2.9419
200 | s:49750 tel:2.9413
201 | s:50000 tel:2.9405
202 | s:50250 tel:2.9399
203 | s:50500 tel:2.9394
204 | s:50750 tel:2.9388
205 | s:51000 tel:2.9379
206 | s:51250 tel:2.9374
207 | s:51500 tel:2.9367
208 | s:51750 tel:2.9361
209 | s:52000 tel:2.9357
210 | s:52250 tel:2.9350
211 | s:52500 tel:2.9346
212 | s:52750 tel:2.9341
213 | s:53000 tel:2.9336
214 | s:53250 tel:2.9332
215 | s:53500 tel:2.9328
216 | s:53750 tel:2.9324
217 | s:54000 tel:2.9320
218 | s:54250 tel:2.9317
219 | s:54500 tel:2.9314
220 | s:54750 tel:2.9309
221 | s:55000 tel:2.9306
222 | s:55250 tel:2.9303
223 | s:55500 tel:2.9301
224 | s:55750 tel:2.9299
225 | s:56000 tel:2.9296
226 | s:56250 tel:2.9294
227 | s:56500 tel:2.9292
228 | s:56750 tel:2.9290
229 | s:57000 tel:2.9289
230 | s:57250 tel:2.9287
231 | s:57500 tel:2.9286
232 | s:57750 tel:2.9285
233 | s:58000 tel:2.9284
234 | s:58250 tel:2.9283
235 | s:58500 tel:2.9283
236 | s:58750 tel:2.9282
237 | s:59000 tel:2.9282
238 | s:59250 tel:2.9282
239 | s:59500 tel:2.9282
240 | s:59750 tel:2.9281
241 | s:60000 tel:2.9282
242 |
--------------------------------------------------------------------------------
/records/012625_BatchSize/README.md:
--------------------------------------------------------------------------------
1 | # 11/26/25 - Misc Tweaks
2 |
3 | Changelogs:
4 |
5 | 1. Reduced per-device training sequence length from `64*1024` to `48*1024`. See [Critical Batch Size](https://arxiv.org/abs/2410.21676) literature.
6 | 2. Increased per-device eval sequence length from `64*1024` to `4*64*1024`. This improves `val_loss` by `~0.0015` or an equivalent of a reduction of 10 training steps. Overall it saves `~1 sec` of training time.
7 | 3. Modified scales for `fp8` training of LM Head. Saves `1 sec` and improves `val_loss` by as much as `~0.01` after reducing training sequence length down to `48*1024`. I don't know wtf is causing this and I'm NOT going crazy about this. I have evidence. See `records/012625_MiscTweaks/no-autocast-same-fp8-scales`.
8 | - `w_s = 2.0**9` (from `2.0**5`)
9 | - `grad_s = 2.0**19` (from `2.0**29`)
10 | 4. Upgraded PyTorch to 2.7.0 nightly version (20250125) for CUDA 12.6
11 | - `pip install --pre torch==2.7.0.dev20250125+cu126 --index-url https://download.pytorch.org/whl/nightly/cu126`
12 |
13 | 
14 | 
15 |
16 | ```python
17 | accs = [3.2806, 3.2771, 3.2829, 3.2813, 3.2789, 3.2774, 3.2798, 3.2759, 3.2794,
18 | 3.2775, 3.2768, 3.2793, 3.2838, 3.2779, 3.2782, 3.2770, 3.2775, 3.2784,
19 | 3.2782, 3.2776, 3.2814, 3.2785, 3.2793, 3.2797, 3.2782, 3.2789, 3.2759,
20 | 3.2803, 3.2780, 3.2782, 3.2744, 3.2819, 3.2801, 3.2782, 3.2771, 3.2782,
21 | 3.2792, 3.2778, 3.2774, 3.2798, 3.2799, 3.2768, 3.2814, 3.2816, 3.2785,
22 | 3.2817, 3.2801, 3.2755, 3.2780, 3.2774, 3.2797, 3.2789, 3.2843, 3.2777,
23 | 3.2777, 3.2768, 3.2763, 3.2773, 3.2792, 3.2819, 3.2778, 3.2792, 3.2782,
24 | 3.2776, 3.2752, 3.2792, 3.2786, 3.2793, 3.2773, 3.2804, 3.2802, 3.2779,
25 | 3.2780, 3.2779, 3.2801, 3.2773, 3.2802, 3.2770, 3.2785, 3.2772, 3.2818]
26 |
27 | import scipy.stats
28 | print(f'p={scipy.stats.ttest_1samp(accs, 3.28, alternative='less').pvalue:.8f}')
29 | # p=0.00000002 (statistically significant)
30 |
31 | import torch
32 | print(torch.std_mean(torch.tensor(accs)))
33 | # (tensor(0.0019), tensor(3.2787))
34 | ```
35 |
36 | ---
37 |
38 | 
39 |
--------------------------------------------------------------------------------
/records/012625_BatchSize/ablations.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KellerJordan/modded-nanogpt/ca964e982191830eebbd155e185937077511a8aa/records/012625_BatchSize/ablations.png
--------------------------------------------------------------------------------
/records/012625_BatchSize/val_losses.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KellerJordan/modded-nanogpt/ca964e982191830eebbd155e185937077511a8aa/records/012625_BatchSize/val_losses.png
--------------------------------------------------------------------------------
/records/012625_BatchSize/wallclock.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KellerJordan/modded-nanogpt/ca964e982191830eebbd155e185937077511a8aa/records/012625_BatchSize/wallclock.png
--------------------------------------------------------------------------------
/records/030625_GPT2MediumLongerCooldown/README.md:
--------------------------------------------------------------------------------
1 | Changelog:
2 | - Increased learning rate cooldown phase duration from 40% to 60% of the total train duration
3 | - Decreased steps from 7050 to 6950
4 |
5 | Record author: @YouJiacheng
6 |
7 | Results from 36 runs: (P(<2.92) > 99.9%)
8 | ```
9 | [2.9199, 2.9185, 2.9195, 2.9194, 2.9206, 2.9209, 2.9188, 2.9193, 2.9207, 2.9181, 2.9186, 2.9196, 2.9202, 2.9174, 2.9185, 2.9197, 2.9179, 2.9204, 2.9184, 2.9186, 2.9178, 2.9192, 2.9194, 2.9194, 2.9189, 2.9193, 2.9212, 2.9181, 2.9192, 2.9203, 2.9198, 2.9192, 2.919, 2.9196, 2.9182, 2.9186]
10 | ```
11 |
12 |
13 |
--------------------------------------------------------------------------------
/records/041625_GPT2Medium_Record7/README.md:
--------------------------------------------------------------------------------
1 | 1. remove FP8 lm head, 6710→6450 steps, ~0 wall clock change
2 | 2. logits softcap tanh→ISRU, ~1.5ms/step faster, slightly better (2.91970 -> 2.91944)
3 | 3. New sharded mixed precision Muon, remove CastedLinear, ~3.6ms/step faster
4 | 4. merge qkv&o weight, ~0.8ms/step faster
5 | 5. merge scalars weight, ~0.6ms/step faster
6 |
--------------------------------------------------------------------------------
/records/060624_AdamW/README.md:
--------------------------------------------------------------------------------
1 | This is the log for my baseline AdamW training to which I compared the new Muon and SOAP optimizers.
2 |
3 | just the log, which is in the old llm.c format ("tel" lines are val loss)
4 |
5 | this was batch size 2^19, so ~5B tokens
6 |
7 | was learning rate 0.0018, warmup=250, warmdown=2000, betas=(0.9, 0.95) IIRC
8 |
9 |
--------------------------------------------------------------------------------
/records/100924_SOAP/README.md:
--------------------------------------------------------------------------------
1 | # SOAP record October 9 2024
2 |
3 | * New sample efficiency record: <3.28 validation loss in 3.15B tokens
4 | * Uses SOAP optimizer ([Vyas et al. 2024](https://arxiv.org/abs/2409.11321))
5 | * 363ms/step - not a new wallclock record (SOAP is in active development to reduce the wallclock overhead for distributed training, so this may change)
6 | * Set by Nikhil Vyas @vyasnikhil96. Hyperparameters also tuned slightly by me
7 | * [https://x.com/vyasnikhil96/status/1842656792217858063](https://x.com/vyasnikhil96/status/1842656792217858063)
8 | * [https://github.com/nikhilvyas/modded-nanogpt-SOAP/tree/master](https://github.com/nikhilvyas/modded-nanogpt-SOAP/tree/master)
9 |
10 |
--------------------------------------------------------------------------------
/records/101024_Muon/train_gpt2.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | with open(sys.argv[0]) as f:
4 | code = f.read() # read the code of this file ASAP, for logging
5 | import uuid
6 | import glob
7 | import time
8 | from dataclasses import dataclass
9 |
10 | import numpy as np
11 | import torch
12 | from torch import nn
13 | import torch.nn.functional as F
14 | import torch.distributed as dist
15 | import torch._inductor.config as config
16 | from torch.nn.parallel import DistributedDataParallel as DDP
17 |
18 | # -----------------------------------------------------------------------------
19 | # Muon optimizer
20 |
21 | def zeropower_via_svd(G, steps=None):
22 | U, S, V = G.svd()
23 | return U @ V.T
24 |
25 | @torch.compile
26 | def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7):
27 | """
28 | Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
29 | quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
30 | of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
31 | zero even beyond the point where the iteration no longer converges all the way to one everywhere
32 | on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
33 | where S' is diagonal with S_{ii}' \sim Uniform(0.5, 1.5), which turns out not to hurt model
34 | performance at all relative to UV^T, where USV^T = G is the SVD.
35 | """
36 | assert len(G.shape) == 2
37 | a, b, c = (3.4445, -4.7750, 2.0315)
38 | X = G.bfloat16() / (G.norm() + eps) # ensure top singular value <= 1
39 | if G.size(0) > G.size(1):
40 | X = X.T
41 | for _ in range(steps):
42 | A = X @ X.T
43 | B = A @ X
44 | X = a * X + b * B + c * A @ B
45 | if G.size(0) > G.size(1):
46 | X = X.T
47 | return X.to(G.dtype)
48 |
49 | zeropower_backends = dict(svd=zeropower_via_svd, newtonschulz5=zeropower_via_newtonschulz5)
50 |
51 | class Muon(torch.optim.Optimizer):
52 | """
53 | Muon: MomentUm Orthogonalized by Newton-schulz
54 |
55 | Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
56 | processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
57 | matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
58 | the advantage that it can be stably run in bfloat16 on the GPU.
59 |
60 | Some warnings:
61 | - This optimizer assumes that all parameters passed in are 2D.
62 | - It should not be used for the embedding layer, the final fully connected layer, or any {0,1}-D
63 | parameters; those should all be optimized by a standard method (e.g., AdamW).
64 | - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
65 | - We believe it is unlikely to work well for training with small batch size.
66 | - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
67 | - We have not yet tried this optimizer for training scenarios larger than NanoGPT (124M).
68 |
69 | Arguments:
70 | lr: The learning rate used by the internal SGD.
71 | momentum: The momentum used by the internal SGD.
72 | nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
73 | backend: The chosen backend for the orthogonalization step. (recommended: 'newtonschulz5')
74 | backend_steps: The number of iteration steps to use in the backend, if it is iterative.
75 | """
76 | def __init__(self, params, lr=3e-4, momentum=0.95, nesterov=True, backend='newtonschulz5', backend_steps=5):
77 | defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, backend=backend, backend_steps=backend_steps)
78 | super().__init__(params, defaults)
79 |
80 | def step(self):
81 | for group in self.param_groups:
82 | lr = group['lr']
83 | momentum = group['momentum']
84 | zeropower_backend = zeropower_backends[group['backend']]
85 | for p in group['params']:
86 | g = p.grad
87 | if g is None:
88 | continue
89 | state = self.state[p]
90 | if 'momentum_buffer' not in state:
91 | state['momentum_buffer'] = torch.zeros_like(g)
92 | buf = state['momentum_buffer']
93 | buf.mul_(momentum).add_(g)
94 | if group['nesterov']:
95 | g = g.add(buf, alpha=momentum)
96 | if g.size(0) == 3 * g.size(1): # split grouped QKV parameters
97 | g = torch.cat([zeropower_backend(g1, steps=group['backend_steps']) for g1 in g.split(g.size(1))])
98 | scale = g.size(1)**0.5
99 | else:
100 | g = zeropower_backend(g, steps=group['backend_steps'])
101 | scale = max(g.size(0), g.size(1))**0.5 # scale to have update.square().mean() == 1
102 | p.data.add_(g, alpha=-lr * scale)
103 |
104 | # -----------------------------------------------------------------------------
105 | # PyTorch nn.Module definitions for the GPT-2 model
106 |
107 | class Rotary(torch.nn.Module):
108 |
109 | def __init__(self, dim, base=10000):
110 | super().__init__()
111 | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
112 | self.register_buffer("inv_freq", inv_freq)
113 | self.seq_len_cached = None
114 | self.cos_cached = None
115 | self.sin_cached = None
116 |
117 | def forward(self, x):
118 | seq_len = x.shape[1]
119 | if seq_len != self.seq_len_cached:
120 | self.seq_len_cached = seq_len
121 | t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
122 | freqs = torch.outer(t, self.inv_freq).to(x.device)
123 | self.cos_cached = freqs.cos()
124 | self.sin_cached = freqs.sin()
125 | return self.cos_cached[None, :, None, :], self.sin_cached[None, :, None, :]
126 |
127 | def apply_rotary_emb(x, cos, sin):
128 | assert x.ndim == 4 # multihead attention
129 | d = x.shape[3]//2
130 | x1 = x[..., :d]
131 | x2 = x[..., d:]
132 | y1 = x1 * cos + x2 * sin
133 | y2 = x1 * (-sin) + x2 * cos
134 | return torch.cat([y1, y2], 3)
135 |
136 | def rmsnorm(x0, eps=1e-6):
137 | x = x0.float()
138 | x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
139 | return x.type_as(x0)
140 |
141 | class CausalSelfAttention(nn.Module):
142 |
143 | def __init__(self, config):
144 | super().__init__()
145 | self.n_head = config.n_head
146 | self.n_embd = config.n_embd
147 | self.head_dim = self.n_embd // self.n_head
148 | assert self.n_embd % self.n_head == 0
149 | # key, query, value projections for all heads, but in a batch
150 | self.c_attn = nn.Linear(self.n_embd, 3 * self.n_embd, bias=False)
151 | # output projection
152 | self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
153 | self.rotary = Rotary(self.head_dim)
154 |
155 | def forward(self, x):
156 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
157 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim
158 | qkv = self.c_attn(x)
159 | q, k, v = qkv.split(self.n_embd, dim=2)
160 | k = k.view(B, T, self.n_head, self.head_dim)
161 | q = q.view(B, T, self.n_head, self.head_dim)
162 | v = v.view(B, T, self.n_head, self.head_dim)
163 | cos, sin = self.rotary(q)
164 | q = apply_rotary_emb(q, cos, sin)
165 | k = apply_rotary_emb(k, cos, sin)
166 | y = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True)
167 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
168 | # output projection
169 | y = self.c_proj(y)
170 | return y
171 |
172 | class MLP(nn.Module):
173 |
174 | def __init__(self, config):
175 | super().__init__()
176 | self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
177 | self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
178 |
179 | def forward(self, x):
180 | x = self.c_fc(x)
181 | x = F.gelu(x)
182 | x = self.c_proj(x)
183 | return x
184 |
185 | class Block(nn.Module):
186 |
187 | def __init__(self, config):
188 | super().__init__()
189 | self.attn = CausalSelfAttention(config)
190 | self.mlp = MLP(config)
191 | self.attn_scale = (1 / (2 * config.n_layer)**0.5)
192 |
193 | def forward(self, x):
194 | x = x + self.attn_scale * self.attn(rmsnorm(x))
195 | x = x + self.mlp(rmsnorm(x))
196 | return x
197 |
198 | # -----------------------------------------------------------------------------
199 | # The main GPT-2 model
200 |
201 | @dataclass
202 | class GPTConfig:
203 | vocab_size : int = 50257
204 | n_layer : int = 12
205 | n_head : int = 12
206 | n_embd : int = 768
207 |
208 | class GPT(nn.Module):
209 |
210 | def __init__(self, config):
211 | super().__init__()
212 | self.config = config
213 |
214 | self.transformer = nn.ModuleDict(dict(
215 | wte = nn.Embedding(config.vocab_size, config.n_embd),
216 | h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
217 | ))
218 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
219 | self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
220 |
221 | def forward(self, idx, targets=None, return_logits=True):
222 | b, t = idx.size()
223 | pos = torch.arange(0, t, dtype=torch.long, device=idx.device) # shape (t)
224 |
225 | # forward the GPT model itself
226 | x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
227 |
228 | for block in self.transformer.h:
229 | x = block(x)
230 | x = rmsnorm(x)
231 |
232 | if targets is not None:
233 | # if we are given some desired targets also calculate the loss
234 | logits = self.lm_head(x)
235 | logits = logits.float() # use tf32/fp32 for logits
236 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
237 | else:
238 | # inference-time mini-optimization: only forward the lm_head on the very last position
239 | logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
240 | logits = logits.float() # use tf32/fp32 for logits
241 | loss = None
242 |
243 | # there are performance reasons why not returning logits is prudent, if not needed
244 | if not return_logits:
245 | logits = None
246 |
247 | return logits, loss
248 |
249 | # -----------------------------------------------------------------------------
250 | # Our own simple Distributed Data Loader
251 |
252 | def _peek_data_shard(filename):
253 | # only reads the header, returns header data
254 | with open(filename, "rb") as f:
255 | # first read the header, which is 256 int32 integers (4 bytes each)
256 | header = np.frombuffer(f.read(256*4), dtype=np.int32)
257 | if header[0] != 20240520:
258 | print("ERROR: magic number mismatch in the data .bin file!")
259 | print("---> HINT: Are you passing in a correct file with --input_bin?")
260 | print("---> HINT: Dataset encoding changed recently, re-run data prepro or refer again to README")
261 | print("---> HINT: For example re-run: `python dev/data/tinyshakespeare.py`, then re-try")
262 | exit(1)
263 | assert header[1] == 1, "unsupported version"
264 | ntok = header[2] # number of tokens (claimed)
265 | return ntok # for now just return the number of tokens
266 |
267 | def _load_data_shard(filename):
268 | with open(filename, "rb") as f:
269 | # first read the header, which is 256 int32 integers (4 bytes each)
270 | header = np.frombuffer(f.read(256*4), dtype=np.int32)
271 | assert header[0] == 20240520, "magic number mismatch in the data .bin file"
272 | assert header[1] == 1, "unsupported version"
273 | ntok = header[2] # number of tokens (claimed)
274 | # the rest of it are tokens, stored as uint16
275 | tokens = np.frombuffer(f.read(), dtype=np.uint16)
276 | assert len(tokens) == ntok, "number of tokens read does not match header?"
277 | return tokens
278 |
279 | class DistributedDataLoader:
280 | def __init__(self, filename_pattern, B, T, process_rank, num_processes):
281 | self.process_rank = process_rank
282 | self.num_processes = num_processes
283 | self.B = B
284 | self.T = T
285 |
286 | # glob files that match the pattern
287 | self.files = sorted(glob.glob(filename_pattern))
288 | assert len(self.files) > 0, f"did not find any files that match the pattern {filename_pattern}"
289 |
290 | # load and validate all data shards, count number of tokens in total
291 | ntok_total = 0
292 | for fname in self.files:
293 | shard_ntok = _peek_data_shard(fname)
294 | assert shard_ntok >= num_processes * B * T + 1
295 | ntok_total += int(shard_ntok)
296 | self.ntok_total = ntok_total
297 |
298 | # kick things off
299 | self.reset()
300 |
301 | def reset(self):
302 | self.current_shard = 0
303 | self.current_position = self.process_rank * self.B * self.T
304 | self.tokens = _load_data_shard(self.files[self.current_shard])
305 |
306 | def advance(self): # advance to next data shard
307 | self.current_shard = (self.current_shard + 1) % len(self.files)
308 | self.current_position = self.process_rank * self.B * self.T
309 | self.tokens = _load_data_shard(self.files[self.current_shard])
310 |
311 | def next_batch(self):
312 | B = self.B
313 | T = self.T
314 | buf = self.tokens[self.current_position : self.current_position+B*T+1]
315 | buf = torch.tensor(buf.astype(np.int32), dtype=torch.long)
316 | x = (buf[:-1]).view(B, T) # inputs
317 | y = (buf[1:]).view(B, T) # targets
318 | # advance current position and load next shard if necessary
319 | self.current_position += B * T * self.num_processes
320 | if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens):
321 | self.advance()
322 | return x.cuda(), y.cuda()
323 |
324 | # -----------------------------------------------------------------------------
325 | # int main
326 |
327 | @dataclass
328 | class Hyperparameters:
329 | # data hyperparams
330 | input_bin : str = 'data/fineweb10B/fineweb_train_*.bin' # input .bin to train on
331 | input_val_bin : str = 'data/fineweb10B/fineweb_val_*.bin' # input .bin to eval validation loss on
332 | # optimization hyperparams
333 | batch_size : int = 8*64 # batch size, in sequences, across all devices
334 | device_batch_size : int = 64 # batch size, in sequences, per device
335 | sequence_length : int = 1024 # sequence length, in tokens
336 | num_iterations : int = 6200 # number of iterations to run
337 | learning_rate : float = 0.0036
338 | warmup_iters : int = 0
339 | warmdown_iters : int = 1800 # number of iterations of linear warmup/warmdown for triangular or trapezoidal schedule
340 | weight_decay : float = 0
341 | # evaluation and logging hyperparams
342 | val_loss_every : int = 125 # every how many steps to evaluate val loss? 0 for only at the end
343 | val_tokens : int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons
344 | save_every : int = 0 # every how many steps to save the checkpoint? 0 for only at the end
345 | args = Hyperparameters()
346 |
347 | # set up DDP (distributed data parallel). torchrun sets this env variable
348 | assert torch.cuda.is_available()
349 | dist.init_process_group(backend='nccl')
350 | ddp_rank = int(os.environ['RANK'])
351 | ddp_local_rank = int(os.environ['LOCAL_RANK'])
352 | ddp_world_size = int(os.environ['WORLD_SIZE'])
353 | device = f'cuda:{ddp_local_rank}'
354 | torch.cuda.set_device(device)
355 | print(f"using device: {device}")
356 | master_process = (ddp_rank == 0) # this process will do logging, checkpointing etc.
357 |
358 | # convenience variables
359 | B, T = args.device_batch_size, args.sequence_length
360 | # calculate the number of steps to take in the val loop.
361 | assert args.val_tokens % (B * T * ddp_world_size) == 0
362 | val_steps = args.val_tokens // (B * T * ddp_world_size)
363 | # calculate the steps of gradient accumulation required to attain the desired global batch size.
364 | assert args.batch_size % (B * ddp_world_size) == 0
365 | train_accumulation_steps = args.batch_size // (B * ddp_world_size)
366 |
367 | # load tokens
368 | train_loader = DistributedDataLoader(args.input_bin, B, T, ddp_rank, ddp_world_size)
369 | val_loader = DistributedDataLoader(args.input_val_bin, B, T, ddp_rank, ddp_world_size)
370 | if master_process:
371 | print(f"Training DataLoader: total number of tokens: {train_loader.ntok_total} across {len(train_loader.files)} files")
372 | print(f"Validation DataLoader: total number of tokens: {val_loader.ntok_total} across {len(val_loader.files)} files")
373 | x, y = train_loader.next_batch()
374 |
375 | # init the model from scratch
376 | num_vocab = 50257
377 | model = GPT(GPTConfig(vocab_size=num_vocab, n_layer=12, n_head=12, n_embd=768))
378 | model = model.cuda()
379 | if hasattr(config, "coordinate_descent_tuning"):
380 | config.coordinate_descent_tuning = True # suggested by @Chillee
381 | model = torch.compile(model)
382 | # here we wrap model into DDP container
383 | model = DDP(model, device_ids=[ddp_local_rank])
384 | raw_model = model.module # always contains the "raw" unwrapped model
385 | ctx = torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16)
386 |
387 | # init the optimizer(s)
388 | optimizer1 = torch.optim.AdamW(raw_model.lm_head.parameters(), lr=args.learning_rate, betas=(0.9, 0.95),
389 | weight_decay=args.weight_decay, fused=True)
390 | optimizer2 = Muon(raw_model.transformer.h.parameters(), lr=0.1*args.learning_rate, momentum=0.95)
391 | optimizers = [optimizer1, optimizer2]
392 | # learning rate decay scheduler (linear warmup and warmdown)
393 | def get_lr(it):
394 | assert it <= args.num_iterations
395 | # 1) linear warmup for warmup_iters steps
396 | if it < args.warmup_iters:
397 | return (it+1) / args.warmup_iters
398 | # 2) constant lr for a while
399 | elif it < args.num_iterations - args.warmdown_iters:
400 | return 1.0
401 | # 3) linear warmdown
402 | else:
403 | decay_ratio = (args.num_iterations - it) / args.warmdown_iters
404 | return decay_ratio
405 | schedulers = [torch.optim.lr_scheduler.LambdaLR(opt, get_lr) for opt in optimizers]
406 |
407 | # begin logging
408 | if master_process:
409 | run_id = str(uuid.uuid4())
410 | logdir = 'logs/%s/' % run_id
411 | os.makedirs(logdir, exist_ok=True)
412 | logfile = 'logs/%s.txt' % run_id
413 | # create the log file
414 | with open(logfile, "w") as f:
415 | # begin the log by printing this file (the Python code)
416 | f.write('='*100 + '\n')
417 | f.write(code)
418 | f.write('='*100 + '\n')
419 | # log information about the hardware/software environment this is running on
420 | # and print the full `nvidia-smi` to file
421 | f.write(f"Running pytorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}\nnvidia-smi:\n")
422 | import subprocess
423 | result = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
424 | f.write(f'{result.stdout}\n')
425 | f.write('='*100 + '\n')
426 |
427 | training_time_ms = 0
428 | # start the clock
429 | torch.cuda.synchronize()
430 | t0 = time.time()
431 | # begin training
432 | train_loader.reset()
433 | for step in range(args.num_iterations + 1):
434 | last_step = (step == args.num_iterations)
435 | # This effectively ignores timing first 10 steps, which are slower for weird reasons.
436 | # Alternately, and slightly more correctly in terms of benchmarking, we could do 10
437 | # steps with dummy data first, and then re-initialize the model and reset the loader.
438 | if step == 10:
439 | training_time_ms = 0
440 | t0 = time.time()
441 | timed_steps = float('nan') if step <= 11 else (step - 10) + 1 # <= 11 to avoid bug in val
442 |
443 | # once in a while evaluate the validation dataset
444 | if (last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0)):
445 | # stop the clock
446 | torch.cuda.synchronize()
447 | training_time_ms += 1000 * (time.time() - t0)
448 | # run validation batches
449 | model.eval()
450 | val_loader.reset()
451 | val_loss = 0.0
452 | for _ in range(val_steps):
453 | x_val, y_val = val_loader.next_batch()
454 | with torch.no_grad(): # of course, we'd like to use ctx here too, but that creates a torch.compile error for some reason
455 | _, loss = model(x_val, y_val, return_logits=False)
456 | val_loss += loss
457 | dist.all_reduce(val_loss, op=dist.ReduceOp.AVG)
458 | val_loss /= val_steps
459 | # log val loss to console and to logfile
460 | if master_process:
461 | print(f'step:{step}/{args.num_iterations} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms')
462 | with open(logfile, "a") as f:
463 | f.write(f'step:{step}/{args.num_iterations} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms\n')
464 | # start the clock again
465 | torch.cuda.synchronize()
466 | t0 = time.time()
467 |
468 | if master_process and (last_step or (args.save_every > 0 and step % args.save_every == 0)):
469 | # stop the clock
470 | torch.cuda.synchronize()
471 | training_time_ms += 1000 * (time.time() - t0)
472 | # save the state of the training process
473 | log = dict(step=step, code=code, model=raw_model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers])
474 | torch.save(log, 'logs/%s/state_step%06d.pt' % (run_id, step))
475 | # start the clock again
476 | torch.cuda.synchronize()
477 | t0 = time.time()
478 |
479 | # bit confusing: we want to make sure to eval on 0th iteration
480 | # but also after the very last iteration. so we loop for step <= num_iterations
481 | # instead of just < num_iterations (one extra due to <=), only to do
482 | # the validation/sampling one last time, and then we break right here as we're done.
483 | if last_step:
484 | break
485 |
486 | # --------------- TRAINING SECTION BEGIN -----------------
487 | model.train()
488 | for i in range(1, train_accumulation_steps+1):
489 | # forward pass
490 | with ctx:
491 | _, loss = model(x, y, return_logits=False)
492 | train_loss = loss.detach()
493 | # advance the dataset for the next batch
494 | x, y = train_loader.next_batch()
495 | # backward pass
496 | if i < train_accumulation_steps:
497 | with model.no_sync(): # there's no need to sync gradients every accumulation step
498 | loss.backward()
499 | else:
500 | loss.backward() # just sync on the last step
501 | for p in model.parameters():
502 | p.grad /= train_accumulation_steps
503 | # step the optimizers and schedulers
504 | for opt, sched in zip(optimizers, schedulers):
505 | opt.step()
506 | sched.step()
507 | # null the gradients
508 | model.zero_grad(set_to_none=True)
509 | # --------------- TRAINING SECTION END -------------------
510 | # everything that follows now is just diagnostics, prints, logging, etc.
511 |
512 | #dist.all_reduce(train_loss, op=dist.ReduceOp.AVG) # all-reducing the training loss would be more correct in terms of logging, but slower
513 | if master_process:
514 | approx_time = training_time_ms + 1000 * (time.time() - t0)
515 | print(f"step:{step+1}/{args.num_iterations} train_loss:{train_loss.item():.4f} train_time:{approx_time:.0f}ms step_avg:{approx_time/timed_steps:.2f}ms")
516 | with open(logfile, "a") as f:
517 | f.write(f"step:{step+1}/{args.num_iterations} train_loss:{train_loss.item():.4f} train_time:{approx_time:.0f}ms step_avg:{approx_time/timed_steps:.2f}ms\n")
518 |
519 | if master_process:
520 | print(f"peak memory consumption: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB")
521 |
522 | # -------------------------------------------------------------------------
523 | # clean up nice
524 | dist.destroy_process_group()
525 |
--------------------------------------------------------------------------------
/records/101324_llmc/README.md:
--------------------------------------------------------------------------------
1 | This is a log produced by running the current version of Andrej Karpathy's [llm.c](https://github.com/karpathy/llm.c), as of October 13th 2024.
2 |
3 | It was run on a node with 8x H100 HBM3 according to the instructions [here](https://github.com/karpathy/llm.c/discussions/481).
4 | The mean per-step time was 140ms. The total number of training tokens is 10.26B. The final validation loss was **3.2722**.
5 |
6 | This is (significantly) better than the quoted result of **3.29** val loss in
7 | [Andrej Karpathy's May 28th GPT-2 replication discussion](https://github.com/karpathy/llm.c/discussions/481#:~:text=By%20the%20end%20of%20the%20optimization%20we%27ll%20get%20to%20about%203.29).
8 | So it appears that there have been some improvements to the training algorithm used by llm.c since then.
9 |
10 | Note that the set of examples which llm.c uses for validation appears to be the same as what we do in this repo, i.e., the first `10 * 2**20` tokens of the val set.
11 |
12 |
--------------------------------------------------------------------------------
/records/101424_ModernArch/train_gpt2.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | with open(sys.argv[0]) as f:
4 | code = f.read() # read the code of this file ASAP, for logging
5 | import uuid
6 | import glob
7 | import time
8 | from dataclasses import dataclass
9 |
10 | import numpy as np
11 | import torch
12 | from torch import nn
13 | import torch.nn.functional as F
14 | import torch.distributed as dist
15 | import torch._inductor.config as config
16 | from torch.nn.parallel import DistributedDataParallel as DDP
17 |
18 | # -----------------------------------------------------------------------------
19 | # Muon optimizer
20 |
21 | def zeropower_via_svd(G, steps=None):
22 | U, S, V = G.svd()
23 | return U @ V.T
24 |
25 | @torch.compile
26 | def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7):
27 | """
28 | Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
29 | quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
30 | of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
31 | zero even beyond the point where the iteration no longer converges all the way to one everywhere
32 | on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
33 | where S' is diagonal with S_{ii}' \sim Uniform(0.5, 1.5), which turns out not to hurt model
34 | performance at all relative to UV^T, where USV^T = G is the SVD.
35 | """
36 | assert len(G.shape) == 2
37 | a, b, c = (3.4445, -4.7750, 2.0315)
38 | X = G.bfloat16()
39 | X /= (X.norm() + eps) # ensure top singular value <= 1
40 | if G.size(0) > G.size(1):
41 | X = X.T
42 | for _ in range(steps):
43 | A = X @ X.T
44 | B = A @ X
45 | X = a * X + b * B + c * A @ B
46 | if G.size(0) > G.size(1):
47 | X = X.T
48 | return X
49 |
50 | zeropower_backends = dict(svd=zeropower_via_svd, newtonschulz5=zeropower_via_newtonschulz5)
51 |
52 | class Muon(torch.optim.Optimizer):
53 | """
54 | Muon - MomentUm Orthogonalized by Newton-schulz
55 |
56 | Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
57 | processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
58 | matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
59 | the advantage that it can be stably run in bfloat16 on the GPU.
60 |
61 | Some warnings:
62 | - This optimizer assumes that all parameters passed in are 2D.
63 | - It should not be used for the embedding layer, the final fully connected layer, or any {0,1}-D
64 | parameters; those should all be optimized by a standard method (e.g., AdamW).
65 | - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
66 | - We believe it is unlikely to work well for training with small batch size.
67 | - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
68 | - We have not yet tried this optimizer for training scenarios larger than NanoGPT (124M).
69 |
70 | Arguments:
71 | lr: The learning rate used by the internal SGD.
72 | momentum: The momentum used by the internal SGD.
73 | nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
74 | backend: The chosen backend for the orthogonalization step. (recommended: 'newtonschulz5')
75 | backend_steps: The number of iteration steps to use in the backend, if it is iterative.
76 | """
77 | def __init__(self, params, lr=3e-4, momentum=0.95, nesterov=True, backend='newtonschulz5', backend_steps=5):
78 | defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, backend=backend, backend_steps=backend_steps)
79 | super().__init__(params, defaults)
80 |
81 | def step(self):
82 | for group in self.param_groups:
83 | lr = group['lr']
84 | momentum = group['momentum']
85 | zeropower_backend = zeropower_backends[group['backend']]
86 | for p in group['params']:
87 | g = p.grad
88 | if g is None:
89 | continue
90 | state = self.state[p]
91 | if 'momentum_buffer' not in state:
92 | state['momentum_buffer'] = torch.zeros_like(g)
93 | buf = state['momentum_buffer']
94 | buf.mul_(momentum).add_(g)
95 | if group['nesterov']:
96 | g = g.add(buf, alpha=momentum)
97 | if g.size(0) == 3 * g.size(1): # split grouped QKV parameters
98 | g = torch.cat([zeropower_backend(g1, steps=group['backend_steps']) for g1 in g.split(g.size(1))])
99 | scale = g.size(1)**0.5
100 | else:
101 | g = zeropower_backend(g, steps=group['backend_steps'])
102 | scale = max(g.size(0), g.size(1))**0.5 # scale to have update.square().mean() == 1
103 | p.data.add_(g, alpha=-lr * scale)
104 |
105 | # -----------------------------------------------------------------------------
106 | # PyTorch nn.Module definitions for the GPT-2 model
107 |
108 | class Rotary(torch.nn.Module):
109 |
110 | def __init__(self, dim, base=10000):
111 | super().__init__()
112 | self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
113 | self.seq_len_cached = None
114 | self.cos_cached = None
115 | self.sin_cached = None
116 |
117 | def forward(self, x):
118 | seq_len = x.shape[1]
119 | if seq_len != self.seq_len_cached:
120 | self.seq_len_cached = seq_len
121 | t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
122 | freqs = torch.outer(t, self.inv_freq).to(x.device)
123 | self.cos_cached = freqs.cos().bfloat16()
124 | self.sin_cached = freqs.sin().bfloat16()
125 | return self.cos_cached[None, :, None, :], self.sin_cached[None, :, None, :]
126 |
127 | def apply_rotary_emb(x, cos, sin):
128 | assert x.ndim == 4 # multihead attention
129 | d = x.shape[3]//2
130 | x1 = x[..., :d]
131 | x2 = x[..., d:]
132 | y1 = x1 * cos + x2 * sin
133 | y2 = x1 * (-sin) + x2 * cos
134 | return torch.cat([y1, y2], 3).type_as(x)
135 |
136 | class CausalSelfAttention(nn.Module):
137 |
138 | def __init__(self, config):
139 | super().__init__()
140 | self.n_head = config.n_head
141 | self.n_embd = config.n_embd
142 | self.head_dim = self.n_embd // self.n_head
143 | assert self.n_embd % self.n_head == 0
144 | self.c_q = nn.Linear(self.n_embd, self.n_embd, bias=False)
145 | self.c_k = nn.Linear(self.n_embd, self.n_embd, bias=False)
146 | self.c_v = nn.Linear(self.n_embd, self.n_embd, bias=False)
147 | # output projection
148 | self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
149 | self.c_proj.weight.data.zero_() # zero init suggested by @Grad62304977
150 | self.rotary = Rotary(self.head_dim)
151 |
152 | def forward(self, x):
153 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
154 | q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
155 | k = self.c_k(x).view(B, T, self.n_head, self.head_dim)
156 | v = self.c_v(x).view(B, T, self.n_head, self.head_dim)
157 | cos, sin = self.rotary(q)
158 | q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
159 | q, k = F.rms_norm(q, (q.size(-1),)), F.rms_norm(k, (k.size(-1),)) # QK norm suggested by @Grad62304977
160 | y = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True)
161 | y = y.transpose(1, 2).contiguous().view_as(x) # re-assemble all head outputs side by side
162 | y = self.c_proj(y)
163 | return y
164 |
165 | class MLP(nn.Module):
166 |
167 | def __init__(self, config):
168 | super().__init__()
169 | self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
170 | self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
171 | self.c_proj.weight.data.zero_() # zero init suggested by @Grad62304977
172 |
173 | def forward(self, x):
174 | x = self.c_fc(x)
175 | x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977
176 | x = self.c_proj(x)
177 | return x
178 |
179 | class Block(nn.Module):
180 |
181 | def __init__(self, config):
182 | super().__init__()
183 | self.attn = CausalSelfAttention(config)
184 | self.mlp = MLP(config)
185 |
186 | def forward(self, x):
187 | x = x + self.attn(F.rms_norm(x, (x.size(-1),)))
188 | x = x + self.mlp(F.rms_norm(x, (x.size(-1),)))
189 | return x
190 |
191 | # -----------------------------------------------------------------------------
192 | # The main GPT-2 model
193 |
194 | @dataclass
195 | class GPTConfig:
196 | vocab_size : int = 50304
197 | n_layer : int = 12
198 | n_head : int = 6 # head dim 128 suggested by @Grad62304977
199 | n_embd : int = 768
200 |
201 | class GPT(nn.Module):
202 |
203 | def __init__(self, config):
204 | super().__init__()
205 | self.config = config
206 |
207 | self.transformer = nn.ModuleDict(dict(
208 | wte = nn.Embedding(config.vocab_size, config.n_embd),
209 | h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
210 | ))
211 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
212 | self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
213 |
214 | def forward(self, idx, targets=None, return_logits=True):
215 |
216 | # forward the GPT model itself
217 | x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
218 | for block in self.transformer.h:
219 | x = block(x)
220 | x = F.rms_norm(x, (x.size(-1),))
221 |
222 | if targets is not None:
223 | # if we are given some desired targets also calculate the loss
224 | logits = self.lm_head(x)
225 | logits = logits.float() # use tf32/fp32 for logits
226 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
227 | else:
228 | # inference-time mini-optimization: only forward the lm_head on the very last position
229 | logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
230 | logits = logits.float() # use tf32/fp32 for logits
231 | loss = None
232 |
233 | # there are performance reasons why not returning logits is prudent, if not needed
234 | if not return_logits:
235 | logits = None
236 |
237 | return logits, loss
238 |
239 | # -----------------------------------------------------------------------------
240 | # Our own simple Distributed Data Loader
241 |
242 | def _peek_data_shard(filename):
243 | # only reads the header, returns header data
244 | with open(filename, "rb") as f:
245 | # first read the header, which is 256 int32 integers (4 bytes each)
246 | header = np.frombuffer(f.read(256*4), dtype=np.int32)
247 | if header[0] != 20240520:
248 | print("ERROR: magic number mismatch in the data .bin file!")
249 | print("---> HINT: Are you passing in a correct file with --input_bin?")
250 | print("---> HINT: Dataset encoding changed recently, re-run data prepro or refer again to README")
251 | print("---> HINT: For example re-run: `python dev/data/tinyshakespeare.py`, then re-try")
252 | exit(1)
253 | assert header[1] == 1, "unsupported version"
254 | ntok = header[2] # number of tokens (claimed)
255 | return ntok # for now just return the number of tokens
256 |
257 | def _load_data_shard(filename):
258 | with open(filename, "rb") as f:
259 | # first read the header, which is 256 int32 integers (4 bytes each)
260 | header = np.frombuffer(f.read(256*4), dtype=np.int32)
261 | assert header[0] == 20240520, "magic number mismatch in the data .bin file"
262 | assert header[1] == 1, "unsupported version"
263 | ntok = header[2] # number of tokens (claimed)
264 | # the rest of it are tokens, stored as uint16
265 | tokens = np.frombuffer(f.read(), dtype=np.uint16)
266 | assert len(tokens) == ntok, "number of tokens read does not match header?"
267 | return tokens
268 |
269 | class DistributedDataLoader:
270 | def __init__(self, filename_pattern, B, T, process_rank, num_processes):
271 | self.process_rank = process_rank
272 | self.num_processes = num_processes
273 | self.B = B
274 | self.T = T
275 |
276 | # glob files that match the pattern
277 | self.files = sorted(glob.glob(filename_pattern))
278 | assert len(self.files) > 0, f"did not find any files that match the pattern {filename_pattern}"
279 |
280 | # load and validate all data shards, count number of tokens in total
281 | ntok_total = 0
282 | for fname in self.files:
283 | shard_ntok = _peek_data_shard(fname)
284 | assert shard_ntok >= num_processes * B * T + 1
285 | ntok_total += int(shard_ntok)
286 | self.ntok_total = ntok_total
287 |
288 | # kick things off
289 | self.reset()
290 |
291 | def reset(self):
292 | self.current_shard = 0
293 | self.current_position = self.process_rank * self.B * self.T
294 | self.tokens = _load_data_shard(self.files[self.current_shard])
295 |
296 | def advance(self): # advance to next data shard
297 | self.current_shard = (self.current_shard + 1) % len(self.files)
298 | self.current_position = self.process_rank * self.B * self.T
299 | self.tokens = _load_data_shard(self.files[self.current_shard])
300 |
301 | def next_batch(self):
302 | B = self.B
303 | T = self.T
304 | buf = self.tokens[self.current_position : self.current_position+B*T+1]
305 | buf = torch.tensor(buf.astype(np.int32), dtype=torch.long)
306 | x = (buf[:-1]).view(B, T) # inputs
307 | y = (buf[1:]).view(B, T) # targets
308 | # advance current position and load next shard if necessary
309 | self.current_position += B * T * self.num_processes
310 | if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens):
311 | self.advance()
312 | return x.cuda(), y.cuda()
313 |
314 | # -----------------------------------------------------------------------------
315 | # int main
316 |
317 | @dataclass
318 | class Hyperparameters:
319 | # data hyperparams
320 | input_bin : str = 'data/fineweb10B/fineweb_train_*.bin' # input .bin to train on
321 | input_val_bin : str = 'data/fineweb10B/fineweb_val_*.bin' # input .bin to eval validation loss on
322 | # optimization hyperparams
323 | batch_size : int = 8*64 # batch size, in sequences, across all devices
324 | device_batch_size : int = 64 # batch size, in sequences, per device
325 | sequence_length : int = 1024 # sequence length, in tokens
326 | num_iterations : int = 5100 # number of iterations to run
327 | learning_rate : float = 0.0036
328 | warmup_iters : int = 0
329 | warmdown_iters : int = 1450 # number of iterations of linear warmup/warmdown for triangular or trapezoidal schedule
330 | weight_decay : float = 0
331 | # evaluation and logging hyperparams
332 | val_loss_every : int = 125 # every how many steps to evaluate val loss? 0 for only at the end
333 | val_tokens : int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons
334 | save_every : int = 0 # every how many steps to save the checkpoint? 0 for only at the end
335 | args = Hyperparameters()
336 |
337 | # set up DDP (distributed data parallel). torchrun sets this env variable
338 | assert torch.cuda.is_available()
339 | dist.init_process_group(backend='nccl')
340 | ddp_rank = int(os.environ['RANK'])
341 | ddp_local_rank = int(os.environ['LOCAL_RANK'])
342 | ddp_world_size = int(os.environ['WORLD_SIZE'])
343 | device = f'cuda:{ddp_local_rank}'
344 | torch.cuda.set_device(device)
345 | print(f"using device: {device}")
346 | master_process = (ddp_rank == 0) # this process will do logging, checkpointing etc.
347 |
348 | # convenience variables
349 | B, T = args.device_batch_size, args.sequence_length
350 | # calculate the number of steps to take in the val loop.
351 | assert args.val_tokens % (B * T * ddp_world_size) == 0
352 | val_steps = args.val_tokens // (B * T * ddp_world_size)
353 | # calculate the steps of gradient accumulation required to attain the desired global batch size.
354 | assert args.batch_size % (B * ddp_world_size) == 0
355 | train_accumulation_steps = args.batch_size // (B * ddp_world_size)
356 |
357 | # load tokens
358 | train_loader = DistributedDataLoader(args.input_bin, B, T, ddp_rank, ddp_world_size)
359 | val_loader = DistributedDataLoader(args.input_val_bin, B, T, ddp_rank, ddp_world_size)
360 | if master_process:
361 | print(f"Training DataLoader: total number of tokens: {train_loader.ntok_total} across {len(train_loader.files)} files")
362 | print(f"Validation DataLoader: total number of tokens: {val_loader.ntok_total} across {len(val_loader.files)} files")
363 | x, y = train_loader.next_batch()
364 |
365 | # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. suggested to me by @Grad62304977.
366 | # this originates from Karpathy's experiments.
367 | num_vocab = 50304
368 | model = GPT(GPTConfig(vocab_size=num_vocab, n_layer=12, n_head=6, n_embd=768))
369 | model = model.cuda()
370 | if hasattr(config, "coordinate_descent_tuning"):
371 | config.coordinate_descent_tuning = True # suggested by @Chillee
372 | model = torch.compile(model)
373 | # here we wrap model into DDP container
374 | model = DDP(model, device_ids=[ddp_local_rank])
375 | raw_model = model.module # always contains the "raw" unwrapped model
376 | ctx = torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16)
377 |
378 | # init the optimizer(s)
379 | optimizer1 = torch.optim.AdamW(raw_model.lm_head.parameters(), lr=args.learning_rate, betas=(0.9, 0.95),
380 | weight_decay=args.weight_decay, fused=True)
381 | optimizer2 = Muon(raw_model.transformer.h.parameters(), lr=0.1*args.learning_rate, momentum=0.95)
382 | optimizers = [optimizer1, optimizer2]
383 | # learning rate decay scheduler (linear warmup and warmdown)
384 | def get_lr(it):
385 | assert it <= args.num_iterations
386 | # 1) linear warmup for warmup_iters steps
387 | if it < args.warmup_iters:
388 | return (it+1) / args.warmup_iters
389 | # 2) constant lr for a while
390 | elif it < args.num_iterations - args.warmdown_iters:
391 | return 1.0
392 | # 3) linear warmdown
393 | else:
394 | decay_ratio = (args.num_iterations - it) / args.warmdown_iters
395 | return decay_ratio
396 | schedulers = [torch.optim.lr_scheduler.LambdaLR(opt, get_lr) for opt in optimizers]
397 |
398 | # begin logging
399 | if master_process:
400 | run_id = str(uuid.uuid4())
401 | logdir = 'logs/%s/' % run_id
402 | os.makedirs(logdir, exist_ok=True)
403 | logfile = 'logs/%s.txt' % run_id
404 | # create the log file
405 | with open(logfile, "w") as f:
406 | # begin the log by printing this file (the Python code)
407 | f.write('='*100 + '\n')
408 | f.write(code)
409 | f.write('='*100 + '\n')
410 | # log information about the hardware/software environment this is running on
411 | # and print the full `nvidia-smi` to file
412 | f.write(f"Running pytorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}\nnvidia-smi:\n")
413 | import subprocess
414 | result = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
415 | f.write(f'{result.stdout}\n')
416 | f.write('='*100 + '\n')
417 |
418 | training_time_ms = 0
419 | # start the clock
420 | torch.cuda.synchronize()
421 | t0 = time.time()
422 | # begin training
423 | train_loader.reset()
424 | for step in range(args.num_iterations + 1):
425 | last_step = (step == args.num_iterations)
426 | # This effectively ignores timing first 10 steps, which are slower for weird reasons.
427 | # Alternately, and slightly more correctly in terms of benchmarking, we could do 10
428 | # steps with dummy data first, and then re-initialize the model and reset the loader.
429 | if step == 10:
430 | training_time_ms = 0
431 | t0 = time.time()
432 | timed_steps = float('nan') if step <= 11 else (step - 10) + 1 # <= 11 to avoid bug in val
433 |
434 | # once in a while evaluate the validation dataset
435 | if (last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0)):
436 | # stop the clock
437 | torch.cuda.synchronize()
438 | training_time_ms += 1000 * (time.time() - t0)
439 | # run validation batches
440 | model.eval()
441 | val_loader.reset()
442 | val_loss = 0.0
443 | for _ in range(val_steps):
444 | x_val, y_val = val_loader.next_batch()
445 | with ctx: # of course, we'd like to use no_grad() here too, but that creates a torch.compile error for some reason
446 | _, loss = model(x_val, y_val, return_logits=False)
447 | val_loss += loss.detach()
448 | del loss
449 | dist.all_reduce(val_loss, op=dist.ReduceOp.AVG)
450 | val_loss /= val_steps
451 | # log val loss to console and to logfile
452 | if master_process:
453 | print(f'step:{step}/{args.num_iterations} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms')
454 | with open(logfile, "a") as f:
455 | f.write(f'step:{step}/{args.num_iterations} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms\n')
456 | # start the clock again
457 | torch.cuda.synchronize()
458 | t0 = time.time()
459 |
460 | if master_process and (last_step or (args.save_every > 0 and step % args.save_every == 0)):
461 | # stop the clock
462 | torch.cuda.synchronize()
463 | training_time_ms += 1000 * (time.time() - t0)
464 | # save the state of the training process
465 | log = dict(step=step, code=code, model=raw_model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers])
466 | torch.save(log, 'logs/%s/state_step%06d.pt' % (run_id, step))
467 | # start the clock again
468 | torch.cuda.synchronize()
469 | t0 = time.time()
470 |
471 | # bit confusing: we want to make sure to eval on 0th iteration
472 | # but also after the very last iteration. so we loop for step <= num_iterations
473 | # instead of just < num_iterations (one extra due to <=), only to do
474 | # the validation/sampling one last time, and then we break right here as we're done.
475 | if last_step:
476 | break
477 |
478 | # --------------- TRAINING SECTION BEGIN -----------------
479 | model.train()
480 | for i in range(1, train_accumulation_steps+1):
481 | # forward pass
482 | with ctx:
483 | _, loss = model(x, y, return_logits=False)
484 | train_loss = loss.detach()
485 | # advance the dataset for the next batch
486 | x, y = train_loader.next_batch()
487 | # backward pass
488 | if i < train_accumulation_steps:
489 | with model.no_sync(): # there's no need to sync gradients every accumulation step
490 | loss.backward()
491 | else:
492 | loss.backward() # just sync on the last step
493 | for p in model.parameters():
494 | p.grad /= train_accumulation_steps
495 | # step the optimizers and schedulers
496 | for opt, sched in zip(optimizers, schedulers):
497 | opt.step()
498 | sched.step()
499 | # null the gradients
500 | model.zero_grad(set_to_none=True)
501 | # --------------- TRAINING SECTION END -------------------
502 | # everything that follows now is just diagnostics, prints, logging, etc.
503 |
504 | #dist.all_reduce(train_loss, op=dist.ReduceOp.AVG) # all-reducing the training loss would be more correct in terms of logging, but slower
505 | if master_process:
506 | approx_time = training_time_ms + 1000 * (time.time() - t0)
507 | print(f"step:{step+1}/{args.num_iterations} train_loss:{train_loss.item():.4f} train_time:{approx_time:.0f}ms step_avg:{approx_time/timed_steps:.2f}ms")
508 | with open(logfile, "a") as f:
509 | f.write(f"step:{step+1}/{args.num_iterations} train_loss:{train_loss.item():.4f} train_time:{approx_time:.0f}ms step_avg:{approx_time/timed_steps:.2f}ms\n")
510 |
511 | if master_process:
512 | print(f"peak memory consumption: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB")
513 |
514 | # -------------------------------------------------------------------------
515 | # clean up nice
516 | dist.destroy_process_group()
517 |
--------------------------------------------------------------------------------
/records/102924_Optimizers/README.md:
--------------------------------------------------------------------------------
1 | # Optimizer comparison for NanoGPT speedrunning
2 |
3 | This is a comparison between the four best optimizers I am aware of for NanoGPT speedrunning. They are compared using the 10/18/24 NanoGPT speedrunning record.
4 |
5 | Reproducible logs:
6 | * [Adam](95a9fd44-7c13-49c7-b324-3e7d9e23a499.txt)
7 | * [DistributedShampoo](8bfe4e35-c3fc-4b70-a984-3be937b71ff3)
8 | * [SOAP](e21a2838-a0f2-46f2-a247-db0021165682.txt)
9 | * [Muon](8d6193f4-27fc-4e68-899f-af70019a4d54.txt)
10 |
11 | Results:
12 | 
13 | 
14 |
15 | ### General notes for all optimizers
16 |
17 | All optimizers are run using zero weight decay (which is found to be empirically optimal).
18 |
19 | And they are all run with a warmup-stable-decay / trapezoidal schedule, which also seems to be optimal. That's what causes the kink in the loss curve ~75% of the way to the end.
20 |
21 | In addition, in all cases, we optimize the shared embedding/head layer just using Adam (which is also found to be empirically optimal).
22 | Note that in the following code snippets, `raw_model.transformer.h.parameters()` gives all parameters besides those two.
23 |
24 | In each case, the hyperparameters are the best ones I could find in around 20 attempts.
25 |
26 | ## [Adam](95a9fd44-7c13-49c7-b324-3e7d9e23a499.txt)
27 | The optimizer here is equivalent to:
28 | ```
29 | torch.optim.Adam(raw_model.transformer.h.parameters(), lr=0.0018, betas=(0.9, 0.95))
30 | ```
31 |
32 |
33 | ## [DistributedShampoo](8bfe4e35-c3fc-4b70-a984-3be937b71ff3.txt)
34 | Run as follows:
35 | ```
36 | DistributedShampoo(
37 | raw_model.transformer.h.parameters(),
38 | lr=0.0018,
39 | betas=(0.95, 0.95),
40 | epsilon=1e-12,
41 | weight_decay=0,
42 | max_preconditioner_dim=8192,
43 | precondition_frequency=10,
44 | use_decoupled_weight_decay=True,
45 | grafting_config=AdamGraftingConfig(
46 | beta2=0.95,
47 | epsilon=1e-8,
48 | ),
49 | distributed_config=DDPShampooConfig(
50 | communication_dtype=CommunicationDType.FP32,
51 | num_trainers_per_group=8,
52 | communicate_params=False,
53 | ),
54 | )
55 | ```
56 |
57 | This is using the official `DistributedShampoo` implementation from [here](https://github.com/facebookresearch/optimizers/tree/ad2809a291c01859f68fcabbcb49a2aa75fd7827/distributed_shampoo).
58 |
59 | Things that turned out to be important:
60 | * Don't use epsilon above 1e-8; this loses performance. Epsilon 1e-12 performs as well as 1e-15
61 | * Betas=(0.95, 0.95) seemed optimal, which turns out to be the same thing that SOAP uses
62 | * Higher preconditioner update frequency is better but slower
63 |
64 | I'm open to hyperparameter suggestions; the experiment takes ~20-30 minutes to run on a fresh 8xH100 instance, so it's not hard for me to run more attempts.
65 |
66 |
67 | ## [SOAP](e21a2838-a0f2-46f2-a247-db0021165682.txt)
68 | ```
69 | SOAP(model.transformer.h.parameters(), lr=0.0018, betas=(.95, .95), precondition_frequency=10)
70 | ```
71 |
72 | This is using the official SOAP implementation [here](https://github.com/nikhilvyas/SOAP/blob/bbce86e890d3b697380f4376acb600c2d6c3d203/soap.py).
73 |
74 | Based on conversations with the authors, it is likely that a future SOAP implementation will significantly reduce the wallclock overhead.
75 |
76 |
77 | ## [Muon](8d6193f4-27fc-4e68-899f-af70019a4d54.txt)
78 | ```
79 | Muon(raw_model.transformer.h.parameters(), lr=0.02, momentum=0.95)
80 | ```
81 |
82 |
83 | ## Openness
84 |
85 | These training logs are reproducible (just cut out the part besides the code, and run it using the `run.sh` in the top-level folder). They take 12-25 minutes to run.
86 |
87 | I tried to do a good job sweeping the hyperparameters for each optimizer, but I can easily have missed something, or just not have performed enough runs.
88 |
89 | Therefore, I am interested in any better hyperparameter settings which other researchers can find, for any of the optimizers.
90 | If you post or send me your own reproducible log with one of these optimizers, I will be very happy to boost it in any way I can.
91 |
92 | ## Appendix: Negative results
93 |
94 | I believe it was Shazeer who said something like "negative results in machine learning are not worth much, because your inability to make something work doesn't prove that it can't work"
95 |
96 | Given that disclaimer, here are some optimizers that I tried to make work, but was unable to get a significant boost over Adam with:
97 | * Sophia
98 | * Lion
99 | * AdamWScheduleFree
100 | * AdEmaMix (actually this was slightly better than Adam, just not enough to get near competing with the three Shampoo-like optimizers)
101 |
102 | Of course, this is just for NanoGPT speedrunning (short train duration); it's quite possible they work better at longer training duration or for larger models.
103 |
104 |
--------------------------------------------------------------------------------
/records/102924_Optimizers/nanogpt_speedrun81w.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KellerJordan/modded-nanogpt/ca964e982191830eebbd155e185937077511a8aa/records/102924_Optimizers/nanogpt_speedrun81w.png
--------------------------------------------------------------------------------
/records/102924_Optimizers/nanogpt_speedrun82w.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KellerJordan/modded-nanogpt/ca964e982191830eebbd155e185937077511a8aa/records/102924_Optimizers/nanogpt_speedrun82w.png
--------------------------------------------------------------------------------
/records/110324_UntieEmbed/README.md:
--------------------------------------------------------------------------------
1 | # New record 11/03/24
2 |
3 |
4 | New NanoGPT training speed record: 3.28 FineWeb val loss in 10.8 minutes on 8xH100
5 |
6 | Previous record: 12.0 minutes
7 | Changelog:
8 | - untied embed and head weights
9 | - added RMSNorm after embed
10 | - init head to zero
11 |
12 | Driven by @Grad62304977
13 |
14 | ---
15 |
16 | Technically, this is somewhat of an "any%" record, since untying the embedding and lm_head adds 39M parameters.
17 |
18 | However, it doesn't change the number of active parameters or the inference throughput. Future records will stay constrained to 124M active parameters.
19 |
20 | ---
21 |
22 | Like the last architectural change, this record was driven by @Grad62304977. I just finetuned some things and did bookkeeping.
23 |
24 | ---
25 |
26 | Shoutout to @cloneofsimo whose scaling guide already suggests initializing the head to zero. This works quite well and is a significant fraction of the record.
27 |
28 |
--------------------------------------------------------------------------------
/records/110424_50Bruns/README.md:
--------------------------------------------------------------------------------
1 | # 50B-token runs
2 |
3 | This folder contains four runs generated by extending the 11/03/24 speedrun record to 50B FineWeb tokens.
4 | The goal is to test how the speedrun generalizes to long durations, and especially how well Muon does.
5 |
6 | We compare two things:
7 | 1. We compare Muon to Adam as the optimizer for the transformer body. (The head and embedding are always optimized by Adam.)
8 | 2. We compare training on 5 epochs of 10B tokens to training on 50B tokens. (Surprisingly this does about the same)
9 |
10 | The four resulting runs are as follows:
11 |
12 | * [Muon 50B tokens](./530f3ee1-8862-4d21-be2b-da10eb05e6a9.txt) (HellaSwag=35.82)
13 | * [Adam 50B tokens](./69c33fc9-eabb-4a38-aa08-6922914eb405.txt) (HellaSwag=34.26)
14 | * [Muon 5x10B tokens](./4fbe61ec-f79a-4c19-836d-46d599deecce.txt) (HellaSwag=36.17)
15 | * [Adam 5x10B tokens](./3d715d41-453a-40d6-9506-421ba69766b2.txt) (HellaSwag=34.05)
16 |
17 | To get a sense of what a good HellaSwag score would be for this scale of model, here are some baselines:
18 | * Karpathy's baseline llm.c training (trained for 10B FineWeb tokens): 29.9
19 | * OpenAI GPT-2 (124M): 29.4
20 | * OpenAI GPT-3 (124M) (trained for 300B WebText tokens): 33.7
21 | * Huggingface SmolLM2-135M (trained for 2T FineWeb/DCLM/etc tokens): 42.1
22 |
23 | Note: I'm a little concerned that the learning rate schedule (WSD) and weight decay (zero), which are tuned for the speedrun duration,
24 | might become undertuned/suboptimal for trainings of this duration.
25 | It does look like the gap between Muon/Adam is too large to be closed by something like this, and the HellaSwag scores look quite reasonable, but you never know.
26 |
27 |
--------------------------------------------------------------------------------
/records/110624_ShortcutsTweaks/README.md:
--------------------------------------------------------------------------------
1 | # New record 11/06/24
2 |
3 | 8.2 minutes on 8xH100 (previous record: 10.8 minutes)
4 |
5 | 
6 | 
7 |
8 | * [Old record 11/03/24](d6b50d71-f419-4d26-bb39-a60d55ae7a04.txt)
9 | * [+shorten duration](4a71cc92-0f43-4058-a033-23e85c1e98f1.txt)
10 | * [+value residual](042f9e87-07e6-4504-bb04-4ec59a380211.txt) by @Grad62304977 following [1]
11 | * [+learnable lambda](43f60c4f-0448-4de7-83d9-643ca26f61e7.txt) @Grad62304977's innovation on top of [1]
12 | * [+embed shortcut](05b29e54-0be0-4a0f-a1e2-7d5317daedd3.txt)
13 | * [+momentum warmup](10119f53-7001-4248-bfd9-33d32427a912.txt)
14 | * [+tanh logit capping](dd7304a6-cc43-4d5e-adb8-c070111464a1.txt) by @Grad62304977 following [2]
15 |
16 | ## Code snippets
17 |
18 | ### Value residual
19 |
20 | In the attention layer:
21 | ```
22 | def forward(self, x, v1=None):
23 | ...
24 | v = self.c_v(x).view(B, T, self.n_head, self.head_dim)
25 | if v1 is None:
26 | v1 = v
27 | v = 0.5 * v + 0.5 * v1.view_as(v)
28 | ```
29 | Where the first block receives v1=None, and subsequent blocks receive v1 as the value produced by the first block.
30 |
31 | ### Learnable lambda
32 |
33 | In the attention block:
34 | ```
35 | def __init__(self, config):
36 | ...
37 | self.lamb = nn.Parameter(torch.tensor(0.5))
38 |
39 | def forward(self, x, v1=None):
40 | ...
41 | v = (1 - self.lamb) * v + self.lamb * v1.view_as(v)
42 | ```
43 | That is, we just replace the fixed 0.5 constant used in standard value residual [1] with a learnable scalar (optimized by Adam(lr=0.02)).
44 |
45 | ### Embed shortcut
46 |
47 | Replaces the standard transformer block with this:
48 |
49 | ```
50 | class Block(nn.Module):
51 |
52 | def __init__(self, config):
53 | super().__init__()
54 | self.attn = CausalSelfAttention(config)
55 | self.mlp = MLP(config)
56 | self.lambdas = nn.Parameter(torch.tensor([1., 0.]))
57 |
58 | def forward(self, x, x0):
59 | x = self.lambdas[0] * x + self.lambdas[1] * x0
60 | x = x + self.attn(F.rms_norm(x, (x.size(-1),)), v1)
61 | x = x + self.mlp(F.rms_norm(x, (x.size(-1),)))
62 | return x
63 | ```
64 |
65 | where the two scalars are optimized using Adam(lr=0.02), and `x0` is fed in from the initial embedding via:
66 | ```
67 | ...
68 | x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
69 | x = F.rms_norm(x, (x.size(-1),))
70 | x0 = x
71 | for block in self.transformer.h:
72 | x = block(x, x0)
73 | ...
74 | ```
75 |
76 | ### Momentum warmup
77 |
78 | Just adds the following two lines.
79 | ```
80 | frac = min(step/500, 1)
81 | optimizer3.param_groups[0]['momentum'] = (1 - frac) * 0.85 + frac * 0.95
82 | ```
83 | where `optimizer3` is the Muon for the body of the transformer.
84 |
85 | ### Tanh soft capping
86 |
87 | Just adds the following line.
88 |
89 | ```
90 | logits = 30 * torch.tanh(logits / 30)
91 | ```
92 |
93 |
94 | ## References
95 |
96 | 1. [Zhou, Zhanchao, et al. "Value Residual Learning For Alleviating Attention Concentration In Transformers." arXiv preprint arXiv:2410.17897 (2024).](https://arxiv.org/abs/2410.17897)
97 | 2. [Team, Gemma, et al. "Gemma 2: Improving open language models at a practical size." arXiv preprint arXiv:2408.00118 (2024).](https://arxiv.org/abs/2408.00118)
98 |
99 |
--------------------------------------------------------------------------------
/records/110624_ShortcutsTweaks/nanogpt_speedrun110.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KellerJordan/modded-nanogpt/ca964e982191830eebbd155e185937077511a8aa/records/110624_ShortcutsTweaks/nanogpt_speedrun110.png
--------------------------------------------------------------------------------
/records/110624_ShortcutsTweaks/nanogpt_speedrun111.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KellerJordan/modded-nanogpt/ca964e982191830eebbd155e185937077511a8aa/records/110624_ShortcutsTweaks/nanogpt_speedrun111.png
--------------------------------------------------------------------------------
/records/110924_Replicateleloykun/README.md:
--------------------------------------------------------------------------------
1 | This is a replication attempt for the record attempt described [here](https://x.com/leloykun/status/1854557419768254915) by @leloykun.
2 |
3 | The original record could not be directly accepted because it showed a slower wallclock time than the previous record -
4 | however, this was plausibly due to hardware differences, as the competitor's hardware was slightly slower.
5 |
6 | Therefore, to certify this attempt as the new record, here I replicated it on my own hardware.
7 | This did successfully reduce the wallclock time compared to the 11/07/24 record by ~11 seconds, however it also
8 | resulted in an invalid val loss of 3.2824, above the threshold of 3.28.
9 |
10 | The [original record attempt's reproducible log](https://github.com/leloykun/modded-nanogpt/blob/224f10d190677d9dc3c9c45da280078196a6fe40/records/110724_EmbeddingBetasCooldown/6c9d875b-ad91-46c9-9ede-2c7f998b9b16.txt) attained a val loss of 3.2798, just barely below the 3.28 threshold. So this difference is plausibly due to random inter-run variance.
11 |
12 | This indicates that the true average val loss of the run may be worse than 3.28, meaning I am **unable to certify it as the new record.**
13 |
14 | Ideally, all records should attain a low enough val loss such that >95% of runs attain below 3.28. Good evidence for this would be a single run
15 | attaining <= 3.278. Previous records have adhered to this rule, but admittedly it's hard to define precisely and is therefore mostly a matter of taste.
16 |
17 |
--------------------------------------------------------------------------------
/records/111024_UNetDoubleLr/README.md:
--------------------------------------------------------------------------------
1 | This is a record by Brendan Hogan Rappazzo [@brendanh0gan](https://x.com/brendanh0gan).
2 |
3 | New record: 7.23 minutes
4 |
5 | Previous record: 7.8 minutes
6 |
7 | Changelog:
8 | - Added U-net-like skip connections into the transformer
9 | - Doubled the learning rate
10 |
11 | ---
12 |
13 | This record was first posted [here](https://x.com/brendanh0gan/status/1855273758681866352), & then a few iterations were required to benchmark it on 8x SXM H100s.
14 | Brendan's fork of modded-nanogpt is [here](https://github.com/brendanhogan/modded-nanogpt/tree/master). The code for the record can also be extracted from the reproducible log in this folder.
15 |
16 |
--------------------------------------------------------------------------------
/records/111924_FlexAttention/README.md:
--------------------------------------------------------------------------------
1 | ## 11/19/24 FlexAttention record
2 |
3 | This training has significant variance of around 0.005 stddev between runs. So not all runs go beneath 3.28, though [the mean is around 3.279](https://x.com/YouJiacheng/status/1859876224639828068).
4 |
5 | This variance is probably caused by the previous record doubling the learning rate, plus this record significantly shortening the duration.
6 |
--------------------------------------------------------------------------------
/records/120424_ValueEmbed/README.md:
--------------------------------------------------------------------------------
1 | ## Statistical tests
2 |
3 | ```
4 | accs = [3.2759, 3.2781, 3.2791, 3.2771, 3.2838, 3.2749, 3.2793, 3.279, 3.2794, 3.2744, 3.2751,
5 | 3.2845, 3.2736, 3.2783, 3.2793, 3.2779, 3.2756, 3.281, 3.2803, 3.2766, 3.2851, 3.275,
6 | 3.2778, 3.2723, 3.2842, 3.2735, 3.275, 3.2796, 3.2782, 3.2758, 3.2763, 3.2751, 3.2791,
7 | 3.2804, 3.2725, 3.2898, 3.2718, 3.2764, 3.271, 3.2745]
8 |
9 | import scipy.stats
10 | print('p=%.4f' % scipy.stats.ttest_1samp(accs, 3.28, alternative='less').pvalue)
11 | # p=0.0003 (statistically significant)
12 |
13 | import torch
14 | print(torch.std_mean(torch.tensor(accs)))
15 | # (tensor(0.0040), tensor(3.2777))
16 | ```
17 |
18 | ## ChangeLog
19 |
20 | * Added 12 new embedding layers which get mixed into the value activations at each layer. (=463M new parameters, of which 9216 are active per token)
21 |
22 | ## Contributors
23 |
24 | * @KoszarskyB
25 |
26 |
27 |
--------------------------------------------------------------------------------
/records/120424_ValueEmbed/train_gpt2.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | with open(sys.argv[0]) as f:
4 | code = f.read() # read the code of this file ASAP, for logging
5 | import uuid
6 | import glob
7 | import time
8 | import contextlib
9 | from dataclasses import dataclass
10 |
11 | import numpy as np
12 | import torch
13 | from torch import nn
14 | import torch.nn.functional as F
15 | import torch.distributed as dist
16 | import torch._inductor.config as config
17 | from torch.nn.parallel import DistributedDataParallel as DDP
18 | # Use of FlexAttention contributed by @KoszarskyB
19 | from torch.nn.attention.flex_attention import flex_attention, create_block_mask
20 | flex_attention = torch.compile(flex_attention, dynamic=False)
21 | create_block_mask = torch.compile(create_block_mask, dynamic=False)
22 |
23 | # -----------------------------------------------------------------------------
24 | # Muon optimizer
25 |
26 | def zeropower_via_svd(G, steps=None):
27 | U, S, V = G.svd()
28 | return U @ V.T
29 |
30 | @torch.compile
31 | def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7):
32 | """
33 | Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
34 | quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
35 | of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
36 | zero even beyond the point where the iteration no longer converges all the way to one everywhere
37 | on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
38 | where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
39 | performance at all relative to UV^T, where USV^T = G is the SVD.
40 | """
41 | assert len(G.shape) == 2
42 | a, b, c = (3.4445, -4.7750, 2.0315)
43 | X = G.bfloat16()
44 | X /= (X.norm() + eps) # ensure top singular value <= 1
45 | if G.size(0) > G.size(1):
46 | X = X.T
47 | for _ in range(steps):
48 | A = X @ X.T
49 | B = b * A + c * A @ A # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
50 | X = a * X + B @ X
51 | if G.size(0) > G.size(1):
52 | X = X.T
53 | return X
54 |
55 | zeropower_backends = dict(svd=zeropower_via_svd, newtonschulz5=zeropower_via_newtonschulz5)
56 |
57 | class Muon(torch.optim.Optimizer):
58 | """
59 | Muon - MomentUm Orthogonalized by Newton-schulz
60 |
61 | Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
62 | processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
63 | matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
64 | the advantage that it can be stably run in bfloat16 on the GPU.
65 |
66 | Some warnings:
67 | - This optimizer assumes that all parameters passed in are 2D.
68 | - It should not be used for the embedding layer, the final fully connected layer, or any {0,1}-D
69 | parameters; those should all be optimized by a standard method (e.g., AdamW).
70 | - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
71 | - We believe it is unlikely to work well for training with small batch size.
72 | - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
73 | - We have not yet tried this optimizer for training scenarios larger than NanoGPT (124M).
74 |
75 | Arguments:
76 | lr: The learning rate used by the internal SGD.
77 | momentum: The momentum used by the internal SGD.
78 | nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
79 | backend: The chosen backend for the orthogonalization step. (recommended: 'newtonschulz5')
80 | backend_steps: The number of iteration steps to use in the backend, if it is iterative.
81 | """
82 | def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True,
83 | backend='newtonschulz5', backend_steps=5):
84 | defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, backend=backend, backend_steps=backend_steps)
85 | super().__init__(params, defaults)
86 |
87 | def step(self):
88 |
89 | for group in self.param_groups:
90 |
91 | lr = group['lr']
92 | momentum = group['momentum']
93 | zeropower_backend = zeropower_backends[group['backend']]
94 |
95 | # generate weight updates in distributed fashion
96 | total_params = sum(p.numel() for p in group['params'])
97 | updates_flat = torch.zeros(total_params, device='cuda', dtype=torch.bfloat16)
98 | curr_idx = 0
99 | for i, p in enumerate(group['params']):
100 | # luckily this will perfectly distribute a transformer with multiple of 4 layers to 8 GPUs
101 | if i % int(os.environ['WORLD_SIZE']) == int(os.environ['RANK']):
102 | g = p.grad
103 | assert g is not None
104 | state = self.state[p]
105 | if 'momentum_buffer' not in state:
106 | state['momentum_buffer'] = torch.zeros_like(g)
107 | buf = state['momentum_buffer']
108 | buf.mul_(momentum).add_(g)
109 | g = g.add(buf, alpha=momentum) if group['nesterov'] else buf
110 | g = zeropower_backend(g, steps=group['backend_steps'])
111 | g *= max(1, g.size(0)/g.size(1))**0.5
112 | updates_flat[curr_idx:curr_idx+p.numel()] = g.flatten()
113 | curr_idx += p.numel()
114 |
115 | # sync updates across devices. we are not memory-constrained so can do this simple deserialization
116 | dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM)
117 |
118 | # deserialize and apply updates
119 | curr_idx = 0
120 | for p in group['params']:
121 | g = updates_flat[curr_idx:curr_idx+p.numel()].view_as(p.data).type_as(p.data)
122 | p.data.add_(g, alpha=-lr)
123 | curr_idx += p.numel()
124 |
125 | # -----------------------------------------------------------------------------
126 | # PyTorch nn.Module definitions for the GPT-2 model
127 |
128 | def norm(x):
129 | return F.rms_norm(x, (x.size(-1),))
130 |
131 | class CastedLinear(nn.Linear):
132 |
133 | def __init__(self, in_features, out_features):
134 | super().__init__(in_features, out_features, bias=False)
135 |
136 | def forward(self, x):
137 | return F.linear(x, self.weight.to(x.dtype))
138 |
139 | class Rotary(torch.nn.Module):
140 |
141 | def __init__(self, dim, base=10000):
142 | super().__init__()
143 | self.register_buffer('inv_freq', (1 / base) ** (torch.arange(0, dim, 2) / dim))
144 | self.seq_len_cached = None
145 | self.cos_cached = None
146 | self.sin_cached = None
147 |
148 | def forward(self, x):
149 | seq_len = x.shape[1]
150 | if seq_len != self.seq_len_cached:
151 | t = torch.arange(seq_len, device=x.device)
152 | freqs = torch.outer(t, self.inv_freq)
153 | self.seq_len_cached = seq_len
154 | self.cos_cached = freqs.cos()
155 | self.sin_cached = freqs.sin()
156 | cos, sin = self.cos_cached[None, :, None, :], self.sin_cached[None, :, None, :]
157 | # apply_rotary_emb(x, cos, sin)
158 | x1, x2 = x.chunk(2, dim=3)
159 | y1 = x1 * cos + x2 * sin
160 | y2 = x1 * (-sin) + x2 * cos
161 | return torch.cat((y1, y2), 3).type_as(x)
162 |
163 | class CausalSelfAttention(nn.Module):
164 |
165 | def __init__(self, dim, n_head):
166 | super().__init__()
167 | assert dim % n_head == 0
168 | self.n_head = n_head
169 | self.c_q = CastedLinear(dim, dim)
170 | self.c_k = CastedLinear(dim, dim)
171 | self.c_v = CastedLinear(dim, dim)
172 | # value residual lambda
173 | self.lamb = nn.Parameter(torch.tensor(0.5)) # @Grad62304977
174 | # rotary embeddings
175 | self.rotary = Rotary(dim // n_head) # dim // n_head = head_dim
176 | # output projection
177 | self.c_proj = CastedLinear(dim, dim)
178 | self.c_proj.weight.data.zero_() # zero init suggested by @Grad62304977
179 |
180 | def forward(self, x, vi, block_mask):
181 | B, T = x.size(0), x.size(1) # batch size, sequence length
182 | assert B == 1, "Must use batch size = 1 for FlexAttention"
183 | q = self.c_q(x).view(B, T, self.n_head, -1)
184 | k = self.c_k(x).view(B, T, self.n_head, -1)
185 | v = self.c_v(x).view(B, T, self.n_head, -1)
186 | v = (1 - self.lamb) * v + self.lamb * vi.view_as(v) # @Grad62304977
187 | q, k = norm(q), norm(k) # QK norm suggested by @Grad62304977
188 | q, k = self.rotary(q), self.rotary(k)
189 | y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask)
190 | y = y.transpose(1, 2).contiguous().view_as(x) # re-assemble all head outputs side by side
191 | y = self.c_proj(y)
192 | return y
193 |
194 | class MLP(nn.Module):
195 |
196 | def __init__(self, dim):
197 | super().__init__()
198 | self.c_fc = CastedLinear(dim, 4 * dim)
199 | self.c_proj = CastedLinear(4 * dim, dim)
200 | self.c_proj.weight.data.zero_() # zero init suggested by @Grad62304977
201 |
202 | def forward(self, x):
203 | x = self.c_fc(x)
204 | x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977
205 | x = self.c_proj(x)
206 | return x
207 |
208 | class Block(nn.Module):
209 |
210 | def __init__(self, config):
211 | super().__init__()
212 | self.attn = CausalSelfAttention(config.n_embd, config.n_head)
213 | self.mlp = MLP(config.n_embd)
214 | self.lambdas = nn.Parameter(torch.tensor([1., 0.]))
215 |
216 | def forward(self, x, vi, x0, block_mask):
217 | x = self.lambdas[0] * x + self.lambdas[1] * x0
218 | x = x + self.attn(norm(x), vi, block_mask)
219 | x = x + self.mlp(norm(x))
220 | return x
221 |
222 | # -----------------------------------------------------------------------------
223 | # The main GPT-2 model
224 |
225 | @dataclass
226 | class GPTConfig:
227 | vocab_size : int = 50304
228 | n_layer : int = 12
229 | n_head : int = 6 # head dim 128 suggested by @Grad62304977
230 | n_embd : int = 768
231 |
232 | class GPT(nn.Module):
233 |
234 | def __init__(self, config):
235 | super().__init__()
236 |
237 | # U-net design by @brendanh0gan
238 | self.num_encoder_layers = config.n_layer // 2 # Half of the layers for encoder
239 | self.num_decoder_layers = config.n_layer - self.num_encoder_layers # Remaining for decoder
240 | # Add learnable skip connection weights for decoder layers
241 | self.skip_weights = nn.Parameter(torch.ones(self.num_decoder_layers))
242 |
243 | self.transformer = nn.ModuleDict(dict(
244 | wte = nn.Embedding(config.vocab_size, config.n_embd),
245 | # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual learning
246 | vte = nn.Embedding(config.vocab_size, config.n_embd*12),
247 | h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
248 | ))
249 | self.lm_head = CastedLinear(config.n_embd, config.vocab_size)
250 | self.lm_head.weight.data.zero_() # @Grad62304977
251 |
252 | def forward(self, idx, target, attn_blocksize):
253 |
254 | docs = (idx == 50256).cumsum(0)
255 | def document_causal_mask(b, h, q_idx, kv_idx):
256 | causal_mask = q_idx >= kv_idx
257 | document_mask = docs[q_idx] == docs[kv_idx]
258 | window_mask = q_idx - kv_idx < attn_blocksize
259 | return causal_mask & document_mask & window_mask
260 |
261 | S = len(idx)
262 | block_mask = create_block_mask(document_causal_mask, None, None, S, S, device="cuda", _compile=True)
263 |
264 | # forward the GPT model itself
265 | x = self.transformer.wte(idx[None]) # token embeddings of shape (b, t, n_embd)
266 | x = norm(x) # @Grad62304977
267 | x0 = x
268 | vi = self.transformer.vte(idx[None]).chunk(12, dim=-1)
269 |
270 | # Store outputs for U-Net skip connections
271 | skip_connections = []
272 | # Encoder pass - process only the first half of the blocks
273 | for i in range(self.num_encoder_layers):
274 | x = self.transformer.h[i](x, vi[i], x0, block_mask)
275 | skip_connections.append(x)
276 | # Decoder pass - process the remaining blocks with weighted skip connections
277 | for i in range(self.num_decoder_layers):
278 | x = x + self.skip_weights[i] * skip_connections.pop()
279 | x = self.transformer.h[self.num_encoder_layers + i](x, vi[self.num_encoder_layers+i], x0, block_mask)
280 |
281 | x = norm(x)
282 | logits = self.lm_head(x)
283 | logits = 30 * torch.tanh(logits / 30) # @Grad62304977
284 | logits = logits.float()
285 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target.view(-1))
286 | return loss
287 |
288 | # -----------------------------------------------------------------------------
289 | # Our own simple Distributed Data Loader
290 |
291 | def _peek_data_shard(filename):
292 | # only reads the header, returns header data
293 | with open(filename, "rb") as f:
294 | # first read the header, which is 256 int32 integers (4 bytes each)
295 | header = np.frombuffer(f.read(256*4), dtype=np.int32)
296 | if header[0] != 20240520:
297 | print("ERROR: magic number mismatch in the data .bin file!")
298 | print("---> HINT: Are you passing in a correct file with --input_bin?")
299 | print("---> HINT: Dataset encoding changed recently, re-run data prepro or refer again to README")
300 | print("---> HINT: For example re-run: `python dev/data/tinyshakespeare.py`, then re-try")
301 | exit(1)
302 | assert header[1] == 1, "unsupported version"
303 | ntok = header[2] # number of tokens (claimed)
304 | return ntok # for now just return the number of tokens
305 |
306 | def _load_data_shard(filename):
307 | with open(filename, "rb") as f:
308 | # first read the header, which is 256 int32 integers (4 bytes each)
309 | header = np.frombuffer(f.read(256*4), dtype=np.int32)
310 | assert header[0] == 20240520, "magic number mismatch in the data .bin file"
311 | assert header[1] == 1, "unsupported version"
312 | ntok = header[2] # number of tokens (claimed)
313 | # the rest of it are tokens, stored as uint16
314 | tokens = np.frombuffer(f.read(), dtype=np.uint16)
315 | assert len(tokens) == ntok, "number of tokens read does not match header?"
316 | return tokens
317 |
318 | class DistributedDataLoader:
319 | def __init__(self, filename_pattern, T, process_rank, num_processes):
320 | self.process_rank = process_rank
321 | self.num_processes = num_processes
322 | self.T = T
323 |
324 | # glob files that match the pattern
325 | self.files = sorted(glob.glob(filename_pattern))
326 | assert len(self.files) > 0, f"did not find any files that match the pattern {filename_pattern}"
327 |
328 | # load and validate all data shards, count number of tokens in total
329 | ntok_total = 0
330 | for fname in self.files:
331 | shard_ntok = _peek_data_shard(fname)
332 | assert shard_ntok >= num_processes * T + 1
333 | ntok_total += int(shard_ntok)
334 | self.ntok_total = ntok_total
335 |
336 | self.reset()
337 |
338 | def reset(self):
339 | self.current_shard = -1
340 | self.advance()
341 |
342 | def advance(self): # advance to next data shard
343 | self.current_shard = (self.current_shard + 1) % len(self.files)
344 | self.current_position = self.process_rank * self.T
345 | self.tokens = _load_data_shard(self.files[self.current_shard])
346 |
347 | def next_batch(self):
348 | batch_size = self.T * self.num_processes
349 | buf = self.tokens[self.current_position:self.current_position+self.T+1]
350 | buf = torch.tensor(buf.astype(np.int32), dtype=torch.long)
351 | x = buf[:-1] # inputs
352 | y = buf[1:] # targets
353 | # advance current position and load next shard if necessary
354 | self.current_position += batch_size
355 | if self.current_position + batch_size >= len(self.tokens):
356 | self.advance()
357 | return x.cuda(), y.cuda()
358 |
359 | # -----------------------------------------------------------------------------
360 | # int main
361 |
362 | @dataclass
363 | class Hyperparameters:
364 | # data hyperparams
365 | input_bin : str = 'data/fineweb10B/fineweb_train_*.bin' # input .bin to train on
366 | input_val_bin : str = 'data/fineweb10B/fineweb_val_*.bin' # input .bin to eval validation loss on
367 | # optimization hyperparams
368 | batch_size : int = 8 # batch size, in sequences, across all devices
369 | sequence_length : int = 64*1024 # sequence length, in tokens
370 | num_iterations : int = 1530 # number of iterations to run
371 | warmup_iters : int = 0
372 | cooldown_iters : int = 600 # number of iterations of linear warmup/cooldown for triangular or trapezoidal schedule
373 | weight_decay : float = 0
374 | # evaluation and logging hyperparams
375 | val_loss_every : int = 125 # every how many steps to evaluate val loss? 0 for only at the end
376 | val_tokens : int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons
377 | save_every : int = 0 # every how many steps to save the checkpoint? 0 for only at the end
378 | args = Hyperparameters()
379 |
380 | # set up DDP (distributed data parallel). torchrun sets this env variable
381 | assert torch.cuda.is_available()
382 | dist.init_process_group(backend='nccl')
383 | ddp_rank = int(os.environ['RANK'])
384 | ddp_local_rank = int(os.environ['LOCAL_RANK'])
385 | ddp_world_size = int(os.environ['WORLD_SIZE'])
386 | device = f'cuda:{ddp_local_rank}'
387 | torch.cuda.set_device(device)
388 | print(f"using device: {device}")
389 | master_process = (ddp_rank == 0) # this process will do logging, checkpointing etc.
390 |
391 | # begin logging
392 | logfile = None
393 | if master_process:
394 | run_id = str(uuid.uuid4())
395 | logdir = 'logs/%s/' % run_id
396 | os.makedirs(logdir, exist_ok=True)
397 | logfile = 'logs/%s.txt' % run_id
398 | # create the log file
399 | with open(logfile, "w") as f:
400 | # begin the log by printing this file (the Python code)
401 | f.write(code)
402 | f.write('='*100 + '\n')
403 | def print0(s, logonly=False):
404 | if master_process:
405 | with open(logfile, "a") as f:
406 | if not logonly:
407 | print(s)
408 | f.write(s+'\n')
409 | # log information about the hardware/software environment this is running on
410 | # and print the full `nvidia-smi` to file
411 | print0(f"Running pytorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}\nnvidia-smi:")
412 | import subprocess
413 | result = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
414 | print0(f'{result.stdout}', logonly=True)
415 | print0('='*100, logonly=True)
416 |
417 | # convenience variables
418 | T = args.sequence_length
419 | # calculate the number of steps to take in the val loop.
420 | assert args.val_tokens % (T * ddp_world_size) == 0
421 | val_steps = args.val_tokens // (T * ddp_world_size)
422 | # calculate the steps of gradient accumulation required to attain the desired global batch size.
423 | assert args.batch_size % (ddp_world_size) == 0
424 | train_accumulation_steps = args.batch_size // ddp_world_size
425 |
426 | # load tokens
427 | train_loader = DistributedDataLoader(args.input_bin, T, ddp_rank, ddp_world_size)
428 | val_loader = DistributedDataLoader(args.input_val_bin, T, ddp_rank, ddp_world_size)
429 | print0(f"Training DataLoader: total number of tokens: {train_loader.ntok_total} across {len(train_loader.files)} files")
430 | print0(f"Validation DataLoader: total number of tokens: {val_loader.ntok_total} across {len(val_loader.files)} files")
431 | print0('='*100, logonly=True)
432 | x, y = train_loader.next_batch()
433 |
434 | # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. suggested to me by @Grad62304977.
435 | # this originates from Karpathy's experiments.
436 | num_vocab = 50304
437 | model = GPT(GPTConfig(vocab_size=num_vocab, n_layer=12, n_head=6, n_embd=768))
438 | model = model.cuda().bfloat16()
439 | for m in model.modules():
440 | if isinstance(m, CastedLinear):
441 | m.float()
442 | if hasattr(config, "coordinate_descent_tuning"):
443 | config.coordinate_descent_tuning = True # suggested by @Chillee
444 | model = torch.compile(model)
445 | # here we wrap model into DDP container
446 | model = DDP(model, device_ids=[ddp_local_rank])
447 | raw_model = model.module # always contains the "raw" unwrapped model
448 |
449 | # init the optimizer(s)
450 | optimizer1 = torch.optim.Adam([raw_model.transformer.wte.weight, raw_model.transformer.vte.weight], lr=0.6, betas=(0.8, 0.95), fused=True)
451 | optimizer2 = torch.optim.Adam([raw_model.lm_head.weight], lr=0.008, betas=(0.8, 0.95), fused=True)
452 | params = list(raw_model.transformer.h.parameters())
453 | matrix_params = [p for p in params if p.ndim == 2]
454 | scalar_params = [p for p in params if p.ndim < 2] + [raw_model.skip_weights]
455 | optimizer3 = Muon(matrix_params, lr=0.05, momentum=0.95)
456 | optimizer4 = torch.optim.Adam(scalar_params, lr=0.04, betas=(0.8, 0.95), fused=True) # note that this learning rate is neither sensitive nor tuned
457 | optimizers = [optimizer1, optimizer2, optimizer3, optimizer4]
458 | # learning rate decay scheduler (linear warmup and cooldown)
459 | def get_lr(it):
460 | assert it <= args.num_iterations
461 | # 1) linear warmup for warmup_iters steps
462 | if it < args.warmup_iters:
463 | return (it+1) / args.warmup_iters
464 | # 2) constant lr for a while
465 | elif it < args.num_iterations - args.cooldown_iters:
466 | return 1.0
467 | # 3) linear cooldown
468 | else:
469 | decay_ratio = (args.num_iterations - it) / args.cooldown_iters
470 | return decay_ratio
471 | schedulers = [torch.optim.lr_scheduler.LambdaLR(opt, get_lr) for opt in optimizers]
472 |
473 | # Start training loop
474 | training_time_ms = 0
475 | # start the clock
476 | torch.cuda.synchronize()
477 | t0 = time.time()
478 | # begin training
479 | for step in range(args.num_iterations + 1):
480 | last_step = (step == args.num_iterations)
481 | # This effectively ignores timing first 10 steps, which are slower for weird reasons.
482 | # Alternately, and slightly more correctly in terms of benchmarking, we could do 10
483 | # steps with dummy data first, and then re-initialize the model and reset the loader.
484 | if step == 10:
485 | training_time_ms = 0
486 | t0 = time.time()
487 | timed_steps = float('nan') if step <= 11 else (step - 10) + 1 # <= 11 to avoid bug in val
488 |
489 | # Set the attention blocksize for the current step, in chunks of 64. By @fernbear.bsky.social
490 | attn_blocksize = torch.tensor(64*((step/args.num_iterations * (1792 - 64) + 64)//64), dtype=torch.int, device='cuda')
491 |
492 | # once in a while evaluate the validation dataset
493 | if (last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0)):
494 | # stop the clock
495 | torch.cuda.synchronize()
496 | training_time_ms += 1000 * (time.time() - t0)
497 | # run validation batches
498 | model.eval()
499 | val_loader.reset()
500 | val_loss = 0.0
501 | for _ in range(val_steps):
502 | with torch.no_grad():
503 | x_val, y_val = val_loader.next_batch()
504 | val_loss += model(x_val, y_val, attn_blocksize=attn_blocksize)
505 | dist.all_reduce(val_loss, op=dist.ReduceOp.AVG)
506 | val_loss /= val_steps
507 | # log val loss to console and to logfile
508 | print0(f'step:{step}/{args.num_iterations} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms')
509 | # start the clock again
510 | torch.cuda.synchronize()
511 | t0 = time.time()
512 |
513 | if master_process and (last_step or (args.save_every > 0 and step % args.save_every == 0)):
514 | # stop the clock
515 | torch.cuda.synchronize()
516 | training_time_ms += 1000 * (time.time() - t0)
517 | # save the state of the training process
518 | log = dict(step=step, code=code, model=raw_model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers])
519 | torch.save(log, 'logs/%s/state_step%06d.pt' % (run_id, step))
520 | # start the clock again
521 | torch.cuda.synchronize()
522 | t0 = time.time()
523 |
524 | # bit confusing: we want to make sure to eval on 0th iteration
525 | # but also after the very last iteration. so we loop for step <= num_iterations
526 | # instead of just < num_iterations (one extra due to <=), only to do
527 | # the validation/sampling one last time, and then we break right here as we're done.
528 | if last_step:
529 | break
530 |
531 | # --------------- TRAINING SECTION BEGIN -----------------
532 | model.train()
533 | for i in range(1, train_accumulation_steps+1):
534 | ctx = model.no_sync() if i < train_accumulation_steps else contextlib.nullcontext()
535 | with ctx: # there's no need to sync gradients every accumulation step
536 | # forward pass
537 | loss = model(x, y, attn_blocksize=attn_blocksize)
538 | # advance the dataset for the next batch
539 | x, y = train_loader.next_batch()
540 | # backward pass
541 | loss.backward()
542 | train_loss = loss.detach()
543 | for p in model.parameters():
544 | p.grad /= train_accumulation_steps
545 | # momentum warmup for Muon
546 | frac = min(step/300, 1)
547 | optimizer3.param_groups[0]['momentum'] = (1 - frac) * 0.85 + frac * 0.95
548 | # step the optimizers and schedulers
549 | for opt, sched in zip(optimizers, schedulers):
550 | opt.step()
551 | sched.step()
552 | # null the gradients
553 | model.zero_grad(set_to_none=True)
554 | # --------------- TRAINING SECTION END -------------------
555 | # everything that follows now is just diagnostics, prints, logging, etc.
556 |
557 | #dist.all_reduce(train_loss, op=dist.ReduceOp.AVG) # all-reducing the training loss would be more correct in terms of logging, but slower
558 | approx_time = training_time_ms + 1000 * (time.time() - t0)
559 | print0(f"step:{step+1}/{args.num_iterations} train_loss:{train_loss.item():.4f} train_time:{approx_time:.0f}ms step_avg:{approx_time/timed_steps:.2f}ms")
560 |
561 | if master_process:
562 | print(f"peak memory consumption: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB")
563 |
564 | # -------------------------------------------------------------------------
565 | # clean up nice
566 | dist.destroy_process_group()
567 |
--------------------------------------------------------------------------------
/records/121724_SparsifyEmbeds/README.md:
--------------------------------------------------------------------------------
1 | # Sparsify embeds
2 |
3 | Running this new record by Jiacheng You 1261 times yielded the following series of val losses:
4 | ```
5 | accs = [3.2772, 3.2794, 3.2786, 3.2807, 3.2816, 3.2827, 3.2805, 3.2787, 3.2757, 3.2798, 3.2795, 3.2785, 3.2785, 3.2803, 3.2804, 3.2816, 3.2814, 3.2808, 3.2776, 3.2809, 3.2792, 3.2776, 3.2784, 3.2809, 3.2804, 3.2779, 3.2799, 3.2761, 3.2763, 3.2794, 3.2754, 3.2816, 3.2818, 3.2785, 3.2837, 3.2765, 3.2805, 3.2784, 3.2783, 3.28, 3.2791, 3.2777, 3.2815, 3.2789, 3.2796, 3.2804, 3.2793, 3.2817, 3.2799, 3.2803, 3.2773, 3.283, 3.2781, 3.2785, 3.2771, 3.2824, 3.2819, 3.2791, 3.2799, 3.2792, 3.2771, 3.2799, 3.2782, 3.2811, 3.2786, 3.2774, 3.2786, 3.2807, 3.2775, 3.2778, 3.2778, 3.2801, 3.2764, 3.2774, 3.2801, 3.28, 3.2785, 3.2813, 3.2799, 3.2787, 3.2802, 3.2776, 3.2818, 3.2783, 3.2774, 3.2779, 3.279, 3.2777, 3.2814, 3.2783, 3.2796, 3.2822, 3.2785, 3.2784, 3.283, 3.2799, 3.2786, 3.2799, 3.2797, 3.2791, 3.2761, 3.278, 3.2775, 3.28, 3.2804, 3.2781, 3.2778, 3.2806, 3.2767, 3.2787, 3.2769, 3.2794, 3.2856, 3.2764, 3.278, 3.2814, 3.2803, 3.2781, 3.28, 3.2787, 3.2787, 3.2797, 3.2799, 3.2815, 3.2777, 3.2792, 3.2799, 3.2789, 3.2772, 3.2762, 3.2792, 3.2769, 3.28, 3.2786, 3.2751, 3.2818, 3.2791, 3.2776, 3.2778, 3.2796, 3.2793, 3.2785, 3.2826, 3.281, 3.2781, 3.2796, 3.2783, 3.2788, 3.2811, 3.2818, 3.2803, 3.2794, 3.2757, 3.2793, 3.277, 3.2765, 3.2785, 3.2788, 3.2796, 3.2773, 3.2778, 3.2802, 3.2837, 3.2753, 3.2831, 3.276, 3.2773, 3.2762, 3.2789, 3.2769, 3.2805, 3.2816, 3.2761, 3.2788, 3.2787, 3.2785, 3.2818, 3.2787, 3.2838, 3.279, 3.2805, 3.2807, 3.2804, 3.2797, 3.2752, 3.2838, 3.2834, 3.2792, 3.2804, 3.2793, 3.282, 3.2829, 3.2796, 3.2789, 3.279, 3.2778, 3.2787, 3.279, 3.279, 3.2789, 3.2756, 3.281, 3.28, 3.2804, 3.2796, 3.2803, 3.2795, 3.2781, 3.2783, 3.2772, 3.2807, 3.279, 3.2787, 3.2777, 3.2781, 3.2818, 3.2748, 3.2786, 3.2758, 3.2762, 3.2801, 3.2817, 3.2807, 3.2804, 3.2772, 3.281, 3.2766, 3.278, 3.2753, 3.2803, 3.2787, 3.2799, 3.2797, 3.2794, 3.2823, 3.2769, 3.2789, 3.2769, 3.277, 3.2806, 3.2799, 3.2787, 3.2786, 3.28, 3.28, 3.2813, 3.279, 3.2795, 3.2792, 3.2807, 3.2806, 3.2779, 3.2783, 3.2796, 3.2778, 3.2808, 3.2778, 3.2785, 3.2781, 3.2808, 3.2802, 3.2819, 3.2794, 3.2784, 3.2819, 3.2824, 3.2814, 3.2791, 3.2779, 3.2788, 3.2788, 3.2796, 3.2798, 3.2782, 3.2782, 3.2768, 3.2785, 3.2788, 3.2812, 3.2811, 3.2793, 3.2812, 3.2824, 3.2786, 3.2787, 3.2806, 3.2807, 3.2771, 3.2825, 3.2791, 3.2761, 3.2831, 3.2803, 3.2807, 3.2793, 3.2795, 3.2825, 3.276, 3.279, 3.2817, 3.2808, 3.279, 3.2793, 3.282, 3.2835, 3.2789, 3.2791, 3.2792, 3.2797, 3.281, 3.2795, 3.2775, 3.2772, 3.2818, 3.2787, 3.2775, 3.2814, 3.2787, 3.2818, 3.2772, 3.2796, 3.2787, 3.2815, 3.2795, 3.2799, 3.2785, 3.2772, 3.2788, 3.279, 3.2776, 3.2819, 3.2783, 3.2751, 3.2763, 3.2771, 3.2797, 3.2783, 3.2823, 3.2798, 3.277, 3.2813, 3.2774, 3.2801, 3.2821, 3.2806, 3.2833, 3.281, 3.2819, 3.2794, 3.2815, 3.279, 3.2837, 3.2779, 3.28, 3.2803, 3.2784, 3.2786, 3.2782, 3.2782, 3.2791, 3.279, 3.2806, 3.2801, 3.2807, 3.2797, 3.2767, 3.2796, 3.2798, 3.2816, 3.2766, 3.2823, 3.2772, 3.2765, 3.2784, 3.2775, 3.2779, 3.284, 3.2778, 3.2806, 3.2806, 3.281, 3.2787, 3.2823, 3.2771, 3.2768, 3.2782, 3.2822, 3.2785, 3.279, 3.2811, 3.2785, 3.2781, 3.2802, 3.2793, 3.2794, 3.2811, 3.2837, 3.2785, 3.2809, 3.283, 3.2813, 3.2805, 3.2769, 3.2806, 3.276, 3.2814, 3.28, 3.277, 3.2791, 3.2775, 3.279, 3.2802, 3.2809, 3.2815, 3.2763, 3.2862, 3.2791, 3.2791, 3.2763, 3.2789, 3.2792, 3.2816, 3.2792, 3.2775, 3.2803, 3.2809, 3.2828, 3.2805, 3.2794, 3.2801, 3.281, 3.2772, 3.2806, 3.2789, 3.2827, 3.2796, 3.2846, 3.2812, 3.2791, 3.2765, 3.2784, 3.2777, 3.2773, 3.2778, 3.2768, 3.2783, 3.2793, 3.2778, 3.2776, 3.2742, 3.2769, 3.2774, 3.2813, 3.2801, 3.2807, 3.2777, 3.2821, 3.2794, 3.2791, 3.279, 3.2763, 3.2804, 3.2803, 3.2795, 3.2805, 3.2815, 3.2801, 3.2823, 3.2798, 3.2802, 3.2784, 3.2801, 3.2792, 3.2856, 3.2805, 3.2782, 3.2808, 3.2793, 3.2804, 3.2775, 3.2798, 3.28, 3.2795, 3.2789, 3.2795, 3.2771, 3.2792, 3.2802, 3.2815, 3.2806, 3.2813, 3.2809, 3.2829, 3.2778, 3.282, 3.2825, 3.2789, 3.282, 3.2785, 3.2782, 3.2791, 3.2803, 3.276, 3.2777, 3.279, 3.278, 3.2781, 3.2815, 3.2776, 3.2823, 3.275, 3.2794, 3.2787, 3.2775, 3.2792, 3.2794, 3.2774, 3.2806, 3.2785, 3.28, 3.2758, 3.2795, 3.2778, 3.2796, 3.2817, 3.2785, 3.2815, 3.2797, 3.2749, 3.2785, 3.2804, 3.277, 3.2791, 3.2818, 3.2826, 3.2784, 3.2768, 3.2801, 3.2833, 3.2817, 3.2796, 3.2783, 3.2781, 3.2757, 3.2787, 3.2803, 3.2786, 3.2757, 3.2774, 3.2813, 3.2777, 3.2821, 3.2791, 3.2774, 3.2786, 3.2808, 3.2791, 3.279, 3.2813, 3.2818, 3.2771, 3.2861, 3.2805, 3.2789, 3.2769, 3.2809, 3.2823, 3.2854, 3.2819, 3.2789, 3.2796, 3.2815, 3.2781, 3.2819, 3.2802, 3.2788, 3.2767, 3.277, 3.2798, 3.2806, 3.2787, 3.2778, 3.2786, 3.2805, 3.2799, 3.2776, 3.2775, 3.2815, 3.2791, 3.28, 3.2789, 3.28, 3.2807, 3.2793, 3.2783, 3.2771, 3.2801, 3.2796, 3.2789, 3.2772, 3.2783, 3.2812, 3.2803, 3.2756, 3.2775, 3.2807, 3.2801, 3.2787, 3.28, 3.2827, 3.2801, 3.2798, 3.2785, 3.2798, 3.2787, 3.2785, 3.2796, 3.2762, 3.2808, 3.2788, 3.2775, 3.2765, 3.2792, 3.2784, 3.2787, 3.2793, 3.2793, 3.2784, 3.2773, 3.2812, 3.2785, 3.2759, 3.2781, 3.2786, 3.2783, 3.2804, 3.2791, 3.2791, 3.2772, 3.2803, 3.2773, 3.2778, 3.2809, 3.2815, 3.2784, 3.278, 3.2783, 3.2818, 3.2805, 3.2802, 3.2828, 3.2767, 3.2811, 3.2786, 3.2798, 3.2796, 3.2777, 3.2793, 3.277, 3.2762, 3.2773, 3.2796, 3.2786, 3.2809, 3.2797, 3.2796, 3.2815, 3.2803, 3.2833, 3.2793, 3.2773, 3.2761, 3.2832, 3.2798, 3.2801, 3.2806, 3.2803, 3.2797, 3.276, 3.2798, 3.2797, 3.2788, 3.2824, 3.2785, 3.2802, 3.2817, 3.2766, 3.2815, 3.2797, 3.279, 3.2808, 3.2776, 3.2789, 3.2783, 3.2772, 3.2803, 3.282, 3.2773, 3.2803, 3.28, 3.2772, 3.2827, 3.2804, 3.2776, 3.2794, 3.2815, 3.2836, 3.2813, 3.2794, 3.2795, 3.279, 3.2772, 3.2787, 3.2813, 3.2778, 3.2798, 3.2819, 3.2788, 3.2838, 3.2792, 3.2772, 3.2799, 3.2837, 3.2801, 3.2806, 3.2799, 3.2793, 3.2788, 3.2786, 3.2766, 3.2782, 3.281, 3.2783, 3.2789, 3.2801, 3.2759, 3.281, 3.2762, 3.2795, 3.2799, 3.2835, 3.2772, 3.2794, 3.2803, 3.2782, 3.2804, 3.2782, 3.28, 3.2766, 3.2823, 3.2771, 3.2775, 3.2811, 3.2789, 3.2808, 3.2787, 3.2805, 3.2812, 3.281, 3.2809, 3.2795, 3.2801, 3.2817, 3.2789, 3.2808, 3.2779, 3.2758, 3.2779, 3.276, 3.2779, 3.2823, 3.2818, 3.2816, 3.2806, 3.2807, 3.2788, 3.2778, 3.2821, 3.2777, 3.2779, 3.2775, 3.2785, 3.2794, 3.2813, 3.2825, 3.2812, 3.2801, 3.2782, 3.2807, 3.2797, 3.2781, 3.2778, 3.2778, 3.2803, 3.2832, 3.2819, 3.2783, 3.279, 3.2785, 3.279, 3.2786, 3.2793, 3.2798, 3.282, 3.2794, 3.2818, 3.2796, 3.2795, 3.2796, 3.2779, 3.2788, 3.2776, 3.2787, 3.2766, 3.279, 3.2764, 3.2831, 3.2819, 3.2791, 3.2784, 3.2793, 3.2824, 3.28, 3.2812, 3.2773, 3.2777, 3.283, 3.2774, 3.278, 3.2801, 3.2817, 3.2791, 3.2811, 3.281, 3.2817, 3.2803, 3.2791, 3.2816, 3.2785, 3.2797, 3.2805, 3.2809, 3.2825, 3.2799, 3.2777, 3.2803, 3.2787, 3.2783, 3.2784, 3.2781, 3.2754, 3.2801, 3.2782, 3.2792, 3.2776, 3.2786, 3.283, 3.28, 3.2771, 3.2808, 3.2774, 3.2787, 3.2788, 3.2787, 3.278, 3.2805, 3.2783, 3.2814, 3.2785, 3.2794, 3.2825, 3.2767, 3.2781, 3.2812, 3.2792, 3.2807, 3.2785, 3.2833, 3.2763, 3.2834, 3.2798, 3.278, 3.2783, 3.2781, 3.28, 3.2793, 3.2768, 3.2786, 3.2797, 3.2775, 3.2786, 3.2776, 3.2792, 3.277, 3.2784, 3.2804, 3.2795, 3.2802, 3.2761, 3.2778, 3.2796, 3.278, 3.2776, 3.2781, 3.2791, 3.2768, 3.2797, 3.2835, 3.2774, 3.2783, 3.2814, 3.2799, 3.2799, 3.2812, 3.2787, 3.2815, 3.2786, 3.2771, 3.2796, 3.2812, 3.276, 3.2814, 3.2775, 3.28, 3.2824, 3.2806, 3.2806, 3.2786, 3.2839, 3.2794, 3.2787, 3.2809, 3.2838, 3.2794, 3.2812, 3.282, 3.2783, 3.2813, 3.2803, 3.2784, 3.2826, 3.279, 3.2784, 3.2835, 3.2811, 3.2843, 3.2805, 3.282, 3.2805, 3.2798, 3.2786, 3.2785, 3.2804, 3.276, 3.2764, 3.2774, 3.2783, 3.279, 3.2815, 3.2803, 3.2768, 3.2796, 3.2801, 3.2808, 3.2778, 3.2798, 3.2804, 3.2814, 3.2782, 3.2801, 3.2787, 3.2795, 3.2792, 3.2791, 3.2776, 3.2774, 3.2776, 3.2796, 3.2795, 3.2766, 3.2774, 3.2773, 3.2804, 3.2785, 3.2802, 3.2805, 3.2802, 3.2793, 3.281, 3.2763, 3.2834, 3.2803, 3.2781, 3.2768, 3.2771, 3.278, 3.2779, 3.2815, 3.2812, 3.2807, 3.2819, 3.2824, 3.2812, 3.2806, 3.2782, 3.2797, 3.2782, 3.2793, 3.2755, 3.2808, 3.2816, 3.2796, 3.2817, 3.2779, 3.2774, 3.2774, 3.2774, 3.2794, 3.278, 3.2836, 3.2828, 3.279, 3.2805, 3.2824, 3.2776, 3.2795, 3.2807, 3.2783, 3.2809, 3.2789, 3.2771, 3.2792, 3.2775, 3.2809, 3.2813, 3.2797, 3.2788, 3.2792, 3.2763, 3.282, 3.2762, 3.2787, 3.282, 3.2791, 3.2781, 3.2778, 3.279, 3.2788, 3.2791, 3.2829, 3.2769, 3.28, 3.2768, 3.277, 3.2774, 3.2841, 3.2777, 3.2749, 3.2785, 3.2805, 3.2814, 3.2768, 3.2767, 3.2803, 3.2785, 3.2808, 3.2811, 3.2805, 3.2794, 3.2772, 3.2791, 3.2809, 3.2807, 3.2815, 3.2793, 3.2784, 3.2794, 3.2819, 3.2812, 3.2799, 3.2805, 3.2782, 3.2805, 3.2793, 3.2788, 3.2783, 3.2804, 3.2795, 3.2785, 3.2808, 3.2823, 3.2787, 3.2787, 3.278, 3.2791, 3.2805, 3.2808, 3.2787, 3.2779, 3.2781, 3.2787, 3.2779, 3.2775, 3.2789, 3.2784, 3.2803, 3.2774, 3.2798, 3.2772, 3.28, 3.2816, 3.2792, 3.278, 3.2792, 3.2787, 3.2813, 3.2799, 3.2802, 3.281, 3.2768, 3.2811, 3.2772, 3.2802, 3.2822, 3.2789, 3.2762, 3.2775, 3.2799, 3.2792, 3.2795, 3.2792, 3.2793, 3.2817, 3.2784, 3.28, 3.2792, 3.2788, 3.2815, 3.2782, 3.2826, 3.28, 3.2782, 3.2792, 3.2757, 3.2766, 3.2788, 3.2778, 3.2788, 3.2797, 3.2797, 3.2777, 3.2783, 3.2778, 3.2799, 3.2812, 3.2813, 3.2802, 3.2818, 3.2801, 3.277, 3.2839, 3.2806, 3.2777, 3.2805, 3.278, 3.279, 3.2775, 3.28, 3.2774, 3.2789, 3.277, 3.2807, 3.2805, 3.2795, 3.2777, 3.2813, 3.2805, 3.2809, 3.2814, 3.2794, 3.2797, 3.2803, 3.2802, 3.2808, 3.278, 3.275, 3.283, 3.2791, 3.2761, 3.2787, 3.2797, 3.2781, 3.2754, 3.2775, 3.2797, 3.281, 3.2792, 3.2797, 3.2812, 3.2781, 3.2782, 3.2803, 3.2778, 3.2812, 3.2809, 3.2781, 3.2769, 3.2797, 3.2774, 3.2787, 3.2805, 3.2796, 3.2771, 3.2776, 3.2784, 3.2757, 3.2784, 3.2795, 3.2802, 3.2788, 3.2787, 3.2778, 3.2832, 3.2784, 3.2802, 3.2805, 3.2787, 3.2786, 3.2801, 3.2811, 3.2809, 3.2793, 3.2809, 3.2786, 3.2777, 3.2801, 3.2787, 3.2838, 3.2751, 3.2777, 3.2806, 3.2786, 3.2772, 3.2797, 3.278, 3.281, 3.2808, 3.2769, 3.2774, 3.2779, 3.2836, 3.2792, 3.2786, 3.2834, 3.2786, 3.2781, 3.2809, 3.2788, 3.2799, 3.28, 3.2807]
6 |
7 | import scipy.stats
8 | print('p=%.4f' % scipy.stats.ttest_1samp(accs, 3.28, alternative='less').pvalue)
9 | # p=0.0000
10 |
11 | import torch
12 | print(torch.std_mean(torch.tensor(accs)))
13 | # (tensor(0.0018), tensor(3.2794))
14 | ```
15 |
16 | 
17 |
18 |
--------------------------------------------------------------------------------
/records/121724_SparsifyEmbeds/loss_hist_121724.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KellerJordan/modded-nanogpt/ca964e982191830eebbd155e185937077511a8aa/records/121724_SparsifyEmbeds/loss_hist_121724.png
--------------------------------------------------------------------------------
/records/123124_Target350M/README.md:
--------------------------------------------------------------------------------
1 | The `train_gpt.py` in this folder attains ~2.95 loss in ~26min on 8xH100
2 |
3 | It can therefore form the first record for the task of matching the performance attained by Karpathy using llm.c to train a [350M parameter model on 30B tokens](https://github.com/karpathy/llm.c/discussions/481).
4 |
5 | I'd be happy to see / help boost competition on this task as well
6 |
7 |
--------------------------------------------------------------------------------
/records/123124_Target350M/train_gpt.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | with open(sys.argv[0]) as f:
4 | code = f.read() # read the code of this file ASAP, for logging
5 | import uuid
6 | import time
7 | import glob
8 | import contextlib
9 | from dataclasses import dataclass
10 |
11 | import torch
12 | torch.empty(1, device='cuda', requires_grad=True).backward()
13 | from torch import nn
14 | import torch.nn.functional as F
15 | import torch.distributed as dist
16 | from torch.nn.parallel import DistributedDataParallel as DDP
17 | from torch.nn.attention.flex_attention import BlockMask, flex_attention #KoszarskyB
18 |
19 | # -----------------------------------------------------------------------------
20 | # Muon optimizer
21 |
22 | @torch.compile
23 | def zeropower_via_newtonschulz5(G, steps):
24 | """
25 | Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
26 | quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
27 | of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
28 | zero even beyond the point where the iteration no longer converges all the way to one everywhere
29 | on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
30 | where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
31 | performance at all relative to UV^T, where USV^T = G is the SVD.
32 | """
33 | assert len(G.shape) == 2
34 | a, b, c = (3.4445, -4.7750, 2.0315)
35 | X = G.bfloat16()
36 | if G.size(0) > G.size(1):
37 | X = X.T
38 |
39 | # Ensure spectral norm is at most 1
40 | X = X / (X.norm() + 1e-7)
41 | # Perform the NS iterations
42 | for _ in range(steps):
43 | A = X @ X.T
44 | B = b * A + c * A @ A # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
45 | X = a * X + B @ X
46 |
47 | if G.size(0) > G.size(1):
48 | X = X.T
49 | return X
50 |
51 | class Muon(torch.optim.Optimizer):
52 | """
53 | Muon - MomentUm Orthogonalized by Newton-schulz
54 |
55 | Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
56 | processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
57 | matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
58 | the advantage that it can be stably run in bfloat16 on the GPU.
59 |
60 | Some warnings:
61 | - This optimizer assumes that all parameters passed in are 2D.
62 | - It should not be used for the embedding layer, the final fully connected layer, or any {0,1}-D
63 | parameters; those should all be optimized by a standard method (e.g., AdamW).
64 | - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
65 | - We believe it is unlikely to work well for training with small batch size.
66 | - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
67 | - We have not yet tried this optimizer for training scenarios larger than NanoGPT (124M).
68 |
69 | Arguments:
70 | lr: The learning rate used by the internal SGD.
71 | momentum: The momentum used by the internal SGD.
72 | nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
73 | ns_steps: The number of Newton-Schulz iteration steps to use.
74 | """
75 | def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5):
76 | self.world_size = int(os.environ['WORLD_SIZE'])
77 | self.rank = int(os.environ['RANK'])
78 | defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
79 | assert all(isinstance(p, torch.Tensor) for p in params)
80 | sizes = {p.numel() for p in params}
81 | param_groups = [dict(params=[p for p in params if p.numel() == size],
82 | update_buffer=[torch.empty(size, device='cuda', dtype=torch.bfloat16) for _ in range(self.world_size)])
83 | for size in sizes]
84 | super().__init__(param_groups, defaults)
85 |
86 | def step(self):
87 |
88 | for group in self.param_groups:
89 |
90 | lr = group['lr']
91 | momentum = group['momentum']
92 | nesterov = group['nesterov']
93 | ns_steps = group['ns_steps']
94 | update_buffers = group['update_buffer']
95 | # generate weight updates in distributed fashion
96 | params = group['params']
97 | handle = None
98 | params_world = None
99 | def update_prev():
100 | if params_world is None:
101 | return
102 | assert handle is not None
103 | handle.wait()
104 | for p_world, g_world in zip(params_world, update_buffers):
105 | p_world.data.add_(
106 | g_world.view_as(p_world),
107 | alpha=-lr * max(1, p_world.size(0) / p_world.size(1)) ** 0.5,
108 | )
109 | for base_i in range(len(params))[::self.world_size]:
110 | if base_i + rank < len(params):
111 | p = params[base_i + self.rank]
112 | g = p.grad
113 | assert g is not None
114 | state = self.state[p]
115 | if 'momentum_buffer' not in state:
116 | state['momentum_buffer'] = torch.zeros_like(g)
117 | buf = state['momentum_buffer']
118 | buf.lerp_(g, 1 - momentum)
119 | g = g.lerp_(buf, momentum) if nesterov else buf
120 | g = zeropower_via_newtonschulz5(g, steps=ns_steps).flatten()
121 | else:
122 | g = update_buffers[rank]
123 | update_prev()
124 | handle = dist.all_gather(update_buffers, g, async_op=True)
125 | params_world = params[base_i : base_i + self.world_size]
126 | update_prev()
127 |
128 | # -----------------------------------------------------------------------------
129 | # PyTorch nn.Module definitions for the GPT-2 model
130 |
131 | def norm(x):
132 | return F.rms_norm(x, (x.size(-1),))
133 |
134 | class CastedLinear(nn.Linear):
135 |
136 | def __init__(self, in_features, out_features):
137 | super().__init__(in_features, out_features, bias=False)
138 |
139 | def forward(self, x):
140 | return F.linear(x, self.weight.to(x.dtype))
141 |
142 | class Rotary(nn.Module):
143 |
144 | def __init__(self, dim, max_seq_len=65536):
145 | super().__init__()
146 | inv_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32)
147 | inv_freq = torch.cat([inv_freq, inv_freq.new_zeros(dim//4)])
148 | t = torch.arange(max_seq_len, dtype=torch.float32)
149 | theta = torch.einsum('i,j -> ij', t, inv_freq)
150 | self.cos = nn.Buffer(theta.cos(), persistent=False)
151 | self.sin = nn.Buffer(theta.sin(), persistent=False)
152 |
153 | def forward(self, x):
154 | cos, sin = self.cos[None, :x.size(-3), None, :], self.sin[None, :x.size(-3), None, :]
155 | x1, x2 = x.float().chunk(2, dim=-1)
156 | y1 = x1 * cos + x2 * sin
157 | y2 = x1 * (-sin) + x2 * cos
158 | return torch.cat((y1, y2), 3).type_as(x)
159 |
160 | class CausalSelfAttention(nn.Module):
161 |
162 | def __init__(self, dim, num_heads):
163 | super().__init__()
164 | assert dim % num_heads == 0
165 | self.num_heads = num_heads
166 | self.c_q = CastedLinear(dim, dim)
167 | self.c_k = CastedLinear(dim, dim)
168 | self.c_v = CastedLinear(dim, dim)
169 | self.lambdas = nn.Parameter(torch.tensor([0.5, 0.5]))
170 | self.rotary = Rotary(dim // num_heads) # dim // num_heads = head_dim
171 | self.c_proj = CastedLinear(dim, dim)
172 | self.c_proj.weight.data.zero_() # zero init suggested by @Grad62304977
173 |
174 | def forward(self, x, ve, block_mask):
175 | B, T = x.size(0), x.size(1) # batch size, sequence length
176 | assert B == 1, 'Must use batch size = 1 for FlexAttention'
177 | q = self.c_q(x).view(B, T, self.num_heads, -1)
178 | k = self.c_k(x).view(B, T, self.num_heads, -1)
179 | v = self.c_v(x).view(B, T, self.num_heads, -1)
180 | if ve is not None:
181 | v = self.lambdas[0] * v + self.lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977
182 | else:
183 | v = self.lambdas[0] * v
184 | q, k = norm(q), norm(k) # QK norm @Grad62304977
185 | q, k = self.rotary(q), self.rotary(k)
186 | y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask, enable_gqa=True)
187 | y = y.transpose(1, 2).contiguous().view_as(x) # re-assemble all head outputs side by side
188 | y = self.c_proj(y)
189 | return y
190 |
191 | class MLP(nn.Module):
192 |
193 | def __init__(self, dim):
194 | super().__init__()
195 | self.c_fc = CastedLinear(dim, 4 * dim)
196 | self.c_proj = CastedLinear(4 * dim, dim)
197 | self.c_proj.weight.data.zero_() # zero init suggested by @Grad62304977
198 |
199 | def forward(self, x):
200 | x = self.c_fc(x)
201 | x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977
202 | x = self.c_proj(x)
203 | return x
204 |
205 | class Block(nn.Module):
206 |
207 | def __init__(self, model_dim, num_heads, use_attn=True):
208 | super().__init__()
209 | self.attn = CausalSelfAttention(model_dim, num_heads) if use_attn else None
210 | self.mlp = MLP(model_dim)
211 | self.lambdas = nn.Parameter(torch.tensor([1., 0.]))
212 |
213 | def forward(self, x, ve, x0, block_mask):
214 | x = self.lambdas[0] * x + self.lambdas[1] * x0
215 | if self.attn is not None:
216 | x = x + self.attn(norm(x), ve, block_mask)
217 | x = x + self.mlp(norm(x))
218 | return x
219 |
220 | class ValueEmbedding(nn.Module):
221 | def __init__(self, vocab_size, model_dim):
222 | super().__init__()
223 | self.embed = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)])
224 |
225 | def forward(self, inputs):
226 | ve = [emb(inputs) for emb in self.embed]
227 | ve = [ve[0], ve[1], ve[2], None, None, None, None, None, None, None, None, None, None, ve[0], ve[1], ve[2]]
228 | return ve
229 |
230 | # -----------------------------------------------------------------------------
231 | # The main GPT-2 model
232 |
233 | class GPT(nn.Module):
234 |
235 | def __init__(self, vocab_size, num_layers, num_heads, model_dim):
236 | super().__init__()
237 | self.embed = nn.Embedding(vocab_size, model_dim)
238 | self.blocks = nn.ModuleList([Block(model_dim, num_heads, use_attn=(i != 7))
239 | for i in range(num_layers)])
240 | # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual learning
241 | # U-net structure on token value embeddings by @leloykun
242 | self.value_embeds = ValueEmbedding(vocab_size, model_dim)
243 | self.lm_head = CastedLinear(model_dim, vocab_size)
244 | self.lm_head.weight.data.zero_() # @Grad62304977
245 | # U-net design by @brendanh0gan
246 | self.num_encoder_layers = num_layers // 2 # Half of the layers for encoder
247 | self.num_decoder_layers = num_layers - self.num_encoder_layers # Remaining for decoder
248 | # Add learnable skip connection weights for decoder layers
249 | self.skip_weights = nn.Parameter(torch.ones(self.num_decoder_layers))
250 |
251 | def forward(self, inputs, targets, sliding_window_num_blocks):
252 | BLOCK_SIZE = 128
253 | seq_len = len(inputs)
254 | assert seq_len % BLOCK_SIZE == 0
255 | total_num_blocks = seq_len // BLOCK_SIZE
256 | assert inputs.ndim == 1
257 | docs = (inputs == 50256).cumsum(0)
258 | docs_low = docs.view(-1, BLOCK_SIZE)[:, 0].contiguous()
259 | docs_high = docs.view(-1, BLOCK_SIZE)[:, -1].contiguous()
260 |
261 | def document_causal(b, h, q_idx, kv_idx):
262 | causal_mask = q_idx >= kv_idx
263 | document_mask = docs[q_idx] == docs[kv_idx]
264 | return causal_mask & document_mask
265 |
266 | def dense_to_ordered(dense_mask):
267 | num_blocks = dense_mask.sum(dim=-1, dtype=torch.int32)
268 | indices = dense_mask.argsort(dim=-1, descending=True, stable=True).to(torch.int32)
269 | return num_blocks[None, None].contiguous(), indices[None, None].contiguous()
270 |
271 | def create_doc_swc_block_mask(sliding_window_num_blocks):
272 | kv_idx = block_idx = torch.arange(total_num_blocks, dtype=torch.int32, device='cuda')
273 | q_idx = block_idx[:, None]
274 | causal_bm = q_idx >= kv_idx
275 | causal_full_bm = q_idx > kv_idx
276 | window_bm = q_idx - kv_idx < sliding_window_num_blocks
277 | window_full_bm = window_bm
278 | # document_bm = (docs_low[q_idx] <= docs_high[kv_idx]) & (docs_low[kv_idx] <= docs_high[q_idx])
279 | document_bm = (docs_low[:, None] <= docs_high) & (docs_low <= docs_high[:, None])
280 | document_full_bm = (docs_low[:, None] == docs_high) & (docs_low == docs_high[:, None])
281 | nonzero_bm = causal_bm & window_bm & document_bm
282 | full_bm = causal_full_bm & window_full_bm & document_full_bm
283 | kv_num_blocks, kv_indices = dense_to_ordered(nonzero_bm & ~full_bm)
284 | full_kv_num_blocks, full_kv_indices = dense_to_ordered(full_bm)
285 | return BlockMask.from_kv_blocks(
286 | kv_num_blocks,
287 | kv_indices,
288 | full_kv_num_blocks,
289 | full_kv_indices,
290 | BLOCK_SIZE=BLOCK_SIZE,
291 | mask_mod=document_causal,
292 | )
293 |
294 | block_mask = create_doc_swc_block_mask(sliding_window_num_blocks)
295 |
296 | # forward the GPT model itself
297 | x = self.embed(inputs[None]) # token embeddings of shape (b, t, model_dim)
298 | x = norm(x) # @Grad62304977
299 | x0 = x
300 | ve = self.value_embeds(inputs)
301 | ve_enc, ve_dec = ve[:self.num_encoder_layers], ve[self.num_encoder_layers:]
302 |
303 | # Store outputs for U-Net skip connections
304 | skip_connections = []
305 | # Encoder pass - process only the first half of the blocks
306 | for i in range(self.num_encoder_layers):
307 | x = self.blocks[i](x, ve_enc[i], x0, block_mask)
308 | skip_connections.append(x)
309 | # Decoder pass - process the remaining blocks with weighted skip connections
310 | for i in range(self.num_decoder_layers):
311 | x = x + self.skip_weights[i] * skip_connections.pop()
312 | # U-net structure on token value embeddings by @leloykun
313 | x = self.blocks[self.num_encoder_layers + i](x, ve_dec[i], x0, block_mask)
314 |
315 | x = norm(x)
316 | logits = self.lm_head(x)
317 | logits = 30 * torch.tanh(logits / 30) # @Grad62304977
318 | logits = logits.float()
319 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
320 | return loss
321 |
322 | # -----------------------------------------------------------------------------
323 | # Our own simple Distributed Data Loader
324 |
325 | def _load_data_shard(path):
326 | # only reads the header, returns header data
327 | # header is 256 int32
328 | header = torch.from_file(path, False, 256, dtype=torch.int32)
329 | assert header[0] == 20240520, 'magic number mismatch in the data .bin file'
330 | assert header[1] == 1, 'unsupported version'
331 | num_tokens = int(header[2]) # number of tokens (claimed)
332 | with open(path, 'rb', buffering=0) as f:
333 | tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True)
334 | f.seek(256 * 4)
335 | nbytes = f.readinto(tokens.numpy())
336 | assert nbytes == 2 * num_tokens, 'number of tokens read does not match header'
337 | return tokens
338 |
339 | class DistributedDataLoader:
340 |
341 | def __init__(self, filename_pattern):
342 | self.rank = int(os.environ['RANK'])
343 | self.world_size = int(os.environ['WORLD_SIZE'])
344 | self.files = sorted(glob.glob(filename_pattern))
345 | self.reset()
346 |
347 | def reset(self):
348 | self.current_shard = -1
349 | self.advance()
350 |
351 | def advance(self):
352 | self.current_shard = (self.current_shard + 1) % len(self.files)
353 | self.current_position = 0
354 | self.tokens = _load_data_shard(self.files[self.current_shard])
355 |
356 | def next_batch(self, batch_size):
357 | assert batch_size % self.world_size == 0
358 | device_batch_size = batch_size // self.world_size
359 | # load next shard if necessary
360 | if self.current_position + batch_size + 1 >= len(self.tokens):
361 | self.advance()
362 | pos = self.current_position + self.rank * device_batch_size
363 | device_batch_tokens = self.tokens[pos:pos+device_batch_size+1]
364 | # advance current position
365 | self.current_position += batch_size
366 | inputs = device_batch_tokens[:-1].to(device='cuda', dtype=torch.int32, non_blocking=True)
367 | targets = device_batch_tokens[1:].to(device='cuda', dtype=torch.int64, non_blocking=True)
368 | return inputs, targets
369 |
370 | # -----------------------------------------------------------------------------
371 | # int main
372 |
373 | @dataclass
374 | class Hyperparameters:
375 | # data
376 | train_bin : str = 'data/fineweb10B/fineweb_train_*.bin' # input .bin to train on
377 | val_bin : str = 'data/fineweb10B/fineweb_val_*.bin' # input .bin to eval validation loss on
378 | # optimization
379 | batch_size : int = 8*64*1024 # batch size in tokens
380 | device_batch_size : int = 64*1024 # batch size per device in tokens
381 | num_iterations : int = 6000 # number of iterations to run
382 | cooldown_iters : int = 2000 # number of iterations of linear warmup/cooldown for triangular or trapezoidal schedule
383 | bf16_embeds : bool = True
384 | # evaluation and logging
385 | val_loss_every : int = 125 # every how many steps to evaluate val loss? 0 for only at the end
386 | val_tokens : int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons
387 | # implementation
388 | save_checkpoint : bool = False
389 | args = Hyperparameters()
390 |
391 | # set up DDP (distributed data parallel). torchrun sets this env variable
392 | rank = int(os.environ['RANK'])
393 | local_rank = int(os.environ['LOCAL_RANK'])
394 | world_size = int(os.environ['WORLD_SIZE'])
395 | assert torch.cuda.is_available()
396 | torch.cuda.set_device(local_rank)
397 | dist.init_process_group(backend='nccl', device_id=torch.device(local_rank))
398 | dist.barrier()
399 | master_process = (rank == 0) # this process will do logging, checkpointing etc.
400 |
401 | assert args.batch_size % args.device_batch_size == 0
402 | assert (args.batch_size // args.device_batch_size) == world_size
403 |
404 | # begin logging
405 | logfile = None
406 | if master_process:
407 | run_id = uuid.uuid4()
408 | os.makedirs('logs', exist_ok=True)
409 | logfile = f'logs/{run_id}.txt'
410 | print(logfile)
411 |
412 | def print0(s, console=False):
413 | if master_process:
414 | with open(logfile, 'a') as f:
415 | if console:
416 | print(s)
417 | print(s, file=f)
418 |
419 | # begin by printing this file (the Python code)
420 | print0(code)
421 | print0('='*100)
422 | # log information about the hardware/software environment this is running on
423 | print0(f'Running Python {sys.version}')
424 | print0(f'Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}')
425 | import subprocess
426 | result = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
427 | print0(f'{result.stdout}')
428 | print0('='*100)
429 |
430 | # load data
431 | train_loader = DistributedDataLoader(args.train_bin)
432 | val_loader = DistributedDataLoader(args.val_bin)
433 | print0(f'Training dataloader files: {train_loader.files}')
434 | print0(f'Validation dataloader files: {val_loader.files}')
435 | print0('='*100)
436 | inputs_train, targets_train = train_loader.next_batch(args.batch_size)
437 |
438 | # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. suggested to me by @Grad62304977.
439 | # this originates from Karpathy's experiments.
440 | model = GPT(vocab_size=50304, num_layers=16, num_heads=8, model_dim=1024)
441 | model = model.cuda()
442 | if args.bf16_embeds:
443 | for m in model.modules():
444 | if isinstance(m, nn.Embedding):
445 | m.bfloat16()
446 | model = torch.compile(model)
447 | ddp_model = DDP(model, device_ids=[local_rank], broadcast_buffers=False, gradient_as_bucket_view=True)
448 | sliding_window_num_blocks = torch.tensor(1, dtype=torch.int32, device='cuda')
449 | sw_num_blocks_prev = 1
450 |
451 | # collect the parameters to optimize
452 | hidden_matrix_params = [p for p in model.blocks.parameters() if p.ndim == 2]
453 | embed_params = [model.embed.weight, *model.value_embeds.parameters()]
454 | scalar_params = [p for p in model.parameters() if p.ndim < 2]
455 | head_params = [model.lm_head.weight]
456 |
457 | # init the optimizer(s)
458 | optimizer1 = torch.optim.Adam([dict(params=embed_params, lr=0.35),
459 | dict(params=head_params, lr=0.004),
460 | dict(params=scalar_params, lr=0.02)],
461 | betas=(0.8, 0.95), fused=True)
462 | optimizer2 = Muon(hidden_matrix_params, lr=0.03, momentum=0.95)
463 | optimizers = [optimizer1, optimizer2]
464 |
465 | # learning rate decay scheduler (stable then decay)
466 | def get_lr(it):
467 | assert it <= args.num_iterations
468 | # 1) constant lr for first part of training
469 | if it < args.num_iterations - args.cooldown_iters:
470 | return 1.0
471 | # 2) then linear cooldown
472 | else:
473 | return (args.num_iterations - it) / args.cooldown_iters
474 | schedulers = [torch.optim.lr_scheduler.LambdaLR(opt, get_lr) for opt in optimizers]
475 |
476 | # Start training loop
477 | training_time_ms = 0
478 | # start the clock
479 | torch.cuda.synchronize()
480 | t0 = time.perf_counter()
481 | # begin training
482 | train_steps = args.num_iterations
483 | for step in range(train_steps + 1):
484 | last_step = (step == train_steps)
485 | # This effectively ignores timing first 10 steps, which are slower for weird reasons.
486 | # Alternately, and slightly more correctly in terms of benchmarking, we could do 10
487 | # steps with dummy data first, and then re-initialize the model and reset the loader.
488 | if step == 10:
489 | training_time_ms = 0
490 | t0 = time.perf_counter()
491 | timed_steps = float('nan') if step <= 11 else (step - 10) + 1 # <= 11 to avoid bug in val
492 |
493 | # Linearly increase the sliding window size over training in chunks of 128 from 128 -> 1856. By @fernbear.bsky.social
494 | frac_done = step / train_steps # training progress
495 | sw_num_blocks = int(((1 - frac_done) * 128 + frac_done * 1856) // 128)
496 | if sw_num_blocks != sw_num_blocks_prev:
497 | sliding_window_num_blocks.copy_(sw_num_blocks, non_blocking=True)
498 | sw_num_blocks_prev = sw_num_blocks
499 |
500 | # --------------- VALIDATION SECTION -----------------
501 | if (last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0)):
502 | # stop the clock
503 | torch.cuda.synchronize()
504 | training_time_ms += 1000 * (time.perf_counter() - t0)
505 | # run validation batches
506 | model.eval()
507 | val_loader.reset()
508 | val_loss = 0.0
509 | # calculate the number of steps to take in the val loop.
510 | assert args.val_tokens % args.batch_size == 0
511 | val_steps = args.val_tokens // args.batch_size
512 | for _ in range(val_steps):
513 | with torch.no_grad():
514 | inputs_val, targets_val = val_loader.next_batch(args.batch_size)
515 | val_loss += ddp_model(inputs_val, targets_val, sliding_window_num_blocks)
516 | dist.all_reduce(val_loss, op=dist.ReduceOp.AVG)
517 | val_loss /= val_steps
518 | # logging
519 | print0(f'step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms', console=True)
520 | # start the clock again
521 | torch.cuda.synchronize()
522 | t0 = time.perf_counter()
523 |
524 | if last_step:
525 | if master_process and args.save_checkpoint:
526 | log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers])
527 | os.makedirs(f'logs/{run_id}', exist_ok=True)
528 | torch.save(log, f'logs/{run_id}/state_step{step:06d}.pt')
529 | # the last step only has the validation loop, so break to avoid training
530 | break
531 |
532 | # --------------- TRAINING SECTION -----------------
533 | model.train()
534 | ddp_model(inputs_train, targets_train, sliding_window_num_blocks).backward()
535 | inputs_train, targets_train = train_loader.next_batch(args.batch_size)
536 | # momentum warmup for Muon
537 | frac = min(step/300, 1)
538 | for group in optimizer2.param_groups:
539 | group['momentum'] = (1 - frac) * 0.85 + frac * 0.95
540 | # step the optimizers and schedulers
541 | for opt, sched in zip(optimizers, schedulers):
542 | opt.step()
543 | sched.step()
544 | # null the gradients
545 | model.zero_grad(set_to_none=True)
546 | # logging
547 | approx_time = training_time_ms + 1000 * (time.perf_counter() - t0)
548 | print0(f'step:{step+1}/{train_steps} train_time:{approx_time:.0f}ms step_avg:{approx_time/timed_steps:.2f}ms', console=True)
549 |
550 | print0(f'peak memory consumption: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB')
551 | dist.destroy_process_group()
552 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | tqdm
3 | torch
4 | huggingface-hub
5 |
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | torchrun --standalone --nproc_per_node=8 train_gpt.py
2 |
--------------------------------------------------------------------------------
/train_gpt_medium.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | with open(sys.argv[0]) as f:
4 | code = f.read() # read the code of this file ASAP, for logging
5 | import uuid
6 | import time
7 | import copy
8 | from dataclasses import dataclass
9 | from functools import lru_cache
10 | from pathlib import Path
11 |
12 | os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
13 | import torch
14 | torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems
15 | from torch import Tensor, nn
16 | import torch.nn.functional as F
17 | import torch.distributed as dist
18 | # use of FlexAttention contributed by @KoszarskyB
19 | from torch.nn.attention.flex_attention import BlockMask, flex_attention
20 | torch._inductor.config.coordinate_descent_tuning = True # we allow this flag for medium track
21 | torch._dynamo.config.compiled_autograd = True
22 |
23 | # -----------------------------------------------------------------------------
24 | # Muon optimizer
25 |
26 | def zeropower_via_newtonschulz5(G: Tensor) -> Tensor:
27 | """
28 | Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
29 | quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
30 | of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
31 | zero even beyond the point where the iteration no longer converges all the way to one everywhere
32 | on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
33 | where S' is diagonal with S_{ii}' ∈ [1 - l, 1 + r], which turns out not to hurt model
34 | performance at all relative to UV^T, where USV^T = G is the SVD.
35 | """
36 | assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
37 | X = G.bfloat16()
38 | if G.size(-2) > G.size(-1):
39 | X = X.mT
40 |
41 | # Ensure spectral norm is at most 1
42 | X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
43 | # Perform the NS iterations
44 | for a, b, c in [
45 | (4.0848, -6.8946, 2.9270),
46 | (3.9505, -6.3029, 2.6377),
47 | (3.7418, -5.5913, 2.3037),
48 | (2.8769, -3.1427, 1.2046),
49 | (2.8366, -3.0525, 1.2012),
50 | ]:
51 | A = X @ X.mT
52 | B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
53 | X = a * X + B @ X
54 |
55 | if G.size(-2) > G.size(-1):
56 | X = X.mT
57 | return X
58 |
59 | @torch.compile
60 | def update(acc_bf16_view_u16: Tensor, mantissa: Tensor, momentum_buffer: Tensor, grad: Tensor, momentum: Tensor, eff_lr: Tensor, eff_weight_decay: Tensor):
61 | assert acc_bf16_view_u16.dtype == mantissa.dtype == torch.uint16
62 | grad = grad.float()
63 | momentum_buffer.copy_(momentum * momentum_buffer + (1 - momentum) * grad)
64 | v = zeropower_via_newtonschulz5(momentum * momentum_buffer + (1 - momentum) * grad)
65 |
66 | acc_m_u32 = (acc_bf16_view_u16.to(torch.uint32) << 16) | mantissa.to(torch.uint32)
67 | acc_m_u32.view(torch.float32).mul_(1 - eff_weight_decay)
68 | acc_m_u32.view(torch.float32).add_(other=v, alpha=-eff_lr)
69 | acc_bf16_view_u16.copy_((acc_m_u32 >> 16).to(torch.uint16))
70 | mantissa.copy_(acc_m_u32.to(torch.uint16))
71 |
72 | class Muon(torch.optim.Optimizer):
73 | """
74 | Muon - MomentUm Orthogonalized by Newton-schulz
75 |
76 | https://kellerjordan.github.io/posts/muon/
77 |
78 | Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
79 | processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
80 | matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
81 | the advantage that it can be stably run in bfloat16 on the GPU.
82 |
83 | Warning: This optimizer should not be used for the embedding layer, the final fully connected layer,
84 | or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW).
85 | """
86 | def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95, rank=0, world_size=1):
87 | self.rank = rank
88 | self.world_size = world_size
89 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
90 | super().__init__(params, defaults)
91 | assert all(p.dtype == torch.bfloat16 for group in self.param_groups for p in group["params"])
92 |
93 | @torch.no_grad()
94 | def step(self):
95 | futures: list[torch.Future] = []
96 | for group in self.param_groups:
97 | params: list[Tensor] = group["params"]
98 | params_pad = params + [torch.empty_like(params[-1])] * self.world_size
99 | momentum = torch._as_tensor_fullprec(group["momentum"])
100 | for base_i in range(len(params))[::self.world_size]:
101 | if base_i + self.rank < len(params):
102 | p = params[base_i + self.rank]
103 | state = self.state[p]
104 | if len(state) == 0:
105 | state["mantissa"] = torch.zeros_like(p, dtype=torch.uint16)
106 | state["momentum_buffer"] = torch.zeros_like(p, dtype=torch.float32)
107 | update(
108 | p.view(torch.uint16), state["mantissa"], state["momentum_buffer"],
109 | p.grad, momentum,
110 | eff_lr=torch._as_tensor_fullprec(group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5),
111 | eff_weight_decay=torch._as_tensor_fullprec(group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0)),
112 | )
113 | futures.append(dist.all_gather(params_pad[base_i:base_i + self.world_size], params_pad[base_i + self.rank], async_op=True).get_future())
114 | torch.futures.collect_all(futures).wait()
115 |
116 | # -----------------------------------------------------------------------------
117 | # PyTorch nn.Module definitions for the model
118 |
119 | def norm(x: Tensor):
120 | return F.rms_norm(x, (x.size(-1),))
121 |
122 | @torch.no_grad()
123 | def init_linear(w: Tensor):
124 | std = 0.5 * (w.size(-1) ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3)
125 | bound = (3 ** 0.5) * std
126 | return w.uniform_(-bound, bound)
127 |
128 | class Rotary(nn.Module):
129 | def __init__(self, dim: int, max_seq_len: int):
130 | super().__init__()
131 | # half-truncate RoPE by @YouJiacheng (w/ base freq tuning)
132 | angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32)
133 | angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)])
134 | t = torch.arange(max_seq_len, dtype=torch.float32)
135 | theta = torch.einsum("i,j -> ij", t, angular_freq)
136 | self.cos = nn.Buffer(theta.cos(), persistent=False)
137 | self.sin = nn.Buffer(theta.sin(), persistent=False)
138 |
139 | def forward(self, x_BTHD: Tensor):
140 | assert self.cos.size(0) >= x_BTHD.size(-3)
141 | cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :]
142 | x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1)
143 | y1 = x1 * cos + x2 * sin
144 | y2 = x1 * (-sin) + x2 * cos
145 | return torch.cat((y1, y2), 3).type_as(x_BTHD)
146 |
147 | class CausalSelfAttention(nn.Module):
148 | def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128):
149 | super().__init__()
150 | self.num_heads = num_heads
151 | self.head_dim = head_dim
152 | hdim = num_heads * head_dim
153 | # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng
154 | # https://x.com/hi_tysam/status/1879699187107033311
155 | self.qkvo_w = nn.Parameter(init_linear(torch.empty(4, hdim, dim)).bfloat16())
156 | self.qkvo_w.detach()[3].zero_() # out zero init suggested by @Grad62304977
157 | self.rotary = Rotary(head_dim, max_seq_len)
158 | # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun
159 | # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283
160 | self.attn_scale = 0.12
161 |
162 | def forward(self, x: Tensor, ve: Tensor | None, block_mask: BlockMask, lambdas: Tensor):
163 | B, T = x.size(0), x.size(1) # batch size, sequence length
164 | assert B == 1, "Must use batch size = 1 for FlexAttention"
165 | q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2)
166 | q, k = norm(q), norm(k) # QK norm @Grad62304977
167 | q, k = self.rotary(q), self.rotary(k)
168 | v = norm(v)
169 | if ve is not None:
170 | v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977
171 | else: # skip mid-layers token value embeddings by @YouJiacheng
172 | v = lambdas[0] * v
173 | y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask, scale=self.attn_scale).transpose(1, 2)
174 | y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side
175 | y = F.linear(y, self.qkvo_w[3])
176 | return y
177 |
178 | class MLP(nn.Module):
179 | def __init__(self, dim: int):
180 | super().__init__()
181 | hdim = 4 * dim
182 | self.fc_w = nn.Parameter(init_linear(torch.empty(hdim, dim)).bfloat16())
183 | self.proj_w = nn.Parameter(torch.zeros(dim, hdim).bfloat16())
184 | self.fc_w.wd_mul = 2.0
185 | self.proj_w.wd_mul = 2.0
186 |
187 | def forward(self, x: Tensor):
188 | x = F.linear(x, self.fc_w)
189 | x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977
190 | x = F.linear(x, self.proj_w)
191 | return x
192 |
193 | class Block(nn.Module):
194 | def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int):
195 | super().__init__()
196 | # skip attention of blocks.7 (the 8th layer) by @YouJiacheng
197 | self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None
198 | self.mlp = MLP(dim)
199 |
200 | def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, block_mask: BlockMask, lambdas: Tensor, sa_lambdas: Tensor):
201 | x = lambdas[0] * x + lambdas[1] * x0
202 | if self.attn is not None:
203 | x = x + self.attn(x, ve, block_mask, sa_lambdas)
204 | x = x + self.mlp(norm(x))
205 | return x
206 |
207 | # -----------------------------------------------------------------------------
208 | # The main model
209 |
210 | def next_multiple_of_n(v: float | int, *, n: int):
211 | return next(x for x in range(n, int(v) + 1 + n, n) if x >= v)
212 |
213 | class GPT(nn.Module):
214 | def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int):
215 | super().__init__()
216 | self.embed = nn.Embedding(vocab_size, model_dim)
217 | # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897
218 | # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78
219 | self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)])
220 | self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)])
221 | # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency.
222 | # suggested to me by @Grad62304977. this originates from Karpathy's experiments.
223 | self.lm_head_w = nn.Parameter(torch.zeros(next_multiple_of_n(vocab_size, n=128), model_dim))
224 | # Add learnable skip connection weights for decoder layers
225 | assert num_layers % 2 == 0
226 | self.scalars = nn.Parameter(torch.cat([
227 | torch.ones(num_layers), # skip_weights
228 | *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas
229 | *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas
230 | ]))
231 |
232 | def create_blockmasks(self, input_seq: Tensor, sliding_window_num_blocks: Tensor):
233 | BLOCK_SIZE = 128
234 | docs = (input_seq == 50256).cumsum(0)
235 |
236 | def document_causal(b, h, q_idx, kv_idx):
237 | causal_mask = q_idx >= kv_idx
238 | document_mask = docs[q_idx] == docs[kv_idx]
239 | return causal_mask & document_mask
240 |
241 | def dense_to_ordered(dense_blockmask: Tensor):
242 | num_blocks = dense_blockmask.sum(dim=-1, dtype=torch.int32)
243 | indices = dense_blockmask.argsort(dim=-1, descending=False, stable=True).flip(-1).to(torch.int32)
244 | return num_blocks[None, None].contiguous(), indices[None, None].contiguous()
245 |
246 | # manual block mask creation by @YouJiacheng
247 | assert len(input_seq) % BLOCK_SIZE == 0
248 | NUM_BLOCKS = len(input_seq) // BLOCK_SIZE
249 | block_idx = torch.arange(NUM_BLOCKS, dtype=torch.int32, device="cuda")
250 | causal_blockmask_any = block_idx[:, None] >= block_idx
251 | causal_blockmask_all = block_idx[:, None] > block_idx
252 | docs_low = docs.view(-1, BLOCK_SIZE)[:, 0].contiguous()
253 | docs_high = docs.view(-1, BLOCK_SIZE)[:, -1].contiguous()
254 | document_blockmask_any = (docs_low[:, None] <= docs_high) & (docs_high[:, None] >= docs_low)
255 | document_blockmask_all = (docs_low[:, None] == docs_high) & (docs_high[:, None] == docs_low)
256 | blockmask_any = causal_blockmask_any & document_blockmask_any
257 | blockmask_all = causal_blockmask_all & document_blockmask_all
258 | partial_kv_num_blocks, partial_kv_indices = dense_to_ordered(blockmask_any & ~blockmask_all)
259 | full_kv_num_blocks, full_kv_indices = dense_to_ordered(blockmask_all)
260 | def build_bm(window_size_blocks: Tensor) -> BlockMask:
261 | return BlockMask.from_kv_blocks(
262 | torch.clamp_max(partial_kv_num_blocks, torch.clamp_min(window_size_blocks - full_kv_num_blocks, 1)),
263 | partial_kv_indices,
264 | torch.clamp_max(full_kv_num_blocks, window_size_blocks - 1),
265 | full_kv_indices,
266 | BLOCK_SIZE=BLOCK_SIZE,
267 | mask_mod=document_causal,
268 | )
269 | # Long-short SWA block masks by @leloykun & @YouJiacheng, adapated from suggestion by @Grad62304977, following Gemma 2 paper
270 | return build_bm(sliding_window_num_blocks), build_bm(sliding_window_num_blocks // 2)
271 |
272 | def forward(self, input_seq: Tensor, target_seq: Tensor, sliding_window_num_blocks: Tensor):
273 | assert input_seq.ndim == 1
274 |
275 | ve = [value_embed(input_seq) for value_embed in self.value_embeds]
276 | # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure
277 | ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]]
278 | assert len(ve) == len(self.blocks)
279 |
280 | long_bm, short_bm = self.create_blockmasks(input_seq, sliding_window_num_blocks)
281 | block_masks = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm]
282 | assert len(block_masks) == len(self.blocks)
283 |
284 | x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977
285 |
286 | skip_connections = []
287 | skip_map = {
288 | 9: 6,
289 | 10: 4,
290 | 11: 2,
291 | }
292 | skip_weights = self.scalars[:len(self.blocks)]
293 | lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2)
294 | sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2)
295 | for i in range(len(self.blocks)):
296 | if i in skip_map:
297 | x = x + skip_weights[skip_map[i]] * skip_connections[skip_map[i]]
298 | x = self.blocks[i](x, ve[i], x0, block_masks[i], lambdas[i], sa_lambdas[i])
299 | skip_connections.append(x)
300 |
301 | x = norm(x)
302 | if self.training:
303 | logits: Tensor = F.linear(x.flatten(end_dim=1), self.lm_head_w.bfloat16()).float()
304 | loss = F.cross_entropy(15 * logits * torch.rsqrt(logits.square() + 225), target_seq)
305 | return loss
306 |
307 | loss = 0
308 | for i in range(4):
309 | logits: Tensor = F.linear(x.flatten(end_dim=1).chunk(4)[i], self.lm_head_w.bfloat16()).float()
310 | loss += F.cross_entropy(15 * logits * torch.rsqrt(logits.square() + 225), target_seq.chunk(4)[i]) / 4
311 | return loss
312 |
313 | # -----------------------------------------------------------------------------
314 | # Our own simple Distributed Data Loader
315 |
316 | def _load_data_shard(file: Path):
317 | header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32
318 | assert header[0] == 20240520, "magic number mismatch in the data .bin file"
319 | assert header[1] == 1, "unsupported version"
320 | num_tokens = int(header[2]) # number of tokens (claimed)
321 | with file.open("rb", buffering=0) as f:
322 | tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng
323 | f.seek(256 * 4)
324 | nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng
325 | assert nbytes == 2 * num_tokens, "number of tokens read does not match header"
326 | return tokens
327 |
328 | def distributed_data_generator(filename_pattern: str, batch_size: int, rank : int, world_size : int):
329 | files = sorted(Path.cwd().glob(filename_pattern))
330 | assert batch_size % world_size == 0
331 | local_batch_size = batch_size // world_size
332 | file_iter = iter(files) # use itertools.cycle(files) instead if you want to do multi-epoch training
333 | tokens, pos = _load_data_shard(next(file_iter)), 0
334 | while True:
335 | if pos + batch_size + 1 >= len(tokens):
336 | tokens, pos = _load_data_shard(next(file_iter)), 0
337 | buf = tokens[pos + rank * local_batch_size:][:local_batch_size + 1]
338 | inputs = buf[:-1].to(device="cuda", dtype=torch.int32, non_blocking=True) # no sync on host side;
339 | targets = buf[1:].to(device="cuda", dtype=torch.int64, non_blocking=True) # H2D in another stream isn't helpful.
340 | pos += batch_size
341 | yield inputs, targets
342 |
343 | # -----------------------------------------------------------------------------
344 | # int main
345 |
346 | @dataclass
347 | class Hyperparameters:
348 | # data
349 | train_files = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on
350 | val_files = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on
351 | val_tokens = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons
352 | train_seq_len = 64*1024 # FlexAttention sequence length
353 | val_seq_len = 4*64*1024 # FlexAttention sequence length for validation
354 | # optimization
355 | num_iterations = 5960 # number of iterations to run
356 | cooldown_frac = 0.7 # fraction of training spent cooling down the learning rate
357 | # architecture
358 | vocab_size = 50257
359 | # evaluation and logging
360 | val_loss_every = 125 # every how many steps to evaluate val loss? 0 for only at the end
361 | save_checkpoint = False
362 | args = Hyperparameters()
363 |
364 | run_id = int(os.environ.get("RUN_ID", 0))
365 | # torchrun sets these env variables
366 | rank = int(os.environ["RANK"])
367 | world_size = int(os.environ["WORLD_SIZE"])
368 | assert world_size == 8 # this code is designed for 8xH100
369 | assert torch.cuda.is_available()
370 | device = torch.device("cuda", int(os.environ["LOCAL_RANK"]))
371 | torch.cuda.set_device(device)
372 | dist.init_process_group(backend="nccl", device_id=device)
373 | dist.barrier()
374 | master_process = (rank == 0) # this process will do logging, checkpointing etc.
375 |
376 | # begin logging
377 | if master_process:
378 | run_id_full = f"{run_id:03d}_{uuid.uuid4()}"
379 | os.makedirs("logs", exist_ok=True)
380 | logfile = f"logs/{run_id_full}.txt"
381 | print(logfile)
382 | def print0(s, console=False):
383 | if master_process:
384 | with open(logfile, "a") as f:
385 | if console:
386 | print(s)
387 | print(s, file=f)
388 | from torch._logging._internal import trace_structured # noqa: E402
389 | import torch._inductor.codecache # noqa: E402
390 | import torch._inductor.graph # noqa: E402
391 | def _patched_trace_structured(name, metadata_fn, **kwargs):
392 | if name == "inductor_output_code":
393 | print0(f"inductor_output_code: {metadata_fn().get("filename", "Unknown")}")
394 | trace_structured(name, metadata_fn, **kwargs)
395 | torch._inductor.codecache.trace_structured = _patched_trace_structured
396 | torch._inductor.graph.trace_structured = _patched_trace_structured
397 |
398 | # begin by printing this file (the Python code)
399 | print0(code)
400 | print0("="*100)
401 | # log information about the hardware/software environment this is running on
402 | print0(f"Running Python {sys.version}")
403 | print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}")
404 | def nvidia_smi():
405 | import subprocess # avoid top level import
406 | return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout
407 | print0(nvidia_smi())
408 | print0("="*100)
409 |
410 | ########################################
411 | # Construct model and optimizer #
412 | ########################################
413 |
414 | model: nn.Module = GPT(vocab_size=args.vocab_size, num_layers=16, num_heads=8, model_dim=1024,
415 | max_seq_len=max(args.train_seq_len, args.val_seq_len)).cuda()
416 | for m in model.modules():
417 | if isinstance(m, nn.Embedding):
418 | m.bfloat16()
419 | for param in model.parameters():
420 | dist.broadcast(param.detach(), 0)
421 |
422 | # collect the parameters to optimize
423 | hidden_matrix_params = sorted((p for p in model.blocks.parameters() if p.ndim >= 2), key=lambda x: x.size(), reverse=True)
424 | embed_params = [*model.embed.parameters(), *model.value_embeds.parameters()]
425 | scalar_params = [model.scalars]
426 | head_params: list[nn.Parameter] = [model.lm_head_w]
427 | # sanity check
428 | params_collections = [hidden_matrix_params, embed_params, scalar_params, head_params]
429 | optimized_parameters_set = {p for params in params_collections for p in params}
430 | assert optimized_parameters_set == {*model.parameters()}
431 | assert len(optimized_parameters_set) == sum(len(lst) for lst in params_collections)
432 |
433 | # init the optimizer(s)
434 | adam_param_groups = [dict(params=head_params, lr=1/320), dict(params=embed_params, lr=0.3), dict(params=scalar_params, lr=0.015)]
435 | # small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence
436 | # discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094
437 | optimizer1 = torch.optim.AdamW(adam_param_groups, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0, fused=True)
438 | optimizer2 = Muon(hidden_matrix_params, lr=0.025, momentum=0.95, rank=rank, world_size=world_size)
439 | optimizers: list[torch.optim.Optimizer] = [optimizer1, optimizer2]
440 | def opt_params(opt: torch.optim.Optimizer) -> list[nn.Parameter]:
441 | return [p for group in opt.param_groups for p in group["params"]]
442 | opt2params = {opt: opt_params(opt) for opt in optimizers}
443 | for opt in optimizers:
444 | for group in opt.param_groups:
445 | group["initial_lr"] = group["lr"]
446 |
447 | # learning rate schedule: stable then decay
448 | def get_lr(step: int):
449 | x = step / args.num_iterations # progress in training
450 | assert 0 <= x < 1
451 | if x < 1 - args.cooldown_frac:
452 | return 1.0
453 | else:
454 | return (1 - x) / args.cooldown_frac
455 |
456 | # attention window size schedule: linearly increase
457 | @lru_cache(1)
458 | def get_window_size_blocks_helper(window_size: int):
459 | return torch.tensor(window_size // 128, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
460 | def get_window_size_blocks(step: int):
461 | x = step / args.num_iterations # progress in training
462 | assert 0 <= x <= 1
463 | # Linearly increase the block-wise sliding window size over training 128 -> 1792
464 | # increase by @fernbear.bsky.social; block-wise by @YouJiacheng
465 | factor = 4 * x ** 3 - 6 * x ** 2 + 3 * x # cubic schedule by @jadenj3o
466 | window_size = next_multiple_of_n(3456 * factor, n=128)
467 | return get_window_size_blocks_helper(window_size)
468 |
469 | model: nn.Module = torch.compile(model, dynamic=False)
470 |
471 | ########################################
472 | # Warmup kernels #
473 | ########################################
474 |
475 | # Warmup the training kernels, then re-initialize the state so we aren't cheating
476 | warmup_steps = 10
477 | initial_state = copy.deepcopy(dict(model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]))
478 | for _ in range(warmup_steps):
479 | inputs = targets = torch.randint(0, args.vocab_size, size=(args.train_seq_len,), device="cuda")
480 | model(inputs.to(torch.int32), targets, get_window_size_blocks(0)).backward()
481 | for param in model.parameters():
482 | dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)
483 | for opt in optimizers:
484 | opt.step()
485 | model.zero_grad(set_to_none=True)
486 | model.load_state_dict(initial_state["model"])
487 | for opt, opt_state in zip(optimizers, initial_state["optimizers"]):
488 | opt.load_state_dict(opt_state)
489 | del initial_state
490 |
491 | ########################################
492 | # Training and validation #
493 | ########################################
494 |
495 | torch.cuda.reset_peak_memory_stats()
496 | train_loader = distributed_data_generator(args.train_files, world_size * args.train_seq_len, rank, world_size)
497 | training_time_ms = 0
498 | # start the clock
499 | dist.barrier()
500 | t0 = time.perf_counter()
501 | # begin training
502 | train_steps = args.num_iterations
503 | for step in range(train_steps + 1):
504 | last_step = (step == train_steps)
505 |
506 | # --------------- VALIDATION SECTION -----------------
507 | if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0):
508 | # stop the clock
509 | dist.barrier()
510 | training_time_ms += 1000 * (time.perf_counter() - t0)
511 | model.eval()
512 | val_batch_size = world_size * args.val_seq_len
513 | assert args.val_tokens % val_batch_size == 0
514 | val_steps = args.val_tokens // val_batch_size
515 | val_loader = distributed_data_generator(args.val_files, val_batch_size, rank, world_size)
516 | val_loss = 0
517 | with torch.no_grad():
518 | for _ in range(val_steps):
519 | inputs, targets = next(val_loader)
520 | val_loss += model(inputs, targets, get_window_size_blocks(step))
521 | val_loss /= val_steps
522 | del val_loader
523 | dist.all_reduce(val_loss, op=dist.ReduceOp.AVG)
524 | print0(f"step:{step}/{train_steps} val_loss:{val_loss:.6f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True)
525 | model.train()
526 | # start the clock again
527 | dist.barrier()
528 | t0 = time.perf_counter()
529 |
530 | if last_step:
531 | if master_process and args.save_checkpoint:
532 | log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers])
533 | os.makedirs(f"logs/{run_id_full}", exist_ok=True)
534 | torch.save(log, f"logs/{run_id_full}/state_step{step:06d}.pt")
535 | # the last step only has the validation loop, so break to avoid training
536 | break
537 |
538 | # --------------- TRAINING SECTION -----------------
539 | inputs, targets = next(train_loader)
540 | model(inputs, targets, get_window_size_blocks(step)).backward()
541 | opt2futures = {
542 | opt: [dist.all_reduce(p.grad, op=dist.ReduceOp.AVG, async_op=True).get_future() for p in params]
543 | for opt, params in opt2params.items()
544 | }
545 | # set optimization hyperparameters
546 | for opt in optimizers:
547 | for group in opt.param_groups:
548 | group["lr"] = group["initial_lr"] * get_lr(step)
549 | for group in optimizer2.param_groups:
550 | frac = min(step / 300, 1) # momentum warmup for muon
551 | group["momentum"] = (1 - frac) * 0.85 + frac * 0.95
552 | # step the optimizers
553 | for opt in optimizers:
554 | torch.futures.collect_all(opt2futures[opt]).wait()
555 | opt.step()
556 | # null the gradients
557 | model.zero_grad(set_to_none=True)
558 | # logging
559 | approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0)
560 | print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True)
561 |
562 | print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB "
563 | f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True)
564 | dist.destroy_process_group()
565 |
--------------------------------------------------------------------------------