├── .gitignore ├── README.md ├── fig ├── appendix │ └── run_val.PNG ├── result_fig-attempt3 │ ├── result_fig-attempt3_noLARS │ │ ├── noLars-1024.jpg │ │ ├── noLars-128.jpg │ │ ├── noLars-2048.jpg │ │ ├── noLars-256.jpg │ │ ├── noLars-4096.jpg │ │ ├── noLars-512.jpg │ │ └── noLars-8192.jpg │ └── result_fig-attempt3_withLARS │ │ ├── withLars-1024.jpg │ │ ├── withLars-128.jpg │ │ ├── withLars-2048.jpg │ │ ├── withLars-256.jpg │ │ ├── withLars-4096.jpg │ │ ├── withLars-512.jpg │ │ └── withLars-8192.jpg ├── result_fig-attempt4 │ ├── result_fig-noLARS │ │ ├── noLars-1024.jpg │ │ ├── noLars-128.jpg │ │ ├── noLars-2048.jpg │ │ ├── noLars-256.jpg │ │ ├── noLars-4096.jpg │ │ ├── noLars-512.jpg │ │ └── noLars-8192.jpg │ └── result_fig-withLARS │ │ ├── withLars-1024.jpg │ │ ├── withLars-128.jpg │ │ ├── withLars-2048.jpg │ │ ├── withLars-256.jpg │ │ ├── withLars-4096.jpg │ │ ├── withLars-512.jpg │ │ └── withLars-8192.jpg └── result_fig-attempt5 │ ├── result_fig-noLARS │ ├── noLars-1024.jpg │ ├── noLars-1024.pth │ ├── noLars-128.jpg │ ├── noLars-128.pth │ ├── noLars-2048.jpg │ ├── noLars-2048.pth │ ├── noLars-256.jpg │ ├── noLars-256.pth │ ├── noLars-4096.jpg │ ├── noLars-4096.pth │ ├── noLars-512.jpg │ ├── noLars-512.pth │ ├── noLars-8192.jpg │ └── noLars-8192.pth │ └── result_fig-withLARS │ ├── withLars-1024.jpg │ ├── withLars-1024.pth │ ├── withLars-128.jpg │ ├── withLars-128.pth │ ├── withLars-2048.jpg │ ├── withLars-2048.pth │ ├── withLars-256.jpg │ ├── withLars-256.pth │ ├── withLars-4096.jpg │ ├── withLars-4096.pth │ ├── withLars-512.jpg │ ├── withLars-512.pth │ ├── withLars-8192.jpg │ └── withLars-8192.pth ├── hyperparams.py ├── optimizer.py ├── scheduler.py ├── test_lars.py ├── train.py ├── train_with_matplot.py ├── utils.py └── val.py /.gitignore: -------------------------------------------------------------------------------- 1 | core.* 2 | data/* 3 | .ipynb_checkpoints/* 4 | __pycache__/* 5 | checkpoint/* 6 | *.swp 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch-LARS 2 | 3 | ## Objective 4 | 5 | - link: ["Large Batch Training of Convolutional Networks (LARS)"](https://arxiv.org/abs/1708.03888) 6 | - 위 논문에 소개된 LARS를 PyTorch, CUDA로 구현 7 | - Data: CIFAR10 8 | 9 | ## Requirements 10 | 11 | - python == 3.6.8 12 | - pytorch >= 1.1.0 13 | - cuda >= 10 14 | - matplotlib >= 3.1.0 (option) 15 | - etc. 16 | 17 | ## Usage 18 | 19 | - Train 20 | 21 | ```bash 22 | $ git clone https://github.com/cmpark0126/pytorch-LARS.git 23 | $ cd pytorch-LARS/ 24 | $ vi hyperparams.py # 학습을 위해 Basic, Hyperparams class 수정 25 | $ python train.py # CIFAR10 학습 시작 26 | ``` 27 | 28 | - Evaluate 29 | 30 | ```bash 31 | $ vi hyperparams.py # 학습 결과 확인을 위해 Hyperparams_for_val class 조정, 특정한 checkpoint를 선택하는 것이 가능 32 | $ python val.py # 학습 결과 확인, 이걸로 학습 진행 도중 update되어온 test accuracy의 history 확인 가능 33 | ``` 34 | 35 | ## Hyperparams (hyperparams.py) 36 | 37 | - Base (class) 38 | 39 | - batch_size: 기준 Batch size. 실험에서 사용되는 모든 Batch size는 이 size의 배수 형태로 나타난다. 40 | 41 | - lr: 기준 learning rate. 일반적으로 linear scailing에서 기준 값으로 사용한다. 42 | 43 | - multiples: 아래에서 설명되는 k를 구하기 위한 지수로 사용되는 배수이다. 44 | 45 | - Hyperparams (class) 46 | 47 | - batch_size: 실제 학습에서 사용하는 batch size 48 | 49 | - lr: 실제 학습에서 초기 값으로 사용하는 learning rate 50 | 51 | - momentum 52 | 53 | - weight_decay 54 | 55 | - trust_coef: trust coefficient로 LARS 사용시에 내부에서 구해지는 Local LR의 신뢰도를 의미 56 | 57 | - warmup_multiplier 58 | 59 | - warmup_epoch 60 | 61 | - max_decay_epoch: polynomial decay를 최대한 진행할 epoch 수 62 | 63 | - end_learning_rate: decay 작업이 모두 완료되었을 때 learning rate가 수렴될 값 64 | 65 | - num_of_epoch: 학습을 돌릴 총 epoch 수 66 | 67 | - with_lars 68 | 69 | - Hyperparams_for_val (class) 70 | 71 | - checkpoint_folder_name: hyperparams.py와 같은 폴더에는 파라미터를 모아둔 checkpoint folder가 존재해야 하며, 이들 중 하나의 이름을 지정(eg. checkpoint_folder_name = 'checkpoint-attempt1') 72 | 73 | - with_lars: checkpoint 중, lars를 사용한 것 혹은 사용하지 않은 것을 선택 74 | 75 | - batch_size: checkpoint 중, 사용한 batch_size 크기를 지정 76 | 77 | - device: evaluation을 위해 모델을 돌릴 때 사용할 cuda device 선택 78 | 79 | ## Demonstration 80 | 81 | - Terminology 82 | - k 83 | - we increase the batch B by k 84 | - start batch size is 128 85 | - if we use 256 as batch size, k is 2 in this time 86 | - **k = (2 \*\* (multiples - 1))** 87 | - (base line) 88 | - target accuracy which we want to get when we train the model using large batch size with LARS 89 | 90 | * * * 91 | 92 | ### Attempt 1 93 | 94 | - Configuration 95 | 96 | - Hyperparams 97 | 98 | - momentum = 0.9 99 | 100 | - weigth_decay 101 | 102 | - noLars -> 5e-04 103 | - withLARS -> 5e-03 104 | 105 | - warm-up for 5 epoch 106 | 107 | - warmup_multiplier = k 108 | - target lr follows linear scailing rule 109 | 110 | - polynomial decay (power=2) LR policy (after warm-up) 111 | 112 | - for 200 epoch 113 | - minimum lr = 1.5e-05 \* k 114 | 115 | - number of epoch = 200 116 | 117 | - Without LARS 118 | 119 | | Batch | Base LR | top-1 Accuracy, % | Time to train | 120 | | :---: | :-----: | :--------------------: | :-----------: | 121 | | 128 | 0.15 | 89.15 %
(base line) | 2113.52 sec | 122 | | 256 | 0.15 | 88.43 % | 1433.38 sec | 123 | | 512 | 0.15 | 88.72 % | 1820.35 sec | 124 | | 1024 | 0.15 | 87.96 % | 1303.54 sec | 125 | | 2048 | 0.15 | 87.05 % | 1827.90 sec | 126 | | 4096 | 0.15 | 78.03 % | 2083.24 sec | 127 | | 8192 | 0.15 | 14.59 % | 1459.81 sec | 128 | 129 | - With LARS (closest one to base line, for comparing time to train) 130 | 131 | | Batch | Base LR | top-1 Accuracy, % | Time to train | 132 | | :---: | :-----: | :---------------: | :-----------: | 133 | | 128 | 0.15 | 89.16 % | 3203.54 sec | 134 | | 256 | 0.15 | 89.19 % | 2147.74 sec | 135 | | 512 | 0.15 | 89.29 % | 1677.25 sec | 136 | | 1024 | 0.15 | 89.17 % | 1604.91 sec | 137 | | 2048 | 0.15 | 88.70 % | 1413.10 sec | 138 | | 4096 | 0.15 | 86.78 % | 1609.08 sec | 139 | | 8192 | 0.15 | 80.85 % | 1629.48 sec | 140 | 141 | - With LARS (best accuracy) 142 | 143 | | Batch | Base LR | top-1 Accuracy, % | Time to train | 144 | | :---: | :-----: | :---------------: | :-----------: | 145 | | 128 | 0.15 | 89.62 % | 3606.08 sec | 146 | | 256 | 0.15 | 89.78 % | 2675.04 sec | 147 | | 512 | 0.15 | 89.38 % | 1712.90 sec | 148 | | 1024 | 0.15 | 89.22 % | 1967.92 sec | 149 | | 2048 | 0.15 | 88.70 % | 1413.10 sec | 150 | | 4096 | 0.15 | 86.78 % | 1609.08 sec | 151 | | 8192 | 0.15 | 80.85 % | 1629.48 sec | 152 | 153 | * * * 154 | 155 | ### Attempt 2 156 | 157 | - Configuration 158 | 159 | - Hyperparams 160 | 161 | - momentum = 0.9 162 | 163 | - weigth_decay 164 | 165 | - noLars -> 5e-04 166 | - withLARS -> 5e-03 167 | 168 | - trust coefficient = 0.1 169 | 170 | - warm-up for 5 epoch 171 | 172 | - warmup_multiplier = 2 \* k 173 | - target lr follows linear scailing rule 174 | 175 | - polynomial decay (power=2) LR policy (after warm-up) 176 | 177 | - for 200 epoch 178 | - minimum lr = 1e-05 179 | 180 | - number of epoch = 200 181 | 182 | - Without LARS 183 | 184 | | Batch | Base LR | top-1 Accuracy, % | Time to train | 185 | | :---: | :-----: | :--------------------: | :-----------: | 186 | | 128 | 0.05 | 90.40 %
(base line) | 4232.56 sec | 187 | | 256 | 0.05 | 90.00 % | 2968.43 sec | 188 | | 512 | 0.05 | 89.50 % | 2707.79 sec | 189 | | 1024 | 0.05 | 89.27 % | 2627.22 sec | 190 | | 2048 | 0.05 | 89.21 % | 2500.02 sec | 191 | | 4096 | 0.05 | 84.73 % | 2872.25 sec | 192 | | 8192 | 0.05 | 20.85 % | 2923.95 sec | 193 | 194 | - With LARS (closest one to base line, for comparing time to train) 195 | 196 | | Batch | Base LR | top-1 Accuracy, % | Time to train | 197 | | :---: | :-----: | :---------------: | :-----------: | 198 | | 128 | 0.05 | 90.21 % | 6792.61 sec | 199 | | 256 | 0.05 | 90.28 % | 4871.68 sec | 200 | | 512 | 0.05 | 90.41 % | 3581.32 sec | 201 | | 1024 | 0.05 | 90.27 % | 3030.45 sec | 202 | | 2048 | 0.05 | 90.19 % | 2773.21 sec | 203 | | 4096 | 0.05 | 88.49 % | 2866.02 sec | 204 | | 8192 | 0.05 | 62.20 % | 1312.98 sec | 205 | 206 | - With LARS (best accuracy) 207 | 208 | | Batch | Base LR | top-1 Accuracy, % | Time to train | 209 | | :---: | :-----: | :---------------: | :-----------: | 210 | | 128 | 0.05 | 90.21 % | 6792.61 sec | 211 | | 256 | 0.05 | 90.28 % | 4871.68 sec | 212 | | 512 | 0.05 | 90.41 % | 3581.32 sec | 213 | | 1024 | 0.05 | 90.27 % | 3030.45 sec | 214 | | 2048 | 0.05 | 90.19 % | 2773.21 sec | 215 | | 4096 | 0.05 | 88.49 % | 2866.02 sec | 216 | | 8192 | 0.05 | 62.20 % | 1312.98 sec | 217 | 218 | * * * 219 | 220 | ### Attempt 3 221 | 222 | - Configuration 223 | 224 | - Hyperparams 225 | 226 | - momentum = 0.9 227 | 228 | - weigth_decay 229 | 230 | - noLars -> 5e-04 231 | - withLARS -> 5e-03 232 | 233 | - trust coefficient = 0.1 234 | 235 | - warm-up for 5 epoch 236 | 237 | - warmup_multiplier = 2 238 | 239 | - polynomial decay (power=2) LR policy (after warm-up) 240 | 241 | - for 200 epoch 242 | - minimum lr = 1e-05 \* k 243 | 244 | - number of epoch = 200 245 | 246 | - Additional Jobs 247 | 248 | - Use He initialization 249 | 250 | - base lr은 linear scailing rule에 따라 조정 251 | 252 | - Without LARS 253 | 254 | | Batch | Base LR | top-1 Accuracy, % | Time to train | 255 | | :---: | :-----: | :--------------------: | :-----------: | 256 | | 128 | 0.05 | 89.76 % | 3983.89 sec | 257 | | 256 | 0.1 | 90.08 %
(base line) | 3095.91 sec | 258 | | 512 | 0.2 | 89.34 % | 2674.38 sec | 259 | | 1024 | 0.4 | 88.82 % | 2581.19 sec | 260 | | 2048 | 0.8 | 89.29 % | 2660.56 sec | 261 | | 4096 | 1.6 | 85.02 % | 2871.04 sec | 262 | | 8192 | 3.2 | 77.72 % | 3195.90 sec | 263 | 264 | - With LARS (closest one to base line, for comparing time to train) 265 | 266 | | Batch | Base LR | top-1 Accuracy, % | Time to train | 267 | | :---: | :-----: | :---------------: | :-----------: | 268 | | 128 | 0.05 | 90.11 % | 6880.76 sec | 269 | | 256 | 0.1 | 90.12 % | 4262.83 sec | 270 | | 512 | 0.2 | 90.11 % | 3548.07 sec | 271 | | 1024 | 0.4 | 90.02 % | 2760.31 sec | 272 | | 2048 | 0.8 | 90.09 % | 2877.81 sec | 273 | | 4096 | 1.6 | 88.38 % | 2946.53 sec | 274 | | 8192 | 3.2 | 86.40 % | 3260.45 sec | 275 | 276 | - With LARS (best accuracy) 277 | 278 | | Batch | Base LR | top-1 Accuracy, % | Time to train | 279 | | :---: | :-----: | :---------------: | :-----------: | 280 | | 128 | 0.05 | 90.37 % | 7338.71 sec | 281 | | 256 | 0.1 | 90.32 % | 4590.58 sec | 282 | | 512 | 0.2 | 90.11 % | 3548.07 sec | 283 | | 1024 | 0.4 | 90.50 % | 2897.45 sec | 284 | | 2048 | 0.8 | 90.09 % | 2877.81 sec | 285 | | 4096 | 1.6 | 88.38 % | 2946.53 sec | 286 | | 8192 | 3.2 | 86.40 % | 3260.45 sec | 287 | 288 | * * * 289 | 290 | ### Attempt 4 291 | 292 | - Configuration 293 | 294 | - Hyperparams 295 | 296 | - momentum = 0.9 297 | 298 | - weigth_decay 299 | 300 | - noLars -> 5e-04 301 | - withLARS -> 5e-03 302 | 303 | - trust coefficient = 0.1 304 | 305 | - warm-up for 5 epoch 306 | 307 | - warmup_multiplier = 5 308 | 309 | - polynomial decay (power=2) LR policy (after warm-up) 310 | 311 | - for 200 epoch 312 | - minimum lr = 1e-05 \* k 313 | 314 | - number of epoch = 200 315 | 316 | - Additional Jobs 317 | 318 | - Use He initialization 319 | 320 | - base lr은 linear scailing rule에 따라 조정 321 | 322 | - Without LARS 323 | 324 | | Batch | Base LR | top-1 Accuracy, % | Time to train | 325 | | :---: | :-----: | :--------------------: | :-----------: | 326 | | 128 | 0.02 | 89.84 % | 4146.52 sec | 327 | | 256 | 0.04 | 90.22 %
(base line) | 3023.48 sec | 328 | | 512 | 0.08 | 89.42 % | 2588.01 sec | 329 | | 1024 | 0.16 | 89.41 % | 2494.35 sec | 330 | | 2048 | 0.32 | 88.97 % | 2616.32 sec | 331 | | 4096 | 0.64 | 85.13 % | 2872.76 sec | 332 | | 8192 | 1.28 | 75.99 % | 3226.53 sec | 333 | 334 | - With LARS (closest one to base line, for comparing time to train) 335 | 336 | | Batch | Base LR | top-1 Accuracy, % | Time to train | 337 | | :---: | :-----: | :---------------: | :-----------: | 338 | | 128 | 0.02 | 90.20 % | 6740.03 sec | 339 | | 256 | 0.04 | 90.25 % | 4662.09 sec | 340 | | 512 | 0.08 | 90.24 % | 3381.99 sec | 341 | | 1024 | 0.16 | 90.07 % | 2929.32 sec | 342 | | 2048 | 0.32 | 89.82 % | 2908.37 sec | 343 | | 4096 | 0.64 | 88.09 % | 2980.63 sec | 344 | | 8192 | 1.28 | 86.56 % | 3314.60 sec | 345 | 346 | - With LARS (best accuracy) 347 | 348 | | Batch | Base LR | top-1 Accuracy, % | Time to train | 349 | | :---: | :-----: | :---------------: | :-----------: | 350 | | 128 | 0.02 | 90.69 % | 7003.00 sec | 351 | | 256 | 0.04 | 90.32 % | 4808.80 sec | 352 | | 512 | 0.08 | 90.40 % | 3615.13 sec | 353 | | 1024 | 0.16 | 90.07 % | 2929.32 sec | 354 | | 2048 | 0.32 | 89.82 % | 2908.37 sec | 355 | | 4096 | 0.64 | 88.09 % | 2980.63 sec | 356 | | 8192 | 1.28 | 86.56 % | 3314.60 sec | 357 | 358 | * * * 359 | 360 | ### Attempt 5 361 | 362 | - Configuration 363 | 364 | - Hyperparams 365 | 366 | - momentum = 0.9 367 | 368 | - weigth_decay 369 | 370 | - noLars -> 5e-04 371 | - withLARS -> 5e-03 372 | 373 | - trust coefficient = 0.1 374 | 375 | - warm-up for 5 epoch 376 | 377 | - warmup_multiplier = 2 378 | 379 | - polynomial decay (power=2) LR policy (after warm-up) 380 | 381 | - **for 175 epoch** 382 | - minimum lr = 1e-05 \* k 383 | 384 | - **number of epoch = 175** 385 | 386 | - Additional Jobs 387 | 388 | - Use He initialization 389 | 390 | - base lr은 linear scailing rule에 따라 조정 391 | 392 | - Without LARS 393 | 394 | | Batch | Base LR | top-1 Accuracy, % | Time to train | 395 | | :---: | :-----: | :--------------------: | :-----------: | 396 | | 128 | 0.05 | 89.50 %
(base line) | 3682.72 sec | 397 | | 256 | 0.1 | 89.22 % | 2678.24 sec | 398 | | 512 | 0.2 | 89.12 % | 2337.15 sec | 399 | | 1024 | 0.4 | 88.70 % | 2282.48 sec | 400 | | 2048 | 0.8 | 88.89 % | 2316.96 sec | 401 | | 4096 | 1.6 | 86.87 % | 2515.56 sec | 402 | | 8192 | 3.2 | 15.50 % | 2783.00 sec | 403 | 404 | - With LARS (closest one to base line, for comparing time to train) 405 | 406 | | Batch | Base LR | top-1 Accuracy, % | Time to train | 407 | | :---: | :-----: | :---------------: | :-----------: | 408 | | 128 | 0.05 | 89.56 % | 5445.55 sec | 409 | | 256 | 0.1 | 89.52 % | 3461.59 sec | 410 | | 512 | 0.2 | 89.60 % | 2738.91 sec | 411 | | 1024 | 0.4 | 89.50 % | 2410.23 sec | 412 | | 2048 | 0.8 | 89.42 % | 2474.93 sec | 413 | | 4096 | 1.6 | 88.43 % | 2618.97 sec | 414 | | 8192 | 3.2 | 74.96 % | 1835.32 sec | 415 | 416 | - With LARS (best accuracy) 417 | 418 | | Batch | Base LR | top-1 Accuracy, % | Time to train | 419 | | :---: | :-----: | :---------------: | :-----------: | 420 | | 128 | 0.05 | 90.36 % | 6377.71 sec | 421 | | 256 | 0.1 | 90.18 % | 4219.26 sec | 422 | | 512 | 0.2 | 90.08 % | 3130.41 sec | 423 | | 1024 | 0.4 | 89.94 % | 2578.00 sec | 424 | | 2048 | 0.8 | 89.42 % | 2474.93 sec | 425 | | 4096 | 1.6 | 88.43 % | 2618.97 sec | 426 | | 8192 | 3.2 | 74.96 % | 1835.32 sec | 427 | 428 | * * * 429 | 430 | ## Visualization 431 | 432 | 433 | 434 | <Fig1. Attempt4, Without LARS, Batch size = 8192> 435 | 436 | 437 | 438 | <Fig2. Attempt4, With LARS, Batch size = 8192> 439 | 440 | - \과 \를 비교하면 LARS를 사용할 때, 좀 더 안정적으로 학습을 시작하고, 부드럽게 accuracy가 증가하는 것을 확인할 수 있다. 441 | 442 | - Attempt3, 4, 5를 작업하면서 만든 Accuracy 변화율 그래프는 아래 링크에서 확인하는 것이 가능하다. 443 | - [Attempt3](https://github.com/cmpark0126/pytorch-LARS/tree/master/fig/result_fig-attempt3) 444 | - [Attempt4](https://github.com/cmpark0126/pytorch-LARS/tree/master/fig/result_fig-attempt4) 445 | - [Attempt5](https://github.com/cmpark0126/pytorch-LARS/tree/master/fig/result_fig-attempt5) 446 | 447 | ## Analysis of Resnet50 Training With Large Batch (CIFAR10) 448 | 449 | - LARS를 사용하면 1024까지의 Batch를 사용해서 모델이 Base line의 성능을 보일 수 있도록 학습하는 것이 가능하다는 것을 확인 450 | 451 | - LARS만을 사용하는 것보다, He initialization을 포함하여 여러 테크닉을 함께 사용하는 것이 중요하다는 것을 확인 452 | 453 | - LARS를 사용하면 단순히 base line을 만족하는 것이 아니라 더 좋은 성능을 보일 수도 있다는 것을 확인 454 | - Local learning rate가 vanishing 문제나 exploding gradient 문제를 완화시킨다는 논문의 언급에 따른 부가 효과로 보임 455 | 456 | ## Open Issue 457 | 458 | - LARS를 사용하면 약 두 배 정도 시간이 더 들어가는 것을 확인. 학습 시간을 줄일 수 있는 방안이 있는지 찾아보기 459 | 460 | ## Reference 461 | 462 | - Base code: 463 | 464 | - warm-up LR scheduler: 465 | - 또한, 이를 기반으로 PolynomialLRDecay class 구현 466 | - polynomial LR decay scheduler 467 | - 참고: scheduler.py 468 | 469 | - Pytorch Doc / Optimizer: 470 | - Optimizer class 471 | - SGD class 472 | 473 | ## Appendix 474 | 475 | ### val.py 실행 화면 476 | 477 | 478 | 479 | - best accuracy가 update되어 온 history를 확인할 수 있다. 480 | -------------------------------------------------------------------------------- /fig/appendix/run_val.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/appendix/run_val.PNG -------------------------------------------------------------------------------- /fig/result_fig-attempt3/result_fig-attempt3_noLARS/noLars-1024.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt3/result_fig-attempt3_noLARS/noLars-1024.jpg -------------------------------------------------------------------------------- /fig/result_fig-attempt3/result_fig-attempt3_noLARS/noLars-128.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt3/result_fig-attempt3_noLARS/noLars-128.jpg -------------------------------------------------------------------------------- /fig/result_fig-attempt3/result_fig-attempt3_noLARS/noLars-2048.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt3/result_fig-attempt3_noLARS/noLars-2048.jpg -------------------------------------------------------------------------------- /fig/result_fig-attempt3/result_fig-attempt3_noLARS/noLars-256.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt3/result_fig-attempt3_noLARS/noLars-256.jpg -------------------------------------------------------------------------------- /fig/result_fig-attempt3/result_fig-attempt3_noLARS/noLars-4096.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt3/result_fig-attempt3_noLARS/noLars-4096.jpg -------------------------------------------------------------------------------- /fig/result_fig-attempt3/result_fig-attempt3_noLARS/noLars-512.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt3/result_fig-attempt3_noLARS/noLars-512.jpg -------------------------------------------------------------------------------- /fig/result_fig-attempt3/result_fig-attempt3_noLARS/noLars-8192.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt3/result_fig-attempt3_noLARS/noLars-8192.jpg -------------------------------------------------------------------------------- /fig/result_fig-attempt3/result_fig-attempt3_withLARS/withLars-1024.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt3/result_fig-attempt3_withLARS/withLars-1024.jpg -------------------------------------------------------------------------------- /fig/result_fig-attempt3/result_fig-attempt3_withLARS/withLars-128.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt3/result_fig-attempt3_withLARS/withLars-128.jpg -------------------------------------------------------------------------------- /fig/result_fig-attempt3/result_fig-attempt3_withLARS/withLars-2048.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt3/result_fig-attempt3_withLARS/withLars-2048.jpg -------------------------------------------------------------------------------- /fig/result_fig-attempt3/result_fig-attempt3_withLARS/withLars-256.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt3/result_fig-attempt3_withLARS/withLars-256.jpg -------------------------------------------------------------------------------- /fig/result_fig-attempt3/result_fig-attempt3_withLARS/withLars-4096.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt3/result_fig-attempt3_withLARS/withLars-4096.jpg -------------------------------------------------------------------------------- /fig/result_fig-attempt3/result_fig-attempt3_withLARS/withLars-512.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt3/result_fig-attempt3_withLARS/withLars-512.jpg -------------------------------------------------------------------------------- /fig/result_fig-attempt3/result_fig-attempt3_withLARS/withLars-8192.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt3/result_fig-attempt3_withLARS/withLars-8192.jpg -------------------------------------------------------------------------------- /fig/result_fig-attempt4/result_fig-noLARS/noLars-1024.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt4/result_fig-noLARS/noLars-1024.jpg -------------------------------------------------------------------------------- /fig/result_fig-attempt4/result_fig-noLARS/noLars-128.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt4/result_fig-noLARS/noLars-128.jpg -------------------------------------------------------------------------------- /fig/result_fig-attempt4/result_fig-noLARS/noLars-2048.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt4/result_fig-noLARS/noLars-2048.jpg -------------------------------------------------------------------------------- /fig/result_fig-attempt4/result_fig-noLARS/noLars-256.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt4/result_fig-noLARS/noLars-256.jpg -------------------------------------------------------------------------------- /fig/result_fig-attempt4/result_fig-noLARS/noLars-4096.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt4/result_fig-noLARS/noLars-4096.jpg -------------------------------------------------------------------------------- /fig/result_fig-attempt4/result_fig-noLARS/noLars-512.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt4/result_fig-noLARS/noLars-512.jpg -------------------------------------------------------------------------------- /fig/result_fig-attempt4/result_fig-noLARS/noLars-8192.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt4/result_fig-noLARS/noLars-8192.jpg -------------------------------------------------------------------------------- /fig/result_fig-attempt4/result_fig-withLARS/withLars-1024.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt4/result_fig-withLARS/withLars-1024.jpg -------------------------------------------------------------------------------- /fig/result_fig-attempt4/result_fig-withLARS/withLars-128.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt4/result_fig-withLARS/withLars-128.jpg -------------------------------------------------------------------------------- /fig/result_fig-attempt4/result_fig-withLARS/withLars-2048.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt4/result_fig-withLARS/withLars-2048.jpg -------------------------------------------------------------------------------- /fig/result_fig-attempt4/result_fig-withLARS/withLars-256.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt4/result_fig-withLARS/withLars-256.jpg -------------------------------------------------------------------------------- /fig/result_fig-attempt4/result_fig-withLARS/withLars-4096.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt4/result_fig-withLARS/withLars-4096.jpg -------------------------------------------------------------------------------- /fig/result_fig-attempt4/result_fig-withLARS/withLars-512.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt4/result_fig-withLARS/withLars-512.jpg -------------------------------------------------------------------------------- /fig/result_fig-attempt4/result_fig-withLARS/withLars-8192.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt4/result_fig-withLARS/withLars-8192.jpg -------------------------------------------------------------------------------- /fig/result_fig-attempt5/result_fig-noLARS/noLars-1024.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-noLARS/noLars-1024.jpg -------------------------------------------------------------------------------- /fig/result_fig-attempt5/result_fig-noLARS/noLars-1024.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-noLARS/noLars-1024.pth -------------------------------------------------------------------------------- /fig/result_fig-attempt5/result_fig-noLARS/noLars-128.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-noLARS/noLars-128.jpg -------------------------------------------------------------------------------- /fig/result_fig-attempt5/result_fig-noLARS/noLars-128.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-noLARS/noLars-128.pth -------------------------------------------------------------------------------- /fig/result_fig-attempt5/result_fig-noLARS/noLars-2048.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-noLARS/noLars-2048.jpg -------------------------------------------------------------------------------- /fig/result_fig-attempt5/result_fig-noLARS/noLars-2048.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-noLARS/noLars-2048.pth -------------------------------------------------------------------------------- /fig/result_fig-attempt5/result_fig-noLARS/noLars-256.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-noLARS/noLars-256.jpg -------------------------------------------------------------------------------- /fig/result_fig-attempt5/result_fig-noLARS/noLars-256.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-noLARS/noLars-256.pth -------------------------------------------------------------------------------- /fig/result_fig-attempt5/result_fig-noLARS/noLars-4096.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-noLARS/noLars-4096.jpg -------------------------------------------------------------------------------- /fig/result_fig-attempt5/result_fig-noLARS/noLars-4096.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-noLARS/noLars-4096.pth -------------------------------------------------------------------------------- /fig/result_fig-attempt5/result_fig-noLARS/noLars-512.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-noLARS/noLars-512.jpg -------------------------------------------------------------------------------- /fig/result_fig-attempt5/result_fig-noLARS/noLars-512.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-noLARS/noLars-512.pth -------------------------------------------------------------------------------- /fig/result_fig-attempt5/result_fig-noLARS/noLars-8192.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-noLARS/noLars-8192.jpg -------------------------------------------------------------------------------- /fig/result_fig-attempt5/result_fig-noLARS/noLars-8192.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-noLARS/noLars-8192.pth -------------------------------------------------------------------------------- /fig/result_fig-attempt5/result_fig-withLARS/withLars-1024.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-withLARS/withLars-1024.jpg -------------------------------------------------------------------------------- /fig/result_fig-attempt5/result_fig-withLARS/withLars-1024.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-withLARS/withLars-1024.pth -------------------------------------------------------------------------------- /fig/result_fig-attempt5/result_fig-withLARS/withLars-128.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-withLARS/withLars-128.jpg -------------------------------------------------------------------------------- /fig/result_fig-attempt5/result_fig-withLARS/withLars-128.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-withLARS/withLars-128.pth -------------------------------------------------------------------------------- /fig/result_fig-attempt5/result_fig-withLARS/withLars-2048.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-withLARS/withLars-2048.jpg -------------------------------------------------------------------------------- /fig/result_fig-attempt5/result_fig-withLARS/withLars-2048.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-withLARS/withLars-2048.pth -------------------------------------------------------------------------------- /fig/result_fig-attempt5/result_fig-withLARS/withLars-256.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-withLARS/withLars-256.jpg -------------------------------------------------------------------------------- /fig/result_fig-attempt5/result_fig-withLARS/withLars-256.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-withLARS/withLars-256.pth -------------------------------------------------------------------------------- /fig/result_fig-attempt5/result_fig-withLARS/withLars-4096.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-withLARS/withLars-4096.jpg -------------------------------------------------------------------------------- /fig/result_fig-attempt5/result_fig-withLARS/withLars-4096.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-withLARS/withLars-4096.pth -------------------------------------------------------------------------------- /fig/result_fig-attempt5/result_fig-withLARS/withLars-512.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-withLARS/withLars-512.jpg -------------------------------------------------------------------------------- /fig/result_fig-attempt5/result_fig-withLARS/withLars-512.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-withLARS/withLars-512.pth -------------------------------------------------------------------------------- /fig/result_fig-attempt5/result_fig-withLARS/withLars-8192.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-withLARS/withLars-8192.jpg -------------------------------------------------------------------------------- /fig/result_fig-attempt5/result_fig-withLARS/withLars-8192.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-withLARS/withLars-8192.pth -------------------------------------------------------------------------------- /hyperparams.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | class Base: 4 | batch_size = 128 # initial batch size 5 | lr = 0.05 # Initial learning rate 6 | multiples = 1 # help to calculate k 7 | 8 | class Hyperparams: 9 | '''Hyper parameters''' 10 | device = [0] 11 | 12 | batch_size = Base.batch_size * (2 ** (Base.multiples - 1)) # k = (2 ** (Base.multiples - 1) 13 | lr = Base.lr * (2 ** (Base.multiples - 1)) # for LR linear scailing 14 | 15 | # optim 16 | momentum = 0.9 17 | weight_decay = 5e-4 18 | trust_coef = 0.1 19 | 20 | # warm-up step & Linear Scaling Rule 21 | warmup_multiplier = 2 22 | warmup_epoch = 5 23 | 24 | # decay lr step (polynomial) 25 | max_decay_epoch = 200 26 | end_learning_rate = 0.0001 27 | 28 | num_of_epoch = 200 29 | with_lars = False 30 | resume = False 31 | 32 | def print_hyperparms(): 33 | print('batch_size: ' + str(Hyperparams.batch_size)) 34 | print('lr: ' + str(Hyperparams.lr)) 35 | print('momentum: ' + str(Hyperparams.momentum)) 36 | print('trust_coef: ' + str(Hyperparams.trust_coef)) 37 | print('warmup_multiplier: ' + str(Hyperparams.warmup_multiplier)) 38 | print('warmup_epoch: ' + str(Hyperparams.warmup_epoch)) 39 | print('max_decay_epoch: ' + str(Hyperparams.max_decay_epoch)) 40 | print('end_learning_rate: ' + str(Hyperparams.end_learning_rate)) 41 | print('num_of_epoch: ' + str(Hyperparams.num_of_epoch)) 42 | print('device: ' + str(Hyperparams.device)) 43 | print('resume: ' + str(Hyperparams.resume)) 44 | print('with_lars: ' + str(Hyperparams.with_lars)) 45 | print('weight_decay: ' + str(Hyperparams.weight_decay)) 46 | 47 | def get_info_dict(): 48 | return dict(batch_size=Hyperparams.batch_size, 49 | lr=Hyperparams.lr, 50 | momentum=Hyperparams.momentum, 51 | trust_coef=Hyperparams.trust_coef, 52 | warmup_multiplier=Hyperparams.warmup_multiplier, 53 | warmup_epoch=Hyperparams.warmup_epoch, 54 | max_decay_epoch=Hyperparams.max_decay_epoch, 55 | end_learning_rate=Hyperparams.end_learning_rate, 56 | num_of_epoch=Hyperparams.num_of_epoch, 57 | device=Hyperparams.device, 58 | resume=Hyperparams.resume, 59 | with_lars=Hyperparams.with_lars, 60 | weight_decay=Hyperparams.weight_decay) 61 | 62 | class Hyperparams_for_val: 63 | checkpoint_folder_name = 'checkpoint' 64 | with_lars = False 65 | batch_size = 128 66 | device = [0] -------------------------------------------------------------------------------- /optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.optimizer import Optimizer, required 3 | 4 | class SGD_without_lars(Optimizer): 5 | r"""Implements stochastic gradient descent (optionally with momentum). 6 | """ 7 | 8 | def __init__(self, params, lr=required, momentum=0, weight_decay=0): 9 | if lr is not required and lr < 0.0: 10 | raise ValueError("Invalid learning rate: {}".format(lr)) 11 | if momentum < 0.0: 12 | raise ValueError("Invalid momentum value: {}".format(momentum)) 13 | if weight_decay < 0.0: 14 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 15 | 16 | defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay) 17 | super(SGD_without_lars, self).__init__(params, defaults) 18 | 19 | def __setstate__(self, state): 20 | super(SGD_without_lars, self).__setstate__(state) 21 | 22 | def step(self, closure=None): 23 | """Performs a single optimization step. 24 | 25 | Arguments: 26 | closure (callable, optional): A closure that reevaluates the model 27 | and returns the loss. 28 | """ 29 | loss = None 30 | if closure is not None: 31 | loss = closure() 32 | 33 | for group in self.param_groups: 34 | weight_decay = group['weight_decay'] 35 | momentum = group['momentum'] 36 | lr = group['lr'] 37 | 38 | for p in group['params']: 39 | #torch.cuda.nvtx.range_push('trial') 40 | if p.grad is None: 41 | continue 42 | d_p = p.grad.data 43 | torch.cuda.nvtx.range_push('weight decay') 44 | if weight_decay != 0: 45 | d_p.add_(weight_decay, p.data) 46 | torch.cuda.nvtx.range_pop() 47 | # d_p.mul_(lr) 48 | 49 | torch.cuda.nvtx.range_push('momentum') 50 | if momentum != 0: 51 | param_state = self.state[p] 52 | if 'momentum_buffer' not in param_state: 53 | buf = param_state['momentum_buffer'] = torch.clone(d_p).detach() 54 | else: 55 | buf = param_state['momentum_buffer'] 56 | buf.mul_(momentum).add_(d_p) 57 | d_p = buf 58 | torch.cuda.nvtx.range_pop() 59 | 60 | torch.cuda.nvtx.range_push('weight update') 61 | p.data.add_(-lr, d_p) 62 | torch.cuda.nvtx.range_pop() 63 | 64 | # torch.cuda.nvtx.range_pop() 65 | return loss 66 | 67 | 68 | class SGD_with_lars(Optimizer): 69 | r"""Implements stochastic gradient descent (optionally with momentum). 70 | """ 71 | 72 | def __init__(self, params, lr=required, momentum=0, weight_decay=0, trust_coef=1.): # need to add trust coef 73 | if lr is not required and lr < 0.0: 74 | raise ValueError("Invalid learning rate: {}".format(lr)) 75 | if momentum < 0.0: 76 | raise ValueError("Invalid momentum value: {}".format(momentum)) 77 | if weight_decay < 0.0: 78 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 79 | if trust_coef < 0.0: 80 | raise ValueError("Invalid trust_coef value: {}".format(trust_coef)) 81 | 82 | defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay, trust_coef=trust_coef) 83 | 84 | super(SGD_with_lars, self).__init__(params, defaults) 85 | 86 | def __setstate__(self, state): 87 | super(SGD_with_lars, self).__setstate__(state) 88 | 89 | def step(self, closure=None): 90 | """Performs a single optimization step. 91 | 92 | Arguments: 93 | closure (callable, optional): A closure that reevaluates the model 94 | and returns the loss. 95 | """ 96 | loss = None 97 | if closure is not None: 98 | loss = closure() 99 | 100 | for group in self.param_groups: 101 | weight_decay = group['weight_decay'] 102 | momentum = group['momentum'] 103 | trust_coef = group['trust_coef'] 104 | global_lr = group['lr'] 105 | 106 | for p in group['params']: 107 | if p.grad is None: 108 | continue 109 | d_p = p.grad.data 110 | 111 | p_norm = torch.norm(p.data, p=2) 112 | d_p_norm = torch.norm(d_p, p=2).add_(momentum, p_norm) 113 | lr = torch.div(p_norm, d_p_norm).mul_(trust_coef) 114 | 115 | lr.mul_(global_lr) 116 | 117 | if weight_decay != 0: 118 | d_p.add_(weight_decay, p.data) 119 | 120 | d_p.mul_(lr) 121 | 122 | if momentum != 0: 123 | param_state = self.state[p] 124 | if 'momentum_buffer' not in param_state: 125 | buf = param_state['momentum_buffer'] = torch.clone(d_p).detach() 126 | else: 127 | buf = param_state['momentum_buffer'] 128 | buf.mul_(momentum).add_(d_p) 129 | d_p = buf 130 | 131 | p.data.add_(-1, d_p) 132 | 133 | return loss 134 | 135 | 136 | class SGD_with_lars_ver2(Optimizer): 137 | r"""Implements stochastic gradient descent (optionally with momentum). 138 | """ 139 | 140 | def __init__(self, params, lr=required, momentum=0, weight_decay=0, trust_coef=1.): # need to add trust coef 141 | if lr is not required and lr < 0.0: 142 | raise ValueError("Invalid learning rate: {}".format(lr)) 143 | if momentum < 0.0: 144 | raise ValueError("Invalid momentum value: {}".format(momentum)) 145 | if weight_decay < 0.0: 146 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 147 | if trust_coef < 0.0: 148 | raise ValueError("Invalid trust_coef value: {}".format(trust_coef)) 149 | 150 | defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay, trust_coef=trust_coef) 151 | 152 | super(SGD_with_lars_ver2, self).__init__(params, defaults) 153 | 154 | def __setstate__(self, state): 155 | super(SGD_with_lars_ver2, self).__setstate__(state) 156 | 157 | def step(self, closure=None): 158 | """Performs a single optimization step. 159 | 160 | Arguments: 161 | closure (callable, optional): A closure that reevaluates the model 162 | and returns the loss. 163 | """ 164 | loss = None 165 | if closure is not None: 166 | loss = closure() 167 | 168 | for group in self.param_groups: 169 | weight_decay = group['weight_decay'] 170 | momentum = group['momentum'] 171 | trust_coef = group['trust_coef'] 172 | global_lr = group['lr'] 173 | 174 | for p in group['params']: 175 | if p.grad is None: 176 | continue 177 | d_p = p.grad.data 178 | 179 | # torch.cuda.nvtx.range_push('p_norm') 180 | p_norm = torch.norm(p.data, p=2) 181 | # torch.cuda.nvtx.range_pop() 182 | # print('p_norm') 183 | # print(p_norm) 184 | # torch.cuda.nvtx.range_push('d_p_norm') 185 | d_p_norm = torch.norm(d_p, p=2).add_(weight_decay, p_norm) 186 | #torch.cuda.nvtx.range_pop() 187 | # print('d_p_norm') 188 | # print(torch.norm(d_p, p=2)) 189 | #torch.cuda.nvtx.range_push('div') 190 | lr = torch.div(p_norm, d_p_norm) 191 | #torch.cuda.nvtx.range_pop() 192 | # print('result') 193 | # print(torch.div(p_norm, d_p_norm)) 194 | # print('') 195 | 196 | 197 | #torch.cuda.nvtx.range_push('calculate local lr') 198 | lr.mul_(-global_lr*trust_coef) 199 | #torch.cuda.nvtx.range_pop() 200 | 201 | #torch.cuda.nvtx.range_push('weight decay') 202 | if weight_decay != 0: 203 | d_p.add_(weight_decay, p.data) 204 | #torch.cuda.nvtx.range_pop() 205 | 206 | #torch.cuda.nvtx.range_push('momentum') 207 | if momentum != 0: 208 | param_state = self.state[p] 209 | if 'momentum_buffer' not in param_state: 210 | buf = param_state['momentum_buffer'] = torch.clone(d_p).detach() 211 | else: 212 | buf = param_state['momentum_buffer'] 213 | buf.mul_(momentum).add_(d_p) 214 | d_p = buf 215 | #torch.cuda.nvtx.range_pop() 216 | 217 | #torch.cuda.nvtx.range_push('weight update') 218 | d_p.mul_(lr) 219 | p.data.add_(d_p) 220 | #torch.cuda.nvtx.range_pop() 221 | 222 | 223 | return loss 224 | -------------------------------------------------------------------------------- /scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler 2 | from torch.optim.lr_scheduler import ReduceLROnPlateau 3 | 4 | class GradualWarmupScheduler(_LRScheduler): 5 | """ Gradually warm-up(increasing) learning rate in optimizer. 6 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. 7 | Args: 8 | optimizer (Optimizer): Wrapped optimizer. 9 | multiplier: target learning rate = base lr * multiplier 10 | total_epoch: target learning rate is reached at total_epoch, gradually 11 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) 12 | """ 13 | 14 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): 15 | self.multiplier = multiplier 16 | if self.multiplier < 1.: 17 | raise ValueError('multiplier should be greater than 1.') 18 | self.total_epoch = total_epoch 19 | self.after_scheduler = after_scheduler 20 | self.finished = False 21 | super().__init__(optimizer) 22 | 23 | def get_lr(self): 24 | if self.last_epoch > self.total_epoch: 25 | if self.after_scheduler: 26 | if not self.finished: 27 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] 28 | self.finished = True 29 | return self.after_scheduler.get_lr() 30 | return [base_lr * self.multiplier for base_lr in self.base_lrs] 31 | 32 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 33 | 34 | def step_ReduceLROnPlateau(self, metrics, epoch=None): 35 | if epoch is None: 36 | epoch = self.last_epoch + 1 37 | self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning 38 | if self.last_epoch <= self.total_epoch: 39 | warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 40 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): 41 | param_group['lr'] = lr 42 | else: 43 | if epoch is None: 44 | self.after_scheduler.step(metrics, None) 45 | else: 46 | self.after_scheduler.step(metrics, epoch - self.total_epoch) 47 | 48 | def step(self, epoch=None, metrics=None): 49 | if type(self.after_scheduler) != ReduceLROnPlateau: # if atter scheduler is not reduce LR Plateau scheduler 50 | if self.finished and self.after_scheduler: 51 | if epoch is None: 52 | self.after_scheduler.step(None) 53 | else: 54 | self.after_scheduler.step(epoch - self.total_epoch) 55 | else: 56 | return super(GradualWarmupScheduler, self).step(epoch) 57 | else: 58 | self.step_ReduceLROnPlateau(metrics, epoch) 59 | 60 | 61 | class PolynomialLRDecay(_LRScheduler): 62 | """Polynomial decay(decrease) learning rate until step reach to max_decay_step 63 | 64 | Args: 65 | optimizer (Optimizer): Wrapped optimizer. 66 | max_decay_steps: after this step, we stop decreasing learning rate 67 | end_learning_rate: scheduler stoping learning rate decay, value of learning rate must be this value 68 | power: TBW 69 | """ 70 | 71 | def __init__(self, optimizer, max_decay_steps, end_learning_rate=0.0001, power=1.0): 72 | if max_decay_steps <= 1.: 73 | raise ValueError('max_decay_steps should be greater than 1.') 74 | self.max_decay_steps = max_decay_steps 75 | self.end_learning_rate = end_learning_rate 76 | self.power = power 77 | self.last_step = 0 78 | super().__init__(optimizer) 79 | 80 | def get_lr(self): 81 | if self.last_step > self.max_decay_steps: 82 | return [self.end_learning_rate for _ in self.base_lrs] 83 | 84 | return [(base_lr - self.end_learning_rate) * 85 | ((1 - self.last_step / self.max_decay_steps) ** (self.power)) + 86 | self.end_learning_rate for base_lr in self.base_lrs] 87 | 88 | def step(self, step=None): 89 | if step is None: 90 | step = self.last_step + 1 91 | self.last_step = step if step != 0 else 1 92 | if self.last_step <= self.max_decay_steps: 93 | decay_lrs = [(base_lr - self.end_learning_rate) * 94 | ((1 - self.last_step / self.max_decay_steps) ** (self.power)) + 95 | self.end_learning_rate for base_lr in self.base_lrs] 96 | for param_group, lr in zip(self.optimizer.param_groups, decay_lrs): 97 | param_group['lr'] = lr -------------------------------------------------------------------------------- /test_lars.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | import torch.backends.cudnn as cudnn 6 | 7 | import torchvision.models as models 8 | 9 | import os 10 | import time 11 | 12 | from optimizer import SGD_without_lars, SGD_with_lars, SGD_with_lars_ver2 13 | from scheduler import GradualWarmupScheduler, PolynomialLRDecay 14 | from hyperparams import Hyperparams as hp 15 | from utils import progress_bar 16 | 17 | with torch.cuda.device(0): 18 | # Model 19 | print('==> Building model..') 20 | net = models.resnet50() 21 | net.cuda() 22 | net = torch.nn.DataParallel(net, device_ids=[0]) 23 | cudnn.benchmark = True 24 | 25 | # Loss & Optimizer 26 | criterion = nn.CrossEntropyLoss() 27 | # optimizer = SGD_with_lars(net.parameters(), lr=hp.lr, momentum=hp.momentum, weight_decay=hp.weight_decay, trust_coef=hp.trust_coef) 28 | optimizer = SGD_with_lars_ver2(net.parameters(), lr=hp.lr, momentum=hp.momentum, weight_decay=hp.weight_decay, trust_coef=hp.trust_coef) 29 | # optimizer = SGD_without_lars(net.parameters(), lr=hp.lr, momentum=hp.momentum, weight_decay=hp.weight_decay) 30 | # optimizer = optim.SGD(net.parameters(), lr=hp.lr, momentum=hp.momentum, weight_decay=hp.weight_decay) 31 | 32 | # Training 33 | net.train() 34 | 35 | inputs = torch.ones([2, 3, 32, 32]).cuda() 36 | targets = torch.ones([2], dtype=torch.long).cuda() 37 | optimizer.zero_grad() 38 | outputs = net(inputs) 39 | loss = criterion(outputs, targets) 40 | loss.backward() 41 | 42 | print('Complete Forward & Backward') 43 | 44 | for batch_idx in range(5): 45 | start_time = time.time() 46 | # torch.cuda.nvtx.range_push('trial') 47 | 48 | optimizer.step() 49 | 50 | # torch.cuda.nvtx.range_pop() 51 | print('time to optimize is %.3f' % (time.time() - start_time)) 52 | 53 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | import torch.backends.cudnn as cudnn 6 | 7 | import torchvision 8 | import torchvision.transforms as transforms 9 | import torchvision.models as models 10 | 11 | import os 12 | import time 13 | 14 | from optimizer import SGD_without_lars, SGD_with_lars, SGD_with_lars_ver2 15 | from scheduler import GradualWarmupScheduler, PolynomialLRDecay 16 | from hyperparams import Hyperparams as hp 17 | from utils import progress_bar 18 | 19 | with torch.cuda.device(hp.device[0]): 20 | all_accs = [] 21 | best_acc = 0 # best test accuracy 22 | all_epochs = [] 23 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 24 | all_times = [] 25 | time_to_train = 0 26 | 27 | # Data 28 | print('==> Preparing data..') 29 | transform_train = transforms.Compose([ 30 | transforms.RandomCrop(32, padding=4), 31 | transforms.RandomHorizontalFlip(), 32 | transforms.ToTensor(), 33 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 34 | ]) 35 | 36 | transform_test = transforms.Compose([ 37 | transforms.ToTensor(), 38 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 39 | ]) 40 | 41 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) 42 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=hp.batch_size, shuffle=True, num_workers=2) 43 | 44 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) 45 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2) 46 | 47 | classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') 48 | 49 | def init_weights(m): 50 | if type(m) == nn.Linear or type(m) == nn.Conv2d: 51 | torch.nn.init.kaiming_uniform_(m.weight) 52 | 53 | # Model 54 | print('==> Building model..') 55 | net = models.resnet50() 56 | net.apply(init_weights) 57 | net.cuda() 58 | net = torch.nn.DataParallel(net, device_ids=hp.device) 59 | cudnn.benchmark = True 60 | 61 | if hp.resume: 62 | # Load checkpoint. 63 | print('==> Resuming from checkpoint..') 64 | assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' 65 | if hp.with_lars: 66 | checkpoint = torch.load('./checkpoint/withLars-' + str(hp.batch_size) + '.pth') 67 | else: 68 | checkpoint = torch.load('./checkpoint/noLars-' + str(hp.batch_size) + '.pth') 69 | net.load_state_dict(checkpoint['net']) 70 | best_acc = checkpoint['acc'] 71 | start_epoch = checkpoint['epoch'] 72 | time_to_train = checkpoint['time_to_train'] 73 | basic_info = checkpoint['basic_info'] 74 | 75 | # Loss & Optimizer 76 | criterion = nn.CrossEntropyLoss() 77 | optimizer = None 78 | if hp.with_lars: 79 | # optimizer = SGD_with_lars(net.parameters(), lr=hp.lr, momentum=hp.momentum, weight_decay=hp.weight_decay, trust_coef=hp.trust_coef) 80 | optimizer = SGD_with_lars_ver2(net.parameters(), lr=hp.lr, momentum=hp.momentum, weight_decay=hp.weight_decay, trust_coef=hp.trust_coef) 81 | else: 82 | # optimizer = SGD_without_lars(net.parameters(), lr=hp.lr, momentum=hp.momentum, weight_decay=hp.weight_decay) 83 | optimizer = optim.SGD(net.parameters(), lr=hp.lr, momentum=hp.momentum, weight_decay=hp.weight_decay) 84 | 85 | warmup_scheduler = GradualWarmupScheduler(optimizer=optimizer, multiplier=hp.warmup_multiplier, total_epoch=hp.warmup_epoch) 86 | poly_decay_scheduler = PolynomialLRDecay(optimizer=optimizer, max_decay_steps=hp.max_decay_epoch * len(trainloader), 87 | end_learning_rate=hp.end_learning_rate, power=2.0) # poly(2) 88 | 89 | # Training 90 | def train(epoch): 91 | global time_to_train 92 | net.train() 93 | train_loss = 0 94 | correct = 0 95 | total = 0 96 | 97 | start_time = time.time() 98 | for batch_idx, (inputs, targets) in enumerate(trainloader): 99 | if epoch > hp.warmup_epoch: # after warmup schduler step 100 | poly_decay_scheduler.step() 101 | inputs, targets = inputs.cuda(), targets.cuda() 102 | optimizer.zero_grad() 103 | outputs = net(inputs) 104 | loss = criterion(outputs, targets) 105 | loss.backward() 106 | optimizer.step() 107 | 108 | train_loss += loss.item() 109 | _, predicted = outputs.max(1) 110 | total += targets.size(0) 111 | correct += predicted.eq(targets).sum().item() 112 | 113 | progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 114 | % (train_loss/(batch_idx+1), 100.*correct/total, correct, total)) 115 | time_to_train = time_to_train + (time.time() - start_time) 116 | 117 | def test(epoch): 118 | global best_acc 119 | net.eval() 120 | test_loss = 0 121 | correct = 0 122 | total = 0 123 | with torch.no_grad(): 124 | for batch_idx, (inputs, targets) in enumerate(testloader): 125 | inputs, targets = inputs.cuda(), targets.cuda() 126 | outputs = net(inputs) 127 | loss = criterion(outputs, targets) 128 | 129 | test_loss += loss.item() 130 | _, predicted = outputs.max(1) 131 | total += targets.size(0) 132 | correct += predicted.eq(targets).sum().item() 133 | 134 | progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 135 | % (test_loss/(batch_idx+1), 100.*correct/total, correct, total)) 136 | 137 | # Save checkpoint. 138 | acc = 100.*correct/total 139 | if acc > best_acc: 140 | all_accs.append(acc) 141 | all_epochs.append(epoch) 142 | all_times.append(round(time_to_train, 2)) 143 | print('Saving..') 144 | state = { 145 | 'net': net.state_dict(), 146 | 'acc': all_accs, 147 | 'epoch': all_epochs, 148 | 'time_to_train': all_times, 149 | 'basic_info': hp.get_info_dict() 150 | } 151 | if not os.path.isdir('checkpoint'): 152 | os.mkdir('checkpoint') 153 | if hp.with_lars: 154 | torch.save(state, './checkpoint/withLars-' + str(hp.batch_size) + '.pth') 155 | else: 156 | torch.save(state, './checkpoint/noLars-' + str(hp.batch_size) + '.pth') 157 | best_acc = acc 158 | 159 | if hp.with_lars: 160 | print('Resnet50, data=cifar10, With LARS') 161 | else: 162 | print('Resnet50, data=cifar10, Without LARS') 163 | hp.print_hyperparms() 164 | for epoch in range(0, hp.num_of_epoch): 165 | print('\nEpoch: %d' % epoch) 166 | if epoch <= hp.warmup_epoch: # for readability 167 | warmup_scheduler.step() 168 | if epoch > hp.warmup_epoch: # after warmup, start decay scheduler with warmup-ed learning rate 169 | poly_decay_scheduler.base_lrs = warmup_scheduler.get_lr() 170 | for param_group in optimizer.param_groups: 171 | print('lr: ' + str(param_group['lr'])) 172 | train(epoch) 173 | test(epoch) 174 | -------------------------------------------------------------------------------- /train_with_matplot.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | import torch.backends.cudnn as cudnn 6 | 7 | import torchvision 8 | import torchvision.transforms as transforms 9 | import torchvision.models as models 10 | 11 | import os 12 | import time 13 | 14 | from optimizer import SGD_without_lars, SGD_with_lars, SGD_with_lars_ver2 15 | from scheduler import GradualWarmupScheduler, PolynomialLRDecay 16 | from hyperparams import Hyperparams as hp 17 | from utils import progress_bar 18 | 19 | import matplotlib.pyplot as plt 20 | 21 | with torch.cuda.device(hp.device[0]): 22 | all_accs = [] 23 | best_acc = 0 # best test accuracy 24 | all_epochs = [] 25 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 26 | all_times = [] 27 | time_to_train = 0 28 | 29 | train_correct = 0 30 | train_total = 0 31 | test_correct = 0 32 | test_total = 0 33 | 34 | epochs = [] 35 | train_accs = [] 36 | test_accs = [] 37 | 38 | # Data 39 | print('==> Preparing data..') 40 | transform_train = transforms.Compose([ 41 | transforms.RandomCrop(32, padding=4), 42 | transforms.RandomHorizontalFlip(), 43 | transforms.ToTensor(), 44 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 45 | ]) 46 | 47 | transform_test = transforms.Compose([ 48 | transforms.ToTensor(), 49 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 50 | ]) 51 | 52 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) 53 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=hp.batch_size, shuffle=True, num_workers=2) 54 | 55 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) 56 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2) 57 | 58 | classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') 59 | 60 | def init_weights(m): 61 | if type(m) == nn.Linear or type(m) == nn.Conv2d: 62 | torch.nn.init.kaiming_uniform_(m.weight) 63 | 64 | # Model 65 | print('==> Building model..') 66 | net = models.resnet50() 67 | net.apply(init_weights) 68 | net.cuda() 69 | net = torch.nn.DataParallel(net, device_ids=hp.device) 70 | cudnn.benchmark = True 71 | 72 | if hp.resume: 73 | # Load checkpoint. 74 | print('==> Resuming from checkpoint..') 75 | assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' 76 | if hp.with_lars: 77 | checkpoint = torch.load('./checkpoint/withLars-' + str(hp.batch_size) + '.pth') 78 | else: 79 | checkpoint = torch.load('./checkpoint/noLars-' + str(hp.batch_size) + '.pth') 80 | net.load_state_dict(checkpoint['net']) 81 | best_acc = checkpoint['acc'] 82 | start_epoch = checkpoint['epoch'] 83 | time_to_train = checkpoint['time_to_train'] 84 | basic_info = checkpoint['basic_info'] 85 | 86 | # Loss & Optimizer 87 | criterion = nn.CrossEntropyLoss() 88 | optimizer = None 89 | if hp.with_lars: 90 | # optimizer = SGD_with_lars(net.parameters(), lr=hp.lr, momentum=hp.momentum, weight_decay=hp.weight_decay, trust_coef=hp.trust_coef) 91 | optimizer = SGD_with_lars_ver2(net.parameters(), lr=hp.lr, momentum=hp.momentum, weight_decay=hp.weight_decay, trust_coef=hp.trust_coef) 92 | else: 93 | # optimizer = SGD_without_lars(net.parameters(), lr=hp.lr, momentum=hp.momentum, weight_decay=hp.weight_decay) 94 | optimizer = optim.SGD(net.parameters(), lr=hp.lr, momentum=hp.momentum, weight_decay=hp.weight_decay) 95 | 96 | warmup_scheduler = GradualWarmupScheduler(optimizer=optimizer, multiplier=hp.warmup_multiplier, total_epoch=hp.warmup_epoch) 97 | poly_decay_scheduler = PolynomialLRDecay(optimizer=optimizer, max_decay_steps=hp.max_decay_epoch * len(trainloader), 98 | end_learning_rate=hp.end_learning_rate, power=2.0) # poly(2) 99 | 100 | # Training 101 | def train(epoch): 102 | global train_total 103 | global train_correct 104 | global time_to_train 105 | net.train() 106 | train_loss = 0 107 | correct = 0 108 | total = 0 109 | 110 | start_time = time.time() 111 | for batch_idx, (inputs, targets) in enumerate(trainloader): 112 | if epoch > hp.warmup_epoch: # after warmup schduler step 113 | poly_decay_scheduler.step() 114 | inputs, targets = inputs.cuda(), targets.cuda() 115 | optimizer.zero_grad() 116 | outputs = net(inputs) 117 | loss = criterion(outputs, targets) 118 | loss.backward() 119 | optimizer.step() 120 | 121 | train_loss += loss.item() 122 | _, predicted = outputs.max(1) 123 | total += targets.size(0) 124 | correct += predicted.eq(targets).sum().item() 125 | 126 | progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 127 | % (train_loss/(batch_idx+1), 100.*correct/total, correct, total)) 128 | time_to_train = time_to_train + (time.time() - start_time) 129 | 130 | train_total = total 131 | train_correct = correct 132 | 133 | def test(epoch): 134 | global best_acc 135 | global test_total 136 | global test_correct 137 | net.eval() 138 | test_loss = 0 139 | correct = 0 140 | total = 0 141 | with torch.no_grad(): 142 | for batch_idx, (inputs, targets) in enumerate(testloader): 143 | inputs, targets = inputs.cuda(), targets.cuda() 144 | outputs = net(inputs) 145 | loss = criterion(outputs, targets) 146 | 147 | test_loss += loss.item() 148 | _, predicted = outputs.max(1) 149 | total += targets.size(0) 150 | correct += predicted.eq(targets).sum().item() 151 | 152 | progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 153 | % (test_loss/(batch_idx+1), 100.*correct/total, correct, total)) 154 | 155 | test_total = total 156 | test_correct = correct 157 | 158 | # Save checkpoint. 159 | acc = 100.*correct/total 160 | if acc > best_acc: 161 | all_accs.append(acc) 162 | all_epochs.append(epoch) 163 | all_times.append(round(time_to_train, 2)) 164 | print('Saving..') 165 | state = { 166 | 'net': net.state_dict(), 167 | 'acc': all_accs, 168 | 'epoch': all_epochs, 169 | 'time_to_train': all_times, 170 | 'basic_info': hp.get_info_dict() 171 | } 172 | if not os.path.isdir('checkpoint'): 173 | os.mkdir('checkpoint') 174 | if hp.with_lars: 175 | torch.save(state, './checkpoint/withLars-' + str(hp.batch_size) + '.pth') 176 | else: 177 | torch.save(state, './checkpoint/noLars-' + str(hp.batch_size) + '.pth') 178 | best_acc = acc 179 | 180 | if hp.with_lars: 181 | print('Resnet50, data=cifar10, With LARS') 182 | else: 183 | print('Resnet50, data=cifar10, Without LARS') 184 | hp.print_hyperparms() 185 | for epoch in range(0, hp.num_of_epoch): 186 | print('\nEpoch: %d' % epoch) 187 | if epoch <= hp.warmup_epoch: # for readability 188 | warmup_scheduler.step() 189 | if epoch > hp.warmup_epoch: # after warmup, start decay scheduler with warmup-ed learning rate 190 | poly_decay_scheduler.base_lrs = warmup_scheduler.get_lr() 191 | for param_group in optimizer.param_groups: 192 | print('lr: ' + str(param_group['lr'])) 193 | train(epoch) 194 | test(epoch) 195 | 196 | epochs.append(epoch) 197 | train_accs.append(100.*train_correct/train_total) 198 | test_accs.append(100.*test_correct/test_total) 199 | 200 | plt.plot(epochs, train_accs, epochs, test_accs, 'r-') 201 | state = { 'test_acc': test_accs } 202 | 203 | if not os.path.isdir('result_fig'): 204 | os.mkdir('result_fig') 205 | 206 | if hp.with_lars: 207 | plt.title('Resnet50, data=cifar10, With LARS, batch_size: ' + str(hp.batch_size)) 208 | plt.savefig('./result_fig/withLars-' + str(hp.batch_size) + '.jpg') 209 | torch.save(state, './result_fig/withLars-' + str(hp.batch_size) + '.pth') 210 | else: 211 | plt.title('Resnet50, data=cifar10, Without LARS, batch_size: ' + str(hp.batch_size)) 212 | plt.savefig('./result_fig/noLars-' + str(hp.batch_size) + '.jpg') 213 | torch.save(state, './result_fig/noLars-' + str(hp.batch_size) + '.pth') 214 | 215 | plt.gcf().clear() 216 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | - msr_init: net parameter initialization. 4 | - progress_bar: progress bar mimic xlua.progress. 5 | ''' 6 | import os 7 | import sys 8 | import time 9 | import math 10 | 11 | import torch.nn as nn 12 | import torch.nn.init as init 13 | 14 | 15 | def get_mean_and_std(dataset): 16 | '''Compute the mean and std value of dataset.''' 17 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 18 | mean = torch.zeros(3) 19 | std = torch.zeros(3) 20 | print('==> Computing mean and std..') 21 | for inputs, targets in dataloader: 22 | for i in range(3): 23 | mean[i] += inputs[:,i,:,:].mean() 24 | std[i] += inputs[:,i,:,:].std() 25 | mean.div_(len(dataset)) 26 | std.div_(len(dataset)) 27 | return mean, std 28 | 29 | def init_params(net): 30 | '''Init layer parameters.''' 31 | for m in net.modules(): 32 | if isinstance(m, nn.Conv2d): 33 | init.kaiming_normal(m.weight, mode='fan_out') 34 | if m.bias: 35 | init.constant(m.bias, 0) 36 | elif isinstance(m, nn.BatchNorm2d): 37 | init.constant(m.weight, 1) 38 | init.constant(m.bias, 0) 39 | elif isinstance(m, nn.Linear): 40 | init.normal(m.weight, std=1e-3) 41 | if m.bias: 42 | init.constant(m.bias, 0) 43 | 44 | 45 | _, term_width = os.popen('stty size', 'r').read().split() 46 | term_width = int(term_width) 47 | 48 | TOTAL_BAR_LENGTH = 65. 49 | last_time = time.time() 50 | begin_time = last_time 51 | def progress_bar(current, total, msg=None): 52 | global last_time, begin_time 53 | if current == 0: 54 | begin_time = time.time() # Reset for new bar. 55 | 56 | cur_len = int(TOTAL_BAR_LENGTH*current/total) 57 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 58 | 59 | sys.stdout.write(' [') 60 | for i in range(cur_len): 61 | sys.stdout.write('=') 62 | sys.stdout.write('>') 63 | for i in range(rest_len): 64 | sys.stdout.write('.') 65 | sys.stdout.write(']') 66 | 67 | cur_time = time.time() 68 | step_time = cur_time - last_time 69 | last_time = cur_time 70 | tot_time = cur_time - begin_time 71 | 72 | L = [] 73 | L.append(' Step: %s' % format_time(step_time)) 74 | L.append(' | Tot: %s' % format_time(tot_time)) 75 | if msg: 76 | L.append(' | ' + msg) 77 | 78 | msg = ''.join(L) 79 | sys.stdout.write(msg) 80 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 81 | sys.stdout.write(' ') 82 | 83 | # Go back to the center of the bar. 84 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): 85 | sys.stdout.write('\b') 86 | sys.stdout.write(' %d/%d ' % (current+1, total)) 87 | 88 | if current < total-1: 89 | sys.stdout.write('\r') 90 | else: 91 | sys.stdout.write('\n') 92 | sys.stdout.flush() 93 | 94 | def format_time(seconds): 95 | days = int(seconds / 3600/24) 96 | seconds = seconds - days*3600*24 97 | hours = int(seconds / 3600) 98 | seconds = seconds - hours*3600 99 | minutes = int(seconds / 60) 100 | seconds = seconds - minutes*60 101 | secondsf = int(seconds) 102 | seconds = seconds - secondsf 103 | millis = int(seconds*1000) 104 | 105 | f = '' 106 | i = 1 107 | if days > 0: 108 | f += str(days) + 'D' 109 | i += 1 110 | if hours > 0 and i <= 2: 111 | f += str(hours) + 'h' 112 | i += 1 113 | if minutes > 0 and i <= 2: 114 | f += str(minutes) + 'm' 115 | i += 1 116 | if secondsf > 0 and i <= 2: 117 | f += str(secondsf) + 's' 118 | i += 1 119 | if millis > 0 and i <= 2: 120 | f += str(millis) + 'ms' 121 | i += 1 122 | if f == '': 123 | f = '0ms' 124 | return f -------------------------------------------------------------------------------- /val.py: -------------------------------------------------------------------------------- 1 | '''Train CIFAR10 with PyTorch.''' 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import torch.nn.functional as F 6 | import torch.backends.cudnn as cudnn 7 | 8 | import torchvision 9 | import torchvision.transforms as transforms 10 | import torchvision.models as models 11 | 12 | import os 13 | import argparse 14 | 15 | from hyperparams import Hyperparams_for_val as hp 16 | from utils import progress_bar 17 | 18 | with torch.cuda.device(hp.device[0]): 19 | # Data 20 | print('==> Preparing data..') 21 | transform_test = transforms.Compose([ 22 | transforms.ToTensor(), 23 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 24 | ]) 25 | 26 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) 27 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2) 28 | 29 | classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') 30 | 31 | # Model 32 | print('==> Building model..') 33 | net = models.resnet50() 34 | net.cuda() 35 | net = torch.nn.DataParallel(net, device_ids=hp.device) 36 | cudnn.benchmark = True 37 | 38 | # Load checkpoint. 39 | print('==> Resuming from checkpoint..') 40 | assert os.path.isdir(hp.checkpoint_folder_name), 'Error: no checkpoint directory found!' 41 | if hp.with_lars: 42 | checkpoint = torch.load('./' + hp.checkpoint_folder_name + '/withLars-' + str(hp.batch_size) + '.pth') 43 | else: 44 | checkpoint = torch.load('./' + hp.checkpoint_folder_name + '/noLars-' + str(hp.batch_size) + '.pth') 45 | net.load_state_dict(checkpoint['net']) 46 | best_acc = checkpoint['acc'] 47 | epoch = checkpoint['epoch'] 48 | time_to_train = checkpoint['time_to_train'] # after 2nd 49 | basic_info = checkpoint['basic_info'] # after 3rd 50 | 51 | criterion = nn.CrossEntropyLoss() 52 | 53 | def test(): 54 | global best_acc 55 | net.eval() 56 | test_loss = 0 57 | correct = 0 58 | total = 0 59 | with torch.no_grad(): 60 | for batch_idx, (inputs, targets) in enumerate(testloader): 61 | inputs, targets = inputs.cuda(), targets.cuda() 62 | outputs = net(inputs) 63 | loss = criterion(outputs, targets) 64 | 65 | test_loss += loss.item() 66 | _, predicted = outputs.max(1) 67 | total += targets.size(0) 68 | correct += predicted.eq(targets).sum().item() 69 | 70 | progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 71 | % (test_loss/(batch_idx+1), 100.*correct/total, correct, total)) 72 | 73 | 74 | if hp.with_lars: 75 | print('Resnet50, data=cifar10, With LARS, Validation') 76 | else: 77 | print('Resnet50, data=cifar10, Without LARS, Validation') 78 | print('basic_info=' + str(basic_info)) 79 | 80 | for epo, acc, time in zip(epoch, best_acc, time_to_train): 81 | print (str(epo) + ' epoch | ' + str(acc) + ' % | ' + str(time) + ' sec') 82 | 83 | test() 84 | 85 | 86 | --------------------------------------------------------------------------------