├── .gitignore
├── README.md
├── fig
├── appendix
│ └── run_val.PNG
├── result_fig-attempt3
│ ├── result_fig-attempt3_noLARS
│ │ ├── noLars-1024.jpg
│ │ ├── noLars-128.jpg
│ │ ├── noLars-2048.jpg
│ │ ├── noLars-256.jpg
│ │ ├── noLars-4096.jpg
│ │ ├── noLars-512.jpg
│ │ └── noLars-8192.jpg
│ └── result_fig-attempt3_withLARS
│ │ ├── withLars-1024.jpg
│ │ ├── withLars-128.jpg
│ │ ├── withLars-2048.jpg
│ │ ├── withLars-256.jpg
│ │ ├── withLars-4096.jpg
│ │ ├── withLars-512.jpg
│ │ └── withLars-8192.jpg
├── result_fig-attempt4
│ ├── result_fig-noLARS
│ │ ├── noLars-1024.jpg
│ │ ├── noLars-128.jpg
│ │ ├── noLars-2048.jpg
│ │ ├── noLars-256.jpg
│ │ ├── noLars-4096.jpg
│ │ ├── noLars-512.jpg
│ │ └── noLars-8192.jpg
│ └── result_fig-withLARS
│ │ ├── withLars-1024.jpg
│ │ ├── withLars-128.jpg
│ │ ├── withLars-2048.jpg
│ │ ├── withLars-256.jpg
│ │ ├── withLars-4096.jpg
│ │ ├── withLars-512.jpg
│ │ └── withLars-8192.jpg
└── result_fig-attempt5
│ ├── result_fig-noLARS
│ ├── noLars-1024.jpg
│ ├── noLars-1024.pth
│ ├── noLars-128.jpg
│ ├── noLars-128.pth
│ ├── noLars-2048.jpg
│ ├── noLars-2048.pth
│ ├── noLars-256.jpg
│ ├── noLars-256.pth
│ ├── noLars-4096.jpg
│ ├── noLars-4096.pth
│ ├── noLars-512.jpg
│ ├── noLars-512.pth
│ ├── noLars-8192.jpg
│ └── noLars-8192.pth
│ └── result_fig-withLARS
│ ├── withLars-1024.jpg
│ ├── withLars-1024.pth
│ ├── withLars-128.jpg
│ ├── withLars-128.pth
│ ├── withLars-2048.jpg
│ ├── withLars-2048.pth
│ ├── withLars-256.jpg
│ ├── withLars-256.pth
│ ├── withLars-4096.jpg
│ ├── withLars-4096.pth
│ ├── withLars-512.jpg
│ ├── withLars-512.pth
│ ├── withLars-8192.jpg
│ └── withLars-8192.pth
├── hyperparams.py
├── optimizer.py
├── scheduler.py
├── test_lars.py
├── train.py
├── train_with_matplot.py
├── utils.py
└── val.py
/.gitignore:
--------------------------------------------------------------------------------
1 | core.*
2 | data/*
3 | .ipynb_checkpoints/*
4 | __pycache__/*
5 | checkpoint/*
6 | *.swp
7 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Pytorch-LARS
2 |
3 | ## Objective
4 |
5 | - link: ["Large Batch Training of Convolutional Networks (LARS)"](https://arxiv.org/abs/1708.03888)
6 | - 위 논문에 소개된 LARS를 PyTorch, CUDA로 구현
7 | - Data: CIFAR10
8 |
9 | ## Requirements
10 |
11 | - python == 3.6.8
12 | - pytorch >= 1.1.0
13 | - cuda >= 10
14 | - matplotlib >= 3.1.0 (option)
15 | - etc.
16 |
17 | ## Usage
18 |
19 | - Train
20 |
21 | ```bash
22 | $ git clone https://github.com/cmpark0126/pytorch-LARS.git
23 | $ cd pytorch-LARS/
24 | $ vi hyperparams.py # 학습을 위해 Basic, Hyperparams class 수정
25 | $ python train.py # CIFAR10 학습 시작
26 | ```
27 |
28 | - Evaluate
29 |
30 | ```bash
31 | $ vi hyperparams.py # 학습 결과 확인을 위해 Hyperparams_for_val class 조정, 특정한 checkpoint를 선택하는 것이 가능
32 | $ python val.py # 학습 결과 확인, 이걸로 학습 진행 도중 update되어온 test accuracy의 history 확인 가능
33 | ```
34 |
35 | ## Hyperparams (hyperparams.py)
36 |
37 | - Base (class)
38 |
39 | - batch_size: 기준 Batch size. 실험에서 사용되는 모든 Batch size는 이 size의 배수 형태로 나타난다.
40 |
41 | - lr: 기준 learning rate. 일반적으로 linear scailing에서 기준 값으로 사용한다.
42 |
43 | - multiples: 아래에서 설명되는 k를 구하기 위한 지수로 사용되는 배수이다.
44 |
45 | - Hyperparams (class)
46 |
47 | - batch_size: 실제 학습에서 사용하는 batch size
48 |
49 | - lr: 실제 학습에서 초기 값으로 사용하는 learning rate
50 |
51 | - momentum
52 |
53 | - weight_decay
54 |
55 | - trust_coef: trust coefficient로 LARS 사용시에 내부에서 구해지는 Local LR의 신뢰도를 의미
56 |
57 | - warmup_multiplier
58 |
59 | - warmup_epoch
60 |
61 | - max_decay_epoch: polynomial decay를 최대한 진행할 epoch 수
62 |
63 | - end_learning_rate: decay 작업이 모두 완료되었을 때 learning rate가 수렴될 값
64 |
65 | - num_of_epoch: 학습을 돌릴 총 epoch 수
66 |
67 | - with_lars
68 |
69 | - Hyperparams_for_val (class)
70 |
71 | - checkpoint_folder_name: hyperparams.py와 같은 폴더에는 파라미터를 모아둔 checkpoint folder가 존재해야 하며, 이들 중 하나의 이름을 지정(eg. checkpoint_folder_name = 'checkpoint-attempt1')
72 |
73 | - with_lars: checkpoint 중, lars를 사용한 것 혹은 사용하지 않은 것을 선택
74 |
75 | - batch_size: checkpoint 중, 사용한 batch_size 크기를 지정
76 |
77 | - device: evaluation을 위해 모델을 돌릴 때 사용할 cuda device 선택
78 |
79 | ## Demonstration
80 |
81 | - Terminology
82 | - k
83 | - we increase the batch B by k
84 | - start batch size is 128
85 | - if we use 256 as batch size, k is 2 in this time
86 | - **k = (2 \*\* (multiples - 1))**
87 | - (base line)
88 | - target accuracy which we want to get when we train the model using large batch size with LARS
89 |
90 | * * *
91 |
92 | ### Attempt 1
93 |
94 | - Configuration
95 |
96 | - Hyperparams
97 |
98 | - momentum = 0.9
99 |
100 | - weigth_decay
101 |
102 | - noLars -> 5e-04
103 | - withLARS -> 5e-03
104 |
105 | - warm-up for 5 epoch
106 |
107 | - warmup_multiplier = k
108 | - target lr follows linear scailing rule
109 |
110 | - polynomial decay (power=2) LR policy (after warm-up)
111 |
112 | - for 200 epoch
113 | - minimum lr = 1.5e-05 \* k
114 |
115 | - number of epoch = 200
116 |
117 | - Without LARS
118 |
119 | | Batch | Base LR | top-1 Accuracy, % | Time to train |
120 | | :---: | :-----: | :--------------------: | :-----------: |
121 | | 128 | 0.15 | 89.15 %
(base line) | 2113.52 sec |
122 | | 256 | 0.15 | 88.43 % | 1433.38 sec |
123 | | 512 | 0.15 | 88.72 % | 1820.35 sec |
124 | | 1024 | 0.15 | 87.96 % | 1303.54 sec |
125 | | 2048 | 0.15 | 87.05 % | 1827.90 sec |
126 | | 4096 | 0.15 | 78.03 % | 2083.24 sec |
127 | | 8192 | 0.15 | 14.59 % | 1459.81 sec |
128 |
129 | - With LARS (closest one to base line, for comparing time to train)
130 |
131 | | Batch | Base LR | top-1 Accuracy, % | Time to train |
132 | | :---: | :-----: | :---------------: | :-----------: |
133 | | 128 | 0.15 | 89.16 % | 3203.54 sec |
134 | | 256 | 0.15 | 89.19 % | 2147.74 sec |
135 | | 512 | 0.15 | 89.29 % | 1677.25 sec |
136 | | 1024 | 0.15 | 89.17 % | 1604.91 sec |
137 | | 2048 | 0.15 | 88.70 % | 1413.10 sec |
138 | | 4096 | 0.15 | 86.78 % | 1609.08 sec |
139 | | 8192 | 0.15 | 80.85 % | 1629.48 sec |
140 |
141 | - With LARS (best accuracy)
142 |
143 | | Batch | Base LR | top-1 Accuracy, % | Time to train |
144 | | :---: | :-----: | :---------------: | :-----------: |
145 | | 128 | 0.15 | 89.62 % | 3606.08 sec |
146 | | 256 | 0.15 | 89.78 % | 2675.04 sec |
147 | | 512 | 0.15 | 89.38 % | 1712.90 sec |
148 | | 1024 | 0.15 | 89.22 % | 1967.92 sec |
149 | | 2048 | 0.15 | 88.70 % | 1413.10 sec |
150 | | 4096 | 0.15 | 86.78 % | 1609.08 sec |
151 | | 8192 | 0.15 | 80.85 % | 1629.48 sec |
152 |
153 | * * *
154 |
155 | ### Attempt 2
156 |
157 | - Configuration
158 |
159 | - Hyperparams
160 |
161 | - momentum = 0.9
162 |
163 | - weigth_decay
164 |
165 | - noLars -> 5e-04
166 | - withLARS -> 5e-03
167 |
168 | - trust coefficient = 0.1
169 |
170 | - warm-up for 5 epoch
171 |
172 | - warmup_multiplier = 2 \* k
173 | - target lr follows linear scailing rule
174 |
175 | - polynomial decay (power=2) LR policy (after warm-up)
176 |
177 | - for 200 epoch
178 | - minimum lr = 1e-05
179 |
180 | - number of epoch = 200
181 |
182 | - Without LARS
183 |
184 | | Batch | Base LR | top-1 Accuracy, % | Time to train |
185 | | :---: | :-----: | :--------------------: | :-----------: |
186 | | 128 | 0.05 | 90.40 %
(base line) | 4232.56 sec |
187 | | 256 | 0.05 | 90.00 % | 2968.43 sec |
188 | | 512 | 0.05 | 89.50 % | 2707.79 sec |
189 | | 1024 | 0.05 | 89.27 % | 2627.22 sec |
190 | | 2048 | 0.05 | 89.21 % | 2500.02 sec |
191 | | 4096 | 0.05 | 84.73 % | 2872.25 sec |
192 | | 8192 | 0.05 | 20.85 % | 2923.95 sec |
193 |
194 | - With LARS (closest one to base line, for comparing time to train)
195 |
196 | | Batch | Base LR | top-1 Accuracy, % | Time to train |
197 | | :---: | :-----: | :---------------: | :-----------: |
198 | | 128 | 0.05 | 90.21 % | 6792.61 sec |
199 | | 256 | 0.05 | 90.28 % | 4871.68 sec |
200 | | 512 | 0.05 | 90.41 % | 3581.32 sec |
201 | | 1024 | 0.05 | 90.27 % | 3030.45 sec |
202 | | 2048 | 0.05 | 90.19 % | 2773.21 sec |
203 | | 4096 | 0.05 | 88.49 % | 2866.02 sec |
204 | | 8192 | 0.05 | 62.20 % | 1312.98 sec |
205 |
206 | - With LARS (best accuracy)
207 |
208 | | Batch | Base LR | top-1 Accuracy, % | Time to train |
209 | | :---: | :-----: | :---------------: | :-----------: |
210 | | 128 | 0.05 | 90.21 % | 6792.61 sec |
211 | | 256 | 0.05 | 90.28 % | 4871.68 sec |
212 | | 512 | 0.05 | 90.41 % | 3581.32 sec |
213 | | 1024 | 0.05 | 90.27 % | 3030.45 sec |
214 | | 2048 | 0.05 | 90.19 % | 2773.21 sec |
215 | | 4096 | 0.05 | 88.49 % | 2866.02 sec |
216 | | 8192 | 0.05 | 62.20 % | 1312.98 sec |
217 |
218 | * * *
219 |
220 | ### Attempt 3
221 |
222 | - Configuration
223 |
224 | - Hyperparams
225 |
226 | - momentum = 0.9
227 |
228 | - weigth_decay
229 |
230 | - noLars -> 5e-04
231 | - withLARS -> 5e-03
232 |
233 | - trust coefficient = 0.1
234 |
235 | - warm-up for 5 epoch
236 |
237 | - warmup_multiplier = 2
238 |
239 | - polynomial decay (power=2) LR policy (after warm-up)
240 |
241 | - for 200 epoch
242 | - minimum lr = 1e-05 \* k
243 |
244 | - number of epoch = 200
245 |
246 | - Additional Jobs
247 |
248 | - Use He initialization
249 |
250 | - base lr은 linear scailing rule에 따라 조정
251 |
252 | - Without LARS
253 |
254 | | Batch | Base LR | top-1 Accuracy, % | Time to train |
255 | | :---: | :-----: | :--------------------: | :-----------: |
256 | | 128 | 0.05 | 89.76 % | 3983.89 sec |
257 | | 256 | 0.1 | 90.08 %
(base line) | 3095.91 sec |
258 | | 512 | 0.2 | 89.34 % | 2674.38 sec |
259 | | 1024 | 0.4 | 88.82 % | 2581.19 sec |
260 | | 2048 | 0.8 | 89.29 % | 2660.56 sec |
261 | | 4096 | 1.6 | 85.02 % | 2871.04 sec |
262 | | 8192 | 3.2 | 77.72 % | 3195.90 sec |
263 |
264 | - With LARS (closest one to base line, for comparing time to train)
265 |
266 | | Batch | Base LR | top-1 Accuracy, % | Time to train |
267 | | :---: | :-----: | :---------------: | :-----------: |
268 | | 128 | 0.05 | 90.11 % | 6880.76 sec |
269 | | 256 | 0.1 | 90.12 % | 4262.83 sec |
270 | | 512 | 0.2 | 90.11 % | 3548.07 sec |
271 | | 1024 | 0.4 | 90.02 % | 2760.31 sec |
272 | | 2048 | 0.8 | 90.09 % | 2877.81 sec |
273 | | 4096 | 1.6 | 88.38 % | 2946.53 sec |
274 | | 8192 | 3.2 | 86.40 % | 3260.45 sec |
275 |
276 | - With LARS (best accuracy)
277 |
278 | | Batch | Base LR | top-1 Accuracy, % | Time to train |
279 | | :---: | :-----: | :---------------: | :-----------: |
280 | | 128 | 0.05 | 90.37 % | 7338.71 sec |
281 | | 256 | 0.1 | 90.32 % | 4590.58 sec |
282 | | 512 | 0.2 | 90.11 % | 3548.07 sec |
283 | | 1024 | 0.4 | 90.50 % | 2897.45 sec |
284 | | 2048 | 0.8 | 90.09 % | 2877.81 sec |
285 | | 4096 | 1.6 | 88.38 % | 2946.53 sec |
286 | | 8192 | 3.2 | 86.40 % | 3260.45 sec |
287 |
288 | * * *
289 |
290 | ### Attempt 4
291 |
292 | - Configuration
293 |
294 | - Hyperparams
295 |
296 | - momentum = 0.9
297 |
298 | - weigth_decay
299 |
300 | - noLars -> 5e-04
301 | - withLARS -> 5e-03
302 |
303 | - trust coefficient = 0.1
304 |
305 | - warm-up for 5 epoch
306 |
307 | - warmup_multiplier = 5
308 |
309 | - polynomial decay (power=2) LR policy (after warm-up)
310 |
311 | - for 200 epoch
312 | - minimum lr = 1e-05 \* k
313 |
314 | - number of epoch = 200
315 |
316 | - Additional Jobs
317 |
318 | - Use He initialization
319 |
320 | - base lr은 linear scailing rule에 따라 조정
321 |
322 | - Without LARS
323 |
324 | | Batch | Base LR | top-1 Accuracy, % | Time to train |
325 | | :---: | :-----: | :--------------------: | :-----------: |
326 | | 128 | 0.02 | 89.84 % | 4146.52 sec |
327 | | 256 | 0.04 | 90.22 %
(base line) | 3023.48 sec |
328 | | 512 | 0.08 | 89.42 % | 2588.01 sec |
329 | | 1024 | 0.16 | 89.41 % | 2494.35 sec |
330 | | 2048 | 0.32 | 88.97 % | 2616.32 sec |
331 | | 4096 | 0.64 | 85.13 % | 2872.76 sec |
332 | | 8192 | 1.28 | 75.99 % | 3226.53 sec |
333 |
334 | - With LARS (closest one to base line, for comparing time to train)
335 |
336 | | Batch | Base LR | top-1 Accuracy, % | Time to train |
337 | | :---: | :-----: | :---------------: | :-----------: |
338 | | 128 | 0.02 | 90.20 % | 6740.03 sec |
339 | | 256 | 0.04 | 90.25 % | 4662.09 sec |
340 | | 512 | 0.08 | 90.24 % | 3381.99 sec |
341 | | 1024 | 0.16 | 90.07 % | 2929.32 sec |
342 | | 2048 | 0.32 | 89.82 % | 2908.37 sec |
343 | | 4096 | 0.64 | 88.09 % | 2980.63 sec |
344 | | 8192 | 1.28 | 86.56 % | 3314.60 sec |
345 |
346 | - With LARS (best accuracy)
347 |
348 | | Batch | Base LR | top-1 Accuracy, % | Time to train |
349 | | :---: | :-----: | :---------------: | :-----------: |
350 | | 128 | 0.02 | 90.69 % | 7003.00 sec |
351 | | 256 | 0.04 | 90.32 % | 4808.80 sec |
352 | | 512 | 0.08 | 90.40 % | 3615.13 sec |
353 | | 1024 | 0.16 | 90.07 % | 2929.32 sec |
354 | | 2048 | 0.32 | 89.82 % | 2908.37 sec |
355 | | 4096 | 0.64 | 88.09 % | 2980.63 sec |
356 | | 8192 | 1.28 | 86.56 % | 3314.60 sec |
357 |
358 | * * *
359 |
360 | ### Attempt 5
361 |
362 | - Configuration
363 |
364 | - Hyperparams
365 |
366 | - momentum = 0.9
367 |
368 | - weigth_decay
369 |
370 | - noLars -> 5e-04
371 | - withLARS -> 5e-03
372 |
373 | - trust coefficient = 0.1
374 |
375 | - warm-up for 5 epoch
376 |
377 | - warmup_multiplier = 2
378 |
379 | - polynomial decay (power=2) LR policy (after warm-up)
380 |
381 | - **for 175 epoch**
382 | - minimum lr = 1e-05 \* k
383 |
384 | - **number of epoch = 175**
385 |
386 | - Additional Jobs
387 |
388 | - Use He initialization
389 |
390 | - base lr은 linear scailing rule에 따라 조정
391 |
392 | - Without LARS
393 |
394 | | Batch | Base LR | top-1 Accuracy, % | Time to train |
395 | | :---: | :-----: | :--------------------: | :-----------: |
396 | | 128 | 0.05 | 89.50 %
(base line) | 3682.72 sec |
397 | | 256 | 0.1 | 89.22 % | 2678.24 sec |
398 | | 512 | 0.2 | 89.12 % | 2337.15 sec |
399 | | 1024 | 0.4 | 88.70 % | 2282.48 sec |
400 | | 2048 | 0.8 | 88.89 % | 2316.96 sec |
401 | | 4096 | 1.6 | 86.87 % | 2515.56 sec |
402 | | 8192 | 3.2 | 15.50 % | 2783.00 sec |
403 |
404 | - With LARS (closest one to base line, for comparing time to train)
405 |
406 | | Batch | Base LR | top-1 Accuracy, % | Time to train |
407 | | :---: | :-----: | :---------------: | :-----------: |
408 | | 128 | 0.05 | 89.56 % | 5445.55 sec |
409 | | 256 | 0.1 | 89.52 % | 3461.59 sec |
410 | | 512 | 0.2 | 89.60 % | 2738.91 sec |
411 | | 1024 | 0.4 | 89.50 % | 2410.23 sec |
412 | | 2048 | 0.8 | 89.42 % | 2474.93 sec |
413 | | 4096 | 1.6 | 88.43 % | 2618.97 sec |
414 | | 8192 | 3.2 | 74.96 % | 1835.32 sec |
415 |
416 | - With LARS (best accuracy)
417 |
418 | | Batch | Base LR | top-1 Accuracy, % | Time to train |
419 | | :---: | :-----: | :---------------: | :-----------: |
420 | | 128 | 0.05 | 90.36 % | 6377.71 sec |
421 | | 256 | 0.1 | 90.18 % | 4219.26 sec |
422 | | 512 | 0.2 | 90.08 % | 3130.41 sec |
423 | | 1024 | 0.4 | 89.94 % | 2578.00 sec |
424 | | 2048 | 0.8 | 89.42 % | 2474.93 sec |
425 | | 4096 | 1.6 | 88.43 % | 2618.97 sec |
426 | | 8192 | 3.2 | 74.96 % | 1835.32 sec |
427 |
428 | * * *
429 |
430 | ## Visualization
431 |
432 |
433 |
434 | <Fig1. Attempt4, Without LARS, Batch size = 8192>
435 |
436 |
437 |
438 | <Fig2. Attempt4, With LARS, Batch size = 8192>
439 |
440 | - \과 \를 비교하면 LARS를 사용할 때, 좀 더 안정적으로 학습을 시작하고, 부드럽게 accuracy가 증가하는 것을 확인할 수 있다.
441 |
442 | - Attempt3, 4, 5를 작업하면서 만든 Accuracy 변화율 그래프는 아래 링크에서 확인하는 것이 가능하다.
443 | - [Attempt3](https://github.com/cmpark0126/pytorch-LARS/tree/master/fig/result_fig-attempt3)
444 | - [Attempt4](https://github.com/cmpark0126/pytorch-LARS/tree/master/fig/result_fig-attempt4)
445 | - [Attempt5](https://github.com/cmpark0126/pytorch-LARS/tree/master/fig/result_fig-attempt5)
446 |
447 | ## Analysis of Resnet50 Training With Large Batch (CIFAR10)
448 |
449 | - LARS를 사용하면 1024까지의 Batch를 사용해서 모델이 Base line의 성능을 보일 수 있도록 학습하는 것이 가능하다는 것을 확인
450 |
451 | - LARS만을 사용하는 것보다, He initialization을 포함하여 여러 테크닉을 함께 사용하는 것이 중요하다는 것을 확인
452 |
453 | - LARS를 사용하면 단순히 base line을 만족하는 것이 아니라 더 좋은 성능을 보일 수도 있다는 것을 확인
454 | - Local learning rate가 vanishing 문제나 exploding gradient 문제를 완화시킨다는 논문의 언급에 따른 부가 효과로 보임
455 |
456 | ## Open Issue
457 |
458 | - LARS를 사용하면 약 두 배 정도 시간이 더 들어가는 것을 확인. 학습 시간을 줄일 수 있는 방안이 있는지 찾아보기
459 |
460 | ## Reference
461 |
462 | - Base code:
463 |
464 | - warm-up LR scheduler:
465 | - 또한, 이를 기반으로 PolynomialLRDecay class 구현
466 | - polynomial LR decay scheduler
467 | - 참고: scheduler.py
468 |
469 | - Pytorch Doc / Optimizer:
470 | - Optimizer class
471 | - SGD class
472 |
473 | ## Appendix
474 |
475 | ### val.py 실행 화면
476 |
477 |
478 |
479 | - best accuracy가 update되어 온 history를 확인할 수 있다.
480 |
--------------------------------------------------------------------------------
/fig/appendix/run_val.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/appendix/run_val.PNG
--------------------------------------------------------------------------------
/fig/result_fig-attempt3/result_fig-attempt3_noLARS/noLars-1024.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt3/result_fig-attempt3_noLARS/noLars-1024.jpg
--------------------------------------------------------------------------------
/fig/result_fig-attempt3/result_fig-attempt3_noLARS/noLars-128.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt3/result_fig-attempt3_noLARS/noLars-128.jpg
--------------------------------------------------------------------------------
/fig/result_fig-attempt3/result_fig-attempt3_noLARS/noLars-2048.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt3/result_fig-attempt3_noLARS/noLars-2048.jpg
--------------------------------------------------------------------------------
/fig/result_fig-attempt3/result_fig-attempt3_noLARS/noLars-256.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt3/result_fig-attempt3_noLARS/noLars-256.jpg
--------------------------------------------------------------------------------
/fig/result_fig-attempt3/result_fig-attempt3_noLARS/noLars-4096.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt3/result_fig-attempt3_noLARS/noLars-4096.jpg
--------------------------------------------------------------------------------
/fig/result_fig-attempt3/result_fig-attempt3_noLARS/noLars-512.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt3/result_fig-attempt3_noLARS/noLars-512.jpg
--------------------------------------------------------------------------------
/fig/result_fig-attempt3/result_fig-attempt3_noLARS/noLars-8192.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt3/result_fig-attempt3_noLARS/noLars-8192.jpg
--------------------------------------------------------------------------------
/fig/result_fig-attempt3/result_fig-attempt3_withLARS/withLars-1024.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt3/result_fig-attempt3_withLARS/withLars-1024.jpg
--------------------------------------------------------------------------------
/fig/result_fig-attempt3/result_fig-attempt3_withLARS/withLars-128.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt3/result_fig-attempt3_withLARS/withLars-128.jpg
--------------------------------------------------------------------------------
/fig/result_fig-attempt3/result_fig-attempt3_withLARS/withLars-2048.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt3/result_fig-attempt3_withLARS/withLars-2048.jpg
--------------------------------------------------------------------------------
/fig/result_fig-attempt3/result_fig-attempt3_withLARS/withLars-256.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt3/result_fig-attempt3_withLARS/withLars-256.jpg
--------------------------------------------------------------------------------
/fig/result_fig-attempt3/result_fig-attempt3_withLARS/withLars-4096.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt3/result_fig-attempt3_withLARS/withLars-4096.jpg
--------------------------------------------------------------------------------
/fig/result_fig-attempt3/result_fig-attempt3_withLARS/withLars-512.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt3/result_fig-attempt3_withLARS/withLars-512.jpg
--------------------------------------------------------------------------------
/fig/result_fig-attempt3/result_fig-attempt3_withLARS/withLars-8192.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt3/result_fig-attempt3_withLARS/withLars-8192.jpg
--------------------------------------------------------------------------------
/fig/result_fig-attempt4/result_fig-noLARS/noLars-1024.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt4/result_fig-noLARS/noLars-1024.jpg
--------------------------------------------------------------------------------
/fig/result_fig-attempt4/result_fig-noLARS/noLars-128.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt4/result_fig-noLARS/noLars-128.jpg
--------------------------------------------------------------------------------
/fig/result_fig-attempt4/result_fig-noLARS/noLars-2048.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt4/result_fig-noLARS/noLars-2048.jpg
--------------------------------------------------------------------------------
/fig/result_fig-attempt4/result_fig-noLARS/noLars-256.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt4/result_fig-noLARS/noLars-256.jpg
--------------------------------------------------------------------------------
/fig/result_fig-attempt4/result_fig-noLARS/noLars-4096.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt4/result_fig-noLARS/noLars-4096.jpg
--------------------------------------------------------------------------------
/fig/result_fig-attempt4/result_fig-noLARS/noLars-512.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt4/result_fig-noLARS/noLars-512.jpg
--------------------------------------------------------------------------------
/fig/result_fig-attempt4/result_fig-noLARS/noLars-8192.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt4/result_fig-noLARS/noLars-8192.jpg
--------------------------------------------------------------------------------
/fig/result_fig-attempt4/result_fig-withLARS/withLars-1024.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt4/result_fig-withLARS/withLars-1024.jpg
--------------------------------------------------------------------------------
/fig/result_fig-attempt4/result_fig-withLARS/withLars-128.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt4/result_fig-withLARS/withLars-128.jpg
--------------------------------------------------------------------------------
/fig/result_fig-attempt4/result_fig-withLARS/withLars-2048.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt4/result_fig-withLARS/withLars-2048.jpg
--------------------------------------------------------------------------------
/fig/result_fig-attempt4/result_fig-withLARS/withLars-256.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt4/result_fig-withLARS/withLars-256.jpg
--------------------------------------------------------------------------------
/fig/result_fig-attempt4/result_fig-withLARS/withLars-4096.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt4/result_fig-withLARS/withLars-4096.jpg
--------------------------------------------------------------------------------
/fig/result_fig-attempt4/result_fig-withLARS/withLars-512.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt4/result_fig-withLARS/withLars-512.jpg
--------------------------------------------------------------------------------
/fig/result_fig-attempt4/result_fig-withLARS/withLars-8192.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt4/result_fig-withLARS/withLars-8192.jpg
--------------------------------------------------------------------------------
/fig/result_fig-attempt5/result_fig-noLARS/noLars-1024.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-noLARS/noLars-1024.jpg
--------------------------------------------------------------------------------
/fig/result_fig-attempt5/result_fig-noLARS/noLars-1024.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-noLARS/noLars-1024.pth
--------------------------------------------------------------------------------
/fig/result_fig-attempt5/result_fig-noLARS/noLars-128.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-noLARS/noLars-128.jpg
--------------------------------------------------------------------------------
/fig/result_fig-attempt5/result_fig-noLARS/noLars-128.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-noLARS/noLars-128.pth
--------------------------------------------------------------------------------
/fig/result_fig-attempt5/result_fig-noLARS/noLars-2048.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-noLARS/noLars-2048.jpg
--------------------------------------------------------------------------------
/fig/result_fig-attempt5/result_fig-noLARS/noLars-2048.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-noLARS/noLars-2048.pth
--------------------------------------------------------------------------------
/fig/result_fig-attempt5/result_fig-noLARS/noLars-256.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-noLARS/noLars-256.jpg
--------------------------------------------------------------------------------
/fig/result_fig-attempt5/result_fig-noLARS/noLars-256.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-noLARS/noLars-256.pth
--------------------------------------------------------------------------------
/fig/result_fig-attempt5/result_fig-noLARS/noLars-4096.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-noLARS/noLars-4096.jpg
--------------------------------------------------------------------------------
/fig/result_fig-attempt5/result_fig-noLARS/noLars-4096.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-noLARS/noLars-4096.pth
--------------------------------------------------------------------------------
/fig/result_fig-attempt5/result_fig-noLARS/noLars-512.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-noLARS/noLars-512.jpg
--------------------------------------------------------------------------------
/fig/result_fig-attempt5/result_fig-noLARS/noLars-512.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-noLARS/noLars-512.pth
--------------------------------------------------------------------------------
/fig/result_fig-attempt5/result_fig-noLARS/noLars-8192.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-noLARS/noLars-8192.jpg
--------------------------------------------------------------------------------
/fig/result_fig-attempt5/result_fig-noLARS/noLars-8192.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-noLARS/noLars-8192.pth
--------------------------------------------------------------------------------
/fig/result_fig-attempt5/result_fig-withLARS/withLars-1024.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-withLARS/withLars-1024.jpg
--------------------------------------------------------------------------------
/fig/result_fig-attempt5/result_fig-withLARS/withLars-1024.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-withLARS/withLars-1024.pth
--------------------------------------------------------------------------------
/fig/result_fig-attempt5/result_fig-withLARS/withLars-128.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-withLARS/withLars-128.jpg
--------------------------------------------------------------------------------
/fig/result_fig-attempt5/result_fig-withLARS/withLars-128.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-withLARS/withLars-128.pth
--------------------------------------------------------------------------------
/fig/result_fig-attempt5/result_fig-withLARS/withLars-2048.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-withLARS/withLars-2048.jpg
--------------------------------------------------------------------------------
/fig/result_fig-attempt5/result_fig-withLARS/withLars-2048.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-withLARS/withLars-2048.pth
--------------------------------------------------------------------------------
/fig/result_fig-attempt5/result_fig-withLARS/withLars-256.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-withLARS/withLars-256.jpg
--------------------------------------------------------------------------------
/fig/result_fig-attempt5/result_fig-withLARS/withLars-256.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-withLARS/withLars-256.pth
--------------------------------------------------------------------------------
/fig/result_fig-attempt5/result_fig-withLARS/withLars-4096.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-withLARS/withLars-4096.jpg
--------------------------------------------------------------------------------
/fig/result_fig-attempt5/result_fig-withLARS/withLars-4096.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-withLARS/withLars-4096.pth
--------------------------------------------------------------------------------
/fig/result_fig-attempt5/result_fig-withLARS/withLars-512.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-withLARS/withLars-512.jpg
--------------------------------------------------------------------------------
/fig/result_fig-attempt5/result_fig-withLARS/withLars-512.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-withLARS/withLars-512.pth
--------------------------------------------------------------------------------
/fig/result_fig-attempt5/result_fig-withLARS/withLars-8192.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-withLARS/withLars-8192.jpg
--------------------------------------------------------------------------------
/fig/result_fig-attempt5/result_fig-withLARS/withLars-8192.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmpark0126/pytorch-LARS/f3dede6caebd3bbb683e1f34b4ca730df0b0be40/fig/result_fig-attempt5/result_fig-withLARS/withLars-8192.pth
--------------------------------------------------------------------------------
/hyperparams.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | class Base:
4 | batch_size = 128 # initial batch size
5 | lr = 0.05 # Initial learning rate
6 | multiples = 1 # help to calculate k
7 |
8 | class Hyperparams:
9 | '''Hyper parameters'''
10 | device = [0]
11 |
12 | batch_size = Base.batch_size * (2 ** (Base.multiples - 1)) # k = (2 ** (Base.multiples - 1)
13 | lr = Base.lr * (2 ** (Base.multiples - 1)) # for LR linear scailing
14 |
15 | # optim
16 | momentum = 0.9
17 | weight_decay = 5e-4
18 | trust_coef = 0.1
19 |
20 | # warm-up step & Linear Scaling Rule
21 | warmup_multiplier = 2
22 | warmup_epoch = 5
23 |
24 | # decay lr step (polynomial)
25 | max_decay_epoch = 200
26 | end_learning_rate = 0.0001
27 |
28 | num_of_epoch = 200
29 | with_lars = False
30 | resume = False
31 |
32 | def print_hyperparms():
33 | print('batch_size: ' + str(Hyperparams.batch_size))
34 | print('lr: ' + str(Hyperparams.lr))
35 | print('momentum: ' + str(Hyperparams.momentum))
36 | print('trust_coef: ' + str(Hyperparams.trust_coef))
37 | print('warmup_multiplier: ' + str(Hyperparams.warmup_multiplier))
38 | print('warmup_epoch: ' + str(Hyperparams.warmup_epoch))
39 | print('max_decay_epoch: ' + str(Hyperparams.max_decay_epoch))
40 | print('end_learning_rate: ' + str(Hyperparams.end_learning_rate))
41 | print('num_of_epoch: ' + str(Hyperparams.num_of_epoch))
42 | print('device: ' + str(Hyperparams.device))
43 | print('resume: ' + str(Hyperparams.resume))
44 | print('with_lars: ' + str(Hyperparams.with_lars))
45 | print('weight_decay: ' + str(Hyperparams.weight_decay))
46 |
47 | def get_info_dict():
48 | return dict(batch_size=Hyperparams.batch_size,
49 | lr=Hyperparams.lr,
50 | momentum=Hyperparams.momentum,
51 | trust_coef=Hyperparams.trust_coef,
52 | warmup_multiplier=Hyperparams.warmup_multiplier,
53 | warmup_epoch=Hyperparams.warmup_epoch,
54 | max_decay_epoch=Hyperparams.max_decay_epoch,
55 | end_learning_rate=Hyperparams.end_learning_rate,
56 | num_of_epoch=Hyperparams.num_of_epoch,
57 | device=Hyperparams.device,
58 | resume=Hyperparams.resume,
59 | with_lars=Hyperparams.with_lars,
60 | weight_decay=Hyperparams.weight_decay)
61 |
62 | class Hyperparams_for_val:
63 | checkpoint_folder_name = 'checkpoint'
64 | with_lars = False
65 | batch_size = 128
66 | device = [0]
--------------------------------------------------------------------------------
/optimizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.optim.optimizer import Optimizer, required
3 |
4 | class SGD_without_lars(Optimizer):
5 | r"""Implements stochastic gradient descent (optionally with momentum).
6 | """
7 |
8 | def __init__(self, params, lr=required, momentum=0, weight_decay=0):
9 | if lr is not required and lr < 0.0:
10 | raise ValueError("Invalid learning rate: {}".format(lr))
11 | if momentum < 0.0:
12 | raise ValueError("Invalid momentum value: {}".format(momentum))
13 | if weight_decay < 0.0:
14 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
15 |
16 | defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay)
17 | super(SGD_without_lars, self).__init__(params, defaults)
18 |
19 | def __setstate__(self, state):
20 | super(SGD_without_lars, self).__setstate__(state)
21 |
22 | def step(self, closure=None):
23 | """Performs a single optimization step.
24 |
25 | Arguments:
26 | closure (callable, optional): A closure that reevaluates the model
27 | and returns the loss.
28 | """
29 | loss = None
30 | if closure is not None:
31 | loss = closure()
32 |
33 | for group in self.param_groups:
34 | weight_decay = group['weight_decay']
35 | momentum = group['momentum']
36 | lr = group['lr']
37 |
38 | for p in group['params']:
39 | #torch.cuda.nvtx.range_push('trial')
40 | if p.grad is None:
41 | continue
42 | d_p = p.grad.data
43 | torch.cuda.nvtx.range_push('weight decay')
44 | if weight_decay != 0:
45 | d_p.add_(weight_decay, p.data)
46 | torch.cuda.nvtx.range_pop()
47 | # d_p.mul_(lr)
48 |
49 | torch.cuda.nvtx.range_push('momentum')
50 | if momentum != 0:
51 | param_state = self.state[p]
52 | if 'momentum_buffer' not in param_state:
53 | buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
54 | else:
55 | buf = param_state['momentum_buffer']
56 | buf.mul_(momentum).add_(d_p)
57 | d_p = buf
58 | torch.cuda.nvtx.range_pop()
59 |
60 | torch.cuda.nvtx.range_push('weight update')
61 | p.data.add_(-lr, d_p)
62 | torch.cuda.nvtx.range_pop()
63 |
64 | # torch.cuda.nvtx.range_pop()
65 | return loss
66 |
67 |
68 | class SGD_with_lars(Optimizer):
69 | r"""Implements stochastic gradient descent (optionally with momentum).
70 | """
71 |
72 | def __init__(self, params, lr=required, momentum=0, weight_decay=0, trust_coef=1.): # need to add trust coef
73 | if lr is not required and lr < 0.0:
74 | raise ValueError("Invalid learning rate: {}".format(lr))
75 | if momentum < 0.0:
76 | raise ValueError("Invalid momentum value: {}".format(momentum))
77 | if weight_decay < 0.0:
78 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
79 | if trust_coef < 0.0:
80 | raise ValueError("Invalid trust_coef value: {}".format(trust_coef))
81 |
82 | defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay, trust_coef=trust_coef)
83 |
84 | super(SGD_with_lars, self).__init__(params, defaults)
85 |
86 | def __setstate__(self, state):
87 | super(SGD_with_lars, self).__setstate__(state)
88 |
89 | def step(self, closure=None):
90 | """Performs a single optimization step.
91 |
92 | Arguments:
93 | closure (callable, optional): A closure that reevaluates the model
94 | and returns the loss.
95 | """
96 | loss = None
97 | if closure is not None:
98 | loss = closure()
99 |
100 | for group in self.param_groups:
101 | weight_decay = group['weight_decay']
102 | momentum = group['momentum']
103 | trust_coef = group['trust_coef']
104 | global_lr = group['lr']
105 |
106 | for p in group['params']:
107 | if p.grad is None:
108 | continue
109 | d_p = p.grad.data
110 |
111 | p_norm = torch.norm(p.data, p=2)
112 | d_p_norm = torch.norm(d_p, p=2).add_(momentum, p_norm)
113 | lr = torch.div(p_norm, d_p_norm).mul_(trust_coef)
114 |
115 | lr.mul_(global_lr)
116 |
117 | if weight_decay != 0:
118 | d_p.add_(weight_decay, p.data)
119 |
120 | d_p.mul_(lr)
121 |
122 | if momentum != 0:
123 | param_state = self.state[p]
124 | if 'momentum_buffer' not in param_state:
125 | buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
126 | else:
127 | buf = param_state['momentum_buffer']
128 | buf.mul_(momentum).add_(d_p)
129 | d_p = buf
130 |
131 | p.data.add_(-1, d_p)
132 |
133 | return loss
134 |
135 |
136 | class SGD_with_lars_ver2(Optimizer):
137 | r"""Implements stochastic gradient descent (optionally with momentum).
138 | """
139 |
140 | def __init__(self, params, lr=required, momentum=0, weight_decay=0, trust_coef=1.): # need to add trust coef
141 | if lr is not required and lr < 0.0:
142 | raise ValueError("Invalid learning rate: {}".format(lr))
143 | if momentum < 0.0:
144 | raise ValueError("Invalid momentum value: {}".format(momentum))
145 | if weight_decay < 0.0:
146 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
147 | if trust_coef < 0.0:
148 | raise ValueError("Invalid trust_coef value: {}".format(trust_coef))
149 |
150 | defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay, trust_coef=trust_coef)
151 |
152 | super(SGD_with_lars_ver2, self).__init__(params, defaults)
153 |
154 | def __setstate__(self, state):
155 | super(SGD_with_lars_ver2, self).__setstate__(state)
156 |
157 | def step(self, closure=None):
158 | """Performs a single optimization step.
159 |
160 | Arguments:
161 | closure (callable, optional): A closure that reevaluates the model
162 | and returns the loss.
163 | """
164 | loss = None
165 | if closure is not None:
166 | loss = closure()
167 |
168 | for group in self.param_groups:
169 | weight_decay = group['weight_decay']
170 | momentum = group['momentum']
171 | trust_coef = group['trust_coef']
172 | global_lr = group['lr']
173 |
174 | for p in group['params']:
175 | if p.grad is None:
176 | continue
177 | d_p = p.grad.data
178 |
179 | # torch.cuda.nvtx.range_push('p_norm')
180 | p_norm = torch.norm(p.data, p=2)
181 | # torch.cuda.nvtx.range_pop()
182 | # print('p_norm')
183 | # print(p_norm)
184 | # torch.cuda.nvtx.range_push('d_p_norm')
185 | d_p_norm = torch.norm(d_p, p=2).add_(weight_decay, p_norm)
186 | #torch.cuda.nvtx.range_pop()
187 | # print('d_p_norm')
188 | # print(torch.norm(d_p, p=2))
189 | #torch.cuda.nvtx.range_push('div')
190 | lr = torch.div(p_norm, d_p_norm)
191 | #torch.cuda.nvtx.range_pop()
192 | # print('result')
193 | # print(torch.div(p_norm, d_p_norm))
194 | # print('')
195 |
196 |
197 | #torch.cuda.nvtx.range_push('calculate local lr')
198 | lr.mul_(-global_lr*trust_coef)
199 | #torch.cuda.nvtx.range_pop()
200 |
201 | #torch.cuda.nvtx.range_push('weight decay')
202 | if weight_decay != 0:
203 | d_p.add_(weight_decay, p.data)
204 | #torch.cuda.nvtx.range_pop()
205 |
206 | #torch.cuda.nvtx.range_push('momentum')
207 | if momentum != 0:
208 | param_state = self.state[p]
209 | if 'momentum_buffer' not in param_state:
210 | buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
211 | else:
212 | buf = param_state['momentum_buffer']
213 | buf.mul_(momentum).add_(d_p)
214 | d_p = buf
215 | #torch.cuda.nvtx.range_pop()
216 |
217 | #torch.cuda.nvtx.range_push('weight update')
218 | d_p.mul_(lr)
219 | p.data.add_(d_p)
220 | #torch.cuda.nvtx.range_pop()
221 |
222 |
223 | return loss
224 |
--------------------------------------------------------------------------------
/scheduler.py:
--------------------------------------------------------------------------------
1 | from torch.optim.lr_scheduler import _LRScheduler
2 | from torch.optim.lr_scheduler import ReduceLROnPlateau
3 |
4 | class GradualWarmupScheduler(_LRScheduler):
5 | """ Gradually warm-up(increasing) learning rate in optimizer.
6 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
7 | Args:
8 | optimizer (Optimizer): Wrapped optimizer.
9 | multiplier: target learning rate = base lr * multiplier
10 | total_epoch: target learning rate is reached at total_epoch, gradually
11 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
12 | """
13 |
14 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
15 | self.multiplier = multiplier
16 | if self.multiplier < 1.:
17 | raise ValueError('multiplier should be greater than 1.')
18 | self.total_epoch = total_epoch
19 | self.after_scheduler = after_scheduler
20 | self.finished = False
21 | super().__init__(optimizer)
22 |
23 | def get_lr(self):
24 | if self.last_epoch > self.total_epoch:
25 | if self.after_scheduler:
26 | if not self.finished:
27 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
28 | self.finished = True
29 | return self.after_scheduler.get_lr()
30 | return [base_lr * self.multiplier for base_lr in self.base_lrs]
31 |
32 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
33 |
34 | def step_ReduceLROnPlateau(self, metrics, epoch=None):
35 | if epoch is None:
36 | epoch = self.last_epoch + 1
37 | self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
38 | if self.last_epoch <= self.total_epoch:
39 | warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
40 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
41 | param_group['lr'] = lr
42 | else:
43 | if epoch is None:
44 | self.after_scheduler.step(metrics, None)
45 | else:
46 | self.after_scheduler.step(metrics, epoch - self.total_epoch)
47 |
48 | def step(self, epoch=None, metrics=None):
49 | if type(self.after_scheduler) != ReduceLROnPlateau: # if atter scheduler is not reduce LR Plateau scheduler
50 | if self.finished and self.after_scheduler:
51 | if epoch is None:
52 | self.after_scheduler.step(None)
53 | else:
54 | self.after_scheduler.step(epoch - self.total_epoch)
55 | else:
56 | return super(GradualWarmupScheduler, self).step(epoch)
57 | else:
58 | self.step_ReduceLROnPlateau(metrics, epoch)
59 |
60 |
61 | class PolynomialLRDecay(_LRScheduler):
62 | """Polynomial decay(decrease) learning rate until step reach to max_decay_step
63 |
64 | Args:
65 | optimizer (Optimizer): Wrapped optimizer.
66 | max_decay_steps: after this step, we stop decreasing learning rate
67 | end_learning_rate: scheduler stoping learning rate decay, value of learning rate must be this value
68 | power: TBW
69 | """
70 |
71 | def __init__(self, optimizer, max_decay_steps, end_learning_rate=0.0001, power=1.0):
72 | if max_decay_steps <= 1.:
73 | raise ValueError('max_decay_steps should be greater than 1.')
74 | self.max_decay_steps = max_decay_steps
75 | self.end_learning_rate = end_learning_rate
76 | self.power = power
77 | self.last_step = 0
78 | super().__init__(optimizer)
79 |
80 | def get_lr(self):
81 | if self.last_step > self.max_decay_steps:
82 | return [self.end_learning_rate for _ in self.base_lrs]
83 |
84 | return [(base_lr - self.end_learning_rate) *
85 | ((1 - self.last_step / self.max_decay_steps) ** (self.power)) +
86 | self.end_learning_rate for base_lr in self.base_lrs]
87 |
88 | def step(self, step=None):
89 | if step is None:
90 | step = self.last_step + 1
91 | self.last_step = step if step != 0 else 1
92 | if self.last_step <= self.max_decay_steps:
93 | decay_lrs = [(base_lr - self.end_learning_rate) *
94 | ((1 - self.last_step / self.max_decay_steps) ** (self.power)) +
95 | self.end_learning_rate for base_lr in self.base_lrs]
96 | for param_group, lr in zip(self.optimizer.param_groups, decay_lrs):
97 | param_group['lr'] = lr
--------------------------------------------------------------------------------
/test_lars.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | import torch.nn.functional as F
5 | import torch.backends.cudnn as cudnn
6 |
7 | import torchvision.models as models
8 |
9 | import os
10 | import time
11 |
12 | from optimizer import SGD_without_lars, SGD_with_lars, SGD_with_lars_ver2
13 | from scheduler import GradualWarmupScheduler, PolynomialLRDecay
14 | from hyperparams import Hyperparams as hp
15 | from utils import progress_bar
16 |
17 | with torch.cuda.device(0):
18 | # Model
19 | print('==> Building model..')
20 | net = models.resnet50()
21 | net.cuda()
22 | net = torch.nn.DataParallel(net, device_ids=[0])
23 | cudnn.benchmark = True
24 |
25 | # Loss & Optimizer
26 | criterion = nn.CrossEntropyLoss()
27 | # optimizer = SGD_with_lars(net.parameters(), lr=hp.lr, momentum=hp.momentum, weight_decay=hp.weight_decay, trust_coef=hp.trust_coef)
28 | optimizer = SGD_with_lars_ver2(net.parameters(), lr=hp.lr, momentum=hp.momentum, weight_decay=hp.weight_decay, trust_coef=hp.trust_coef)
29 | # optimizer = SGD_without_lars(net.parameters(), lr=hp.lr, momentum=hp.momentum, weight_decay=hp.weight_decay)
30 | # optimizer = optim.SGD(net.parameters(), lr=hp.lr, momentum=hp.momentum, weight_decay=hp.weight_decay)
31 |
32 | # Training
33 | net.train()
34 |
35 | inputs = torch.ones([2, 3, 32, 32]).cuda()
36 | targets = torch.ones([2], dtype=torch.long).cuda()
37 | optimizer.zero_grad()
38 | outputs = net(inputs)
39 | loss = criterion(outputs, targets)
40 | loss.backward()
41 |
42 | print('Complete Forward & Backward')
43 |
44 | for batch_idx in range(5):
45 | start_time = time.time()
46 | # torch.cuda.nvtx.range_push('trial')
47 |
48 | optimizer.step()
49 |
50 | # torch.cuda.nvtx.range_pop()
51 | print('time to optimize is %.3f' % (time.time() - start_time))
52 |
53 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | import torch.nn.functional as F
5 | import torch.backends.cudnn as cudnn
6 |
7 | import torchvision
8 | import torchvision.transforms as transforms
9 | import torchvision.models as models
10 |
11 | import os
12 | import time
13 |
14 | from optimizer import SGD_without_lars, SGD_with_lars, SGD_with_lars_ver2
15 | from scheduler import GradualWarmupScheduler, PolynomialLRDecay
16 | from hyperparams import Hyperparams as hp
17 | from utils import progress_bar
18 |
19 | with torch.cuda.device(hp.device[0]):
20 | all_accs = []
21 | best_acc = 0 # best test accuracy
22 | all_epochs = []
23 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch
24 | all_times = []
25 | time_to_train = 0
26 |
27 | # Data
28 | print('==> Preparing data..')
29 | transform_train = transforms.Compose([
30 | transforms.RandomCrop(32, padding=4),
31 | transforms.RandomHorizontalFlip(),
32 | transforms.ToTensor(),
33 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
34 | ])
35 |
36 | transform_test = transforms.Compose([
37 | transforms.ToTensor(),
38 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
39 | ])
40 |
41 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
42 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=hp.batch_size, shuffle=True, num_workers=2)
43 |
44 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
45 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
46 |
47 | classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
48 |
49 | def init_weights(m):
50 | if type(m) == nn.Linear or type(m) == nn.Conv2d:
51 | torch.nn.init.kaiming_uniform_(m.weight)
52 |
53 | # Model
54 | print('==> Building model..')
55 | net = models.resnet50()
56 | net.apply(init_weights)
57 | net.cuda()
58 | net = torch.nn.DataParallel(net, device_ids=hp.device)
59 | cudnn.benchmark = True
60 |
61 | if hp.resume:
62 | # Load checkpoint.
63 | print('==> Resuming from checkpoint..')
64 | assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
65 | if hp.with_lars:
66 | checkpoint = torch.load('./checkpoint/withLars-' + str(hp.batch_size) + '.pth')
67 | else:
68 | checkpoint = torch.load('./checkpoint/noLars-' + str(hp.batch_size) + '.pth')
69 | net.load_state_dict(checkpoint['net'])
70 | best_acc = checkpoint['acc']
71 | start_epoch = checkpoint['epoch']
72 | time_to_train = checkpoint['time_to_train']
73 | basic_info = checkpoint['basic_info']
74 |
75 | # Loss & Optimizer
76 | criterion = nn.CrossEntropyLoss()
77 | optimizer = None
78 | if hp.with_lars:
79 | # optimizer = SGD_with_lars(net.parameters(), lr=hp.lr, momentum=hp.momentum, weight_decay=hp.weight_decay, trust_coef=hp.trust_coef)
80 | optimizer = SGD_with_lars_ver2(net.parameters(), lr=hp.lr, momentum=hp.momentum, weight_decay=hp.weight_decay, trust_coef=hp.trust_coef)
81 | else:
82 | # optimizer = SGD_without_lars(net.parameters(), lr=hp.lr, momentum=hp.momentum, weight_decay=hp.weight_decay)
83 | optimizer = optim.SGD(net.parameters(), lr=hp.lr, momentum=hp.momentum, weight_decay=hp.weight_decay)
84 |
85 | warmup_scheduler = GradualWarmupScheduler(optimizer=optimizer, multiplier=hp.warmup_multiplier, total_epoch=hp.warmup_epoch)
86 | poly_decay_scheduler = PolynomialLRDecay(optimizer=optimizer, max_decay_steps=hp.max_decay_epoch * len(trainloader),
87 | end_learning_rate=hp.end_learning_rate, power=2.0) # poly(2)
88 |
89 | # Training
90 | def train(epoch):
91 | global time_to_train
92 | net.train()
93 | train_loss = 0
94 | correct = 0
95 | total = 0
96 |
97 | start_time = time.time()
98 | for batch_idx, (inputs, targets) in enumerate(trainloader):
99 | if epoch > hp.warmup_epoch: # after warmup schduler step
100 | poly_decay_scheduler.step()
101 | inputs, targets = inputs.cuda(), targets.cuda()
102 | optimizer.zero_grad()
103 | outputs = net(inputs)
104 | loss = criterion(outputs, targets)
105 | loss.backward()
106 | optimizer.step()
107 |
108 | train_loss += loss.item()
109 | _, predicted = outputs.max(1)
110 | total += targets.size(0)
111 | correct += predicted.eq(targets).sum().item()
112 |
113 | progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
114 | % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
115 | time_to_train = time_to_train + (time.time() - start_time)
116 |
117 | def test(epoch):
118 | global best_acc
119 | net.eval()
120 | test_loss = 0
121 | correct = 0
122 | total = 0
123 | with torch.no_grad():
124 | for batch_idx, (inputs, targets) in enumerate(testloader):
125 | inputs, targets = inputs.cuda(), targets.cuda()
126 | outputs = net(inputs)
127 | loss = criterion(outputs, targets)
128 |
129 | test_loss += loss.item()
130 | _, predicted = outputs.max(1)
131 | total += targets.size(0)
132 | correct += predicted.eq(targets).sum().item()
133 |
134 | progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
135 | % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
136 |
137 | # Save checkpoint.
138 | acc = 100.*correct/total
139 | if acc > best_acc:
140 | all_accs.append(acc)
141 | all_epochs.append(epoch)
142 | all_times.append(round(time_to_train, 2))
143 | print('Saving..')
144 | state = {
145 | 'net': net.state_dict(),
146 | 'acc': all_accs,
147 | 'epoch': all_epochs,
148 | 'time_to_train': all_times,
149 | 'basic_info': hp.get_info_dict()
150 | }
151 | if not os.path.isdir('checkpoint'):
152 | os.mkdir('checkpoint')
153 | if hp.with_lars:
154 | torch.save(state, './checkpoint/withLars-' + str(hp.batch_size) + '.pth')
155 | else:
156 | torch.save(state, './checkpoint/noLars-' + str(hp.batch_size) + '.pth')
157 | best_acc = acc
158 |
159 | if hp.with_lars:
160 | print('Resnet50, data=cifar10, With LARS')
161 | else:
162 | print('Resnet50, data=cifar10, Without LARS')
163 | hp.print_hyperparms()
164 | for epoch in range(0, hp.num_of_epoch):
165 | print('\nEpoch: %d' % epoch)
166 | if epoch <= hp.warmup_epoch: # for readability
167 | warmup_scheduler.step()
168 | if epoch > hp.warmup_epoch: # after warmup, start decay scheduler with warmup-ed learning rate
169 | poly_decay_scheduler.base_lrs = warmup_scheduler.get_lr()
170 | for param_group in optimizer.param_groups:
171 | print('lr: ' + str(param_group['lr']))
172 | train(epoch)
173 | test(epoch)
174 |
--------------------------------------------------------------------------------
/train_with_matplot.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | import torch.nn.functional as F
5 | import torch.backends.cudnn as cudnn
6 |
7 | import torchvision
8 | import torchvision.transforms as transforms
9 | import torchvision.models as models
10 |
11 | import os
12 | import time
13 |
14 | from optimizer import SGD_without_lars, SGD_with_lars, SGD_with_lars_ver2
15 | from scheduler import GradualWarmupScheduler, PolynomialLRDecay
16 | from hyperparams import Hyperparams as hp
17 | from utils import progress_bar
18 |
19 | import matplotlib.pyplot as plt
20 |
21 | with torch.cuda.device(hp.device[0]):
22 | all_accs = []
23 | best_acc = 0 # best test accuracy
24 | all_epochs = []
25 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch
26 | all_times = []
27 | time_to_train = 0
28 |
29 | train_correct = 0
30 | train_total = 0
31 | test_correct = 0
32 | test_total = 0
33 |
34 | epochs = []
35 | train_accs = []
36 | test_accs = []
37 |
38 | # Data
39 | print('==> Preparing data..')
40 | transform_train = transforms.Compose([
41 | transforms.RandomCrop(32, padding=4),
42 | transforms.RandomHorizontalFlip(),
43 | transforms.ToTensor(),
44 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
45 | ])
46 |
47 | transform_test = transforms.Compose([
48 | transforms.ToTensor(),
49 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
50 | ])
51 |
52 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
53 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=hp.batch_size, shuffle=True, num_workers=2)
54 |
55 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
56 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
57 |
58 | classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
59 |
60 | def init_weights(m):
61 | if type(m) == nn.Linear or type(m) == nn.Conv2d:
62 | torch.nn.init.kaiming_uniform_(m.weight)
63 |
64 | # Model
65 | print('==> Building model..')
66 | net = models.resnet50()
67 | net.apply(init_weights)
68 | net.cuda()
69 | net = torch.nn.DataParallel(net, device_ids=hp.device)
70 | cudnn.benchmark = True
71 |
72 | if hp.resume:
73 | # Load checkpoint.
74 | print('==> Resuming from checkpoint..')
75 | assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
76 | if hp.with_lars:
77 | checkpoint = torch.load('./checkpoint/withLars-' + str(hp.batch_size) + '.pth')
78 | else:
79 | checkpoint = torch.load('./checkpoint/noLars-' + str(hp.batch_size) + '.pth')
80 | net.load_state_dict(checkpoint['net'])
81 | best_acc = checkpoint['acc']
82 | start_epoch = checkpoint['epoch']
83 | time_to_train = checkpoint['time_to_train']
84 | basic_info = checkpoint['basic_info']
85 |
86 | # Loss & Optimizer
87 | criterion = nn.CrossEntropyLoss()
88 | optimizer = None
89 | if hp.with_lars:
90 | # optimizer = SGD_with_lars(net.parameters(), lr=hp.lr, momentum=hp.momentum, weight_decay=hp.weight_decay, trust_coef=hp.trust_coef)
91 | optimizer = SGD_with_lars_ver2(net.parameters(), lr=hp.lr, momentum=hp.momentum, weight_decay=hp.weight_decay, trust_coef=hp.trust_coef)
92 | else:
93 | # optimizer = SGD_without_lars(net.parameters(), lr=hp.lr, momentum=hp.momentum, weight_decay=hp.weight_decay)
94 | optimizer = optim.SGD(net.parameters(), lr=hp.lr, momentum=hp.momentum, weight_decay=hp.weight_decay)
95 |
96 | warmup_scheduler = GradualWarmupScheduler(optimizer=optimizer, multiplier=hp.warmup_multiplier, total_epoch=hp.warmup_epoch)
97 | poly_decay_scheduler = PolynomialLRDecay(optimizer=optimizer, max_decay_steps=hp.max_decay_epoch * len(trainloader),
98 | end_learning_rate=hp.end_learning_rate, power=2.0) # poly(2)
99 |
100 | # Training
101 | def train(epoch):
102 | global train_total
103 | global train_correct
104 | global time_to_train
105 | net.train()
106 | train_loss = 0
107 | correct = 0
108 | total = 0
109 |
110 | start_time = time.time()
111 | for batch_idx, (inputs, targets) in enumerate(trainloader):
112 | if epoch > hp.warmup_epoch: # after warmup schduler step
113 | poly_decay_scheduler.step()
114 | inputs, targets = inputs.cuda(), targets.cuda()
115 | optimizer.zero_grad()
116 | outputs = net(inputs)
117 | loss = criterion(outputs, targets)
118 | loss.backward()
119 | optimizer.step()
120 |
121 | train_loss += loss.item()
122 | _, predicted = outputs.max(1)
123 | total += targets.size(0)
124 | correct += predicted.eq(targets).sum().item()
125 |
126 | progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
127 | % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
128 | time_to_train = time_to_train + (time.time() - start_time)
129 |
130 | train_total = total
131 | train_correct = correct
132 |
133 | def test(epoch):
134 | global best_acc
135 | global test_total
136 | global test_correct
137 | net.eval()
138 | test_loss = 0
139 | correct = 0
140 | total = 0
141 | with torch.no_grad():
142 | for batch_idx, (inputs, targets) in enumerate(testloader):
143 | inputs, targets = inputs.cuda(), targets.cuda()
144 | outputs = net(inputs)
145 | loss = criterion(outputs, targets)
146 |
147 | test_loss += loss.item()
148 | _, predicted = outputs.max(1)
149 | total += targets.size(0)
150 | correct += predicted.eq(targets).sum().item()
151 |
152 | progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
153 | % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
154 |
155 | test_total = total
156 | test_correct = correct
157 |
158 | # Save checkpoint.
159 | acc = 100.*correct/total
160 | if acc > best_acc:
161 | all_accs.append(acc)
162 | all_epochs.append(epoch)
163 | all_times.append(round(time_to_train, 2))
164 | print('Saving..')
165 | state = {
166 | 'net': net.state_dict(),
167 | 'acc': all_accs,
168 | 'epoch': all_epochs,
169 | 'time_to_train': all_times,
170 | 'basic_info': hp.get_info_dict()
171 | }
172 | if not os.path.isdir('checkpoint'):
173 | os.mkdir('checkpoint')
174 | if hp.with_lars:
175 | torch.save(state, './checkpoint/withLars-' + str(hp.batch_size) + '.pth')
176 | else:
177 | torch.save(state, './checkpoint/noLars-' + str(hp.batch_size) + '.pth')
178 | best_acc = acc
179 |
180 | if hp.with_lars:
181 | print('Resnet50, data=cifar10, With LARS')
182 | else:
183 | print('Resnet50, data=cifar10, Without LARS')
184 | hp.print_hyperparms()
185 | for epoch in range(0, hp.num_of_epoch):
186 | print('\nEpoch: %d' % epoch)
187 | if epoch <= hp.warmup_epoch: # for readability
188 | warmup_scheduler.step()
189 | if epoch > hp.warmup_epoch: # after warmup, start decay scheduler with warmup-ed learning rate
190 | poly_decay_scheduler.base_lrs = warmup_scheduler.get_lr()
191 | for param_group in optimizer.param_groups:
192 | print('lr: ' + str(param_group['lr']))
193 | train(epoch)
194 | test(epoch)
195 |
196 | epochs.append(epoch)
197 | train_accs.append(100.*train_correct/train_total)
198 | test_accs.append(100.*test_correct/test_total)
199 |
200 | plt.plot(epochs, train_accs, epochs, test_accs, 'r-')
201 | state = { 'test_acc': test_accs }
202 |
203 | if not os.path.isdir('result_fig'):
204 | os.mkdir('result_fig')
205 |
206 | if hp.with_lars:
207 | plt.title('Resnet50, data=cifar10, With LARS, batch_size: ' + str(hp.batch_size))
208 | plt.savefig('./result_fig/withLars-' + str(hp.batch_size) + '.jpg')
209 | torch.save(state, './result_fig/withLars-' + str(hp.batch_size) + '.pth')
210 | else:
211 | plt.title('Resnet50, data=cifar10, Without LARS, batch_size: ' + str(hp.batch_size))
212 | plt.savefig('./result_fig/noLars-' + str(hp.batch_size) + '.jpg')
213 | torch.save(state, './result_fig/noLars-' + str(hp.batch_size) + '.pth')
214 |
215 | plt.gcf().clear()
216 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | '''Some helper functions for PyTorch, including:
2 | - get_mean_and_std: calculate the mean and std value of dataset.
3 | - msr_init: net parameter initialization.
4 | - progress_bar: progress bar mimic xlua.progress.
5 | '''
6 | import os
7 | import sys
8 | import time
9 | import math
10 |
11 | import torch.nn as nn
12 | import torch.nn.init as init
13 |
14 |
15 | def get_mean_and_std(dataset):
16 | '''Compute the mean and std value of dataset.'''
17 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
18 | mean = torch.zeros(3)
19 | std = torch.zeros(3)
20 | print('==> Computing mean and std..')
21 | for inputs, targets in dataloader:
22 | for i in range(3):
23 | mean[i] += inputs[:,i,:,:].mean()
24 | std[i] += inputs[:,i,:,:].std()
25 | mean.div_(len(dataset))
26 | std.div_(len(dataset))
27 | return mean, std
28 |
29 | def init_params(net):
30 | '''Init layer parameters.'''
31 | for m in net.modules():
32 | if isinstance(m, nn.Conv2d):
33 | init.kaiming_normal(m.weight, mode='fan_out')
34 | if m.bias:
35 | init.constant(m.bias, 0)
36 | elif isinstance(m, nn.BatchNorm2d):
37 | init.constant(m.weight, 1)
38 | init.constant(m.bias, 0)
39 | elif isinstance(m, nn.Linear):
40 | init.normal(m.weight, std=1e-3)
41 | if m.bias:
42 | init.constant(m.bias, 0)
43 |
44 |
45 | _, term_width = os.popen('stty size', 'r').read().split()
46 | term_width = int(term_width)
47 |
48 | TOTAL_BAR_LENGTH = 65.
49 | last_time = time.time()
50 | begin_time = last_time
51 | def progress_bar(current, total, msg=None):
52 | global last_time, begin_time
53 | if current == 0:
54 | begin_time = time.time() # Reset for new bar.
55 |
56 | cur_len = int(TOTAL_BAR_LENGTH*current/total)
57 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
58 |
59 | sys.stdout.write(' [')
60 | for i in range(cur_len):
61 | sys.stdout.write('=')
62 | sys.stdout.write('>')
63 | for i in range(rest_len):
64 | sys.stdout.write('.')
65 | sys.stdout.write(']')
66 |
67 | cur_time = time.time()
68 | step_time = cur_time - last_time
69 | last_time = cur_time
70 | tot_time = cur_time - begin_time
71 |
72 | L = []
73 | L.append(' Step: %s' % format_time(step_time))
74 | L.append(' | Tot: %s' % format_time(tot_time))
75 | if msg:
76 | L.append(' | ' + msg)
77 |
78 | msg = ''.join(L)
79 | sys.stdout.write(msg)
80 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
81 | sys.stdout.write(' ')
82 |
83 | # Go back to the center of the bar.
84 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
85 | sys.stdout.write('\b')
86 | sys.stdout.write(' %d/%d ' % (current+1, total))
87 |
88 | if current < total-1:
89 | sys.stdout.write('\r')
90 | else:
91 | sys.stdout.write('\n')
92 | sys.stdout.flush()
93 |
94 | def format_time(seconds):
95 | days = int(seconds / 3600/24)
96 | seconds = seconds - days*3600*24
97 | hours = int(seconds / 3600)
98 | seconds = seconds - hours*3600
99 | minutes = int(seconds / 60)
100 | seconds = seconds - minutes*60
101 | secondsf = int(seconds)
102 | seconds = seconds - secondsf
103 | millis = int(seconds*1000)
104 |
105 | f = ''
106 | i = 1
107 | if days > 0:
108 | f += str(days) + 'D'
109 | i += 1
110 | if hours > 0 and i <= 2:
111 | f += str(hours) + 'h'
112 | i += 1
113 | if minutes > 0 and i <= 2:
114 | f += str(minutes) + 'm'
115 | i += 1
116 | if secondsf > 0 and i <= 2:
117 | f += str(secondsf) + 's'
118 | i += 1
119 | if millis > 0 and i <= 2:
120 | f += str(millis) + 'ms'
121 | i += 1
122 | if f == '':
123 | f = '0ms'
124 | return f
--------------------------------------------------------------------------------
/val.py:
--------------------------------------------------------------------------------
1 | '''Train CIFAR10 with PyTorch.'''
2 | import torch
3 | import torch.nn as nn
4 | import torch.optim as optim
5 | import torch.nn.functional as F
6 | import torch.backends.cudnn as cudnn
7 |
8 | import torchvision
9 | import torchvision.transforms as transforms
10 | import torchvision.models as models
11 |
12 | import os
13 | import argparse
14 |
15 | from hyperparams import Hyperparams_for_val as hp
16 | from utils import progress_bar
17 |
18 | with torch.cuda.device(hp.device[0]):
19 | # Data
20 | print('==> Preparing data..')
21 | transform_test = transforms.Compose([
22 | transforms.ToTensor(),
23 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
24 | ])
25 |
26 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
27 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
28 |
29 | classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
30 |
31 | # Model
32 | print('==> Building model..')
33 | net = models.resnet50()
34 | net.cuda()
35 | net = torch.nn.DataParallel(net, device_ids=hp.device)
36 | cudnn.benchmark = True
37 |
38 | # Load checkpoint.
39 | print('==> Resuming from checkpoint..')
40 | assert os.path.isdir(hp.checkpoint_folder_name), 'Error: no checkpoint directory found!'
41 | if hp.with_lars:
42 | checkpoint = torch.load('./' + hp.checkpoint_folder_name + '/withLars-' + str(hp.batch_size) + '.pth')
43 | else:
44 | checkpoint = torch.load('./' + hp.checkpoint_folder_name + '/noLars-' + str(hp.batch_size) + '.pth')
45 | net.load_state_dict(checkpoint['net'])
46 | best_acc = checkpoint['acc']
47 | epoch = checkpoint['epoch']
48 | time_to_train = checkpoint['time_to_train'] # after 2nd
49 | basic_info = checkpoint['basic_info'] # after 3rd
50 |
51 | criterion = nn.CrossEntropyLoss()
52 |
53 | def test():
54 | global best_acc
55 | net.eval()
56 | test_loss = 0
57 | correct = 0
58 | total = 0
59 | with torch.no_grad():
60 | for batch_idx, (inputs, targets) in enumerate(testloader):
61 | inputs, targets = inputs.cuda(), targets.cuda()
62 | outputs = net(inputs)
63 | loss = criterion(outputs, targets)
64 |
65 | test_loss += loss.item()
66 | _, predicted = outputs.max(1)
67 | total += targets.size(0)
68 | correct += predicted.eq(targets).sum().item()
69 |
70 | progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
71 | % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
72 |
73 |
74 | if hp.with_lars:
75 | print('Resnet50, data=cifar10, With LARS, Validation')
76 | else:
77 | print('Resnet50, data=cifar10, Without LARS, Validation')
78 | print('basic_info=' + str(basic_info))
79 |
80 | for epo, acc, time in zip(epoch, best_acc, time_to_train):
81 | print (str(epo) + ' epoch | ' + str(acc) + ' % | ' + str(time) + ' sec')
82 |
83 | test()
84 |
85 |
86 |
--------------------------------------------------------------------------------