├── image-1.png ├── results ├── mlp_cifar10.png ├── mlp_cifar10_ready.png ├── loss_vs_log2lr_SPMLP_1.0_Adam.png ├── loss_vs_log2lr_SPMLP_1.0_SGD.png ├── loss_vs_log2lr_muMLP_1.0_Adam.png ├── loss_vs_log2lr_muMLP_1.0_SGD.png ├── loss_vs_log2lr_SPMLP_1.0_SGD.csv ├── loss_vs_log2lr_muMLP_1.0_SGD.csv ├── loss_vs_log2lr_SPMLP_1.0_Adam.csv └── loss_vs_log2lr_muMLP_1.0_Adam.csv ├── run_mlp.sh ├── README.md ├── toy_cifar.ipynb └── train_mlp.py /image-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Laz4rz/mup/HEAD/image-1.png -------------------------------------------------------------------------------- /results/mlp_cifar10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Laz4rz/mup/HEAD/results/mlp_cifar10.png -------------------------------------------------------------------------------- /results/mlp_cifar10_ready.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Laz4rz/mup/HEAD/results/mlp_cifar10_ready.png -------------------------------------------------------------------------------- /results/loss_vs_log2lr_SPMLP_1.0_Adam.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Laz4rz/mup/HEAD/results/loss_vs_log2lr_SPMLP_1.0_Adam.png -------------------------------------------------------------------------------- /results/loss_vs_log2lr_SPMLP_1.0_SGD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Laz4rz/mup/HEAD/results/loss_vs_log2lr_SPMLP_1.0_SGD.png -------------------------------------------------------------------------------- /results/loss_vs_log2lr_muMLP_1.0_Adam.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Laz4rz/mup/HEAD/results/loss_vs_log2lr_muMLP_1.0_Adam.png -------------------------------------------------------------------------------- /results/loss_vs_log2lr_muMLP_1.0_SGD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Laz4rz/mup/HEAD/results/loss_vs_log2lr_muMLP_1.0_SGD.png -------------------------------------------------------------------------------- /run_mlp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo "Running MLP Adam experiment..." 4 | python3 train_mlp.py --model SPMLP --subset 1 --optimizer Adam --lr_range -16 -4 5 | 6 | echo "Running muMLP Adam experiment..." 7 | python3 train_mlp.py --model muMLP --subset 1 --optimizer Adam --lr_range -12 0 8 | 9 | echo "Running MLP SGD experiment..." 10 | python3 train_mlp.py --model SPMLP --subset 1 --optimizer SGD --lr_range -12 0 11 | 12 | echo "Running muMLP SGD experiment..." 13 | python3 train_mlp.py --model muMLP --subset 1 --optimizer SGD --lr_range -12 0 14 | 15 | echo "All experiments completed." 16 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # muP made easy 2 | 3 | A minimal (really) implementation of muP with SGD and Adam, following the Tensor Programs IV and Tensor Programs V papers. Classes `SPMLP` and `muMLPTab9` implement SP and muP parametrizations as shown in Table1 TPIV paper or Table9 TPV paper equivalently. Rest of the code is just training utils. 4 | 5 | This implementation does not rely on "setting shapes", nor optimizer trickes, like others. There is also no tunable scaling hyperparameters. 6 | 7 | Running `run_mlp.sh` will reproduce the results from below. Training script will auto-run on all GPUs with >16GB memory and <50% utilization. Feel free to change it as per your GPUs. 8 | 9 | ![alt text](image-1.png) 10 | 11 | Special thanks to dvruette for help deciphering the notation and discussions on debugging. 12 | -------------------------------------------------------------------------------- /results/loss_vs_log2lr_SPMLP_1.0_SGD.csv: -------------------------------------------------------------------------------- 1 | ,128,256,512,1024,2048,4096,8192 2 | -12.0,1.8440642356872559,1.7113268375396729,1.5676597356796265,1.328279972076416,0.8689160943031311,0.29880619049072266,0.07484827935695648 3 | -11.692307692307692,1.8286209106445312,1.6929866075515747,1.5388480424880981,1.2743785381317139,0.7759881019592285,0.23671577870845795,0.06190268322825432 4 | -11.384615384615385,1.8151963949203491,1.6757885217666626,1.5088512897491455,1.2165743112564087,0.6816218495368958,0.1882523000240326,0.05198953300714493 5 | -11.076923076923077,1.8018076419830322,1.6580321788787842,1.4782907962799072,1.155057668685913,0.5882124304771423,0.15056705474853516,0.04354432597756386 6 | -10.76923076923077,1.788156509399414,1.6389607191085815,1.4446731805801392,1.0860258340835571,0.49929800629615784,0.11891654133796692,0.03606623783707619 7 | -10.461538461538462,1.774210810661316,1.6184409856796265,1.4074441194534302,1.008029818534851,0.40881481766700745,0.09342408180236816,0.029585061594843864 8 | -10.153846153846153,1.7586843967437744,1.5955761671066284,1.3660738468170166,0.9224177598953247,0.32746878266334534,0.0735900029540062,0.02361658215522766 9 | -9.846153846153847,1.740665316581726,1.5706031322479248,1.3201310634613037,0.8261221647262573,0.2550088167190552,0.056633591651916504,0.018631136044859886 10 | -9.538461538461538,1.721273422241211,1.5421451330184937,1.2633254528045654,0.7220906615257263,0.18785539269447327,0.0433524027466774,0.014225480146706104 11 | -9.23076923076923,1.7025381326675415,1.5108869075775146,1.2026528120040894,0.6121242046356201,0.14039936661720276,0.033717911690473557,0.01103190891444683 12 | -8.923076923076923,1.6819576025009155,1.476582407951355,1.1324946880340576,0.49607205390930176,0.10894399881362915,0.026458539068698883,0.009327193722128868 13 | -8.615384615384615,1.6615709066390991,1.4392845630645752,1.0494415760040283,0.3987891376018524,0.08830025792121887,0.02141677401959896,0.010574492625892162 14 | -8.307692307692307,1.6396470069885254,1.395561933517456,0.9699442982673645,0.31885868310928345,0.07306777685880661,0.019760945811867714,0.012901779264211655 15 | -8.0,1.6173704862594604,1.3510650396347046,0.8833211660385132,0.25523117184638977,0.06554372608661652,0.018744291737675667,0.020020386204123497 16 | -7.692307692307692,1.5966178178787231,1.3116792440414429,0.8149693012237549,0.21981625258922577,0.06152557209134102,0.026071544736623764,0.04174180328845978 17 | -7.384615384615384,1.578596591949463,1.2687950134277344,0.7188543677330017,0.18604262173175812,0.07703716307878494,0.04988740757107735,0.13199540972709656 18 | -7.076923076923077,1.5611035823822021,1.2266196012496948,0.6718690991401672,0.18129485845565796,0.10545295476913452,0.08102243393659592,0.18955223262310028 19 | -6.769230769230769,1.5464657545089722,1.1877166032791138,0.6410055160522461,0.19681167602539062,0.1568695455789566,0.17867322266101837,0.32224202156066895 20 | -6.461538461538462,1.5466300249099731,1.1550594568252563,0.6654894948005676,0.25284621119499207,0.26765087246894836,0.2881878614425659,0.6152384281158447 21 | -6.153846153846153,1.5403950214385986,1.149572491645813,0.686809241771698,0.33354347944259644,0.3568597435951233,0.463738352060318,2.0364296436309814 22 | -5.846153846153846,1.547912359237671,1.1572850942611694,0.7571589946746826,0.4458096921443939,0.5457643866539001,1.550693154335022,2.286449909210205 23 | -5.538461538461538,1.5645866394042969,1.2088087797164917,0.842339813709259,0.5918247103691101,0.9878395199775696,2.2829031944274902,2.302102565765381 24 | -5.23076923076923,1.5988731384277344,1.2312849760055542,0.9523286819458008,0.8452523350715637,2.1589763164520264,2.2981975078582764,2.3014492988586426 25 | -4.9230769230769225,1.6232376098632812,1.2888154983520508,1.1796332597732544,1.4421889781951904,2.302072525024414,2.3012092113494873,2.3000526428222656 26 | -4.615384615384615,1.6773532629013062,1.336965560913086,1.619910717010498,2.2288551330566406,2.3018195629119873,2.3020787239074707,2.3020806312561035 27 | -4.3076923076923075,1.7209601402282715,1.4264695644378662,2.2440073490142822,2.2832462787628174,2.303180456161499,2.3024628162384033,2.302051067352295 28 | -4.0,1.771563172340393,1.7217445373535156,2.2937910556793213,2.302372455596924,2.3019800186157227,2.3025848865509033, 29 | -3.6923076923076916,2.0579164028167725,2.298804998397827,2.273745059967041,2.302447557449341,2.3024473190307617,2.301659107208252,2.302004098892212 30 | -3.3846153846153832,2.1106138229370117,2.2690813541412354,2.3025848865509033,2.3017444610595703,2.3021528720855713,1.9269826412200928, 31 | -3.0769230769230766,2.1083452701568604,2.2108614444732666,2.3019824028015137,2.290863513946533,2.302449941635132,, 32 | -2.7692307692307683,2.1214442253112793,2.1463217735290527,2.302345037460327,2.3025848865509033,2.277165412902832,, 33 | -2.4615384615384617,2.139286518096924,2.2738068103790283,2.3013126850128174,,,, 34 | -2.1538461538461533,2.1705551147460938,2.231416702270508,2.300826072692871,,,, 35 | -1.846153846153845,2.2995195388793945,,2.3025848865509033,,,, 36 | -1.5384615384615383,2.2999088764190674,,,,,, 37 | -1.23076923076923,2.3015480041503906,,,,,, 38 | -0.9230769230769234,2.3023102283477783,,,,,, 39 | -0.615384615384615,2.302419900894165,,,,,, 40 | -0.3076923076923066,2.3025848865509033,,,,,, 41 | 0.0,2.3025848865509033,,2.3025848865509033,,,, 42 | -------------------------------------------------------------------------------- /results/loss_vs_log2lr_muMLP_1.0_SGD.csv: -------------------------------------------------------------------------------- 1 | ,128,256,512,1024,2048,4096,8192 2 | -12.0,1.8580272197723389,1.7955725193023682,1.7542314529418945,1.7160221338272095,1.704688549041748,1.695602536201477,1.6919137239456177 3 | -11.692307692307692,1.8278729915618896,1.764150857925415,1.722434639930725,1.683312177658081,1.6702347993850708,1.661184310913086,1.6576658487319946 4 | -11.384615384615385,1.7981488704681396,1.7326247692108154,1.690285086631775,1.6503396034240723,1.6353532075881958,1.6264208555221558,1.6228634119033813 5 | -11.076923076923077,1.768265724182129,1.7005356550216675,1.6575931310653687,1.61680269241333,1.599884271621704,1.591048240661621,1.587283730506897 6 | -10.76923076923077,1.738079309463501,1.668017029762268,1.624210000038147,1.5824754238128662,1.5636532306671143,1.5547029972076416,1.5506350994110107 7 | -10.461538461538462,1.707590103149414,1.6350497007369995,1.5899714231491089,1.54691743850708,1.5262112617492676,1.5169098377227783,1.5124363899230957 8 | -10.153846153846153,1.6766124963760376,1.601531744003296,1.554632306098938,1.5096330642700195,1.4869287014007568,1.4770318269729614,1.4719737768173218 9 | -9.846153846153847,1.6456845998764038,1.5676156282424927,1.5179928541183472,1.469989538192749,1.4450749158859253,1.4342368841171265,1.4283745288848877 10 | -9.538461538461538,1.6148011684417725,1.5328624248504639,1.4792914390563965,1.427317500114441,1.399674415588379,1.3872405290603638,1.3804163932800293 11 | -9.23076923076923,1.583606481552124,1.4968191385269165,1.4377511739730835,1.3806359767913818,1.3494006395339966,1.3345924615859985,1.3266130685806274 12 | -8.923076923076923,1.551928997039795,1.4588520526885986,1.392361044883728,1.328897476196289,1.2927160263061523,1.2745672464370728,1.2651323080062866 13 | -8.615384615384615,1.5191245079040527,1.4183101654052734,1.3420827388763428,1.271266222000122,1.228369116783142,1.2060933113098145,1.1945050954818726 14 | -8.307692307692307,1.4853029251098633,1.3742140531539917,1.2866977453231812,1.2064118385314941,1.1552915573120117,1.127920150756836,1.1133151054382324 15 | -8.0,1.4498862028121948,1.3266054391860962,1.2252788543701172,1.1334775686264038,1.0719037055969238,1.0383689403533936,1.019554853439331 16 | -7.692307692307692,1.4119794368743896,1.2740989923477173,1.1577366590499878,1.0506478548049927,0.9767716526985168,0.9358946084976196,0.9119977951049805 17 | -7.384615384615384,1.3708109855651855,1.2166273593902588,1.082798957824707,0.9574782252311707,0.8699079155921936,0.8208462595939636,0.7917525768280029 18 | -7.076923076923077,1.3268284797668457,1.1528922319412231,0.9993859529495239,0.8542959094047546,0.7535517811775208,0.6978685259819031,0.6648340821266174 19 | -6.769230769230769,1.2791804075241089,1.0817902088165283,0.9055108428001404,0.743331789970398,0.6348782777786255,0.5770388245582581,0.5438244342803955 20 | -6.461538461538462,1.2272220849990845,1.0043426752090454,0.8045516610145569,0.6305358409881592,0.5237105488777161,0.4688855707645416,0.4369391202926636 21 | -6.153846153846153,1.172063946723938,0.9198054075241089,0.6964566111564636,0.5211473703384399,0.4245668351650238,0.37527844309806824,0.34710097312927246 22 | -5.846153846153846,1.1124546527862549,0.8318623900413513,0.588599443435669,0.4206116795539856,0.3345232605934143,0.29427438974380493,0.2656753659248352 23 | -5.538461538461538,1.053644061088562,0.740325927734375,0.4847997725009918,0.3239917755126953,0.24991457164287567,0.22392898797988892,0.19466005265712738 24 | -5.23076923076923,0.9908750057220459,0.6530495285987854,0.38929253816604614,0.24154706299304962,0.17776834964752197,0.16218794882297516,0.1313224732875824 25 | -4.9230769230769225,0.935684323310852,0.5756202340126038,0.30578160285949707,0.16527767479419708,0.11736927181482315,0.08576342463493347,0.08235286921262741 26 | -4.615384615384615,0.8802027106285095,0.5151736736297607,0.2356458604335785,0.10802765190601349,0.06283684074878693,0.05279918015003204,0.04966201260685921 27 | -4.3076923076923075,0.8398526906967163,0.46930158138275146,0.19047462940216064,0.06588290631771088,0.03510713577270508,0.025121109560132027,0.024278543889522552 28 | -4.0,0.8212745785713196,0.45320379734039307,0.18698552250862122,0.04801448807120323,0.022100500762462616,0.016465215012431145,0.013536302372813225 29 | -3.6923076923076916,0.8034840822219849,0.4452359974384308,0.1859789788722992,0.04416177421808243,0.014313205145299435,0.009670404717326164,0.009411800652742386 30 | -3.3846153846153832,0.8146777153015137,0.4650191366672516,0.21159303188323975,0.07169809192419052,0.012502077966928482,0.008421479724347591,0.009637289680540562 31 | -3.0769230769230766,0.8366654515266418,0.4978407323360443,0.25197693705558777,0.11994849890470505,0.028672680258750916,0.006545894779264927,0.005487229209393263 32 | -2.7692307692307683,0.8632645606994629,0.5410938262939453,0.3419749438762665,0.1841246485710144,0.0721115842461586,0.015094545669853687,0.006654795259237289 33 | -2.4615384615384617,0.9002864956855774,0.6353059411048889,0.40577542781829834,0.2567635774612427,0.15394043922424316,0.13624295592308044,0.022576766088604927 34 | -2.1538461538461533,0.986055850982666,0.7188214063644409,0.5378162264823914,0.3941771388053894,0.272704541683197,0.18555854260921478,0.1438663899898529 35 | -1.846153846153845,1.0685127973556519,0.8424038887023926,0.692351222038269,0.5451285243034363,0.45349690318107605,0.43611088395118713,0.2997572720050812 36 | -1.5384615384615383,1.228724718093872,1.0534017086029053,0.9120779633522034,0.753929853439331,0.6592607498168945,0.6222734451293945,0.5023761987686157 37 | -1.23076923076923,1.4400874376296997,1.3174989223480225,1.1809396743774414,1.0781378746032715,0.9565817713737488,0.8917664885520935,0.8202407956123352 38 | -0.9230769230769234,1.6379470825195312,1.5091272592544556,1.4227300882339478,1.3255043029785156,1.2311803102493286,1.108778953552246,1.098970651626587 39 | -0.615384615384615,1.8560960292816162,1.754832148551941,1.6691420078277588,1.62034273147583,,1.5216076374053955, 40 | -0.3076923076923066,1.999579906463623,1.9728907346725464,1.9164683818817139,,,, 41 | 0.0,,,,,,, 42 | -------------------------------------------------------------------------------- /results/loss_vs_log2lr_SPMLP_1.0_Adam.csv: -------------------------------------------------------------------------------- 1 | ,128,256,512,1024,2048,4096,8192 2 | -16.0,2.2172842025756836,2.0531539916992188,1.6485592126846313,0.9054889678955078,0.18011365830898285,0.1226733922958374,0.13490451872348785 3 | -15.692307692307692,2.0424652099609375,1.8729407787322998,1.4695404767990112,0.725061297416687,0.13771885633468628,0.14635971188545227,0.1403399556875229 4 | -15.384615384615385,1.9055242538452148,1.7182905673980713,1.3119683265686035,0.5599125027656555,0.10830141603946686,0.11395743489265442,0.11015798896551132 5 | -15.076923076923077,1.800491213798523,1.5904146432876587,1.162899374961853,0.41463586688041687,0.093971386551857,0.09955265372991562,0.14636898040771484 6 | -14.76923076923077,1.7165933847427368,1.486684799194336,1.027328372001648,0.2937505841255188,0.07283134758472443,0.10645565390586853,0.12349887192249298 7 | -14.461538461538462,1.6476198434829712,1.3970303535461426,0.8984345197677612,0.19827726483345032,0.06822749227285385,0.10727369040250778,0.13145329058170319 8 | -14.153846153846153,1.588948130607605,1.3160239458084106,0.7698894739151001,0.13011176884174347,0.0781206488609314,0.09537163376808167,0.11225146055221558 9 | -13.846153846153847,1.535818099975586,1.2381800413131714,0.6433331966400146,0.0888335257768631,0.07724593579769135,0.08261971920728683,0.09004729986190796 10 | -13.538461538461538,1.4861501455307007,1.1582425832748413,0.5092683434486389,0.06344333291053772,0.06666519492864609,0.08683225512504578,0.10965074598789215 11 | -13.23076923076923,1.4370448589324951,1.0717657804489136,0.384344220161438,0.05408918485045433,0.07632694393396378,0.09897515922784805,0.1081252470612526 12 | -12.923076923076923,1.3887946605682373,0.9828964471817017,0.2750627398490906,0.06882326304912567,0.07785709202289581,0.08628582209348679,0.08155211806297302 13 | -12.615384615384615,1.3393768072128296,0.889931857585907,0.19778092205524445,0.07497615367174149,0.08516484498977661,0.08544362336397171,0.1199835017323494 14 | -12.307692307692307,1.2904317378997803,0.7953230738639832,0.15585044026374817,0.06859895586967468,0.08530279994010925,0.1080525666475296,0.12976697087287903 15 | -12.0,1.2400676012039185,0.7091561555862427,0.14270219206809998,0.100722536444664,0.10399463027715683,0.11693568527698517,0.20668770372867584 16 | -11.692307692307692,1.1878174543380737,0.6263875365257263,0.14193350076675415,0.10907239466905594,0.12520325183868408,0.14222043752670288,0.21503686904907227 17 | -11.384615384615383,1.1399446725845337,0.5605972409248352,0.13882270455360413,0.11401907354593277,0.1258085072040558,0.19395098090171814,0.2815280556678772 18 | -11.076923076923077,1.0895793437957764,0.5200114250183105,0.16821712255477905,0.13182085752487183,0.16944557428359985,0.2619277238845825,0.34423062205314636 19 | -10.76923076923077,1.0441935062408447,0.49907612800598145,0.186610147356987,0.1818707138299942,0.22706100344657898,0.3303757607936859,0.38386106491088867 20 | -10.461538461538462,1.016004204750061,0.48552653193473816,0.22638855874538422,0.2141244113445282,0.32534533739089966,0.37645643949508667,0.4530380368232727 21 | -10.153846153846153,0.9861443638801575,0.5199493169784546,0.2964347004890442,0.31799617409706116,0.41664546728134155,0.46927231550216675,0.5506712794303894 22 | -9.846153846153847,0.9678689241409302,0.5809422135353088,0.41204753518104553,0.47975441813468933,0.5022143721580505,0.573781430721283,0.6760505437850952 23 | -9.538461538461538,0.9824527502059937,0.6738874912261963,0.5489313006401062,0.5681962966918945,0.6084248423576355,0.6813074350357056,0.8350532650947571 24 | -9.23076923076923,1.0072777271270752,0.7748966217041016,0.6797304153442383,0.6684000492095947,0.7052732110023499,0.7999793291091919,1.0257829427719116 25 | -8.923076923076923,1.0312284231185913,0.8571575284004211,0.7925387620925903,0.7779092788696289,0.8484315276145935,1.0600422620773315,1.3325321674346924 26 | -8.615384615384615,1.0502943992614746,0.9520278573036194,0.8773601651191711,0.8761265873908997,1.0404914617538452,1.3261096477508545,1.5929926633834839 27 | -8.307692307692307,1.0864661931991577,1.0161983966827393,0.9955116510391235,1.0630848407745361,1.2775232791900635,1.562719464302063,1.8970701694488525 28 | -8.0,1.1401340961456299,1.091988444328308,1.13106369972229,1.3421937227249146,1.5311068296432495,1.7876873016357422,2.0809197425842285 29 | -7.692307692307692,1.179654836654663,1.1909103393554688,1.3148751258850098,1.550895094871521,1.7884535789489746,2.0168116092681885,2.30249285697937 30 | -7.384615384615383,1.2829419374465942,1.393317699432373,1.5824270248413086,1.7812970876693726,1.9883919954299927,2.3025388717651367,2.302170515060425 31 | -7.076923076923077,1.400172472000122,1.5776114463806152,1.8017299175262451,1.9776321649551392,2.3025848865509033,2.302400588989258,2.302262544631958 32 | -6.769230769230768,1.576484203338623,1.9089133739471436,2.039346694946289,2.1362807750701904,2.3025388717651367,2.3025848865509033,2.3023085594177246 33 | -6.461538461538462,1.9230403900146484,2.0393056869506836,2.1505510807037354,2.3025848865509033,2.3025848865509033,2.302354574203491,2.30249285697937 34 | -6.153846153846153,2.079033851623535,2.2635655403137207,2.3025848865509033,2.3025848865509033,2.3025848865509033,2.3025848865509033,2.30249285697937 35 | -5.846153846153845,2.3025848865509033,2.3025388717651367,2.3025848865509033,2.3025848865509033,2.3025388717651367,2.3025388717651367,2.30249285697937 36 | -5.538461538461538,2.3025848865509033,2.3025848865509033,2.3025848865509033,2.3025388717651367,2.3025848865509033,2.3025848865509033,2.3025848865509033 37 | -5.23076923076923,2.3025848865509033,2.3025848865509033,2.3025848865509033,2.3025388717651367,2.3025848865509033,2.3025848865509033,2.30249285697937 38 | -4.923076923076923,2.3025848865509033,2.3025848865509033,2.3025848865509033,2.3025848865509033,2.3025388717651367,2.3025848865509033,2.3024468421936035 39 | -4.615384615384615,2.3025848865509033,2.3025848865509033,2.3025848865509033,2.3025848865509033,2.3025848865509033,2.3025388717651367,2.3025388717651367 40 | -4.307692307692307,2.3025848865509033,2.3025848865509033,2.3025848865509033,2.3025848865509033,2.3025848865509033,2.3025848865509033,2.30249285697937 41 | -4.0,2.3025848865509033,2.3025848865509033,2.3025848865509033,2.3025388717651367,2.3025848865509033,2.302400588989258,2.3025388717651367 42 | -------------------------------------------------------------------------------- /results/loss_vs_log2lr_muMLP_1.0_Adam.csv: -------------------------------------------------------------------------------- 1 | ,128,256,512,1024,2048,4096,8192 2 | -12.0,1.4206106662750244,1.3330944776535034,1.262108325958252,1.1972408294677734,1.1369351148605347,1.0785802602767944,1.009942650794983 3 | -11.692307692307692,1.3808192014694214,1.2885868549346924,1.210929274559021,1.1411852836608887,1.0741991996765137,1.0086736679077148,0.9305814504623413 4 | -11.384615384615385,1.3395483493804932,1.2403818368911743,1.1570433378219604,1.0805660486221313,1.0058432817459106,0.9312848448753357,0.8422932028770447 5 | -11.076923076923077,1.2961549758911133,1.1905239820480347,1.1001179218292236,1.0147615671157837,0.9303061962127686,0.8455123901367188,0.7442283034324646 6 | -10.76923076923077,1.2533414363861084,1.1385014057159424,1.0383803844451904,0.9429838061332703,0.8478613495826721,0.7507948875427246,0.6381991505622864 7 | -10.461538461538462,1.2059581279754639,1.0831193923950195,0.9728246331214905,0.8660746812820435,0.7576060891151428,0.6485912203788757,0.5277799963951111 8 | -10.153846153846153,1.1584067344665527,1.0243240594863892,0.9025211334228516,0.7827584147453308,0.6614481210708618,0.5422160625457764,0.4195397198200226 9 | -9.846153846153847,1.1097642183303833,0.9625787138938904,0.8277496099472046,0.6944910883903503,0.562453031539917,0.4388996958732605,0.3215731382369995 10 | -9.538461538461538,1.0618786811828613,0.8988134264945984,0.7497575283050537,0.6011410355567932,0.4662381410598755,0.3448903262615204,0.2428734302520752 11 | -9.23076923076923,1.0082736015319824,0.8310298919677734,0.6683062314987183,0.5116293430328369,0.3776470422744751,0.26768043637275696,0.18787062168121338 12 | -8.923076923076923,0.964329183101654,0.7635444402694702,0.5882040858268738,0.4279516041278839,0.3035544753074646,0.20958717167377472,0.1504085659980774 13 | -8.615384615384615,0.9058070182800293,0.6932271718978882,0.5074883699417114,0.34878072142601013,0.2448357492685318,0.16778014600276947,0.1288631409406662 14 | -8.307692307692307,0.8602041602134705,0.6229555010795593,0.4296429753303528,0.2841620445251465,0.2004416435956955,0.13999424874782562,0.10772936046123505 15 | -8.0,0.8180731534957886,0.5593947768211365,0.35821986198425293,0.22932715713977814,0.16112367808818817,0.12039078027009964,0.09717018902301788 16 | -7.692307692307692,0.7781124114990234,0.4993979334831238,0.2977942228317261,0.17987701296806335,0.13058295845985413,0.10658438503742218,0.08830660581588745 17 | -7.384615384615384,0.739459216594696,0.4368683397769928,0.24155765771865845,0.14631658792495728,0.11087410897016525,0.09157118946313858,0.07409659773111343 18 | -7.076923076923077,0.7052993774414062,0.3895774483680725,0.19712522625923157,0.12170979380607605,0.09846470504999161,0.08288329094648361,0.08126974105834961 19 | -6.769230769230769,0.6955474615097046,0.3609372675418854,0.17269060015678406,0.10012480616569519,0.09510697424411774,0.07308612763881683,0.07508502900600433 20 | -6.461538461538462,0.6880418062210083,0.34396785497665405,0.163906529545784,0.10825936496257782,0.09120301157236099,0.08454662561416626,0.07704904675483704 21 | -6.153846153846153,0.6981275677680969,0.34348782896995544,0.16927912831306458,0.11544906347990036,0.0903533473610878,0.08310854434967041,0.08408361673355103 22 | -5.846153846153846,0.7132565975189209,0.3703601658344269,0.18153375387191772,0.13205033540725708,0.1006687730550766,0.09757325798273087,0.10017269104719162 23 | -5.538461538461538,0.7714946269989014,0.4138732850551605,0.21563859283924103,0.1535515934228897,0.10002674162387848,0.09859314560890198,0.09829584509134293 24 | -5.23076923076923,0.805079996585846,0.4795529544353485,0.271258145570755,0.17450760304927826,0.13322791457176208,0.13302895426750183,0.10752154886722565 25 | -4.9230769230769225,0.8492432832717896,0.5503376722335815,0.3392457067966461,0.22261324524879456,0.16140975058078766,0.13579700887203217,0.12633121013641357 26 | -4.615384615384615,0.9064252376556396,0.646146297454834,0.4124026596546173,0.28480544686317444,0.22628505527973175,0.17477735877037048,0.14169351756572723 27 | -4.3076923076923075,0.931185245513916,0.7032446265220642,0.4969250559806824,0.36147207021713257,0.27314379811286926,0.20925986766815186,0.17702455818653107 28 | -4.0,0.9736283421516418,0.7584465146064758,0.5662205815315247,0.41167598962783813,0.3333348333835602,0.2741101086139679,0.21400032937526703 29 | -3.6923076923076916,1.0025300979614258,0.8218345046043396,0.6292194128036499,0.4831300377845764,0.40819427371025085,0.342590868473053,0.2830682396888733 30 | -3.3846153846153832,1.0743858814239502,0.8826375603675842,0.716712236404419,0.5793853998184204,0.4742085635662079,0.4113706648349762,0.349984347820282 31 | -3.0769230769230766,1.1711702346801758,0.9720792174339294,0.7740287184715271,0.6408830285072327,0.5279196500778198,0.47670018672943115,0.40762999653816223 32 | -2.7692307692307683,1.240645408630371,1.0380141735076904,0.8558477163314819,0.7198582887649536,0.5713950991630554,0.5478404760360718,0.47321709990501404 33 | -2.4615384615384617,1.3655222654342651,1.1677027940750122,0.9614454507827759,0.76822829246521,0.6735852956771851,0.5773382186889648,0.5364844799041748 34 | -2.1538461538461533,1.5980745553970337,1.2658027410507202,1.0558853149414062,0.8846914172172546,0.7336425185203552,0.6561277508735657,0.5872576832771301 35 | -1.846153846153845,1.8436317443847656,1.427786111831665,1.1869512796401978,0.9955216646194458,0.8649254441261292,0.7240926623344421,0.6836898922920227 36 | -1.5384615384615383,2.007619857788086,1.7296048402786255,1.3996155261993408,1.1520044803619385,0.9657474160194397,0.8317741751670837,0.7899918556213379 37 | -1.23076923076923,2.222170829772949,1.9180463552474976,1.6245970726013184,1.3098454475402832,1.1126508712768555,0.948315441608429,0.8932841420173645 38 | -0.9230769230769234,2.3025848865509033,2.192922830581665,1.8261533975601196,1.5591129064559937,1.2466591596603394,1.1122732162475586,1.021584391593933 39 | -0.615384615384615,2.3025848865509033,2.3025848865509033,2.027647018432617,1.7304290533065796,1.4968854188919067,1.226948618888855,1.1361500024795532 40 | -0.3076923076923066,2.3025848865509033,2.3026046752929688,2.199495553970337,1.9682382345199585,1.697022557258606,1.4554752111434937,1.2665451765060425 41 | 0.0,2.3025848865509033,2.3025848865509033,2.3025848865509033,2.1370961666107178,1.8561878204345703,1.617972731590271,1.488148808479309 42 | -------------------------------------------------------------------------------- /toy_cifar.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 23, 6 | "id": "80971829", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import torch\n", 11 | "import torch.nn as nn\n", 12 | "import torch.nn.functional as F\n", 13 | "from torchvision import datasets, transforms\n", 14 | "from torch.utils.data import TensorDataset, DataLoader\n", 15 | "\n", 16 | "import numpy as np\n", 17 | "from pyhessian import hessian\n", 18 | "from tqdm import tqdm\n", 19 | "import pandas as pd\n", 20 | "import os\n", 21 | "import matplotlib.pyplot as plt\n", 22 | "\n", 23 | "from train_mlp import muMLPTab9\n", 24 | "\n", 25 | "device = \"cuda\"\n", 26 | "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "id": "e436cfc3", 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "def get_cifar(batch_size=128, num_classes=10, MSE=False, on_gpu=False, device=None):\n", 37 | " assert np.unique(targets[indices]).shape[0] >= num_classes, f\"Number of classes {np.unique(targets[indices]).shape[0]} != {num_classes}\"\n", 38 | " transform = transforms.Compose([\n", 39 | " transforms.ToTensor(),\n", 40 | " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n", 41 | " ])\n", 42 | " \n", 43 | " train_ds = datasets.CIFAR10(root='/tmp', train=True, download=False, transform=transform)\n", 44 | " targets = np.array(train_ds.targets)\n", 45 | " mask = np.isin(targets, np.arange(num_classes))\n", 46 | " indices = np.where(mask)[0]\n", 47 | "\n", 48 | "\n", 49 | " X, y = [], []\n", 50 | " for i in tqdm(indices):\n", 51 | " x, y_ = train_ds[i]\n", 52 | " X.append(x)\n", 53 | " y.append(y_)\n", 54 | " X = torch.stack(X)\n", 55 | " y = torch.tensor(y)\n", 56 | "\n", 57 | " if MSE:\n", 58 | " y = F.one_hot(y, num_classes=num_classes).float()\n", 59 | "\n", 60 | " if on_gpu:\n", 61 | " assert device is not None, \"Please provide a device=\"\n", 62 | " X = X.to(device)\n", 63 | " y = y.to(device)\n", 64 | "\n", 65 | " tensor_ds = TensorDataset(X, y)\n", 66 | " train_dl = DataLoader(tensor_ds, batch_size=batch_size, shuffle=True, pin_memory=not on_gpu)\n", 67 | "\n", 68 | " if on_gpu:\n", 69 | " print(f\"Estimated size of the dataset in MB: {(X.numel() * X.element_size() + y.numel() * y.element_size()) / 1024 / 1024:.2f}\")\n", 70 | "\n", 71 | " return train_dl, tensor_ds\n" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 42, 77 | "id": "7bc8f232", 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "seed = 1\n", 82 | "epochs = 5\n", 83 | "classes = 2" 84 | ] 85 | }, 86 | { 87 | "cell_type": "markdown", 88 | "id": "1f5f0bf8", 89 | "metadata": {}, 90 | "source": [ 91 | "# Tensors loaded on GPU per batch" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": 43, 97 | "id": "e664fba2", 98 | "metadata": {}, 99 | "outputs": [ 100 | { 101 | "name": "stderr", 102 | "output_type": "stream", 103 | "text": [ 104 | "100%|██████████| 10000/10000 [00:02<00:00, 4600.45it/s]\n" 105 | ] 106 | }, 107 | { 108 | "name": "stdout", 109 | "output_type": "stream", 110 | "text": [ 111 | "79\n", 112 | "torch.Size([128])\n", 113 | "0.6877436098098755\n", 114 | "0.6252798287391662\n", 115 | "0.5944725264549255\n", 116 | "0.5714316897392273\n", 117 | "0.553438679933548\n" 118 | ] 119 | } 120 | ], 121 | "source": [ 122 | "dl, ds = get_cifar(batch_size=128, num_classes=classes, MSE=False, on_gpu=False)\n", 123 | "print(len(dl))\n", 124 | "\n", 125 | "torch.manual_seed(seed)\n", 126 | "np.random.seed(seed)\n", 127 | "print(next(iter(dl))[1].shape)\n", 128 | "model = muMLPTab9(128, classes).to(device)\n", 129 | "criterion = nn.CrossEntropyLoss()\n", 130 | "\n", 131 | "model.train()\n", 132 | "optimizer = torch.optim.SGD(model.parameters(), lr=0.001)\n", 133 | "for epoch in range(epochs):\n", 134 | " epoch_loss = 0\n", 135 | " for i, (X, y) in enumerate(dl):\n", 136 | " X, y = X.to(device), y.to(device)\n", 137 | " optimizer.zero_grad()\n", 138 | " out = model(X)\n", 139 | " loss = criterion(out, y)\n", 140 | " loss.backward()\n", 141 | " optimizer.step()\n", 142 | " epoch_loss += loss.item() * X.size(0)\n", 143 | " print(epoch_loss / len(dl.dataset))" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": 44, 149 | "id": "833799d3", 150 | "metadata": {}, 151 | "outputs": [ 152 | { 153 | "name": "stdout", 154 | "output_type": "stream", 155 | "text": [ 156 | "torch.Size([128])\n", 157 | "0.6877436098098755\n", 158 | "0.6252798287391662\n", 159 | "0.5944725264549255\n", 160 | "0.5714316897392273\n", 161 | "0.553438679933548\n" 162 | ] 163 | } 164 | ], 165 | "source": [ 166 | "torch.manual_seed(seed)\n", 167 | "np.random.seed(seed)\n", 168 | "print(next(iter(dl))[1].shape)\n", 169 | "model = muMLPTab9(128, classes).to(device)\n", 170 | "criterion = nn.CrossEntropyLoss()\n", 171 | "\n", 172 | "model.train()\n", 173 | "optimizer = torch.optim.SGD(model.parameters(), lr=0.001)\n", 174 | "for epoch in range(epochs):\n", 175 | " epoch_loss = 0\n", 176 | " for i, (X, y) in enumerate(dl):\n", 177 | " X, y = X.to(device), y.to(device)\n", 178 | " optimizer.zero_grad()\n", 179 | " out = model(X)\n", 180 | " loss = criterion(out, y)\n", 181 | " loss.backward()\n", 182 | " optimizer.step()\n", 183 | " epoch_loss += loss.item() * X.size(0)\n", 184 | " print(epoch_loss / len(dl.dataset))" 185 | ] 186 | }, 187 | { 188 | "cell_type": "markdown", 189 | "id": "bb2aca19", 190 | "metadata": {}, 191 | "source": [ 192 | "# Tensors on GPU" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": 46, 198 | "id": "6fbdc7c7", 199 | "metadata": {}, 200 | "outputs": [ 201 | { 202 | "name": "stderr", 203 | "output_type": "stream", 204 | "text": [ 205 | "100%|██████████| 10000/10000 [00:02<00:00, 4246.00it/s]\n" 206 | ] 207 | }, 208 | { 209 | "name": "stdout", 210 | "output_type": "stream", 211 | "text": [ 212 | "Estimated size of the dataset in MB: 117.26\n", 213 | "79\n", 214 | "torch.Size([128])\n", 215 | "0.6877436098098755\n", 216 | "0.6252798287391662\n", 217 | "0.5944725264549255\n", 218 | "0.5714316897392273\n", 219 | "0.553438679933548\n" 220 | ] 221 | } 222 | ], 223 | "source": [ 224 | "dl, ds = get_cifar(batch_size=128, num_classes=classes, MSE=False, on_gpu=True, device=device)\n", 225 | "print(len(dl))\n", 226 | "\n", 227 | "torch.manual_seed(seed)\n", 228 | "np.random.seed(seed)\n", 229 | "print(next(iter(dl))[1].shape)\n", 230 | "model = muMLPTab9(128, classes).to(device)\n", 231 | "criterion = nn.CrossEntropyLoss()\n", 232 | "\n", 233 | "model.train()\n", 234 | "optimizer = torch.optim.SGD(model.parameters(), lr=0.001)\n", 235 | "for epoch in range(epochs):\n", 236 | " epoch_loss = 0\n", 237 | " for i, (X, y) in enumerate(dl):\n", 238 | " optimizer.zero_grad()\n", 239 | " out = model(X)\n", 240 | " loss = criterion(out, y)\n", 241 | " loss.backward()\n", 242 | " optimizer.step()\n", 243 | " epoch_loss += loss.item() * X.size(0)\n", 244 | " print(epoch_loss / len(dl.dataset))" 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": 47, 250 | "id": "885c8058", 251 | "metadata": {}, 252 | "outputs": [ 253 | { 254 | "name": "stdout", 255 | "output_type": "stream", 256 | "text": [ 257 | "torch.Size([128])\n", 258 | "0.6877436098098755\n", 259 | "0.6252798287391662\n", 260 | "0.5944725264549255\n", 261 | "0.5714316897392273\n", 262 | "0.553438679933548\n" 263 | ] 264 | } 265 | ], 266 | "source": [ 267 | "torch.manual_seed(seed)\n", 268 | "np.random.seed(seed)\n", 269 | "print(next(iter(dl))[1].shape)\n", 270 | "model = muMLPTab9(128, classes).to(device)\n", 271 | "criterion = nn.CrossEntropyLoss()\n", 272 | "\n", 273 | "model.train()\n", 274 | "optimizer = torch.optim.SGD(model.parameters(), lr=0.001)\n", 275 | "for epoch in range(epochs):\n", 276 | " epoch_loss = 0\n", 277 | " for i, (X, y) in enumerate(dl):\n", 278 | " optimizer.zero_grad()\n", 279 | " out = model(X)\n", 280 | " loss = criterion(out, y)\n", 281 | " loss.backward()\n", 282 | " optimizer.step()\n", 283 | " epoch_loss += loss.item() * X.size(0)\n", 284 | " print(epoch_loss / len(dl.dataset))" 285 | ] 286 | }, 287 | { 288 | "cell_type": "markdown", 289 | "id": "82cbb43f", 290 | "metadata": {}, 291 | "source": [ 292 | "# MSE + on GPU" 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "execution_count": 48, 298 | "id": "bb6bf0a8", 299 | "metadata": {}, 300 | "outputs": [ 301 | { 302 | "name": "stderr", 303 | "output_type": "stream", 304 | "text": [ 305 | "100%|██████████| 10000/10000 [00:02<00:00, 4303.08it/s]\n" 306 | ] 307 | }, 308 | { 309 | "name": "stdout", 310 | "output_type": "stream", 311 | "text": [ 312 | "Estimated size of the dataset in MB: 117.26\n", 313 | "79\n", 314 | "torch.Size([128, 2])\n", 315 | "0.6868212818145752\n", 316 | "0.4809396454811096\n", 317 | "0.4108572193145752\n", 318 | "0.36817205924987795\n", 319 | "0.33871090376377105\n" 320 | ] 321 | } 322 | ], 323 | "source": [ 324 | "dl, ds = get_cifar(batch_size=128, num_classes=classes, MSE=True, on_gpu=True, device=device)\n", 325 | "print(len(dl))\n", 326 | "\n", 327 | "torch.manual_seed(seed)\n", 328 | "np.random.seed(seed)\n", 329 | "print(next(iter(dl))[1].shape)\n", 330 | "model = muMLPTab9(128, classes).to(device)\n", 331 | "criterion = nn.MSELoss()\n", 332 | "\n", 333 | "model.train()\n", 334 | "optimizer = torch.optim.SGD(model.parameters(), lr=0.001)\n", 335 | "for epoch in range(epochs):\n", 336 | " epoch_loss = 0\n", 337 | " for i, (X, y) in enumerate(dl):\n", 338 | " X, y = X.to(device), y.to(device)\n", 339 | " optimizer.zero_grad()\n", 340 | " out = model(X)\n", 341 | " loss = criterion(out, y)\n", 342 | " loss.backward()\n", 343 | " optimizer.step()\n", 344 | " epoch_loss += loss.item() * X.size(0)\n", 345 | " print(epoch_loss / len(dl.dataset))" 346 | ] 347 | }, 348 | { 349 | "cell_type": "code", 350 | "execution_count": 49, 351 | "id": "f4b6f840", 352 | "metadata": {}, 353 | "outputs": [ 354 | { 355 | "name": "stdout", 356 | "output_type": "stream", 357 | "text": [ 358 | "torch.Size([128, 2])\n", 359 | "0.6868212818145752\n", 360 | "0.4809396454811096\n", 361 | "0.4108572193145752\n", 362 | "0.36817205924987795\n", 363 | "0.33871090376377105\n" 364 | ] 365 | } 366 | ], 367 | "source": [ 368 | "torch.manual_seed(seed)\n", 369 | "np.random.seed(seed)\n", 370 | "print(next(iter(dl))[1].shape)\n", 371 | "model = muMLPTab9(128, classes).to(device)\n", 372 | "criterion = nn.MSELoss()\n", 373 | "\n", 374 | "model.train()\n", 375 | "optimizer = torch.optim.SGD(model.parameters(), lr=0.001)\n", 376 | "for epoch in range(epochs):\n", 377 | " epoch_loss = 0\n", 378 | " for i, (X, y) in enumerate(dl):\n", 379 | " X, y = X.to(device), y.to(device)\n", 380 | " optimizer.zero_grad()\n", 381 | " out = model(X)\n", 382 | " loss = criterion(out, y)\n", 383 | " loss.backward()\n", 384 | " optimizer.step()\n", 385 | " epoch_loss += loss.item() * X.size(0)\n", 386 | " print(epoch_loss / len(dl.dataset))" 387 | ] 388 | }, 389 | { 390 | "cell_type": "code", 391 | "execution_count": null, 392 | "id": "81cc3ae4", 393 | "metadata": {}, 394 | "outputs": [], 395 | "source": [] 396 | } 397 | ], 398 | "metadata": { 399 | "kernelspec": { 400 | "display_name": "mup-abc", 401 | "language": "python", 402 | "name": "python3" 403 | }, 404 | "language_info": { 405 | "codemirror_mode": { 406 | "name": "ipython", 407 | "version": 3 408 | }, 409 | "file_extension": ".py", 410 | "mimetype": "text/x-python", 411 | "name": "python", 412 | "nbconvert_exporter": "python", 413 | "pygments_lexer": "ipython3", 414 | "version": "3.11.12" 415 | } 416 | }, 417 | "nbformat": 4, 418 | "nbformat_minor": 5 419 | } 420 | -------------------------------------------------------------------------------- /train_mlp.py: -------------------------------------------------------------------------------- 1 | import math 2 | from time import sleep 3 | import subprocess 4 | import itertools 5 | import numpy as np 6 | import torch.nn.functional as F 7 | from torchvision import datasets, transforms 8 | import torch 9 | from torch import nn 10 | from torch.optim import SGD, Adam 11 | import matplotlib.pyplot as plt 12 | import pandas as pd 13 | from tqdm import tqdm 14 | import argparse 15 | import torch.multiprocessing as mp 16 | 17 | def chunk_jobs(jobs, n_chunks): 18 | """Split a list of jobs into n_chunks as evenly as possible, tagging each job with a unique index.""" 19 | chunk_sizes = [len(jobs) // n_chunks] * n_chunks 20 | for i in range(len(jobs) % n_chunks): 21 | chunk_sizes[i] += 1 22 | 23 | chunks = [] 24 | start = 0 25 | idx = 0 26 | for size in chunk_sizes: 27 | chunk = [] 28 | for job in jobs[start:start + size]: 29 | chunk.append((idx, job[0], job[1])) 30 | idx += 1 31 | chunks.append(chunk) 32 | start += size 33 | 34 | return chunks 35 | 36 | def get_available_gpus(min_free_mem_gb=4): 37 | """Returns a list of GPU IDs with at least min_free_mem_gb available.""" 38 | result = subprocess.run( 39 | ['nvidia-smi', '--query-gpu=memory.free', '--format=csv,nounits,noheader'], 40 | stdout=subprocess.PIPE, 41 | stderr=subprocess.PIPE, 42 | text=True 43 | ) 44 | if result.returncode != 0: 45 | raise RuntimeError(f"nvidia-smi failed: {result.stderr}") 46 | 47 | free_memories = [int(x) for x in result.stdout.strip().split('\n')] 48 | return [i for i, mem in enumerate(free_memories) if mem >= min_free_mem_gb * 1024] 49 | 50 | 51 | def get_available_gpus(min_free_mem_gb=4, max_utilization=10): 52 | result = subprocess.run( 53 | [ 54 | 'nvidia-smi', 55 | '--query-gpu=memory.total,memory.used,utilization.gpu', 56 | '--format=csv,nounits,noheader' 57 | ], 58 | stdout=subprocess.PIPE, 59 | stderr=subprocess.PIPE, 60 | text=True 61 | ) 62 | 63 | if result.returncode != 0: 64 | raise RuntimeError(f"nvidia-smi failed: {result.stderr}") 65 | 66 | available_gpus = [] 67 | for i, line in enumerate(result.stdout.strip().split('\n')): 68 | total_str, used_str, util_str = map(str.strip, line.split(',')) 69 | total = int(total_str) # in MB 70 | used = int(used_str) # in MB 71 | util = int(util_str) # in % 72 | 73 | free_mem_gb = (total - used) / 1024 74 | if free_mem_gb >= min_free_mem_gb and util < max_utilization: 75 | available_gpus.append(i) 76 | 77 | return available_gpus 78 | 79 | 80 | def preload_subset(batch_size, subset_percentage, return_dataset=False): 81 | transform = transforms.Compose([ 82 | transforms.ToTensor(), 83 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 84 | ]) 85 | train_ds = datasets.CIFAR10(root='/tmp', train=True, download=True, transform=transform) 86 | 87 | torch.manual_seed(0) 88 | np.random.seed(0) 89 | subset_size = int(len(train_ds) * subset_percentage) 90 | indices = np.random.choice(len(train_ds), subset_size, replace=False) 91 | train_subset = torch.utils.data.Subset(train_ds, indices) 92 | xs = torch.stack([train_subset[i][0] for i in range(len(train_subset))]) 93 | ys = torch.tensor([train_subset[i][1] for i in range(len(train_subset))]) 94 | preloaded_dataset = torch.utils.data.TensorDataset(xs, ys) 95 | preloaded = torch.utils.data.DataLoader(preloaded_dataset, batch_size=batch_size, shuffle=True, num_workers=0) 96 | if return_dataset: 97 | return preloaded, preloaded_dataset 98 | 99 | return preloaded 100 | 101 | class SP_MLP(nn.Module): 102 | """Initialized according to Table1 from TP4 -- the most similar training behavior to the plots""" 103 | def __init__(self, width=128, num_classes=10): 104 | super().__init__() 105 | self.width = width 106 | self.fc_1 = nn.Linear(3072, width, bias=False) 107 | self.fc_2 = nn.Linear(width, width, bias=False) 108 | self.fc_3 = nn.Linear(width, num_classes, bias=False) 109 | self.reset_parameters() 110 | 111 | def reset_parameters(self): 112 | nn.init.normal_(self.fc_1.weight, std=1.0) 113 | nn.init.normal_(self.fc_2.weight, std=self.width**(-0.5)) 114 | nn.init.normal_(self.fc_3.weight, std=self.width**(-0.5)) 115 | 116 | def forward(self, x): 117 | x = x.view(x.size(0), -1) 118 | h = F.relu(self.fc_1(x)) 119 | h = F.relu(self.fc_2(h)) 120 | return self.fc_3(h) 121 | 122 | class NTK_MLP(nn.Module): 123 | """Initialized according to Table1 from TP4""" 124 | def __init__(self, width=128, num_classes=10): 125 | super().__init__() 126 | self.width = width 127 | self.fc_1 = nn.Linear(3072, width, bias=False) 128 | self.fc_2 = nn.Linear(width, width, bias=False) 129 | self.fc_3 = nn.Linear(width, num_classes, bias=False) 130 | self.reset_parameters() 131 | 132 | def reset_parameters(self): 133 | nn.init.normal_(self.fc_1.weight, std=self.width**(0)) 134 | nn.init.normal_(self.fc_2.weight, std=self.width**(0)) 135 | nn.init.normal_(self.fc_3.weight, std=self.width**(0)) 136 | 137 | def forward(self, x): 138 | x = x.view(x.size(0), -1) 139 | h = F.relu(self.fc_1(x)) 140 | h = F.relu(self.fc_2(h) * self.width**(-0.5)) 141 | return self.fc_3(h) * self.width**(-0.5) 142 | 143 | class demoMLP(nn.Module): 144 | """SP model from the muP demo example jupyternotebook -- doesnt show expected train behavior""" 145 | def __init__(self, width=128, num_classes=10, nonlin=F.relu, output_mult=1.0, input_mult=1.0): 146 | super().__init__() 147 | self.nonlin = nonlin 148 | self.input_mult = input_mult 149 | self.output_mult = output_mult 150 | self.fc_1 = nn.Linear(3072, width, bias=False) 151 | self.fc_2 = nn.Linear(width, width, bias=False) 152 | self.fc_3 = nn.Linear(width, num_classes, bias=False) 153 | self.reset_parameters() 154 | 155 | def reset_parameters(self): 156 | nn.init.kaiming_normal_(self.fc_1.weight, a=1, mode='fan_in') 157 | self.fc_1.weight.data /= self.input_mult**0.5 158 | nn.init.kaiming_normal_(self.fc_2.weight, a=1, mode='fan_in') 159 | nn.init.zeros_(self.fc_3.weight) 160 | 161 | def forward(self, x): 162 | x = x.view(x.size(0), -1) 163 | out = self.nonlin(self.fc_1(x) * self.input_mult**0.5) 164 | out = self.nonlin(self.fc_2(out)) 165 | return self.fc_3(out) * self.output_mult 166 | 167 | class MLP(nn.Module): 168 | """Standard MLP model -- does not show SP expected training behavior""" 169 | def __init__(self, width=128, num_classes=10): 170 | super().__init__() 171 | self.width = width 172 | self.fc_1 = nn.Linear(3072, width, bias=False) 173 | self.fc_2 = nn.Linear(width, width, bias=False) 174 | self.fc_3 = nn.Linear(width, num_classes, bias=False) 175 | 176 | def forward(self, x): 177 | x = x.view(x.size(0), -1) 178 | h = self.fc_1(x) 179 | h = F.relu(h) 180 | h = self.fc_2(h) 181 | h = F.relu(h) 182 | h = self.fc_3(h) 183 | return h 184 | 185 | class muMLPTab9(nn.Module): 186 | """muP initialized MLP model, according to Table9 from TP5 (thanks to dvruette)""" 187 | def __init__(self, width=128, num_classes=10): 188 | super().__init__() 189 | self.width = width 190 | self.input_mult = self.width**0.5 191 | self.output_mult = self.width**-0.5 192 | self.fc_1 = nn.Linear(3072, width, bias=False) 193 | self.fc_2 = nn.Linear(width, width, bias=False) 194 | self.fc_3 = nn.Linear(width, num_classes, bias=False) 195 | self.reset_parameters() 196 | 197 | def reset_parameters(self): 198 | nn.init.normal_(self.fc_1.weight, std=self.width**-0.5) # ? 1/fanout 199 | nn.init.normal_(self.fc_2.weight, std=self.width**-0.5) 200 | nn.init.normal_(self.fc_3.weight, std=self.width**-0.5) 201 | 202 | def forward(self, x): 203 | x = x.view(x.size(0), -1) 204 | h = self.input_mult * self.fc_1(x) 205 | h = self.fc_2(F.relu(h)) 206 | h = self.output_mult * self.fc_3(F.relu(h)) 207 | return h 208 | 209 | def get_parameter_groups(self, learning_rate, optimizer): 210 | ''' 211 | SGD specific muP learning rates (Table 9, TP5) 212 | *IMPORTANT* SGD in muP just takes the LR that you pass 213 | This is only here for implementation completeness 214 | ''' 215 | if optimizer == SGD: 216 | return [ 217 | {'params': self.fc_1.parameters(), 'lr': learning_rate}, 218 | {'params': self.fc_2.parameters(), 'lr': learning_rate}, 219 | {'params': self.fc_3.parameters(), 'lr': learning_rate} 220 | ] 221 | elif optimizer == Adam: 222 | '''Adam specific muP learning rates (Table 9, TP5)''' 223 | return [ 224 | {'params': self.fc_1.parameters(), 'lr': learning_rate/self.width**0.5}, 225 | {'params': self.fc_2.parameters(), 'lr': learning_rate/self.width**0.5}, 226 | {'params': self.fc_3.parameters(), 'lr': learning_rate/self.width} 227 | ] 228 | 229 | class customMLP(nn.Module): 230 | """muP initialized MLP model, according to Table9 from TP5 (thanks to dvruette)""" 231 | def __init__(self, width=128, num_classes=10): 232 | super().__init__() 233 | self.width = width 234 | self.input_mult = self.width**0.5 235 | self.output_mult = self.width**-0.5 236 | self.fc_1 = nn.Linear(3072, width, bias=False) 237 | self.fc_2 = nn.Linear(width, width, bias=False) 238 | self.fc_3 = nn.Linear(width, num_classes, bias=False) 239 | self.reset_parameters() 240 | 241 | def reset_parameters(self): 242 | nn.init.normal_(self.fc_1.weight, std=self.width**-0.5) # ? 1/fanout 243 | nn.init.normal_(self.fc_2.weight, std=self.width**-0.5) 244 | nn.init.normal_(self.fc_3.weight, std=self.width**-0.5) 245 | 246 | def forward(self, x): 247 | x = x.view(x.size(0), -1) 248 | h = self.input_mult * self.fc_1(x) 249 | h = self.fc_2(F.relu(h)) 250 | h = self.output_mult * self.fc_3(F.relu(h)) 251 | return h 252 | 253 | def get_parameter_groups(self, learning_rate, optimizer): 254 | ''' 255 | SGD specific muP learning rates (Table 9, TP5) 256 | *IMPORTANT* SGD in muP just takes the LR that you pass 257 | This is only here for implementation completeness 258 | ''' 259 | if optimizer == SGD: 260 | return [ 261 | {'params': self.fc_1.parameters(), 'lr': learning_rate}, 262 | {'params': self.fc_2.parameters(), 'lr': learning_rate}, 263 | {'params': self.fc_3.parameters(), 'lr': learning_rate} 264 | ] 265 | elif optimizer == Adam: 266 | '''Adam specific muP learning rates (Table 9, TP5)''' 267 | return [ 268 | {'params': self.fc_1.parameters(), 'lr': learning_rate/self.width**0.5}, 269 | {'params': self.fc_2.parameters(), 'lr': learning_rate/self.width**0.5}, 270 | {'params': self.fc_3.parameters(), 'lr': learning_rate/self.width} 271 | ] 272 | 273 | def train(model, train_dl, optimizer, num_epochs, device): 274 | model.train() 275 | for epoch in range(num_epochs): 276 | train_loss = 0 277 | for batch_idx, (data, target) in enumerate(train_dl): 278 | data, target = data.to(device), target.to(device) 279 | optimizer.zero_grad() 280 | output = model(data) 281 | loss = F.cross_entropy(output, target) 282 | train_loss += loss.item() * data.size(0) 283 | loss.backward() 284 | optimizer.step() 285 | 286 | return train_loss / len(train_dl.dataset) 287 | 288 | def run_chunk(jobs, device, shared_tensor, preloaded, seeds, model_class, optimizer, epochs): 289 | torch.cuda.set_device(device) 290 | for job in jobs: 291 | job_id, log2lr, width = job 292 | run_experiment(log2lr, width, seeds, job_id, device, shared_tensor, preloaded, model_class, optimizer, epochs) 293 | 294 | def run_experiment(log2lr, width, seeds, job_id, device, shared_tensor, preloaded, model_class, optimizer, epochs): 295 | train_dl = preloaded 296 | losses = [] 297 | print(f"Running job {job_id} on device {device} with log2lr={log2lr}, width={width}") 298 | for seed in seeds: 299 | torch.manual_seed(seed) 300 | np.random.seed(seed) 301 | 302 | model = model_class(width=width).to(device) 303 | # custom parameter groups for muMLP, else just use model.parameters() 304 | parameters = model.get_parameter_groups(2**log2lr, optimizer) if hasattr(model, 'get_parameter_groups') else model.parameters() 305 | optimizer = optimizer(parameters, lr=2**log2lr) 306 | loss = train(model, train_dl, optimizer, num_epochs=epochs, device=device) 307 | 308 | losses.append(loss) 309 | 310 | loss = np.mean(losses) 311 | shared_tensor[job_id] = loss 312 | print(f"Width: {width}, Log2LR: {log2lr}, Loss: {loss:.4f}, Losses: {[round(ls, 3) for ls in losses]}") 313 | 314 | if __name__ == '__main__': 315 | mp.set_start_method('spawn', force=True) 316 | 317 | parser = argparse.ArgumentParser(description="Train MLP or muMLP model.") 318 | parser.add_argument('--model', type=str, choices=['MLP', 'muMLP', 'demoMLP', 'SPMLP'], required=True, help="Choose the model type: 'MLP', 'muMLP', 'SPMLP' or 'demoMLP'") 319 | parser.add_argument('--subset', type=float, default=0.2, help="Percentage of dataset to use for training (default: 0.2)") 320 | parser.add_argument("--optimizer", type=str, default="SGD", choices=["SGD", "Adam"], help="Optimizer to use: 'SGD' or 'Adam'") 321 | parser.add_argument("--lr_range", type=float, nargs=2, default=[-12, -4], help="Range of log2 learning rates to use (default: [-16, -4])") 322 | args = parser.parse_args() 323 | 324 | if args.model == 'MLP': 325 | model_class = MLP 326 | elif args.model == 'muMLP': 327 | model_class = muMLPTab9 328 | elif args.model == 'demoMLP': 329 | model_class = demoMLP 330 | elif args.model == 'SPMLP': 331 | model_class = SP_MLP 332 | else: 333 | raise ValueError("Invalid model type. Choose 'MLP' or 'muMLP'.") 334 | print(f"Using model: {args.model}, subset: {args.subset*100}%") 335 | 336 | optimizer = SGD if args.optimizer == "SGD" else Adam 337 | print(f"Using optimizer: {args.optimizer}: {optimizer}") 338 | 339 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 340 | batch_size = 64 341 | data_dir = '/tmp' 342 | 343 | preloaded = preload_subset(batch_size, args.subset) 344 | print(f"Preloaded dataset with {len(preloaded.dataset)} samples.") 345 | 346 | min_lr, max_lr = args.lr_range 347 | print(f"Log2 learning rate range: {min_lr} to {max_lr}") 348 | 349 | epochs = 20 350 | seeds = [2137] 351 | # seeds = [0, 1, 2, 3, 4] 352 | log2lrs = np.linspace(min_lr, max_lr, 40) 353 | widths = [128, 256, 512, 1024, 2048, 4096, 8192] 354 | # widths = [128] 355 | 356 | free_memory, max_utilization = 16, 50 357 | availage_gpus = get_available_gpus(min_free_mem_gb=free_memory, max_utilization=max_utilization) 358 | if len(availage_gpus) == 0: 359 | raise RuntimeError(f"No available GPUs found with at least {free_memory}GB free memory and utilization < {max_utilization}%") 360 | availage_gpus = [0, 1, 5, 6, 7] 361 | devices = [f"cuda:{i}" for i in availage_gpus] 362 | print(f"Available devices: {len(devices)}, {availage_gpus}") 363 | 364 | jobs = list(itertools.product(log2lrs, widths)) 365 | jobs_chunks = chunk_jobs(jobs, len(devices)) 366 | print(f"Jobs: {len(jobs)}, Chunks: {len(jobs_chunks)}") 367 | 368 | processes = [] 369 | shared_tensor = torch.zeros(len(jobs)).to(device).share_memory_() 370 | pbar = tqdm(total=shared_tensor.numel(), desc="Processing", unit="item") 371 | for enum, job_chunk in enumerate(jobs_chunks): 372 | device = devices[enum] 373 | 374 | print(f"Starting process {enum} on {device} with {len(job_chunk)} jobs") 375 | p = mp.Process(target=run_chunk, args=(job_chunk, device, shared_tensor, preloaded, seeds, model_class, optimizer, epochs)) 376 | processes.append(p) 377 | p.start() 378 | 379 | while any(p.is_alive() for p in processes): 380 | done = shared_tensor.count_nonzero().item() 381 | if done > pbar.n: 382 | pbar.n = shared_tensor.count_nonzero().item() 383 | pbar.set_postfix_str(f"Completed: {shared_tensor.count_nonzero().item()}/{len(shared_tensor)}") 384 | pbar.refresh() 385 | sleep(5) 386 | pbar.close() 387 | 388 | results_df = pd.DataFrame(index=log2lrs, columns=widths) 389 | for i, job in enumerate(jobs): 390 | log2lr, width = job 391 | loss = shared_tensor[i].item() 392 | results_df.loc[log2lr, width] = loss 393 | 394 | plt.figure(figsize=(8, 4)) 395 | for width in widths: 396 | plt.plot(results_df.index, results_df[width], label=f'Width {width}') 397 | plt.xlabel('Log2LR') 398 | plt.ylabel('Loss') 399 | plt.title(f'{args.model}, {args.subset*100}% of CIFAR\nLoss vs Log2LR for different widths') 400 | plt.xlim(np.floor(results_df.index.min())-0.5, np.ceil(results_df.index.max())+0.5) 401 | plt.legend() 402 | plt.grid() 403 | plt.savefig(f'results/loss_vs_log2lr_{args.model}_{args.subset}_{args.optimizer}.png') 404 | results_df.to_csv(f'results/loss_vs_log2lr_{args.model}_{args.subset}_{args.optimizer}.csv') 405 | plt.show() 406 | --------------------------------------------------------------------------------