├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── data ├── cached_fineweb100B.py ├── cached_fineweb10B.py ├── cached_finewebedu10B.py ├── fineweb.py └── requirements.txt ├── img ├── algo_optimizer.png ├── dofa.jpg ├── fig_optimizer.png ├── fig_tuned_nanogpt.png ├── nanogpt_speedrun51.png ├── nanogpt_speedrun52.png ├── nanogpt_speedrun53.png └── nanogpt_speedrun54.png ├── records ├── 010425_SoftCap │ ├── 31d6c427-f1f7-4d8a-91be-a67b5dcd13fd.txt │ ├── README.md │ └── curves_010425.png ├── 011325_Fp8LmHead │ ├── README.md │ └── c51969c2-d04c-40a7-bcea-c092c3c2d11a.txt ├── 011625_Sub3Min │ ├── 1d3bd93b-a69e-4118-aeb8-8184239d7566.txt │ ├── README.md │ ├── attn-entropy.png │ ├── attn-scales-pattern.gif │ ├── learned-attn-scales.png │ └── long-short-swa.png ├── 011825_GPT2Medium │ ├── 241dd7a7-3d76-4dce-85a4-7df60387f32a.txt │ └── main.log ├── 012625_BatchSize │ ├── 0bdd5ee9-ac28-4202-bdf1-c906b102b0ec.txt │ ├── README.md │ ├── ablations.png │ ├── c44090cc-1b99-4c95-8624-38fb4b5834f9.txt │ ├── val_losses.png │ └── wallclock.png ├── 020125_RuleTweak │ └── eff63a8c-2f7e-4fc5-97ce-7f600dae0bc7.txt ├── 020825_GPT2MediumWeightDecay │ └── b01743db-605c-4326-b5b1-d388ee5bebc5.txt ├── 021425_GPT2MediumOptCoeffs │ └── 1baa66b2-bff7-4850-aced-d63885ffb4b6.txt ├── 030625_GPT2MediumLongerCooldown │ ├── 779c041a-2a37-45d2-a18b-ec0f223c2bb7.txt │ └── README.md ├── 032525_GPT2MediumArchOptTweaks │ └── train_gpt-20250329.txt ├── 041625_GPT2Medium_Record7 │ ├── 223_3310d0b1-b24d-48ee-899f-d5c2a254a195.txt │ └── README.md ├── 042225_GPT2Medium_Record8 │ └── 075_640429f2-e726-4e83-aa27-684626239ffc.txt ├── 052425_FasterReduce │ └── 23f40b75-06fb-4c3f-87a8-743524769a35.txt ├── 052425_StableTorch │ └── 89d9f224-3b01-4581-966e-358d692335e0.txt ├── 052525_EvenFasterReduce │ └── 6ae86d05-5cb2-4e40-a512-63246fd08e45.txt ├── 052525_MuonWithAuxAdamExample │ └── b01550f9-03d8-4a9c-86fe-4ab434f1c5e0.txt ├── 060624_AdamW │ ├── README.md │ └── f66d43d7-e449-4029-8adf-e8537bab49ea.log ├── 100924_SOAP │ ├── 5bdc3988-496c-4232-b4ef-53764cb81c92.txt │ ├── README.md │ └── train_gpt2.py ├── 101024_Muon │ ├── eb5659d0-fb6a-49e5-a311-f1f89412f726.txt │ └── train_gpt2.py ├── 101324_llmc │ ├── README.md │ └── main.log ├── 101424_ModernArch │ ├── dabaaddd-237c-4ec9-939d-6608a9ed5e27.txt │ └── train_gpt2.py ├── 101724_DistributedMuon │ └── 22d24867-eb5a-4fcc-ae2c-263d0277dfd1.txt ├── 101824_PyTorch25 │ └── d4bfb25f-688d-4da5-8743-33926fad4842.txt ├── 102024_ScaleUp1B │ ├── 87bd51fd-6203-4c88-b3aa-8a849a6a83ca.txt │ ├── ad8d7ae5-7b2d-4ee9-bc52-f912e9174d7a.txt │ └── c0078066-c8c9-49c8-868a-ff4d4f32e615.txt ├── 102924_Optimizers │ ├── 8bfe4e35-c3fc-4b70-a984-3be937b71ff3.txt │ ├── 8d6193f4-27fc-4e68-899f-af70019a4d54.txt │ ├── 95a9fd44-7c13-49c7-b324-3e7d9e23a499.txt │ ├── README.md │ ├── e21a2838-a0f2-46f2-a247-db0021165682.txt │ ├── nanogpt_speedrun81w.png │ └── nanogpt_speedrun82w.png ├── 110324_UntieEmbed │ ├── README.md │ └── d6b50d71-f419-4d26-bb39-a60d55ae7a04.txt ├── 110424_50Bruns │ ├── 3d715d41-453a-40d6-9506-421ba69766b2.txt │ ├── 4fbe61ec-f79a-4c19-836d-46d599deecce.txt │ ├── 530f3ee1-8862-4d21-be2b-da10eb05e6a9.txt │ ├── 69c33fc9-eabb-4a38-aa08-6922914eb405.txt │ └── README.md ├── 110624_ShortcutsTweaks │ ├── 042f9e87-07e6-4504-bb04-4ec59a380211.txt │ ├── 05b29e54-0be0-4a0f-a1e2-7d5317daedd3.txt │ ├── 10119f53-7001-4248-bfd9-33d32427a912.txt │ ├── 43f60c4f-0448-4de7-83d9-643ca26f61e7.txt │ ├── 4a71cc92-0f43-4058-a033-23e85c1e98f1.txt │ ├── README.md │ ├── d6b50d71-f419-4d26-bb39-a60d55ae7a04.txt │ ├── dd7304a6-cc43-4d5e-adb8-c070111464a1.txt │ ├── nanogpt_speedrun110.png │ └── nanogpt_speedrun111.png ├── 110824_CastBf16 │ └── a833bed8-2fa8-4cfe-af05-58c1cc48bc30.txt ├── 110924_Replicateleloykun │ ├── 1621af10-aa0c-42af-bf54-8a773c63a2af.txt │ └── README.md ├── 111024_ScaleShortcuts │ ├── 3e55eb2e-6261-466a-b1e9-2b31f56fb16a.txt │ ├── 4897c987-9d09-435c-a23f-20585912936a.txt │ ├── 70a0ada6-8dee-4fef-8980-135379479c21.txt │ ├── a833bed8-2fa8-4cfe-af05-58c1cc48bc30.txt │ └── d6b50d71-f419-4d26-bb39-a60d55ae7a04.txt ├── 111024_UNetDoubleLr │ ├── README.md │ └── c87bb826-797b-4f37-98c7-d3a5dad2de74.txt ├── 111424_QuantizedFP4 │ ├── 433c1732-0c3d-4099-a4a8-ec31eae49b16.txt │ ├── 70a0ada6-8dee-4fef-8980-135379479c21.txt │ ├── 932bbe0e-41c3-4a5b-94bd-4ea3350909bd.txt │ └── a833bed8-2fa8-4cfe-af05-58c1cc48bc30.txt ├── 111924_FlexAttention │ ├── 8384493d-dba9-4991-b16b-8696953f5e6d.txt │ └── README.md ├── 112424_WindowWarmup │ ├── 3151f69a-f89e-452c-ac80-b85118007583.txt │ ├── 4428858e-7cb8-4a25-a936-818d8f28de51.txt │ ├── ae732e01-04b2-4665-b570-a77210e73e28.txt │ ├── ba299b7e-a36a-4fd8-a268-25bb772010dd.txt │ ├── cf9e4571-c5fc-4323-abf3-a98d862ec6c8.txt │ ├── d1cf11aa-7b8e-4d28-a94d-1aab632e0f38.txt │ └── dca62101-15d4-4c76-842e-99213fa2508b.txt ├── 120424_ValueEmbed │ ├── 00008ea0-21dd-442a-82ee-d12799249d0f.txt │ ├── 123ac41a-8f7b-42a1-853c-5a791dcaa7ae.txt │ ├── 14511f40-47db-4c94-b35b-70616770fd2d.txt │ ├── 19bb65fb-f903-4a41-803b-fbd57562f653.txt │ ├── 1b4db08c-46d9-4e1b-a8e6-872a685061c3.txt │ ├── 2358dd3a-8ce4-4b8a-a367-fca6dcd38343.txt │ ├── 2577721d-9ce9-400c-8902-ce95d6fbcf64.txt │ ├── 2d140d43-3323-408e-b559-3639c0880323.txt │ ├── 2e527b0b-3540-4bcd-a15d-955f86cb8bd2.txt │ ├── 2f4ce5fe-b625-41b4-acbb-d4c20b591ead.txt │ ├── 385d2312-0cf9-48c3-af3f-35c12c12a38d.txt │ ├── 3e39d446-d7e8-4c49-b92c-62c3848a3c01.txt │ ├── 4f8bcdc3-18cf-4743-895e-1deb08696fe7.txt │ ├── 51b3baf0-69d6-43ee-a88b-2c5c28e3dd5b.txt │ ├── 51faed93-0804-418c-9057-0d94c3f94a9c.txt │ ├── 66173c47-b15b-4a24-a835-60c82f6b8283.txt │ ├── 67716aee-6747-4997-a37c-b96932fab4dd.txt │ ├── 6b244191-77a3-41ea-a314-82c6a9184b31.txt │ ├── 74cba1d4-da56-4334-9622-e0aa960dfe3f.txt │ ├── 75a3af7b-f1a6-47dc-a989-d95e4419ff31.txt │ ├── 87f81569-fa04-4eb3-8b75-42c116e96ba0.txt │ ├── 8bd08106-6eb1-4cd1-9779-ccf1192bda1b.txt │ ├── 92b1541a-a6d6-4ddf-8932-ea4bcd31ba3b.txt │ ├── 949e5cfd-cb9c-48e7-a888-551981582a9b.txt │ ├── 968f73e2-b588-4102-80b0-996bae126be1.txt │ ├── README.md │ ├── a5c54ede-2647-4df1-872d-de033898a9e2.txt │ ├── b19b341a-bf8d-46f9-8076-fef1fcb7445e.txt │ ├── bb60727f-25b8-4e01-bf08-293aa3860e7c.txt │ ├── c249f3c7-b947-4b6f-8a42-56e97ef7712e.txt │ ├── c8e1a7d3-a37e-4a88-b28a-3afb2d8089ca.txt │ ├── d12cb409-0f5d-4624-951c-60119a482bca.txt │ ├── d6520673-0f5f-4c28-898b-f52d056b257d.txt │ ├── d884d4cb-0656-460d-9454-897eb9789f2a.txt │ ├── e0248b14-212b-436b-9a12-ba142d720ab4.txt │ ├── e3ffb000-a388-4bdc-b88f-2f85794c13c8.txt │ ├── ed14c8b2-2ac1-41e0-acea-3cc55cd94f83.txt │ ├── eecabe35-2224-4988-9910-8e8434a0c281.txt │ ├── f0f173e3-69ee-4970-ba0b-f7f3e5d92e33.txt │ └── train_gpt2.py ├── 120824_UNetValueEmbedsTweaks │ ├── 0069607b-aa90-49fd-9766-4368bcd168c4.txt │ ├── 08d38857-18d8-406f-907b-c80142fec515.txt │ ├── 0ad22fc7-dd0c-4d79-a216-c7667db09e15.txt │ ├── 13daf1c4-986b-4938-b269-1ac935891505.txt │ ├── 14dfd7b3-d65a-4813-8327-03a9f1bf9f6e.txt │ ├── 16cacfc3-de94-429e-be3a-03c3c5293d08.txt │ ├── 19c57414-499c-4892-b25c-a7083583fa59.txt │ ├── 1a1fde38-fd1f-4cc9-a648-74b8d6dab734.txt │ ├── 1b363e84-a47e-48c9-ae21-7fb825512c9a.txt │ ├── 23ed0b57-544b-4fc5-b780-062dcf3f6f2c.txt │ ├── 257dc15f-5bfd-4c0b-9349-a46b3c13d5ef.txt │ ├── 26fa5797-44d0-4a63-9e57-f435f2f59aad.txt │ ├── 29354340-7020-42e4-8b06-b9a7c85e3d51.txt │ ├── 43ed2395-9260-48d6-9ab2-dd6f7f4684b5.txt │ ├── 472b8553-de0d-4f91-9285-10781b4cd07f.txt │ ├── 4732cc5e-a214-47d6-bfce-cb4ae2f663c8.txt │ ├── 48359737-5c81-4653-8d5c-5c2da8405f16.txt │ ├── 4a0f4d79-dc64-4731-b777-33f1e3ec6d3d.txt │ ├── 4bcd20c5-9ffd-40e9-8d9a-6d9d647b16ba.txt │ ├── 4de826f5-a244-490f-8025-5f0d370bd68a.txt │ ├── 501b69a6-bcbc-4403-9bc0-c257bdc3f8b6.txt │ ├── 52275a2c-80fd-4743-ad16-1d5098d97821.txt │ ├── 554a738b-571d-4da1-8556-393e7592eda7.txt │ ├── 563ca9ab-8f99-4a56-a8be-5ab3b21ac197.txt │ ├── 569ca2ab-12e9-4035-a0c7-f18c7050d234.txt │ ├── 59ba1f2d-a3b7-4fa8-b099-f13b838470ee.txt │ ├── 5bd46aa9-aaa1-4390-b355-41b730732b10.txt │ ├── 5dc85466-81b7-4770-ba8d-6937f4400758.txt │ ├── 5e28cee7-fc40-4593-bc83-a22495bfaacc.txt │ ├── 5f2dc8e6-dfc3-4078-b756-e7521561ebce.txt │ ├── 5fca3d5f-6290-47af-9130-b78a777b24c4.txt │ ├── 5fcbae90-4380-412e-8209-51fccc183040.txt │ ├── 60adb82c-ba8d-41b0-81fc-565439384b72.txt │ ├── 6175f2f0-8526-4f20-86d9-d7c2a6dcda19.txt │ ├── 625a6fcc-203c-4545-b697-0b8daa2b6d07.txt │ ├── 677ef1aa-30f7-4c4d-96c8-b85c7d70c4d6.txt │ ├── 68b508e5-a986-40d7-a307-21c282389be5.txt │ ├── 6a4a3bd3-c3b5-4bef-9267-03724ba49759.txt │ ├── 6bf9356b-7ff8-4253-97b0-c2747ad5a52e.txt │ ├── 71e11d12-c639-4966-bd9a-be49de53bc9c.txt │ ├── 7639d7d4-a962-4193-8e57-c1d76ea22f54.txt │ ├── 7de3e01c-69f5-4cfe-9e99-97413d15488c.txt │ ├── 81074c42-378a-4231-a37d-10f5c477a78c.txt │ ├── 83d3d075-ede1-4b85-93ab-24dc38ae1bf5.txt │ ├── 83fd4151-f2d1-4d8d-ad31-5546543072f3.txt │ ├── 88baa340-752f-4874-9dc8-f2f7831d78e4.txt │ ├── 89ac99c6-838d-431b-967a-4bde4d81d6c4.txt │ ├── 8a0afde4-391d-4e2d-8df1-e44fe0b80feb.txt │ ├── 8b074f6f-a4ad-42ae-9460-679426524ddb.txt │ ├── 8b82b55d-6955-4100-a1a9-5630cb85683c.txt │ ├── 93406931-1cc4-419c-b7a7-adef928d4c54.txt │ ├── 938ecd23-aaa8-4975-b6da-9407ab45e0f9.txt │ ├── 94df9d58-590d-4171-a015-e59e144d7206.txt │ ├── 95dc5712-2fd2-45bf-89c9-a86d7704e3d0.txt │ ├── 9e4f1bfe-0f1a-4bd8-af1b-072dcaa61c4c.txt │ ├── abf007e8-6370-4987-bd1a-85bff108fe5e.txt │ ├── ad6d9498-e76d-4ce4-acb5-fd2116eb77d8.txt │ ├── b7197dc5-b590-4e32-8590-a8c0076b64ab.txt │ ├── bbc9dd90-96c5-480b-984c-594145555821.txt │ ├── bc895608-3be1-4365-89ff-abbf6c316b37.txt │ ├── bdcab079-8761-4bb8-b2bc-adc2a45fbc9a.txt │ ├── be800e9a-90a5-457d-8f86-a943f4dea20c.txt │ ├── c2e4c338-f228-4ece-a8cb-1ccf5e949817.txt │ ├── c8cecdbc-504d-4df2-b844-fadc9bddabfa.txt │ ├── d13cc58e-654c-432b-b1b6-1c8b8ae8a232.txt │ ├── d30450c5-ecc2-426d-9375-bcf48e6123d8.txt │ ├── d69df4f4-5e17-45a3-94ba-5c2ec73a82e5.txt │ ├── e276c0fa-854b-44ad-830b-cdf887eaf6c3.txt │ ├── e2c6f3ca-140e-43fa-862e-4ee48c04422e.txt │ ├── e2f54f7a-f6c5-4fcf-bea6-e0442b62adca.txt │ ├── e3319113-ec3a-457a-975c-e89dcddd15f2.txt │ ├── e5589f56-4312-44f2-9254-95304cae10cf.txt │ ├── e66b0dd9-9680-4e7a-85c7-ac77f1d89aa2.txt │ ├── f2fda86f-4418-4ff8-a3a3-8112731cbcf9.txt │ └── f9a93608-ed4e-46ab-9c06-07f90f1328a6.txt ├── 121024_MFUTweaks │ ├── 04df8c70-2462-4b6c-9e49-c7589b321a81.txt │ ├── 052953d9-0495-43db-9d15-ae98fff06f18.txt │ ├── 0aa83756-53f0-4268-9721-db6d5985bc42.txt │ ├── 217c8375-d34f-45ee-9bee-ca8f9a6608da.txt │ ├── 2d25e67e-6ab4-45ff-92e0-d223bcf1784e.txt │ ├── 2d564428-0822-4d01-a61b-f53e4cd646c6.txt │ ├── 30beb830-0f15-4a71-8c98-187ea2015606.txt │ ├── 35c28373-6cf8-463c-8007-1146801f8068.txt │ ├── 38757b64-ade6-45fa-b628-11cc83b6ed0c.txt │ ├── 3a714803-f1e1-4b81-b984-e2457ccb664f.txt │ ├── 409647d3-2887-4760-9cb4-327d3bb219a4.txt │ ├── 47029c45-997b-4498-a5c0-2fc76d365eb8.txt │ ├── 5175d854-1dcb-41e1-a690-b223fa69fd7f.txt │ ├── 571574de-d894-49b1-8f02-281ddd607823.txt │ ├── 591496b6-fdce-4902-b279-4360b378865d.txt │ ├── 59b474de-a528-4040-b6f4-19c89c2041cf.txt │ ├── 7442abb4-a571-4340-9844-de16209c3762.txt │ ├── 76ee9a40-0f61-42bb-849b-bf46c3a3e7c9.txt │ ├── 7787c6ea-33f2-4c63-a345-a8e8a830e595.txt │ ├── 8dff461f-9696-4d45-b717-bad7b1cb2060.txt │ ├── 91b48522-0c7b-4fee-83ea-c01df3e3d5c0.txt │ ├── 9323a20a-8f7f-4549-8997-dde75dfc6ec6.txt │ ├── a3e860f0-98ed-4845-9aee-cccd01fd9b5b.txt │ ├── a55c6778-e76a-44bc-a0e1-68c0e7580096.txt │ ├── a880e114-d133-469c-85db-ab58bc507d04.txt │ ├── ac154191-d2fb-478a-b6ea-55276b4f6655.txt │ ├── afcec83b-9286-455e-81a6-5eb3710d0ead.txt │ ├── c01ffedd-443f-421b-b7ca-492e8187557b.txt │ ├── c80b525d-de02-4fb0-b4ae-44f3308474ea.txt │ ├── d00194c9-e6d0-4767-b703-7e3c3ff24881.txt │ ├── d6ab1d3a-92f5-4847-8f3b-69faabc89c67.txt │ ├── d7c9acce-8b10-4f22-a505-4741319decc3.txt │ ├── e0f4f91a-4b79-4a39-a316-9ce769921840.txt │ ├── e1e718d6-89dc-4794-a969-17b2977b004c.txt │ ├── e221be41-2c83-4158-98dd-2df74ad2b9ba.txt │ ├── ea08b5fb-abc8-46b0-a643-4561d621d78d.txt │ ├── eca39347-6616-4903-ad6b-bab828aa2f78.txt │ ├── ed5a4791-fc3b-4fbf-b871-5c1c34f19b5a.txt │ ├── f2546c91-e2b3-4906-9f5b-bee25b47216f.txt │ └── feed17f1-d377-484b-b04b-0124144dfc62.txt ├── 121724_SparsifyEmbeds │ ├── 165384d5-5a41-4fba-9b79-8ecc372616ea.txt │ ├── README.md │ └── loss_hist_121724.png └── 123124_Target350M │ ├── README.md │ └── train_gpt.py ├── requirements.txt ├── run.sh ├── train_gpt.py └── train_gpt_medium.py /.gitignore: -------------------------------------------------------------------------------- 1 | fineweb10B/ 2 | pylog124M/ 3 | __pycache__/ 4 | logs/ 5 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:12.6.2-cudnn-devel-ubuntu24.04 2 | 3 | ENV DEBIAN_FRONTEND=noninteractive 4 | ENV PYTHON_VERSION=3.12.7 5 | ENV PATH=/usr/local/bin:$PATH 6 | 7 | RUN apt update && apt install -y --no-install-recommends build-essential libssl-dev zlib1g-dev \ 8 | libbz2-dev libreadline-dev libsqlite3-dev curl git libncursesw5-dev xz-utils tk-dev libxml2-dev \ 9 | libxmlsec1-dev libffi-dev liblzma-dev \ 10 | && apt clean && rm -rf /var/lib/apt/lists/* 11 | 12 | RUN curl -O https://www.python.org/ftp/python/${PYTHON_VERSION}/Python-${PYTHON_VERSION}.tgz && \ 13 | tar -xzf Python-${PYTHON_VERSION}.tgz && \ 14 | cd Python-${PYTHON_VERSION} && \ 15 | ./configure --enable-optimizations && \ 16 | make -j$(nproc) && \ 17 | make altinstall && \ 18 | cd .. && \ 19 | rm -rf Python-${PYTHON_VERSION} Python-${PYTHON_VERSION}.tgz 20 | 21 | RUN ln -s /usr/local/bin/python3.12 /usr/local/bin/python && \ 22 | ln -s /usr/local/bin/pip3.12 /usr/local/bin/pip 23 | 24 | COPY requirements.txt /modded-nanogpt/requirements.txt 25 | WORKDIR /modded-nanogpt 26 | 27 | RUN python -m pip install --upgrade pip && \ 28 | pip install -r requirements.txt 29 | 30 | RUN pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 --upgrade 31 | 32 | CMD ["bash"] 33 | ENTRYPOINT [] 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Keller Jordan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Modded-NanoGPT 2 | 3 | This repository hosts the *NanoGPT speedrun*, in which we (collaboratively|competitively) search for the fastest algorithm to use 8 NVIDIA H100 GPUs to train a language model that attains 3.28 cross-entropy loss on the [FineWeb](https://huggingface.co/datasets/HuggingFaceFW/fineweb) validation set. 4 | 5 | The target (3.28 validation loss on FineWeb) follows Andrej Karpathy's [GPT-2 replication in llm.c, which attains that loss after running for 45 minutes](https://github.com/karpathy/llm.c/discussions/481#:~:text=By%20the%20end%20of%20the%20optimization%20we%27ll%20get%20to%20about%203.29). 6 | The speedrun code also descends from llm.c's [PyTorch trainer](https://github.com/karpathy/llm.c/blob/master/train_gpt2.py), which itself descends from NanoGPT, hence the name of the repo. 7 | Thanks to the efforts of many contributors, this repo now contains a training algorithm which attains the target performance in: 8 | * 3 minutes on 8xH100 (the llm.c GPT-2 replication needed 45) 9 | * 0.73B tokens (the llm.c GPT-2 replication needed 10B) 10 | 11 | This improvement in training speed has been brought about by the following techniques: 12 | * Modernized architecture: Rotary embeddings, QK-Norm, and ReLU² 13 | * The Muon optimizer [[writeup](https://kellerjordan.github.io/posts/muon/)] [[repo](https://github.com/KellerJordan/Muon)] 14 | * Untie head from embedding, use FP8 matmul for head, and softcap logits (the latter following Gemma 2) 15 | * Initialization of projection and classification layers to zero (muP-like) 16 | * Skip connections from embedding to every block as well as between blocks in U-net pattern 17 | * Extra embeddings which are mixed into the values in attention layers (inspired by Zhou et al. 2024) 18 | * FlexAttention with long-short sliding window attention pattern (inspired by Gemma 2) and window size warmup 19 | 20 | As well as many systems optimizations. 21 | 22 | Contributors list (growing with each new record): [@bozavlado](https://x.com/bozavlado); [@brendanh0gan](https://x.com/brendanh0gan); 23 | [@fernbear.bsky.social](https://bsky.app/profile/fernbear.bsky.social); [@Grad62304977](https://x.com/Grad62304977); 24 | [@jxbz](https://x.com/jxbz); [@kellerjordan0](https://x.com/kellerjordan0); 25 | [@KoszarskyB](https://x.com/KoszarskyB); [@leloykun](https://x.com/@leloykun); 26 | [@YouJiacheng](https://x.com/YouJiacheng); [@jadenj3o](https://x.com/jadenj3o); 27 | [@KonstantinWilleke](https://github.com/KonstantinWilleke), [@alexrgilbert](https://github.com/alexrgilbert), [@adricarda](https://github.com/adricarda), 28 | [@tuttyfrutyee](https://github.com/tuttyfrutyee), [@vdlad](https://github.com/vdlad); 29 | [@ryanyang0](https://x.com/ryanyang0) 30 | 31 | 32 | --- 33 | 34 | ## Running the current record 35 | 36 | To run the current record, run the following commands. 37 | ```bash 38 | git clone https://github.com/KellerJordan/modded-nanogpt.git && cd modded-nanogpt 39 | pip install -r requirements.txt 40 | pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 --upgrade 41 | # downloads only the first 800M training tokens to save time 42 | python data/cached_fineweb10B.py 8 43 | ./run.sh 44 | ``` 45 | 46 | **Note: torch.compile will add around 5 minutes of latency the first time you run the code.** 47 | 48 | ## Alternative: Running with Docker (recommended for precise timing) 49 | 50 | For cases where CUDA or NCCL versions aren't compatible with your current system setup, Docker can be a helpful alternative. 51 | This approach standardizes versions for CUDA, NCCL, CUDNN, and Python, reducing dependency issues and simplifying setup. 52 | Note: an NVIDIA driver must already be installed on the system (useful if only the NVIDIA driver and Docker are available). 53 | 54 | ```bash 55 | git clone https://github.com/KellerJordan/modded-nanogpt.git && cd modded-nanogpt 56 | sudo docker build -t modded-nanogpt . 57 | sudo docker run -it --rm --gpus all -v $(pwd):/modded-nanogpt modded-nanogpt python data/cached_fineweb10B.py 8 58 | sudo docker run -it --rm --gpus all -v $(pwd):/modded-nanogpt modded-nanogpt sh run.sh 59 | ``` 60 | 61 | To get an interactive docker, you can use 62 | ```bash 63 | sudo docker run -it --rm --gpus all -v $(pwd):/modded-nanogpt modded-nanogpt bash 64 | ``` 65 | 66 | --- 67 | 68 | ## World record history 69 | 70 | The following is the historical progression of world speed records for the following competitive task: 71 | 72 | > *Train a neural network to ≤3.28 validation loss on FineWeb using 8x NVIDIA H100s.* 73 | 74 | Note: The 3.28 target was selected to match [Andrej Karpathy's GPT-2 (small) reproduction](https://github.com/karpathy/llm.c/discussions/481). 75 | 76 | | # | Record time | Description | Date | Log | Contributors | 77 | | - | - | - | - | - | - | 78 | 1 | 45 minutes | [llm.c baseline](https://github.com/karpathy/llm.c/discussions/481) | 05/28/24 | [log](records/101324_llmc/main.log) | @karpathy, llm.c contributors 79 | 2 | 31.4 minutes | [Tuned learning rate & rotary embeddings](https://x.com/kellerjordan0/status/1798863559243513937) | 06/06/24 | [log](records/060624_AdamW/f66d43d7-e449-4029-8adf-e8537bab49ea.log) | @kellerjordan0 80 | 3 | 24.9 minutes | [Introduced the Muon optimizer](https://x.com/kellerjordan0/status/1842300916864844014) | 10/04/24 | none | @kellerjordan0, @jxbz 81 | 4 | 22.3 minutes | [Muon improvements](https://x.com/kellerjordan0/status/1844820919061287009) | 10/11/24 | [log](records/101024_Muon/eb5659d0-fb6a-49e5-a311-f1f89412f726.txt) | @kellerjordan0, @bozavlado 82 | 5 | 15.2 minutes | [Pad embeddings, ReLU², zero-init projections, QK-norm](https://x.com/kellerjordan0/status/1845865698532450646) | 10/14/24 | [log](records/101424_ModernArch/dabaaddd-237c-4ec9-939d-6608a9ed5e27.txt) | @Grad62304977, @kellerjordan0 83 | 6 | 13.1 minutes | [Distributed the overhead of Muon](https://x.com/kellerjordan0/status/1847291684016783746) | 10/18/24 | [log](records/101724_DistributedMuon/22d24867-eb5a-4fcc-ae2c-263d0277dfd1.txt) | @kellerjordan0 84 | 7 | 12.0 minutes | [Upgraded PyTorch 2.5.0](https://x.com/kellerjordan0/status/1847358578686152764) | 10/18/24 | [log](records/101824_PyTorch25/d4bfb25f-688d-4da5-8743-33926fad4842.txt) | @kellerjordan0 85 | 8 | 10.8 minutes | [Untied embedding and head](https://x.com/kellerjordan0/status/1853188916704387239) | 11/03/24 | [log](records/110324_UntieEmbed/d6b50d71-f419-4d26-bb39-a60d55ae7a04.txt) | @Grad62304977, @kellerjordan0 86 | 9 | 8.2 minutes | [Value and embedding skip connections, momentum warmup, logit softcap](https://x.com/kellerjordan0/status/1854296101303800108) | 11/06/24 | [log](records/110624_ShortcutsTweaks/dd7304a6-cc43-4d5e-adb8-c070111464a1.txt) | @Grad62304977, @kellerjordan0 87 | 10 | 7.8 minutes | [Bfloat16 activations](https://x.com/kellerjordan0/status/1855267054774865980) | 11/08/24 | [log](records/110824_CastBf16/a833bed8-2fa8-4cfe-af05-58c1cc48bc30.txt) | @kellerjordan0 88 | 11 | 7.2 minutes | [U-net pattern skip connections & double lr](https://x.com/kellerjordan0/status/1856053121103093922) | 11/10/24 | [log](records/111024_UNetDoubleLr/c87bb826-797b-4f37-98c7-d3a5dad2de74.txt) | @brendanh0gan 89 | 12 | 5.03 minutes | [1024-ctx dense causal attention → 64K-ctx FlexAttention](https://x.com/kellerjordan0/status/1859331370268623321) | 11/19/24 | [log](records/111924_FlexAttention/8384493d-dba9-4991-b16b-8696953f5e6d.txt) | @KoszarskyB 90 | 13 | 4.66 minutes | [Attention window warmup](https://x.com/hi_tysam/status/1860851011797053450) | 11/24/24 | [log](records/112424_WindowWarmup/cf9e4571-c5fc-4323-abf3-a98d862ec6c8.txt) | @fernbear.bsky.social 91 | 14 | 4.41 minutes | [Value Embeddings](https://x.com/KoszarskyB/status/1864746625572257852) | 12/04/24 | [log](records/120424_ValueEmbed) | @KoszarskyB 92 | 15 | 3.95 minutes | [U-net pattern value embeddings, assorted code optimizations](https://x.com/YouJiacheng/status/1865761473886347747) | 12/08/24 | [log](records/120824_UNetValueEmbedsTweaks) | @leloykun, @YouJiacheng 93 | 16 | 3.80 minutes | [Split value embeddings, block sliding window, separate block mask](https://x.com/YouJiacheng/status/1866734331559071981) | 12/10/24 | [log](records/121024_MFUTweaks) | @YouJiacheng 94 | 17 | 3.57 minutes | [Sparsify value embeddings, improve rotary embeddings, drop an attn layer](https://x.com/YouJiacheng/status/1868938024731787640) | 12/17/24 | [log](records/121724_SparsifyEmbeds) | @YouJiacheng 95 | 18 | 3.4 minutes | [Lower logit softcap from 30 to 15](https://x.com/kellerjordan0/status/1876048851158880624) | 01/04/25 | [log](records/010425_SoftCap/31d6c427-f1f7-4d8a-91be-a67b5dcd13fd.txt) | @KoszarskyB 96 | 19 | 3.142 minutes | [FP8 head, offset logits, lr decay to 0.1 instead of 0.0](https://x.com/YouJiacheng/status/1878827972519772241) | 01/13/25 | [log](records/011325_Fp8LmHead/c51969c2-d04c-40a7-bcea-c092c3c2d11a.txt) | @YouJiacheng 97 | 20 | 2.992 minutes | [Merged QKV weights, long-short attention, attention scale, lower Adam epsilon, batched Muon](https://x.com/leloykun/status/1880301753213809016) | 01/16/25 | [log](records/011625_Sub3Min/1d3bd93b-a69e-4118-aeb8-8184239d7566.txt) | @leloykun, @fernbear.bsky.social, @YouJiacheng, @brendanh0gan, @scottjmaddox, @Grad62304977 98 | 21 | 2.933 minutes | [Reduced batch size](https://x.com/leloykun/status/1885640350368420160) | 01/26/25 | [log](records/012625_BatchSize/c44090cc-1b99-4c95-8624-38fb4b5834f9.txt) | @leloykun 99 | 21 | 2.997 minutes | 21st record with new timing | 02/01/25 | [log](records/020125_RuleTweak/eff63a8c-2f7e-4fc5-97ce-7f600dae0bc7.txt) | not a new record, just re-timing #21 with the [updated rules](#timing-change-after-record-21) 100 | 21 | 3.014 minutes | 21st record with latest torch | 05/24/25 | [log](records/052425_StableTorch/89d9f224-3b01-4581-966e-358d692335e0.txt) | not a new record, just re-timing #21 with latest torch 101 | 22 | 2.990 minutes | [Faster gradient all-reduce](https://x.com/KonstantinWille/status/1927137223238909969) | 05/24/25 | [log](records/052425_FasterReduce/23f40b75-06fb-4c3f-87a8-743524769a35.txt) | @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad; The Enigma project 102 | 23 | 2.979 minutes | Overlap computation and gradient communication | 05/25/25 | [log](records/052525_EvenFasterReduce/6ae86d05-5cb2-4e40-a512-63246fd08e45.txt) | @ryanyang0 103 | 104 | ## Rules 105 | 106 | The only rules are that new records must: 107 | 108 | 1. Not modify the train or validation data pipelines. (You can change the batch size, sequence length, attention structure etc.; just don't change the underlying streams of tokens.) 109 | 2. Attain ≤3.28 mean val loss. (Due to inter-run variance, submissions must provide enough run logs to attain a statistical significance level of p<0.01 that their mean val loss is ≤3.28. Example code to compute p-value can be found [here](records/010425_SoftCap#softer-softcap). For submissions which improve speed by optimizing the systems performance, without touching the ML, this requirement is waived.) 110 | 3. Not use any extra `torch._inductor.config` or `torch.compile` flags. (These can save a few seconds, but they can also make compilation take >30min. This rule was introduced after the 21st record.) 111 | 112 | > Note: `torch._inductor.config.coordinate_descent_tuning` is allowed for GPT-2 Medium track (a.k.a. 2.92 track). 113 | 114 | Other than that, anything and everything is fair game! 115 | 116 | [further clarifications](https://github.com/KellerJordan/modded-nanogpt/discussions/23?sort=new#discussioncomment-12109560) 117 | 118 | --- 119 | 120 | ### Comment on the target metric 121 | 122 | The target metric is *cross-entropy loss on the FineWeb val set*. To speak mathematically, the goal of the speedrun is *to obtain a probability model of language which assigns a probability of at least `math.exp(-3.28 * 10485760)` to the first 10,485,760 tokens of the FineWeb valset. Hence, e.g., we allow evaluation at any sequence length, so long as we still have a valid probability model of language. 123 | 124 | --- 125 | 126 | ### Timing change after record 21 127 | 128 | After the 21st record, we made two changes to the timing. First, there used to be an initial "grace period" of 10 untimed steps to allow kernel warmup. We replaced this with an explicit kernel-warmup section which is untimed and uses dummy data. This results in an extra runtime of 850ms from the 10 extra timed steps. 129 | Second, we banned the use of `torch._inductor.config.coordinate_descent_tuning`. This saves ~25min of untimed pre-run compilation, but results in an extra runtime of ~3s. 130 | 131 | 135 | 138 | 139 | 145 | 146 | --- 147 | 148 | ### Notable attempts & forks 149 | 150 | **Notable runs:** 151 | 152 | * [@alexjc's 01/20/2025 2.77-minute TokenMonster-based record](https://x.com/alexjc/status/1881410039639863622). 153 | This record is technically outside the rules of the speedrun, since we specified that the train/val tokens must be kept fixed. 154 | However, it's very interesting, and worth including. The run is not more data-efficient; rather, the speedup comes from the improved tokenizer allowing 155 | the vocabulary size to be reduced (nearly halved!) while preserving the same bytes-per-token, which saves lots of parameters and FLOPs in the head and embeddings. 156 | 157 | **Notable forks:** 158 | * [https://github.com/BlinkDL/modded-nanogpt-rwkv](https://github.com/BlinkDL/modded-nanogpt-rwkv) 159 | * [https://github.com/nikhilvyas/modded-nanogpt-SOAP](https://github.com/nikhilvyas/modded-nanogpt-SOAP) 160 | 161 | --- 162 | 163 | ## Speedrun track 2: GPT-2 Medium 164 | 165 | The target loss for this track is lowered from 3.28 to 2.92, as per Andrej Karpathy's 350M-parameter llm.c baseline. 166 | This baseline generates a model with performance similar to the original GPT-2 Medium, whereas the first track's baseline generates a model on par with GPT-2 Small. 167 | All other rules remain the same. 168 | 169 | > Note: `torch._inductor.config.coordinate_descent_tuning` is turned on after the record 6 (*). 170 | 171 | | # | Record time | Description | Date | Log | Contributors | 172 | | - | - | - | - | - | - | 173 | 1 | 5.8 hours | [llm.c baseline (350M parameters)](https://github.com/karpathy/llm.c/discussions/481) | 05/28/24 | [log](records/011825_GPT2Medium/main.log) | @karpathy, llm.c contributors 174 | 2 | 29.3 minutes | [Initial record based on scaling up the GPT-2 small track speedrun](https://x.com/kellerjordan0/status/1881959719012847703) | 01/18/25 | [log](records/011825_GPT2Medium/241dd7a7-3d76-4dce-85a4-7df60387f32a.txt) | @kellerjordan0 175 | 3 | 28.1 minutes | [Added standard weight decay](https://x.com/kellerjordan0/status/1888320690543284449) | 02/08/25 | [log](records/020825_GPT2MediumWeightDecay/b01743db-605c-4326-b5b1-d388ee5bebc5.txt) | @kellerjordan0 176 | 4 | 27.7 minutes | [Tuned Muon Newton-Schulz coefficients](https://x.com/leloykun/status/1892793848163946799) | 02/14/25 | [log](records/021425_GPT2MediumOptCoeffs/1baa66b2-bff7-4850-aced-d63885ffb4b6.txt) | @leloykun 177 | 5 | 27.2 minutes | [Increased learning rate cooldown phase duration](records/030625_GPT2MediumLongerCooldown/779c041a-2a37-45d2-a18b-ec0f223c2bb7.txt) | 03/06/25 | [log](records/030625_GPT2MediumLongerCooldown/779c041a-2a37-45d2-a18b-ec0f223c2bb7.txt) | @YouJiacheng 178 | 6 | 25.95 minutes* | [2x MLP wd, qkv norm, all_reduce/opt.step() overlap, optimized skip pattern](https://x.com/YouJiacheng/status/1905861218138804534) | 03/25/25 | [log](records/032525_GPT2MediumArchOptTweaks/train_gpt-20250329.txt) | @YouJiacheng 179 | 7 | 25.29 minutes | [Remove FP8 head; ISRU logits softcap; New sharded mixed precision Muon; merge weights](https://x.com/YouJiacheng/status/1912570883878842527) | 04/16/25 | [log](records/041625_GPT2Medium_Record7/223_3310d0b1-b24d-48ee-899f-d5c2a254a195.txt) | @YouJiacheng 180 | 8 | 24.50 minutes | [Cubic sliding window size schedule, 2× max window size (24.84 minutes)](https://x.com/jadenj3o/status/1914893086276169754) [24.5min repro](https://x.com/YouJiacheng/status/1915667616913645985) | 04/22/25 | [log](records/042225_GPT2Medium_Record8/075_640429f2-e726-4e83-aa27-684626239ffc.txt) | @jadenj3o 181 | 182 | --- 183 | 184 | ### Q: What is the point of NanoGPT speedrunning? 185 | 186 | A: The officially stated goal of NanoGPT speedrunning is as follows: `gotta go fast`. But for something a little more verbose involving an argument for good benchmarking, here's some kind of manifesto, adorned with a blessing from the master. [https://x.com/karpathy/status/1846790537262571739](https://x.com/karpathy/status/1846790537262571739) 187 | 188 | ### Q: What makes "NanoGPT speedrunning" not just another idiosyncratic benchmark? 189 | 190 | A: Because it is a *competitive* benchmark. In particular, if you attain a new speed record (using whatever method you want), there is an open invitation for you 191 | to post that record (on arXiv or X) and thereby vacuum up all the clout for yourself. I will even help you do it by reposting you as much as I can. 192 | 193 | 200 | 201 | ["Artificial intelligence advances by inventing games and gloating to goad others to play" - Professor Ben Recht](https://www.argmin.net/p/too-much-information) 202 | 203 | ### Q: NanoGPT speedrunning is cool and all, but meh it probably won't scale and is just overfitting to val loss 204 | 205 | A: This is hard to refute, since "at scale" is an infinite category (what if the methods stop working only for >100T models?), making it impossible to fully prove. 206 | Also, I would agree that some of the methods used in the speedrun are unlikely to scale, particularly those which *impose additional structure* on the network, such as logit softcapping. 207 | But if the reader cares about 1.5B models, they might be convinced by this result: 208 | 209 | *Straightforwardly scaling up the speedrun (10/18/24 version) to 1.5B parameters yields a model with GPT-2 (1.5B)-level HellaSwag performance 2.5x more cheaply than [@karpathy's baseline](https://github.com/karpathy/llm.c/discussions/677) ($233 instead of $576):* 210 | 211 | ![](img/nanogpt_speedrun51.png) 212 | [[reproducible log](https://github.com/KellerJordan/modded-nanogpt/blob/master/records/102024_ScaleUp1B/ad8d7ae5-7b2d-4ee9-bc52-f912e9174d7a.txt)] 213 | ![](img/nanogpt_speedrun52.png) 214 | 215 | --- 216 | 217 | ## [Muon optimizer](https://github.com/KellerJordan/Muon) 218 | 219 | Muon is defined as follows: 220 | 221 | ![](img/algo_optimizer.png) 222 | 223 | Where NewtonSchulz5 is the following Newton-Schulz iteration [2, 3], which approximately replaces `G` with `U @ V.T` where `U, S, V = G.svd()`. 224 | ```python 225 | @torch.compile 226 | def zeroth_power_via_newtonschulz5(G, steps=5, eps=1e-7): 227 | assert len(G.shape) == 2 228 | a, b, c = (3.4445, -4.7750, 2.0315) 229 | X = G.bfloat16() / (G.norm() + eps) 230 | if G.size(0) > G.size(1): 231 | X = X.T 232 | for _ in range(steps): 233 | A = X @ X.T 234 | B = b * A + c * A @ A 235 | X = a * X + B @ X 236 | if G.size(0) > G.size(1): 237 | X = X.T 238 | return X.to(G.dtype) 239 | ``` 240 | 241 | For this training scenario, Muon has the following favorable properties: 242 | * Lower memory usage than Adam 243 | * ~1.5x better sample-efficiency 244 | * <2% wallclock overhead 245 | 246 | 247 | ### Provenance 248 | 249 | Many of the choices made to generate this optimizer were obtained experimentally by our pursuit of [CIFAR-10 speedrunning](https://github.com/KellerJordan/cifar10-airbench). 250 | In particular, we experimentally obtained the following practices: 251 | * Using Nesterov momentum inside the update, with orthogonalization applied after momentum. 252 | * Using a specifically quintic Newton-Schulz iteration as the method of orthogonalization. 253 | * Using non-convergent coefficients for the quintic polynomial in order to maximize slope at zero, and thereby minimize the number of necessary Newton-Schulz iterations. 254 | It turns out that the variance doesn't actually matter that much, so we end up with a quintic that rapidly converges to the range 0.68, 1.13 upon repeated application, rather than converging more slowly to 1. 255 | * Running the Newton-Schulz iteration in bfloat16 (whereas Shampoo implementations often depend on inverse-pth-roots run in fp32 or fp64). 256 | 257 | Our use of a Newton-Schulz iteration for orthogonalization traces to [Bernstein & Newhouse (2024)](https://arxiv.org/abs/2409.20325), 258 | who suggested it as a way to compute Shampoo [5, 6] preconditioners, and theoretically explored Shampoo without preconditioner accumulation. 259 | In particular, Jeremy Bernstein @jxbz sent us the draft, which caused us to experiment with various Newton-Schulz iterations as the 260 | orthogonalization method for this optimizer. 261 | If we had used SVD instead of a Newton-Schulz iteration, this optimizer would have been too slow to be useful. 262 | Bernstein & Newhouse also pointed out that Shampoo without preconditioner accumulation is equivalent to steepest descent in the spectral norm, 263 | and therefore Shampoo can be thought of as a way to smooth out spectral steepest descent. 264 | The proposed optimizer can be thought of as a second way of smoothing spectral steepest descent, with a different set of memory and runtime tradeoffs 265 | compared to Shampoo. 266 | 267 | --- 268 | 269 | ## Running on fewer GPUs 270 | 271 | * To run experiments on fewer GPUs, simply modify `run.sh` to have a different `--nproc_per_node`. This should not change the behavior of the training. 272 | * If you're running out of memory, you may need to reduce the sequence length for FlexAttention (which does change the training. see [here](https://github.com/KellerJordan/modded-nanogpt/pull/38) for a guide) 273 | 274 | --- 275 | 276 | ## References 277 | 278 | 1. [Guilherme Penedo et al. "The fineweb datasets: Decanting the web for the finest text data at scale." arXiv preprint arXiv:2406.17557 (2024).](https://arxiv.org/abs/2406.17557) 279 | 2. Nicholas J. Higham. Functions of Matrices. Society for Industrial and Applied Mathematics (2008). Equation 5.22. 280 | 3. Günther Schulz. Iterative Berechnung der reziproken Matrix. Z. Angew. Math. Mech., 13:57–59 (1933). 281 | 4. [Jeremy Bernstein and Laker Newhouse. "Old Optimizer, New Norm: An Anthology." arxiv preprint arXiv:2409.20325 (2024).](https://arxiv.org/abs/2409.20325) 282 | 5. [Vineet Gupta, Tomer Koren, and Yoram Singer. "Shampoo: Preconditioned stochastic tensor optimization." International Conference on Machine Learning. PMLR, 2018.](https://arxiv.org/abs/1802.09568) 283 | 6. [Rohan Anil et al. "Scalable second order optimization for deep learning." arXiv preprint arXiv:2002.09018 (2020).](https://arxiv.org/abs/2002.09018) 284 | 7. [Alexander Hägele et al. "Scaling Laws and Compute-Optimal Training Beyond Fixed Training Durations." arXiv preprint arXiv:2405.18392 (2024).](https://arxiv.org/abs/2405.18392) 285 | 8. [Zhanchao Zhou et al. "Value Residual Learning For Alleviating Attention Concentration In Transformers." arXiv preprint arXiv:2410.17897 (2024).](https://arxiv.org/abs/2410.17897) 286 | 9. [Team, Gemma, et al. "Gemma 2: Improving open language models at a practical size." arXiv preprint arXiv:2408.00118 (2024).](https://arxiv.org/abs/2408.00118) 287 | 10. [Alec Radford et al. "Language models are unsupervised multitask learners." OpenAI blog 1.8 (2019).](https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) 288 | 289 | ## Citation 290 | 291 | ``` 292 | @misc{modded_nanogpt_2024, 293 | author = {Keller Jordan and Jeremy Bernstein and Brendan Rappazzo and 294 | @fernbear.bsky.social and Boza Vlado and You Jiacheng and 295 | Franz Cesista and Braden Koszarsky and @Grad62304977}, 296 | title = {modded-nanogpt: Speedrunning the NanoGPT baseline}, 297 | year = {2024}, 298 | url = {https://github.com/KellerJordan/modded-nanogpt} 299 | } 300 | ``` 301 | 302 | itsover_wereback 303 | 304 | -------------------------------------------------------------------------------- /data/cached_fineweb100B.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from huggingface_hub import hf_hub_download 4 | # Download the GPT-2 tokens of Fineweb100B from huggingface. This 5 | # saves about an hour of startup time compared to regenerating them. 6 | def get(fname): 7 | local_dir = os.path.join(os.path.dirname(__file__), 'fineweb100B') 8 | if not os.path.exists(os.path.join(local_dir, fname)): 9 | hf_hub_download(repo_id="kjj0/fineweb100B-gpt2", filename=fname, 10 | repo_type="dataset", local_dir=local_dir) 11 | get("fineweb_val_%06d.bin" % 0) 12 | num_chunks = 1030 # full fineweb100B. Each chunk is 100M tokens 13 | if len(sys.argv) >= 2: # we can pass an argument to download less 14 | num_chunks = int(sys.argv[1]) 15 | for i in range(1, num_chunks+1): 16 | get("fineweb_train_%06d.bin" % i) 17 | -------------------------------------------------------------------------------- /data/cached_fineweb10B.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from huggingface_hub import hf_hub_download 4 | # Download the GPT-2 tokens of Fineweb10B from huggingface. This 5 | # saves about an hour of startup time compared to regenerating them. 6 | def get(fname): 7 | local_dir = os.path.join(os.path.dirname(__file__), 'fineweb10B') 8 | if not os.path.exists(os.path.join(local_dir, fname)): 9 | hf_hub_download(repo_id="kjj0/fineweb10B-gpt2", filename=fname, 10 | repo_type="dataset", local_dir=local_dir) 11 | get("fineweb_val_%06d.bin" % 0) 12 | num_chunks = 103 # full fineweb10B. Each chunk is 100M tokens 13 | if len(sys.argv) >= 2: # we can pass an argument to download less 14 | num_chunks = int(sys.argv[1]) 15 | for i in range(1, num_chunks+1): 16 | get("fineweb_train_%06d.bin" % i) 17 | -------------------------------------------------------------------------------- /data/cached_finewebedu10B.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from huggingface_hub import hf_hub_download 4 | # Download the GPT-2 tokens of FinewebEDU10B from huggingface. This 5 | # saves about an hour of startup time compared to regenerating them. 6 | def get(fname): 7 | local_dir = os.path.join(os.path.dirname(__file__), 'finewebedu10B') 8 | if not os.path.exists(os.path.join(local_dir, fname)): 9 | hf_hub_download(repo_id="kjj0/finewebedu10B-gpt2", filename=fname, 10 | repo_type="dataset", local_dir=local_dir) 11 | get("finewebedu_val_%06d.bin" % 0) 12 | num_chunks = 99 # full FinewebEDU10B. Each chunk is 100M tokens 13 | if len(sys.argv) >= 2: # we can pass an argument to download less 14 | num_chunks = int(sys.argv[1]) 15 | for i in range(1, num_chunks+1): 16 | get("finewebedu_train_%06d.bin" % i) 17 | -------------------------------------------------------------------------------- /data/fineweb.py: -------------------------------------------------------------------------------- 1 | """ 2 | FineWeb dataset (for srs pretraining) 3 | https://huggingface.co/datasets/HuggingFaceFW/fineweb 4 | 5 | example doc to highlight the structure of the dataset: 6 | { 7 | "text": "Posted by mattsmith on 20th April 2012\nStraight from...", 8 | "id": "", 9 | "dump": "CC-MAIN-2013-20", 10 | "url": "http://nleastchatter.com/philliesphandom/tag/freddy-galvis/", 11 | "date": "2013-05-18T07:24:47Z", 12 | "file_path": "s3://commoncrawl/long.../path.../file.gz", 13 | "language": "en", 14 | "language_score": 0.9185474514961243, 15 | "token_count": 594 16 | } 17 | """ 18 | import os 19 | import argparse 20 | import multiprocessing as mp 21 | import numpy as np 22 | import tiktoken 23 | # from huggingface_hub import snapshot_download 24 | from datasets import load_dataset 25 | from tqdm import tqdm 26 | import argparse 27 | import numpy as np 28 | def write_datafile(filename, toks): 29 | """ 30 | Saves token data as a .bin file, for reading in C. 31 | - First comes a header with 256 int32s 32 | - The tokens follow, each as a uint16 33 | """ 34 | assert len(toks) < 2**31, "token count too large" # ~2.1B tokens 35 | # construct the header 36 | header = np.zeros(256, dtype=np.int32) 37 | header[0] = 20240520 # magic 38 | header[1] = 1 # version 39 | header[2] = len(toks) # number of tokens after the 256*4 bytes of header (each 2 bytes as uint16) 40 | # construct the tokens numpy array, if not already 41 | if not isinstance(toks, np.ndarray) or not toks.dtype == np.uint16: 42 | # validate that no token exceeds a uint16 43 | maxtok = 2**16 44 | assert all(0 <= t < maxtok for t in toks), "token dictionary too large for uint16" 45 | toks_np = np.array(toks, dtype=np.uint16) 46 | else: 47 | toks_np = toks 48 | # write to file 49 | print(f"writing {len(toks):,} tokens to {filename}") 50 | with open(filename, "wb") as f: 51 | f.write(header.tobytes()) 52 | f.write(toks_np.tobytes()) 53 | # ------------------------------------------ 54 | 55 | parser = argparse.ArgumentParser(description="FineWeb dataset preprocessing") 56 | parser.add_argument("-v", "--version", type=str, default="10B", help="Which version of fineweb to use 10B|100B") 57 | parser.add_argument("-s", "--shard_size", type=int, default=10**8, help="Size of each shard in tokens") 58 | args = parser.parse_args() 59 | 60 | # FineWeb has a few possible subsamples available 61 | assert args.version in ["10B", "100B"], "version must be one of 10B, 100B" 62 | if args.version == "10B": 63 | local_dir = "fineweb10B" 64 | remote_name = "sample-10BT" 65 | elif args.version == "100B": 66 | local_dir = "fineweb100B" 67 | remote_name = "sample-100BT" 68 | 69 | # create the cache the local directory if it doesn't exist yet 70 | DATA_CACHE_DIR = os.path.join(os.path.dirname(__file__), local_dir) 71 | os.makedirs(DATA_CACHE_DIR, exist_ok=True) 72 | 73 | # download the dataset 74 | fw = load_dataset("HuggingFaceFW/fineweb", name=remote_name, split="train") 75 | 76 | # init the tokenizer 77 | enc = tiktoken.get_encoding("gpt2") 78 | eot = enc._special_tokens['<|endoftext|>'] # end of text token 79 | def tokenize(doc): 80 | # tokenizes a single document and returns a numpy array of uint16 tokens 81 | tokens = [eot] # the special <|endoftext|> token delimits all documents 82 | tokens.extend(enc.encode_ordinary(doc["text"])) 83 | tokens_np = np.array(tokens) 84 | assert (0 <= tokens_np).all() and (tokens_np < 2**16).all(), "token dictionary too large for uint16" 85 | tokens_np_uint16 = tokens_np.astype(np.uint16) 86 | return tokens_np_uint16 87 | 88 | # tokenize all documents and write output shards, each of shard_size tokens (last shard has remainder) 89 | nprocs = max(1, os.cpu_count() - 2) # don't hog the entire system 90 | with mp.Pool(nprocs) as pool: 91 | shard_index = 0 92 | # preallocate buffer to hold current shard 93 | all_tokens_np = np.empty((args.shard_size,), dtype=np.uint16) 94 | token_count = 0 95 | progress_bar = None 96 | for tokens in pool.imap(tokenize, fw, chunksize=16): 97 | 98 | # is there enough space in the current shard for the new tokens? 99 | if token_count + len(tokens) < args.shard_size: 100 | # simply append tokens to current shard 101 | all_tokens_np[token_count:token_count+len(tokens)] = tokens 102 | token_count += len(tokens) 103 | # update progress bar 104 | if progress_bar is None: 105 | progress_bar = tqdm(total=args.shard_size, unit="tokens", desc=f"Shard {shard_index}") 106 | progress_bar.update(len(tokens)) 107 | else: 108 | # write the current shard and start a new one 109 | split = "val" if shard_index == 0 else "train" 110 | filename = os.path.join(DATA_CACHE_DIR, f"fineweb_{split}_{shard_index:06d}.bin") 111 | # split the document into whatever fits in this shard; the remainder goes to next one 112 | remainder = args.shard_size - token_count 113 | progress_bar.update(remainder) 114 | all_tokens_np[token_count:token_count+remainder] = tokens[:remainder] 115 | write_datafile(filename, all_tokens_np) 116 | shard_index += 1 117 | progress_bar = None 118 | # populate the next shard with the leftovers of the current doc 119 | all_tokens_np[0:len(tokens)-remainder] = tokens[remainder:] 120 | token_count = len(tokens)-remainder 121 | 122 | # write any remaining tokens as the last shard 123 | if token_count != 0: 124 | split = "val" if shard_index == 0 else "train" 125 | filename = os.path.join(DATA_CACHE_DIR, f"fineweb_{split}_{shard_index:06d}.bin") 126 | write_datafile(filename, all_tokens_np[:token_count]) 127 | -------------------------------------------------------------------------------- /data/requirements.txt: -------------------------------------------------------------------------------- 1 | datasets 2 | tiktoken 3 | -------------------------------------------------------------------------------- /img/algo_optimizer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KellerJordan/modded-nanogpt/ca964e982191830eebbd155e185937077511a8aa/img/algo_optimizer.png -------------------------------------------------------------------------------- /img/dofa.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KellerJordan/modded-nanogpt/ca964e982191830eebbd155e185937077511a8aa/img/dofa.jpg -------------------------------------------------------------------------------- /img/fig_optimizer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KellerJordan/modded-nanogpt/ca964e982191830eebbd155e185937077511a8aa/img/fig_optimizer.png -------------------------------------------------------------------------------- /img/fig_tuned_nanogpt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KellerJordan/modded-nanogpt/ca964e982191830eebbd155e185937077511a8aa/img/fig_tuned_nanogpt.png -------------------------------------------------------------------------------- /img/nanogpt_speedrun51.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KellerJordan/modded-nanogpt/ca964e982191830eebbd155e185937077511a8aa/img/nanogpt_speedrun51.png -------------------------------------------------------------------------------- /img/nanogpt_speedrun52.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KellerJordan/modded-nanogpt/ca964e982191830eebbd155e185937077511a8aa/img/nanogpt_speedrun52.png -------------------------------------------------------------------------------- /img/nanogpt_speedrun53.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KellerJordan/modded-nanogpt/ca964e982191830eebbd155e185937077511a8aa/img/nanogpt_speedrun53.png -------------------------------------------------------------------------------- /img/nanogpt_speedrun54.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KellerJordan/modded-nanogpt/ca964e982191830eebbd155e185937077511a8aa/img/nanogpt_speedrun54.png -------------------------------------------------------------------------------- /records/010425_SoftCap/README.md: -------------------------------------------------------------------------------- 1 | # Softer softcap 2 | 3 | This record, by Braden Koszarsky, increases the degree of logit softcapping, yielding a 7% speedup. 4 | [reproducible log](31d6c427-f1f7-4d8a-91be-a67b5dcd13fd.txt) 5 | 6 | Previously, logits were softcapped (via tanh) to be at most 30. The new record lowers that to 15, 7 | which boosts performance such that the step count can be reduced from 1490 to 1390. 8 | 9 | Lowering the tanh softcap can be understood as a form of extra structure which we are imposing on the network, which improves 10 | performance in the small-scale regime. 11 | 12 | Running this new record 80 times yielded the following series of val losses: 13 | ``` 14 | accs = [3.2798, 3.2804, 3.2837, 3.2808, 3.2782, 3.2801, 3.283, 3.2825, 3.2777, 3.2769, 3.2834, 3.2832, 3.2753, 15 | 3.2809, 3.2778, 3.2801, 3.2799, 3.2804, 3.2765, 3.2792, 3.2786, 3.2792, 3.2801, 3.2762, 3.2803, 3.2784, 16 | 3.2792, 3.2791, 3.2769, 3.279, 3.2784, 3.2775, 3.283, 3.2785, 3.2753, 3.2805, 3.2766, 3.2766, 3.2781, 17 | 3.2819, 3.2754, 3.2827, 3.2803, 3.2784, 3.2802, 3.2794, 3.2765, 3.278, 3.2782, 3.278, 3.2816, 3.279, 18 | 3.2771, 3.2791, 3.2768, 3.2781, 3.2794, 3.2798, 3.2785, 3.2804, 3.2777, 3.2765, 3.2796, 3.278, 3.2803, 19 | 3.2793, 3.2793, 3.2788, 3.2797, 3.278, 3.2799, 3.2813, 3.2803, 3.2768, 3.2803, 3.2796, 3.28, 3.2796, 20 | 3.2783, 3.278] 21 | 22 | import scipy.stats 23 | print('p=%.4f' % scipy.stats.ttest_1samp(accs, 3.28, alternative='less').pvalue) 24 | # p=0.0001 25 | 26 | import torch 27 | print(torch.std_mean(torch.tensor(accs))) 28 | # (tensor(0.0019), tensor(3.2791)) 29 | ``` 30 | 31 | ![](curves_010425.png) 32 | 33 | -------------------------------------------------------------------------------- /records/010425_SoftCap/curves_010425.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KellerJordan/modded-nanogpt/ca964e982191830eebbd155e185937077511a8aa/records/010425_SoftCap/curves_010425.png -------------------------------------------------------------------------------- /records/011325_Fp8LmHead/README.md: -------------------------------------------------------------------------------- 1 | Note: statistical significance was obtained by @YouJiacheng [here](https://x.com/YouJiacheng/status/1878827972519772241). 2 | -------------------------------------------------------------------------------- /records/011625_Sub3Min/README.md: -------------------------------------------------------------------------------- 1 | # Sub-3 minute record 2 | 3 | ## Evidence for <=3.28 mean loss 4 | 5 | ```bash 6 | $ grep "1393/1393 val" * | python -c "import sys; ss = list(sys.stdin); accs = [float(s.split()[1].split(':')[1]) for s in ss]; print(accs); import scipy.stats; mvs = scipy.stats.bayes_mvs(accs); print(mvs[0]); print(mvs[2]); print(f'p={scipy.stats.ttest_1samp(accs, 3.28, alternative='less').pvalue:.4f}')" 7 | [3.276, 3.2785, 3.2796, 3.2788, 3.2789, 3.2768, 3.2775, 3.2784, 3.2767, 3.2792, 3.2807, 3.2801, 3.2805, 3.2777, 3.2789, 3.2799, 3.2786, 3.2776, 3.2791, 3.2808, 3.2776, 3.2786, 3.2774, 3.2832, 3.277, 3.2789, 3.2784, 3.2766, 3.2755, 3.2784, 3.2798, 3.2825] 8 | Mean(statistic=np.float64(3.27869375), minmax=(np.float64(3.2781784751445135), np.float64(3.2792090248554864))) 9 | Std_dev(statistic=np.float64(0.0017621789337662857), minmax=(np.float64(0.0014271074116428265), np.float64(0.002179878373699496))) 10 | p=0.0001 11 | ``` 12 | 13 | ``` 14 | Mean runtime: 179.8 seconds 15 | Stddev: 101ms 16 | ``` 17 | 18 | ## Details on the changes made 19 | 20 | ### Long-Short Sliding Window Attention 21 | 22 | ![](long-short-swa.png) 23 | 24 | This attention mechanism is inspired by the Local-Global Attention introduced by the [Gemma 2](https://arxiv.org/abs/2408.00118) paper (and more recent "hybrid" architectures). But there are two key differences: 25 | 26 | 1. We use [Sliding Window Attention](https://arxiv.org/abs/2004.05150) for both the "global attention" (i.e. "long SWA") and the "local attention" (i.e. "short SWA") parts. The difference between the two is that the "long SWA" has double the context length of the "short SWA". 27 | 2. We also **warmup the context length** of both the sliding window attention mechanisms, but **at different rates**. The "long SWA" context length is warmed up at a double the rate compared to the "short SWA". 28 | 29 | We also made a speedrun-specific decision to only use "long SWA" in the first, fifth, and last layers. The first, because we do not want to compress information too early in the network. The last, because the model architecture we use for the speedrun follows a UNet-like structure, and we want the first and the last layers to be symmetric. And finally, the fifth layer, mainly because it is empirically the best choice for the speedrun. 30 | 31 | This would have been very difficult to implement without PyTorch's [FlexAttention](https://pytorch.org/blog/flexattention/). 32 | 33 | ```diff 34 | # In GPT.forward... 35 | def dense_to_ordered(dense_mask: torch.Tensor): 36 | num_blocks = dense_mask.sum(dim=-1, dtype=torch.int32) 37 | - indices = dense_mask.argsort(dim=-1, descending=True, stable=True).to(torch.int32) 38 | + indices = dense_mask.argsort(dim=-1, descending=False, stable=True).flip(-1).to(torch.int32) 39 | return num_blocks[None, None].contiguous(), indices[None, None].contiguous() 40 | 41 | - def create_doc_swc_block_mask(sliding_window_num_blocks): 42 | + def create_doc_swc_block_masks(sliding_window_num_blocks: int): 43 | kv_idx = block_idx = torch.arange(total_num_blocks, dtype=torch.int32, device='cuda') 44 | q_idx = block_idx[:, None] 45 | causal_bm = q_idx >= kv_idx 46 | causal_full_bm = q_idx > kv_idx 47 | - window_bm = q_idx - kv_idx < sliding_window_num_blocks 48 | - window_full_bm = window_bm # block-wise sliding window by @YouJiacheng 49 | document_bm = (docs_low[:, None] <= docs_high) & (docs_low <= docs_high[:, None]) 50 | document_full_bm = (docs_low[:, None] == docs_high) & (docs_low == docs_high[:, None]) 51 | - nonzero_bm = causal_bm & window_bm & document_bm 52 | - full_bm = causal_full_bm & window_full_bm & document_full_bm 53 | + nonzero_bm = causal_bm & document_bm 54 | + full_bm = causal_full_bm & document_full_bm 55 | kv_num_blocks, kv_indices = dense_to_ordered(nonzero_bm & ~full_bm) 56 | full_kv_num_blocks, full_kv_indices = dense_to_ordered(full_bm) 57 | - return BlockMask.from_kv_blocks( 58 | - kv_num_blocks, 59 | - kv_indices, 60 | - full_kv_num_blocks, 61 | - full_kv_indices, 62 | - BLOCK_SIZE=BLOCK_SIZE, 63 | - mask_mod=document_causal, 64 | - ) 65 | + def build_bm(sw_num_blocks: Tensor) -> BlockMask: 66 | + return BlockMask.from_kv_blocks( 67 | + torch.clamp_max(kv_num_blocks, torch.clamp_min(sw_num_blocks - full_kv_num_blocks, 1)), 68 | + kv_indices, 69 | + torch.clamp_max(full_kv_num_blocks, sw_num_blocks - 1), 70 | + full_kv_indices, 71 | + BLOCK_SIZE=BLOCK_SIZE, 72 | + mask_mod=document_causal, 73 | + ) 74 | + return build_bm(sliding_window_num_blocks), build_bm(sliding_window_num_blocks // 2) 75 | 76 | - block_mask = create_doc_swc_block_mask(sliding_window_num_blocks) 77 | + long_bm, short_bm = create_doc_swc_block_masks(sliding_window_num_blocks) 78 | ... 79 | skip_connections = [] 80 | # Encoder pass - process only the first half of the blocks 81 | + block_masks = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm] 82 | for i in range(self.num_encoder_layers): 83 | - x = self.blocks[i](x, ve_enc[i], x0, block_mask) 84 | + x = self.blocks[i](x, ve_enc[i], x0, block_masks[i]) 85 | skip_connections.append(x) 86 | # Decoder pass - process the remaining blocks with weighted skip connections 87 | + block_masks.reverse() 88 | for i in range(self.num_decoder_layers): 89 | x = x + self.skip_weights[i] * skip_connections.pop() 90 | - x = self.blocks[self.num_encoder_layers + i](x, ve_dec[i], x0, block_mask) 91 | + x = self.blocks[self.num_encoder_layers + i](x, ve_dec[i], x0, block_masks[i]) 92 | ``` 93 | 94 | ### Attention Scale Modification 95 | 96 | We currently use QK-Normalization to stabilize the attention coefficients. This helps [reduce the wallclock time of the speedrun](https://x.com/kellerjordan0/status/1845865698532450646). However, unlike in larger-scale models such as [ViT-22B](https://arxiv.org/pdf/2302.05442) and [Chameleon](https://arxiv.org/pdf/2405.09818v1), we use a parameter-free RMSNorm instead of the usual LayerNorm with learnable parameters. 97 | 98 | But while the parameter-free RMSNorm is faster and leads to more stable training runs, it also constrains the logit sharpness and consequently the entropy of the attention coefficients to be in the same range across different layers. And in out setup, this leads to higher attention entropies which means the model is less "certain" which tokens to "attend to" during training. While not problematic early in training as we also don't want the model to overfit early on, it can be problematic later on when we want the model to "focus" on the most important tokens. And the current record is now tight-enough for this to be a problem. 99 | 100 | ![](attn-entropy.png) 101 | 102 | To fix this issue, we first tried out (1) RMSNorm with learned channel-wise parameters and (2) a learned scalar "attention scale" parameter, one for each Attenion layer. Both approaches allowed us to reduce training steps by 20, with a ~0.5-0.7 ms/step overhead. Overall, the wallclock time reduction was ~2-3 secs. 103 | 104 | Strangely, the models seemed to consistently learn a UNet-like attention scales pattern. And hardcoding this pattern lead to roughly the same results (e.g. `attn_scale(layer_idx) := 0.12 + 0.01 * min(layer_idx, 11 - layer_idx)`). We find this interesting and could be a potential area for future research. But fow now, we offer now explanation why this pattern emerges and why it works well aside from divine intervention. 105 | 106 | ![](attn-scales-pattern.gif) 107 | 108 | We eventually settled with simply setting the attention scale to `0.12` (vs. the default `1.0 / sqrt(d_model)`) for all layers. This leads to the same 20 step reduction, but with no per-step overhead; an overall speed gain for ~3 secs. 109 | 110 | ```diff 111 | # In CausalSelfAttention.__init__ 112 | + self.attn_scale = 0.12 113 | ... 114 | # In CausalSelfAttention.forward 115 | - y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask) 116 | + y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask, scale=self.attn_scale) 117 | ``` 118 | 119 | For logs on learnable attention scales, see: [README for 01/12/25 record attempt](https://github.com/leloykun/modded-nanogpt/blob/fc--learnable-attn-scale/records/011225_LearnableAttnScale/README.md) 120 | 121 | ### Stacked QKV Weights & Batched Muon Implementation 122 | 123 | This is an implementation/compiler-level optimization that leads to a 1-2 secs speed improvement. The crux is that, with a big enough GPU, doing one massive matmul for the QKV weights is faster than doing three smaller matmuls, one for each of the weights. 124 | 125 | The problem, however, is that Muon performs better on the unmerged QKV weights primarily due to the massive matmuls in its Newton-Schulz iterations. Our previous implementation involved storing these weights separately as before but concatenating them in the forward pass. But this concatenation operation introduced a ~1 sec regression. Finally, we got rid of this overhead by stacking the QKV weights instead and using a batched implementation of Muon. 126 | 127 | ### Adam `eps=1e-10` fix 128 | 129 | The speedrun is so tight now that even Adam's default epsilon parameter is already causing problems. 130 | 131 | For context, we initialize our LM head as a zero matrix. This leads to small gradients early on in training which could sometimes be even smaller than Adam's default epsilon--causing training instability and increased validation loss. 132 | 133 | To address this issue, we simply reduced Adam's `eps` from `1e-8` down to `1e-10`. This lead to a 0.0014 validation loss improvement with no per-step overhead; thereby allowing us to reduce training steps by 10. 134 | 135 | ```diff 136 | - optimizer1 = torch.optim.Adam(adam_params, betas=(0.8, 0.95), fused=True) 137 | + optimizer1 = torch.optim.Adam(adam_params, betas=(0.8, 0.95), fused=True, eps=1e-10) 138 | ``` 139 | -------------------------------------------------------------------------------- /records/011625_Sub3Min/attn-entropy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KellerJordan/modded-nanogpt/ca964e982191830eebbd155e185937077511a8aa/records/011625_Sub3Min/attn-entropy.png -------------------------------------------------------------------------------- /records/011625_Sub3Min/attn-scales-pattern.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KellerJordan/modded-nanogpt/ca964e982191830eebbd155e185937077511a8aa/records/011625_Sub3Min/attn-scales-pattern.gif -------------------------------------------------------------------------------- /records/011625_Sub3Min/learned-attn-scales.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KellerJordan/modded-nanogpt/ca964e982191830eebbd155e185937077511a8aa/records/011625_Sub3Min/learned-attn-scales.png -------------------------------------------------------------------------------- /records/011625_Sub3Min/long-short-swa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KellerJordan/modded-nanogpt/ca964e982191830eebbd155e185937077511a8aa/records/011625_Sub3Min/long-short-swa.png -------------------------------------------------------------------------------- /records/011825_GPT2Medium/main.log: -------------------------------------------------------------------------------- 1 | s:0 tel:10.9830 2 | s:250 tel:6.3097 3 | s:500 tel:5.5887 4 | s:750 tel:4.9040 5 | s:1000 tel:4.3734 6 | s:1250 tel:4.1307 7 | s:1500 tel:3.9782 8 | s:1750 tel:3.8681 9 | s:2000 tel:3.7884 10 | s:2250 tel:3.7278 11 | s:2500 tel:3.6737 12 | s:2750 tel:3.6324 13 | s:3000 tel:3.6014 14 | s:3250 tel:3.5695 15 | s:3500 tel:3.5400 16 | s:3750 tel:3.5161 17 | s:4000 tel:3.4925 18 | s:4250 tel:3.4754 19 | s:4500 tel:3.4540 20 | s:4750 tel:3.4349 21 | s:5000 tel:3.4190 22 | s:5250 tel:3.4076 23 | s:5500 tel:3.3982 24 | s:5750 tel:3.3842 25 | s:6000 tel:3.3674 26 | s:6250 tel:3.3589 27 | s:6500 tel:3.3493 28 | s:6750 tel:3.3442 29 | s:7000 tel:3.3309 30 | s:7250 tel:3.3207 31 | s:7500 tel:3.3110 32 | s:7750 tel:3.3055 33 | s:8000 tel:3.2969 34 | s:8250 tel:3.2885 35 | s:8500 tel:3.2813 36 | s:8750 tel:3.2780 37 | s:9000 tel:3.2689 38 | s:9250 tel:3.2654 39 | s:9500 tel:3.2574 40 | s:9750 tel:3.2507 41 | s:10000 tel:3.2461 42 | s:10250 tel:3.2401 43 | s:10500 tel:3.2369 44 | s:10750 tel:3.2297 45 | s:11000 tel:3.2247 46 | s:11250 tel:3.2212 47 | s:11500 tel:3.2165 48 | s:11750 tel:3.2135 49 | s:12000 tel:3.2063 50 | s:12250 tel:3.2048 51 | s:12500 tel:3.1993 52 | s:12750 tel:3.1969 53 | s:13000 tel:3.1934 54 | s:13250 tel:3.1887 55 | s:13500 tel:3.1873 56 | s:13750 tel:3.1823 57 | s:14000 tel:3.1819 58 | s:14250 tel:3.1758 59 | s:14500 tel:3.1715 60 | s:14750 tel:3.1668 61 | s:15000 tel:3.1648 62 | s:15250 tel:3.1618 63 | s:15500 tel:3.1589 64 | s:15750 tel:3.1562 65 | s:16000 tel:3.1533 66 | s:16250 tel:3.1508 67 | s:16500 tel:3.1466 68 | s:16750 tel:3.1450 69 | s:17000 tel:3.1448 70 | s:17250 tel:3.1398 71 | s:17500 tel:3.1355 72 | s:17750 tel:3.1342 73 | s:18000 tel:3.1311 74 | s:18250 tel:3.1281 75 | s:18500 tel:3.1258 76 | s:18750 tel:3.1229 77 | s:19000 tel:3.1240 78 | s:19250 tel:3.1195 79 | s:19500 tel:3.1167 80 | s:19750 tel:3.1144 81 | s:20000 tel:3.1129 82 | s:20250 tel:3.1115 83 | s:20500 tel:3.1091 84 | s:20750 tel:3.1074 85 | s:21000 tel:3.1037 86 | s:21250 tel:3.1012 87 | s:21500 tel:3.1006 88 | s:21750 tel:3.0981 89 | s:22000 tel:3.0951 90 | s:22250 tel:3.0938 91 | s:22500 tel:3.0920 92 | s:22750 tel:3.0897 93 | s:23000 tel:3.0888 94 | s:23250 tel:3.0845 95 | s:23500 tel:3.0850 96 | s:23750 tel:3.0812 97 | s:24000 tel:3.0794 98 | s:24250 tel:3.0773 99 | s:24500 tel:3.0755 100 | s:24750 tel:3.0758 101 | s:25000 tel:3.0728 102 | s:25250 tel:3.0708 103 | s:25500 tel:3.0677 104 | s:25750 tel:3.0676 105 | s:26000 tel:3.0654 106 | s:26250 tel:3.0631 107 | s:26500 tel:3.0604 108 | s:26750 tel:3.0589 109 | s:27000 tel:3.0587 110 | s:27250 tel:3.0572 111 | s:27500 tel:3.0553 112 | s:27750 tel:3.0534 113 | s:28000 tel:3.0525 114 | s:28250 tel:3.0501 115 | s:28500 tel:3.0486 116 | s:28750 tel:3.0462 117 | s:29000 tel:3.0456 118 | s:29250 tel:3.0437 119 | s:29500 tel:3.0406 120 | s:29750 tel:3.0409 121 | s:30000 tel:3.0387 122 | s:30250 tel:3.0370 123 | s:30500 tel:3.0369 124 | s:30750 tel:3.0334 125 | s:31000 tel:3.0320 126 | s:31250 tel:3.0306 127 | s:31500 tel:3.0289 128 | s:31750 tel:3.0280 129 | s:32000 tel:3.0252 130 | s:32250 tel:3.0259 131 | s:32500 tel:3.0239 132 | s:32750 tel:3.0227 133 | s:33000 tel:3.0194 134 | s:33250 tel:3.0189 135 | s:33500 tel:3.0168 136 | s:33750 tel:3.0168 137 | s:34000 tel:3.0138 138 | s:34250 tel:3.0125 139 | s:34500 tel:3.0116 140 | s:34750 tel:3.0100 141 | s:35000 tel:3.0082 142 | s:35250 tel:3.0075 143 | s:35500 tel:3.0051 144 | s:35750 tel:3.0037 145 | s:36000 tel:3.0026 146 | s:36250 tel:3.0015 147 | s:36500 tel:3.0000 148 | s:36750 tel:2.9987 149 | s:37000 tel:2.9974 150 | s:37250 tel:2.9954 151 | s:37500 tel:2.9938 152 | s:37750 tel:2.9927 153 | s:38000 tel:2.9911 154 | s:38250 tel:2.9901 155 | s:38500 tel:2.9890 156 | s:38750 tel:2.9871 157 | s:39000 tel:2.9865 158 | s:39250 tel:2.9847 159 | s:39500 tel:2.9833 160 | s:39750 tel:2.9818 161 | s:40000 tel:2.9812 162 | s:40250 tel:2.9798 163 | s:40500 tel:2.9781 164 | s:40750 tel:2.9772 165 | s:41000 tel:2.9762 166 | s:41250 tel:2.9749 167 | s:41500 tel:2.9734 168 | s:41750 tel:2.9724 169 | s:42000 tel:2.9717 170 | s:42250 tel:2.9702 171 | s:42500 tel:2.9685 172 | s:42750 tel:2.9681 173 | s:43000 tel:2.9667 174 | s:43250 tel:2.9651 175 | s:43500 tel:2.9641 176 | s:43750 tel:2.9633 177 | s:44000 tel:2.9638 178 | s:44250 tel:2.9612 179 | s:44500 tel:2.9599 180 | s:44750 tel:2.9592 181 | s:45000 tel:2.9581 182 | s:45250 tel:2.9569 183 | s:45500 tel:2.9563 184 | s:45750 tel:2.9549 185 | s:46000 tel:2.9541 186 | s:46250 tel:2.9530 187 | s:46500 tel:2.9520 188 | s:46750 tel:2.9515 189 | s:47000 tel:2.9504 190 | s:47250 tel:2.9494 191 | s:47500 tel:2.9485 192 | s:47750 tel:2.9475 193 | s:48000 tel:2.9467 194 | s:48250 tel:2.9459 195 | s:48500 tel:2.9451 196 | s:48750 tel:2.9440 197 | s:49000 tel:2.9433 198 | s:49250 tel:2.9428 199 | s:49500 tel:2.9419 200 | s:49750 tel:2.9413 201 | s:50000 tel:2.9405 202 | s:50250 tel:2.9399 203 | s:50500 tel:2.9394 204 | s:50750 tel:2.9388 205 | s:51000 tel:2.9379 206 | s:51250 tel:2.9374 207 | s:51500 tel:2.9367 208 | s:51750 tel:2.9361 209 | s:52000 tel:2.9357 210 | s:52250 tel:2.9350 211 | s:52500 tel:2.9346 212 | s:52750 tel:2.9341 213 | s:53000 tel:2.9336 214 | s:53250 tel:2.9332 215 | s:53500 tel:2.9328 216 | s:53750 tel:2.9324 217 | s:54000 tel:2.9320 218 | s:54250 tel:2.9317 219 | s:54500 tel:2.9314 220 | s:54750 tel:2.9309 221 | s:55000 tel:2.9306 222 | s:55250 tel:2.9303 223 | s:55500 tel:2.9301 224 | s:55750 tel:2.9299 225 | s:56000 tel:2.9296 226 | s:56250 tel:2.9294 227 | s:56500 tel:2.9292 228 | s:56750 tel:2.9290 229 | s:57000 tel:2.9289 230 | s:57250 tel:2.9287 231 | s:57500 tel:2.9286 232 | s:57750 tel:2.9285 233 | s:58000 tel:2.9284 234 | s:58250 tel:2.9283 235 | s:58500 tel:2.9283 236 | s:58750 tel:2.9282 237 | s:59000 tel:2.9282 238 | s:59250 tel:2.9282 239 | s:59500 tel:2.9282 240 | s:59750 tel:2.9281 241 | s:60000 tel:2.9282 242 | -------------------------------------------------------------------------------- /records/012625_BatchSize/README.md: -------------------------------------------------------------------------------- 1 | # 11/26/25 - Misc Tweaks 2 | 3 | Changelogs: 4 | 5 | 1. Reduced per-device training sequence length from `64*1024` to `48*1024`. See [Critical Batch Size](https://arxiv.org/abs/2410.21676) literature. 6 | 2. Increased per-device eval sequence length from `64*1024` to `4*64*1024`. This improves `val_loss` by `~0.0015` or an equivalent of a reduction of 10 training steps. Overall it saves `~1 sec` of training time. 7 | 3. Modified scales for `fp8` training of LM Head. Saves `1 sec` and improves `val_loss` by as much as `~0.01` after reducing training sequence length down to `48*1024`. I don't know wtf is causing this and I'm NOT going crazy about this. I have evidence. See `records/012625_MiscTweaks/no-autocast-same-fp8-scales`. 8 | - `w_s = 2.0**9` (from `2.0**5`) 9 | - `grad_s = 2.0**19` (from `2.0**29`) 10 | 4. Upgraded PyTorch to 2.7.0 nightly version (20250125) for CUDA 12.6 11 | - `pip install --pre torch==2.7.0.dev20250125+cu126 --index-url https://download.pytorch.org/whl/nightly/cu126` 12 | 13 | ![](val_losses.png) 14 | ![](wallclock.png) 15 | 16 | ```python 17 | accs = [3.2806, 3.2771, 3.2829, 3.2813, 3.2789, 3.2774, 3.2798, 3.2759, 3.2794, 18 | 3.2775, 3.2768, 3.2793, 3.2838, 3.2779, 3.2782, 3.2770, 3.2775, 3.2784, 19 | 3.2782, 3.2776, 3.2814, 3.2785, 3.2793, 3.2797, 3.2782, 3.2789, 3.2759, 20 | 3.2803, 3.2780, 3.2782, 3.2744, 3.2819, 3.2801, 3.2782, 3.2771, 3.2782, 21 | 3.2792, 3.2778, 3.2774, 3.2798, 3.2799, 3.2768, 3.2814, 3.2816, 3.2785, 22 | 3.2817, 3.2801, 3.2755, 3.2780, 3.2774, 3.2797, 3.2789, 3.2843, 3.2777, 23 | 3.2777, 3.2768, 3.2763, 3.2773, 3.2792, 3.2819, 3.2778, 3.2792, 3.2782, 24 | 3.2776, 3.2752, 3.2792, 3.2786, 3.2793, 3.2773, 3.2804, 3.2802, 3.2779, 25 | 3.2780, 3.2779, 3.2801, 3.2773, 3.2802, 3.2770, 3.2785, 3.2772, 3.2818] 26 | 27 | import scipy.stats 28 | print(f'p={scipy.stats.ttest_1samp(accs, 3.28, alternative='less').pvalue:.8f}') 29 | # p=0.00000002 (statistically significant) 30 | 31 | import torch 32 | print(torch.std_mean(torch.tensor(accs))) 33 | # (tensor(0.0019), tensor(3.2787)) 34 | ``` 35 | 36 | --- 37 | 38 | ![](ablations.png) 39 | -------------------------------------------------------------------------------- /records/012625_BatchSize/ablations.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KellerJordan/modded-nanogpt/ca964e982191830eebbd155e185937077511a8aa/records/012625_BatchSize/ablations.png -------------------------------------------------------------------------------- /records/012625_BatchSize/val_losses.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KellerJordan/modded-nanogpt/ca964e982191830eebbd155e185937077511a8aa/records/012625_BatchSize/val_losses.png -------------------------------------------------------------------------------- /records/012625_BatchSize/wallclock.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KellerJordan/modded-nanogpt/ca964e982191830eebbd155e185937077511a8aa/records/012625_BatchSize/wallclock.png -------------------------------------------------------------------------------- /records/030625_GPT2MediumLongerCooldown/README.md: -------------------------------------------------------------------------------- 1 | Changelog: 2 | - Increased learning rate cooldown phase duration from 40% to 60% of the total train duration 3 | - Decreased steps from 7050 to 6950 4 | 5 | Record author: @YouJiacheng 6 | 7 | Results from 36 runs: (P(<2.92) > 99.9%) 8 | ``` 9 | [2.9199, 2.9185, 2.9195, 2.9194, 2.9206, 2.9209, 2.9188, 2.9193, 2.9207, 2.9181, 2.9186, 2.9196, 2.9202, 2.9174, 2.9185, 2.9197, 2.9179, 2.9204, 2.9184, 2.9186, 2.9178, 2.9192, 2.9194, 2.9194, 2.9189, 2.9193, 2.9212, 2.9181, 2.9192, 2.9203, 2.9198, 2.9192, 2.919, 2.9196, 2.9182, 2.9186] 10 | ``` 11 | 12 | 13 | -------------------------------------------------------------------------------- /records/041625_GPT2Medium_Record7/README.md: -------------------------------------------------------------------------------- 1 | 1. remove FP8 lm head, 6710→6450 steps, ~0 wall clock change 2 | 2. logits softcap tanh→ISRU, ~1.5ms/step faster, slightly better (2.91970 -> 2.91944) 3 | 3. New sharded mixed precision Muon, remove CastedLinear, ~3.6ms/step faster 4 | 4. merge qkv&o weight, ~0.8ms/step faster 5 | 5. merge scalars weight, ~0.6ms/step faster 6 | -------------------------------------------------------------------------------- /records/060624_AdamW/README.md: -------------------------------------------------------------------------------- 1 | This is the log for my baseline AdamW training to which I compared the new Muon and SOAP optimizers. 2 | 3 | just the log, which is in the old llm.c format ("tel" lines are val loss) 4 | 5 | this was batch size 2^19, so ~5B tokens 6 | 7 | was learning rate 0.0018, warmup=250, warmdown=2000, betas=(0.9, 0.95) IIRC 8 | 9 | -------------------------------------------------------------------------------- /records/100924_SOAP/README.md: -------------------------------------------------------------------------------- 1 | # SOAP record October 9 2024 2 | 3 | * New sample efficiency record: <3.28 validation loss in 3.15B tokens 4 | * Uses SOAP optimizer ([Vyas et al. 2024](https://arxiv.org/abs/2409.11321)) 5 | * 363ms/step - not a new wallclock record (SOAP is in active development to reduce the wallclock overhead for distributed training, so this may change) 6 | * Set by Nikhil Vyas @vyasnikhil96. Hyperparameters also tuned slightly by me 7 | * [https://x.com/vyasnikhil96/status/1842656792217858063](https://x.com/vyasnikhil96/status/1842656792217858063) 8 | * [https://github.com/nikhilvyas/modded-nanogpt-SOAP/tree/master](https://github.com/nikhilvyas/modded-nanogpt-SOAP/tree/master) 9 | 10 | -------------------------------------------------------------------------------- /records/101024_Muon/train_gpt2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | with open(sys.argv[0]) as f: 4 | code = f.read() # read the code of this file ASAP, for logging 5 | import uuid 6 | import glob 7 | import time 8 | from dataclasses import dataclass 9 | 10 | import numpy as np 11 | import torch 12 | from torch import nn 13 | import torch.nn.functional as F 14 | import torch.distributed as dist 15 | import torch._inductor.config as config 16 | from torch.nn.parallel import DistributedDataParallel as DDP 17 | 18 | # ----------------------------------------------------------------------------- 19 | # Muon optimizer 20 | 21 | def zeropower_via_svd(G, steps=None): 22 | U, S, V = G.svd() 23 | return U @ V.T 24 | 25 | @torch.compile 26 | def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7): 27 | """ 28 | Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a 29 | quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose 30 | of minimizing steps, it turns out to be empirically effective to keep increasing the slope at 31 | zero even beyond the point where the iteration no longer converges all the way to one everywhere 32 | on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T 33 | where S' is diagonal with S_{ii}' \sim Uniform(0.5, 1.5), which turns out not to hurt model 34 | performance at all relative to UV^T, where USV^T = G is the SVD. 35 | """ 36 | assert len(G.shape) == 2 37 | a, b, c = (3.4445, -4.7750, 2.0315) 38 | X = G.bfloat16() / (G.norm() + eps) # ensure top singular value <= 1 39 | if G.size(0) > G.size(1): 40 | X = X.T 41 | for _ in range(steps): 42 | A = X @ X.T 43 | B = A @ X 44 | X = a * X + b * B + c * A @ B 45 | if G.size(0) > G.size(1): 46 | X = X.T 47 | return X.to(G.dtype) 48 | 49 | zeropower_backends = dict(svd=zeropower_via_svd, newtonschulz5=zeropower_via_newtonschulz5) 50 | 51 | class Muon(torch.optim.Optimizer): 52 | """ 53 | Muon: MomentUm Orthogonalized by Newton-schulz 54 | 55 | Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- 56 | processing step, in which each 2D parameter's update is replaced with the nearest orthogonal 57 | matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has 58 | the advantage that it can be stably run in bfloat16 on the GPU. 59 | 60 | Some warnings: 61 | - This optimizer assumes that all parameters passed in are 2D. 62 | - It should not be used for the embedding layer, the final fully connected layer, or any {0,1}-D 63 | parameters; those should all be optimized by a standard method (e.g., AdamW). 64 | - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions. 65 | - We believe it is unlikely to work well for training with small batch size. 66 | - We believe it may not work well for finetuning pretrained models, but we haven't tested this. 67 | - We have not yet tried this optimizer for training scenarios larger than NanoGPT (124M). 68 | 69 | Arguments: 70 | lr: The learning rate used by the internal SGD. 71 | momentum: The momentum used by the internal SGD. 72 | nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) 73 | backend: The chosen backend for the orthogonalization step. (recommended: 'newtonschulz5') 74 | backend_steps: The number of iteration steps to use in the backend, if it is iterative. 75 | """ 76 | def __init__(self, params, lr=3e-4, momentum=0.95, nesterov=True, backend='newtonschulz5', backend_steps=5): 77 | defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, backend=backend, backend_steps=backend_steps) 78 | super().__init__(params, defaults) 79 | 80 | def step(self): 81 | for group in self.param_groups: 82 | lr = group['lr'] 83 | momentum = group['momentum'] 84 | zeropower_backend = zeropower_backends[group['backend']] 85 | for p in group['params']: 86 | g = p.grad 87 | if g is None: 88 | continue 89 | state = self.state[p] 90 | if 'momentum_buffer' not in state: 91 | state['momentum_buffer'] = torch.zeros_like(g) 92 | buf = state['momentum_buffer'] 93 | buf.mul_(momentum).add_(g) 94 | if group['nesterov']: 95 | g = g.add(buf, alpha=momentum) 96 | if g.size(0) == 3 * g.size(1): # split grouped QKV parameters 97 | g = torch.cat([zeropower_backend(g1, steps=group['backend_steps']) for g1 in g.split(g.size(1))]) 98 | scale = g.size(1)**0.5 99 | else: 100 | g = zeropower_backend(g, steps=group['backend_steps']) 101 | scale = max(g.size(0), g.size(1))**0.5 # scale to have update.square().mean() == 1 102 | p.data.add_(g, alpha=-lr * scale) 103 | 104 | # ----------------------------------------------------------------------------- 105 | # PyTorch nn.Module definitions for the GPT-2 model 106 | 107 | class Rotary(torch.nn.Module): 108 | 109 | def __init__(self, dim, base=10000): 110 | super().__init__() 111 | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) 112 | self.register_buffer("inv_freq", inv_freq) 113 | self.seq_len_cached = None 114 | self.cos_cached = None 115 | self.sin_cached = None 116 | 117 | def forward(self, x): 118 | seq_len = x.shape[1] 119 | if seq_len != self.seq_len_cached: 120 | self.seq_len_cached = seq_len 121 | t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) 122 | freqs = torch.outer(t, self.inv_freq).to(x.device) 123 | self.cos_cached = freqs.cos() 124 | self.sin_cached = freqs.sin() 125 | return self.cos_cached[None, :, None, :], self.sin_cached[None, :, None, :] 126 | 127 | def apply_rotary_emb(x, cos, sin): 128 | assert x.ndim == 4 # multihead attention 129 | d = x.shape[3]//2 130 | x1 = x[..., :d] 131 | x2 = x[..., d:] 132 | y1 = x1 * cos + x2 * sin 133 | y2 = x1 * (-sin) + x2 * cos 134 | return torch.cat([y1, y2], 3) 135 | 136 | def rmsnorm(x0, eps=1e-6): 137 | x = x0.float() 138 | x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) 139 | return x.type_as(x0) 140 | 141 | class CausalSelfAttention(nn.Module): 142 | 143 | def __init__(self, config): 144 | super().__init__() 145 | self.n_head = config.n_head 146 | self.n_embd = config.n_embd 147 | self.head_dim = self.n_embd // self.n_head 148 | assert self.n_embd % self.n_head == 0 149 | # key, query, value projections for all heads, but in a batch 150 | self.c_attn = nn.Linear(self.n_embd, 3 * self.n_embd, bias=False) 151 | # output projection 152 | self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) 153 | self.rotary = Rotary(self.head_dim) 154 | 155 | def forward(self, x): 156 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) 157 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 158 | qkv = self.c_attn(x) 159 | q, k, v = qkv.split(self.n_embd, dim=2) 160 | k = k.view(B, T, self.n_head, self.head_dim) 161 | q = q.view(B, T, self.n_head, self.head_dim) 162 | v = v.view(B, T, self.n_head, self.head_dim) 163 | cos, sin = self.rotary(q) 164 | q = apply_rotary_emb(q, cos, sin) 165 | k = apply_rotary_emb(k, cos, sin) 166 | y = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True) 167 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side 168 | # output projection 169 | y = self.c_proj(y) 170 | return y 171 | 172 | class MLP(nn.Module): 173 | 174 | def __init__(self, config): 175 | super().__init__() 176 | self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False) 177 | self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False) 178 | 179 | def forward(self, x): 180 | x = self.c_fc(x) 181 | x = F.gelu(x) 182 | x = self.c_proj(x) 183 | return x 184 | 185 | class Block(nn.Module): 186 | 187 | def __init__(self, config): 188 | super().__init__() 189 | self.attn = CausalSelfAttention(config) 190 | self.mlp = MLP(config) 191 | self.attn_scale = (1 / (2 * config.n_layer)**0.5) 192 | 193 | def forward(self, x): 194 | x = x + self.attn_scale * self.attn(rmsnorm(x)) 195 | x = x + self.mlp(rmsnorm(x)) 196 | return x 197 | 198 | # ----------------------------------------------------------------------------- 199 | # The main GPT-2 model 200 | 201 | @dataclass 202 | class GPTConfig: 203 | vocab_size : int = 50257 204 | n_layer : int = 12 205 | n_head : int = 12 206 | n_embd : int = 768 207 | 208 | class GPT(nn.Module): 209 | 210 | def __init__(self, config): 211 | super().__init__() 212 | self.config = config 213 | 214 | self.transformer = nn.ModuleDict(dict( 215 | wte = nn.Embedding(config.vocab_size, config.n_embd), 216 | h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), 217 | )) 218 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 219 | self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying 220 | 221 | def forward(self, idx, targets=None, return_logits=True): 222 | b, t = idx.size() 223 | pos = torch.arange(0, t, dtype=torch.long, device=idx.device) # shape (t) 224 | 225 | # forward the GPT model itself 226 | x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) 227 | 228 | for block in self.transformer.h: 229 | x = block(x) 230 | x = rmsnorm(x) 231 | 232 | if targets is not None: 233 | # if we are given some desired targets also calculate the loss 234 | logits = self.lm_head(x) 235 | logits = logits.float() # use tf32/fp32 for logits 236 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) 237 | else: 238 | # inference-time mini-optimization: only forward the lm_head on the very last position 239 | logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim 240 | logits = logits.float() # use tf32/fp32 for logits 241 | loss = None 242 | 243 | # there are performance reasons why not returning logits is prudent, if not needed 244 | if not return_logits: 245 | logits = None 246 | 247 | return logits, loss 248 | 249 | # ----------------------------------------------------------------------------- 250 | # Our own simple Distributed Data Loader 251 | 252 | def _peek_data_shard(filename): 253 | # only reads the header, returns header data 254 | with open(filename, "rb") as f: 255 | # first read the header, which is 256 int32 integers (4 bytes each) 256 | header = np.frombuffer(f.read(256*4), dtype=np.int32) 257 | if header[0] != 20240520: 258 | print("ERROR: magic number mismatch in the data .bin file!") 259 | print("---> HINT: Are you passing in a correct file with --input_bin?") 260 | print("---> HINT: Dataset encoding changed recently, re-run data prepro or refer again to README") 261 | print("---> HINT: For example re-run: `python dev/data/tinyshakespeare.py`, then re-try") 262 | exit(1) 263 | assert header[1] == 1, "unsupported version" 264 | ntok = header[2] # number of tokens (claimed) 265 | return ntok # for now just return the number of tokens 266 | 267 | def _load_data_shard(filename): 268 | with open(filename, "rb") as f: 269 | # first read the header, which is 256 int32 integers (4 bytes each) 270 | header = np.frombuffer(f.read(256*4), dtype=np.int32) 271 | assert header[0] == 20240520, "magic number mismatch in the data .bin file" 272 | assert header[1] == 1, "unsupported version" 273 | ntok = header[2] # number of tokens (claimed) 274 | # the rest of it are tokens, stored as uint16 275 | tokens = np.frombuffer(f.read(), dtype=np.uint16) 276 | assert len(tokens) == ntok, "number of tokens read does not match header?" 277 | return tokens 278 | 279 | class DistributedDataLoader: 280 | def __init__(self, filename_pattern, B, T, process_rank, num_processes): 281 | self.process_rank = process_rank 282 | self.num_processes = num_processes 283 | self.B = B 284 | self.T = T 285 | 286 | # glob files that match the pattern 287 | self.files = sorted(glob.glob(filename_pattern)) 288 | assert len(self.files) > 0, f"did not find any files that match the pattern {filename_pattern}" 289 | 290 | # load and validate all data shards, count number of tokens in total 291 | ntok_total = 0 292 | for fname in self.files: 293 | shard_ntok = _peek_data_shard(fname) 294 | assert shard_ntok >= num_processes * B * T + 1 295 | ntok_total += int(shard_ntok) 296 | self.ntok_total = ntok_total 297 | 298 | # kick things off 299 | self.reset() 300 | 301 | def reset(self): 302 | self.current_shard = 0 303 | self.current_position = self.process_rank * self.B * self.T 304 | self.tokens = _load_data_shard(self.files[self.current_shard]) 305 | 306 | def advance(self): # advance to next data shard 307 | self.current_shard = (self.current_shard + 1) % len(self.files) 308 | self.current_position = self.process_rank * self.B * self.T 309 | self.tokens = _load_data_shard(self.files[self.current_shard]) 310 | 311 | def next_batch(self): 312 | B = self.B 313 | T = self.T 314 | buf = self.tokens[self.current_position : self.current_position+B*T+1] 315 | buf = torch.tensor(buf.astype(np.int32), dtype=torch.long) 316 | x = (buf[:-1]).view(B, T) # inputs 317 | y = (buf[1:]).view(B, T) # targets 318 | # advance current position and load next shard if necessary 319 | self.current_position += B * T * self.num_processes 320 | if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens): 321 | self.advance() 322 | return x.cuda(), y.cuda() 323 | 324 | # ----------------------------------------------------------------------------- 325 | # int main 326 | 327 | @dataclass 328 | class Hyperparameters: 329 | # data hyperparams 330 | input_bin : str = 'data/fineweb10B/fineweb_train_*.bin' # input .bin to train on 331 | input_val_bin : str = 'data/fineweb10B/fineweb_val_*.bin' # input .bin to eval validation loss on 332 | # optimization hyperparams 333 | batch_size : int = 8*64 # batch size, in sequences, across all devices 334 | device_batch_size : int = 64 # batch size, in sequences, per device 335 | sequence_length : int = 1024 # sequence length, in tokens 336 | num_iterations : int = 6200 # number of iterations to run 337 | learning_rate : float = 0.0036 338 | warmup_iters : int = 0 339 | warmdown_iters : int = 1800 # number of iterations of linear warmup/warmdown for triangular or trapezoidal schedule 340 | weight_decay : float = 0 341 | # evaluation and logging hyperparams 342 | val_loss_every : int = 125 # every how many steps to evaluate val loss? 0 for only at the end 343 | val_tokens : int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons 344 | save_every : int = 0 # every how many steps to save the checkpoint? 0 for only at the end 345 | args = Hyperparameters() 346 | 347 | # set up DDP (distributed data parallel). torchrun sets this env variable 348 | assert torch.cuda.is_available() 349 | dist.init_process_group(backend='nccl') 350 | ddp_rank = int(os.environ['RANK']) 351 | ddp_local_rank = int(os.environ['LOCAL_RANK']) 352 | ddp_world_size = int(os.environ['WORLD_SIZE']) 353 | device = f'cuda:{ddp_local_rank}' 354 | torch.cuda.set_device(device) 355 | print(f"using device: {device}") 356 | master_process = (ddp_rank == 0) # this process will do logging, checkpointing etc. 357 | 358 | # convenience variables 359 | B, T = args.device_batch_size, args.sequence_length 360 | # calculate the number of steps to take in the val loop. 361 | assert args.val_tokens % (B * T * ddp_world_size) == 0 362 | val_steps = args.val_tokens // (B * T * ddp_world_size) 363 | # calculate the steps of gradient accumulation required to attain the desired global batch size. 364 | assert args.batch_size % (B * ddp_world_size) == 0 365 | train_accumulation_steps = args.batch_size // (B * ddp_world_size) 366 | 367 | # load tokens 368 | train_loader = DistributedDataLoader(args.input_bin, B, T, ddp_rank, ddp_world_size) 369 | val_loader = DistributedDataLoader(args.input_val_bin, B, T, ddp_rank, ddp_world_size) 370 | if master_process: 371 | print(f"Training DataLoader: total number of tokens: {train_loader.ntok_total} across {len(train_loader.files)} files") 372 | print(f"Validation DataLoader: total number of tokens: {val_loader.ntok_total} across {len(val_loader.files)} files") 373 | x, y = train_loader.next_batch() 374 | 375 | # init the model from scratch 376 | num_vocab = 50257 377 | model = GPT(GPTConfig(vocab_size=num_vocab, n_layer=12, n_head=12, n_embd=768)) 378 | model = model.cuda() 379 | if hasattr(config, "coordinate_descent_tuning"): 380 | config.coordinate_descent_tuning = True # suggested by @Chillee 381 | model = torch.compile(model) 382 | # here we wrap model into DDP container 383 | model = DDP(model, device_ids=[ddp_local_rank]) 384 | raw_model = model.module # always contains the "raw" unwrapped model 385 | ctx = torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16) 386 | 387 | # init the optimizer(s) 388 | optimizer1 = torch.optim.AdamW(raw_model.lm_head.parameters(), lr=args.learning_rate, betas=(0.9, 0.95), 389 | weight_decay=args.weight_decay, fused=True) 390 | optimizer2 = Muon(raw_model.transformer.h.parameters(), lr=0.1*args.learning_rate, momentum=0.95) 391 | optimizers = [optimizer1, optimizer2] 392 | # learning rate decay scheduler (linear warmup and warmdown) 393 | def get_lr(it): 394 | assert it <= args.num_iterations 395 | # 1) linear warmup for warmup_iters steps 396 | if it < args.warmup_iters: 397 | return (it+1) / args.warmup_iters 398 | # 2) constant lr for a while 399 | elif it < args.num_iterations - args.warmdown_iters: 400 | return 1.0 401 | # 3) linear warmdown 402 | else: 403 | decay_ratio = (args.num_iterations - it) / args.warmdown_iters 404 | return decay_ratio 405 | schedulers = [torch.optim.lr_scheduler.LambdaLR(opt, get_lr) for opt in optimizers] 406 | 407 | # begin logging 408 | if master_process: 409 | run_id = str(uuid.uuid4()) 410 | logdir = 'logs/%s/' % run_id 411 | os.makedirs(logdir, exist_ok=True) 412 | logfile = 'logs/%s.txt' % run_id 413 | # create the log file 414 | with open(logfile, "w") as f: 415 | # begin the log by printing this file (the Python code) 416 | f.write('='*100 + '\n') 417 | f.write(code) 418 | f.write('='*100 + '\n') 419 | # log information about the hardware/software environment this is running on 420 | # and print the full `nvidia-smi` to file 421 | f.write(f"Running pytorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}\nnvidia-smi:\n") 422 | import subprocess 423 | result = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) 424 | f.write(f'{result.stdout}\n') 425 | f.write('='*100 + '\n') 426 | 427 | training_time_ms = 0 428 | # start the clock 429 | torch.cuda.synchronize() 430 | t0 = time.time() 431 | # begin training 432 | train_loader.reset() 433 | for step in range(args.num_iterations + 1): 434 | last_step = (step == args.num_iterations) 435 | # This effectively ignores timing first 10 steps, which are slower for weird reasons. 436 | # Alternately, and slightly more correctly in terms of benchmarking, we could do 10 437 | # steps with dummy data first, and then re-initialize the model and reset the loader. 438 | if step == 10: 439 | training_time_ms = 0 440 | t0 = time.time() 441 | timed_steps = float('nan') if step <= 11 else (step - 10) + 1 # <= 11 to avoid bug in val 442 | 443 | # once in a while evaluate the validation dataset 444 | if (last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0)): 445 | # stop the clock 446 | torch.cuda.synchronize() 447 | training_time_ms += 1000 * (time.time() - t0) 448 | # run validation batches 449 | model.eval() 450 | val_loader.reset() 451 | val_loss = 0.0 452 | for _ in range(val_steps): 453 | x_val, y_val = val_loader.next_batch() 454 | with torch.no_grad(): # of course, we'd like to use ctx here too, but that creates a torch.compile error for some reason 455 | _, loss = model(x_val, y_val, return_logits=False) 456 | val_loss += loss 457 | dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) 458 | val_loss /= val_steps 459 | # log val loss to console and to logfile 460 | if master_process: 461 | print(f'step:{step}/{args.num_iterations} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms') 462 | with open(logfile, "a") as f: 463 | f.write(f'step:{step}/{args.num_iterations} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms\n') 464 | # start the clock again 465 | torch.cuda.synchronize() 466 | t0 = time.time() 467 | 468 | if master_process and (last_step or (args.save_every > 0 and step % args.save_every == 0)): 469 | # stop the clock 470 | torch.cuda.synchronize() 471 | training_time_ms += 1000 * (time.time() - t0) 472 | # save the state of the training process 473 | log = dict(step=step, code=code, model=raw_model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) 474 | torch.save(log, 'logs/%s/state_step%06d.pt' % (run_id, step)) 475 | # start the clock again 476 | torch.cuda.synchronize() 477 | t0 = time.time() 478 | 479 | # bit confusing: we want to make sure to eval on 0th iteration 480 | # but also after the very last iteration. so we loop for step <= num_iterations 481 | # instead of just < num_iterations (one extra due to <=), only to do 482 | # the validation/sampling one last time, and then we break right here as we're done. 483 | if last_step: 484 | break 485 | 486 | # --------------- TRAINING SECTION BEGIN ----------------- 487 | model.train() 488 | for i in range(1, train_accumulation_steps+1): 489 | # forward pass 490 | with ctx: 491 | _, loss = model(x, y, return_logits=False) 492 | train_loss = loss.detach() 493 | # advance the dataset for the next batch 494 | x, y = train_loader.next_batch() 495 | # backward pass 496 | if i < train_accumulation_steps: 497 | with model.no_sync(): # there's no need to sync gradients every accumulation step 498 | loss.backward() 499 | else: 500 | loss.backward() # just sync on the last step 501 | for p in model.parameters(): 502 | p.grad /= train_accumulation_steps 503 | # step the optimizers and schedulers 504 | for opt, sched in zip(optimizers, schedulers): 505 | opt.step() 506 | sched.step() 507 | # null the gradients 508 | model.zero_grad(set_to_none=True) 509 | # --------------- TRAINING SECTION END ------------------- 510 | # everything that follows now is just diagnostics, prints, logging, etc. 511 | 512 | #dist.all_reduce(train_loss, op=dist.ReduceOp.AVG) # all-reducing the training loss would be more correct in terms of logging, but slower 513 | if master_process: 514 | approx_time = training_time_ms + 1000 * (time.time() - t0) 515 | print(f"step:{step+1}/{args.num_iterations} train_loss:{train_loss.item():.4f} train_time:{approx_time:.0f}ms step_avg:{approx_time/timed_steps:.2f}ms") 516 | with open(logfile, "a") as f: 517 | f.write(f"step:{step+1}/{args.num_iterations} train_loss:{train_loss.item():.4f} train_time:{approx_time:.0f}ms step_avg:{approx_time/timed_steps:.2f}ms\n") 518 | 519 | if master_process: 520 | print(f"peak memory consumption: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB") 521 | 522 | # ------------------------------------------------------------------------- 523 | # clean up nice 524 | dist.destroy_process_group() 525 | -------------------------------------------------------------------------------- /records/101324_llmc/README.md: -------------------------------------------------------------------------------- 1 | This is a log produced by running the current version of Andrej Karpathy's [llm.c](https://github.com/karpathy/llm.c), as of October 13th 2024. 2 | 3 | It was run on a node with 8x H100 HBM3 according to the instructions [here](https://github.com/karpathy/llm.c/discussions/481). 4 | The mean per-step time was 140ms. The total number of training tokens is 10.26B. The final validation loss was **3.2722**. 5 | 6 | This is (significantly) better than the quoted result of **3.29** val loss in 7 | [Andrej Karpathy's May 28th GPT-2 replication discussion](https://github.com/karpathy/llm.c/discussions/481#:~:text=By%20the%20end%20of%20the%20optimization%20we%27ll%20get%20to%20about%203.29). 8 | So it appears that there have been some improvements to the training algorithm used by llm.c since then. 9 | 10 | Note that the set of examples which llm.c uses for validation appears to be the same as what we do in this repo, i.e., the first `10 * 2**20` tokens of the val set. 11 | 12 | -------------------------------------------------------------------------------- /records/101424_ModernArch/train_gpt2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | with open(sys.argv[0]) as f: 4 | code = f.read() # read the code of this file ASAP, for logging 5 | import uuid 6 | import glob 7 | import time 8 | from dataclasses import dataclass 9 | 10 | import numpy as np 11 | import torch 12 | from torch import nn 13 | import torch.nn.functional as F 14 | import torch.distributed as dist 15 | import torch._inductor.config as config 16 | from torch.nn.parallel import DistributedDataParallel as DDP 17 | 18 | # ----------------------------------------------------------------------------- 19 | # Muon optimizer 20 | 21 | def zeropower_via_svd(G, steps=None): 22 | U, S, V = G.svd() 23 | return U @ V.T 24 | 25 | @torch.compile 26 | def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7): 27 | """ 28 | Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a 29 | quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose 30 | of minimizing steps, it turns out to be empirically effective to keep increasing the slope at 31 | zero even beyond the point where the iteration no longer converges all the way to one everywhere 32 | on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T 33 | where S' is diagonal with S_{ii}' \sim Uniform(0.5, 1.5), which turns out not to hurt model 34 | performance at all relative to UV^T, where USV^T = G is the SVD. 35 | """ 36 | assert len(G.shape) == 2 37 | a, b, c = (3.4445, -4.7750, 2.0315) 38 | X = G.bfloat16() 39 | X /= (X.norm() + eps) # ensure top singular value <= 1 40 | if G.size(0) > G.size(1): 41 | X = X.T 42 | for _ in range(steps): 43 | A = X @ X.T 44 | B = A @ X 45 | X = a * X + b * B + c * A @ B 46 | if G.size(0) > G.size(1): 47 | X = X.T 48 | return X 49 | 50 | zeropower_backends = dict(svd=zeropower_via_svd, newtonschulz5=zeropower_via_newtonschulz5) 51 | 52 | class Muon(torch.optim.Optimizer): 53 | """ 54 | Muon - MomentUm Orthogonalized by Newton-schulz 55 | 56 | Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- 57 | processing step, in which each 2D parameter's update is replaced with the nearest orthogonal 58 | matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has 59 | the advantage that it can be stably run in bfloat16 on the GPU. 60 | 61 | Some warnings: 62 | - This optimizer assumes that all parameters passed in are 2D. 63 | - It should not be used for the embedding layer, the final fully connected layer, or any {0,1}-D 64 | parameters; those should all be optimized by a standard method (e.g., AdamW). 65 | - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions. 66 | - We believe it is unlikely to work well for training with small batch size. 67 | - We believe it may not work well for finetuning pretrained models, but we haven't tested this. 68 | - We have not yet tried this optimizer for training scenarios larger than NanoGPT (124M). 69 | 70 | Arguments: 71 | lr: The learning rate used by the internal SGD. 72 | momentum: The momentum used by the internal SGD. 73 | nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) 74 | backend: The chosen backend for the orthogonalization step. (recommended: 'newtonschulz5') 75 | backend_steps: The number of iteration steps to use in the backend, if it is iterative. 76 | """ 77 | def __init__(self, params, lr=3e-4, momentum=0.95, nesterov=True, backend='newtonschulz5', backend_steps=5): 78 | defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, backend=backend, backend_steps=backend_steps) 79 | super().__init__(params, defaults) 80 | 81 | def step(self): 82 | for group in self.param_groups: 83 | lr = group['lr'] 84 | momentum = group['momentum'] 85 | zeropower_backend = zeropower_backends[group['backend']] 86 | for p in group['params']: 87 | g = p.grad 88 | if g is None: 89 | continue 90 | state = self.state[p] 91 | if 'momentum_buffer' not in state: 92 | state['momentum_buffer'] = torch.zeros_like(g) 93 | buf = state['momentum_buffer'] 94 | buf.mul_(momentum).add_(g) 95 | if group['nesterov']: 96 | g = g.add(buf, alpha=momentum) 97 | if g.size(0) == 3 * g.size(1): # split grouped QKV parameters 98 | g = torch.cat([zeropower_backend(g1, steps=group['backend_steps']) for g1 in g.split(g.size(1))]) 99 | scale = g.size(1)**0.5 100 | else: 101 | g = zeropower_backend(g, steps=group['backend_steps']) 102 | scale = max(g.size(0), g.size(1))**0.5 # scale to have update.square().mean() == 1 103 | p.data.add_(g, alpha=-lr * scale) 104 | 105 | # ----------------------------------------------------------------------------- 106 | # PyTorch nn.Module definitions for the GPT-2 model 107 | 108 | class Rotary(torch.nn.Module): 109 | 110 | def __init__(self, dim, base=10000): 111 | super().__init__() 112 | self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) 113 | self.seq_len_cached = None 114 | self.cos_cached = None 115 | self.sin_cached = None 116 | 117 | def forward(self, x): 118 | seq_len = x.shape[1] 119 | if seq_len != self.seq_len_cached: 120 | self.seq_len_cached = seq_len 121 | t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) 122 | freqs = torch.outer(t, self.inv_freq).to(x.device) 123 | self.cos_cached = freqs.cos().bfloat16() 124 | self.sin_cached = freqs.sin().bfloat16() 125 | return self.cos_cached[None, :, None, :], self.sin_cached[None, :, None, :] 126 | 127 | def apply_rotary_emb(x, cos, sin): 128 | assert x.ndim == 4 # multihead attention 129 | d = x.shape[3]//2 130 | x1 = x[..., :d] 131 | x2 = x[..., d:] 132 | y1 = x1 * cos + x2 * sin 133 | y2 = x1 * (-sin) + x2 * cos 134 | return torch.cat([y1, y2], 3).type_as(x) 135 | 136 | class CausalSelfAttention(nn.Module): 137 | 138 | def __init__(self, config): 139 | super().__init__() 140 | self.n_head = config.n_head 141 | self.n_embd = config.n_embd 142 | self.head_dim = self.n_embd // self.n_head 143 | assert self.n_embd % self.n_head == 0 144 | self.c_q = nn.Linear(self.n_embd, self.n_embd, bias=False) 145 | self.c_k = nn.Linear(self.n_embd, self.n_embd, bias=False) 146 | self.c_v = nn.Linear(self.n_embd, self.n_embd, bias=False) 147 | # output projection 148 | self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) 149 | self.c_proj.weight.data.zero_() # zero init suggested by @Grad62304977 150 | self.rotary = Rotary(self.head_dim) 151 | 152 | def forward(self, x): 153 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) 154 | q = self.c_q(x).view(B, T, self.n_head, self.head_dim) 155 | k = self.c_k(x).view(B, T, self.n_head, self.head_dim) 156 | v = self.c_v(x).view(B, T, self.n_head, self.head_dim) 157 | cos, sin = self.rotary(q) 158 | q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) 159 | q, k = F.rms_norm(q, (q.size(-1),)), F.rms_norm(k, (k.size(-1),)) # QK norm suggested by @Grad62304977 160 | y = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True) 161 | y = y.transpose(1, 2).contiguous().view_as(x) # re-assemble all head outputs side by side 162 | y = self.c_proj(y) 163 | return y 164 | 165 | class MLP(nn.Module): 166 | 167 | def __init__(self, config): 168 | super().__init__() 169 | self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False) 170 | self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False) 171 | self.c_proj.weight.data.zero_() # zero init suggested by @Grad62304977 172 | 173 | def forward(self, x): 174 | x = self.c_fc(x) 175 | x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 176 | x = self.c_proj(x) 177 | return x 178 | 179 | class Block(nn.Module): 180 | 181 | def __init__(self, config): 182 | super().__init__() 183 | self.attn = CausalSelfAttention(config) 184 | self.mlp = MLP(config) 185 | 186 | def forward(self, x): 187 | x = x + self.attn(F.rms_norm(x, (x.size(-1),))) 188 | x = x + self.mlp(F.rms_norm(x, (x.size(-1),))) 189 | return x 190 | 191 | # ----------------------------------------------------------------------------- 192 | # The main GPT-2 model 193 | 194 | @dataclass 195 | class GPTConfig: 196 | vocab_size : int = 50304 197 | n_layer : int = 12 198 | n_head : int = 6 # head dim 128 suggested by @Grad62304977 199 | n_embd : int = 768 200 | 201 | class GPT(nn.Module): 202 | 203 | def __init__(self, config): 204 | super().__init__() 205 | self.config = config 206 | 207 | self.transformer = nn.ModuleDict(dict( 208 | wte = nn.Embedding(config.vocab_size, config.n_embd), 209 | h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), 210 | )) 211 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 212 | self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying 213 | 214 | def forward(self, idx, targets=None, return_logits=True): 215 | 216 | # forward the GPT model itself 217 | x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) 218 | for block in self.transformer.h: 219 | x = block(x) 220 | x = F.rms_norm(x, (x.size(-1),)) 221 | 222 | if targets is not None: 223 | # if we are given some desired targets also calculate the loss 224 | logits = self.lm_head(x) 225 | logits = logits.float() # use tf32/fp32 for logits 226 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) 227 | else: 228 | # inference-time mini-optimization: only forward the lm_head on the very last position 229 | logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim 230 | logits = logits.float() # use tf32/fp32 for logits 231 | loss = None 232 | 233 | # there are performance reasons why not returning logits is prudent, if not needed 234 | if not return_logits: 235 | logits = None 236 | 237 | return logits, loss 238 | 239 | # ----------------------------------------------------------------------------- 240 | # Our own simple Distributed Data Loader 241 | 242 | def _peek_data_shard(filename): 243 | # only reads the header, returns header data 244 | with open(filename, "rb") as f: 245 | # first read the header, which is 256 int32 integers (4 bytes each) 246 | header = np.frombuffer(f.read(256*4), dtype=np.int32) 247 | if header[0] != 20240520: 248 | print("ERROR: magic number mismatch in the data .bin file!") 249 | print("---> HINT: Are you passing in a correct file with --input_bin?") 250 | print("---> HINT: Dataset encoding changed recently, re-run data prepro or refer again to README") 251 | print("---> HINT: For example re-run: `python dev/data/tinyshakespeare.py`, then re-try") 252 | exit(1) 253 | assert header[1] == 1, "unsupported version" 254 | ntok = header[2] # number of tokens (claimed) 255 | return ntok # for now just return the number of tokens 256 | 257 | def _load_data_shard(filename): 258 | with open(filename, "rb") as f: 259 | # first read the header, which is 256 int32 integers (4 bytes each) 260 | header = np.frombuffer(f.read(256*4), dtype=np.int32) 261 | assert header[0] == 20240520, "magic number mismatch in the data .bin file" 262 | assert header[1] == 1, "unsupported version" 263 | ntok = header[2] # number of tokens (claimed) 264 | # the rest of it are tokens, stored as uint16 265 | tokens = np.frombuffer(f.read(), dtype=np.uint16) 266 | assert len(tokens) == ntok, "number of tokens read does not match header?" 267 | return tokens 268 | 269 | class DistributedDataLoader: 270 | def __init__(self, filename_pattern, B, T, process_rank, num_processes): 271 | self.process_rank = process_rank 272 | self.num_processes = num_processes 273 | self.B = B 274 | self.T = T 275 | 276 | # glob files that match the pattern 277 | self.files = sorted(glob.glob(filename_pattern)) 278 | assert len(self.files) > 0, f"did not find any files that match the pattern {filename_pattern}" 279 | 280 | # load and validate all data shards, count number of tokens in total 281 | ntok_total = 0 282 | for fname in self.files: 283 | shard_ntok = _peek_data_shard(fname) 284 | assert shard_ntok >= num_processes * B * T + 1 285 | ntok_total += int(shard_ntok) 286 | self.ntok_total = ntok_total 287 | 288 | # kick things off 289 | self.reset() 290 | 291 | def reset(self): 292 | self.current_shard = 0 293 | self.current_position = self.process_rank * self.B * self.T 294 | self.tokens = _load_data_shard(self.files[self.current_shard]) 295 | 296 | def advance(self): # advance to next data shard 297 | self.current_shard = (self.current_shard + 1) % len(self.files) 298 | self.current_position = self.process_rank * self.B * self.T 299 | self.tokens = _load_data_shard(self.files[self.current_shard]) 300 | 301 | def next_batch(self): 302 | B = self.B 303 | T = self.T 304 | buf = self.tokens[self.current_position : self.current_position+B*T+1] 305 | buf = torch.tensor(buf.astype(np.int32), dtype=torch.long) 306 | x = (buf[:-1]).view(B, T) # inputs 307 | y = (buf[1:]).view(B, T) # targets 308 | # advance current position and load next shard if necessary 309 | self.current_position += B * T * self.num_processes 310 | if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens): 311 | self.advance() 312 | return x.cuda(), y.cuda() 313 | 314 | # ----------------------------------------------------------------------------- 315 | # int main 316 | 317 | @dataclass 318 | class Hyperparameters: 319 | # data hyperparams 320 | input_bin : str = 'data/fineweb10B/fineweb_train_*.bin' # input .bin to train on 321 | input_val_bin : str = 'data/fineweb10B/fineweb_val_*.bin' # input .bin to eval validation loss on 322 | # optimization hyperparams 323 | batch_size : int = 8*64 # batch size, in sequences, across all devices 324 | device_batch_size : int = 64 # batch size, in sequences, per device 325 | sequence_length : int = 1024 # sequence length, in tokens 326 | num_iterations : int = 5100 # number of iterations to run 327 | learning_rate : float = 0.0036 328 | warmup_iters : int = 0 329 | warmdown_iters : int = 1450 # number of iterations of linear warmup/warmdown for triangular or trapezoidal schedule 330 | weight_decay : float = 0 331 | # evaluation and logging hyperparams 332 | val_loss_every : int = 125 # every how many steps to evaluate val loss? 0 for only at the end 333 | val_tokens : int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons 334 | save_every : int = 0 # every how many steps to save the checkpoint? 0 for only at the end 335 | args = Hyperparameters() 336 | 337 | # set up DDP (distributed data parallel). torchrun sets this env variable 338 | assert torch.cuda.is_available() 339 | dist.init_process_group(backend='nccl') 340 | ddp_rank = int(os.environ['RANK']) 341 | ddp_local_rank = int(os.environ['LOCAL_RANK']) 342 | ddp_world_size = int(os.environ['WORLD_SIZE']) 343 | device = f'cuda:{ddp_local_rank}' 344 | torch.cuda.set_device(device) 345 | print(f"using device: {device}") 346 | master_process = (ddp_rank == 0) # this process will do logging, checkpointing etc. 347 | 348 | # convenience variables 349 | B, T = args.device_batch_size, args.sequence_length 350 | # calculate the number of steps to take in the val loop. 351 | assert args.val_tokens % (B * T * ddp_world_size) == 0 352 | val_steps = args.val_tokens // (B * T * ddp_world_size) 353 | # calculate the steps of gradient accumulation required to attain the desired global batch size. 354 | assert args.batch_size % (B * ddp_world_size) == 0 355 | train_accumulation_steps = args.batch_size // (B * ddp_world_size) 356 | 357 | # load tokens 358 | train_loader = DistributedDataLoader(args.input_bin, B, T, ddp_rank, ddp_world_size) 359 | val_loader = DistributedDataLoader(args.input_val_bin, B, T, ddp_rank, ddp_world_size) 360 | if master_process: 361 | print(f"Training DataLoader: total number of tokens: {train_loader.ntok_total} across {len(train_loader.files)} files") 362 | print(f"Validation DataLoader: total number of tokens: {val_loader.ntok_total} across {len(val_loader.files)} files") 363 | x, y = train_loader.next_batch() 364 | 365 | # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. suggested to me by @Grad62304977. 366 | # this originates from Karpathy's experiments. 367 | num_vocab = 50304 368 | model = GPT(GPTConfig(vocab_size=num_vocab, n_layer=12, n_head=6, n_embd=768)) 369 | model = model.cuda() 370 | if hasattr(config, "coordinate_descent_tuning"): 371 | config.coordinate_descent_tuning = True # suggested by @Chillee 372 | model = torch.compile(model) 373 | # here we wrap model into DDP container 374 | model = DDP(model, device_ids=[ddp_local_rank]) 375 | raw_model = model.module # always contains the "raw" unwrapped model 376 | ctx = torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16) 377 | 378 | # init the optimizer(s) 379 | optimizer1 = torch.optim.AdamW(raw_model.lm_head.parameters(), lr=args.learning_rate, betas=(0.9, 0.95), 380 | weight_decay=args.weight_decay, fused=True) 381 | optimizer2 = Muon(raw_model.transformer.h.parameters(), lr=0.1*args.learning_rate, momentum=0.95) 382 | optimizers = [optimizer1, optimizer2] 383 | # learning rate decay scheduler (linear warmup and warmdown) 384 | def get_lr(it): 385 | assert it <= args.num_iterations 386 | # 1) linear warmup for warmup_iters steps 387 | if it < args.warmup_iters: 388 | return (it+1) / args.warmup_iters 389 | # 2) constant lr for a while 390 | elif it < args.num_iterations - args.warmdown_iters: 391 | return 1.0 392 | # 3) linear warmdown 393 | else: 394 | decay_ratio = (args.num_iterations - it) / args.warmdown_iters 395 | return decay_ratio 396 | schedulers = [torch.optim.lr_scheduler.LambdaLR(opt, get_lr) for opt in optimizers] 397 | 398 | # begin logging 399 | if master_process: 400 | run_id = str(uuid.uuid4()) 401 | logdir = 'logs/%s/' % run_id 402 | os.makedirs(logdir, exist_ok=True) 403 | logfile = 'logs/%s.txt' % run_id 404 | # create the log file 405 | with open(logfile, "w") as f: 406 | # begin the log by printing this file (the Python code) 407 | f.write('='*100 + '\n') 408 | f.write(code) 409 | f.write('='*100 + '\n') 410 | # log information about the hardware/software environment this is running on 411 | # and print the full `nvidia-smi` to file 412 | f.write(f"Running pytorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}\nnvidia-smi:\n") 413 | import subprocess 414 | result = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) 415 | f.write(f'{result.stdout}\n') 416 | f.write('='*100 + '\n') 417 | 418 | training_time_ms = 0 419 | # start the clock 420 | torch.cuda.synchronize() 421 | t0 = time.time() 422 | # begin training 423 | train_loader.reset() 424 | for step in range(args.num_iterations + 1): 425 | last_step = (step == args.num_iterations) 426 | # This effectively ignores timing first 10 steps, which are slower for weird reasons. 427 | # Alternately, and slightly more correctly in terms of benchmarking, we could do 10 428 | # steps with dummy data first, and then re-initialize the model and reset the loader. 429 | if step == 10: 430 | training_time_ms = 0 431 | t0 = time.time() 432 | timed_steps = float('nan') if step <= 11 else (step - 10) + 1 # <= 11 to avoid bug in val 433 | 434 | # once in a while evaluate the validation dataset 435 | if (last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0)): 436 | # stop the clock 437 | torch.cuda.synchronize() 438 | training_time_ms += 1000 * (time.time() - t0) 439 | # run validation batches 440 | model.eval() 441 | val_loader.reset() 442 | val_loss = 0.0 443 | for _ in range(val_steps): 444 | x_val, y_val = val_loader.next_batch() 445 | with ctx: # of course, we'd like to use no_grad() here too, but that creates a torch.compile error for some reason 446 | _, loss = model(x_val, y_val, return_logits=False) 447 | val_loss += loss.detach() 448 | del loss 449 | dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) 450 | val_loss /= val_steps 451 | # log val loss to console and to logfile 452 | if master_process: 453 | print(f'step:{step}/{args.num_iterations} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms') 454 | with open(logfile, "a") as f: 455 | f.write(f'step:{step}/{args.num_iterations} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms\n') 456 | # start the clock again 457 | torch.cuda.synchronize() 458 | t0 = time.time() 459 | 460 | if master_process and (last_step or (args.save_every > 0 and step % args.save_every == 0)): 461 | # stop the clock 462 | torch.cuda.synchronize() 463 | training_time_ms += 1000 * (time.time() - t0) 464 | # save the state of the training process 465 | log = dict(step=step, code=code, model=raw_model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) 466 | torch.save(log, 'logs/%s/state_step%06d.pt' % (run_id, step)) 467 | # start the clock again 468 | torch.cuda.synchronize() 469 | t0 = time.time() 470 | 471 | # bit confusing: we want to make sure to eval on 0th iteration 472 | # but also after the very last iteration. so we loop for step <= num_iterations 473 | # instead of just < num_iterations (one extra due to <=), only to do 474 | # the validation/sampling one last time, and then we break right here as we're done. 475 | if last_step: 476 | break 477 | 478 | # --------------- TRAINING SECTION BEGIN ----------------- 479 | model.train() 480 | for i in range(1, train_accumulation_steps+1): 481 | # forward pass 482 | with ctx: 483 | _, loss = model(x, y, return_logits=False) 484 | train_loss = loss.detach() 485 | # advance the dataset for the next batch 486 | x, y = train_loader.next_batch() 487 | # backward pass 488 | if i < train_accumulation_steps: 489 | with model.no_sync(): # there's no need to sync gradients every accumulation step 490 | loss.backward() 491 | else: 492 | loss.backward() # just sync on the last step 493 | for p in model.parameters(): 494 | p.grad /= train_accumulation_steps 495 | # step the optimizers and schedulers 496 | for opt, sched in zip(optimizers, schedulers): 497 | opt.step() 498 | sched.step() 499 | # null the gradients 500 | model.zero_grad(set_to_none=True) 501 | # --------------- TRAINING SECTION END ------------------- 502 | # everything that follows now is just diagnostics, prints, logging, etc. 503 | 504 | #dist.all_reduce(train_loss, op=dist.ReduceOp.AVG) # all-reducing the training loss would be more correct in terms of logging, but slower 505 | if master_process: 506 | approx_time = training_time_ms + 1000 * (time.time() - t0) 507 | print(f"step:{step+1}/{args.num_iterations} train_loss:{train_loss.item():.4f} train_time:{approx_time:.0f}ms step_avg:{approx_time/timed_steps:.2f}ms") 508 | with open(logfile, "a") as f: 509 | f.write(f"step:{step+1}/{args.num_iterations} train_loss:{train_loss.item():.4f} train_time:{approx_time:.0f}ms step_avg:{approx_time/timed_steps:.2f}ms\n") 510 | 511 | if master_process: 512 | print(f"peak memory consumption: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB") 513 | 514 | # ------------------------------------------------------------------------- 515 | # clean up nice 516 | dist.destroy_process_group() 517 | -------------------------------------------------------------------------------- /records/102924_Optimizers/README.md: -------------------------------------------------------------------------------- 1 | # Optimizer comparison for NanoGPT speedrunning 2 | 3 | This is a comparison between the four best optimizers I am aware of for NanoGPT speedrunning. They are compared using the 10/18/24 NanoGPT speedrunning record. 4 | 5 | Reproducible logs: 6 | * [Adam](95a9fd44-7c13-49c7-b324-3e7d9e23a499.txt) 7 | * [DistributedShampoo](8bfe4e35-c3fc-4b70-a984-3be937b71ff3) 8 | * [SOAP](e21a2838-a0f2-46f2-a247-db0021165682.txt) 9 | * [Muon](8d6193f4-27fc-4e68-899f-af70019a4d54.txt) 10 | 11 | Results: 12 | ![1](nanogpt_speedrun81w.png) 13 | ![2](nanogpt_speedrun82w.png) 14 | 15 | ### General notes for all optimizers 16 | 17 | All optimizers are run using zero weight decay (which is found to be empirically optimal). 18 | 19 | And they are all run with a warmup-stable-decay / trapezoidal schedule, which also seems to be optimal. That's what causes the kink in the loss curve ~75% of the way to the end. 20 | 21 | In addition, in all cases, we optimize the shared embedding/head layer just using Adam (which is also found to be empirically optimal). 22 | Note that in the following code snippets, `raw_model.transformer.h.parameters()` gives all parameters besides those two. 23 | 24 | In each case, the hyperparameters are the best ones I could find in around 20 attempts. 25 | 26 | ## [Adam](95a9fd44-7c13-49c7-b324-3e7d9e23a499.txt) 27 | The optimizer here is equivalent to: 28 | ``` 29 | torch.optim.Adam(raw_model.transformer.h.parameters(), lr=0.0018, betas=(0.9, 0.95)) 30 | ``` 31 | 32 | 33 | ## [DistributedShampoo](8bfe4e35-c3fc-4b70-a984-3be937b71ff3.txt) 34 | Run as follows: 35 | ``` 36 | DistributedShampoo( 37 | raw_model.transformer.h.parameters(), 38 | lr=0.0018, 39 | betas=(0.95, 0.95), 40 | epsilon=1e-12, 41 | weight_decay=0, 42 | max_preconditioner_dim=8192, 43 | precondition_frequency=10, 44 | use_decoupled_weight_decay=True, 45 | grafting_config=AdamGraftingConfig( 46 | beta2=0.95, 47 | epsilon=1e-8, 48 | ), 49 | distributed_config=DDPShampooConfig( 50 | communication_dtype=CommunicationDType.FP32, 51 | num_trainers_per_group=8, 52 | communicate_params=False, 53 | ), 54 | ) 55 | ``` 56 | 57 | This is using the official `DistributedShampoo` implementation from [here](https://github.com/facebookresearch/optimizers/tree/ad2809a291c01859f68fcabbcb49a2aa75fd7827/distributed_shampoo). 58 | 59 | Things that turned out to be important: 60 | * Don't use epsilon above 1e-8; this loses performance. Epsilon 1e-12 performs as well as 1e-15 61 | * Betas=(0.95, 0.95) seemed optimal, which turns out to be the same thing that SOAP uses 62 | * Higher preconditioner update frequency is better but slower 63 | 64 | I'm open to hyperparameter suggestions; the experiment takes ~20-30 minutes to run on a fresh 8xH100 instance, so it's not hard for me to run more attempts. 65 | 66 | 67 | ## [SOAP](e21a2838-a0f2-46f2-a247-db0021165682.txt) 68 | ``` 69 | SOAP(model.transformer.h.parameters(), lr=0.0018, betas=(.95, .95), precondition_frequency=10) 70 | ``` 71 | 72 | This is using the official SOAP implementation [here](https://github.com/nikhilvyas/SOAP/blob/bbce86e890d3b697380f4376acb600c2d6c3d203/soap.py). 73 | 74 | Based on conversations with the authors, it is likely that a future SOAP implementation will significantly reduce the wallclock overhead. 75 | 76 | 77 | ## [Muon](8d6193f4-27fc-4e68-899f-af70019a4d54.txt) 78 | ``` 79 | Muon(raw_model.transformer.h.parameters(), lr=0.02, momentum=0.95) 80 | ``` 81 | 82 | 83 | ## Openness 84 | 85 | These training logs are reproducible (just cut out the part besides the code, and run it using the `run.sh` in the top-level folder). They take 12-25 minutes to run. 86 | 87 | I tried to do a good job sweeping the hyperparameters for each optimizer, but I can easily have missed something, or just not have performed enough runs. 88 | 89 | Therefore, I am interested in any better hyperparameter settings which other researchers can find, for any of the optimizers. 90 | If you post or send me your own reproducible log with one of these optimizers, I will be very happy to boost it in any way I can. 91 | 92 | ## Appendix: Negative results 93 | 94 | I believe it was Shazeer who said something like "negative results in machine learning are not worth much, because your inability to make something work doesn't prove that it can't work" 95 | 96 | Given that disclaimer, here are some optimizers that I tried to make work, but was unable to get a significant boost over Adam with: 97 | * Sophia 98 | * Lion 99 | * AdamWScheduleFree 100 | * AdEmaMix (actually this was slightly better than Adam, just not enough to get near competing with the three Shampoo-like optimizers) 101 | 102 | Of course, this is just for NanoGPT speedrunning (short train duration); it's quite possible they work better at longer training duration or for larger models. 103 | 104 | -------------------------------------------------------------------------------- /records/102924_Optimizers/nanogpt_speedrun81w.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KellerJordan/modded-nanogpt/ca964e982191830eebbd155e185937077511a8aa/records/102924_Optimizers/nanogpt_speedrun81w.png -------------------------------------------------------------------------------- /records/102924_Optimizers/nanogpt_speedrun82w.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KellerJordan/modded-nanogpt/ca964e982191830eebbd155e185937077511a8aa/records/102924_Optimizers/nanogpt_speedrun82w.png -------------------------------------------------------------------------------- /records/110324_UntieEmbed/README.md: -------------------------------------------------------------------------------- 1 | # New record 11/03/24 2 | 3 | 4 | New NanoGPT training speed record: 3.28 FineWeb val loss in 10.8 minutes on 8xH100 5 | 6 | Previous record: 12.0 minutes 7 | Changelog: 8 | - untied embed and head weights 9 | - added RMSNorm after embed 10 | - init head to zero 11 | 12 | Driven by @Grad62304977 13 | 14 | --- 15 | 16 | Technically, this is somewhat of an "any%" record, since untying the embedding and lm_head adds 39M parameters. 17 | 18 | However, it doesn't change the number of active parameters or the inference throughput. Future records will stay constrained to 124M active parameters. 19 | 20 | --- 21 | 22 | Like the last architectural change, this record was driven by @Grad62304977. I just finetuned some things and did bookkeeping. 23 | 24 | --- 25 | 26 | Shoutout to @cloneofsimo whose scaling guide already suggests initializing the head to zero. This works quite well and is a significant fraction of the record. 27 | 28 | -------------------------------------------------------------------------------- /records/110424_50Bruns/README.md: -------------------------------------------------------------------------------- 1 | # 50B-token runs 2 | 3 | This folder contains four runs generated by extending the 11/03/24 speedrun record to 50B FineWeb tokens. 4 | The goal is to test how the speedrun generalizes to long durations, and especially how well Muon does. 5 | 6 | We compare two things: 7 | 1. We compare Muon to Adam as the optimizer for the transformer body. (The head and embedding are always optimized by Adam.) 8 | 2. We compare training on 5 epochs of 10B tokens to training on 50B tokens. (Surprisingly this does about the same) 9 | 10 | The four resulting runs are as follows: 11 | 12 | * [Muon 50B tokens](./530f3ee1-8862-4d21-be2b-da10eb05e6a9.txt) (HellaSwag=35.82) 13 | * [Adam 50B tokens](./69c33fc9-eabb-4a38-aa08-6922914eb405.txt) (HellaSwag=34.26) 14 | * [Muon 5x10B tokens](./4fbe61ec-f79a-4c19-836d-46d599deecce.txt) (HellaSwag=36.17) 15 | * [Adam 5x10B tokens](./3d715d41-453a-40d6-9506-421ba69766b2.txt) (HellaSwag=34.05) 16 | 17 | To get a sense of what a good HellaSwag score would be for this scale of model, here are some baselines: 18 | * Karpathy's baseline llm.c training (trained for 10B FineWeb tokens): 29.9 19 | * OpenAI GPT-2 (124M): 29.4 20 | * OpenAI GPT-3 (124M) (trained for 300B WebText tokens): 33.7 21 | * Huggingface SmolLM2-135M (trained for 2T FineWeb/DCLM/etc tokens): 42.1 22 | 23 | Note: I'm a little concerned that the learning rate schedule (WSD) and weight decay (zero), which are tuned for the speedrun duration, 24 | might become undertuned/suboptimal for trainings of this duration. 25 | It does look like the gap between Muon/Adam is too large to be closed by something like this, and the HellaSwag scores look quite reasonable, but you never know. 26 | 27 | -------------------------------------------------------------------------------- /records/110624_ShortcutsTweaks/README.md: -------------------------------------------------------------------------------- 1 | # New record 11/06/24 2 | 3 | 8.2 minutes on 8xH100 (previous record: 10.8 minutes) 4 | 5 | ![](nanogpt_speedrun110.png) 6 | ![](nanogpt_speedrun111.png) 7 | 8 | * [Old record 11/03/24](d6b50d71-f419-4d26-bb39-a60d55ae7a04.txt) 9 | * [+shorten duration](4a71cc92-0f43-4058-a033-23e85c1e98f1.txt) 10 | * [+value residual](042f9e87-07e6-4504-bb04-4ec59a380211.txt) by @Grad62304977 following [1] 11 | * [+learnable lambda](43f60c4f-0448-4de7-83d9-643ca26f61e7.txt) @Grad62304977's innovation on top of [1] 12 | * [+embed shortcut](05b29e54-0be0-4a0f-a1e2-7d5317daedd3.txt) 13 | * [+momentum warmup](10119f53-7001-4248-bfd9-33d32427a912.txt) 14 | * [+tanh logit capping](dd7304a6-cc43-4d5e-adb8-c070111464a1.txt) by @Grad62304977 following [2] 15 | 16 | ## Code snippets 17 | 18 | ### Value residual 19 | 20 | In the attention layer: 21 | ``` 22 | def forward(self, x, v1=None): 23 | ... 24 | v = self.c_v(x).view(B, T, self.n_head, self.head_dim) 25 | if v1 is None: 26 | v1 = v 27 | v = 0.5 * v + 0.5 * v1.view_as(v) 28 | ``` 29 | Where the first block receives v1=None, and subsequent blocks receive v1 as the value produced by the first block. 30 | 31 | ### Learnable lambda 32 | 33 | In the attention block: 34 | ``` 35 | def __init__(self, config): 36 | ... 37 | self.lamb = nn.Parameter(torch.tensor(0.5)) 38 | 39 | def forward(self, x, v1=None): 40 | ... 41 | v = (1 - self.lamb) * v + self.lamb * v1.view_as(v) 42 | ``` 43 | That is, we just replace the fixed 0.5 constant used in standard value residual [1] with a learnable scalar (optimized by Adam(lr=0.02)). 44 | 45 | ### Embed shortcut 46 | 47 | Replaces the standard transformer block with this: 48 | 49 | ``` 50 | class Block(nn.Module): 51 | 52 | def __init__(self, config): 53 | super().__init__() 54 | self.attn = CausalSelfAttention(config) 55 | self.mlp = MLP(config) 56 | self.lambdas = nn.Parameter(torch.tensor([1., 0.])) 57 | 58 | def forward(self, x, x0): 59 | x = self.lambdas[0] * x + self.lambdas[1] * x0 60 | x = x + self.attn(F.rms_norm(x, (x.size(-1),)), v1) 61 | x = x + self.mlp(F.rms_norm(x, (x.size(-1),))) 62 | return x 63 | ``` 64 | 65 | where the two scalars are optimized using Adam(lr=0.02), and `x0` is fed in from the initial embedding via: 66 | ``` 67 | ... 68 | x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) 69 | x = F.rms_norm(x, (x.size(-1),)) 70 | x0 = x 71 | for block in self.transformer.h: 72 | x = block(x, x0) 73 | ... 74 | ``` 75 | 76 | ### Momentum warmup 77 | 78 | Just adds the following two lines. 79 | ``` 80 | frac = min(step/500, 1) 81 | optimizer3.param_groups[0]['momentum'] = (1 - frac) * 0.85 + frac * 0.95 82 | ``` 83 | where `optimizer3` is the Muon for the body of the transformer. 84 | 85 | ### Tanh soft capping 86 | 87 | Just adds the following line. 88 | 89 | ``` 90 | logits = 30 * torch.tanh(logits / 30) 91 | ``` 92 | 93 | 94 | ## References 95 | 96 | 1. [Zhou, Zhanchao, et al. "Value Residual Learning For Alleviating Attention Concentration In Transformers." arXiv preprint arXiv:2410.17897 (2024).](https://arxiv.org/abs/2410.17897) 97 | 2. [Team, Gemma, et al. "Gemma 2: Improving open language models at a practical size." arXiv preprint arXiv:2408.00118 (2024).](https://arxiv.org/abs/2408.00118) 98 | 99 | -------------------------------------------------------------------------------- /records/110624_ShortcutsTweaks/nanogpt_speedrun110.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KellerJordan/modded-nanogpt/ca964e982191830eebbd155e185937077511a8aa/records/110624_ShortcutsTweaks/nanogpt_speedrun110.png -------------------------------------------------------------------------------- /records/110624_ShortcutsTweaks/nanogpt_speedrun111.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KellerJordan/modded-nanogpt/ca964e982191830eebbd155e185937077511a8aa/records/110624_ShortcutsTweaks/nanogpt_speedrun111.png -------------------------------------------------------------------------------- /records/110924_Replicateleloykun/README.md: -------------------------------------------------------------------------------- 1 | This is a replication attempt for the record attempt described [here](https://x.com/leloykun/status/1854557419768254915) by @leloykun. 2 | 3 | The original record could not be directly accepted because it showed a slower wallclock time than the previous record - 4 | however, this was plausibly due to hardware differences, as the competitor's hardware was slightly slower. 5 | 6 | Therefore, to certify this attempt as the new record, here I replicated it on my own hardware. 7 | This did successfully reduce the wallclock time compared to the 11/07/24 record by ~11 seconds, however it also 8 | resulted in an invalid val loss of 3.2824, above the threshold of 3.28. 9 | 10 | The [original record attempt's reproducible log](https://github.com/leloykun/modded-nanogpt/blob/224f10d190677d9dc3c9c45da280078196a6fe40/records/110724_EmbeddingBetasCooldown/6c9d875b-ad91-46c9-9ede-2c7f998b9b16.txt) attained a val loss of 3.2798, just barely below the 3.28 threshold. So this difference is plausibly due to random inter-run variance. 11 | 12 | This indicates that the true average val loss of the run may be worse than 3.28, meaning I am **unable to certify it as the new record.** 13 | 14 | Ideally, all records should attain a low enough val loss such that >95% of runs attain below 3.28. Good evidence for this would be a single run 15 | attaining <= 3.278. Previous records have adhered to this rule, but admittedly it's hard to define precisely and is therefore mostly a matter of taste. 16 | 17 | -------------------------------------------------------------------------------- /records/111024_UNetDoubleLr/README.md: -------------------------------------------------------------------------------- 1 | This is a record by Brendan Hogan Rappazzo [@brendanh0gan](https://x.com/brendanh0gan). 2 | 3 | New record: 7.23 minutes 4 | 5 | Previous record: 7.8 minutes 6 | 7 | Changelog: 8 | - Added U-net-like skip connections into the transformer 9 | - Doubled the learning rate 10 | 11 | --- 12 | 13 | This record was first posted [here](https://x.com/brendanh0gan/status/1855273758681866352), & then a few iterations were required to benchmark it on 8x SXM H100s. 14 | Brendan's fork of modded-nanogpt is [here](https://github.com/brendanhogan/modded-nanogpt/tree/master). The code for the record can also be extracted from the reproducible log in this folder. 15 | 16 | -------------------------------------------------------------------------------- /records/111924_FlexAttention/README.md: -------------------------------------------------------------------------------- 1 | ## 11/19/24 FlexAttention record 2 | 3 | This training has significant variance of around 0.005 stddev between runs. So not all runs go beneath 3.28, though [the mean is around 3.279](https://x.com/YouJiacheng/status/1859876224639828068). 4 | 5 | This variance is probably caused by the previous record doubling the learning rate, plus this record significantly shortening the duration. 6 | -------------------------------------------------------------------------------- /records/120424_ValueEmbed/README.md: -------------------------------------------------------------------------------- 1 | ## Statistical tests 2 | 3 | ``` 4 | accs = [3.2759, 3.2781, 3.2791, 3.2771, 3.2838, 3.2749, 3.2793, 3.279, 3.2794, 3.2744, 3.2751, 5 | 3.2845, 3.2736, 3.2783, 3.2793, 3.2779, 3.2756, 3.281, 3.2803, 3.2766, 3.2851, 3.275, 6 | 3.2778, 3.2723, 3.2842, 3.2735, 3.275, 3.2796, 3.2782, 3.2758, 3.2763, 3.2751, 3.2791, 7 | 3.2804, 3.2725, 3.2898, 3.2718, 3.2764, 3.271, 3.2745] 8 | 9 | import scipy.stats 10 | print('p=%.4f' % scipy.stats.ttest_1samp(accs, 3.28, alternative='less').pvalue) 11 | # p=0.0003 (statistically significant) 12 | 13 | import torch 14 | print(torch.std_mean(torch.tensor(accs))) 15 | # (tensor(0.0040), tensor(3.2777)) 16 | ``` 17 | 18 | ## ChangeLog 19 | 20 | * Added 12 new embedding layers which get mixed into the value activations at each layer. (=463M new parameters, of which 9216 are active per token) 21 | 22 | ## Contributors 23 | 24 | * @KoszarskyB 25 | 26 | 27 | -------------------------------------------------------------------------------- /records/120424_ValueEmbed/train_gpt2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | with open(sys.argv[0]) as f: 4 | code = f.read() # read the code of this file ASAP, for logging 5 | import uuid 6 | import glob 7 | import time 8 | import contextlib 9 | from dataclasses import dataclass 10 | 11 | import numpy as np 12 | import torch 13 | from torch import nn 14 | import torch.nn.functional as F 15 | import torch.distributed as dist 16 | import torch._inductor.config as config 17 | from torch.nn.parallel import DistributedDataParallel as DDP 18 | # Use of FlexAttention contributed by @KoszarskyB 19 | from torch.nn.attention.flex_attention import flex_attention, create_block_mask 20 | flex_attention = torch.compile(flex_attention, dynamic=False) 21 | create_block_mask = torch.compile(create_block_mask, dynamic=False) 22 | 23 | # ----------------------------------------------------------------------------- 24 | # Muon optimizer 25 | 26 | def zeropower_via_svd(G, steps=None): 27 | U, S, V = G.svd() 28 | return U @ V.T 29 | 30 | @torch.compile 31 | def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7): 32 | """ 33 | Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a 34 | quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose 35 | of minimizing steps, it turns out to be empirically effective to keep increasing the slope at 36 | zero even beyond the point where the iteration no longer converges all the way to one everywhere 37 | on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T 38 | where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model 39 | performance at all relative to UV^T, where USV^T = G is the SVD. 40 | """ 41 | assert len(G.shape) == 2 42 | a, b, c = (3.4445, -4.7750, 2.0315) 43 | X = G.bfloat16() 44 | X /= (X.norm() + eps) # ensure top singular value <= 1 45 | if G.size(0) > G.size(1): 46 | X = X.T 47 | for _ in range(steps): 48 | A = X @ X.T 49 | B = b * A + c * A @ A # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng 50 | X = a * X + B @ X 51 | if G.size(0) > G.size(1): 52 | X = X.T 53 | return X 54 | 55 | zeropower_backends = dict(svd=zeropower_via_svd, newtonschulz5=zeropower_via_newtonschulz5) 56 | 57 | class Muon(torch.optim.Optimizer): 58 | """ 59 | Muon - MomentUm Orthogonalized by Newton-schulz 60 | 61 | Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- 62 | processing step, in which each 2D parameter's update is replaced with the nearest orthogonal 63 | matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has 64 | the advantage that it can be stably run in bfloat16 on the GPU. 65 | 66 | Some warnings: 67 | - This optimizer assumes that all parameters passed in are 2D. 68 | - It should not be used for the embedding layer, the final fully connected layer, or any {0,1}-D 69 | parameters; those should all be optimized by a standard method (e.g., AdamW). 70 | - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions. 71 | - We believe it is unlikely to work well for training with small batch size. 72 | - We believe it may not work well for finetuning pretrained models, but we haven't tested this. 73 | - We have not yet tried this optimizer for training scenarios larger than NanoGPT (124M). 74 | 75 | Arguments: 76 | lr: The learning rate used by the internal SGD. 77 | momentum: The momentum used by the internal SGD. 78 | nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) 79 | backend: The chosen backend for the orthogonalization step. (recommended: 'newtonschulz5') 80 | backend_steps: The number of iteration steps to use in the backend, if it is iterative. 81 | """ 82 | def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, 83 | backend='newtonschulz5', backend_steps=5): 84 | defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, backend=backend, backend_steps=backend_steps) 85 | super().__init__(params, defaults) 86 | 87 | def step(self): 88 | 89 | for group in self.param_groups: 90 | 91 | lr = group['lr'] 92 | momentum = group['momentum'] 93 | zeropower_backend = zeropower_backends[group['backend']] 94 | 95 | # generate weight updates in distributed fashion 96 | total_params = sum(p.numel() for p in group['params']) 97 | updates_flat = torch.zeros(total_params, device='cuda', dtype=torch.bfloat16) 98 | curr_idx = 0 99 | for i, p in enumerate(group['params']): 100 | # luckily this will perfectly distribute a transformer with multiple of 4 layers to 8 GPUs 101 | if i % int(os.environ['WORLD_SIZE']) == int(os.environ['RANK']): 102 | g = p.grad 103 | assert g is not None 104 | state = self.state[p] 105 | if 'momentum_buffer' not in state: 106 | state['momentum_buffer'] = torch.zeros_like(g) 107 | buf = state['momentum_buffer'] 108 | buf.mul_(momentum).add_(g) 109 | g = g.add(buf, alpha=momentum) if group['nesterov'] else buf 110 | g = zeropower_backend(g, steps=group['backend_steps']) 111 | g *= max(1, g.size(0)/g.size(1))**0.5 112 | updates_flat[curr_idx:curr_idx+p.numel()] = g.flatten() 113 | curr_idx += p.numel() 114 | 115 | # sync updates across devices. we are not memory-constrained so can do this simple deserialization 116 | dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) 117 | 118 | # deserialize and apply updates 119 | curr_idx = 0 120 | for p in group['params']: 121 | g = updates_flat[curr_idx:curr_idx+p.numel()].view_as(p.data).type_as(p.data) 122 | p.data.add_(g, alpha=-lr) 123 | curr_idx += p.numel() 124 | 125 | # ----------------------------------------------------------------------------- 126 | # PyTorch nn.Module definitions for the GPT-2 model 127 | 128 | def norm(x): 129 | return F.rms_norm(x, (x.size(-1),)) 130 | 131 | class CastedLinear(nn.Linear): 132 | 133 | def __init__(self, in_features, out_features): 134 | super().__init__(in_features, out_features, bias=False) 135 | 136 | def forward(self, x): 137 | return F.linear(x, self.weight.to(x.dtype)) 138 | 139 | class Rotary(torch.nn.Module): 140 | 141 | def __init__(self, dim, base=10000): 142 | super().__init__() 143 | self.register_buffer('inv_freq', (1 / base) ** (torch.arange(0, dim, 2) / dim)) 144 | self.seq_len_cached = None 145 | self.cos_cached = None 146 | self.sin_cached = None 147 | 148 | def forward(self, x): 149 | seq_len = x.shape[1] 150 | if seq_len != self.seq_len_cached: 151 | t = torch.arange(seq_len, device=x.device) 152 | freqs = torch.outer(t, self.inv_freq) 153 | self.seq_len_cached = seq_len 154 | self.cos_cached = freqs.cos() 155 | self.sin_cached = freqs.sin() 156 | cos, sin = self.cos_cached[None, :, None, :], self.sin_cached[None, :, None, :] 157 | # apply_rotary_emb(x, cos, sin) 158 | x1, x2 = x.chunk(2, dim=3) 159 | y1 = x1 * cos + x2 * sin 160 | y2 = x1 * (-sin) + x2 * cos 161 | return torch.cat((y1, y2), 3).type_as(x) 162 | 163 | class CausalSelfAttention(nn.Module): 164 | 165 | def __init__(self, dim, n_head): 166 | super().__init__() 167 | assert dim % n_head == 0 168 | self.n_head = n_head 169 | self.c_q = CastedLinear(dim, dim) 170 | self.c_k = CastedLinear(dim, dim) 171 | self.c_v = CastedLinear(dim, dim) 172 | # value residual lambda 173 | self.lamb = nn.Parameter(torch.tensor(0.5)) # @Grad62304977 174 | # rotary embeddings 175 | self.rotary = Rotary(dim // n_head) # dim // n_head = head_dim 176 | # output projection 177 | self.c_proj = CastedLinear(dim, dim) 178 | self.c_proj.weight.data.zero_() # zero init suggested by @Grad62304977 179 | 180 | def forward(self, x, vi, block_mask): 181 | B, T = x.size(0), x.size(1) # batch size, sequence length 182 | assert B == 1, "Must use batch size = 1 for FlexAttention" 183 | q = self.c_q(x).view(B, T, self.n_head, -1) 184 | k = self.c_k(x).view(B, T, self.n_head, -1) 185 | v = self.c_v(x).view(B, T, self.n_head, -1) 186 | v = (1 - self.lamb) * v + self.lamb * vi.view_as(v) # @Grad62304977 187 | q, k = norm(q), norm(k) # QK norm suggested by @Grad62304977 188 | q, k = self.rotary(q), self.rotary(k) 189 | y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask) 190 | y = y.transpose(1, 2).contiguous().view_as(x) # re-assemble all head outputs side by side 191 | y = self.c_proj(y) 192 | return y 193 | 194 | class MLP(nn.Module): 195 | 196 | def __init__(self, dim): 197 | super().__init__() 198 | self.c_fc = CastedLinear(dim, 4 * dim) 199 | self.c_proj = CastedLinear(4 * dim, dim) 200 | self.c_proj.weight.data.zero_() # zero init suggested by @Grad62304977 201 | 202 | def forward(self, x): 203 | x = self.c_fc(x) 204 | x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 205 | x = self.c_proj(x) 206 | return x 207 | 208 | class Block(nn.Module): 209 | 210 | def __init__(self, config): 211 | super().__init__() 212 | self.attn = CausalSelfAttention(config.n_embd, config.n_head) 213 | self.mlp = MLP(config.n_embd) 214 | self.lambdas = nn.Parameter(torch.tensor([1., 0.])) 215 | 216 | def forward(self, x, vi, x0, block_mask): 217 | x = self.lambdas[0] * x + self.lambdas[1] * x0 218 | x = x + self.attn(norm(x), vi, block_mask) 219 | x = x + self.mlp(norm(x)) 220 | return x 221 | 222 | # ----------------------------------------------------------------------------- 223 | # The main GPT-2 model 224 | 225 | @dataclass 226 | class GPTConfig: 227 | vocab_size : int = 50304 228 | n_layer : int = 12 229 | n_head : int = 6 # head dim 128 suggested by @Grad62304977 230 | n_embd : int = 768 231 | 232 | class GPT(nn.Module): 233 | 234 | def __init__(self, config): 235 | super().__init__() 236 | 237 | # U-net design by @brendanh0gan 238 | self.num_encoder_layers = config.n_layer // 2 # Half of the layers for encoder 239 | self.num_decoder_layers = config.n_layer - self.num_encoder_layers # Remaining for decoder 240 | # Add learnable skip connection weights for decoder layers 241 | self.skip_weights = nn.Parameter(torch.ones(self.num_decoder_layers)) 242 | 243 | self.transformer = nn.ModuleDict(dict( 244 | wte = nn.Embedding(config.vocab_size, config.n_embd), 245 | # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual learning 246 | vte = nn.Embedding(config.vocab_size, config.n_embd*12), 247 | h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), 248 | )) 249 | self.lm_head = CastedLinear(config.n_embd, config.vocab_size) 250 | self.lm_head.weight.data.zero_() # @Grad62304977 251 | 252 | def forward(self, idx, target, attn_blocksize): 253 | 254 | docs = (idx == 50256).cumsum(0) 255 | def document_causal_mask(b, h, q_idx, kv_idx): 256 | causal_mask = q_idx >= kv_idx 257 | document_mask = docs[q_idx] == docs[kv_idx] 258 | window_mask = q_idx - kv_idx < attn_blocksize 259 | return causal_mask & document_mask & window_mask 260 | 261 | S = len(idx) 262 | block_mask = create_block_mask(document_causal_mask, None, None, S, S, device="cuda", _compile=True) 263 | 264 | # forward the GPT model itself 265 | x = self.transformer.wte(idx[None]) # token embeddings of shape (b, t, n_embd) 266 | x = norm(x) # @Grad62304977 267 | x0 = x 268 | vi = self.transformer.vte(idx[None]).chunk(12, dim=-1) 269 | 270 | # Store outputs for U-Net skip connections 271 | skip_connections = [] 272 | # Encoder pass - process only the first half of the blocks 273 | for i in range(self.num_encoder_layers): 274 | x = self.transformer.h[i](x, vi[i], x0, block_mask) 275 | skip_connections.append(x) 276 | # Decoder pass - process the remaining blocks with weighted skip connections 277 | for i in range(self.num_decoder_layers): 278 | x = x + self.skip_weights[i] * skip_connections.pop() 279 | x = self.transformer.h[self.num_encoder_layers + i](x, vi[self.num_encoder_layers+i], x0, block_mask) 280 | 281 | x = norm(x) 282 | logits = self.lm_head(x) 283 | logits = 30 * torch.tanh(logits / 30) # @Grad62304977 284 | logits = logits.float() 285 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target.view(-1)) 286 | return loss 287 | 288 | # ----------------------------------------------------------------------------- 289 | # Our own simple Distributed Data Loader 290 | 291 | def _peek_data_shard(filename): 292 | # only reads the header, returns header data 293 | with open(filename, "rb") as f: 294 | # first read the header, which is 256 int32 integers (4 bytes each) 295 | header = np.frombuffer(f.read(256*4), dtype=np.int32) 296 | if header[0] != 20240520: 297 | print("ERROR: magic number mismatch in the data .bin file!") 298 | print("---> HINT: Are you passing in a correct file with --input_bin?") 299 | print("---> HINT: Dataset encoding changed recently, re-run data prepro or refer again to README") 300 | print("---> HINT: For example re-run: `python dev/data/tinyshakespeare.py`, then re-try") 301 | exit(1) 302 | assert header[1] == 1, "unsupported version" 303 | ntok = header[2] # number of tokens (claimed) 304 | return ntok # for now just return the number of tokens 305 | 306 | def _load_data_shard(filename): 307 | with open(filename, "rb") as f: 308 | # first read the header, which is 256 int32 integers (4 bytes each) 309 | header = np.frombuffer(f.read(256*4), dtype=np.int32) 310 | assert header[0] == 20240520, "magic number mismatch in the data .bin file" 311 | assert header[1] == 1, "unsupported version" 312 | ntok = header[2] # number of tokens (claimed) 313 | # the rest of it are tokens, stored as uint16 314 | tokens = np.frombuffer(f.read(), dtype=np.uint16) 315 | assert len(tokens) == ntok, "number of tokens read does not match header?" 316 | return tokens 317 | 318 | class DistributedDataLoader: 319 | def __init__(self, filename_pattern, T, process_rank, num_processes): 320 | self.process_rank = process_rank 321 | self.num_processes = num_processes 322 | self.T = T 323 | 324 | # glob files that match the pattern 325 | self.files = sorted(glob.glob(filename_pattern)) 326 | assert len(self.files) > 0, f"did not find any files that match the pattern {filename_pattern}" 327 | 328 | # load and validate all data shards, count number of tokens in total 329 | ntok_total = 0 330 | for fname in self.files: 331 | shard_ntok = _peek_data_shard(fname) 332 | assert shard_ntok >= num_processes * T + 1 333 | ntok_total += int(shard_ntok) 334 | self.ntok_total = ntok_total 335 | 336 | self.reset() 337 | 338 | def reset(self): 339 | self.current_shard = -1 340 | self.advance() 341 | 342 | def advance(self): # advance to next data shard 343 | self.current_shard = (self.current_shard + 1) % len(self.files) 344 | self.current_position = self.process_rank * self.T 345 | self.tokens = _load_data_shard(self.files[self.current_shard]) 346 | 347 | def next_batch(self): 348 | batch_size = self.T * self.num_processes 349 | buf = self.tokens[self.current_position:self.current_position+self.T+1] 350 | buf = torch.tensor(buf.astype(np.int32), dtype=torch.long) 351 | x = buf[:-1] # inputs 352 | y = buf[1:] # targets 353 | # advance current position and load next shard if necessary 354 | self.current_position += batch_size 355 | if self.current_position + batch_size >= len(self.tokens): 356 | self.advance() 357 | return x.cuda(), y.cuda() 358 | 359 | # ----------------------------------------------------------------------------- 360 | # int main 361 | 362 | @dataclass 363 | class Hyperparameters: 364 | # data hyperparams 365 | input_bin : str = 'data/fineweb10B/fineweb_train_*.bin' # input .bin to train on 366 | input_val_bin : str = 'data/fineweb10B/fineweb_val_*.bin' # input .bin to eval validation loss on 367 | # optimization hyperparams 368 | batch_size : int = 8 # batch size, in sequences, across all devices 369 | sequence_length : int = 64*1024 # sequence length, in tokens 370 | num_iterations : int = 1530 # number of iterations to run 371 | warmup_iters : int = 0 372 | cooldown_iters : int = 600 # number of iterations of linear warmup/cooldown for triangular or trapezoidal schedule 373 | weight_decay : float = 0 374 | # evaluation and logging hyperparams 375 | val_loss_every : int = 125 # every how many steps to evaluate val loss? 0 for only at the end 376 | val_tokens : int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons 377 | save_every : int = 0 # every how many steps to save the checkpoint? 0 for only at the end 378 | args = Hyperparameters() 379 | 380 | # set up DDP (distributed data parallel). torchrun sets this env variable 381 | assert torch.cuda.is_available() 382 | dist.init_process_group(backend='nccl') 383 | ddp_rank = int(os.environ['RANK']) 384 | ddp_local_rank = int(os.environ['LOCAL_RANK']) 385 | ddp_world_size = int(os.environ['WORLD_SIZE']) 386 | device = f'cuda:{ddp_local_rank}' 387 | torch.cuda.set_device(device) 388 | print(f"using device: {device}") 389 | master_process = (ddp_rank == 0) # this process will do logging, checkpointing etc. 390 | 391 | # begin logging 392 | logfile = None 393 | if master_process: 394 | run_id = str(uuid.uuid4()) 395 | logdir = 'logs/%s/' % run_id 396 | os.makedirs(logdir, exist_ok=True) 397 | logfile = 'logs/%s.txt' % run_id 398 | # create the log file 399 | with open(logfile, "w") as f: 400 | # begin the log by printing this file (the Python code) 401 | f.write(code) 402 | f.write('='*100 + '\n') 403 | def print0(s, logonly=False): 404 | if master_process: 405 | with open(logfile, "a") as f: 406 | if not logonly: 407 | print(s) 408 | f.write(s+'\n') 409 | # log information about the hardware/software environment this is running on 410 | # and print the full `nvidia-smi` to file 411 | print0(f"Running pytorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}\nnvidia-smi:") 412 | import subprocess 413 | result = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) 414 | print0(f'{result.stdout}', logonly=True) 415 | print0('='*100, logonly=True) 416 | 417 | # convenience variables 418 | T = args.sequence_length 419 | # calculate the number of steps to take in the val loop. 420 | assert args.val_tokens % (T * ddp_world_size) == 0 421 | val_steps = args.val_tokens // (T * ddp_world_size) 422 | # calculate the steps of gradient accumulation required to attain the desired global batch size. 423 | assert args.batch_size % (ddp_world_size) == 0 424 | train_accumulation_steps = args.batch_size // ddp_world_size 425 | 426 | # load tokens 427 | train_loader = DistributedDataLoader(args.input_bin, T, ddp_rank, ddp_world_size) 428 | val_loader = DistributedDataLoader(args.input_val_bin, T, ddp_rank, ddp_world_size) 429 | print0(f"Training DataLoader: total number of tokens: {train_loader.ntok_total} across {len(train_loader.files)} files") 430 | print0(f"Validation DataLoader: total number of tokens: {val_loader.ntok_total} across {len(val_loader.files)} files") 431 | print0('='*100, logonly=True) 432 | x, y = train_loader.next_batch() 433 | 434 | # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. suggested to me by @Grad62304977. 435 | # this originates from Karpathy's experiments. 436 | num_vocab = 50304 437 | model = GPT(GPTConfig(vocab_size=num_vocab, n_layer=12, n_head=6, n_embd=768)) 438 | model = model.cuda().bfloat16() 439 | for m in model.modules(): 440 | if isinstance(m, CastedLinear): 441 | m.float() 442 | if hasattr(config, "coordinate_descent_tuning"): 443 | config.coordinate_descent_tuning = True # suggested by @Chillee 444 | model = torch.compile(model) 445 | # here we wrap model into DDP container 446 | model = DDP(model, device_ids=[ddp_local_rank]) 447 | raw_model = model.module # always contains the "raw" unwrapped model 448 | 449 | # init the optimizer(s) 450 | optimizer1 = torch.optim.Adam([raw_model.transformer.wte.weight, raw_model.transformer.vte.weight], lr=0.6, betas=(0.8, 0.95), fused=True) 451 | optimizer2 = torch.optim.Adam([raw_model.lm_head.weight], lr=0.008, betas=(0.8, 0.95), fused=True) 452 | params = list(raw_model.transformer.h.parameters()) 453 | matrix_params = [p for p in params if p.ndim == 2] 454 | scalar_params = [p for p in params if p.ndim < 2] + [raw_model.skip_weights] 455 | optimizer3 = Muon(matrix_params, lr=0.05, momentum=0.95) 456 | optimizer4 = torch.optim.Adam(scalar_params, lr=0.04, betas=(0.8, 0.95), fused=True) # note that this learning rate is neither sensitive nor tuned 457 | optimizers = [optimizer1, optimizer2, optimizer3, optimizer4] 458 | # learning rate decay scheduler (linear warmup and cooldown) 459 | def get_lr(it): 460 | assert it <= args.num_iterations 461 | # 1) linear warmup for warmup_iters steps 462 | if it < args.warmup_iters: 463 | return (it+1) / args.warmup_iters 464 | # 2) constant lr for a while 465 | elif it < args.num_iterations - args.cooldown_iters: 466 | return 1.0 467 | # 3) linear cooldown 468 | else: 469 | decay_ratio = (args.num_iterations - it) / args.cooldown_iters 470 | return decay_ratio 471 | schedulers = [torch.optim.lr_scheduler.LambdaLR(opt, get_lr) for opt in optimizers] 472 | 473 | # Start training loop 474 | training_time_ms = 0 475 | # start the clock 476 | torch.cuda.synchronize() 477 | t0 = time.time() 478 | # begin training 479 | for step in range(args.num_iterations + 1): 480 | last_step = (step == args.num_iterations) 481 | # This effectively ignores timing first 10 steps, which are slower for weird reasons. 482 | # Alternately, and slightly more correctly in terms of benchmarking, we could do 10 483 | # steps with dummy data first, and then re-initialize the model and reset the loader. 484 | if step == 10: 485 | training_time_ms = 0 486 | t0 = time.time() 487 | timed_steps = float('nan') if step <= 11 else (step - 10) + 1 # <= 11 to avoid bug in val 488 | 489 | # Set the attention blocksize for the current step, in chunks of 64. By @fernbear.bsky.social 490 | attn_blocksize = torch.tensor(64*((step/args.num_iterations * (1792 - 64) + 64)//64), dtype=torch.int, device='cuda') 491 | 492 | # once in a while evaluate the validation dataset 493 | if (last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0)): 494 | # stop the clock 495 | torch.cuda.synchronize() 496 | training_time_ms += 1000 * (time.time() - t0) 497 | # run validation batches 498 | model.eval() 499 | val_loader.reset() 500 | val_loss = 0.0 501 | for _ in range(val_steps): 502 | with torch.no_grad(): 503 | x_val, y_val = val_loader.next_batch() 504 | val_loss += model(x_val, y_val, attn_blocksize=attn_blocksize) 505 | dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) 506 | val_loss /= val_steps 507 | # log val loss to console and to logfile 508 | print0(f'step:{step}/{args.num_iterations} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms') 509 | # start the clock again 510 | torch.cuda.synchronize() 511 | t0 = time.time() 512 | 513 | if master_process and (last_step or (args.save_every > 0 and step % args.save_every == 0)): 514 | # stop the clock 515 | torch.cuda.synchronize() 516 | training_time_ms += 1000 * (time.time() - t0) 517 | # save the state of the training process 518 | log = dict(step=step, code=code, model=raw_model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) 519 | torch.save(log, 'logs/%s/state_step%06d.pt' % (run_id, step)) 520 | # start the clock again 521 | torch.cuda.synchronize() 522 | t0 = time.time() 523 | 524 | # bit confusing: we want to make sure to eval on 0th iteration 525 | # but also after the very last iteration. so we loop for step <= num_iterations 526 | # instead of just < num_iterations (one extra due to <=), only to do 527 | # the validation/sampling one last time, and then we break right here as we're done. 528 | if last_step: 529 | break 530 | 531 | # --------------- TRAINING SECTION BEGIN ----------------- 532 | model.train() 533 | for i in range(1, train_accumulation_steps+1): 534 | ctx = model.no_sync() if i < train_accumulation_steps else contextlib.nullcontext() 535 | with ctx: # there's no need to sync gradients every accumulation step 536 | # forward pass 537 | loss = model(x, y, attn_blocksize=attn_blocksize) 538 | # advance the dataset for the next batch 539 | x, y = train_loader.next_batch() 540 | # backward pass 541 | loss.backward() 542 | train_loss = loss.detach() 543 | for p in model.parameters(): 544 | p.grad /= train_accumulation_steps 545 | # momentum warmup for Muon 546 | frac = min(step/300, 1) 547 | optimizer3.param_groups[0]['momentum'] = (1 - frac) * 0.85 + frac * 0.95 548 | # step the optimizers and schedulers 549 | for opt, sched in zip(optimizers, schedulers): 550 | opt.step() 551 | sched.step() 552 | # null the gradients 553 | model.zero_grad(set_to_none=True) 554 | # --------------- TRAINING SECTION END ------------------- 555 | # everything that follows now is just diagnostics, prints, logging, etc. 556 | 557 | #dist.all_reduce(train_loss, op=dist.ReduceOp.AVG) # all-reducing the training loss would be more correct in terms of logging, but slower 558 | approx_time = training_time_ms + 1000 * (time.time() - t0) 559 | print0(f"step:{step+1}/{args.num_iterations} train_loss:{train_loss.item():.4f} train_time:{approx_time:.0f}ms step_avg:{approx_time/timed_steps:.2f}ms") 560 | 561 | if master_process: 562 | print(f"peak memory consumption: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB") 563 | 564 | # ------------------------------------------------------------------------- 565 | # clean up nice 566 | dist.destroy_process_group() 567 | -------------------------------------------------------------------------------- /records/121724_SparsifyEmbeds/README.md: -------------------------------------------------------------------------------- 1 | # Sparsify embeds 2 | 3 | Running this new record by Jiacheng You 1261 times yielded the following series of val losses: 4 | ``` 5 | accs = [3.2772, 3.2794, 3.2786, 3.2807, 3.2816, 3.2827, 3.2805, 3.2787, 3.2757, 3.2798, 3.2795, 3.2785, 3.2785, 3.2803, 3.2804, 3.2816, 3.2814, 3.2808, 3.2776, 3.2809, 3.2792, 3.2776, 3.2784, 3.2809, 3.2804, 3.2779, 3.2799, 3.2761, 3.2763, 3.2794, 3.2754, 3.2816, 3.2818, 3.2785, 3.2837, 3.2765, 3.2805, 3.2784, 3.2783, 3.28, 3.2791, 3.2777, 3.2815, 3.2789, 3.2796, 3.2804, 3.2793, 3.2817, 3.2799, 3.2803, 3.2773, 3.283, 3.2781, 3.2785, 3.2771, 3.2824, 3.2819, 3.2791, 3.2799, 3.2792, 3.2771, 3.2799, 3.2782, 3.2811, 3.2786, 3.2774, 3.2786, 3.2807, 3.2775, 3.2778, 3.2778, 3.2801, 3.2764, 3.2774, 3.2801, 3.28, 3.2785, 3.2813, 3.2799, 3.2787, 3.2802, 3.2776, 3.2818, 3.2783, 3.2774, 3.2779, 3.279, 3.2777, 3.2814, 3.2783, 3.2796, 3.2822, 3.2785, 3.2784, 3.283, 3.2799, 3.2786, 3.2799, 3.2797, 3.2791, 3.2761, 3.278, 3.2775, 3.28, 3.2804, 3.2781, 3.2778, 3.2806, 3.2767, 3.2787, 3.2769, 3.2794, 3.2856, 3.2764, 3.278, 3.2814, 3.2803, 3.2781, 3.28, 3.2787, 3.2787, 3.2797, 3.2799, 3.2815, 3.2777, 3.2792, 3.2799, 3.2789, 3.2772, 3.2762, 3.2792, 3.2769, 3.28, 3.2786, 3.2751, 3.2818, 3.2791, 3.2776, 3.2778, 3.2796, 3.2793, 3.2785, 3.2826, 3.281, 3.2781, 3.2796, 3.2783, 3.2788, 3.2811, 3.2818, 3.2803, 3.2794, 3.2757, 3.2793, 3.277, 3.2765, 3.2785, 3.2788, 3.2796, 3.2773, 3.2778, 3.2802, 3.2837, 3.2753, 3.2831, 3.276, 3.2773, 3.2762, 3.2789, 3.2769, 3.2805, 3.2816, 3.2761, 3.2788, 3.2787, 3.2785, 3.2818, 3.2787, 3.2838, 3.279, 3.2805, 3.2807, 3.2804, 3.2797, 3.2752, 3.2838, 3.2834, 3.2792, 3.2804, 3.2793, 3.282, 3.2829, 3.2796, 3.2789, 3.279, 3.2778, 3.2787, 3.279, 3.279, 3.2789, 3.2756, 3.281, 3.28, 3.2804, 3.2796, 3.2803, 3.2795, 3.2781, 3.2783, 3.2772, 3.2807, 3.279, 3.2787, 3.2777, 3.2781, 3.2818, 3.2748, 3.2786, 3.2758, 3.2762, 3.2801, 3.2817, 3.2807, 3.2804, 3.2772, 3.281, 3.2766, 3.278, 3.2753, 3.2803, 3.2787, 3.2799, 3.2797, 3.2794, 3.2823, 3.2769, 3.2789, 3.2769, 3.277, 3.2806, 3.2799, 3.2787, 3.2786, 3.28, 3.28, 3.2813, 3.279, 3.2795, 3.2792, 3.2807, 3.2806, 3.2779, 3.2783, 3.2796, 3.2778, 3.2808, 3.2778, 3.2785, 3.2781, 3.2808, 3.2802, 3.2819, 3.2794, 3.2784, 3.2819, 3.2824, 3.2814, 3.2791, 3.2779, 3.2788, 3.2788, 3.2796, 3.2798, 3.2782, 3.2782, 3.2768, 3.2785, 3.2788, 3.2812, 3.2811, 3.2793, 3.2812, 3.2824, 3.2786, 3.2787, 3.2806, 3.2807, 3.2771, 3.2825, 3.2791, 3.2761, 3.2831, 3.2803, 3.2807, 3.2793, 3.2795, 3.2825, 3.276, 3.279, 3.2817, 3.2808, 3.279, 3.2793, 3.282, 3.2835, 3.2789, 3.2791, 3.2792, 3.2797, 3.281, 3.2795, 3.2775, 3.2772, 3.2818, 3.2787, 3.2775, 3.2814, 3.2787, 3.2818, 3.2772, 3.2796, 3.2787, 3.2815, 3.2795, 3.2799, 3.2785, 3.2772, 3.2788, 3.279, 3.2776, 3.2819, 3.2783, 3.2751, 3.2763, 3.2771, 3.2797, 3.2783, 3.2823, 3.2798, 3.277, 3.2813, 3.2774, 3.2801, 3.2821, 3.2806, 3.2833, 3.281, 3.2819, 3.2794, 3.2815, 3.279, 3.2837, 3.2779, 3.28, 3.2803, 3.2784, 3.2786, 3.2782, 3.2782, 3.2791, 3.279, 3.2806, 3.2801, 3.2807, 3.2797, 3.2767, 3.2796, 3.2798, 3.2816, 3.2766, 3.2823, 3.2772, 3.2765, 3.2784, 3.2775, 3.2779, 3.284, 3.2778, 3.2806, 3.2806, 3.281, 3.2787, 3.2823, 3.2771, 3.2768, 3.2782, 3.2822, 3.2785, 3.279, 3.2811, 3.2785, 3.2781, 3.2802, 3.2793, 3.2794, 3.2811, 3.2837, 3.2785, 3.2809, 3.283, 3.2813, 3.2805, 3.2769, 3.2806, 3.276, 3.2814, 3.28, 3.277, 3.2791, 3.2775, 3.279, 3.2802, 3.2809, 3.2815, 3.2763, 3.2862, 3.2791, 3.2791, 3.2763, 3.2789, 3.2792, 3.2816, 3.2792, 3.2775, 3.2803, 3.2809, 3.2828, 3.2805, 3.2794, 3.2801, 3.281, 3.2772, 3.2806, 3.2789, 3.2827, 3.2796, 3.2846, 3.2812, 3.2791, 3.2765, 3.2784, 3.2777, 3.2773, 3.2778, 3.2768, 3.2783, 3.2793, 3.2778, 3.2776, 3.2742, 3.2769, 3.2774, 3.2813, 3.2801, 3.2807, 3.2777, 3.2821, 3.2794, 3.2791, 3.279, 3.2763, 3.2804, 3.2803, 3.2795, 3.2805, 3.2815, 3.2801, 3.2823, 3.2798, 3.2802, 3.2784, 3.2801, 3.2792, 3.2856, 3.2805, 3.2782, 3.2808, 3.2793, 3.2804, 3.2775, 3.2798, 3.28, 3.2795, 3.2789, 3.2795, 3.2771, 3.2792, 3.2802, 3.2815, 3.2806, 3.2813, 3.2809, 3.2829, 3.2778, 3.282, 3.2825, 3.2789, 3.282, 3.2785, 3.2782, 3.2791, 3.2803, 3.276, 3.2777, 3.279, 3.278, 3.2781, 3.2815, 3.2776, 3.2823, 3.275, 3.2794, 3.2787, 3.2775, 3.2792, 3.2794, 3.2774, 3.2806, 3.2785, 3.28, 3.2758, 3.2795, 3.2778, 3.2796, 3.2817, 3.2785, 3.2815, 3.2797, 3.2749, 3.2785, 3.2804, 3.277, 3.2791, 3.2818, 3.2826, 3.2784, 3.2768, 3.2801, 3.2833, 3.2817, 3.2796, 3.2783, 3.2781, 3.2757, 3.2787, 3.2803, 3.2786, 3.2757, 3.2774, 3.2813, 3.2777, 3.2821, 3.2791, 3.2774, 3.2786, 3.2808, 3.2791, 3.279, 3.2813, 3.2818, 3.2771, 3.2861, 3.2805, 3.2789, 3.2769, 3.2809, 3.2823, 3.2854, 3.2819, 3.2789, 3.2796, 3.2815, 3.2781, 3.2819, 3.2802, 3.2788, 3.2767, 3.277, 3.2798, 3.2806, 3.2787, 3.2778, 3.2786, 3.2805, 3.2799, 3.2776, 3.2775, 3.2815, 3.2791, 3.28, 3.2789, 3.28, 3.2807, 3.2793, 3.2783, 3.2771, 3.2801, 3.2796, 3.2789, 3.2772, 3.2783, 3.2812, 3.2803, 3.2756, 3.2775, 3.2807, 3.2801, 3.2787, 3.28, 3.2827, 3.2801, 3.2798, 3.2785, 3.2798, 3.2787, 3.2785, 3.2796, 3.2762, 3.2808, 3.2788, 3.2775, 3.2765, 3.2792, 3.2784, 3.2787, 3.2793, 3.2793, 3.2784, 3.2773, 3.2812, 3.2785, 3.2759, 3.2781, 3.2786, 3.2783, 3.2804, 3.2791, 3.2791, 3.2772, 3.2803, 3.2773, 3.2778, 3.2809, 3.2815, 3.2784, 3.278, 3.2783, 3.2818, 3.2805, 3.2802, 3.2828, 3.2767, 3.2811, 3.2786, 3.2798, 3.2796, 3.2777, 3.2793, 3.277, 3.2762, 3.2773, 3.2796, 3.2786, 3.2809, 3.2797, 3.2796, 3.2815, 3.2803, 3.2833, 3.2793, 3.2773, 3.2761, 3.2832, 3.2798, 3.2801, 3.2806, 3.2803, 3.2797, 3.276, 3.2798, 3.2797, 3.2788, 3.2824, 3.2785, 3.2802, 3.2817, 3.2766, 3.2815, 3.2797, 3.279, 3.2808, 3.2776, 3.2789, 3.2783, 3.2772, 3.2803, 3.282, 3.2773, 3.2803, 3.28, 3.2772, 3.2827, 3.2804, 3.2776, 3.2794, 3.2815, 3.2836, 3.2813, 3.2794, 3.2795, 3.279, 3.2772, 3.2787, 3.2813, 3.2778, 3.2798, 3.2819, 3.2788, 3.2838, 3.2792, 3.2772, 3.2799, 3.2837, 3.2801, 3.2806, 3.2799, 3.2793, 3.2788, 3.2786, 3.2766, 3.2782, 3.281, 3.2783, 3.2789, 3.2801, 3.2759, 3.281, 3.2762, 3.2795, 3.2799, 3.2835, 3.2772, 3.2794, 3.2803, 3.2782, 3.2804, 3.2782, 3.28, 3.2766, 3.2823, 3.2771, 3.2775, 3.2811, 3.2789, 3.2808, 3.2787, 3.2805, 3.2812, 3.281, 3.2809, 3.2795, 3.2801, 3.2817, 3.2789, 3.2808, 3.2779, 3.2758, 3.2779, 3.276, 3.2779, 3.2823, 3.2818, 3.2816, 3.2806, 3.2807, 3.2788, 3.2778, 3.2821, 3.2777, 3.2779, 3.2775, 3.2785, 3.2794, 3.2813, 3.2825, 3.2812, 3.2801, 3.2782, 3.2807, 3.2797, 3.2781, 3.2778, 3.2778, 3.2803, 3.2832, 3.2819, 3.2783, 3.279, 3.2785, 3.279, 3.2786, 3.2793, 3.2798, 3.282, 3.2794, 3.2818, 3.2796, 3.2795, 3.2796, 3.2779, 3.2788, 3.2776, 3.2787, 3.2766, 3.279, 3.2764, 3.2831, 3.2819, 3.2791, 3.2784, 3.2793, 3.2824, 3.28, 3.2812, 3.2773, 3.2777, 3.283, 3.2774, 3.278, 3.2801, 3.2817, 3.2791, 3.2811, 3.281, 3.2817, 3.2803, 3.2791, 3.2816, 3.2785, 3.2797, 3.2805, 3.2809, 3.2825, 3.2799, 3.2777, 3.2803, 3.2787, 3.2783, 3.2784, 3.2781, 3.2754, 3.2801, 3.2782, 3.2792, 3.2776, 3.2786, 3.283, 3.28, 3.2771, 3.2808, 3.2774, 3.2787, 3.2788, 3.2787, 3.278, 3.2805, 3.2783, 3.2814, 3.2785, 3.2794, 3.2825, 3.2767, 3.2781, 3.2812, 3.2792, 3.2807, 3.2785, 3.2833, 3.2763, 3.2834, 3.2798, 3.278, 3.2783, 3.2781, 3.28, 3.2793, 3.2768, 3.2786, 3.2797, 3.2775, 3.2786, 3.2776, 3.2792, 3.277, 3.2784, 3.2804, 3.2795, 3.2802, 3.2761, 3.2778, 3.2796, 3.278, 3.2776, 3.2781, 3.2791, 3.2768, 3.2797, 3.2835, 3.2774, 3.2783, 3.2814, 3.2799, 3.2799, 3.2812, 3.2787, 3.2815, 3.2786, 3.2771, 3.2796, 3.2812, 3.276, 3.2814, 3.2775, 3.28, 3.2824, 3.2806, 3.2806, 3.2786, 3.2839, 3.2794, 3.2787, 3.2809, 3.2838, 3.2794, 3.2812, 3.282, 3.2783, 3.2813, 3.2803, 3.2784, 3.2826, 3.279, 3.2784, 3.2835, 3.2811, 3.2843, 3.2805, 3.282, 3.2805, 3.2798, 3.2786, 3.2785, 3.2804, 3.276, 3.2764, 3.2774, 3.2783, 3.279, 3.2815, 3.2803, 3.2768, 3.2796, 3.2801, 3.2808, 3.2778, 3.2798, 3.2804, 3.2814, 3.2782, 3.2801, 3.2787, 3.2795, 3.2792, 3.2791, 3.2776, 3.2774, 3.2776, 3.2796, 3.2795, 3.2766, 3.2774, 3.2773, 3.2804, 3.2785, 3.2802, 3.2805, 3.2802, 3.2793, 3.281, 3.2763, 3.2834, 3.2803, 3.2781, 3.2768, 3.2771, 3.278, 3.2779, 3.2815, 3.2812, 3.2807, 3.2819, 3.2824, 3.2812, 3.2806, 3.2782, 3.2797, 3.2782, 3.2793, 3.2755, 3.2808, 3.2816, 3.2796, 3.2817, 3.2779, 3.2774, 3.2774, 3.2774, 3.2794, 3.278, 3.2836, 3.2828, 3.279, 3.2805, 3.2824, 3.2776, 3.2795, 3.2807, 3.2783, 3.2809, 3.2789, 3.2771, 3.2792, 3.2775, 3.2809, 3.2813, 3.2797, 3.2788, 3.2792, 3.2763, 3.282, 3.2762, 3.2787, 3.282, 3.2791, 3.2781, 3.2778, 3.279, 3.2788, 3.2791, 3.2829, 3.2769, 3.28, 3.2768, 3.277, 3.2774, 3.2841, 3.2777, 3.2749, 3.2785, 3.2805, 3.2814, 3.2768, 3.2767, 3.2803, 3.2785, 3.2808, 3.2811, 3.2805, 3.2794, 3.2772, 3.2791, 3.2809, 3.2807, 3.2815, 3.2793, 3.2784, 3.2794, 3.2819, 3.2812, 3.2799, 3.2805, 3.2782, 3.2805, 3.2793, 3.2788, 3.2783, 3.2804, 3.2795, 3.2785, 3.2808, 3.2823, 3.2787, 3.2787, 3.278, 3.2791, 3.2805, 3.2808, 3.2787, 3.2779, 3.2781, 3.2787, 3.2779, 3.2775, 3.2789, 3.2784, 3.2803, 3.2774, 3.2798, 3.2772, 3.28, 3.2816, 3.2792, 3.278, 3.2792, 3.2787, 3.2813, 3.2799, 3.2802, 3.281, 3.2768, 3.2811, 3.2772, 3.2802, 3.2822, 3.2789, 3.2762, 3.2775, 3.2799, 3.2792, 3.2795, 3.2792, 3.2793, 3.2817, 3.2784, 3.28, 3.2792, 3.2788, 3.2815, 3.2782, 3.2826, 3.28, 3.2782, 3.2792, 3.2757, 3.2766, 3.2788, 3.2778, 3.2788, 3.2797, 3.2797, 3.2777, 3.2783, 3.2778, 3.2799, 3.2812, 3.2813, 3.2802, 3.2818, 3.2801, 3.277, 3.2839, 3.2806, 3.2777, 3.2805, 3.278, 3.279, 3.2775, 3.28, 3.2774, 3.2789, 3.277, 3.2807, 3.2805, 3.2795, 3.2777, 3.2813, 3.2805, 3.2809, 3.2814, 3.2794, 3.2797, 3.2803, 3.2802, 3.2808, 3.278, 3.275, 3.283, 3.2791, 3.2761, 3.2787, 3.2797, 3.2781, 3.2754, 3.2775, 3.2797, 3.281, 3.2792, 3.2797, 3.2812, 3.2781, 3.2782, 3.2803, 3.2778, 3.2812, 3.2809, 3.2781, 3.2769, 3.2797, 3.2774, 3.2787, 3.2805, 3.2796, 3.2771, 3.2776, 3.2784, 3.2757, 3.2784, 3.2795, 3.2802, 3.2788, 3.2787, 3.2778, 3.2832, 3.2784, 3.2802, 3.2805, 3.2787, 3.2786, 3.2801, 3.2811, 3.2809, 3.2793, 3.2809, 3.2786, 3.2777, 3.2801, 3.2787, 3.2838, 3.2751, 3.2777, 3.2806, 3.2786, 3.2772, 3.2797, 3.278, 3.281, 3.2808, 3.2769, 3.2774, 3.2779, 3.2836, 3.2792, 3.2786, 3.2834, 3.2786, 3.2781, 3.2809, 3.2788, 3.2799, 3.28, 3.2807] 6 | 7 | import scipy.stats 8 | print('p=%.4f' % scipy.stats.ttest_1samp(accs, 3.28, alternative='less').pvalue) 9 | # p=0.0000 10 | 11 | import torch 12 | print(torch.std_mean(torch.tensor(accs))) 13 | # (tensor(0.0018), tensor(3.2794)) 14 | ``` 15 | 16 | ![](loss_hist_121724.png) 17 | 18 | -------------------------------------------------------------------------------- /records/121724_SparsifyEmbeds/loss_hist_121724.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KellerJordan/modded-nanogpt/ca964e982191830eebbd155e185937077511a8aa/records/121724_SparsifyEmbeds/loss_hist_121724.png -------------------------------------------------------------------------------- /records/123124_Target350M/README.md: -------------------------------------------------------------------------------- 1 | The `train_gpt.py` in this folder attains ~2.95 loss in ~26min on 8xH100 2 | 3 | It can therefore form the first record for the task of matching the performance attained by Karpathy using llm.c to train a [350M parameter model on 30B tokens](https://github.com/karpathy/llm.c/discussions/481). 4 | 5 | I'd be happy to see / help boost competition on this task as well 6 | 7 | -------------------------------------------------------------------------------- /records/123124_Target350M/train_gpt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | with open(sys.argv[0]) as f: 4 | code = f.read() # read the code of this file ASAP, for logging 5 | import uuid 6 | import time 7 | import glob 8 | import contextlib 9 | from dataclasses import dataclass 10 | 11 | import torch 12 | torch.empty(1, device='cuda', requires_grad=True).backward() 13 | from torch import nn 14 | import torch.nn.functional as F 15 | import torch.distributed as dist 16 | from torch.nn.parallel import DistributedDataParallel as DDP 17 | from torch.nn.attention.flex_attention import BlockMask, flex_attention #KoszarskyB 18 | 19 | # ----------------------------------------------------------------------------- 20 | # Muon optimizer 21 | 22 | @torch.compile 23 | def zeropower_via_newtonschulz5(G, steps): 24 | """ 25 | Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a 26 | quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose 27 | of minimizing steps, it turns out to be empirically effective to keep increasing the slope at 28 | zero even beyond the point where the iteration no longer converges all the way to one everywhere 29 | on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T 30 | where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model 31 | performance at all relative to UV^T, where USV^T = G is the SVD. 32 | """ 33 | assert len(G.shape) == 2 34 | a, b, c = (3.4445, -4.7750, 2.0315) 35 | X = G.bfloat16() 36 | if G.size(0) > G.size(1): 37 | X = X.T 38 | 39 | # Ensure spectral norm is at most 1 40 | X = X / (X.norm() + 1e-7) 41 | # Perform the NS iterations 42 | for _ in range(steps): 43 | A = X @ X.T 44 | B = b * A + c * A @ A # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng 45 | X = a * X + B @ X 46 | 47 | if G.size(0) > G.size(1): 48 | X = X.T 49 | return X 50 | 51 | class Muon(torch.optim.Optimizer): 52 | """ 53 | Muon - MomentUm Orthogonalized by Newton-schulz 54 | 55 | Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- 56 | processing step, in which each 2D parameter's update is replaced with the nearest orthogonal 57 | matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has 58 | the advantage that it can be stably run in bfloat16 on the GPU. 59 | 60 | Some warnings: 61 | - This optimizer assumes that all parameters passed in are 2D. 62 | - It should not be used for the embedding layer, the final fully connected layer, or any {0,1}-D 63 | parameters; those should all be optimized by a standard method (e.g., AdamW). 64 | - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions. 65 | - We believe it is unlikely to work well for training with small batch size. 66 | - We believe it may not work well for finetuning pretrained models, but we haven't tested this. 67 | - We have not yet tried this optimizer for training scenarios larger than NanoGPT (124M). 68 | 69 | Arguments: 70 | lr: The learning rate used by the internal SGD. 71 | momentum: The momentum used by the internal SGD. 72 | nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) 73 | ns_steps: The number of Newton-Schulz iteration steps to use. 74 | """ 75 | def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5): 76 | self.world_size = int(os.environ['WORLD_SIZE']) 77 | self.rank = int(os.environ['RANK']) 78 | defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps) 79 | assert all(isinstance(p, torch.Tensor) for p in params) 80 | sizes = {p.numel() for p in params} 81 | param_groups = [dict(params=[p for p in params if p.numel() == size], 82 | update_buffer=[torch.empty(size, device='cuda', dtype=torch.bfloat16) for _ in range(self.world_size)]) 83 | for size in sizes] 84 | super().__init__(param_groups, defaults) 85 | 86 | def step(self): 87 | 88 | for group in self.param_groups: 89 | 90 | lr = group['lr'] 91 | momentum = group['momentum'] 92 | nesterov = group['nesterov'] 93 | ns_steps = group['ns_steps'] 94 | update_buffers = group['update_buffer'] 95 | # generate weight updates in distributed fashion 96 | params = group['params'] 97 | handle = None 98 | params_world = None 99 | def update_prev(): 100 | if params_world is None: 101 | return 102 | assert handle is not None 103 | handle.wait() 104 | for p_world, g_world in zip(params_world, update_buffers): 105 | p_world.data.add_( 106 | g_world.view_as(p_world), 107 | alpha=-lr * max(1, p_world.size(0) / p_world.size(1)) ** 0.5, 108 | ) 109 | for base_i in range(len(params))[::self.world_size]: 110 | if base_i + rank < len(params): 111 | p = params[base_i + self.rank] 112 | g = p.grad 113 | assert g is not None 114 | state = self.state[p] 115 | if 'momentum_buffer' not in state: 116 | state['momentum_buffer'] = torch.zeros_like(g) 117 | buf = state['momentum_buffer'] 118 | buf.lerp_(g, 1 - momentum) 119 | g = g.lerp_(buf, momentum) if nesterov else buf 120 | g = zeropower_via_newtonschulz5(g, steps=ns_steps).flatten() 121 | else: 122 | g = update_buffers[rank] 123 | update_prev() 124 | handle = dist.all_gather(update_buffers, g, async_op=True) 125 | params_world = params[base_i : base_i + self.world_size] 126 | update_prev() 127 | 128 | # ----------------------------------------------------------------------------- 129 | # PyTorch nn.Module definitions for the GPT-2 model 130 | 131 | def norm(x): 132 | return F.rms_norm(x, (x.size(-1),)) 133 | 134 | class CastedLinear(nn.Linear): 135 | 136 | def __init__(self, in_features, out_features): 137 | super().__init__(in_features, out_features, bias=False) 138 | 139 | def forward(self, x): 140 | return F.linear(x, self.weight.to(x.dtype)) 141 | 142 | class Rotary(nn.Module): 143 | 144 | def __init__(self, dim, max_seq_len=65536): 145 | super().__init__() 146 | inv_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) 147 | inv_freq = torch.cat([inv_freq, inv_freq.new_zeros(dim//4)]) 148 | t = torch.arange(max_seq_len, dtype=torch.float32) 149 | theta = torch.einsum('i,j -> ij', t, inv_freq) 150 | self.cos = nn.Buffer(theta.cos(), persistent=False) 151 | self.sin = nn.Buffer(theta.sin(), persistent=False) 152 | 153 | def forward(self, x): 154 | cos, sin = self.cos[None, :x.size(-3), None, :], self.sin[None, :x.size(-3), None, :] 155 | x1, x2 = x.float().chunk(2, dim=-1) 156 | y1 = x1 * cos + x2 * sin 157 | y2 = x1 * (-sin) + x2 * cos 158 | return torch.cat((y1, y2), 3).type_as(x) 159 | 160 | class CausalSelfAttention(nn.Module): 161 | 162 | def __init__(self, dim, num_heads): 163 | super().__init__() 164 | assert dim % num_heads == 0 165 | self.num_heads = num_heads 166 | self.c_q = CastedLinear(dim, dim) 167 | self.c_k = CastedLinear(dim, dim) 168 | self.c_v = CastedLinear(dim, dim) 169 | self.lambdas = nn.Parameter(torch.tensor([0.5, 0.5])) 170 | self.rotary = Rotary(dim // num_heads) # dim // num_heads = head_dim 171 | self.c_proj = CastedLinear(dim, dim) 172 | self.c_proj.weight.data.zero_() # zero init suggested by @Grad62304977 173 | 174 | def forward(self, x, ve, block_mask): 175 | B, T = x.size(0), x.size(1) # batch size, sequence length 176 | assert B == 1, 'Must use batch size = 1 for FlexAttention' 177 | q = self.c_q(x).view(B, T, self.num_heads, -1) 178 | k = self.c_k(x).view(B, T, self.num_heads, -1) 179 | v = self.c_v(x).view(B, T, self.num_heads, -1) 180 | if ve is not None: 181 | v = self.lambdas[0] * v + self.lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 182 | else: 183 | v = self.lambdas[0] * v 184 | q, k = norm(q), norm(k) # QK norm @Grad62304977 185 | q, k = self.rotary(q), self.rotary(k) 186 | y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask, enable_gqa=True) 187 | y = y.transpose(1, 2).contiguous().view_as(x) # re-assemble all head outputs side by side 188 | y = self.c_proj(y) 189 | return y 190 | 191 | class MLP(nn.Module): 192 | 193 | def __init__(self, dim): 194 | super().__init__() 195 | self.c_fc = CastedLinear(dim, 4 * dim) 196 | self.c_proj = CastedLinear(4 * dim, dim) 197 | self.c_proj.weight.data.zero_() # zero init suggested by @Grad62304977 198 | 199 | def forward(self, x): 200 | x = self.c_fc(x) 201 | x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 202 | x = self.c_proj(x) 203 | return x 204 | 205 | class Block(nn.Module): 206 | 207 | def __init__(self, model_dim, num_heads, use_attn=True): 208 | super().__init__() 209 | self.attn = CausalSelfAttention(model_dim, num_heads) if use_attn else None 210 | self.mlp = MLP(model_dim) 211 | self.lambdas = nn.Parameter(torch.tensor([1., 0.])) 212 | 213 | def forward(self, x, ve, x0, block_mask): 214 | x = self.lambdas[0] * x + self.lambdas[1] * x0 215 | if self.attn is not None: 216 | x = x + self.attn(norm(x), ve, block_mask) 217 | x = x + self.mlp(norm(x)) 218 | return x 219 | 220 | class ValueEmbedding(nn.Module): 221 | def __init__(self, vocab_size, model_dim): 222 | super().__init__() 223 | self.embed = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) 224 | 225 | def forward(self, inputs): 226 | ve = [emb(inputs) for emb in self.embed] 227 | ve = [ve[0], ve[1], ve[2], None, None, None, None, None, None, None, None, None, None, ve[0], ve[1], ve[2]] 228 | return ve 229 | 230 | # ----------------------------------------------------------------------------- 231 | # The main GPT-2 model 232 | 233 | class GPT(nn.Module): 234 | 235 | def __init__(self, vocab_size, num_layers, num_heads, model_dim): 236 | super().__init__() 237 | self.embed = nn.Embedding(vocab_size, model_dim) 238 | self.blocks = nn.ModuleList([Block(model_dim, num_heads, use_attn=(i != 7)) 239 | for i in range(num_layers)]) 240 | # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual learning 241 | # U-net structure on token value embeddings by @leloykun 242 | self.value_embeds = ValueEmbedding(vocab_size, model_dim) 243 | self.lm_head = CastedLinear(model_dim, vocab_size) 244 | self.lm_head.weight.data.zero_() # @Grad62304977 245 | # U-net design by @brendanh0gan 246 | self.num_encoder_layers = num_layers // 2 # Half of the layers for encoder 247 | self.num_decoder_layers = num_layers - self.num_encoder_layers # Remaining for decoder 248 | # Add learnable skip connection weights for decoder layers 249 | self.skip_weights = nn.Parameter(torch.ones(self.num_decoder_layers)) 250 | 251 | def forward(self, inputs, targets, sliding_window_num_blocks): 252 | BLOCK_SIZE = 128 253 | seq_len = len(inputs) 254 | assert seq_len % BLOCK_SIZE == 0 255 | total_num_blocks = seq_len // BLOCK_SIZE 256 | assert inputs.ndim == 1 257 | docs = (inputs == 50256).cumsum(0) 258 | docs_low = docs.view(-1, BLOCK_SIZE)[:, 0].contiguous() 259 | docs_high = docs.view(-1, BLOCK_SIZE)[:, -1].contiguous() 260 | 261 | def document_causal(b, h, q_idx, kv_idx): 262 | causal_mask = q_idx >= kv_idx 263 | document_mask = docs[q_idx] == docs[kv_idx] 264 | return causal_mask & document_mask 265 | 266 | def dense_to_ordered(dense_mask): 267 | num_blocks = dense_mask.sum(dim=-1, dtype=torch.int32) 268 | indices = dense_mask.argsort(dim=-1, descending=True, stable=True).to(torch.int32) 269 | return num_blocks[None, None].contiguous(), indices[None, None].contiguous() 270 | 271 | def create_doc_swc_block_mask(sliding_window_num_blocks): 272 | kv_idx = block_idx = torch.arange(total_num_blocks, dtype=torch.int32, device='cuda') 273 | q_idx = block_idx[:, None] 274 | causal_bm = q_idx >= kv_idx 275 | causal_full_bm = q_idx > kv_idx 276 | window_bm = q_idx - kv_idx < sliding_window_num_blocks 277 | window_full_bm = window_bm 278 | # document_bm = (docs_low[q_idx] <= docs_high[kv_idx]) & (docs_low[kv_idx] <= docs_high[q_idx]) 279 | document_bm = (docs_low[:, None] <= docs_high) & (docs_low <= docs_high[:, None]) 280 | document_full_bm = (docs_low[:, None] == docs_high) & (docs_low == docs_high[:, None]) 281 | nonzero_bm = causal_bm & window_bm & document_bm 282 | full_bm = causal_full_bm & window_full_bm & document_full_bm 283 | kv_num_blocks, kv_indices = dense_to_ordered(nonzero_bm & ~full_bm) 284 | full_kv_num_blocks, full_kv_indices = dense_to_ordered(full_bm) 285 | return BlockMask.from_kv_blocks( 286 | kv_num_blocks, 287 | kv_indices, 288 | full_kv_num_blocks, 289 | full_kv_indices, 290 | BLOCK_SIZE=BLOCK_SIZE, 291 | mask_mod=document_causal, 292 | ) 293 | 294 | block_mask = create_doc_swc_block_mask(sliding_window_num_blocks) 295 | 296 | # forward the GPT model itself 297 | x = self.embed(inputs[None]) # token embeddings of shape (b, t, model_dim) 298 | x = norm(x) # @Grad62304977 299 | x0 = x 300 | ve = self.value_embeds(inputs) 301 | ve_enc, ve_dec = ve[:self.num_encoder_layers], ve[self.num_encoder_layers:] 302 | 303 | # Store outputs for U-Net skip connections 304 | skip_connections = [] 305 | # Encoder pass - process only the first half of the blocks 306 | for i in range(self.num_encoder_layers): 307 | x = self.blocks[i](x, ve_enc[i], x0, block_mask) 308 | skip_connections.append(x) 309 | # Decoder pass - process the remaining blocks with weighted skip connections 310 | for i in range(self.num_decoder_layers): 311 | x = x + self.skip_weights[i] * skip_connections.pop() 312 | # U-net structure on token value embeddings by @leloykun 313 | x = self.blocks[self.num_encoder_layers + i](x, ve_dec[i], x0, block_mask) 314 | 315 | x = norm(x) 316 | logits = self.lm_head(x) 317 | logits = 30 * torch.tanh(logits / 30) # @Grad62304977 318 | logits = logits.float() 319 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) 320 | return loss 321 | 322 | # ----------------------------------------------------------------------------- 323 | # Our own simple Distributed Data Loader 324 | 325 | def _load_data_shard(path): 326 | # only reads the header, returns header data 327 | # header is 256 int32 328 | header = torch.from_file(path, False, 256, dtype=torch.int32) 329 | assert header[0] == 20240520, 'magic number mismatch in the data .bin file' 330 | assert header[1] == 1, 'unsupported version' 331 | num_tokens = int(header[2]) # number of tokens (claimed) 332 | with open(path, 'rb', buffering=0) as f: 333 | tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) 334 | f.seek(256 * 4) 335 | nbytes = f.readinto(tokens.numpy()) 336 | assert nbytes == 2 * num_tokens, 'number of tokens read does not match header' 337 | return tokens 338 | 339 | class DistributedDataLoader: 340 | 341 | def __init__(self, filename_pattern): 342 | self.rank = int(os.environ['RANK']) 343 | self.world_size = int(os.environ['WORLD_SIZE']) 344 | self.files = sorted(glob.glob(filename_pattern)) 345 | self.reset() 346 | 347 | def reset(self): 348 | self.current_shard = -1 349 | self.advance() 350 | 351 | def advance(self): 352 | self.current_shard = (self.current_shard + 1) % len(self.files) 353 | self.current_position = 0 354 | self.tokens = _load_data_shard(self.files[self.current_shard]) 355 | 356 | def next_batch(self, batch_size): 357 | assert batch_size % self.world_size == 0 358 | device_batch_size = batch_size // self.world_size 359 | # load next shard if necessary 360 | if self.current_position + batch_size + 1 >= len(self.tokens): 361 | self.advance() 362 | pos = self.current_position + self.rank * device_batch_size 363 | device_batch_tokens = self.tokens[pos:pos+device_batch_size+1] 364 | # advance current position 365 | self.current_position += batch_size 366 | inputs = device_batch_tokens[:-1].to(device='cuda', dtype=torch.int32, non_blocking=True) 367 | targets = device_batch_tokens[1:].to(device='cuda', dtype=torch.int64, non_blocking=True) 368 | return inputs, targets 369 | 370 | # ----------------------------------------------------------------------------- 371 | # int main 372 | 373 | @dataclass 374 | class Hyperparameters: 375 | # data 376 | train_bin : str = 'data/fineweb10B/fineweb_train_*.bin' # input .bin to train on 377 | val_bin : str = 'data/fineweb10B/fineweb_val_*.bin' # input .bin to eval validation loss on 378 | # optimization 379 | batch_size : int = 8*64*1024 # batch size in tokens 380 | device_batch_size : int = 64*1024 # batch size per device in tokens 381 | num_iterations : int = 6000 # number of iterations to run 382 | cooldown_iters : int = 2000 # number of iterations of linear warmup/cooldown for triangular or trapezoidal schedule 383 | bf16_embeds : bool = True 384 | # evaluation and logging 385 | val_loss_every : int = 125 # every how many steps to evaluate val loss? 0 for only at the end 386 | val_tokens : int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons 387 | # implementation 388 | save_checkpoint : bool = False 389 | args = Hyperparameters() 390 | 391 | # set up DDP (distributed data parallel). torchrun sets this env variable 392 | rank = int(os.environ['RANK']) 393 | local_rank = int(os.environ['LOCAL_RANK']) 394 | world_size = int(os.environ['WORLD_SIZE']) 395 | assert torch.cuda.is_available() 396 | torch.cuda.set_device(local_rank) 397 | dist.init_process_group(backend='nccl', device_id=torch.device(local_rank)) 398 | dist.barrier() 399 | master_process = (rank == 0) # this process will do logging, checkpointing etc. 400 | 401 | assert args.batch_size % args.device_batch_size == 0 402 | assert (args.batch_size // args.device_batch_size) == world_size 403 | 404 | # begin logging 405 | logfile = None 406 | if master_process: 407 | run_id = uuid.uuid4() 408 | os.makedirs('logs', exist_ok=True) 409 | logfile = f'logs/{run_id}.txt' 410 | print(logfile) 411 | 412 | def print0(s, console=False): 413 | if master_process: 414 | with open(logfile, 'a') as f: 415 | if console: 416 | print(s) 417 | print(s, file=f) 418 | 419 | # begin by printing this file (the Python code) 420 | print0(code) 421 | print0('='*100) 422 | # log information about the hardware/software environment this is running on 423 | print0(f'Running Python {sys.version}') 424 | print0(f'Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}') 425 | import subprocess 426 | result = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) 427 | print0(f'{result.stdout}') 428 | print0('='*100) 429 | 430 | # load data 431 | train_loader = DistributedDataLoader(args.train_bin) 432 | val_loader = DistributedDataLoader(args.val_bin) 433 | print0(f'Training dataloader files: {train_loader.files}') 434 | print0(f'Validation dataloader files: {val_loader.files}') 435 | print0('='*100) 436 | inputs_train, targets_train = train_loader.next_batch(args.batch_size) 437 | 438 | # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. suggested to me by @Grad62304977. 439 | # this originates from Karpathy's experiments. 440 | model = GPT(vocab_size=50304, num_layers=16, num_heads=8, model_dim=1024) 441 | model = model.cuda() 442 | if args.bf16_embeds: 443 | for m in model.modules(): 444 | if isinstance(m, nn.Embedding): 445 | m.bfloat16() 446 | model = torch.compile(model) 447 | ddp_model = DDP(model, device_ids=[local_rank], broadcast_buffers=False, gradient_as_bucket_view=True) 448 | sliding_window_num_blocks = torch.tensor(1, dtype=torch.int32, device='cuda') 449 | sw_num_blocks_prev = 1 450 | 451 | # collect the parameters to optimize 452 | hidden_matrix_params = [p for p in model.blocks.parameters() if p.ndim == 2] 453 | embed_params = [model.embed.weight, *model.value_embeds.parameters()] 454 | scalar_params = [p for p in model.parameters() if p.ndim < 2] 455 | head_params = [model.lm_head.weight] 456 | 457 | # init the optimizer(s) 458 | optimizer1 = torch.optim.Adam([dict(params=embed_params, lr=0.35), 459 | dict(params=head_params, lr=0.004), 460 | dict(params=scalar_params, lr=0.02)], 461 | betas=(0.8, 0.95), fused=True) 462 | optimizer2 = Muon(hidden_matrix_params, lr=0.03, momentum=0.95) 463 | optimizers = [optimizer1, optimizer2] 464 | 465 | # learning rate decay scheduler (stable then decay) 466 | def get_lr(it): 467 | assert it <= args.num_iterations 468 | # 1) constant lr for first part of training 469 | if it < args.num_iterations - args.cooldown_iters: 470 | return 1.0 471 | # 2) then linear cooldown 472 | else: 473 | return (args.num_iterations - it) / args.cooldown_iters 474 | schedulers = [torch.optim.lr_scheduler.LambdaLR(opt, get_lr) for opt in optimizers] 475 | 476 | # Start training loop 477 | training_time_ms = 0 478 | # start the clock 479 | torch.cuda.synchronize() 480 | t0 = time.perf_counter() 481 | # begin training 482 | train_steps = args.num_iterations 483 | for step in range(train_steps + 1): 484 | last_step = (step == train_steps) 485 | # This effectively ignores timing first 10 steps, which are slower for weird reasons. 486 | # Alternately, and slightly more correctly in terms of benchmarking, we could do 10 487 | # steps with dummy data first, and then re-initialize the model and reset the loader. 488 | if step == 10: 489 | training_time_ms = 0 490 | t0 = time.perf_counter() 491 | timed_steps = float('nan') if step <= 11 else (step - 10) + 1 # <= 11 to avoid bug in val 492 | 493 | # Linearly increase the sliding window size over training in chunks of 128 from 128 -> 1856. By @fernbear.bsky.social 494 | frac_done = step / train_steps # training progress 495 | sw_num_blocks = int(((1 - frac_done) * 128 + frac_done * 1856) // 128) 496 | if sw_num_blocks != sw_num_blocks_prev: 497 | sliding_window_num_blocks.copy_(sw_num_blocks, non_blocking=True) 498 | sw_num_blocks_prev = sw_num_blocks 499 | 500 | # --------------- VALIDATION SECTION ----------------- 501 | if (last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0)): 502 | # stop the clock 503 | torch.cuda.synchronize() 504 | training_time_ms += 1000 * (time.perf_counter() - t0) 505 | # run validation batches 506 | model.eval() 507 | val_loader.reset() 508 | val_loss = 0.0 509 | # calculate the number of steps to take in the val loop. 510 | assert args.val_tokens % args.batch_size == 0 511 | val_steps = args.val_tokens // args.batch_size 512 | for _ in range(val_steps): 513 | with torch.no_grad(): 514 | inputs_val, targets_val = val_loader.next_batch(args.batch_size) 515 | val_loss += ddp_model(inputs_val, targets_val, sliding_window_num_blocks) 516 | dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) 517 | val_loss /= val_steps 518 | # logging 519 | print0(f'step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms', console=True) 520 | # start the clock again 521 | torch.cuda.synchronize() 522 | t0 = time.perf_counter() 523 | 524 | if last_step: 525 | if master_process and args.save_checkpoint: 526 | log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) 527 | os.makedirs(f'logs/{run_id}', exist_ok=True) 528 | torch.save(log, f'logs/{run_id}/state_step{step:06d}.pt') 529 | # the last step only has the validation loop, so break to avoid training 530 | break 531 | 532 | # --------------- TRAINING SECTION ----------------- 533 | model.train() 534 | ddp_model(inputs_train, targets_train, sliding_window_num_blocks).backward() 535 | inputs_train, targets_train = train_loader.next_batch(args.batch_size) 536 | # momentum warmup for Muon 537 | frac = min(step/300, 1) 538 | for group in optimizer2.param_groups: 539 | group['momentum'] = (1 - frac) * 0.85 + frac * 0.95 540 | # step the optimizers and schedulers 541 | for opt, sched in zip(optimizers, schedulers): 542 | opt.step() 543 | sched.step() 544 | # null the gradients 545 | model.zero_grad(set_to_none=True) 546 | # logging 547 | approx_time = training_time_ms + 1000 * (time.perf_counter() - t0) 548 | print0(f'step:{step+1}/{train_steps} train_time:{approx_time:.0f}ms step_avg:{approx_time/timed_steps:.2f}ms', console=True) 549 | 550 | print0(f'peak memory consumption: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB') 551 | dist.destroy_process_group() 552 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | tqdm 3 | torch 4 | huggingface-hub 5 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | torchrun --standalone --nproc_per_node=8 train_gpt.py 2 | -------------------------------------------------------------------------------- /train_gpt_medium.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | with open(sys.argv[0]) as f: 4 | code = f.read() # read the code of this file ASAP, for logging 5 | import uuid 6 | import time 7 | import copy 8 | from dataclasses import dataclass 9 | from functools import lru_cache 10 | from pathlib import Path 11 | 12 | os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" 13 | import torch 14 | torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems 15 | from torch import Tensor, nn 16 | import torch.nn.functional as F 17 | import torch.distributed as dist 18 | # use of FlexAttention contributed by @KoszarskyB 19 | from torch.nn.attention.flex_attention import BlockMask, flex_attention 20 | torch._inductor.config.coordinate_descent_tuning = True # we allow this flag for medium track 21 | torch._dynamo.config.compiled_autograd = True 22 | 23 | # ----------------------------------------------------------------------------- 24 | # Muon optimizer 25 | 26 | def zeropower_via_newtonschulz5(G: Tensor) -> Tensor: 27 | """ 28 | Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a 29 | quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose 30 | of minimizing steps, it turns out to be empirically effective to keep increasing the slope at 31 | zero even beyond the point where the iteration no longer converges all the way to one everywhere 32 | on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T 33 | where S' is diagonal with S_{ii}' ∈ [1 - l, 1 + r], which turns out not to hurt model 34 | performance at all relative to UV^T, where USV^T = G is the SVD. 35 | """ 36 | assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng 37 | X = G.bfloat16() 38 | if G.size(-2) > G.size(-1): 39 | X = X.mT 40 | 41 | # Ensure spectral norm is at most 1 42 | X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) 43 | # Perform the NS iterations 44 | for a, b, c in [ 45 | (4.0848, -6.8946, 2.9270), 46 | (3.9505, -6.3029, 2.6377), 47 | (3.7418, -5.5913, 2.3037), 48 | (2.8769, -3.1427, 1.2046), 49 | (2.8366, -3.0525, 1.2012), 50 | ]: 51 | A = X @ X.mT 52 | B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng 53 | X = a * X + B @ X 54 | 55 | if G.size(-2) > G.size(-1): 56 | X = X.mT 57 | return X 58 | 59 | @torch.compile 60 | def update(acc_bf16_view_u16: Tensor, mantissa: Tensor, momentum_buffer: Tensor, grad: Tensor, momentum: Tensor, eff_lr: Tensor, eff_weight_decay: Tensor): 61 | assert acc_bf16_view_u16.dtype == mantissa.dtype == torch.uint16 62 | grad = grad.float() 63 | momentum_buffer.copy_(momentum * momentum_buffer + (1 - momentum) * grad) 64 | v = zeropower_via_newtonschulz5(momentum * momentum_buffer + (1 - momentum) * grad) 65 | 66 | acc_m_u32 = (acc_bf16_view_u16.to(torch.uint32) << 16) | mantissa.to(torch.uint32) 67 | acc_m_u32.view(torch.float32).mul_(1 - eff_weight_decay) 68 | acc_m_u32.view(torch.float32).add_(other=v, alpha=-eff_lr) 69 | acc_bf16_view_u16.copy_((acc_m_u32 >> 16).to(torch.uint16)) 70 | mantissa.copy_(acc_m_u32.to(torch.uint16)) 71 | 72 | class Muon(torch.optim.Optimizer): 73 | """ 74 | Muon - MomentUm Orthogonalized by Newton-schulz 75 | 76 | https://kellerjordan.github.io/posts/muon/ 77 | 78 | Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- 79 | processing step, in which each 2D parameter's update is replaced with the nearest orthogonal 80 | matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has 81 | the advantage that it can be stably run in bfloat16 on the GPU. 82 | 83 | Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, 84 | or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). 85 | """ 86 | def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95, rank=0, world_size=1): 87 | self.rank = rank 88 | self.world_size = world_size 89 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) 90 | super().__init__(params, defaults) 91 | assert all(p.dtype == torch.bfloat16 for group in self.param_groups for p in group["params"]) 92 | 93 | @torch.no_grad() 94 | def step(self): 95 | futures: list[torch.Future] = [] 96 | for group in self.param_groups: 97 | params: list[Tensor] = group["params"] 98 | params_pad = params + [torch.empty_like(params[-1])] * self.world_size 99 | momentum = torch._as_tensor_fullprec(group["momentum"]) 100 | for base_i in range(len(params))[::self.world_size]: 101 | if base_i + self.rank < len(params): 102 | p = params[base_i + self.rank] 103 | state = self.state[p] 104 | if len(state) == 0: 105 | state["mantissa"] = torch.zeros_like(p, dtype=torch.uint16) 106 | state["momentum_buffer"] = torch.zeros_like(p, dtype=torch.float32) 107 | update( 108 | p.view(torch.uint16), state["mantissa"], state["momentum_buffer"], 109 | p.grad, momentum, 110 | eff_lr=torch._as_tensor_fullprec(group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5), 111 | eff_weight_decay=torch._as_tensor_fullprec(group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0)), 112 | ) 113 | futures.append(dist.all_gather(params_pad[base_i:base_i + self.world_size], params_pad[base_i + self.rank], async_op=True).get_future()) 114 | torch.futures.collect_all(futures).wait() 115 | 116 | # ----------------------------------------------------------------------------- 117 | # PyTorch nn.Module definitions for the model 118 | 119 | def norm(x: Tensor): 120 | return F.rms_norm(x, (x.size(-1),)) 121 | 122 | @torch.no_grad() 123 | def init_linear(w: Tensor): 124 | std = 0.5 * (w.size(-1) ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) 125 | bound = (3 ** 0.5) * std 126 | return w.uniform_(-bound, bound) 127 | 128 | class Rotary(nn.Module): 129 | def __init__(self, dim: int, max_seq_len: int): 130 | super().__init__() 131 | # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) 132 | angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) 133 | angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) 134 | t = torch.arange(max_seq_len, dtype=torch.float32) 135 | theta = torch.einsum("i,j -> ij", t, angular_freq) 136 | self.cos = nn.Buffer(theta.cos(), persistent=False) 137 | self.sin = nn.Buffer(theta.sin(), persistent=False) 138 | 139 | def forward(self, x_BTHD: Tensor): 140 | assert self.cos.size(0) >= x_BTHD.size(-3) 141 | cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] 142 | x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) 143 | y1 = x1 * cos + x2 * sin 144 | y2 = x1 * (-sin) + x2 * cos 145 | return torch.cat((y1, y2), 3).type_as(x_BTHD) 146 | 147 | class CausalSelfAttention(nn.Module): 148 | def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): 149 | super().__init__() 150 | self.num_heads = num_heads 151 | self.head_dim = head_dim 152 | hdim = num_heads * head_dim 153 | # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng 154 | # https://x.com/hi_tysam/status/1879699187107033311 155 | self.qkvo_w = nn.Parameter(init_linear(torch.empty(4, hdim, dim)).bfloat16()) 156 | self.qkvo_w.detach()[3].zero_() # out zero init suggested by @Grad62304977 157 | self.rotary = Rotary(head_dim, max_seq_len) 158 | # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun 159 | # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 160 | self.attn_scale = 0.12 161 | 162 | def forward(self, x: Tensor, ve: Tensor | None, block_mask: BlockMask, lambdas: Tensor): 163 | B, T = x.size(0), x.size(1) # batch size, sequence length 164 | assert B == 1, "Must use batch size = 1 for FlexAttention" 165 | q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) 166 | q, k = norm(q), norm(k) # QK norm @Grad62304977 167 | q, k = self.rotary(q), self.rotary(k) 168 | v = norm(v) 169 | if ve is not None: 170 | v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 171 | else: # skip mid-layers token value embeddings by @YouJiacheng 172 | v = lambdas[0] * v 173 | y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask, scale=self.attn_scale).transpose(1, 2) 174 | y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side 175 | y = F.linear(y, self.qkvo_w[3]) 176 | return y 177 | 178 | class MLP(nn.Module): 179 | def __init__(self, dim: int): 180 | super().__init__() 181 | hdim = 4 * dim 182 | self.fc_w = nn.Parameter(init_linear(torch.empty(hdim, dim)).bfloat16()) 183 | self.proj_w = nn.Parameter(torch.zeros(dim, hdim).bfloat16()) 184 | self.fc_w.wd_mul = 2.0 185 | self.proj_w.wd_mul = 2.0 186 | 187 | def forward(self, x: Tensor): 188 | x = F.linear(x, self.fc_w) 189 | x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 190 | x = F.linear(x, self.proj_w) 191 | return x 192 | 193 | class Block(nn.Module): 194 | def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): 195 | super().__init__() 196 | # skip attention of blocks.7 (the 8th layer) by @YouJiacheng 197 | self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None 198 | self.mlp = MLP(dim) 199 | 200 | def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, block_mask: BlockMask, lambdas: Tensor, sa_lambdas: Tensor): 201 | x = lambdas[0] * x + lambdas[1] * x0 202 | if self.attn is not None: 203 | x = x + self.attn(x, ve, block_mask, sa_lambdas) 204 | x = x + self.mlp(norm(x)) 205 | return x 206 | 207 | # ----------------------------------------------------------------------------- 208 | # The main model 209 | 210 | def next_multiple_of_n(v: float | int, *, n: int): 211 | return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) 212 | 213 | class GPT(nn.Module): 214 | def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): 215 | super().__init__() 216 | self.embed = nn.Embedding(vocab_size, model_dim) 217 | # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 218 | # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 219 | self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) 220 | self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) 221 | # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. 222 | # suggested to me by @Grad62304977. this originates from Karpathy's experiments. 223 | self.lm_head_w = nn.Parameter(torch.zeros(next_multiple_of_n(vocab_size, n=128), model_dim)) 224 | # Add learnable skip connection weights for decoder layers 225 | assert num_layers % 2 == 0 226 | self.scalars = nn.Parameter(torch.cat([ 227 | torch.ones(num_layers), # skip_weights 228 | *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas 229 | *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas 230 | ])) 231 | 232 | def create_blockmasks(self, input_seq: Tensor, sliding_window_num_blocks: Tensor): 233 | BLOCK_SIZE = 128 234 | docs = (input_seq == 50256).cumsum(0) 235 | 236 | def document_causal(b, h, q_idx, kv_idx): 237 | causal_mask = q_idx >= kv_idx 238 | document_mask = docs[q_idx] == docs[kv_idx] 239 | return causal_mask & document_mask 240 | 241 | def dense_to_ordered(dense_blockmask: Tensor): 242 | num_blocks = dense_blockmask.sum(dim=-1, dtype=torch.int32) 243 | indices = dense_blockmask.argsort(dim=-1, descending=False, stable=True).flip(-1).to(torch.int32) 244 | return num_blocks[None, None].contiguous(), indices[None, None].contiguous() 245 | 246 | # manual block mask creation by @YouJiacheng 247 | assert len(input_seq) % BLOCK_SIZE == 0 248 | NUM_BLOCKS = len(input_seq) // BLOCK_SIZE 249 | block_idx = torch.arange(NUM_BLOCKS, dtype=torch.int32, device="cuda") 250 | causal_blockmask_any = block_idx[:, None] >= block_idx 251 | causal_blockmask_all = block_idx[:, None] > block_idx 252 | docs_low = docs.view(-1, BLOCK_SIZE)[:, 0].contiguous() 253 | docs_high = docs.view(-1, BLOCK_SIZE)[:, -1].contiguous() 254 | document_blockmask_any = (docs_low[:, None] <= docs_high) & (docs_high[:, None] >= docs_low) 255 | document_blockmask_all = (docs_low[:, None] == docs_high) & (docs_high[:, None] == docs_low) 256 | blockmask_any = causal_blockmask_any & document_blockmask_any 257 | blockmask_all = causal_blockmask_all & document_blockmask_all 258 | partial_kv_num_blocks, partial_kv_indices = dense_to_ordered(blockmask_any & ~blockmask_all) 259 | full_kv_num_blocks, full_kv_indices = dense_to_ordered(blockmask_all) 260 | def build_bm(window_size_blocks: Tensor) -> BlockMask: 261 | return BlockMask.from_kv_blocks( 262 | torch.clamp_max(partial_kv_num_blocks, torch.clamp_min(window_size_blocks - full_kv_num_blocks, 1)), 263 | partial_kv_indices, 264 | torch.clamp_max(full_kv_num_blocks, window_size_blocks - 1), 265 | full_kv_indices, 266 | BLOCK_SIZE=BLOCK_SIZE, 267 | mask_mod=document_causal, 268 | ) 269 | # Long-short SWA block masks by @leloykun & @YouJiacheng, adapated from suggestion by @Grad62304977, following Gemma 2 paper 270 | return build_bm(sliding_window_num_blocks), build_bm(sliding_window_num_blocks // 2) 271 | 272 | def forward(self, input_seq: Tensor, target_seq: Tensor, sliding_window_num_blocks: Tensor): 273 | assert input_seq.ndim == 1 274 | 275 | ve = [value_embed(input_seq) for value_embed in self.value_embeds] 276 | # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure 277 | ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] 278 | assert len(ve) == len(self.blocks) 279 | 280 | long_bm, short_bm = self.create_blockmasks(input_seq, sliding_window_num_blocks) 281 | block_masks = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] 282 | assert len(block_masks) == len(self.blocks) 283 | 284 | x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 285 | 286 | skip_connections = [] 287 | skip_map = { 288 | 9: 6, 289 | 10: 4, 290 | 11: 2, 291 | } 292 | skip_weights = self.scalars[:len(self.blocks)] 293 | lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) 294 | sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) 295 | for i in range(len(self.blocks)): 296 | if i in skip_map: 297 | x = x + skip_weights[skip_map[i]] * skip_connections[skip_map[i]] 298 | x = self.blocks[i](x, ve[i], x0, block_masks[i], lambdas[i], sa_lambdas[i]) 299 | skip_connections.append(x) 300 | 301 | x = norm(x) 302 | if self.training: 303 | logits: Tensor = F.linear(x.flatten(end_dim=1), self.lm_head_w.bfloat16()).float() 304 | loss = F.cross_entropy(15 * logits * torch.rsqrt(logits.square() + 225), target_seq) 305 | return loss 306 | 307 | loss = 0 308 | for i in range(4): 309 | logits: Tensor = F.linear(x.flatten(end_dim=1).chunk(4)[i], self.lm_head_w.bfloat16()).float() 310 | loss += F.cross_entropy(15 * logits * torch.rsqrt(logits.square() + 225), target_seq.chunk(4)[i]) / 4 311 | return loss 312 | 313 | # ----------------------------------------------------------------------------- 314 | # Our own simple Distributed Data Loader 315 | 316 | def _load_data_shard(file: Path): 317 | header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 318 | assert header[0] == 20240520, "magic number mismatch in the data .bin file" 319 | assert header[1] == 1, "unsupported version" 320 | num_tokens = int(header[2]) # number of tokens (claimed) 321 | with file.open("rb", buffering=0) as f: 322 | tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng 323 | f.seek(256 * 4) 324 | nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng 325 | assert nbytes == 2 * num_tokens, "number of tokens read does not match header" 326 | return tokens 327 | 328 | def distributed_data_generator(filename_pattern: str, batch_size: int, rank : int, world_size : int): 329 | files = sorted(Path.cwd().glob(filename_pattern)) 330 | assert batch_size % world_size == 0 331 | local_batch_size = batch_size // world_size 332 | file_iter = iter(files) # use itertools.cycle(files) instead if you want to do multi-epoch training 333 | tokens, pos = _load_data_shard(next(file_iter)), 0 334 | while True: 335 | if pos + batch_size + 1 >= len(tokens): 336 | tokens, pos = _load_data_shard(next(file_iter)), 0 337 | buf = tokens[pos + rank * local_batch_size:][:local_batch_size + 1] 338 | inputs = buf[:-1].to(device="cuda", dtype=torch.int32, non_blocking=True) # no sync on host side; 339 | targets = buf[1:].to(device="cuda", dtype=torch.int64, non_blocking=True) # H2D in another stream isn't helpful. 340 | pos += batch_size 341 | yield inputs, targets 342 | 343 | # ----------------------------------------------------------------------------- 344 | # int main 345 | 346 | @dataclass 347 | class Hyperparameters: 348 | # data 349 | train_files = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on 350 | val_files = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on 351 | val_tokens = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons 352 | train_seq_len = 64*1024 # FlexAttention sequence length 353 | val_seq_len = 4*64*1024 # FlexAttention sequence length for validation 354 | # optimization 355 | num_iterations = 5960 # number of iterations to run 356 | cooldown_frac = 0.7 # fraction of training spent cooling down the learning rate 357 | # architecture 358 | vocab_size = 50257 359 | # evaluation and logging 360 | val_loss_every = 125 # every how many steps to evaluate val loss? 0 for only at the end 361 | save_checkpoint = False 362 | args = Hyperparameters() 363 | 364 | run_id = int(os.environ.get("RUN_ID", 0)) 365 | # torchrun sets these env variables 366 | rank = int(os.environ["RANK"]) 367 | world_size = int(os.environ["WORLD_SIZE"]) 368 | assert world_size == 8 # this code is designed for 8xH100 369 | assert torch.cuda.is_available() 370 | device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) 371 | torch.cuda.set_device(device) 372 | dist.init_process_group(backend="nccl", device_id=device) 373 | dist.barrier() 374 | master_process = (rank == 0) # this process will do logging, checkpointing etc. 375 | 376 | # begin logging 377 | if master_process: 378 | run_id_full = f"{run_id:03d}_{uuid.uuid4()}" 379 | os.makedirs("logs", exist_ok=True) 380 | logfile = f"logs/{run_id_full}.txt" 381 | print(logfile) 382 | def print0(s, console=False): 383 | if master_process: 384 | with open(logfile, "a") as f: 385 | if console: 386 | print(s) 387 | print(s, file=f) 388 | from torch._logging._internal import trace_structured # noqa: E402 389 | import torch._inductor.codecache # noqa: E402 390 | import torch._inductor.graph # noqa: E402 391 | def _patched_trace_structured(name, metadata_fn, **kwargs): 392 | if name == "inductor_output_code": 393 | print0(f"inductor_output_code: {metadata_fn().get("filename", "Unknown")}") 394 | trace_structured(name, metadata_fn, **kwargs) 395 | torch._inductor.codecache.trace_structured = _patched_trace_structured 396 | torch._inductor.graph.trace_structured = _patched_trace_structured 397 | 398 | # begin by printing this file (the Python code) 399 | print0(code) 400 | print0("="*100) 401 | # log information about the hardware/software environment this is running on 402 | print0(f"Running Python {sys.version}") 403 | print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") 404 | def nvidia_smi(): 405 | import subprocess # avoid top level import 406 | return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout 407 | print0(nvidia_smi()) 408 | print0("="*100) 409 | 410 | ######################################## 411 | # Construct model and optimizer # 412 | ######################################## 413 | 414 | model: nn.Module = GPT(vocab_size=args.vocab_size, num_layers=16, num_heads=8, model_dim=1024, 415 | max_seq_len=max(args.train_seq_len, args.val_seq_len)).cuda() 416 | for m in model.modules(): 417 | if isinstance(m, nn.Embedding): 418 | m.bfloat16() 419 | for param in model.parameters(): 420 | dist.broadcast(param.detach(), 0) 421 | 422 | # collect the parameters to optimize 423 | hidden_matrix_params = sorted((p for p in model.blocks.parameters() if p.ndim >= 2), key=lambda x: x.size(), reverse=True) 424 | embed_params = [*model.embed.parameters(), *model.value_embeds.parameters()] 425 | scalar_params = [model.scalars] 426 | head_params: list[nn.Parameter] = [model.lm_head_w] 427 | # sanity check 428 | params_collections = [hidden_matrix_params, embed_params, scalar_params, head_params] 429 | optimized_parameters_set = {p for params in params_collections for p in params} 430 | assert optimized_parameters_set == {*model.parameters()} 431 | assert len(optimized_parameters_set) == sum(len(lst) for lst in params_collections) 432 | 433 | # init the optimizer(s) 434 | adam_param_groups = [dict(params=head_params, lr=1/320), dict(params=embed_params, lr=0.3), dict(params=scalar_params, lr=0.015)] 435 | # small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence 436 | # discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 437 | optimizer1 = torch.optim.AdamW(adam_param_groups, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0, fused=True) 438 | optimizer2 = Muon(hidden_matrix_params, lr=0.025, momentum=0.95, rank=rank, world_size=world_size) 439 | optimizers: list[torch.optim.Optimizer] = [optimizer1, optimizer2] 440 | def opt_params(opt: torch.optim.Optimizer) -> list[nn.Parameter]: 441 | return [p for group in opt.param_groups for p in group["params"]] 442 | opt2params = {opt: opt_params(opt) for opt in optimizers} 443 | for opt in optimizers: 444 | for group in opt.param_groups: 445 | group["initial_lr"] = group["lr"] 446 | 447 | # learning rate schedule: stable then decay 448 | def get_lr(step: int): 449 | x = step / args.num_iterations # progress in training 450 | assert 0 <= x < 1 451 | if x < 1 - args.cooldown_frac: 452 | return 1.0 453 | else: 454 | return (1 - x) / args.cooldown_frac 455 | 456 | # attention window size schedule: linearly increase 457 | @lru_cache(1) 458 | def get_window_size_blocks_helper(window_size: int): 459 | return torch.tensor(window_size // 128, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) 460 | def get_window_size_blocks(step: int): 461 | x = step / args.num_iterations # progress in training 462 | assert 0 <= x <= 1 463 | # Linearly increase the block-wise sliding window size over training 128 -> 1792 464 | # increase by @fernbear.bsky.social; block-wise by @YouJiacheng 465 | factor = 4 * x ** 3 - 6 * x ** 2 + 3 * x # cubic schedule by @jadenj3o 466 | window_size = next_multiple_of_n(3456 * factor, n=128) 467 | return get_window_size_blocks_helper(window_size) 468 | 469 | model: nn.Module = torch.compile(model, dynamic=False) 470 | 471 | ######################################## 472 | # Warmup kernels # 473 | ######################################## 474 | 475 | # Warmup the training kernels, then re-initialize the state so we aren't cheating 476 | warmup_steps = 10 477 | initial_state = copy.deepcopy(dict(model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers])) 478 | for _ in range(warmup_steps): 479 | inputs = targets = torch.randint(0, args.vocab_size, size=(args.train_seq_len,), device="cuda") 480 | model(inputs.to(torch.int32), targets, get_window_size_blocks(0)).backward() 481 | for param in model.parameters(): 482 | dist.all_reduce(param.grad, op=dist.ReduceOp.AVG) 483 | for opt in optimizers: 484 | opt.step() 485 | model.zero_grad(set_to_none=True) 486 | model.load_state_dict(initial_state["model"]) 487 | for opt, opt_state in zip(optimizers, initial_state["optimizers"]): 488 | opt.load_state_dict(opt_state) 489 | del initial_state 490 | 491 | ######################################## 492 | # Training and validation # 493 | ######################################## 494 | 495 | torch.cuda.reset_peak_memory_stats() 496 | train_loader = distributed_data_generator(args.train_files, world_size * args.train_seq_len, rank, world_size) 497 | training_time_ms = 0 498 | # start the clock 499 | dist.barrier() 500 | t0 = time.perf_counter() 501 | # begin training 502 | train_steps = args.num_iterations 503 | for step in range(train_steps + 1): 504 | last_step = (step == train_steps) 505 | 506 | # --------------- VALIDATION SECTION ----------------- 507 | if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): 508 | # stop the clock 509 | dist.barrier() 510 | training_time_ms += 1000 * (time.perf_counter() - t0) 511 | model.eval() 512 | val_batch_size = world_size * args.val_seq_len 513 | assert args.val_tokens % val_batch_size == 0 514 | val_steps = args.val_tokens // val_batch_size 515 | val_loader = distributed_data_generator(args.val_files, val_batch_size, rank, world_size) 516 | val_loss = 0 517 | with torch.no_grad(): 518 | for _ in range(val_steps): 519 | inputs, targets = next(val_loader) 520 | val_loss += model(inputs, targets, get_window_size_blocks(step)) 521 | val_loss /= val_steps 522 | del val_loader 523 | dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) 524 | print0(f"step:{step}/{train_steps} val_loss:{val_loss:.6f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) 525 | model.train() 526 | # start the clock again 527 | dist.barrier() 528 | t0 = time.perf_counter() 529 | 530 | if last_step: 531 | if master_process and args.save_checkpoint: 532 | log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) 533 | os.makedirs(f"logs/{run_id_full}", exist_ok=True) 534 | torch.save(log, f"logs/{run_id_full}/state_step{step:06d}.pt") 535 | # the last step only has the validation loop, so break to avoid training 536 | break 537 | 538 | # --------------- TRAINING SECTION ----------------- 539 | inputs, targets = next(train_loader) 540 | model(inputs, targets, get_window_size_blocks(step)).backward() 541 | opt2futures = { 542 | opt: [dist.all_reduce(p.grad, op=dist.ReduceOp.AVG, async_op=True).get_future() for p in params] 543 | for opt, params in opt2params.items() 544 | } 545 | # set optimization hyperparameters 546 | for opt in optimizers: 547 | for group in opt.param_groups: 548 | group["lr"] = group["initial_lr"] * get_lr(step) 549 | for group in optimizer2.param_groups: 550 | frac = min(step / 300, 1) # momentum warmup for muon 551 | group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 552 | # step the optimizers 553 | for opt in optimizers: 554 | torch.futures.collect_all(opt2futures[opt]).wait() 555 | opt.step() 556 | # null the gradients 557 | model.zero_grad(set_to_none=True) 558 | # logging 559 | approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) 560 | print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) 561 | 562 | print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " 563 | f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) 564 | dist.destroy_process_group() 565 | --------------------------------------------------------------------------------