├── .gitignore ├── Pipfile ├── Pipfile.lock ├── README.md ├── requirements.txt ├── sample ├── main.py └── net.py └── sweep.yaml /.gitignore: -------------------------------------------------------------------------------- 1 | wandb 2 | __pycache__ 3 | dataset -------------------------------------------------------------------------------- /Pipfile: -------------------------------------------------------------------------------- 1 | [[source]] 2 | name = "pypi" 3 | url = "https://pypi.org/simple" 4 | verify_ssl = true 5 | 6 | [dev-packages] 7 | 8 | [packages] 9 | torch = "*" 10 | torchvision = "*" 11 | pandas = "*" 12 | numpy = "*" 13 | wandb = "*" 14 | 15 | [requires] 16 | python_version = "3.8" 17 | -------------------------------------------------------------------------------- /Pipfile.lock: -------------------------------------------------------------------------------- 1 | { 2 | "_meta": { 3 | "hash": { 4 | "sha256": "a42b9192b0b76b4525bc5862fe64e3ce0bb456bff131ba6bd78eccb1b49506d2" 5 | }, 6 | "pipfile-spec": 6, 7 | "requires": { 8 | "python_version": "3.8" 9 | }, 10 | "sources": [ 11 | { 12 | "name": "pypi", 13 | "url": "https://pypi.org/simple", 14 | "verify_ssl": true 15 | } 16 | ] 17 | }, 18 | "default": { 19 | "certifi": { 20 | "hashes": [ 21 | "sha256:78884e7c1d4b00ce3cea67b44566851c4343c120abd683433ce934a68ea58872", 22 | "sha256:d62a0163eb4c2344ac042ab2bdf75399a71a2d8c7d47eac2e2ee91b9d6339569" 23 | ], 24 | "version": "==2021.10.8" 25 | }, 26 | "charset-normalizer": { 27 | "hashes": [ 28 | "sha256:876d180e9d7432c5d1dfd4c5d26b72f099d503e8fcc0feb7532c9289be60fcbd", 29 | "sha256:cb957888737fc0bbcd78e3df769addb41fd1ff8cf950dc9e7ad7793f1bf44455" 30 | ], 31 | "markers": "python_version >= '3'", 32 | "version": "==2.0.10" 33 | }, 34 | "click": { 35 | "hashes": [ 36 | "sha256:353f466495adaeb40b6b5f592f9f91cb22372351c84caeb068132442a4518ef3", 37 | "sha256:410e932b050f5eed773c4cda94de75971c89cdb3155a72a0831139a79e5ecb5b" 38 | ], 39 | "markers": "python_version >= '3.6'", 40 | "version": "==8.0.3" 41 | }, 42 | "configparser": { 43 | "hashes": [ 44 | "sha256:1b35798fdf1713f1c3139016cfcbc461f09edbf099d1fb658d4b7479fcaa3daa", 45 | "sha256:e8b39238fb6f0153a069aa253d349467c3c4737934f253ef6abac5fe0eca1e5d" 46 | ], 47 | "markers": "python_version >= '3.6'", 48 | "version": "==5.2.0" 49 | }, 50 | "docker-pycreds": { 51 | "hashes": [ 52 | "sha256:6ce3270bcaf404cc4c3e27e4b6c70d3521deae82fb508767870fdbf772d584d4", 53 | "sha256:7266112468627868005106ec19cd0d722702d2b7d5912a28e19b826c3d37af49" 54 | ], 55 | "version": "==0.4.0" 56 | }, 57 | "gitdb": { 58 | "hashes": [ 59 | "sha256:8033ad4e853066ba6ca92050b9df2f89301b8fc8bf7e9324d412a63f8bf1a8fd", 60 | "sha256:bac2fd45c0a1c9cf619e63a90d62bdc63892ef92387424b855792a6cabe789aa" 61 | ], 62 | "markers": "python_version >= '3.6'", 63 | "version": "==4.0.9" 64 | }, 65 | "gitpython": { 66 | "hashes": [ 67 | "sha256:26ac35c212d1f7b16036361ca5cff3ec66e11753a0d677fb6c48fa4e1a9dd8d6", 68 | "sha256:fc8868f63a2e6d268fb25f481995ba185a85a66fcad126f039323ff6635669ee" 69 | ], 70 | "markers": "python_version >= '3.7'", 71 | "version": "==3.1.26" 72 | }, 73 | "gql": { 74 | "hashes": [ 75 | "sha256:ad0f0b8226428d727c8e1d1cac4e521d83ed024d814921bd55b8adb997dadf4b" 76 | ], 77 | "version": "==0.2.0" 78 | }, 79 | "graphql-core": { 80 | "hashes": [ 81 | "sha256:63bb8593aeeadb0a53e14207b910027fe51158d017927fad87326dac806185ee" 82 | ], 83 | "version": "==1.1" 84 | }, 85 | "idna": { 86 | "hashes": [ 87 | "sha256:84d9dd047ffa80596e0f246e2eab0b391788b0503584e8945f2368256d2735ff", 88 | "sha256:9d643ff0a55b762d5cdb124b8eaa99c66322e2157b69160bc32796e824360e6d" 89 | ], 90 | "markers": "python_version >= '3'", 91 | "version": "==3.3" 92 | }, 93 | "numpy": { 94 | "hashes": [ 95 | "sha256:1598a6de323508cfeed6b7cd6c4efb43324f4692e20d1f76e1feec7f59013448", 96 | "sha256:1b0ece94018ae21163d1f651b527156e1f03943b986188dd81bc7e066eae9d1c", 97 | "sha256:2e40be731ad618cb4974d5ba60d373cdf4f1b8dcbf1dcf4d9dff5e212baf69c5", 98 | "sha256:4ba59db1fcc27ea31368af524dcf874d9277f21fd2e1f7f1e2e0c75ee61419ed", 99 | "sha256:59ca9c6592da581a03d42cc4e270732552243dc45e87248aa8d636d53812f6a5", 100 | "sha256:5e0feb76849ca3e83dd396254e47c7dba65b3fa9ed3df67c2556293ae3e16de3", 101 | "sha256:6d205249a0293e62bbb3898c4c2e1ff8a22f98375a34775a259a0523111a8f6c", 102 | "sha256:6fcc5a3990e269f86d388f165a089259893851437b904f422d301cdce4ff25c8", 103 | "sha256:82847f2765835c8e5308f136bc34018d09b49037ec23ecc42b246424c767056b", 104 | "sha256:87902e5c03355335fc5992a74ba0247a70d937f326d852fc613b7f53516c0963", 105 | "sha256:9ab21d1cb156a620d3999dd92f7d1c86824c622873841d6b080ca5495fa10fef", 106 | "sha256:a1baa1dc8ecd88fb2d2a651671a84b9938461e8a8eed13e2f0a812a94084d1fa", 107 | "sha256:a244f7af80dacf21054386539699ce29bcc64796ed9850c99a34b41305630286", 108 | "sha256:a35af656a7ba1d3decdd4fae5322b87277de8ac98b7d9da657d9e212ece76a61", 109 | "sha256:b1fe1a6f3a6f355f6c29789b5927f8bd4f134a4bd9a781099a7c4f66af8850f5", 110 | "sha256:b5ad0adb51b2dee7d0ee75a69e9871e2ddfb061c73ea8bc439376298141f77f5", 111 | "sha256:ba3c7a2814ec8a176bb71f91478293d633c08582119e713a0c5351c0f77698da", 112 | "sha256:cd77d58fb2acf57c1d1ee2835567cd70e6f1835e32090538f17f8a3a99e5e34b", 113 | "sha256:cdb3a70285e8220875e4d2bc394e49b4988bdb1298ffa4e0bd81b2f613be397c", 114 | "sha256:deb529c40c3f1e38d53d5ae6cd077c21f1d49e13afc7936f7f868455e16b64a0", 115 | "sha256:e7894793e6e8540dbeac77c87b489e331947813511108ae097f1715c018b8f3d" 116 | ], 117 | "index": "pypi", 118 | "version": "==1.18.2" 119 | }, 120 | "nvidia-ml-py3": { 121 | "hashes": [ 122 | "sha256:390f02919ee9d73fe63a98c73101061a6b37fa694a793abf56673320f1f51277" 123 | ], 124 | "version": "==7.352.0" 125 | }, 126 | "pandas": { 127 | "hashes": [ 128 | "sha256:07c1b58936b80eafdfe694ce964ac21567b80a48d972879a359b3ebb2ea76835", 129 | "sha256:0ebe327fb088df4d06145227a4aa0998e4f80a9e6aed4b61c1f303bdfdf7c722", 130 | "sha256:11c7cb654cd3a0e9c54d81761b5920cdc86b373510d829461d8f2ed6d5905266", 131 | "sha256:12f492dd840e9db1688126216706aa2d1fcd3f4df68a195f9479272d50054645", 132 | "sha256:167a1315367cea6ec6a5e11e791d9604f8e03f95b57ad227409de35cf850c9c5", 133 | "sha256:1a7c56f1df8d5ad8571fa251b864231f26b47b59cbe41aa5c0983d17dbb7a8e4", 134 | "sha256:1fa4bae1a6784aa550a1c9e168422798104a85bf9c77a1063ea77ee6f8452e3a", 135 | "sha256:32f42e322fb903d0e189a4c10b75ba70d90958cc4f66a1781ed027f1a1d14586", 136 | "sha256:387dc7b3c0424327fe3218f81e05fc27832772a5dffbed385013161be58df90b", 137 | "sha256:6597df07ea361231e60c00692d8a8099b519ed741c04e65821e632bc9ccb924c", 138 | "sha256:743bba36e99d4440403beb45a6f4f3a667c090c00394c176092b0b910666189b", 139 | "sha256:858a0d890d957ae62338624e4aeaf1de436dba2c2c0772570a686eaca8b4fc85", 140 | "sha256:863c3e4b7ae550749a0bb77fa22e601a36df9d2905afef34a6965bed092ba9e5", 141 | "sha256:a210c91a02ec5ff05617a298ad6f137b9f6f5771bf31f2d6b6367d7f71486639", 142 | "sha256:ca84a44cf727f211752e91eab2d1c6c1ab0f0540d5636a8382a3af428542826e", 143 | "sha256:d234bcf669e8b4d6cbcd99e3ce7a8918414520aeb113e2a81aeb02d0a533d7f7" 144 | ], 145 | "index": "pypi", 146 | "version": "==1.0.3" 147 | }, 148 | "pillow": { 149 | "hashes": [ 150 | "sha256:03b27b197deb4ee400ed57d8d4e572d2d8d80f825b6634daf6e2c18c3c6ccfa6", 151 | "sha256:0b281fcadbb688607ea6ece7649c5d59d4bbd574e90db6cd030e9e85bde9fecc", 152 | "sha256:0ebd8b9137630a7bbbff8c4b31e774ff05bbb90f7911d93ea2c9371e41039b52", 153 | "sha256:113723312215b25c22df1fdf0e2da7a3b9c357a7d24a93ebbe80bfda4f37a8d4", 154 | "sha256:2d16b6196fb7a54aff6b5e3ecd00f7c0bab1b56eee39214b2b223a9d938c50af", 155 | "sha256:2fd8053e1f8ff1844419842fd474fc359676b2e2a2b66b11cc59f4fa0a301315", 156 | "sha256:31b265496e603985fad54d52d11970383e317d11e18e856971bdbb86af7242a4", 157 | "sha256:3586e12d874ce2f1bc875a3ffba98732ebb12e18fb6d97be482bd62b56803281", 158 | "sha256:47f5cf60bcb9fbc46011f75c9b45a8b5ad077ca352a78185bd3e7f1d294b98bb", 159 | "sha256:490e52e99224858f154975db61c060686df8a6b3f0212a678e5d2e2ce24675c9", 160 | "sha256:500d397ddf4bbf2ca42e198399ac13e7841956c72645513e8ddf243b31ad2128", 161 | "sha256:52abae4c96b5da630a8b4247de5428f593465291e5b239f3f843a911a3cf0105", 162 | "sha256:6579f9ba84a3d4f1807c4aab4be06f373017fc65fff43498885ac50a9b47a553", 163 | "sha256:68e06f8b2248f6dc8b899c3e7ecf02c9f413aab622f4d6190df53a78b93d97a5", 164 | "sha256:6c5439bfb35a89cac50e81c751317faea647b9a3ec11c039900cd6915831064d", 165 | "sha256:72c3110228944019e5f27232296c5923398496b28be42535e3b2dc7297b6e8b6", 166 | "sha256:72f649d93d4cc4d8cf79c91ebc25137c358718ad75f99e99e043325ea7d56100", 167 | "sha256:7aaf07085c756f6cb1c692ee0d5a86c531703b6e8c9cae581b31b562c16b98ce", 168 | "sha256:80fe92813d208ce8aa7d76da878bdc84b90809f79ccbad2a288e9bcbeac1d9bd", 169 | "sha256:95545137fc56ce8c10de646074d242001a112a92de169986abd8c88c27566a05", 170 | "sha256:97b6d21771da41497b81652d44191489296555b761684f82b7b544c49989110f", 171 | "sha256:98cb63ca63cb61f594511c06218ab4394bf80388b3d66cd61d0b1f63ee0ea69f", 172 | "sha256:9f3b4522148586d35e78313db4db0df4b759ddd7649ef70002b6c3767d0fdeb7", 173 | "sha256:a09a9d4ec2b7887f7a088bbaacfd5c07160e746e3d47ec5e8050ae3b2a229e9f", 174 | "sha256:b5050d681bcf5c9f2570b93bee5d3ec8ae4cf23158812f91ed57f7126df91762", 175 | "sha256:bb47a548cea95b86494a26c89d153fd31122ed65255db5dcbc421a2d28eb3379", 176 | "sha256:bc462d24500ba707e9cbdef436c16e5c8cbf29908278af053008d9f689f56dee", 177 | "sha256:c2067b3bb0781f14059b112c9da5a91c80a600a97915b4f48b37f197895dd925", 178 | "sha256:d154ed971a4cc04b93a6d5b47f37948d1f621f25de3e8fa0c26b2d44f24e3e8f", 179 | "sha256:d5dcea1387331c905405b09cdbfb34611050cc52c865d71f2362f354faee1e9f", 180 | "sha256:ee6e2963e92762923956fe5d3479b1fdc3b76c83f290aad131a2f98c3df0593e", 181 | "sha256:fd0e5062f11cb3e730450a7d9f323f4051b532781026395c4323b8ad055523c4" 182 | ], 183 | "index": "pypi", 184 | "version": "==9.0.0" 185 | }, 186 | "promise": { 187 | "hashes": [ 188 | "sha256:dfd18337c523ba4b6a58801c164c1904a9d4d1b1747c7d5dbf45b693a49d93d0" 189 | ], 190 | "version": "==2.3" 191 | }, 192 | "psutil": { 193 | "hashes": [ 194 | "sha256:072664401ae6e7c1bfb878c65d7282d4b4391f1bc9a56d5e03b5a490403271b5", 195 | "sha256:1070a9b287846a21a5d572d6dddd369517510b68710fca56b0e9e02fd24bed9a", 196 | "sha256:1d7b433519b9a38192dfda962dd8f44446668c009833e1429a52424624f408b4", 197 | "sha256:3151a58f0fbd8942ba94f7c31c7e6b310d2989f4da74fcbf28b934374e9bf841", 198 | "sha256:32acf55cb9a8cbfb29167cd005951df81b567099295291bcfd1027365b36591d", 199 | "sha256:3611e87eea393f779a35b192b46a164b1d01167c9d323dda9b1e527ea69d697d", 200 | "sha256:3d00a664e31921009a84367266b35ba0aac04a2a6cad09c550a89041034d19a0", 201 | "sha256:4e2fb92e3aeae3ec3b7b66c528981fd327fb93fd906a77215200404444ec1845", 202 | "sha256:539e429da49c5d27d5a58e3563886057f8fc3868a5547b4f1876d9c0f007bccf", 203 | "sha256:55ce319452e3d139e25d6c3f85a1acf12d1607ddedea5e35fb47a552c051161b", 204 | "sha256:58c7d923dc209225600aec73aa2c4ae8ea33b1ab31bc11ef8a5933b027476f07", 205 | "sha256:7336292a13a80eb93c21f36bde4328aa748a04b68c13d01dfddd67fc13fd0618", 206 | "sha256:742c34fff804f34f62659279ed5c5b723bb0195e9d7bd9907591de9f8f6558e2", 207 | "sha256:7641300de73e4909e5d148e90cc3142fb890079e1525a840cf0dfd39195239fd", 208 | "sha256:76cebf84aac1d6da5b63df11fe0d377b46b7b500d892284068bacccf12f20666", 209 | "sha256:7779be4025c540d1d65a2de3f30caeacc49ae7a2152108adeaf42c7534a115ce", 210 | "sha256:7d190ee2eaef7831163f254dc58f6d2e2a22e27382b936aab51c835fc080c3d3", 211 | "sha256:8293942e4ce0c5689821f65ce6522ce4786d02af57f13c0195b40e1edb1db61d", 212 | "sha256:869842dbd66bb80c3217158e629d6fceaecc3a3166d3d1faee515b05dd26ca25", 213 | "sha256:90a58b9fcae2dbfe4ba852b57bd4a1dded6b990a33d6428c7614b7d48eccb492", 214 | "sha256:9b51917c1af3fa35a3f2dabd7ba96a2a4f19df3dec911da73875e1edaf22a40b", 215 | "sha256:b2237f35c4bbae932ee98902a08050a27821f8f6dfa880a47195e5993af4702d", 216 | "sha256:c3400cae15bdb449d518545cbd5b649117de54e3596ded84aacabfbb3297ead2", 217 | "sha256:c51f1af02334e4b516ec221ee26b8fdf105032418ca5a5ab9737e8c87dafe203", 218 | "sha256:cb8d10461c1ceee0c25a64f2dd54872b70b89c26419e147a05a10b753ad36ec2", 219 | "sha256:d62a2796e08dd024b8179bd441cb714e0f81226c352c802fca0fd3f89eeacd94", 220 | "sha256:df2c8bd48fb83a8408c8390b143c6a6fa10cb1a674ca664954de193fdcab36a9", 221 | "sha256:e5c783d0b1ad6ca8a5d3e7b680468c9c926b804be83a3a8e95141b05c39c9f64", 222 | "sha256:e9805fed4f2a81de98ae5fe38b75a74c6e6ad2df8a5c479594c7629a1fe35f56", 223 | "sha256:ea42d747c5f71b5ccaa6897b216a7dadb9f52c72a0fe2b872ef7d3e1eacf3ba3", 224 | "sha256:ef216cc9feb60634bda2f341a9559ac594e2eeaadd0ba187a4c2eb5b5d40b91c", 225 | "sha256:ff0d41f8b3e9ebb6b6110057e40019a432e96aae2008951121ba4e56040b84f3" 226 | ], 227 | "markers": "python_version >= '2.6' and python_version not in '3.0, 3.1, 3.2, 3.3'", 228 | "version": "==5.9.0" 229 | }, 230 | "python-dateutil": { 231 | "hashes": [ 232 | "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86", 233 | "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9" 234 | ], 235 | "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'", 236 | "version": "==2.8.2" 237 | }, 238 | "pytz": { 239 | "hashes": [ 240 | "sha256:3672058bc3453457b622aab7a1c3bfd5ab0bdae451512f6cf25f64ed37f5b87c", 241 | "sha256:acad2d8b20a1af07d4e4c9d2e9285c5ed9104354062f275f3fcd88dcef4f1326" 242 | ], 243 | "version": "==2021.3" 244 | }, 245 | "pyyaml": { 246 | "hashes": [ 247 | "sha256:0283c35a6a9fbf047493e3a0ce8d79ef5030852c51e9d911a27badfde0605293", 248 | "sha256:055d937d65826939cb044fc8c9b08889e8c743fdc6a32b33e2390f66013e449b", 249 | "sha256:07751360502caac1c067a8132d150cf3d61339af5691fe9e87803040dbc5db57", 250 | "sha256:0b4624f379dab24d3725ffde76559cff63d9ec94e1736b556dacdfebe5ab6d4b", 251 | "sha256:0ce82d761c532fe4ec3f87fc45688bdd3a4c1dc5e0b4a19814b9009a29baefd4", 252 | "sha256:1e4747bc279b4f613a09eb64bba2ba602d8a6664c6ce6396a4d0cd413a50ce07", 253 | "sha256:213c60cd50106436cc818accf5baa1aba61c0189ff610f64f4a3e8c6726218ba", 254 | "sha256:231710d57adfd809ef5d34183b8ed1eeae3f76459c18fb4a0b373ad56bedcdd9", 255 | "sha256:277a0ef2981ca40581a47093e9e2d13b3f1fbbeffae064c1d21bfceba2030287", 256 | "sha256:2cd5df3de48857ed0544b34e2d40e9fac445930039f3cfe4bcc592a1f836d513", 257 | "sha256:40527857252b61eacd1d9af500c3337ba8deb8fc298940291486c465c8b46ec0", 258 | "sha256:473f9edb243cb1935ab5a084eb238d842fb8f404ed2193a915d1784b5a6b5fc0", 259 | "sha256:48c346915c114f5fdb3ead70312bd042a953a8ce5c7106d5bfb1a5254e47da92", 260 | "sha256:50602afada6d6cbfad699b0c7bb50d5ccffa7e46a3d738092afddc1f9758427f", 261 | "sha256:68fb519c14306fec9720a2a5b45bc9f0c8d1b9c72adf45c37baedfcd949c35a2", 262 | "sha256:77f396e6ef4c73fdc33a9157446466f1cff553d979bd00ecb64385760c6babdc", 263 | "sha256:819b3830a1543db06c4d4b865e70ded25be52a2e0631ccd2f6a47a2822f2fd7c", 264 | "sha256:897b80890765f037df3403d22bab41627ca8811ae55e9a722fd0392850ec4d86", 265 | "sha256:98c4d36e99714e55cfbaaee6dd5badbc9a1ec339ebfc3b1f52e293aee6bb71a4", 266 | "sha256:9df7ed3b3d2e0ecfe09e14741b857df43adb5a3ddadc919a2d94fbdf78fea53c", 267 | "sha256:9fa600030013c4de8165339db93d182b9431076eb98eb40ee068700c9c813e34", 268 | "sha256:a80a78046a72361de73f8f395f1f1e49f956c6be882eed58505a15f3e430962b", 269 | "sha256:b3d267842bf12586ba6c734f89d1f5b871df0273157918b0ccefa29deb05c21c", 270 | "sha256:b5b9eccad747aabaaffbc6064800670f0c297e52c12754eb1d976c57e4f74dcb", 271 | "sha256:c5687b8d43cf58545ade1fe3e055f70eac7a5a1a0bf42824308d868289a95737", 272 | "sha256:cba8c411ef271aa037d7357a2bc8f9ee8b58b9965831d9e51baf703280dc73d3", 273 | "sha256:d15a181d1ecd0d4270dc32edb46f7cb7733c7c508857278d3d378d14d606db2d", 274 | "sha256:d4db7c7aef085872ef65a8fd7d6d09a14ae91f691dec3e87ee5ee0539d516f53", 275 | "sha256:d4eccecf9adf6fbcc6861a38015c2a64f38b9d94838ac1810a9023a0609e1b78", 276 | "sha256:d67d839ede4ed1b28a4e8909735fc992a923cdb84e618544973d7dfc71540803", 277 | "sha256:daf496c58a8c52083df09b80c860005194014c3698698d1a57cbcfa182142a3a", 278 | "sha256:e61ceaab6f49fb8bdfaa0f92c4b57bcfbea54c09277b1b4f7ac376bfb7a7c174", 279 | "sha256:f84fbc98b019fef2ee9a1cb3ce93e3187a6df0b2538a651bfb890254ba9f90b5" 280 | ], 281 | "markers": "python_version >= '3.6'", 282 | "version": "==6.0" 283 | }, 284 | "requests": { 285 | "hashes": [ 286 | "sha256:68d7c56fd5a8999887728ef304a6d12edc7be74f1cfa47714fc8b414525c9a61", 287 | "sha256:f22fa1e554c9ddfd16e6e41ac79759e17be9e492b3587efa038054674760e72d" 288 | ], 289 | "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4, 3.5'", 290 | "version": "==2.27.1" 291 | }, 292 | "sentry-sdk": { 293 | "hashes": [ 294 | "sha256:2cec50166bcb67e1965f8073541b2321e3864cd6fd42a526bcde9f0c4e4cc3f8", 295 | "sha256:7bbaa32bba806ec629962f207b597e86831c7ee2c1f287c21ba7de7fea9a9c46" 296 | ], 297 | "version": "==1.5.2" 298 | }, 299 | "shortuuid": { 300 | "hashes": [ 301 | "sha256:44a7a86bcf24dbaba2e626cf80c779926b7c3a0d31a3a013e0d3cd1077707d23", 302 | "sha256:9435e87e5a64f3b92f7110c81f989a3b7bdb9358e22d2359829167da476cfc23" 303 | ], 304 | "markers": "python_version >= '3.5'", 305 | "version": "==1.0.8" 306 | }, 307 | "six": { 308 | "hashes": [ 309 | "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926", 310 | "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254" 311 | ], 312 | "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'", 313 | "version": "==1.16.0" 314 | }, 315 | "smmap": { 316 | "hashes": [ 317 | "sha256:2aba19d6a040e78d8b09de5c57e96207b09ed71d8e55ce0959eeee6c8e190d94", 318 | "sha256:c840e62059cd3be204b0c9c9f74be2c09d5648eddd4580d9314c3ecde0b30936" 319 | ], 320 | "markers": "python_version >= '3.6'", 321 | "version": "==5.0.0" 322 | }, 323 | "subprocess32": { 324 | "hashes": [ 325 | "sha256:88e37c1aac5388df41cc8a8456bb49ebffd321a3ad4d70358e3518176de3a56b", 326 | "sha256:e45d985aef903c5b7444d34350b05da91a9e0ea015415ab45a21212786c649d0", 327 | "sha256:eb2937c80497978d181efa1b839ec2d9622cf9600a039a79d0e108d1f9aec79d" 328 | ], 329 | "markers": "python_version >= '2.6' and python_version not in '3.0, 3.1, 3.2, 3.3' and python_version < '4'", 330 | "version": "==3.5.4" 331 | }, 332 | "torch": { 333 | "hashes": [ 334 | "sha256:271d4d1e44df6ed57c530f8849b028447c62b8a19b8e8740dd9baa56e7f682c1", 335 | "sha256:30ce089475b287a37d6fbb8d71853e672edaf66699e3dd2eb19be6ce6296732a", 336 | "sha256:405b9eb40e44037d2525b3ddb5bc4c66b519cd742bff249d4207d23f83e88ea5", 337 | "sha256:504915c6bc6051ba6a4c2a43c446463dff04411e352f1e26fe13debeae431778", 338 | "sha256:54d06a0e8ee85e5a437c24f4af9f4196c819294c23ffb5914e177756f55f1829", 339 | "sha256:6f2fd9eb8c7eaf38a982ab266dbbfba0f29fb643bc74e677d045d6f2595e4692", 340 | "sha256:8856f334aa9ecb742c1504bd2563d0ffb8dceb97149c8d72a04afa357f667dbc", 341 | "sha256:8fff03bf7b474c16e4b50da65ea14200cc64553b67b9b2307f9dc7e8c69b9d28", 342 | "sha256:9a1b1db73d8dcfd94b2eee24b939368742aa85f1217c55b8f5681e76c581e99a", 343 | "sha256:bb1e87063661414e1149bef2e3a2499ce0b5060290799d7e26bc5578037075ba", 344 | "sha256:d7b34a78f021935ad727a3bede56a8a8d4fda0b0272314a04c5f6890bbe7bb29" 345 | ], 346 | "index": "pypi", 347 | "version": "==1.4.0" 348 | }, 349 | "torchvision": { 350 | "hashes": [ 351 | "sha256:0ca9cae9ddf1784737493e201aa9411abe62a4479b2e67d1d51b4b7acf16f6eb", 352 | "sha256:1a68d3d98e074d995f3d42a492cca716b0d94605a6fadddf0ce9665425968669", 353 | "sha256:1af6d7b0a515d2a83fe9b6e7969b57ba94ba87a3333e7ed707324a5be1ef5f60", 354 | "sha256:2bf1dc1e16c73c5810d96e4ea463e61129e890100740cd57724413a84d301e41", 355 | "sha256:323500d349d8d91ce2662de41212e8eb1845c68dbf5d4f215ca1e94c7f20723b", 356 | "sha256:358967343eaba74fd748a87f40ea75ca23757e947dbef9a11cd53414d707f793", 357 | "sha256:35e9483858cf8a38debc647c74741605c5c12448d314aa96961082380aadf7e5", 358 | "sha256:4dd05cbc497210928ae3d4d6194561985263c879c3554e9f1823a0fa43d35746", 359 | "sha256:517425af7d41b64caae0f5d9e6b14eeb48d6e62d45f302b73a11a9ec5ee3b6c8", 360 | "sha256:78d455a1da7d10bd38f2e2a0d2ac285e4845c9e7e28aafdf068472cc96bd156b", 361 | "sha256:9e85ba17ff93a0cf6afd39b9a0ad56ca7321db4f1eb90d2034d3b0ecd79be47b", 362 | "sha256:a696ec5009eb52356508eb9b23ddb977043fb82ff7b204459e4c81aca1e5affe", 363 | "sha256:aa4354d339de2c5ea2633a6c94294c68bae3e42a4b099624299e2a50c9e97a85", 364 | "sha256:ec7e4cd54f5ff3a889b90f24b33da1fa9fe3f78d17348965678d9503de1e4a49", 365 | "sha256:fea3d431bf639c0719afff5972eb568ebe143eba447c1c8bb491c7dfb0025ed6" 366 | ], 367 | "index": "pypi", 368 | "version": "==0.5.0" 369 | }, 370 | "urllib3": { 371 | "hashes": [ 372 | "sha256:000ca7f471a233c2251c6c7023ee85305721bfdf18621ebff4fd17a8653427ed", 373 | "sha256:0e7c33d9a63e7ddfcb86780aac87befc2fbddf46c58dbb487e0855f7ceec283c" 374 | ], 375 | "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4' and python_version < '4'", 376 | "version": "==1.26.8" 377 | }, 378 | "wandb": { 379 | "hashes": [ 380 | "sha256:8615337150b852d08828d7d1a9615431313c0f53dcd7856b3ca7dfacb63a8f4d", 381 | "sha256:8c97124104b85407619c5c355626fecf937a77cff3f1c3ce91aa3583d786681d" 382 | ], 383 | "index": "pypi", 384 | "version": "==0.8.31" 385 | }, 386 | "watchdog": { 387 | "hashes": [ 388 | "sha256:25fb5240b195d17de949588628fdf93032ebf163524ef08933db0ea1f99bd685", 389 | "sha256:3386b367e950a11b0568062b70cc026c6f645428a698d33d39e013aaeda4cc04", 390 | "sha256:3becdb380d8916c873ad512f1701f8a92ce79ec6978ffde92919fd18d41da7fb", 391 | "sha256:4ae38bf8ba6f39d5b83f78661273216e7db5b00f08be7592062cb1fc8b8ba542", 392 | "sha256:8047da932432aa32c515ec1447ea79ce578d0559362ca3605f8e9568f844e3c6", 393 | "sha256:8f1c00aa35f504197561060ca4c21d3cc079ba29cf6dd2fe61024c70160c990b", 394 | "sha256:922a69fa533cb0c793b483becaaa0845f655151e7256ec73630a1b2e9ebcb660", 395 | "sha256:9693f35162dc6208d10b10ddf0458cc09ad70c30ba689d9206e02cd836ce28a3", 396 | "sha256:a0f1c7edf116a12f7245be06120b1852275f9506a7d90227648b250755a03923", 397 | "sha256:a36e75df6c767cbf46f61a91c70b3ba71811dfa0aca4a324d9407a06a8b7a2e7", 398 | "sha256:aba5c812f8ee8a3ff3be51887ca2d55fb8e268439ed44110d3846e4229eb0e8b", 399 | "sha256:ad6f1796e37db2223d2a3f302f586f74c72c630b48a9872c1e7ae8e92e0ab669", 400 | "sha256:ae67501c95606072aafa865b6ed47343ac6484472a2f95490ba151f6347acfc2", 401 | "sha256:b2fcf9402fde2672545b139694284dc3b665fd1be660d73eca6805197ef776a3", 402 | "sha256:b52b88021b9541a60531142b0a451baca08d28b74a723d0c99b13c8c8d48d604", 403 | "sha256:b7d336912853d7b77f9b2c24eeed6a5065d0a0cc0d3b6a5a45ad6d1d05fb8cd8", 404 | "sha256:bd9ba4f332cf57b2c1f698be0728c020399ef3040577cde2939f2e045b39c1e5", 405 | "sha256:be9be735f827820a06340dff2ddea1fb7234561fa5e6300a62fe7f54d40546a0", 406 | "sha256:cca7741c0fcc765568350cb139e92b7f9f3c9a08c4f32591d18ab0a6ac9e71b6", 407 | "sha256:d0d19fb2441947b58fbf91336638c2b9f4cc98e05e1045404d7a4cb7cddc7a65", 408 | "sha256:e02794ac791662a5eafc6ffeaf9bcc149035a0e48eb0a9d40a8feb4622605a3d", 409 | "sha256:e0f30db709c939cabf64a6dc5babb276e6d823fd84464ab916f9b9ba5623ca15", 410 | "sha256:e92c2d33858c8f560671b448205a268096e17870dcf60a9bb3ac7bfbafb7f5f9" 411 | ], 412 | "markers": "python_version >= '3.6'", 413 | "version": "==2.1.6" 414 | } 415 | }, 416 | "develop": {} 417 | } 418 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Weights & Biases の使い方 2 | 3 | 4 | 5 | - [Weights & Biases の使い方](#weights--biases-%E3%81%AE%E4%BD%BF%E3%81%84%E6%96%B9) 6 | - [まずは動かしてみる](#%E3%81%BE%E3%81%9A%E3%81%AF%E5%8B%95%E3%81%8B%E3%81%97%E3%81%A6%E3%81%BF%E3%82%8B) 7 | - [学習の監視・可視化](#%E5%AD%A6%E7%BF%92%E3%81%AE%E7%9B%A3%E8%A6%96%E3%83%BB%E5%8F%AF%E8%A6%96%E5%8C%96) 8 | - [最も基本的な使い方](#%E6%9C%80%E3%82%82%E5%9F%BA%E6%9C%AC%E7%9A%84%E3%81%AA%E4%BD%BF%E3%81%84%E6%96%B9) 9 | - [wandb.init()](#wandbinit) 10 | - [wandb.config](#wandbconfig) 11 | - [代入による設定](#%E4%BB%A3%E5%85%A5%E3%81%AB%E3%82%88%E3%82%8B%E8%A8%AD%E5%AE%9A) 12 | - [更新](#%E6%9B%B4%E6%96%B0) 13 | - [yaml から設定を読む](#yaml-%E3%81%8B%E3%82%89%E8%A8%AD%E5%AE%9A%E3%82%92%E8%AA%AD%E3%82%80) 14 | - [wandb.log()](#wandblog) 15 | - [基本](#%E5%9F%BA%E6%9C%AC) 16 | - [ステップを指定して記録](#%E3%82%B9%E3%83%86%E3%83%83%E3%83%97%E3%82%92%E6%8C%87%E5%AE%9A%E3%81%97%E3%81%A6%E8%A8%98%E9%8C%B2) 17 | - [Summary Metrics](#summary-metrics) 18 | - [ウェブ UI で見るプロットの x 軸を変える](#%E3%82%A6%E3%82%A7%E3%83%96-ui-%E3%81%A7%E8%A6%8B%E3%82%8B%E3%83%97%E3%83%AD%E3%83%83%E3%83%88%E3%81%AE-x-%E8%BB%B8%E3%82%92%E5%A4%89%E3%81%88%E3%82%8B) 19 | - [ステップ数が多くなるとサンプリングされる](#%E3%82%B9%E3%83%86%E3%83%83%E3%83%97%E6%95%B0%E3%81%8C%E5%A4%9A%E3%81%8F%E3%81%AA%E3%82%8B%E3%81%A8%E3%82%B5%E3%83%B3%E3%83%97%E3%83%AA%E3%83%B3%E3%82%B0%E3%81%95%E3%82%8C%E3%82%8B) 20 | - [その他記録できること(一部)](#%E3%81%9D%E3%81%AE%E4%BB%96%E8%A8%98%E9%8C%B2%E3%81%A7%E3%81%8D%E3%82%8B%E3%81%93%E3%81%A8%E4%B8%80%E9%83%A8) 21 | - [ログインについての詳細](#%E3%83%AD%E3%82%B0%E3%82%A4%E3%83%B3%E3%81%AB%E3%81%A4%E3%81%84%E3%81%A6%E3%81%AE%E8%A9%B3%E7%B4%B0) 22 | - [API 経由で W&B を使う方法](#api-%E7%B5%8C%E7%94%B1%E3%81%A7-wb-%E3%82%92%E4%BD%BF%E3%81%86%E6%96%B9%E6%B3%95) 23 | - [W&B Sweeps でハイパラサーチ](#wb-sweeps-%E3%81%A7%E3%83%8F%E3%82%A4%E3%83%91%E3%83%A9%E3%82%B5%E3%83%BC%E3%83%81) 24 | - [概要](#%E6%A6%82%E8%A6%81) 25 | - [とりあえず Sweeps を使ってみる](#%E3%81%A8%E3%82%8A%E3%81%82%E3%81%88%E3%81%9A-sweeps-%E3%82%92%E4%BD%BF%E3%81%A3%E3%81%A6%E3%81%BF%E3%82%8B) 26 | - [Under the hood](#under-the-hood) 27 | - [やっていること](#%E3%82%84%E3%81%A3%E3%81%A6%E3%81%84%E3%82%8B%E3%81%93%E3%81%A8) 28 | - [設定ファイル例とコマンドライン引数](#%E8%A8%AD%E5%AE%9A%E3%83%95%E3%82%A1%E3%82%A4%E3%83%AB%E4%BE%8B%E3%81%A8%E3%82%B3%E3%83%9E%E3%83%B3%E3%83%89%E3%83%A9%E3%82%A4%E3%83%B3%E5%BC%95%E6%95%B0) 29 | - [Caveat](#caveat) 30 | - [設定ファイルの各項目について](#%E8%A8%AD%E5%AE%9A%E3%83%95%E3%82%A1%E3%82%A4%E3%83%AB%E3%81%AE%E5%90%84%E9%A0%85%E7%9B%AE%E3%81%AB%E3%81%A4%E3%81%84%E3%81%A6) 31 | - [metric](#metric) 32 | - [method](#method) 33 | - [parameters](#parameters) 34 | - [earlyterminate](#earlyterminate) 35 | - [その他細かいテクニック](#%E3%81%9D%E3%81%AE%E4%BB%96%E7%B4%B0%E3%81%8B%E3%81%84%E3%83%86%E3%82%AF%E3%83%8B%E3%83%83%E3%82%AF) 36 | - [Ray Tune](#ray-tune) 37 | 38 | 39 | 40 | ## まずは動かしてみる 41 | 42 | PyTorch を使った MNIST によるサンプルコードが `./sample` 以下に同梱。 43 | 44 | 1. W&B のアカウント作る 45 | 46 | 2. プロジェクトを作成する。名前は `sample-pytorch-mnist` にする(サンプルコードの中で指定している) 47 | 48 | 3. クライアントのインストールとログインをする。ブラウザが開くので API キーをコピペする。ログインの詳細は後述。 49 | 50 | ``` 51 | pip3 install wandb 52 | wandb login 53 | ``` 54 | 55 | 4. Run 🚀 56 | 57 | ``` 58 | pip3 install -r requirements.txt 59 | python3 sample/main.py 60 | ``` 61 | 62 | 5. ウェブ UI のダッシュボードで経過を確認 🔍 63 | 64 | ## 学習の監視・可視化 65 | 66 | ### 最も基本的な使い方 67 | 68 | `wandb.log({ 'loss': 0.2 })` などすると、リアルタイムで記録が送信され、ウェブ UI で確認できる。 69 | また、記録に関する情報が `./wandb` ディレクトリに諸々が保存されていく。 70 | 71 | ```python 72 | import wandb 73 | 74 | default_hyperparams = { 75 | 'some_hyperparam1': val1, 76 | 'some_hyperparam2': val2 77 | } 78 | 79 | wandb.init( 80 | config=default_hyperparams, 81 | project="project-name", 82 | name="name-of-this-run" 83 | ) 84 | 85 | # ...some ML code 86 | 87 | wandb.log({ 'loss': loss }) 88 | ``` 89 | 90 | ### wandb.init() 91 | 92 | よく使いそうな引数は以下の通り 93 | 94 | - project: プロジェクトの名前(str) 95 | - name: 実行(run と呼ばれる)ごとに名前をつけられる。name とは別にユニークな ID が割り振られるので重複してもよい。与えなかった場合適当な名前が自動で割り振られる(str) 96 | - notes: その実行に関する備考などを書いておくと、ウェブ UI に表示される(str) 97 | - config: その実行に関する設定。これもウェブ UI に表示される。ハイパーパラメータなどを記録しておくと良い(dict-like) 98 | - id: 自分で ID を指定することもできる。(str) 99 | 100 | その他は以下に 101 | https://docs.wandb.com/library/init 102 | 103 | ### wandb.config 104 | 105 | `wantdb.init()` の `config` 引数に渡す以外にも `wandb.config` でも設定できる。 106 | 107 | #### 代入による設定 108 | 109 | ```python 110 | wandb.config.epochs = 4 111 | wandb.config.batch_size = 32 112 | ``` 113 | 114 | #### 更新 115 | 116 | ```python 117 | wandb.config.update({"epochs": 8, "batch_size": 64}) 118 | ``` 119 | 120 | #### yaml から設定を読む 121 | 122 | デフォルトでは `config-defaults.yaml` に書いておくと wandb が勝手に読んでくれる。 123 | 124 | ```yaml 125 | epochs: 126 | desc: Number of epochs to train over 127 | value: 100 128 | batch_size: 129 | desc: Size of each mini-batch 130 | value: 32 131 | ``` 132 | 133 | dict を使った設定との共存もできる。 134 | 135 | ```python 136 | hyperparameter_defaults = dict( 137 | dropout = 0.5, 138 | batch_size = 100, 139 | learning_rate = 0.001, 140 | ) 141 | 142 | config_dictionary = dict( 143 | yaml=my_yaml_file, 144 | params=hyperparameter_defaults, 145 | ) 146 | 147 | wandb.init(config=config_dictionary) 148 | ``` 149 | 150 | コマンドライン引数から渡すとか他の使い方は以下に 151 | https://docs.wandb.com/library/config 152 | 153 | ### wandb.log() 154 | 155 | #### 基本 156 | 157 | `wandb` は `history` (多分 `dict` の `list`)を持っており、`wandb.log()` が呼ばれるたびに引数に渡した `dict` がこれに `append` されていく。 158 | 159 | ```python 160 | wandb.log({ 'accuracy': 0.9, 'epoch': 5 }) 161 | ``` 162 | 163 | #### ステップを指定して記録 164 | 165 | 一つのステップの中で数カ所に分けて `wandb.log()` を呼びたい場合は `step` を明示的に指定する。 166 | 167 | ```python 168 | wandb.log({ 'accuracy': 0.9 }, step=10) 169 | wandb.log({ 'epoch': 5 }, step=10) 170 | ``` 171 | 172 | または `commit=False` を渡す。 173 | 174 | ```python 175 | # まだ記録されない 176 | wandb.log({ 'accuracy': 0.9 }, commit=False) 177 | # ここで { 'accuracy': 0.9, 'loss': 0.2 } が記録される 178 | wandb.log({ 'loss': 0.2 }) 179 | ``` 180 | 181 | #### Summary Metrics 182 | 183 | `wandb.log()` で記録したメトリクスの最後の値がそれぞれ自動で保存され、ウェブ UI ダッシュボードの Summary 欄に表示される。また、以下のように明示的に保存することもできる。 184 | 185 | ```python 186 | # loss: 0.1 が Summary に記録される 187 | wandb.log({ 'loss': 0.3 }) 188 | wandb.log({ 'loss': 0.2 }) 189 | wandb.log({ 'loss': 0.1 }) 190 | 191 | # 明示的に保存 192 | wandb.run.summary["test_accuracy"] = test_accuracy 193 | ``` 194 | 195 | #### ウェブ UI で見るプロットの x 軸を変える 196 | 197 | ウェブ UI でプロットを見るとき、x 軸を自由に設定できる。例えばバッチを x 軸にして経過を見たいときはバッチを記録に入れておくなどした上でダッシュボードで x 軸をそのキーで指定する。 198 | 199 | ```python 200 | wandb.log({ 'batch': 5, ... }) 201 | ``` 202 | 203 | #### ステップ数が多くなるとサンプリングされる 204 | 205 | ウェブ UI で見れるプロットはデータが 1000 を超えると 1000 個だけランダムにサンプリングされるので注意。見るたびに微妙にプロットが違うということが起きうる。 206 | 207 | #### その他記録できること(一部) 208 | 209 | matplotlib.pyplot オブジェクトを渡すと[ploty](https://plot.ly/) に変換して記録するらしい(要検証) 210 | 211 | ```python 212 | plt.plot( ... ) 213 | wandb.log( { 'chart': plt } ] ) 214 | ``` 215 | 216 | - 画像 217 | - 動画 218 | - 音声 219 | - テキスト/ひょう/HTML 220 | - 点群データ 221 | 222 | など。詳しくは以下を参照 223 | https://docs.wandb.com/library/log 224 | 225 | ## ログインについての詳細 226 | 227 | `wandb` は W&B のプログラマブルなクライアントという感じなのでログインが必要。アカウントを持っていない場合は先にサインアップする。 228 | 229 | ``` 230 | wandb login 231 | ``` 232 | 233 | ブラウザが開いて API キーが出るのでそれをコピーして入力する。 234 | ブラウザが使えない環境の場合、 235 | https://app.wandb.ai/authorize 236 | に行くと API キーが払い出されるので、これを入力する。ここで入力された API キーは `~/.netrc` に保存される。 237 | 238 | ``` 239 | machine api.wandb.ai 240 | login user 241 | password XXXXXXXXXXXXXXXXXXXXXXXXXXX(API key) 242 | ``` 243 | 244 | もしくは環境変数 `WANDB_API_KEY` に API キーをセットするとそれを読んでくれる。 245 | 246 | ## API 経由で W&B を使う方法 247 | 248 | 記録されたメトリクスを取得してきてスクリプトでなにかやるとかに使える。 249 | https://docs.wandb.com/library/api/examples 250 | 251 | ## W&B Sweeps でハイパラサーチ 252 | 253 | ### 概要 254 | 255 | Sweeps は自動でハイパラサーチをやるためのツール。 256 | サーチを管理するためのサーバがあり、ここに学習を行うマシン(複数可能)が学習の結果を報告し、管理サーバはそれを受けて学習のスケジューリングとか割り当てを行う。 257 | 258 | 大まかなフローは 259 | 260 | 1. yaml に探索範囲を記述し、それを Sweep サーバに送る 261 | 2. Sweep ID が返ってくるので、学習用のマシンでこれを引数に渡して Sweep エージェント起動 262 | 3. 学習を始めてくれる 263 | 264 | https://docs.wandb.com/sweeps 265 | 266 | ### とりあえず Sweeps を使ってみる 267 | 268 | 1. (optional) `python` コマンドが 3 系でないなどの場合、virtualenv や pipenv を使う。pipenv を使う場合、`Pipenv`/`Pipenv.lock` ファイルが同梱されているので `pipenv sync` しておく。 269 | 270 | 2. まずはログインしておく 271 | 272 | 3. `wandb sweep ./sweep.yaml -p {project-name (e.g. sample-pytorch-mnist)}` 273 | これによって探索範囲や最適化したいメトリクスを W&B Sweeps の管理サーバに送信する。Sweep ID が払い出されるのでこれをコピーする。 274 | 275 | 4. `wandb agent {project-name}/{Sweep ID}` もしくは pipenv 使用の場合、 276 | `pipenv run wandb agent {project-name}/{Sweep ID}` 277 | これでエージェントが立ち上がり、サーチが開始される。 278 | 279 | 5. (optional) 分散してサーチしたい場合、別のマシンでステップ 3 を行うと Sweep サーバがよしなに仕事を割り当ててくれる。 280 | 281 | 6. ウェブ UI のダッシュボードで探索の進捗を確認。 282 | 283 | ### Under the hood 284 | 285 | #### やっていること 286 | 287 | yaml の設定ファイルでは主に 288 | 289 | - 走らせるスクリプトのパス 290 | - 最適化したいメトリクス 291 | - 探索したいハイパーパラメータとその範囲 292 | - サーチアルゴリズム e.g. ベイズ最適化/グリッドサーチ/ランダムサーチ 293 | 294 | を指定するが、wandb は単純に以下のようなコマンドを 295 | 296 | ```sh 297 | python path/to/script.py --hyperparam1=val1 --hyperparam2=val2 298 | ``` 299 | 300 | メトリクスを見てハイパーパラメータを調整しながら逐次実行している。 301 | 302 | よって `argparse` などを使ってコマンドライン引数から設定を読み込むようにするのが良さそう。また Python のバージョン指定などは `pipenv` などを使うのが丸い。 [argparse も pipenv も使わない方法](https://docs.wandb.com/sweeps/faq#sweep-with-custom-commands)もあるが、いまいち挙動がはっきりしないので大人しくこれらを使った方が良い。 303 | 304 | #### 設定ファイル例とコマンドライン引数 305 | 306 | 以下の設定ファイルは `wadb.log()` で記録される `val_loss` を最小化するようにサーチを行う。調整されるパラメータは `lr` と `optimizer` の二つで、それぞれ同じ名前でコマンドライン引数として渡される。 307 | 308 | ```yaml 309 | program: ./path/to/script.py 310 | method: bayes 311 | metric: 312 | name: val_loss 313 | goal: minimize 314 | parameters: 315 | lr: 316 | min: 0.001 317 | max: 0.1 318 | optimizer: 319 | values: ["adam", "sgd"] 320 | ``` 321 | 322 | よって、スクリプト側では 323 | 324 | ```python 325 | import argparse 326 | 327 | parser = argparse.ArgumentParser() 328 | parser.add_argument('--lr', type=float, default=0.01) 329 | parser.add_argument('--optimizer', default='sgd', choices=['adam', 'sgd']) 330 | args = parser.parse_args() 331 | 332 | params = { 'lr': args.lr } 333 | wandb.init(config=params, project='sample-pytorch-mnist', 334 | name='wandb-test-run') 335 | 336 | if args.optimizer == 'sgd': 337 | optimizer = optim.SGD(net.parameters(), lr=args.lr) 338 | else: 339 | optimizer = optim.Adam(lr=args.lr) 340 | ``` 341 | 342 | などとする。 343 | 344 | #### Caveat 345 | 346 | - _ハマりポイントとして、`wandb.init()`でパラメータの初期化が行われなければならない点がある。wandb.config.update()を使うとエラーになるので注意_ 347 | - _最適化したいメトリクスは `wandb.log()`で記録されるようにしないといけない点に注意_ 348 | - _また`grid` (グリッドサーチ) を使うとき、パラメータは `values` で候補を与えなければエラーになる(当然だが)_ 349 | 350 | ### 設定ファイルの各項目について 351 | 352 | #### metric 353 | 354 | - name: (`str`) 最適化するメトリクス 355 | - goal: (`maximize` | `minimize`) 356 | - target: (`float`) ここで指定した値を達成したら探索を終了する 357 | 358 | 例 359 | 360 | ```yaml 361 | metric: 362 | name: val_loss 363 | goal: maximize 364 | target: 0.1 365 | ``` 366 | 367 | #### method 368 | 369 | サーチアルゴリズムを以下から指定する。 370 | 371 | - `bayes` (ベイズ最適化) 372 | - `grid` (グリッドサーチ) 373 | - `random` (ランダムサーチ) 374 | 375 | ランダムサーチは止めない限り探索し続けるが、 376 | `wandb agent --count N SWEEPID` 377 | などとすると `N` 回だけ探索する。 378 | 379 | #### parameters 380 | 381 | 探索されるべきハイパーパラメータを記述する。複数のハイパーパラメータを記述でき、そのそれぞれに対して範囲か候補のリストを指定する。 382 | よく使われるものは以下のとおり 383 | 384 | - min,max: (`int`,`int` | `float`,`float`) 最小値と最大値。 385 | min,max ともに `int` だった場合 `min`と`max` 間の整数からなる離散的な範囲になり、`float` だった場合連続な範囲になる。 386 | - values: (`List[float]`) 候補のリスト 387 | 388 | その他は以下より。分布の指定などができる模様。 389 | https://docs.wandb.com/sweeps/configuration#parameters 390 | 391 | 例 392 | 393 | ```yaml 394 | parameters: 395 | param1: 396 | min: 1 397 | max: 20 398 | param2: 399 | distribution: "normal" 400 | min: -1.0 401 | max: 1.0 402 | param3: 403 | values: ["sgd", "adadelta", "adam"] 404 | ``` 405 | 406 | #### early_terminate 407 | 408 | Hyperband を使い、パフォーマンスが高くない設定を途中で止めて次に進むことでサーチにかかる時間短縮を測るための設定(未検証) 409 | 詳細は https://docs.wandb.com/sweeps/configuration#stopping-criteria を参照。 410 | 411 | Hyperband の詳細 ↓ 412 | [Hyperband: A Novel Bandit-Based Approach to Hyperparameter Optimization](https://arxiv.org/abs/1603.06560) 413 | 414 | 例 415 | 416 | ```yaml 417 | early_terminate: 418 | type: hyperband 419 | min_iter: 3 420 | ``` 421 | 422 | ```yaml 423 | early_terminate: 424 | type: hyperband 425 | max_iter: 27 426 | s: 2 427 | ``` 428 | 429 | ### その他細かいテクニック 430 | 431 | - グリッドサーチを行なったあと、いくつの設定だけやり直したい場合、該当する run をダッシュボードから削除して再度走らせると、その削除された設定だけ再び探索される。 432 | 433 | ### Ray Tune 434 | 435 | (未検証 & まだベータ版) [Ray Tune](https://ray.readthedocs.io/en/latest/tune.html) が統合されているのでこれを使ってサーチもできる模様。 436 | https://docs.wandb.com/sweeps/ray-tune 437 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | pandas 4 | numpy 5 | wandb 6 | -------------------------------------------------------------------------------- /sample/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torchvision 5 | import torchvision.transforms as transforms 6 | import numpy as np 7 | import wandb 8 | import argparse 9 | 10 | from net import Net 11 | 12 | parser = argparse.ArgumentParser( 13 | description='Sample NMIST to demonstrate how to use W&B') 14 | parser.add_argument('--lr', type=float, default=0.01) 15 | parser.add_argument('--epochs', type=int, default=2) 16 | parser.add_argument('--batch_size', type=int, default=256) 17 | parser.add_argument('--optimizer', default='sgd', choices=['sgd', 'adam']) 18 | 19 | args = parser.parse_args() 20 | 21 | # wandb setup 22 | hyperparams = { 23 | 'epochs': args.epochs, 24 | 'batch_size': args.batch_size, 25 | 'lr': args.lr 26 | } 27 | wandb.init(config=hyperparams, project='sample-pytorch-mnist', 28 | name='wandb-test-run') 29 | 30 | # dataset preparation 31 | normalization = transforms.Compose( 32 | [transforms.ToTensor(), 33 | transforms.Normalize((0.5,), (0.5,))]) 34 | 35 | dataset_train = torchvision.datasets.MNIST( 36 | root='./dataset', train=True, download=True, transform=normalization) 37 | dataloader_train = torch.utils.data.DataLoader( 38 | dataset_train, batch_size=args.batch_size, shuffle=True) 39 | 40 | dataset_test = torchvision.datasets.MNIST( 41 | root='./dataset', train=False, download=True, transform=normalization) 42 | dataloader_test = torch.utils.data.DataLoader( 43 | dataset_test, batch_size=args.batch_size, shuffle=True) 44 | 45 | classes = tuple(np.linspace(0, 9, 10, dtype=np.uint8)) 46 | 47 | # model setup 48 | net = Net() 49 | criterion = nn.CrossEntropyLoss() 50 | if args.optimizer == 'sgd': 51 | optimizer = optim.SGD(net.parameters(), lr=args.lr) 52 | else: 53 | optimizer = optim.Adam(lr=args.lr) 54 | 55 | # training 56 | for epoch in range(args.epochs): 57 | for i, (X, Y) in enumerate(dataloader_train, 0): 58 | optimizer.zero_grad() 59 | 60 | Yhat = net(X) 61 | loss = criterion(Yhat, Y) 62 | loss.backward() 63 | optimizer.step() 64 | 65 | wandb.log({'epoch': epoch+1, 'batch': i+1, 'loss': loss.item()}) 66 | print('[{:d}, {:5d}] loss: {:.3f}'.format( 67 | epoch + 1, i + 1, loss.item())) 68 | 69 | 70 | # testing 71 | correct = 0 72 | total = 0 73 | with torch.no_grad(): 74 | for (X, Y) in iter(dataloader_test): 75 | Yhat = net(X) 76 | _, predicted = torch.max(Yhat.data, 1) 77 | total += Y.size(0) 78 | correct += (predicted == Y).sum().item() 79 | loss = criterion(Yhat, Y).item() 80 | 81 | test_accuracy = float(correct/total) 82 | print('Accuracy: {:.2f} %%'.format(100 * test_accuracy)) 83 | 84 | wandb.run.summary["test_accuracy"] = test_accuracy 85 | -------------------------------------------------------------------------------- /sample/net.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class Net(nn.Module): 6 | def __init__(self): 7 | super(Net, self).__init__() 8 | self.conv1 = nn.Conv2d(1, 32, 3) 9 | self.conv2 = nn.Conv2d(32, 64, 3) 10 | self.pool = nn.MaxPool2d(2, 2) 11 | self.dropout1 = nn.Dropout2d() 12 | self.fc1 = nn.Linear(12*12*64, 128) 13 | self.dropout2 = nn.Dropout2d() 14 | self.fc2 = nn.Linear(128, 10) 15 | 16 | def forward(self, x): 17 | x = F.relu(self.conv1(x)) 18 | x = self.pool(F.relu(self.conv2(x))) 19 | x = self.dropout1(x) 20 | x = x.view(-1, 12*12*64) 21 | x = F.relu(self.fc1(x)) 22 | x = self.dropout2(x) 23 | x = self.fc2(x) 24 | return x 25 | -------------------------------------------------------------------------------- /sweep.yaml: -------------------------------------------------------------------------------- 1 | program: ./sample/main.py 2 | method: bayes 3 | metric: 4 | name: loss 5 | goal: minimize 6 | parameters: 7 | lr: 8 | min: 0.001 9 | max: 0.1 10 | optimizer: 11 | values: ["adam", "sgd"] 12 | --------------------------------------------------------------------------------