├── .gitignore ├── README.md ├── chart.xlsx ├── chart_pretty.xlsx ├── draw_chart.py ├── log_original.json ├── log_prune.json ├── log_prune_layer.json ├── main.py ├── model_refactor.py ├── models ├── __init__.py ├── densenet.py ├── dpn.py ├── googlenet.py ├── lenet.py ├── mobilenet.py ├── mobilenetv2.py ├── pnasnet.py ├── preact_resnet.py ├── resnet.py ├── resnext.py ├── senet.py ├── shufflenet.py └── vgg.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | *.gz 3 | *.meta 4 | .idea/ 5 | *.xml 6 | *.iml 7 | *.pyc 8 | data/ 9 | train/ 10 | checkpoint/ 11 | checkpoint_backup/ 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Train & Pruning with PyTorch 2 | by hou-yz, based on `kuangliu/pytorch-cifar` 3 | 4 | improve inference speed and reduce intermediate feature sizes to favor distributed inference (local device compute half of the model and upload the feature for further computing on stronger devices or cloud). 5 | 6 | - pruning stage-1: prune the whole model to increase inference speed and slightly reduce intermediate feature sizes. 7 | 8 | - pruning stage-2: (based on step-1's model) for each split-point (where the intermediate feature is transferred to another device for further computation), specifically prune the layer just before the split-point to reduce intermediate feature sizes even more. 9 | 10 | only support python3 with pytorch > 0.3.1; 11 | model trained on cifar-10, tested only on vgg-16. 12 | 13 | also added auto-logging and auto chart-drawing. 14 | 15 | ## usage 16 | ### training: 17 | ```lua 18 | python main.py --train # train from scratch 19 | python main.py --resume # resume training 20 | ``` 21 | 22 | ### 2-step pruning: 23 | 24 | first, in step-1, you can prune the whole model by 25 | ```lua 26 | python main.py --prune # prune the whole model 27 | ``` 28 | 29 | once you finished step-1, you can then prune each layer (step-2) individually for minimum bandwidth requirement with 30 | ``` lua 31 | python main.py --prune_layer # prune layers and save models separately 32 | ``` 33 | 34 | ### chart drawing: 35 | 36 | for logging and excel chart drawing, try 37 | ```lua 38 | python maim.py --test_pruned # test the pruned model and save *.json logs 39 | python draw_chart.py 40 | ``` 41 | which automatically generate the `chart.xlsx` file. 42 | 43 | 44 | ## updates 45 | - added pruning features; 46 | - added 2-stage pruning method: --prune & --prune_layer 47 | - added draw_chart with `openpyxl` (open in excel); 48 | - added cpu-only support and windows support. 49 | -------------------------------------------------------------------------------- /chart.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hou-yz/pytorch-pruning-2step/484bda49c785e483973a225168ff09206241975c/chart.xlsx -------------------------------------------------------------------------------- /chart_pretty.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hou-yz/pytorch-pruning-2step/484bda49c785e483973a225168ff09206241975c/chart_pretty.xlsx -------------------------------------------------------------------------------- /draw_chart.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | from openpyxl import Workbook 4 | from openpyxl.chart import ( 5 | LineChart, 6 | BarChart, 7 | Reference, 8 | Series, 9 | ) 10 | 11 | with open('./log_original.json') as fp: 12 | original_data = json.load(fp) 13 | with open('./log_prune.json') as fp: 14 | prune_data = json.load(fp) 15 | with open('./log_prune_layer.json') as fp: 16 | prune_layer_data = json.load(fp) 17 | 18 | bandwidth0 = np.power(32, 2) * 3 19 | 20 | ############################### 21 | # original 22 | ############################### 23 | 24 | layer_cfg = ['', ''] 25 | acc = ['acc_original', ''] 26 | # delta_t = ['time_original', 0] 27 | delta_ts = 0 28 | # delta_t_computations = ['', ''] 29 | bandwidth = ['bandwidth_original', bandwidth0] 30 | # all_conv_computations = ['', ''] 31 | for data in original_data: 32 | layer_cfg.append(data['layer_cfg']) 33 | acc.append(data['acc']) 34 | # delta_t.append(data['delta_t']) 35 | delta_ts += np.array(data['delta_ts']) 36 | # delta_t_computations.append(data['delta_t_computations']) 37 | bandwidth.append(data['bandwidth']) 38 | # all_conv_computations.append(data['all_conv_computations']) 39 | delta_ts = ['time_original', 0]+list(delta_ts / len(original_data)) 40 | 41 | wb = Workbook() 42 | # ws = wb.create_sheet("original", 0) 43 | ws = wb.active 44 | ws.append(layer_cfg) 45 | ws.append(acc) 46 | # ws.append(delta_t) 47 | ws.append(delta_ts) 48 | # ws.append(delta_t_computations) 49 | ws.append(bandwidth) 50 | # ws.append(all_conv_computations) 51 | 52 | ############################### 53 | # prune 54 | ############################### 55 | 56 | layer_cfg = ['', ''] 57 | acc = ['acc_prune_s1', ''] 58 | # delta_t = ['time_prune', 0] 59 | delta_ts = 0 60 | # delta_t_computations = ['', ''] 61 | bandwidth = ['bandwidth_prune_s1', bandwidth0] 62 | # all_conv_computations = ['', ''] 63 | for data in prune_data: 64 | layer_cfg.append(data['layer_cfg']) 65 | acc.append(data['acc']) 66 | # delta_t.append(data['delta_t']) 67 | delta_ts += np.array(data['delta_ts']) 68 | # delta_t_computations.append(data['delta_t_computations']) 69 | bandwidth.append(data['bandwidth']) 70 | # all_conv_computations.append(data['all_conv_computations']) 71 | delta_ts = ['time_prune_s1', 0]+list(delta_ts / len(prune_data)) 72 | 73 | # ws = wb.create_sheet("prune_layer", 0) 74 | for i in range(3): 75 | ws.append(list()) 76 | ws.append(layer_cfg) 77 | ws.append(acc) 78 | # ws.append(delta_t) 79 | ws.append(delta_ts) 80 | # ws.append(delta_t_computations) 81 | ws.append(bandwidth) 82 | # ws.append(all_conv_computations) 83 | 84 | ############################### 85 | # prune_layer 86 | ############################### 87 | 88 | layer_cfg = ['', ''] 89 | acc = ['acc_prune_s2', ''] 90 | # delta_t = ['time_prune_layer', 0] 91 | delta_ts = 0 92 | # delta_t_computations = ['', ''] 93 | bandwidth = ['bandwidth_prune_s2', bandwidth0] 94 | # all_conv_computations = ['', ''] 95 | for data in prune_layer_data: 96 | layer_cfg.append(data['layer_cfg']) 97 | acc.append(data['acc']) 98 | # delta_t.append(data['delta_t']) 99 | delta_ts += np.array(data['delta_ts']) 100 | # delta_t_computations.append(data['delta_t_computations']) 101 | bandwidth.append(data['bandwidth']) 102 | # all_conv_computations.append(data['all_conv_computations']) 103 | delta_ts = ['time_prune_s2', 0]+list(delta_ts / len(original_data)) 104 | 105 | # ws = wb.create_sheet("prune_layer", 0) 106 | for i in range(3): 107 | ws.append(list()) 108 | ws.append(layer_cfg) 109 | ws.append(acc) 110 | # ws.append(delta_t) 111 | ws.append(delta_ts) 112 | # ws.append(delta_t_computations) 113 | ws.append(bandwidth) 114 | # ws.append(all_conv_computations) 115 | 116 | for i in range(3): 117 | ws.append(list()) 118 | for index in range(len(layer_cfg)): 119 | if isinstance(layer_cfg[index], int): 120 | layer_cfg[index] = 'conv' 121 | elif layer_cfg[index] == '' and index == 1: 122 | layer_cfg[index] = 'og image' 123 | elif layer_cfg[index] == '': 124 | pass 125 | else: 126 | layer_cfg[index] = 'pool' 127 | 128 | ws.append(layer_cfg) 129 | 130 | # draw chart 131 | time_original = Reference(ws, min_col=1, min_row=3, max_col=20) 132 | time_prune = Reference(ws, min_col=1, min_row=10, max_col=20) 133 | time_prune_layer = Reference(ws, min_col=1, min_row=17, max_col=20) 134 | bandwidth_original = Reference(ws, min_col=1, min_row=4, max_col=20) 135 | bandwidth_prune = Reference(ws, min_col=1, min_row=11, max_col=20) 136 | bandwidth_prune_layer = Reference(ws, min_col=1, min_row=18, max_col=20) 137 | # line chart for time 138 | c1 = LineChart() 139 | c1.add_data(time_original, titles_from_data=True, from_rows=True) 140 | c1.add_data(time_prune, titles_from_data=True, from_rows=True) 141 | c1.add_data(time_prune_layer, titles_from_data=True, from_rows=True) 142 | 143 | cats = Reference(ws, min_col=2, min_row=22, max_col=20) 144 | c1.set_categories(cats) 145 | 146 | c1.x_axis.title = 'layers' 147 | c1.y_axis.title = 'time elapsed (ms)' 148 | c1.y_axis.majorGridlines = None 149 | c1.title = 'bandwidth/time' 150 | 151 | # bar chart for bandwidth 152 | c2 = BarChart() 153 | c2.add_data(bandwidth_original, titles_from_data=True, from_rows=True) 154 | c2.add_data(bandwidth_prune, titles_from_data=True, from_rows=True) 155 | c2.add_data(bandwidth_prune_layer, titles_from_data=True, from_rows=True) 156 | c2.y_axis.axId = 200 # set axid other than 100 157 | c2.y_axis.title = 'data volume (B)' 158 | 159 | # Display y-axis of the second chart on the right by setting it to cross the x-axis at its maximum 160 | c1.y_axis.crosses = "max" 161 | c1 += c2 162 | 163 | ws.add_chart(c1, "D4") 164 | 165 | wb.save('chart.xlsx') 166 | -------------------------------------------------------------------------------- /log_original.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "config": [ 4 | 64, 5 | 64, 6 | "M", 7 | 128, 8 | 128, 9 | "M", 10 | 256, 11 | 256, 12 | 256, 13 | "M", 14 | 512, 15 | 512, 16 | 512, 17 | "M", 18 | 512, 19 | 512, 20 | 512, 21 | "M" 22 | ], 23 | "acc": 88.23, 24 | "index": 0, 25 | "delta_ts": [ 26 | 1.079476125, 27 | 3.954809, 28 | 4.3483299375, 29 | 6.06126546875, 30 | 8.8047388125, 31 | 9.082637125, 32 | 10.29483225, 33 | 12.85673659375, 34 | 15.64298046875, 35 | 15.7937188125, 36 | 17.00898821875, 37 | 19.35287284375, 38 | 21.41161803125, 39 | 21.492163375, 40 | 23.1957933125, 41 | 25.04465759375, 42 | 26.93542715625, 43 | 26.95500021875 44 | ], 45 | "delta_t_prof": 1.079476125, 46 | "layer_cfg": 64, 47 | "bandwidths": [ 48 | 65536, 49 | 65536, 50 | 16384, 51 | 32768, 52 | 32768, 53 | 8192, 54 | 16384, 55 | 16384, 56 | 16384, 57 | 4096, 58 | 8192, 59 | 8192, 60 | 8192, 61 | 2048, 62 | 2048, 63 | 2048, 64 | 2048, 65 | 512 66 | ], 67 | "bandwidth": 65536 68 | }, 69 | { 70 | "config": [ 71 | 64, 72 | 64, 73 | "M", 74 | 128, 75 | 128, 76 | "M", 77 | 256, 78 | 256, 79 | 256, 80 | "M", 81 | 512, 82 | 512, 83 | 512, 84 | "M", 85 | 512, 86 | 512, 87 | 512, 88 | "M" 89 | ], 90 | "acc": 88.23, 91 | "index": 1, 92 | "delta_ts": [ 93 | 0.36192209375, 94 | 3.3214393125, 95 | 3.7927945625, 96 | 4.952026375, 97 | 7.0860606875, 98 | 7.3639601875, 99 | 8.4734473125, 100 | 10.905641625, 101 | 13.332847125, 102 | 13.46183965625, 103 | 14.76172778125, 104 | 16.96781453125, 105 | 19.3109675, 106 | 19.3933250625, 107 | 21.238050375, 108 | 22.8482285, 109 | 24.63402640625, 110 | 24.6562094375 111 | ], 112 | "delta_t_prof": 3.3214393125, 113 | "layer_cfg": 64, 114 | "bandwidths": [ 115 | 65536, 116 | 65536, 117 | 16384, 118 | 32768, 119 | 32768, 120 | 8192, 121 | 16384, 122 | 16384, 123 | 16384, 124 | 4096, 125 | 8192, 126 | 8192, 127 | 8192, 128 | 2048, 129 | 2048, 130 | 2048, 131 | 2048, 132 | 512 133 | ], 134 | "bandwidth": 65536 135 | }, 136 | { 137 | "config": [ 138 | 64, 139 | 64, 140 | "M", 141 | 128, 142 | 128, 143 | "M", 144 | 256, 145 | 256, 146 | 256, 147 | "M", 148 | 512, 149 | 512, 150 | 512, 151 | "M", 152 | 512, 153 | 512, 154 | 512, 155 | "M" 156 | ], 157 | "acc": 88.23, 158 | "index": 2, 159 | "delta_ts": [ 160 | 0.4172368125, 161 | 3.37490240625, 162 | 3.87276778125, 163 | 5.0941424375, 164 | 7.62177146875, 165 | 7.871181875, 166 | 9.0241135625, 167 | 11.359977625, 168 | 13.57742009375, 169 | 13.7394043125, 170 | 14.90913103125, 171 | 17.14248490625, 172 | 19.30123565625, 173 | 19.40126084375, 174 | 21.280373125, 175 | 23.21622325, 176 | 25.0704575625, 177 | 25.093628375 178 | ], 179 | "delta_t_prof": 3.87276778125, 180 | "layer_cfg": "M", 181 | "bandwidths": [ 182 | 65536, 183 | 65536, 184 | 16384, 185 | 32768, 186 | 32768, 187 | 8192, 188 | 16384, 189 | 16384, 190 | 16384, 191 | 4096, 192 | 8192, 193 | 8192, 194 | 8192, 195 | 2048, 196 | 2048, 197 | 2048, 198 | 2048, 199 | 512 200 | ], 201 | "bandwidth": 16384 202 | }, 203 | { 204 | "config": [ 205 | 64, 206 | 64, 207 | "M", 208 | 128, 209 | 128, 210 | "M", 211 | 256, 212 | 256, 213 | 256, 214 | "M", 215 | 512, 216 | 512, 217 | 512, 218 | "M", 219 | 512, 220 | 512, 221 | 512, 222 | "M" 223 | ], 224 | "acc": 88.23, 225 | "index": 3, 226 | "delta_ts": [ 227 | 0.26607765625, 228 | 3.002022625, 229 | 3.50503571875, 230 | 4.8123344375, 231 | 7.2982241875, 232 | 7.56252884375, 233 | 8.6598173125, 234 | 10.6310014375, 235 | 12.76463640625, 236 | 12.92665621875, 237 | 14.0271485, 238 | 16.0472400625, 239 | 18.1458263125, 240 | 18.22918003125, 241 | 20.1045819375, 242 | 22.0091935625, 243 | 23.89325046875, 244 | 23.91619309375 245 | ], 246 | "delta_t_prof": 4.8123344375, 247 | "layer_cfg": 128, 248 | "bandwidths": [ 249 | 65536, 250 | 65536, 251 | 16384, 252 | 32768, 253 | 32768, 254 | 8192, 255 | 16384, 256 | 16384, 257 | 16384, 258 | 4096, 259 | 8192, 260 | 8192, 261 | 8192, 262 | 2048, 263 | 2048, 264 | 2048, 265 | 2048, 266 | 512 267 | ], 268 | "bandwidth": 32768 269 | }, 270 | { 271 | "config": [ 272 | 64, 273 | 64, 274 | "M", 275 | 128, 276 | 128, 277 | "M", 278 | 256, 279 | 256, 280 | 256, 281 | "M", 282 | 512, 283 | 512, 284 | 512, 285 | "M", 286 | 512, 287 | 512, 288 | 512, 289 | "M" 290 | ], 291 | "acc": 88.23, 292 | "index": 4, 293 | "delta_ts": [ 294 | 0.24667825, 295 | 3.482759375, 296 | 3.99284359375, 297 | 5.2927899375, 298 | 7.77811003125, 299 | 8.04561675, 300 | 9.119533625, 301 | 11.3275408125, 302 | 13.55115334375, 303 | 13.68344796875, 304 | 14.8595066875, 305 | 17.15848171875, 306 | 19.5909834375, 307 | 19.6812113125, 308 | 21.482073875, 309 | 23.302907875, 310 | 25.15411528125, 311 | 25.1735260625 312 | ], 313 | "delta_t_prof": 7.77811003125, 314 | "layer_cfg": 128, 315 | "bandwidths": [ 316 | 65536, 317 | 65536, 318 | 16384, 319 | 32768, 320 | 32768, 321 | 8192, 322 | 16384, 323 | 16384, 324 | 16384, 325 | 4096, 326 | 8192, 327 | 8192, 328 | 8192, 329 | 2048, 330 | 2048, 331 | 2048, 332 | 2048, 333 | 512 334 | ], 335 | "bandwidth": 32768 336 | }, 337 | { 338 | "config": [ 339 | 64, 340 | 64, 341 | "M", 342 | 128, 343 | 128, 344 | "M", 345 | 256, 346 | 256, 347 | 256, 348 | "M", 349 | 512, 350 | 512, 351 | 512, 352 | "M", 353 | 512, 354 | 512, 355 | 512, 356 | "M" 357 | ], 358 | "acc": 88.23, 359 | "index": 5, 360 | "delta_ts": [ 361 | 0.211049625, 362 | 3.150714625, 363 | 3.65589796875, 364 | 5.14470846875, 365 | 7.8652533125, 366 | 8.107731625, 367 | 9.42291384375, 368 | 11.74241421875, 369 | 14.1808945, 370 | 14.33600871875, 371 | 15.5394396875, 372 | 17.96579040625, 373 | 20.4020926875, 374 | 20.47760275, 375 | 22.441936375, 376 | 24.30585921875, 377 | 26.1940496875, 378 | 26.21521890625 379 | ], 380 | "delta_t_prof": 8.107731625, 381 | "layer_cfg": "M", 382 | "bandwidths": [ 383 | 65536, 384 | 65536, 385 | 16384, 386 | 32768, 387 | 32768, 388 | 8192, 389 | 16384, 390 | 16384, 391 | 16384, 392 | 4096, 393 | 8192, 394 | 8192, 395 | 8192, 396 | 2048, 397 | 2048, 398 | 2048, 399 | 2048, 400 | 512 401 | ], 402 | "bandwidth": 8192 403 | }, 404 | { 405 | "config": [ 406 | 64, 407 | 64, 408 | "M", 409 | 128, 410 | 128, 411 | "M", 412 | 256, 413 | 256, 414 | 256, 415 | "M", 416 | 512, 417 | 512, 418 | 512, 419 | "M", 420 | 512, 421 | 512, 422 | 512, 423 | "M" 424 | ], 425 | "acc": 88.23, 426 | "index": 6, 427 | "delta_ts": [ 428 | 0.209134875, 429 | 3.0037641875, 430 | 3.340062375, 431 | 4.57748253125, 432 | 7.18409109375, 433 | 7.45105746875, 434 | 8.76287046875, 435 | 11.2532826875, 436 | 14.1825244375, 437 | 14.3348414375, 438 | 15.7962983125, 439 | 18.7019816875, 440 | 21.16464834375, 441 | 21.238951, 442 | 23.2121635625, 443 | 25.198928125, 444 | 27.14013090625, 445 | 27.15969615625 446 | ], 447 | "delta_t_prof": 8.76287046875, 448 | "layer_cfg": 256, 449 | "bandwidths": [ 450 | 65536, 451 | 65536, 452 | 16384, 453 | 32768, 454 | 32768, 455 | 8192, 456 | 16384, 457 | 16384, 458 | 16384, 459 | 4096, 460 | 8192, 461 | 8192, 462 | 8192, 463 | 2048, 464 | 2048, 465 | 2048, 466 | 2048, 467 | 512 468 | ], 469 | "bandwidth": 16384 470 | }, 471 | { 472 | "config": [ 473 | 64, 474 | 64, 475 | "M", 476 | 128, 477 | 128, 478 | "M", 479 | 256, 480 | 256, 481 | 256, 482 | "M", 483 | 512, 484 | 512, 485 | 512, 486 | "M", 487 | 512, 488 | 512, 489 | 512, 490 | "M" 491 | ], 492 | "acc": 88.23, 493 | "index": 7, 494 | "delta_ts": [ 495 | 0.26585403125, 496 | 3.13119534375, 497 | 3.50398034375, 498 | 4.742295625, 499 | 6.90907025, 500 | 7.1549385625, 501 | 8.42726553125, 502 | 10.849104375, 503 | 13.4641250625, 504 | 13.590415875, 505 | 14.82141740625, 506 | 17.1564373125, 507 | 19.47245825, 508 | 19.5528114375, 509 | 21.55806303125, 510 | 23.70490071875, 511 | 25.6725584375, 512 | 25.69850521875 513 | ], 514 | "delta_t_prof": 10.849104375, 515 | "layer_cfg": 256, 516 | "bandwidths": [ 517 | 65536, 518 | 65536, 519 | 16384, 520 | 32768, 521 | 32768, 522 | 8192, 523 | 16384, 524 | 16384, 525 | 16384, 526 | 4096, 527 | 8192, 528 | 8192, 529 | 8192, 530 | 2048, 531 | 2048, 532 | 2048, 533 | 2048, 534 | 512 535 | ], 536 | "bandwidth": 16384 537 | }, 538 | { 539 | "config": [ 540 | 64, 541 | 64, 542 | "M", 543 | 128, 544 | 128, 545 | "M", 546 | 256, 547 | 256, 548 | 256, 549 | "M", 550 | 512, 551 | 512, 552 | 512, 553 | "M", 554 | 512, 555 | 512, 556 | 512, 557 | "M" 558 | ], 559 | "acc": 88.23, 560 | "index": 8, 561 | "delta_ts": [ 562 | 0.21252484375, 563 | 3.46893690625, 564 | 3.91000121875, 565 | 5.51080971875, 566 | 8.63835259375, 567 | 8.93069625, 568 | 10.3580884375, 569 | 13.03069003125, 570 | 15.69355284375, 571 | 15.8586594375, 572 | 17.2061185, 573 | 19.334315125, 574 | 21.81936725, 575 | 21.883176625, 576 | 23.84942834375, 577 | 25.8906224375, 578 | 27.7998233125, 579 | 27.82133934375 580 | ], 581 | "delta_t_prof": 15.69355284375, 582 | "layer_cfg": 256, 583 | "bandwidths": [ 584 | 65536, 585 | 65536, 586 | 16384, 587 | 32768, 588 | 32768, 589 | 8192, 590 | 16384, 591 | 16384, 592 | 16384, 593 | 4096, 594 | 8192, 595 | 8192, 596 | 8192, 597 | 2048, 598 | 2048, 599 | 2048, 600 | 2048, 601 | 512 602 | ], 603 | "bandwidth": 16384 604 | }, 605 | { 606 | "config": [ 607 | 64, 608 | 64, 609 | "M", 610 | 128, 611 | 128, 612 | "M", 613 | 256, 614 | 256, 615 | 256, 616 | "M", 617 | 512, 618 | 512, 619 | 512, 620 | "M", 621 | 512, 622 | 512, 623 | 512, 624 | "M" 625 | ], 626 | "acc": 88.23, 627 | "index": 9, 628 | "delta_ts": [ 629 | 0.2119936875, 630 | 2.79953890625, 631 | 3.3059246875, 632 | 4.538957, 633 | 7.19269221875, 634 | 7.44203846875, 635 | 8.6525339375, 636 | 10.99116878125, 637 | 13.27812603125, 638 | 13.4389594375, 639 | 14.62817509375, 640 | 17.089379625, 641 | 19.76650071875, 642 | 19.85676746875, 643 | 21.86523753125, 644 | 23.69818825, 645 | 25.6130354375, 646 | 25.63380275 647 | ], 648 | "delta_t_prof": 13.4389594375, 649 | "layer_cfg": "M", 650 | "bandwidths": [ 651 | 65536, 652 | 65536, 653 | 16384, 654 | 32768, 655 | 32768, 656 | 8192, 657 | 16384, 658 | 16384, 659 | 16384, 660 | 4096, 661 | 8192, 662 | 8192, 663 | 8192, 664 | 2048, 665 | 2048, 666 | 2048, 667 | 2048, 668 | 512 669 | ], 670 | "bandwidth": 4096 671 | }, 672 | { 673 | "config": [ 674 | 64, 675 | 64, 676 | "M", 677 | 128, 678 | 128, 679 | "M", 680 | 256, 681 | 256, 682 | 256, 683 | "M", 684 | 512, 685 | 512, 686 | 512, 687 | "M", 688 | 512, 689 | 512, 690 | 512, 691 | "M" 692 | ], 693 | "acc": 88.23, 694 | "index": 10, 695 | "delta_ts": [ 696 | 0.211794625, 697 | 3.12214246875, 698 | 3.58992553125, 699 | 4.9420725625, 700 | 7.2513660625, 701 | 7.41544446875, 702 | 8.33462684375, 703 | 10.0990430625, 704 | 12.030658, 705 | 12.189252625, 706 | 13.184183625, 707 | 15.102331625, 708 | 17.36350925, 709 | 17.42069978125, 710 | 19.094894875, 711 | 20.86794215625, 712 | 22.43994921875, 713 | 22.4638346875 714 | ], 715 | "delta_t_prof": 13.184183625, 716 | "layer_cfg": 512, 717 | "bandwidths": [ 718 | 65536, 719 | 65536, 720 | 16384, 721 | 32768, 722 | 32768, 723 | 8192, 724 | 16384, 725 | 16384, 726 | 16384, 727 | 4096, 728 | 8192, 729 | 8192, 730 | 8192, 731 | 2048, 732 | 2048, 733 | 2048, 734 | 2048, 735 | 512 736 | ], 737 | "bandwidth": 8192 738 | }, 739 | { 740 | "config": [ 741 | 64, 742 | 64, 743 | "M", 744 | 128, 745 | 128, 746 | "M", 747 | 256, 748 | 256, 749 | 256, 750 | "M", 751 | 512, 752 | 512, 753 | 512, 754 | "M", 755 | 512, 756 | 512, 757 | 512, 758 | "M" 759 | ], 760 | "acc": 88.23, 761 | "index": 11, 762 | "delta_ts": [ 763 | 0.2622781875, 764 | 3.14403575, 765 | 3.4121335, 766 | 4.69494596875, 767 | 7.216147, 768 | 7.4365556875, 769 | 8.59483484375, 770 | 10.683376625, 771 | 13.1957175, 772 | 13.35641490625, 773 | 14.566357625, 774 | 17.2731430625, 775 | 19.90783203125, 776 | 19.98781975, 777 | 22.001437625, 778 | 24.0147763125, 779 | 25.93400353125, 780 | 25.95686640625 781 | ], 782 | "delta_t_prof": 17.2731430625, 783 | "layer_cfg": 512, 784 | "bandwidths": [ 785 | 65536, 786 | 65536, 787 | 16384, 788 | 32768, 789 | 32768, 790 | 8192, 791 | 16384, 792 | 16384, 793 | 16384, 794 | 4096, 795 | 8192, 796 | 8192, 797 | 8192, 798 | 2048, 799 | 2048, 800 | 2048, 801 | 2048, 802 | 512 803 | ], 804 | "bandwidth": 8192 805 | }, 806 | { 807 | "config": [ 808 | 64, 809 | 64, 810 | "M", 811 | 128, 812 | 128, 813 | "M", 814 | 256, 815 | 256, 816 | 256, 817 | "M", 818 | 512, 819 | 512, 820 | 512, 821 | "M", 822 | 512, 823 | 512, 824 | 512, 825 | "M" 826 | ], 827 | "acc": 88.23, 828 | "index": 12, 829 | "delta_ts": [ 830 | 0.2392026875, 831 | 3.02808634375, 832 | 3.5384660625, 833 | 4.9235505, 834 | 7.03007840625, 835 | 7.2265114375, 836 | 8.19167478125, 837 | 9.92577740625, 838 | 12.083673125, 839 | 12.2387815, 840 | 13.23973815625, 841 | 15.1574694375, 842 | 17.40155915625, 843 | 17.4836526875, 844 | 19.1626485625, 845 | 21.19844328125, 846 | 23.25766775, 847 | 23.273987625 848 | ], 849 | "delta_t_prof": 17.40155915625, 850 | "layer_cfg": 512, 851 | "bandwidths": [ 852 | 65536, 853 | 65536, 854 | 16384, 855 | 32768, 856 | 32768, 857 | 8192, 858 | 16384, 859 | 16384, 860 | 16384, 861 | 4096, 862 | 8192, 863 | 8192, 864 | 8192, 865 | 2048, 866 | 2048, 867 | 2048, 868 | 2048, 869 | 512 870 | ], 871 | "bandwidth": 8192 872 | }, 873 | { 874 | "config": [ 875 | 64, 876 | 64, 877 | "M", 878 | 128, 879 | 128, 880 | "M", 881 | 256, 882 | 256, 883 | 256, 884 | "M", 885 | 512, 886 | 512, 887 | 512, 888 | "M", 889 | 512, 890 | 512, 891 | 512, 892 | "M" 893 | ], 894 | "acc": 88.23, 895 | "index": 13, 896 | "delta_ts": [ 897 | 0.38205603125, 898 | 3.1368751875, 899 | 3.401371875, 900 | 4.68972234375, 901 | 7.14045325, 902 | 7.4064506875, 903 | 8.62990471875, 904 | 11.142618375, 905 | 13.46417459375, 906 | 13.60499290625, 907 | 14.738694625, 908 | 17.027095625, 909 | 19.48182128125, 910 | 19.57386471875, 911 | 21.85103915625, 912 | 23.81573528125, 913 | 25.63742690625, 914 | 25.65683725 915 | ], 916 | "delta_t_prof": 19.57386471875, 917 | "layer_cfg": "M", 918 | "bandwidths": [ 919 | 65536, 920 | 65536, 921 | 16384, 922 | 32768, 923 | 32768, 924 | 8192, 925 | 16384, 926 | 16384, 927 | 16384, 928 | 4096, 929 | 8192, 930 | 8192, 931 | 8192, 932 | 2048, 933 | 2048, 934 | 2048, 935 | 2048, 936 | 512 937 | ], 938 | "bandwidth": 2048 939 | }, 940 | { 941 | "config": [ 942 | 64, 943 | 64, 944 | "M", 945 | 128, 946 | 128, 947 | "M", 948 | 256, 949 | 256, 950 | 256, 951 | "M", 952 | 512, 953 | 512, 954 | 512, 955 | "M", 956 | 512, 957 | 512, 958 | 512, 959 | "M" 960 | ], 961 | "acc": 88.23, 962 | "index": 14, 963 | "delta_ts": [ 964 | 0.22899646875, 965 | 3.07934540625, 966 | 3.51239440625, 967 | 4.97563734375, 968 | 7.1877573125, 969 | 7.34565009375, 970 | 8.43095528125, 971 | 10.16363340625, 972 | 11.86939834375, 973 | 11.9588823125, 974 | 12.9962240625, 975 | 14.87955746875, 976 | 17.45216759375, 977 | 17.5048839375, 978 | 19.61883196875, 979 | 21.30612025, 980 | 22.8939265625, 981 | 22.915834 982 | ], 983 | "delta_t_prof": 19.61883196875, 984 | "layer_cfg": 512, 985 | "bandwidths": [ 986 | 65536, 987 | 65536, 988 | 16384, 989 | 32768, 990 | 32768, 991 | 8192, 992 | 16384, 993 | 16384, 994 | 16384, 995 | 4096, 996 | 8192, 997 | 8192, 998 | 8192, 999 | 2048, 1000 | 2048, 1001 | 2048, 1002 | 2048, 1003 | 512 1004 | ], 1005 | "bandwidth": 2048 1006 | }, 1007 | { 1008 | "config": [ 1009 | 64, 1010 | 64, 1011 | "M", 1012 | 128, 1013 | 128, 1014 | "M", 1015 | 256, 1016 | 256, 1017 | 256, 1018 | "M", 1019 | 512, 1020 | 512, 1021 | 512, 1022 | "M", 1023 | 512, 1024 | 512, 1025 | 512, 1026 | "M" 1027 | ], 1028 | "acc": 88.23, 1029 | "index": 15, 1030 | "delta_ts": [ 1031 | 0.32147771875, 1032 | 3.25671384375, 1033 | 3.69145253125, 1034 | 5.080956625, 1035 | 7.58786090625, 1036 | 7.855957625, 1037 | 9.158178875, 1038 | 11.42036725, 1039 | 13.9902993125, 1040 | 14.162632875, 1041 | 15.389342875, 1042 | 17.43671934375, 1043 | 19.72048296875, 1044 | 19.7969619375, 1045 | 21.50698684375, 1046 | 23.2030594375, 1047 | 24.80739934375, 1048 | 24.82306690625 1049 | ], 1050 | "delta_t_prof": 23.2030594375, 1051 | "layer_cfg": 512, 1052 | "bandwidths": [ 1053 | 65536, 1054 | 65536, 1055 | 16384, 1056 | 32768, 1057 | 32768, 1058 | 8192, 1059 | 16384, 1060 | 16384, 1061 | 16384, 1062 | 4096, 1063 | 8192, 1064 | 8192, 1065 | 8192, 1066 | 2048, 1067 | 2048, 1068 | 2048, 1069 | 2048, 1070 | 512 1071 | ], 1072 | "bandwidth": 2048 1073 | }, 1074 | { 1075 | "config": [ 1076 | 64, 1077 | 64, 1078 | "M", 1079 | 128, 1080 | 128, 1081 | "M", 1082 | 256, 1083 | 256, 1084 | 256, 1085 | "M", 1086 | 512, 1087 | 512, 1088 | 512, 1089 | "M", 1090 | 512, 1091 | 512, 1092 | 512, 1093 | "M" 1094 | ], 1095 | "acc": 88.23, 1096 | "index": 16, 1097 | "delta_ts": [ 1098 | 0.252151, 1099 | 3.0825285, 1100 | 3.41873259375, 1101 | 4.8367830625, 1102 | 7.027621375, 1103 | 7.27149221875, 1104 | 8.4693051875, 1105 | 10.36620015625, 1106 | 12.48876375, 1107 | 12.6267745625, 1108 | 13.75824859375, 1109 | 15.5987240625, 1110 | 17.7003794375, 1111 | 17.783305, 1112 | 19.45704734375, 1113 | 21.025388375, 1114 | 22.46170653125, 1115 | 22.47684553125 1116 | ], 1117 | "delta_t_prof": 22.46170653125, 1118 | "layer_cfg": 512, 1119 | "bandwidths": [ 1120 | 65536, 1121 | 65536, 1122 | 16384, 1123 | 32768, 1124 | 32768, 1125 | 8192, 1126 | 16384, 1127 | 16384, 1128 | 16384, 1129 | 4096, 1130 | 8192, 1131 | 8192, 1132 | 8192, 1133 | 2048, 1134 | 2048, 1135 | 2048, 1136 | 2048, 1137 | 512 1138 | ], 1139 | "bandwidth": 2048 1140 | }, 1141 | { 1142 | "config": [ 1143 | 64, 1144 | 64, 1145 | "M", 1146 | 128, 1147 | 128, 1148 | "M", 1149 | 256, 1150 | 256, 1151 | 256, 1152 | "M", 1153 | 512, 1154 | 512, 1155 | 512, 1156 | "M", 1157 | 512, 1158 | 512, 1159 | 512, 1160 | "M" 1161 | ], 1162 | "acc": 88.23, 1163 | "index": 17, 1164 | "delta_ts": [ 1165 | 0.24863228125, 1166 | 2.98439903125, 1167 | 3.35028140625, 1168 | 4.5425425, 1169 | 6.6659229375, 1170 | 6.82027840625, 1171 | 7.9001390625, 1172 | 9.72041934375, 1173 | 11.5635606875, 1174 | 11.66123371875, 1175 | 12.62021071875, 1176 | 14.39232978125, 1177 | 16.29025909375, 1178 | 16.36242540625, 1179 | 17.83738878125, 1180 | 19.406784875, 1181 | 20.9314828125, 1182 | 20.94411003125 1183 | ], 1184 | "delta_t_prof": 20.94411003125, 1185 | "layer_cfg": "M", 1186 | "bandwidths": [ 1187 | 65536, 1188 | 65536, 1189 | 16384, 1190 | 32768, 1191 | 32768, 1192 | 8192, 1193 | 16384, 1194 | 16384, 1195 | 16384, 1196 | 4096, 1197 | 8192, 1198 | 8192, 1199 | 8192, 1200 | 2048, 1201 | 2048, 1202 | 2048, 1203 | 2048, 1204 | 512 1205 | ], 1206 | "bandwidth": 512 1207 | } 1208 | ] -------------------------------------------------------------------------------- /log_prune.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "config": [ 4 | 23, 5 | 30, 6 | "M", 7 | 57, 8 | 56, 9 | "M", 10 | 85, 11 | 84, 12 | 71, 13 | "M", 14 | 118, 15 | 103, 16 | 109, 17 | "M", 18 | 110, 19 | 105, 20 | 105, 21 | "M" 22 | ], 23 | "acc": 85.09, 24 | "index": 0, 25 | "delta_ts": [ 26 | 0.11358825, 27 | 0.6524008125, 28 | 0.78052915625, 29 | 1.02967765625, 30 | 1.46185646875, 31 | 1.519256, 32 | 1.67551621875, 33 | 1.897288625, 34 | 2.0822316875, 35 | 2.10281871875, 36 | 2.17942565625, 37 | 2.284428, 38 | 2.3702856875, 39 | 2.37935471875, 40 | 2.45145740625, 41 | 2.52506684375, 42 | 2.59117959375, 43 | 2.59310628125 44 | ], 45 | "delta_t_prof": 0.11358825, 46 | "layer_cfg": 23, 47 | "bandwidths": [ 48 | 23552, 49 | 30720, 50 | 7680, 51 | 14592, 52 | 14336, 53 | 3584, 54 | 5440, 55 | 5376, 56 | 4544, 57 | 1136, 58 | 1888, 59 | 1648, 60 | 1744, 61 | 436, 62 | 440, 63 | 420, 64 | 420, 65 | 105 66 | ], 67 | "bandwidth": 23552 68 | }, 69 | { 70 | "config": [ 71 | 23, 72 | 30, 73 | "M", 74 | 57, 75 | 56, 76 | "M", 77 | 85, 78 | 84, 79 | 71, 80 | "M", 81 | 118, 82 | 103, 83 | 109, 84 | "M", 85 | 110, 86 | 105, 87 | 105, 88 | "M" 89 | ], 90 | "acc": 85.09, 91 | "index": 1, 92 | "delta_ts": [ 93 | 0.16648715625, 94 | 0.7412968125, 95 | 0.86761065625, 96 | 1.128629875, 97 | 1.63559815625, 98 | 1.698540125, 99 | 1.85947390625, 100 | 2.12187459375, 101 | 2.34924540625, 102 | 2.37135334375, 103 | 2.453211625, 104 | 2.5607616875, 105 | 2.64559834375, 106 | 2.65526340625, 107 | 2.73318090625, 108 | 2.80541953125, 109 | 2.87597209375, 110 | 2.8778923125 111 | ], 112 | "delta_t_prof": 0.7412968125, 113 | "layer_cfg": 30, 114 | "bandwidths": [ 115 | 23552, 116 | 30720, 117 | 7680, 118 | 14592, 119 | 14336, 120 | 3584, 121 | 5440, 122 | 5376, 123 | 4544, 124 | 1136, 125 | 1888, 126 | 1648, 127 | 1744, 128 | 436, 129 | 440, 130 | 420, 131 | 420, 132 | 105 133 | ], 134 | "bandwidth": 30720 135 | }, 136 | { 137 | "config": [ 138 | 23, 139 | 30, 140 | "M", 141 | 57, 142 | 56, 143 | "M", 144 | 85, 145 | 84, 146 | 71, 147 | "M", 148 | 118, 149 | 103, 150 | 109, 151 | "M", 152 | 110, 153 | 105, 154 | 105, 155 | "M" 156 | ], 157 | "acc": 85.09, 158 | "index": 2, 159 | "delta_ts": [ 160 | 0.20093009375, 161 | 0.72621334375, 162 | 0.854012125, 163 | 1.221297375, 164 | 1.7189039375, 165 | 1.778899125, 166 | 1.97942165625, 167 | 2.24895665625, 168 | 2.4238008125, 169 | 2.44501946875, 170 | 2.54007571875, 171 | 2.67186065625, 172 | 2.75497121875, 173 | 2.763864, 174 | 2.84323803125, 175 | 2.91912134375, 176 | 2.98403071875, 177 | 2.98590321875 178 | ], 179 | "delta_t_prof": 0.854012125, 180 | "layer_cfg": "M", 181 | "bandwidths": [ 182 | 23552, 183 | 30720, 184 | 7680, 185 | 14592, 186 | 14336, 187 | 3584, 188 | 5440, 189 | 5376, 190 | 4544, 191 | 1136, 192 | 1888, 193 | 1648, 194 | 1744, 195 | 436, 196 | 440, 197 | 420, 198 | 420, 199 | 105 200 | ], 201 | "bandwidth": 7680 202 | }, 203 | { 204 | "config": [ 205 | 23, 206 | 30, 207 | "M", 208 | 57, 209 | 56, 210 | "M", 211 | 85, 212 | 84, 213 | 71, 214 | "M", 215 | 118, 216 | 103, 217 | 109, 218 | "M", 219 | 110, 220 | 105, 221 | 105, 222 | "M" 223 | ], 224 | "acc": 85.09, 225 | "index": 3, 226 | "delta_ts": [ 227 | 0.105847625, 228 | 0.64889725, 229 | 0.78045390625, 230 | 1.05999475, 231 | 1.51695759375, 232 | 1.57741509375, 233 | 1.74643221875, 234 | 1.9798571875, 235 | 2.18829715625, 236 | 2.20971290625, 237 | 2.285893375, 238 | 2.39140646875, 239 | 2.487009625, 240 | 2.49656165625, 241 | 2.5705013125, 242 | 2.6446568125, 243 | 2.71337875, 244 | 2.71542528125 245 | ], 246 | "delta_t_prof": 1.05999475, 247 | "layer_cfg": 57, 248 | "bandwidths": [ 249 | 23552, 250 | 30720, 251 | 7680, 252 | 14592, 253 | 14336, 254 | 3584, 255 | 5440, 256 | 5376, 257 | 4544, 258 | 1136, 259 | 1888, 260 | 1648, 261 | 1744, 262 | 436, 263 | 440, 264 | 420, 265 | 420, 266 | 105 267 | ], 268 | "bandwidth": 14592 269 | }, 270 | { 271 | "config": [ 272 | 23, 273 | 30, 274 | "M", 275 | 57, 276 | 56, 277 | "M", 278 | 85, 279 | 84, 280 | 71, 281 | "M", 282 | 118, 283 | 103, 284 | 109, 285 | "M", 286 | 110, 287 | 105, 288 | 105, 289 | "M" 290 | ], 291 | "acc": 85.09, 292 | "index": 4, 293 | "delta_ts": [ 294 | 0.196623875, 295 | 0.757689875, 296 | 0.9350411875, 297 | 1.20698053125, 298 | 1.6778705, 299 | 1.754519625, 300 | 1.95127028125, 301 | 2.20606434375, 302 | 2.3986604375, 303 | 2.4217500625, 304 | 2.53059984375, 305 | 2.65049459375, 306 | 2.74077396875, 307 | 2.7502116875, 308 | 2.8365283125, 309 | 2.91111121875, 310 | 2.98903065625, 311 | 2.99104221875 312 | ], 313 | "delta_t_prof": 1.6778705, 314 | "layer_cfg": 56, 315 | "bandwidths": [ 316 | 23552, 317 | 30720, 318 | 7680, 319 | 14592, 320 | 14336, 321 | 3584, 322 | 5440, 323 | 5376, 324 | 4544, 325 | 1136, 326 | 1888, 327 | 1648, 328 | 1744, 329 | 436, 330 | 440, 331 | 420, 332 | 420, 333 | 105 334 | ], 335 | "bandwidth": 14336 336 | }, 337 | { 338 | "config": [ 339 | 23, 340 | 30, 341 | "M", 342 | 57, 343 | 56, 344 | "M", 345 | 85, 346 | 84, 347 | 71, 348 | "M", 349 | 118, 350 | 103, 351 | 109, 352 | "M", 353 | 110, 354 | 105, 355 | 105, 356 | "M" 357 | ], 358 | "acc": 85.09, 359 | "index": 5, 360 | "delta_ts": [ 361 | 0.17345740625, 362 | 0.63864290625, 363 | 0.76896453125, 364 | 1.00580925, 365 | 1.44612390625, 366 | 1.50453146875, 367 | 1.6572635, 368 | 1.8816075, 369 | 2.07590396875, 370 | 2.09684878125, 371 | 2.17170465625, 372 | 2.2763654375, 373 | 2.370477625, 374 | 2.3795623125, 375 | 2.454071875, 376 | 2.53123590625, 377 | 2.60262325, 378 | 2.60456346875 379 | ], 380 | "delta_t_prof": 1.50453146875, 381 | "layer_cfg": "M", 382 | "bandwidths": [ 383 | 23552, 384 | 30720, 385 | 7680, 386 | 14592, 387 | 14336, 388 | 3584, 389 | 5440, 390 | 5376, 391 | 4544, 392 | 1136, 393 | 1888, 394 | 1648, 395 | 1744, 396 | 436, 397 | 440, 398 | 420, 399 | 420, 400 | 105 401 | ], 402 | "bandwidth": 3584 403 | }, 404 | { 405 | "config": [ 406 | 23, 407 | 30, 408 | "M", 409 | 57, 410 | 56, 411 | "M", 412 | 85, 413 | 84, 414 | 71, 415 | "M", 416 | 118, 417 | 103, 418 | 109, 419 | "M", 420 | 110, 421 | 105, 422 | 105, 423 | "M" 424 | ], 425 | "acc": 85.09, 426 | "index": 6, 427 | "delta_ts": [ 428 | 0.21753346875, 429 | 0.737766625, 430 | 0.866268125, 431 | 1.133596125, 432 | 1.62475284375, 433 | 1.68542575, 434 | 1.86381603125, 435 | 2.1209463125, 436 | 2.34632578125, 437 | 2.36777571875, 438 | 2.45723184375, 439 | 2.57705303125, 440 | 2.6854775625, 441 | 2.69476521875, 442 | 2.77726290625, 443 | 2.856190625, 444 | 2.9355424375, 445 | 2.93762328125 446 | ], 447 | "delta_t_prof": 1.86381603125, 448 | "layer_cfg": 85, 449 | "bandwidths": [ 450 | 23552, 451 | 30720, 452 | 7680, 453 | 14592, 454 | 14336, 455 | 3584, 456 | 5440, 457 | 5376, 458 | 4544, 459 | 1136, 460 | 1888, 461 | 1648, 462 | 1744, 463 | 436, 464 | 440, 465 | 420, 466 | 420, 467 | 105 468 | ], 469 | "bandwidth": 5440 470 | }, 471 | { 472 | "config": [ 473 | 23, 474 | 30, 475 | "M", 476 | 57, 477 | 56, 478 | "M", 479 | 85, 480 | 84, 481 | 71, 482 | "M", 483 | 118, 484 | 103, 485 | 109, 486 | "M", 487 | 110, 488 | 105, 489 | 105, 490 | "M" 491 | ], 492 | "acc": 85.09, 493 | "index": 7, 494 | "delta_ts": [ 495 | 0.11802159375, 496 | 0.71383421875, 497 | 0.8901209375, 498 | 1.166493125, 499 | 1.66576696875, 500 | 1.75190859375, 501 | 1.9433455625, 502 | 2.1787820625, 503 | 2.39880003125, 504 | 2.4203941875, 505 | 2.51077046875, 506 | 2.6355696875, 507 | 2.72263040625, 508 | 2.7324204375, 509 | 2.80708509375, 510 | 2.8791105, 511 | 2.958710125, 512 | 2.960613625 513 | ], 514 | "delta_t_prof": 2.1787820625, 515 | "layer_cfg": 84, 516 | "bandwidths": [ 517 | 23552, 518 | 30720, 519 | 7680, 520 | 14592, 521 | 14336, 522 | 3584, 523 | 5440, 524 | 5376, 525 | 4544, 526 | 1136, 527 | 1888, 528 | 1648, 529 | 1744, 530 | 436, 531 | 440, 532 | 420, 533 | 420, 534 | 105 535 | ], 536 | "bandwidth": 5376 537 | }, 538 | { 539 | "config": [ 540 | 23, 541 | 30, 542 | "M", 543 | 57, 544 | 56, 545 | "M", 546 | 85, 547 | 84, 548 | 71, 549 | "M", 550 | 118, 551 | 103, 552 | 109, 553 | "M", 554 | 110, 555 | 105, 556 | 105, 557 | "M" 558 | ], 559 | "acc": 85.09, 560 | "index": 8, 561 | "delta_ts": [ 562 | 0.164475, 563 | 0.7031471875, 564 | 0.88555325, 565 | 1.1362460625, 566 | 1.61940840625, 567 | 1.7008634375, 568 | 1.86077875, 569 | 2.10103775, 570 | 2.30417265625, 571 | 2.3301905625, 572 | 2.4104764375, 573 | 2.51121521875, 574 | 2.60969159375, 575 | 2.62037065625, 576 | 2.6952423125, 577 | 2.7674815, 578 | 2.83903628125, 579 | 2.8412394375 580 | ], 581 | "delta_t_prof": 2.30417265625, 582 | "layer_cfg": 71, 583 | "bandwidths": [ 584 | 23552, 585 | 30720, 586 | 7680, 587 | 14592, 588 | 14336, 589 | 3584, 590 | 5440, 591 | 5376, 592 | 4544, 593 | 1136, 594 | 1888, 595 | 1648, 596 | 1744, 597 | 436, 598 | 440, 599 | 420, 600 | 420, 601 | 105 602 | ], 603 | "bandwidth": 4544 604 | }, 605 | { 606 | "config": [ 607 | 23, 608 | 30, 609 | "M", 610 | 57, 611 | 56, 612 | "M", 613 | 85, 614 | 84, 615 | 71, 616 | "M", 617 | 118, 618 | 103, 619 | 109, 620 | "M", 621 | 110, 622 | 105, 623 | 105, 624 | "M" 625 | ], 626 | "acc": 85.09, 627 | "index": 9, 628 | "delta_ts": [ 629 | 0.13080984375, 630 | 0.7091984375, 631 | 0.861831375, 632 | 1.18126071875, 633 | 1.71372759375, 634 | 1.78030796875, 635 | 1.97100134375, 636 | 2.358423, 637 | 2.58714471875, 638 | 2.610667375, 639 | 2.7058868125, 640 | 2.83107096875, 641 | 2.9410485, 642 | 2.951260375, 643 | 3.02990325, 644 | 3.11188534375, 645 | 3.190381875, 646 | 3.19264434375 647 | ], 648 | "delta_t_prof": 2.610667375, 649 | "layer_cfg": "M", 650 | "bandwidths": [ 651 | 23552, 652 | 30720, 653 | 7680, 654 | 14592, 655 | 14336, 656 | 3584, 657 | 5440, 658 | 5376, 659 | 4544, 660 | 1136, 661 | 1888, 662 | 1648, 663 | 1744, 664 | 436, 665 | 440, 666 | 420, 667 | 420, 668 | 105 669 | ], 670 | "bandwidth": 1136 671 | }, 672 | { 673 | "config": [ 674 | 23, 675 | 30, 676 | "M", 677 | 57, 678 | 56, 679 | "M", 680 | 85, 681 | 84, 682 | 71, 683 | "M", 684 | 118, 685 | 103, 686 | 109, 687 | "M", 688 | 110, 689 | 105, 690 | 105, 691 | "M" 692 | ], 693 | "acc": 85.09, 694 | "index": 10, 695 | "delta_ts": [ 696 | 0.137620125, 697 | 0.68872875, 698 | 0.82794590625, 699 | 1.11183840625, 700 | 1.60877696875, 701 | 1.67172246875, 702 | 1.85142065625, 703 | 2.09939234375, 704 | 2.2981294375, 705 | 2.32068871875, 706 | 2.40251084375, 707 | 2.510551, 708 | 2.605893625, 709 | 2.6157908125, 710 | 2.68786478125, 711 | 2.75797584375, 712 | 2.8265541875, 713 | 2.82869090625 714 | ], 715 | "delta_t_prof": 2.40251084375, 716 | "layer_cfg": 118, 717 | "bandwidths": [ 718 | 23552, 719 | 30720, 720 | 7680, 721 | 14592, 722 | 14336, 723 | 3584, 724 | 5440, 725 | 5376, 726 | 4544, 727 | 1136, 728 | 1888, 729 | 1648, 730 | 1744, 731 | 436, 732 | 440, 733 | 420, 734 | 420, 735 | 105 736 | ], 737 | "bandwidth": 1888 738 | }, 739 | { 740 | "config": [ 741 | 23, 742 | 30, 743 | "M", 744 | 57, 745 | 56, 746 | "M", 747 | 85, 748 | 84, 749 | 71, 750 | "M", 751 | 118, 752 | 103, 753 | 109, 754 | "M", 755 | 110, 756 | 105, 757 | 105, 758 | "M" 759 | ], 760 | "acc": 85.09, 761 | "index": 11, 762 | "delta_ts": [ 763 | 0.219060125, 764 | 0.77152878125, 765 | 0.90845471875, 766 | 1.2438391875, 767 | 1.76259640625, 768 | 1.8296179375, 769 | 2.03861240625, 770 | 2.30987246875, 771 | 2.5106599375, 772 | 2.53445234375, 773 | 2.6342744375, 774 | 2.7615138125, 775 | 2.8567420625, 776 | 2.86694, 777 | 2.9575230625, 778 | 3.0433628125, 779 | 3.12710240625, 780 | 3.129416375 781 | ], 782 | "delta_t_prof": 2.7615138125, 783 | "layer_cfg": 103, 784 | "bandwidths": [ 785 | 23552, 786 | 30720, 787 | 7680, 788 | 14592, 789 | 14336, 790 | 3584, 791 | 5440, 792 | 5376, 793 | 4544, 794 | 1136, 795 | 1888, 796 | 1648, 797 | 1744, 798 | 436, 799 | 440, 800 | 420, 801 | 420, 802 | 105 803 | ], 804 | "bandwidth": 1648 805 | }, 806 | { 807 | "config": [ 808 | 23, 809 | 30, 810 | "M", 811 | 57, 812 | 56, 813 | "M", 814 | 85, 815 | 84, 816 | 71, 817 | "M", 818 | 118, 819 | 103, 820 | 109, 821 | "M", 822 | 110, 823 | 105, 824 | 105, 825 | "M" 826 | ], 827 | "acc": 85.09, 828 | "index": 12, 829 | "delta_ts": [ 830 | 0.11083040625, 831 | 0.666358125, 832 | 0.804948375, 833 | 1.08153290625, 834 | 1.5548213125, 835 | 1.61504896875, 836 | 1.80998896875, 837 | 2.045148375, 838 | 2.23686865625, 839 | 2.25766771875, 840 | 2.34479840625, 841 | 2.4495084375, 842 | 2.54687325, 843 | 2.556028875, 844 | 2.63047896875, 845 | 2.70671396875, 846 | 2.78401128125, 847 | 2.78599825 848 | ], 849 | "delta_t_prof": 2.54687325, 850 | "layer_cfg": 109, 851 | "bandwidths": [ 852 | 23552, 853 | 30720, 854 | 7680, 855 | 14592, 856 | 14336, 857 | 3584, 858 | 5440, 859 | 5376, 860 | 4544, 861 | 1136, 862 | 1888, 863 | 1648, 864 | 1744, 865 | 436, 866 | 440, 867 | 420, 868 | 420, 869 | 105 870 | ], 871 | "bandwidth": 1744 872 | }, 873 | { 874 | "config": [ 875 | 23, 876 | 30, 877 | "M", 878 | 57, 879 | 56, 880 | "M", 881 | 85, 882 | 84, 883 | 71, 884 | "M", 885 | 118, 886 | 103, 887 | 109, 888 | "M", 889 | 110, 890 | 105, 891 | 105, 892 | "M" 893 | ], 894 | "acc": 85.09, 895 | "index": 13, 896 | "delta_ts": [ 897 | 0.104472125, 898 | 0.6647065, 899 | 0.8077999375, 900 | 1.1177865625, 901 | 1.68590478125, 902 | 1.79190128125, 903 | 1.96875121875, 904 | 2.23612775, 905 | 2.457427125, 906 | 2.49211965625, 907 | 2.57577021875, 908 | 2.6983085625, 909 | 2.80297953125, 910 | 2.816973, 911 | 2.89812290625, 912 | 2.98071928125, 913 | 3.05437046875, 914 | 3.05722084375 915 | ], 916 | "delta_t_prof": 2.816973, 917 | "layer_cfg": "M", 918 | "bandwidths": [ 919 | 23552, 920 | 30720, 921 | 7680, 922 | 14592, 923 | 14336, 924 | 3584, 925 | 5440, 926 | 5376, 927 | 4544, 928 | 1136, 929 | 1888, 930 | 1648, 931 | 1744, 932 | 436, 933 | 440, 934 | 420, 935 | 420, 936 | 105 937 | ], 938 | "bandwidth": 436 939 | }, 940 | { 941 | "config": [ 942 | 23, 943 | 30, 944 | "M", 945 | 57, 946 | 56, 947 | "M", 948 | 85, 949 | 84, 950 | 71, 951 | "M", 952 | 118, 953 | 103, 954 | 109, 955 | "M", 956 | 110, 957 | 105, 958 | 105, 959 | "M" 960 | ], 961 | "acc": 85.09, 962 | "index": 14, 963 | "delta_ts": [ 964 | 0.29025315625, 965 | 0.75345184375, 966 | 0.89207803125, 967 | 1.121283375, 968 | 1.5698099375, 969 | 1.63668696875, 970 | 1.805445, 971 | 2.043355, 972 | 2.25163134375, 973 | 2.2753558125, 974 | 2.35576221875, 975 | 2.4598324375, 976 | 2.5584845, 977 | 2.56785534375, 978 | 2.64912175, 979 | 2.72340684375, 980 | 2.8056718125, 981 | 2.80776221875 982 | ], 983 | "delta_t_prof": 2.64912175, 984 | "layer_cfg": 110, 985 | "bandwidths": [ 986 | 23552, 987 | 30720, 988 | 7680, 989 | 14592, 990 | 14336, 991 | 3584, 992 | 5440, 993 | 5376, 994 | 4544, 995 | 1136, 996 | 1888, 997 | 1648, 998 | 1744, 999 | 436, 1000 | 440, 1001 | 420, 1002 | 420, 1003 | 105 1004 | ], 1005 | "bandwidth": 440 1006 | }, 1007 | { 1008 | "config": [ 1009 | 23, 1010 | 30, 1011 | "M", 1012 | 57, 1013 | 56, 1014 | "M", 1015 | 85, 1016 | 84, 1017 | 71, 1018 | "M", 1019 | 118, 1020 | 103, 1021 | 109, 1022 | "M", 1023 | 110, 1024 | 105, 1025 | 105, 1026 | "M" 1027 | ], 1028 | "acc": 85.09, 1029 | "index": 15, 1030 | "delta_ts": [ 1031 | 0.1343285625, 1032 | 0.71285915625, 1033 | 0.85016946875, 1034 | 1.18777665625, 1035 | 1.7240135, 1036 | 1.79218621875, 1037 | 2.0239770625, 1038 | 2.29661846875, 1039 | 2.54156471875, 1040 | 2.5795464375, 1041 | 2.67105909375, 1042 | 2.801212875, 1043 | 2.89718575, 1044 | 2.91217834375, 1045 | 2.99169771875, 1046 | 3.06911415625, 1047 | 3.13969103125, 1048 | 3.1428933125 1049 | ], 1050 | "delta_t_prof": 3.06911415625, 1051 | "layer_cfg": 105, 1052 | "bandwidths": [ 1053 | 23552, 1054 | 30720, 1055 | 7680, 1056 | 14592, 1057 | 14336, 1058 | 3584, 1059 | 5440, 1060 | 5376, 1061 | 4544, 1062 | 1136, 1063 | 1888, 1064 | 1648, 1065 | 1744, 1066 | 436, 1067 | 440, 1068 | 420, 1069 | 420, 1070 | 105 1071 | ], 1072 | "bandwidth": 420 1073 | }, 1074 | { 1075 | "config": [ 1076 | 23, 1077 | 30, 1078 | "M", 1079 | 57, 1080 | 56, 1081 | "M", 1082 | 85, 1083 | 84, 1084 | 71, 1085 | "M", 1086 | 118, 1087 | 103, 1088 | 109, 1089 | "M", 1090 | 110, 1091 | 105, 1092 | 105, 1093 | "M" 1094 | ], 1095 | "acc": 85.09, 1096 | "index": 16, 1097 | "delta_ts": [ 1098 | 0.101868875, 1099 | 0.678419875, 1100 | 0.8174533125, 1101 | 1.1241566875, 1102 | 1.621369875, 1103 | 1.69011525, 1104 | 1.8666296875, 1105 | 2.1127453125, 1106 | 2.32561303125, 1107 | 2.34711965625, 1108 | 2.4247016875, 1109 | 2.53164003125, 1110 | 2.62870446875, 1111 | 2.6377295, 1112 | 2.7136711875, 1113 | 2.7904385, 1114 | 2.86945815625, 1115 | 2.87149959375 1116 | ], 1117 | "delta_t_prof": 2.86945815625, 1118 | "layer_cfg": 105, 1119 | "bandwidths": [ 1120 | 23552, 1121 | 30720, 1122 | 7680, 1123 | 14592, 1124 | 14336, 1125 | 3584, 1126 | 5440, 1127 | 5376, 1128 | 4544, 1129 | 1136, 1130 | 1888, 1131 | 1648, 1132 | 1744, 1133 | 436, 1134 | 440, 1135 | 420, 1136 | 420, 1137 | 105 1138 | ], 1139 | "bandwidth": 420 1140 | }, 1141 | { 1142 | "config": [ 1143 | 23, 1144 | 30, 1145 | "M", 1146 | 57, 1147 | 56, 1148 | "M", 1149 | 85, 1150 | 84, 1151 | 71, 1152 | "M", 1153 | 118, 1154 | 103, 1155 | 109, 1156 | "M", 1157 | 110, 1158 | 105, 1159 | 105, 1160 | "M" 1161 | ], 1162 | "acc": 85.09, 1163 | "index": 17, 1164 | "delta_ts": [ 1165 | 0.21816859375, 1166 | 0.6997705625, 1167 | 0.8375525, 1168 | 1.06860471875, 1169 | 1.52064996875, 1170 | 1.5900520625, 1171 | 1.7428996875, 1172 | 1.9716084375, 1173 | 2.15746409375, 1174 | 2.18117134375, 1175 | 2.25775359375, 1176 | 2.360581125, 1177 | 2.460398, 1178 | 2.4705750625, 1179 | 2.54335234375, 1180 | 2.6143379375, 1181 | 2.6817635625, 1182 | 2.68396921875 1183 | ], 1184 | "delta_t_prof": 2.68396921875, 1185 | "layer_cfg": "M", 1186 | "bandwidths": [ 1187 | 23552, 1188 | 30720, 1189 | 7680, 1190 | 14592, 1191 | 14336, 1192 | 3584, 1193 | 5440, 1194 | 5376, 1195 | 4544, 1196 | 1136, 1197 | 1888, 1198 | 1648, 1199 | 1744, 1200 | 436, 1201 | 440, 1202 | 420, 1203 | 420, 1204 | 105 1205 | ], 1206 | "bandwidth": 105 1207 | } 1208 | ] -------------------------------------------------------------------------------- /log_prune_layer.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "config": [ 4 | 3, 5 | 30, 6 | "M", 7 | 57, 8 | 56, 9 | "M", 10 | 85, 11 | 84, 12 | 71, 13 | "M", 14 | 118, 15 | 103, 16 | 109, 17 | "M", 18 | 110, 19 | 105, 20 | 105, 21 | "M" 22 | ], 23 | "acc": 86.57, 24 | "index": 0, 25 | "delta_ts": [ 26 | 0.1220933125, 27 | 0.25864446875, 28 | 0.39611240625, 29 | 0.7114073125, 30 | 1.22712721875, 31 | 1.29468459375, 32 | 1.481201, 33 | 1.7446034375, 34 | 1.968209625, 35 | 1.99154840625, 36 | 2.07867228125, 37 | 2.1886398125, 38 | 2.28305225, 39 | 2.2933705, 40 | 2.37075915625, 41 | 2.4494961875, 42 | 2.51927378125, 43 | 2.5214129375 44 | ], 45 | "delta_t_prof": 0.1220933125, 46 | "layer_cfg": 3, 47 | "bandwidths": [ 48 | 3072, 49 | 30720, 50 | 7680, 51 | 14592, 52 | 14336, 53 | 3584, 54 | 5440, 55 | 5376, 56 | 4544, 57 | 1136, 58 | 1888, 59 | 1648, 60 | 1744, 61 | 436, 62 | 440, 63 | 420, 64 | 420, 65 | 105 66 | ], 67 | "bandwidth": 3072 68 | }, 69 | { 70 | "config": [ 71 | 23, 72 | 6, 73 | "M", 74 | 57, 75 | 56, 76 | "M", 77 | 85, 78 | 84, 79 | 71, 80 | "M", 81 | 118, 82 | 103, 83 | 109, 84 | "M", 85 | 110, 86 | 105, 87 | 105, 88 | "M" 89 | ], 90 | "acc": 85.47, 91 | "index": 1, 92 | "delta_ts": [ 93 | 0.0975885, 94 | 0.290354125, 95 | 0.317774625, 96 | 0.39971715625, 97 | 0.86406790625, 98 | 0.92332121875, 99 | 1.0919138125, 100 | 1.3489255, 101 | 1.556215375, 102 | 1.5776726875, 103 | 1.65965834375, 104 | 1.77065621875, 105 | 1.8713810625, 106 | 1.88050409375, 107 | 1.9585281875, 108 | 2.0344486875, 109 | 2.10610509375, 110 | 2.10805471875 111 | ], 112 | "delta_t_prof": 0.290354125, 113 | "layer_cfg": 6, 114 | "bandwidths": [ 115 | 23552, 116 | 6144, 117 | 1536, 118 | 14592, 119 | 14336, 120 | 3584, 121 | 5440, 122 | 5376, 123 | 4544, 124 | 1136, 125 | 1888, 126 | 1648, 127 | 1744, 128 | 436, 129 | 440, 130 | 420, 131 | 420, 132 | 105 133 | ], 134 | "bandwidth": 6144 135 | }, 136 | { 137 | "config": [ 138 | 23, 139 | 6, 140 | "M", 141 | 57, 142 | 56, 143 | "M", 144 | 85, 145 | 84, 146 | 71, 147 | "M", 148 | 118, 149 | 103, 150 | 109, 151 | "M", 152 | 110, 153 | 105, 154 | 105, 155 | "M" 156 | ], 157 | "acc": 85.47, 158 | "index": 2, 159 | "delta_ts": [ 160 | 0.09153596875, 161 | 0.36373290625, 162 | 0.40905903125, 163 | 0.4929178125, 164 | 1.06294275, 165 | 1.15050309375, 166 | 1.345450625, 167 | 1.617535375, 168 | 1.86558284375, 169 | 1.89330884375, 170 | 2.0049671875, 171 | 2.150731375, 172 | 2.25134728125, 173 | 2.26164990625, 174 | 2.34249384375, 175 | 2.419188625, 176 | 2.502344125, 177 | 2.5062683125 178 | ], 179 | "delta_t_prof": 0.40905903125, 180 | "layer_cfg": "M", 181 | "bandwidths": [ 182 | 23552, 183 | 6144, 184 | 1536, 185 | 14592, 186 | 14336, 187 | 3584, 188 | 5440, 189 | 5376, 190 | 4544, 191 | 1136, 192 | 1888, 193 | 1648, 194 | 1744, 195 | 436, 196 | 440, 197 | 420, 198 | 420, 199 | 105 200 | ], 201 | "bandwidth": 1536 202 | }, 203 | { 204 | "config": [ 205 | 23, 206 | 30, 207 | "M", 208 | 9, 209 | 56, 210 | "M", 211 | 85, 212 | 84, 213 | 71, 214 | "M", 215 | 118, 216 | 103, 217 | 109, 218 | "M", 219 | 110, 220 | 105, 221 | 105, 222 | "M" 223 | ], 224 | "acc": 86.02, 225 | "index": 3, 226 | "delta_ts": [ 227 | 0.31451128125, 228 | 0.78561878125, 229 | 0.91668040625, 230 | 0.99882846875, 231 | 1.08730934375, 232 | 1.14510646875, 233 | 1.3039059375, 234 | 1.53299203125, 235 | 1.74705634375, 236 | 1.76801375, 237 | 1.84598565625, 238 | 1.9570494375, 239 | 2.05581778125, 240 | 2.0650960625, 241 | 2.1391909375, 242 | 2.2119311875, 243 | 2.27955840625, 244 | 2.28160684375 245 | ], 246 | "delta_t_prof": 0.99882846875, 247 | "layer_cfg": 9, 248 | "bandwidths": [ 249 | 23552, 250 | 30720, 251 | 7680, 252 | 2304, 253 | 14336, 254 | 3584, 255 | 5440, 256 | 5376, 257 | 4544, 258 | 1136, 259 | 1888, 260 | 1648, 261 | 1744, 262 | 436, 263 | 440, 264 | 420, 265 | 420, 266 | 105 267 | ], 268 | "bandwidth": 2304 269 | }, 270 | { 271 | "config": [ 272 | 23, 273 | 30, 274 | "M", 275 | 57, 276 | 8, 277 | "M", 278 | 85, 279 | 84, 280 | 71, 281 | "M", 282 | 118, 283 | 103, 284 | 109, 285 | "M", 286 | 110, 287 | 105, 288 | 105, 289 | "M" 290 | ], 291 | "acc": 84.13, 292 | "index": 4, 293 | "delta_ts": [ 294 | 0.15428021875, 295 | 0.70036034375, 296 | 0.90634684375, 297 | 1.1656844375, 298 | 1.35072478125, 299 | 1.36030840625, 300 | 1.415749875, 301 | 1.66319553125, 302 | 1.8728339375, 303 | 1.90163571875, 304 | 1.98016775, 305 | 2.079678625, 306 | 2.19145675, 307 | 2.2029879375, 308 | 2.2804146875, 309 | 2.35464059375, 310 | 2.42021190625, 311 | 2.42252096875 312 | ], 313 | "delta_t_prof": 1.35072478125, 314 | "layer_cfg": 8, 315 | "bandwidths": [ 316 | 23552, 317 | 30720, 318 | 7680, 319 | 14592, 320 | 2048, 321 | 512, 322 | 5440, 323 | 5376, 324 | 4544, 325 | 1136, 326 | 1888, 327 | 1648, 328 | 1744, 329 | 436, 330 | 440, 331 | 420, 332 | 420, 333 | 105 334 | ], 335 | "bandwidth": 2048 336 | }, 337 | { 338 | "config": [ 339 | 23, 340 | 30, 341 | "M", 342 | 57, 343 | 8, 344 | "M", 345 | 85, 346 | 84, 347 | 71, 348 | "M", 349 | 118, 350 | 103, 351 | 109, 352 | "M", 353 | 110, 354 | 105, 355 | 105, 356 | "M" 357 | ], 358 | "acc": 84.13, 359 | "index": 5, 360 | "delta_ts": [ 361 | 0.08774184375, 362 | 0.617821375, 363 | 0.82938590625, 364 | 1.07947346875, 365 | 1.25763275, 366 | 1.26680846875, 367 | 1.30182253125, 368 | 1.5305945625, 369 | 1.746352, 370 | 1.77472140625, 371 | 1.852979375, 372 | 1.9618808125, 373 | 2.05708475, 374 | 2.0694246875, 375 | 2.1414258125, 376 | 2.21438028125, 377 | 2.27877178125, 378 | 2.2813055 379 | ], 380 | "delta_t_prof": 1.26680846875, 381 | "layer_cfg": "M", 382 | "bandwidths": [ 383 | 23552, 384 | 30720, 385 | 7680, 386 | 14592, 387 | 2048, 388 | 512, 389 | 5440, 390 | 5376, 391 | 4544, 392 | 1136, 393 | 1888, 394 | 1648, 395 | 1744, 396 | 436, 397 | 440, 398 | 420, 399 | 420, 400 | 105 401 | ], 402 | "bandwidth": 512 403 | }, 404 | { 405 | "config": [ 406 | 23, 407 | 30, 408 | "M", 409 | 57, 410 | 56, 411 | "M", 412 | 13, 413 | 84, 414 | 71, 415 | "M", 416 | 118, 417 | 103, 418 | 109, 419 | "M", 420 | 110, 421 | 105, 422 | 105, 423 | "M" 424 | ], 425 | "acc": 85.78, 426 | "index": 6, 427 | "delta_ts": [ 428 | 0.17304409375, 429 | 0.68838553125, 430 | 0.83677965625, 431 | 1.0934670625, 432 | 1.54930878125, 433 | 1.60730434375, 434 | 1.6652258125, 435 | 1.7166780625, 436 | 1.9023009375, 437 | 1.9230566875, 438 | 2.008898875, 439 | 2.1068039375, 440 | 2.2083715, 441 | 2.217857, 442 | 2.29326996875, 443 | 2.36206575, 444 | 2.43262078125, 445 | 2.43449859375 446 | ], 447 | "delta_t_prof": 1.6652258125, 448 | "layer_cfg": 13, 449 | "bandwidths": [ 450 | 23552, 451 | 30720, 452 | 7680, 453 | 14592, 454 | 14336, 455 | 3584, 456 | 832, 457 | 5376, 458 | 4544, 459 | 1136, 460 | 1888, 461 | 1648, 462 | 1744, 463 | 436, 464 | 440, 465 | 420, 466 | 420, 467 | 105 468 | ], 469 | "bandwidth": 832 470 | }, 471 | { 472 | "config": [ 473 | 23, 474 | 30, 475 | "M", 476 | 57, 477 | 56, 478 | "M", 479 | 85, 480 | 12, 481 | 71, 482 | "M", 483 | 118, 484 | 103, 485 | 109, 486 | "M", 487 | 110, 488 | 105, 489 | 105, 490 | "M" 491 | ], 492 | "acc": 86.26, 493 | "index": 7, 494 | "delta_ts": [ 495 | 0.28446875, 496 | 0.766517, 497 | 0.90534621875, 498 | 1.1423765, 499 | 1.59044834375, 500 | 1.65824590625, 501 | 1.81192940625, 502 | 1.88423378125, 503 | 1.92798296875, 504 | 1.9512331875, 505 | 2.032000625, 506 | 2.13472990625, 507 | 2.243549125, 508 | 2.25405984375, 509 | 2.32791640625, 510 | 2.45888334375, 511 | 2.53542715625, 512 | 2.53736565625 513 | ], 514 | "delta_t_prof": 1.88423378125, 515 | "layer_cfg": 12, 516 | "bandwidths": [ 517 | 23552, 518 | 30720, 519 | 7680, 520 | 14592, 521 | 14336, 522 | 3584, 523 | 5440, 524 | 768, 525 | 4544, 526 | 1136, 527 | 1888, 528 | 1648, 529 | 1744, 530 | 436, 531 | 440, 532 | 420, 533 | 420, 534 | 105 535 | ], 536 | "bandwidth": 768 537 | }, 538 | { 539 | "config": [ 540 | 23, 541 | 30, 542 | "M", 543 | 57, 544 | 56, 545 | "M", 546 | 85, 547 | 84, 548 | 11, 549 | "M", 550 | 118, 551 | 103, 552 | 109, 553 | "M", 554 | 110, 555 | 105, 556 | 105, 557 | "M" 558 | ], 559 | "acc": 86.19, 560 | "index": 8, 561 | "delta_ts": [ 562 | 0.111757875, 563 | 0.6802498125, 564 | 0.85052209375, 565 | 1.1400180625, 566 | 1.66384584375, 567 | 1.734045875, 568 | 1.91516646875, 569 | 2.17453659375, 570 | 2.24900128125, 571 | 2.2530855, 572 | 2.2811150625, 573 | 2.3873805, 574 | 2.49538534375, 575 | 2.50619546875, 576 | 2.64107078125, 577 | 2.7155539375, 578 | 2.78515009375, 579 | 2.78725696875 580 | ], 581 | "delta_t_prof": 2.24900128125, 582 | "layer_cfg": 11, 583 | "bandwidths": [ 584 | 23552, 585 | 30720, 586 | 7680, 587 | 14592, 588 | 14336, 589 | 3584, 590 | 5440, 591 | 5376, 592 | 704, 593 | 176, 594 | 1888, 595 | 1648, 596 | 1744, 597 | 436, 598 | 440, 599 | 420, 600 | 420, 601 | 105 602 | ], 603 | "bandwidth": 704 604 | }, 605 | { 606 | "config": [ 607 | 23, 608 | 30, 609 | "M", 610 | 57, 611 | 56, 612 | "M", 613 | 85, 614 | 84, 615 | 11, 616 | "M", 617 | 118, 618 | 103, 619 | 109, 620 | "M", 621 | 110, 622 | 105, 623 | 105, 624 | "M" 625 | ], 626 | "acc": 86.19, 627 | "index": 9, 628 | "delta_ts": [ 629 | 0.1105239375, 630 | 0.58618653125, 631 | 0.72540678125, 632 | 0.95993578125, 633 | 1.403244, 634 | 1.4701740625, 635 | 1.62411525, 636 | 1.84193028125, 637 | 1.90949603125, 638 | 1.913054875, 639 | 1.94607015625, 640 | 2.05138503125, 641 | 2.14724428125, 642 | 2.1573071875, 643 | 2.23395525, 644 | 2.31005978125, 645 | 2.38157746875, 646 | 2.38337575 647 | ], 648 | "delta_t_prof": 1.913054875, 649 | "layer_cfg": "M", 650 | "bandwidths": [ 651 | 23552, 652 | 30720, 653 | 7680, 654 | 14592, 655 | 14336, 656 | 3584, 657 | 5440, 658 | 5376, 659 | 704, 660 | 176, 661 | 1888, 662 | 1648, 663 | 1744, 664 | 436, 665 | 440, 666 | 420, 667 | 420, 668 | 105 669 | ], 670 | "bandwidth": 176 671 | }, 672 | { 673 | "config": [ 674 | 23, 675 | 30, 676 | "M", 677 | 57, 678 | 56, 679 | "M", 680 | 85, 681 | 84, 682 | 71, 683 | "M", 684 | 22, 685 | 103, 686 | 109, 687 | "M", 688 | 110, 689 | 105, 690 | 105, 691 | "M" 692 | ], 693 | "acc": 86.91, 694 | "index": 10, 695 | "delta_ts": [ 696 | 0.10086003125, 697 | 0.7031173125, 698 | 0.838317625, 699 | 1.15011571875, 700 | 1.68945828125, 701 | 1.7485925, 702 | 1.946967375, 703 | 2.21865453125, 704 | 2.4516130625, 705 | 2.47238575, 706 | 2.504342875, 707 | 2.53925815625, 708 | 2.66005978125, 709 | 2.66973878125, 710 | 2.7567239375, 711 | 2.8388484375, 712 | 2.9172165625, 713 | 2.9191306875 714 | ], 715 | "delta_t_prof": 2.504342875, 716 | "layer_cfg": 22, 717 | "bandwidths": [ 718 | 23552, 719 | 30720, 720 | 7680, 721 | 14592, 722 | 14336, 723 | 3584, 724 | 5440, 725 | 5376, 726 | 4544, 727 | 1136, 728 | 352, 729 | 1648, 730 | 1744, 731 | 436, 732 | 440, 733 | 420, 734 | 420, 735 | 105 736 | ], 737 | "bandwidth": 352 738 | }, 739 | { 740 | "config": [ 741 | 23, 742 | 30, 743 | "M", 744 | 57, 745 | 56, 746 | "M", 747 | 85, 748 | 84, 749 | 71, 750 | "M", 751 | 118, 752 | 19, 753 | 109, 754 | "M", 755 | 110, 756 | 105, 757 | 105, 758 | "M" 759 | ], 760 | "acc": 87.35, 761 | "index": 11, 762 | "delta_ts": [ 763 | 0.14823421875, 764 | 0.58341440625, 765 | 0.721981, 766 | 0.938638, 767 | 1.353975, 768 | 1.41174215625, 769 | 1.5831458125, 770 | 1.82101665625, 771 | 2.00036834375, 772 | 2.02036453125, 773 | 2.11181075, 774 | 2.15129390625, 775 | 2.17903959375, 776 | 2.18791034375, 777 | 2.2652345625, 778 | 2.36532125, 779 | 2.43835825, 780 | 2.4401229375 781 | ], 782 | "delta_t_prof": 2.15129390625, 783 | "layer_cfg": 19, 784 | "bandwidths": [ 785 | 23552, 786 | 30720, 787 | 7680, 788 | 14592, 789 | 14336, 790 | 3584, 791 | 5440, 792 | 5376, 793 | 4544, 794 | 1136, 795 | 1888, 796 | 304, 797 | 1744, 798 | 436, 799 | 440, 800 | 420, 801 | 420, 802 | 105 803 | ], 804 | "bandwidth": 304 805 | }, 806 | { 807 | "config": [ 808 | 23, 809 | 30, 810 | "M", 811 | 57, 812 | 56, 813 | "M", 814 | 85, 815 | 84, 816 | 71, 817 | "M", 818 | 118, 819 | 103, 820 | 18, 821 | "M", 822 | 110, 823 | 105, 824 | 105, 825 | "M" 826 | ], 827 | "acc": 87.56, 828 | "index": 12, 829 | "delta_ts": [ 830 | 0.21317409375, 831 | 0.71314196875, 832 | 0.842275875, 833 | 1.10395859375, 834 | 1.56893678125, 835 | 1.6265519375, 836 | 1.79181265625, 837 | 2.02849375, 838 | 2.22569053125, 839 | 2.246134375, 840 | 2.31960890625, 841 | 2.42451853125, 842 | 2.45926009375, 843 | 2.4608299375, 844 | 2.48308875, 845 | 2.55968309375, 846 | 2.6288260625, 847 | 2.63054246875 848 | ], 849 | "delta_t_prof": 2.45926009375, 850 | "layer_cfg": 18, 851 | "bandwidths": [ 852 | 23552, 853 | 30720, 854 | 7680, 855 | 14592, 856 | 14336, 857 | 3584, 858 | 5440, 859 | 5376, 860 | 4544, 861 | 1136, 862 | 1888, 863 | 1648, 864 | 288, 865 | 72, 866 | 440, 867 | 420, 868 | 420, 869 | 105 870 | ], 871 | "bandwidth": 288 872 | }, 873 | { 874 | "config": [ 875 | 23, 876 | 30, 877 | "M", 878 | 57, 879 | 56, 880 | "M", 881 | 85, 882 | 84, 883 | 71, 884 | "M", 885 | 118, 886 | 103, 887 | 18, 888 | "M", 889 | 110, 890 | 105, 891 | 105, 892 | "M" 893 | ], 894 | "acc": 87.56, 895 | "index": 13, 896 | "delta_ts": [ 897 | 0.1063354375, 898 | 0.59197696875, 899 | 0.73006878125, 900 | 0.9725333125, 901 | 1.4577054375, 902 | 1.52449659375, 903 | 1.6868990625, 904 | 1.92619140625, 905 | 2.13064134375, 906 | 2.15416234375, 907 | 2.22973628125, 908 | 2.33018875, 909 | 2.3635266875, 910 | 2.36506778125, 911 | 2.38610540625, 912 | 2.4534654375, 913 | 2.52396403125, 914 | 2.52565021875 915 | ], 916 | "delta_t_prof": 2.36506778125, 917 | "layer_cfg": "M", 918 | "bandwidths": [ 919 | 23552, 920 | 30720, 921 | 7680, 922 | 14592, 923 | 14336, 924 | 3584, 925 | 5440, 926 | 5376, 927 | 4544, 928 | 1136, 929 | 1888, 930 | 1648, 931 | 288, 932 | 72, 933 | 440, 934 | 420, 935 | 420, 936 | 105 937 | ], 938 | "bandwidth": 72 939 | }, 940 | { 941 | "config": [ 942 | 23, 943 | 30, 944 | "M", 945 | 57, 946 | 56, 947 | "M", 948 | 85, 949 | 84, 950 | 71, 951 | "M", 952 | 118, 953 | 103, 954 | 109, 955 | "M", 956 | 19, 957 | 105, 958 | 105, 959 | "M" 960 | ], 961 | "acc": 87.87, 962 | "index": 14, 963 | "delta_ts": [ 964 | 0.1741975, 965 | 0.851050625, 966 | 1.065806875, 967 | 1.41882309375, 968 | 1.955418, 969 | 2.0408065, 970 | 2.247144125, 971 | 2.516310875, 972 | 2.74876821875, 973 | 2.77344790625, 974 | 2.94988409375, 975 | 3.1604771875, 976 | 3.2936014375, 977 | 3.303463125, 978 | 3.33267325, 979 | 3.35084146875, 980 | 3.46478971875, 981 | 3.4668980625 982 | ], 983 | "delta_t_prof": 3.33267325, 984 | "layer_cfg": 19, 985 | "bandwidths": [ 986 | 23552, 987 | 30720, 988 | 7680, 989 | 14592, 990 | 14336, 991 | 3584, 992 | 5440, 993 | 5376, 994 | 4544, 995 | 1136, 996 | 1888, 997 | 1648, 998 | 1744, 999 | 436, 1000 | 76, 1001 | 420, 1002 | 420, 1003 | 105 1004 | ], 1005 | "bandwidth": 76 1006 | }, 1007 | { 1008 | "config": [ 1009 | 23, 1010 | 30, 1011 | "M", 1012 | 57, 1013 | 56, 1014 | "M", 1015 | 85, 1016 | 84, 1017 | 71, 1018 | "M", 1019 | 118, 1020 | 103, 1021 | 109, 1022 | "M", 1023 | 110, 1024 | 21, 1025 | 105, 1026 | "M" 1027 | ], 1028 | "acc": 87.74, 1029 | "index": 15, 1030 | "delta_ts": [ 1031 | 0.36522203125, 1032 | 0.86526453125, 1033 | 0.99425790625, 1034 | 1.25646390625, 1035 | 1.745508, 1036 | 1.80287996875, 1037 | 1.9669035, 1038 | 2.209551875, 1039 | 2.40764246875, 1040 | 2.42850753125, 1041 | 2.5079519375, 1042 | 2.61232978125, 1043 | 2.7300215625, 1044 | 2.73921765625, 1045 | 2.81909709375, 1046 | 2.84468584375, 1047 | 2.867084625, 1048 | 2.8692801875 1049 | ], 1050 | "delta_t_prof": 2.84468584375, 1051 | "layer_cfg": 21, 1052 | "bandwidths": [ 1053 | 23552, 1054 | 30720, 1055 | 7680, 1056 | 14592, 1057 | 14336, 1058 | 3584, 1059 | 5440, 1060 | 5376, 1061 | 4544, 1062 | 1136, 1063 | 1888, 1064 | 1648, 1065 | 1744, 1066 | 436, 1067 | 440, 1068 | 84, 1069 | 420, 1070 | 105 1071 | ], 1072 | "bandwidth": 84 1073 | }, 1074 | { 1075 | "config": [ 1076 | 23, 1077 | 30, 1078 | "M", 1079 | 57, 1080 | 56, 1081 | "M", 1082 | 85, 1083 | 84, 1084 | 71, 1085 | "M", 1086 | 118, 1087 | 103, 1088 | 109, 1089 | "M", 1090 | 110, 1091 | 105, 1092 | 21, 1093 | "M" 1094 | ], 1095 | "acc": 87.78, 1096 | "index": 16, 1097 | "delta_ts": [ 1098 | 0.0899058125, 1099 | 0.55838415625, 1100 | 0.69053909375, 1101 | 0.924883875, 1102 | 1.3655166875, 1103 | 1.4245826875, 1104 | 1.5785116875, 1105 | 1.80529865625, 1106 | 2.00144540625, 1107 | 2.02110278125, 1108 | 2.0973020625, 1109 | 2.1997386875, 1110 | 2.28759665625, 1111 | 2.2967286875, 1112 | 2.37336146875, 1113 | 2.44474121875, 1114 | 2.470488625, 1115 | 2.47092728125 1116 | ], 1117 | "delta_t_prof": 2.470488625, 1118 | "layer_cfg": 21, 1119 | "bandwidths": [ 1120 | 23552, 1121 | 30720, 1122 | 7680, 1123 | 14592, 1124 | 14336, 1125 | 3584, 1126 | 5440, 1127 | 5376, 1128 | 4544, 1129 | 1136, 1130 | 1888, 1131 | 1648, 1132 | 1744, 1133 | 436, 1134 | 440, 1135 | 420, 1136 | 84, 1137 | 21 1138 | ], 1139 | "bandwidth": 84 1140 | }, 1141 | { 1142 | "config": [ 1143 | 23, 1144 | 30, 1145 | "M", 1146 | 57, 1147 | 56, 1148 | "M", 1149 | 85, 1150 | 84, 1151 | 71, 1152 | "M", 1153 | 118, 1154 | 103, 1155 | 109, 1156 | "M", 1157 | 110, 1158 | 105, 1159 | 21, 1160 | "M" 1161 | ], 1162 | "acc": 87.78, 1163 | "index": 17, 1164 | "delta_ts": [ 1165 | 0.09321615625, 1166 | 0.6091264375, 1167 | 0.74014378125, 1168 | 1.02503734375, 1169 | 1.53554490625, 1170 | 1.59557003125, 1171 | 1.78007659375, 1172 | 2.044504375, 1173 | 2.26794346875, 1174 | 2.28910225, 1175 | 2.4091314375, 1176 | 2.5449301875, 1177 | 2.6697039375, 1178 | 2.68792396875, 1179 | 2.7769520625, 1180 | 2.8629965, 1181 | 2.89086028125, 1182 | 2.89165009375 1183 | ], 1184 | "delta_t_prof": 2.89165009375, 1185 | "layer_cfg": "M", 1186 | "bandwidths": [ 1187 | 23552, 1188 | 30720, 1189 | 7680, 1190 | 14592, 1191 | 14336, 1192 | 3584, 1193 | 5440, 1194 | 5376, 1195 | 4544, 1196 | 1136, 1197 | 1888, 1198 | 1648, 1199 | 1744, 1200 | 436, 1201 | 440, 1202 | 420, 1203 | 84, 1204 | 21 1205 | ], 1206 | "bandwidth": 21 1207 | } 1208 | ] -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Train & Pruning with PyTorch by hou-yz. 3 | ''' 4 | 5 | import torch 6 | import torch.backends.cudnn as cudnn 7 | from torch.autograd import Variable 8 | import torchvision 9 | import torchvision.transforms as transforms 10 | import os 11 | import math 12 | from heapq import nsmallest 13 | from operator import itemgetter 14 | import json 15 | import numpy as np 16 | import argparse 17 | from models import * 18 | from model_refactor import * 19 | 20 | if os.name == 'nt': # windows 21 | num_workers = 0 22 | else: # linux 23 | num_workers = 8 24 | os.environ["CUDA_VISIBLE_DEVICES"] = '1' 25 | 26 | use_cuda = torch.cuda.is_available() 27 | start_epoch = 1 # start from epoch 0 or last checkpoint epoch 28 | total_filter_num_pre_prune = 0 29 | batch_size = 32 30 | 31 | # Data 32 | print('==> Preparing data..') 33 | transform_train = transforms.Compose([ 34 | transforms.RandomCrop(32, padding=4), 35 | transforms.RandomHorizontalFlip(), 36 | transforms.ToTensor(), 37 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 38 | ]) 39 | 40 | transform_test = transforms.Compose([ 41 | transforms.ToTensor(), 42 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 43 | ]) 44 | 45 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) 46 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers) 47 | 48 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) 49 | testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers) 50 | 51 | classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') 52 | 53 | 54 | # Training 55 | def train(optimizer=None, rankfilters=False): 56 | if optimizer is None: 57 | optimizer = torch.optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) 58 | # net.train() 59 | train_loss = 0 60 | correct = 0 61 | total = 0 62 | for batch_idx, (inputs, targets) in enumerate(trainloader): 63 | if use_cuda: 64 | inputs, targets = inputs.cuda(), targets.cuda() 65 | optimizer.zero_grad() 66 | inputs, targets = Variable(inputs), Variable(targets) 67 | if rankfilters: 68 | outputs = pruner.forward(inputs) 69 | loss = criterion(outputs, targets) 70 | loss.backward() 71 | else: 72 | outputs = net(inputs) 73 | loss = criterion(outputs, targets) 74 | loss.backward() 75 | optimizer.step() 76 | 77 | train_loss += loss.data[0] # item() 78 | _, predicted = torch.max(outputs.data, 1) 79 | total += targets.size(0) 80 | correct += predicted.eq(targets.data).cpu().sum() 81 | 82 | print('Train Loss: %.3f | Acc: %.3f%% (%d/%d)' 83 | % (train_loss / (batch_idx + 1), 100. * correct / total, correct, total)) 84 | 85 | 86 | # test 87 | def test(log_index=-1): 88 | # net.eval() 89 | test_loss = 0 90 | correct = 0 91 | total = 0 92 | if log_index == -1 or use_cuda: 93 | for batch_idx, (inputs, targets) in enumerate(testloader): 94 | if use_cuda: 95 | inputs, targets = inputs.cuda(), targets.cuda() 96 | inputs, targets = Variable(inputs, volatile=True), Variable(targets) 97 | outputs = net(inputs) 98 | loss = criterion(outputs, targets) 99 | 100 | test_loss += loss.data[0] # loss.item() 101 | _, predicted = torch.max(outputs.data, 1) 102 | total += targets.size(0) 103 | correct += predicted.eq(targets.data).cpu().sum() 104 | 105 | print('Test Loss: %.3f | Acc: %.3f%% (%d/%d)' % ( 106 | test_loss / (batch_idx + 1), 100. * correct / total, correct, total)) 107 | acc = 100. * correct / total 108 | 109 | if log_index != -1: 110 | (inputs, targets) = list(testloader)[0] 111 | if use_cuda: 112 | inputs, targets = inputs.cuda(), targets.cuda() 113 | # get profile 114 | with torch.autograd.profiler.profile() as prof: 115 | net(Variable(inputs)) 116 | # print(next(net.parameters()).is_cuda) 117 | pruner.forward_n_track(Variable(inputs), log_index) 118 | cfg = pruner.get_cfg() 119 | 120 | # get log for time/bandwidth 121 | delta_ts = [] 122 | bandwidths = [] 123 | for i in range(len(cfg)): 124 | delta_ts.append( 125 | sum(item.cpu_time for item in prof.function_events[:pruner.conv_n_pool_to_layer[i]]) / 126 | np.power(10, 6) / batch_size) 127 | if isinstance(cfg[i], int): 128 | bandwidths.append( 129 | int(cfg[i] * (inputs.shape[2] * inputs.shape[3]) / np.power(4, cfg[:i + 1].count('M')))) 130 | else: 131 | bandwidths.append( 132 | int(cfg[i - 1] * (inputs.shape[2] * inputs.shape[3]) / np.power(4, cfg[:i + 1].count('M')))) 133 | 134 | data = { 135 | 'acc': acc if use_cuda else -1, 136 | 'index': log_index, 137 | 'delta_t_prof': delta_ts[log_index], 138 | 'delta_ts': delta_ts, 139 | 'bandwidth': bandwidths[log_index], 140 | 'bandwidths': bandwidths, 141 | 'layer_cfg': cfg[log_index], 142 | 'config': cfg 143 | } 144 | return data 145 | 146 | return acc 147 | 148 | 149 | # save 150 | def save(acc, conv_index=-1, epoch=-1): 151 | print('Saving..') 152 | try: 153 | # save the cpu model 154 | model = net.module if isinstance(net, torch.nn.DataParallel) else net 155 | state = { 156 | 'net': model.cpu() if use_cuda else model, 157 | 'acc': acc, 158 | 'conv_index': conv_index, 159 | 'epoch': epoch, 160 | } 161 | except: 162 | pass 163 | if not os.path.isdir('checkpoint'): 164 | os.mkdir('checkpoint') 165 | if args.prune: 166 | torch.save(state, './checkpoint/ckpt.prune') 167 | elif args.prune_layer and conv_index != -1: 168 | torch.save(state, './checkpoint/ckpt.prune_layer_%d' % conv_index) 169 | elif epoch != -1: 170 | torch.save(state, './checkpoint/ckpt.train.epoch_' + str(epoch)) 171 | else: 172 | torch.save(state, './checkpoint/ckpt.train') 173 | 174 | # restore the cuda or cpu model 175 | if use_cuda: 176 | net.cuda() 177 | 178 | 179 | class FilterPruner: 180 | def __init__(self, model): 181 | self.model = model 182 | self.reset() 183 | 184 | def reset(self): 185 | self.filter_ranks = {} 186 | 187 | # forward method that gives "compute_rank" a hook 188 | def forward(self, x): 189 | self.activations = [] 190 | self.gradients = [] 191 | self.grad_index = 0 192 | self.activation_to_layer = {} 193 | 194 | conv_index = 0 195 | for layer, (name, module) in enumerate(self.model.features._modules.items()): 196 | x = module(x) 197 | if isinstance(module, torch.nn.modules.Conv2d): 198 | x.register_hook(self.compute_rank) 199 | self.activations.append(x) 200 | self.activation_to_layer[conv_index] = layer 201 | conv_index += 1 202 | 203 | return self.model.classifier(x.view(x.size(0), -1)) 204 | 205 | # forward method that tracks computation info 206 | def forward_n_track(self, x, log_index=-1): 207 | self.conv_n_pool_to_layer = {} 208 | 209 | index = 0 210 | delta_t_computations = 0 211 | all_conv_computations = 0 # num of conv computations to the given layer 212 | t0 = time.time() 213 | for layer, (name, module) in enumerate(self.model.features._modules.items()): 214 | x = module(x) 215 | if isinstance(module, torch.nn.modules.ReLU) or isinstance(module, torch.nn.modules.MaxPool2d): 216 | all_conv_computations += np.prod(x.data.shape[1:]) 217 | self.conv_n_pool_to_layer[index] = layer 218 | if log_index == index: 219 | delta_t = time.time() - t0 220 | delta_t_computations = all_conv_computations 221 | bandwidth = np.prod(x.data.shape[1:]) 222 | index += 1 223 | 224 | return delta_t, delta_t_computations, bandwidth, all_conv_computations 225 | 226 | # for all the conv layers 227 | def get_conv_index_max(self): 228 | conv_index = 0 229 | for layer, (name, module) in enumerate(self.model.features._modules.items()): 230 | if isinstance(module, torch.nn.modules.Conv2d): 231 | conv_index += 1 232 | return conv_index 233 | 234 | # for all the relu layers and pool2d layers 235 | def get_cfg(self): 236 | cfg = [] 237 | for layer, (name, module) in enumerate(self.model.features._modules.items()): 238 | if isinstance(module, torch.nn.modules.Conv2d): 239 | cfg.append(module.out_channels) 240 | elif isinstance(module, torch.nn.modules.MaxPool2d): 241 | cfg.append('M') 242 | return cfg 243 | 244 | def compute_rank(self, grad): 245 | conv_index = len(self.activations) - self.grad_index - 1 246 | activation = self.activations[conv_index] 247 | values = torch.sum((activation * grad), dim=0, keepdim=True).sum(dim=2, keepdim=True).sum(dim=3, keepdim=True)[ 248 | 0, :, 0, 0].data # compute the total 1st order taylor for each filters in a given layer 249 | 250 | # Normalize the rank by the filter dimensions 251 | values = values / (activation.size(0) * activation.size(2) * activation.size(3)) 252 | 253 | if conv_index not in self.filter_ranks: # set self.filter_ranks[conv_index] 254 | self.filter_ranks[conv_index] = torch.FloatTensor(activation.size(1)).zero_() 255 | if use_cuda: 256 | self.filter_ranks[conv_index] = self.filter_ranks[conv_index].cuda() 257 | 258 | self.filter_ranks[conv_index] += values 259 | self.grad_index += 1 260 | 261 | def lowest_ranking_filters(self, num, conv_index): 262 | data = [] 263 | if conv_index == -1: 264 | for i in sorted(self.filter_ranks.keys()): 265 | for j in range(self.filter_ranks[i].size(0)): 266 | data.append((self.activation_to_layer[i], j, self.filter_ranks[i][j])) 267 | else: 268 | for j in range(self.filter_ranks[conv_index].size(0)): 269 | data.append((self.activation_to_layer[conv_index], j, self.filter_ranks[conv_index][j])) 270 | return nsmallest(num, data, itemgetter(2)) # find the minimum of data[_][2], aka, self.filter_ranks[i][j] 271 | 272 | def normalize_ranks_per_layer(self): 273 | for i in self.filter_ranks: 274 | v = torch.abs(self.filter_ranks[i]) 275 | v = v / np.sqrt(torch.sum(v * v)) 276 | self.filter_ranks[i] = v.cpu() 277 | 278 | def get_pruning_plan(self, num_filters_to_prune, conv_index): 279 | filters_to_prune = self.lowest_ranking_filters(num_filters_to_prune, conv_index) 280 | 281 | # After each of the k filters are pruned, 282 | # the filter index of the next filters change since the model is smaller. 283 | filters_to_prune_per_layer = {} 284 | for (l, f, _) in filters_to_prune: 285 | if l not in filters_to_prune_per_layer: 286 | filters_to_prune_per_layer[l] = [] 287 | filters_to_prune_per_layer[l].append(f) 288 | 289 | for l in filters_to_prune_per_layer: 290 | filters_to_prune_per_layer[l] = sorted(filters_to_prune_per_layer[l]) 291 | for i in range(len(filters_to_prune_per_layer[l])): 292 | filters_to_prune_per_layer[l][i] = filters_to_prune_per_layer[l][i] - i 293 | 294 | filters_to_prune = [] 295 | for l in filters_to_prune_per_layer: 296 | for i in filters_to_prune_per_layer[l]: 297 | filters_to_prune.append((l, i)) 298 | 299 | return filters_to_prune 300 | 301 | def get_candidates_to_prune(self, num_filters_to_prune, conv_index): 302 | self.reset() 303 | train(rankfilters=True) 304 | self.normalize_ranks_per_layer() 305 | 306 | return self.get_pruning_plan(num_filters_to_prune, conv_index) 307 | 308 | def total_num_filters(self, conv_index): 309 | filters = 0 310 | i = 0 311 | for name, module in list(self.model.features._modules.items()): 312 | if isinstance(module, torch.nn.modules.Conv2d): 313 | if conv_index == -1: 314 | filters = filters + module.out_channels 315 | elif conv_index == i: 316 | filters = filters + module.out_channels 317 | i = i + 1 318 | 319 | return filters 320 | 321 | def prune(self, conv_index=-1): 322 | # Get the accuracy before pruning 323 | acc_pre_prune = test() 324 | acc = acc_pre_prune 325 | 326 | # train(rankfilters=True) 327 | 328 | # Make sure all the layers are trainable 329 | for param in self.model.features.parameters(): 330 | param.requires_grad = True 331 | 332 | number_of_filters = pruner.total_num_filters(conv_index) 333 | 334 | num_filters_to_prune_per_iteration = math.ceil(number_of_filters / 16) 335 | while acc > acc_pre_prune * 0.95 and pruner.total_num_filters(conv_index) / number_of_filters > 0.2: 336 | # print("Ranking filters.. ") 337 | 338 | prune_targets = pruner.get_candidates_to_prune(num_filters_to_prune_per_iteration, conv_index) 339 | num_layers_pruned = {} # filters to be pruned in each layer 340 | for layer_index, filter_index in prune_targets: 341 | if layer_index not in num_layers_pruned: 342 | num_layers_pruned[layer_index] = 0 343 | num_layers_pruned[layer_index] = num_layers_pruned[layer_index] + 1 344 | 345 | print("Layers that will be pruned", num_layers_pruned) 346 | print("..............Pruning filters............. ") 347 | if use_cuda: 348 | self.model.cpu() 349 | 350 | for layer_index, filter_index in prune_targets: 351 | prune_conv_layer(self.model, layer_index, filter_index) 352 | 353 | if use_cuda: 354 | self.model.cuda() 355 | # self.model = torch.nn.DataParallel(self.model, device_ids=range(torch.cuda.device_count())) 356 | # cudnn.benchmark = True 357 | 358 | optimizer = torch.optim.SGD(self.model.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) 359 | 360 | print("%d / %d Filters remain." % (pruner.total_num_filters(conv_index), number_of_filters)) 361 | # test() 362 | print("Fine tuning to recover from pruning iteration.") 363 | for epoch in range(2): 364 | train(optimizer) 365 | acc = test() 366 | pass 367 | if acc <= acc_pre_prune * 0.95: 368 | pass 369 | 370 | print("Finished. Going to fine tune the model a bit more") 371 | for epoch in range(5): 372 | train(optimizer) 373 | test() 374 | pass 375 | 376 | 377 | def get_args(): 378 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') 379 | parser.add_argument('--lr', default=0.001, type=float, help='learning rate') 380 | parser.add_argument('--epoch', default=10, type=int, help='epoch') 381 | parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint') 382 | parser.add_argument("--train", dest="train", action="store_true") 383 | parser.add_argument("--prune", dest="prune", action="store_true") 384 | parser.add_argument("--prune_layer", dest="prune_layer", action="store_true") 385 | parser.add_argument("--test_pruned", dest="test_pruned", action="store_true") 386 | args = parser.parse_args() 387 | return args 388 | 389 | 390 | if __name__ == '__main__': 391 | args = get_args() 392 | 393 | # Model 394 | if args.train: 395 | print('==> Building model..') 396 | net = VGG('VGG16') 397 | # net = ResNet18() 398 | # net = PreActResNet18() 399 | # net = GoogLeNet() 400 | # net = DenseNet121() 401 | # net = ResNeXt29_2x64d() 402 | # net = MobileNet() 403 | # net = MobileNetV2() 404 | # net = DPN92() 405 | # net = ShuffleNetG2() 406 | # net = SENet18() 407 | else: 408 | # Load checkpoint. 409 | print('==> Resuming from checkpoint..') 410 | assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' 411 | checkpoint = torch.load('./checkpoint/ckpt.train') 412 | net = checkpoint['net'] 413 | acc = checkpoint['acc'] 414 | start_epoch = checkpoint['epoch'] + 1 415 | 416 | if use_cuda: 417 | net.cuda() 418 | # net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count())) 419 | # cudnn.benchmark = True 420 | 421 | criterion = nn.CrossEntropyLoss() 422 | optimizer = torch.optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) 423 | 424 | pruner = FilterPruner(net.module if isinstance(net, torch.nn.DataParallel) else net) 425 | total_filter_num_pre_prune = pruner.total_num_filters(conv_index=-1) 426 | 427 | if args.prune: 428 | pruner.prune() 429 | acc = test() 430 | save(acc) 431 | pass 432 | elif args.prune_layer: 433 | # this is after --prune the whole model 434 | conv_index_max = pruner.get_conv_index_max() 435 | for conv_index in range(conv_index_max): 436 | print('==> Resuming from checkpoint..') 437 | assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' 438 | checkpoint = torch.load('./checkpoint/ckpt.prune') 439 | net = checkpoint['net'] 440 | acc = checkpoint['acc'] 441 | if use_cuda: 442 | net.cuda() 443 | # net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count())) 444 | # cudnn.benchmark = True 445 | # create new pruner in each iteration 446 | pruner = FilterPruner(net.module if isinstance(net, torch.nn.DataParallel) else net) 447 | total_filter_num_pre_prune = pruner.total_num_filters(conv_index=-1) 448 | # prune given layer 449 | pruner.prune(conv_index) 450 | acc = test() 451 | save(acc, conv_index) 452 | pass 453 | elif args.train or args.resume: 454 | for epoch in range(start_epoch, start_epoch + args.epoch): 455 | print('\nEpoch: %d' % epoch) 456 | train() 457 | acc = test() 458 | if epoch % 10 == 0: 459 | save(acc, -1, epoch) 460 | pass 461 | save(acc) 462 | elif args.test_pruned: 463 | use_cuda = 0 464 | cfg = pruner.get_cfg() 465 | conv_index_max = pruner.get_conv_index_max() 466 | original_data = [] 467 | prune_data = [] 468 | prune_layer_data = [] 469 | 470 | last_conv_index = 0 # log for checkpoint restoring, nearest conv layer 471 | for index in range(len(cfg)): 472 | # original 473 | print('==> Resuming from checkpoint..') 474 | assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' 475 | checkpoint = torch.load('./checkpoint/ckpt.train') 476 | net = checkpoint['net'] 477 | acc = checkpoint['acc'] 478 | if use_cuda: 479 | net.cuda() 480 | # net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count())) 481 | # cudnn.benchmark = True 482 | # create new pruner in each iteration 483 | pruner = FilterPruner(net.module if isinstance(net, torch.nn.DataParallel) else net) 484 | total_filter_num_pre_prune = pruner.total_num_filters(conv_index=-1) 485 | data = test(index) 486 | if data['acc'] == -1: 487 | data['acc'] = acc 488 | original_data.append(data) 489 | 490 | # prune 491 | print('==> Resuming from checkpoint..') 492 | assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' 493 | checkpoint = torch.load('./checkpoint/ckpt.prune') 494 | net = checkpoint['net'] 495 | acc = checkpoint['acc'] 496 | if use_cuda: 497 | net.cuda() 498 | # net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count())) 499 | # cudnn.benchmark = True 500 | # create new pruner in each iteration 501 | pruner = FilterPruner(net.module if isinstance(net, torch.nn.DataParallel) else net) 502 | total_filter_num_pre_prune = pruner.total_num_filters(conv_index=-1) 503 | data = test(index) 504 | if data['acc'] == -1: 505 | data['acc'] = acc 506 | prune_data.append(data) 507 | 508 | # prune_layer 509 | print('==> Resuming from checkpoint..') 510 | assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' 511 | checkpoint = torch.load('./checkpoint/ckpt.prune_layer_' + str(last_conv_index)) 512 | # checkpoint = torch.load('./checkpoint/ckpt.prune') 513 | net = checkpoint['net'] 514 | acc = checkpoint['acc'] 515 | if use_cuda: 516 | net.cuda() 517 | # net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count())) 518 | # cudnn.benchmark = True 519 | # create new pruner in each iteration 520 | pruner = FilterPruner(net.module if isinstance(net, torch.nn.DataParallel) else net) 521 | total_filter_num_pre_prune = pruner.total_num_filters(conv_index=-1) 522 | data = test(index) 523 | if data['acc'] == -1: 524 | data['acc'] = acc 525 | prune_layer_data.append(data) 526 | 527 | if index + 1 < len(cfg): 528 | if not isinstance(cfg[index + 1], str): 529 | last_conv_index += 1 530 | 531 | with open('./log_original.json', 'w') as fp: 532 | json.dump(original_data, fp, indent=2) 533 | with open('./log_prune.json', 'w') as fp: 534 | json.dump(prune_data, fp, indent=2) 535 | with open('./log_prune_layer.json', 'w') as fp: 536 | json.dump(prune_layer_data, fp, indent=2) 537 | -------------------------------------------------------------------------------- /model_refactor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | from torchvision import models 4 | import sys 5 | import numpy as np 6 | import os 7 | import time 8 | 9 | 10 | def replace_layers(model, i, indexes, layers): 11 | if i in indexes: 12 | return layers[indexes.index(i)] 13 | return model[i] 14 | 15 | 16 | def prune_conv_layer(model, layer_index, filter_index): 17 | _, conv = list(model.features._modules.items())[layer_index] 18 | batchnorm = None 19 | next_conv = None 20 | offset = 1 21 | 22 | while layer_index + offset < len(list(model.features._modules.items())): # get next conv 23 | res = list(model.features._modules.items())[layer_index + offset] 24 | if isinstance(res[1], torch.nn.modules.conv.Conv2d): 25 | _, next_conv = res 26 | break 27 | offset = offset + 1 28 | 29 | res = list(model.features._modules.items())[layer_index + 1] 30 | if isinstance(res[1], torch.nn.modules.BatchNorm2d): 31 | _, batchnorm = res 32 | 33 | is_bias_present = False 34 | if conv.bias is not None: 35 | is_bias_present = True 36 | 37 | new_conv = \ 38 | torch.nn.Conv2d(in_channels=conv.in_channels, 39 | out_channels=conv.out_channels - 1, 40 | kernel_size=conv.kernel_size, 41 | stride=conv.stride, 42 | padding=conv.padding, 43 | dilation=conv.dilation, 44 | groups=conv.groups, 45 | bias=is_bias_present) 46 | 47 | old_weights = conv.weight.data.cpu().numpy() 48 | new_weights = new_conv.weight.data.cpu().numpy() 49 | 50 | new_weights[: filter_index, :, :, :] = old_weights[: filter_index, :, :, :] 51 | new_weights[filter_index:, :, :, :] = old_weights[filter_index + 1:, :, :, :] 52 | new_conv.weight.data = torch.from_numpy(new_weights).cuda() 53 | 54 | bias_numpy = conv.bias.data.cpu().numpy() 55 | 56 | bias = np.zeros(shape=(bias_numpy.shape[0] - 1), dtype=np.float32) 57 | bias[:filter_index] = bias_numpy[:filter_index] 58 | bias[filter_index:] = bias_numpy[filter_index + 1:] 59 | new_conv.bias.data = torch.from_numpy(bias).cuda() 60 | 61 | if next_conv is not None: 62 | is_bias_present = False 63 | if next_conv.bias is not None: 64 | is_bias_present = True 65 | next_new_conv = \ 66 | torch.nn.Conv2d(in_channels=next_conv.in_channels - 1, 67 | out_channels=next_conv.out_channels, 68 | kernel_size=next_conv.kernel_size, 69 | stride=next_conv.stride, 70 | padding=next_conv.padding, 71 | dilation=next_conv.dilation, 72 | groups=next_conv.groups, 73 | bias=is_bias_present) 74 | 75 | old_weights = next_conv.weight.data.cpu().numpy() 76 | new_weights = next_new_conv.weight.data.cpu().numpy() 77 | 78 | new_weights[:, : filter_index, :, :] = old_weights[:, : filter_index, :, :] 79 | new_weights[:, filter_index:, :, :] = old_weights[:, filter_index + 1:, :, :] 80 | next_new_conv.weight.data = torch.from_numpy(new_weights).cuda() 81 | next_new_conv.bias.data = next_conv.bias.data 82 | 83 | if batchnorm is not None: 84 | new_batchnorm = \ 85 | torch.nn.BatchNorm2d(conv.out_channels - 1) 86 | 87 | try: 88 | old_weights = batchnorm.weight.data.cpu().numpy() 89 | new_weights = new_batchnorm.weight.data.cpu().numpy() 90 | new_weights[:filter_index] = old_weights[:filter_index] 91 | new_weights[filter_index:] = old_weights[filter_index + 1:] 92 | new_batchnorm.weight.data = torch.from_numpy(new_weights).cuda() 93 | 94 | bias_numpy = batchnorm.bias.data.cpu().numpy() 95 | bias = np.zeros(shape=(bias_numpy.shape[0] - 1), dtype=np.float32) 96 | bias[:filter_index] = bias_numpy[:filter_index] 97 | bias[filter_index:] = bias_numpy[filter_index + 1:] 98 | new_batchnorm.bias.data = torch.from_numpy(bias).cuda() 99 | except ValueError: 100 | pass 101 | 102 | 103 | if batchnorm is not None: 104 | features = torch.nn.Sequential( 105 | *(replace_layers(model.features, i, [layer_index + 1], 106 | [new_batchnorm]) for i, _ in enumerate(model.features))) 107 | del model.features 108 | model.features = features 109 | 110 | 111 | if next_conv is not None: 112 | features = torch.nn.Sequential( 113 | *(replace_layers(model.features, i, [layer_index, layer_index + offset], 114 | [new_conv, next_new_conv]) for i, _ in enumerate(model.features))) 115 | 116 | del model.features 117 | del conv 118 | model.features = features 119 | 120 | else: 121 | # Prunning the last conv layer. This affects the first linear layer of the classifier. 122 | model.features = torch.nn.Sequential( 123 | *(replace_layers(model.features, i, [layer_index], 124 | [new_conv]) for i, _ in enumerate(model.features))) 125 | layer_index = 0 126 | old_linear_layer = None 127 | one_layer_classifier = False 128 | for _, module in list(model.classifier._modules.items()): 129 | if isinstance(module, torch.nn.Linear): 130 | old_linear_layer = module 131 | break 132 | layer_index = layer_index + 1 133 | 134 | if isinstance(model.classifier, torch.nn.Linear): 135 | old_linear_layer = model.classifier 136 | one_layer_classifier = True 137 | layer_index = layer_index + 1 138 | 139 | if old_linear_layer is None: 140 | raise BaseException("No linear layer found in classifier") 141 | params_per_input_channel = round(old_linear_layer.in_features / conv.out_channels) 142 | 143 | new_linear_layer = \ 144 | torch.nn.Linear(old_linear_layer.in_features - params_per_input_channel, 145 | old_linear_layer.out_features) 146 | 147 | old_weights = old_linear_layer.weight.data.cpu().numpy() 148 | new_weights = new_linear_layer.weight.data.cpu().numpy() 149 | 150 | new_weights[:, : filter_index * params_per_input_channel] = \ 151 | old_weights[:, : filter_index * params_per_input_channel] 152 | new_weights[:, filter_index * params_per_input_channel:] = \ 153 | old_weights[:, (filter_index + 1) * params_per_input_channel:] 154 | 155 | new_linear_layer.bias.data = old_linear_layer.bias.data 156 | 157 | new_linear_layer.weight.data = torch.from_numpy(new_weights).cuda() 158 | 159 | if one_layer_classifier: 160 | classifier = new_linear_layer 161 | else: 162 | classifier = torch.nn.Sequential( 163 | *(replace_layers(model.classifier, i, [layer_index], 164 | [new_linear_layer]) for i, _ in enumerate(model.classifier))) 165 | 166 | del model.classifier 167 | del next_conv 168 | del conv 169 | model.classifier = classifier 170 | 171 | return # model 172 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .vgg import * 2 | from .dpn import * 3 | from .lenet import * 4 | from .senet import * 5 | from .pnasnet import * 6 | from .densenet import * 7 | from .googlenet import * 8 | from .shufflenet import * 9 | from .resnet import * 10 | from .resnext import * 11 | from .preact_resnet import * 12 | from .mobilenet import * 13 | from .mobilenetv2 import * 14 | -------------------------------------------------------------------------------- /models/densenet.py: -------------------------------------------------------------------------------- 1 | '''DenseNet in PyTorch.''' 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from torch.autograd import Variable 9 | 10 | 11 | class Bottleneck(nn.Module): 12 | def __init__(self, in_planes, growth_rate): 13 | super(Bottleneck, self).__init__() 14 | self.bn1 = nn.BatchNorm2d(in_planes) 15 | self.conv1 = nn.Conv2d(in_planes, 4*growth_rate, kernel_size=1, bias=False) 16 | self.bn2 = nn.BatchNorm2d(4*growth_rate) 17 | self.conv2 = nn.Conv2d(4*growth_rate, growth_rate, kernel_size=3, padding=1, bias=False) 18 | 19 | def forward(self, x): 20 | out = self.conv1(F.relu(self.bn1(x))) 21 | out = self.conv2(F.relu(self.bn2(out))) 22 | out = torch.cat([out,x], 1) 23 | return out 24 | 25 | 26 | class Transition(nn.Module): 27 | def __init__(self, in_planes, out_planes): 28 | super(Transition, self).__init__() 29 | self.bn = nn.BatchNorm2d(in_planes) 30 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False) 31 | 32 | def forward(self, x): 33 | out = self.conv(F.relu(self.bn(x))) 34 | out = F.avg_pool2d(out, 2) 35 | return out 36 | 37 | 38 | class DenseNet(nn.Module): 39 | def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_classes=10): 40 | super(DenseNet, self).__init__() 41 | self.growth_rate = growth_rate 42 | 43 | num_planes = 2*growth_rate 44 | self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, padding=1, bias=False) 45 | 46 | self.dense1 = self._make_dense_layers(block, num_planes, nblocks[0]) 47 | num_planes += nblocks[0]*growth_rate 48 | out_planes = int(math.floor(num_planes*reduction)) 49 | self.trans1 = Transition(num_planes, out_planes) 50 | num_planes = out_planes 51 | 52 | self.dense2 = self._make_dense_layers(block, num_planes, nblocks[1]) 53 | num_planes += nblocks[1]*growth_rate 54 | out_planes = int(math.floor(num_planes*reduction)) 55 | self.trans2 = Transition(num_planes, out_planes) 56 | num_planes = out_planes 57 | 58 | self.dense3 = self._make_dense_layers(block, num_planes, nblocks[2]) 59 | num_planes += nblocks[2]*growth_rate 60 | out_planes = int(math.floor(num_planes*reduction)) 61 | self.trans3 = Transition(num_planes, out_planes) 62 | num_planes = out_planes 63 | 64 | self.dense4 = self._make_dense_layers(block, num_planes, nblocks[3]) 65 | num_planes += nblocks[3]*growth_rate 66 | 67 | self.bn = nn.BatchNorm2d(num_planes) 68 | self.linear = nn.Linear(num_planes, num_classes) 69 | 70 | def _make_dense_layers(self, block, in_planes, nblock): 71 | layers = [] 72 | for i in range(nblock): 73 | layers.append(block(in_planes, self.growth_rate)) 74 | in_planes += self.growth_rate 75 | return nn.Sequential(*layers) 76 | 77 | def forward(self, x): 78 | out = self.conv1(x) 79 | out = self.trans1(self.dense1(out)) 80 | out = self.trans2(self.dense2(out)) 81 | out = self.trans3(self.dense3(out)) 82 | out = self.dense4(out) 83 | out = F.avg_pool2d(F.relu(self.bn(out)), 4) 84 | out = out.view(out.size(0), -1) 85 | out = self.linear(out) 86 | return out 87 | 88 | def DenseNet121(): 89 | return DenseNet(Bottleneck, [6,12,24,16], growth_rate=32) 90 | 91 | def DenseNet169(): 92 | return DenseNet(Bottleneck, [6,12,32,32], growth_rate=32) 93 | 94 | def DenseNet201(): 95 | return DenseNet(Bottleneck, [6,12,48,32], growth_rate=32) 96 | 97 | def DenseNet161(): 98 | return DenseNet(Bottleneck, [6,12,36,24], growth_rate=48) 99 | 100 | def densenet_cifar(): 101 | return DenseNet(Bottleneck, [6,12,24,16], growth_rate=12) 102 | 103 | def test_densenet(): 104 | net = densenet_cifar() 105 | x = torch.randn(1,3,32,32) 106 | y = net(Variable(x)) 107 | print(y) 108 | 109 | # test_densenet() 110 | -------------------------------------------------------------------------------- /models/dpn.py: -------------------------------------------------------------------------------- 1 | '''Dual Path Networks in PyTorch.''' 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from torch.autograd import Variable 7 | 8 | 9 | class Bottleneck(nn.Module): 10 | def __init__(self, last_planes, in_planes, out_planes, dense_depth, stride, first_layer): 11 | super(Bottleneck, self).__init__() 12 | self.out_planes = out_planes 13 | self.dense_depth = dense_depth 14 | 15 | self.conv1 = nn.Conv2d(last_planes, in_planes, kernel_size=1, bias=False) 16 | self.bn1 = nn.BatchNorm2d(in_planes) 17 | self.conv2 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=32, bias=False) 18 | self.bn2 = nn.BatchNorm2d(in_planes) 19 | self.conv3 = nn.Conv2d(in_planes, out_planes+dense_depth, kernel_size=1, bias=False) 20 | self.bn3 = nn.BatchNorm2d(out_planes+dense_depth) 21 | 22 | self.shortcut = nn.Sequential() 23 | if first_layer: 24 | self.shortcut = nn.Sequential( 25 | nn.Conv2d(last_planes, out_planes+dense_depth, kernel_size=1, stride=stride, bias=False), 26 | nn.BatchNorm2d(out_planes+dense_depth) 27 | ) 28 | 29 | def forward(self, x): 30 | out = F.relu(self.bn1(self.conv1(x))) 31 | out = F.relu(self.bn2(self.conv2(out))) 32 | out = self.bn3(self.conv3(out)) 33 | x = self.shortcut(x) 34 | d = self.out_planes 35 | out = torch.cat([x[:,:d,:,:]+out[:,:d,:,:], x[:,d:,:,:], out[:,d:,:,:]], 1) 36 | out = F.relu(out) 37 | return out 38 | 39 | 40 | class DPN(nn.Module): 41 | def __init__(self, cfg): 42 | super(DPN, self).__init__() 43 | in_planes, out_planes = cfg['in_planes'], cfg['out_planes'] 44 | num_blocks, dense_depth = cfg['num_blocks'], cfg['dense_depth'] 45 | 46 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 47 | self.bn1 = nn.BatchNorm2d(64) 48 | self.last_planes = 64 49 | self.layer1 = self._make_layer(in_planes[0], out_planes[0], num_blocks[0], dense_depth[0], stride=1) 50 | self.layer2 = self._make_layer(in_planes[1], out_planes[1], num_blocks[1], dense_depth[1], stride=2) 51 | self.layer3 = self._make_layer(in_planes[2], out_planes[2], num_blocks[2], dense_depth[2], stride=2) 52 | self.layer4 = self._make_layer(in_planes[3], out_planes[3], num_blocks[3], dense_depth[3], stride=2) 53 | self.linear = nn.Linear(out_planes[3]+(num_blocks[3]+1)*dense_depth[3], 10) 54 | 55 | def _make_layer(self, in_planes, out_planes, num_blocks, dense_depth, stride): 56 | strides = [stride] + [1]*(num_blocks-1) 57 | layers = [] 58 | for i,stride in enumerate(strides): 59 | layers.append(Bottleneck(self.last_planes, in_planes, out_planes, dense_depth, stride, i==0)) 60 | self.last_planes = out_planes + (i+2) * dense_depth 61 | return nn.Sequential(*layers) 62 | 63 | def forward(self, x): 64 | out = F.relu(self.bn1(self.conv1(x))) 65 | out = self.layer1(out) 66 | out = self.layer2(out) 67 | out = self.layer3(out) 68 | out = self.layer4(out) 69 | out = F.avg_pool2d(out, 4) 70 | out = out.view(out.size(0), -1) 71 | out = self.linear(out) 72 | return out 73 | 74 | 75 | def DPN26(): 76 | cfg = { 77 | 'in_planes': (96,192,384,768), 78 | 'out_planes': (256,512,1024,2048), 79 | 'num_blocks': (2,2,2,2), 80 | 'dense_depth': (16,32,24,128) 81 | } 82 | return DPN(cfg) 83 | 84 | def DPN92(): 85 | cfg = { 86 | 'in_planes': (96,192,384,768), 87 | 'out_planes': (256,512,1024,2048), 88 | 'num_blocks': (3,4,20,3), 89 | 'dense_depth': (16,32,24,128) 90 | } 91 | return DPN(cfg) 92 | 93 | 94 | def test(): 95 | net = DPN92() 96 | x = Variable(torch.randn(1,3,32,32)) 97 | y = net(x) 98 | print(y) 99 | 100 | # test() 101 | -------------------------------------------------------------------------------- /models/googlenet.py: -------------------------------------------------------------------------------- 1 | '''GoogLeNet with PyTorch.''' 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from torch.autograd import Variable 7 | 8 | 9 | class Inception(nn.Module): 10 | def __init__(self, in_planes, n1x1, n3x3red, n3x3, n5x5red, n5x5, pool_planes): 11 | super(Inception, self).__init__() 12 | # 1x1 conv branch 13 | self.b1 = nn.Sequential( 14 | nn.Conv2d(in_planes, n1x1, kernel_size=1), 15 | nn.BatchNorm2d(n1x1), 16 | nn.ReLU(True), 17 | ) 18 | 19 | # 1x1 conv -> 3x3 conv branch 20 | self.b2 = nn.Sequential( 21 | nn.Conv2d(in_planes, n3x3red, kernel_size=1), 22 | nn.BatchNorm2d(n3x3red), 23 | nn.ReLU(True), 24 | nn.Conv2d(n3x3red, n3x3, kernel_size=3, padding=1), 25 | nn.BatchNorm2d(n3x3), 26 | nn.ReLU(True), 27 | ) 28 | 29 | # 1x1 conv -> 5x5 conv branch 30 | self.b3 = nn.Sequential( 31 | nn.Conv2d(in_planes, n5x5red, kernel_size=1), 32 | nn.BatchNorm2d(n5x5red), 33 | nn.ReLU(True), 34 | nn.Conv2d(n5x5red, n5x5, kernel_size=3, padding=1), 35 | nn.BatchNorm2d(n5x5), 36 | nn.ReLU(True), 37 | nn.Conv2d(n5x5, n5x5, kernel_size=3, padding=1), 38 | nn.BatchNorm2d(n5x5), 39 | nn.ReLU(True), 40 | ) 41 | 42 | # 3x3 pool -> 1x1 conv branch 43 | self.b4 = nn.Sequential( 44 | nn.MaxPool2d(3, stride=1, padding=1), 45 | nn.Conv2d(in_planes, pool_planes, kernel_size=1), 46 | nn.BatchNorm2d(pool_planes), 47 | nn.ReLU(True), 48 | ) 49 | 50 | def forward(self, x): 51 | y1 = self.b1(x) 52 | y2 = self.b2(x) 53 | y3 = self.b3(x) 54 | y4 = self.b4(x) 55 | return torch.cat([y1,y2,y3,y4], 1) 56 | 57 | 58 | class GoogLeNet(nn.Module): 59 | def __init__(self): 60 | super(GoogLeNet, self).__init__() 61 | self.pre_layers = nn.Sequential( 62 | nn.Conv2d(3, 192, kernel_size=3, padding=1), 63 | nn.BatchNorm2d(192), 64 | nn.ReLU(True), 65 | ) 66 | 67 | self.a3 = Inception(192, 64, 96, 128, 16, 32, 32) 68 | self.b3 = Inception(256, 128, 128, 192, 32, 96, 64) 69 | 70 | self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) 71 | 72 | self.a4 = Inception(480, 192, 96, 208, 16, 48, 64) 73 | self.b4 = Inception(512, 160, 112, 224, 24, 64, 64) 74 | self.c4 = Inception(512, 128, 128, 256, 24, 64, 64) 75 | self.d4 = Inception(512, 112, 144, 288, 32, 64, 64) 76 | self.e4 = Inception(528, 256, 160, 320, 32, 128, 128) 77 | 78 | self.a5 = Inception(832, 256, 160, 320, 32, 128, 128) 79 | self.b5 = Inception(832, 384, 192, 384, 48, 128, 128) 80 | 81 | self.avgpool = nn.AvgPool2d(8, stride=1) 82 | self.linear = nn.Linear(1024, 10) 83 | 84 | def forward(self, x): 85 | out = self.pre_layers(x) 86 | out = self.a3(out) 87 | out = self.b3(out) 88 | out = self.maxpool(out) 89 | out = self.a4(out) 90 | out = self.b4(out) 91 | out = self.c4(out) 92 | out = self.d4(out) 93 | out = self.e4(out) 94 | out = self.maxpool(out) 95 | out = self.a5(out) 96 | out = self.b5(out) 97 | out = self.avgpool(out) 98 | out = out.view(out.size(0), -1) 99 | out = self.linear(out) 100 | return out 101 | 102 | # net = GoogLeNet() 103 | # x = torch.randn(1,3,32,32) 104 | # y = net(Variable(x)) 105 | # print(y.size()) 106 | -------------------------------------------------------------------------------- /models/lenet.py: -------------------------------------------------------------------------------- 1 | '''LeNet in PyTorch.''' 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class LeNet(nn.Module): 6 | def __init__(self): 7 | super(LeNet, self).__init__() 8 | self.conv1 = nn.Conv2d(3, 6, 5) 9 | self.conv2 = nn.Conv2d(6, 16, 5) 10 | self.fc1 = nn.Linear(16*5*5, 120) 11 | self.fc2 = nn.Linear(120, 84) 12 | self.fc3 = nn.Linear(84, 10) 13 | 14 | def forward(self, x): 15 | out = F.relu(self.conv1(x)) 16 | out = F.max_pool2d(out, 2) 17 | out = F.relu(self.conv2(out)) 18 | out = F.max_pool2d(out, 2) 19 | out = out.view(out.size(0), -1) 20 | out = F.relu(self.fc1(out)) 21 | out = F.relu(self.fc2(out)) 22 | out = self.fc3(out) 23 | return out 24 | -------------------------------------------------------------------------------- /models/mobilenet.py: -------------------------------------------------------------------------------- 1 | '''MobileNet in PyTorch. 2 | 3 | See the paper "MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications" 4 | for more details. 5 | ''' 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from torch.autograd import Variable 11 | 12 | 13 | class Block(nn.Module): 14 | '''Depthwise conv + Pointwise conv''' 15 | def __init__(self, in_planes, out_planes, stride=1): 16 | super(Block, self).__init__() 17 | self.conv1 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=in_planes, bias=False) 18 | self.bn1 = nn.BatchNorm2d(in_planes) 19 | self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 20 | self.bn2 = nn.BatchNorm2d(out_planes) 21 | 22 | def forward(self, x): 23 | out = F.relu(self.bn1(self.conv1(x))) 24 | out = F.relu(self.bn2(self.conv2(out))) 25 | return out 26 | 27 | 28 | class MobileNet(nn.Module): 29 | # (128,2) means conv planes=128, conv stride=2, by default conv stride=1 30 | cfg = [64, (128,2), 128, (256,2), 256, (512,2), 512, 512, 512, 512, 512, (1024,2), 1024] 31 | 32 | def __init__(self, num_classes=10): 33 | super(MobileNet, self).__init__() 34 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False) 35 | self.bn1 = nn.BatchNorm2d(32) 36 | self.layers = self._make_layers(in_planes=32) 37 | self.linear = nn.Linear(1024, num_classes) 38 | 39 | def _make_layers(self, in_planes): 40 | layers = [] 41 | for x in self.cfg: 42 | out_planes = x if isinstance(x, int) else x[0] 43 | stride = 1 if isinstance(x, int) else x[1] 44 | layers.append(Block(in_planes, out_planes, stride)) 45 | in_planes = out_planes 46 | return nn.Sequential(*layers) 47 | 48 | def forward(self, x): 49 | out = F.relu(self.bn1(self.conv1(x))) 50 | out = self.layers(out) 51 | out = F.avg_pool2d(out, 2) 52 | out = out.view(out.size(0), -1) 53 | out = self.linear(out) 54 | return out 55 | 56 | 57 | def test(): 58 | net = MobileNet() 59 | x = torch.randn(1,3,32,32) 60 | y = net(Variable(x)) 61 | print(y.size()) 62 | 63 | # test() 64 | -------------------------------------------------------------------------------- /models/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | '''MobileNetV2 in PyTorch. 2 | 3 | See the paper "Inverted Residuals and Linear Bottlenecks: 4 | Mobile Networks for Classification, Detection and Segmentation" for more details. 5 | ''' 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from torch.autograd import Variable 11 | 12 | 13 | class Block(nn.Module): 14 | '''expand + depthwise + pointwise''' 15 | def __init__(self, in_planes, out_planes, expansion, stride): 16 | super(Block, self).__init__() 17 | self.stride = stride 18 | 19 | planes = expansion * in_planes 20 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, stride=1, padding=0, bias=False) 21 | self.bn1 = nn.BatchNorm2d(planes) 22 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, groups=planes, bias=False) 23 | self.bn2 = nn.BatchNorm2d(planes) 24 | self.conv3 = nn.Conv2d(planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 25 | self.bn3 = nn.BatchNorm2d(out_planes) 26 | 27 | self.shortcut = nn.Sequential() 28 | if stride == 1 and in_planes != out_planes: 29 | self.shortcut = nn.Sequential( 30 | nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False), 31 | nn.BatchNorm2d(out_planes), 32 | ) 33 | 34 | def forward(self, x): 35 | out = F.relu(self.bn1(self.conv1(x))) 36 | out = F.relu(self.bn2(self.conv2(out))) 37 | out = self.bn3(self.conv3(out)) 38 | out = out + self.shortcut(x) if self.stride==1 else out 39 | return out 40 | 41 | 42 | class MobileNetV2(nn.Module): 43 | # (expansion, out_planes, num_blocks, stride) 44 | cfg = [(1, 16, 1, 1), 45 | (6, 24, 2, 1), # NOTE: change stride 2 -> 1 for CIFAR10 46 | (6, 32, 3, 2), 47 | (6, 64, 4, 2), 48 | (6, 96, 3, 1), 49 | (6, 160, 3, 2), 50 | (6, 320, 1, 1)] 51 | 52 | def __init__(self, num_classes=10): 53 | super(MobileNetV2, self).__init__() 54 | # NOTE: change conv1 stride 2 -> 1 for CIFAR10 55 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False) 56 | self.bn1 = nn.BatchNorm2d(32) 57 | self.layers = self._make_layers(in_planes=32) 58 | self.conv2 = nn.Conv2d(320, 1280, kernel_size=1, stride=1, padding=0, bias=False) 59 | self.bn2 = nn.BatchNorm2d(1280) 60 | self.linear = nn.Linear(1280, num_classes) 61 | 62 | def _make_layers(self, in_planes): 63 | layers = [] 64 | for expansion, out_planes, num_blocks, stride in self.cfg: 65 | strides = [stride] + [1]*(num_blocks-1) 66 | for stride in strides: 67 | layers.append(Block(in_planes, out_planes, expansion, stride)) 68 | in_planes = out_planes 69 | return nn.Sequential(*layers) 70 | 71 | def forward(self, x): 72 | out = F.relu(self.bn1(self.conv1(x))) 73 | out = self.layers(out) 74 | out = F.relu(self.bn2(self.conv2(out))) 75 | # NOTE: change pooling kernel_size 7 -> 4 for CIFAR10 76 | out = F.avg_pool2d(out, 4) 77 | out = out.view(out.size(0), -1) 78 | out = self.linear(out) 79 | return out 80 | 81 | 82 | def test(): 83 | net = MobileNetV2() 84 | x = Variable(torch.randn(2,3,32,32)) 85 | y = net(x) 86 | print(y.size()) 87 | 88 | # test() 89 | -------------------------------------------------------------------------------- /models/pnasnet.py: -------------------------------------------------------------------------------- 1 | '''PNASNet in PyTorch. 2 | 3 | Paper: Progressive Neural Architecture Search 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from torch.autograd import Variable 10 | 11 | 12 | class SepConv(nn.Module): 13 | '''Separable Convolution.''' 14 | def __init__(self, in_planes, out_planes, kernel_size, stride): 15 | super(SepConv, self).__init__() 16 | self.conv1 = nn.Conv2d(in_planes, out_planes, 17 | kernel_size, stride, 18 | padding=(kernel_size-1)//2, 19 | bias=False, groups=in_planes) 20 | self.bn1 = nn.BatchNorm2d(out_planes) 21 | 22 | def forward(self, x): 23 | return self.bn1(self.conv1(x)) 24 | 25 | 26 | class CellA(nn.Module): 27 | def __init__(self, in_planes, out_planes, stride=1): 28 | super(CellA, self).__init__() 29 | self.stride = stride 30 | self.sep_conv1 = SepConv(in_planes, out_planes, kernel_size=7, stride=stride) 31 | if stride==2: 32 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 33 | self.bn1 = nn.BatchNorm2d(out_planes) 34 | 35 | def forward(self, x): 36 | y1 = self.sep_conv1(x) 37 | y2 = F.max_pool2d(x, kernel_size=3, stride=self.stride, padding=1) 38 | if self.stride==2: 39 | y2 = self.bn1(self.conv1(y2)) 40 | return F.relu(y1+y2) 41 | 42 | class CellB(nn.Module): 43 | def __init__(self, in_planes, out_planes, stride=1): 44 | super(CellB, self).__init__() 45 | self.stride = stride 46 | # Left branch 47 | self.sep_conv1 = SepConv(in_planes, out_planes, kernel_size=7, stride=stride) 48 | self.sep_conv2 = SepConv(in_planes, out_planes, kernel_size=3, stride=stride) 49 | # Right branch 50 | self.sep_conv3 = SepConv(in_planes, out_planes, kernel_size=5, stride=stride) 51 | if stride==2: 52 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 53 | self.bn1 = nn.BatchNorm2d(out_planes) 54 | # Reduce channels 55 | self.conv2 = nn.Conv2d(2*out_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 56 | self.bn2 = nn.BatchNorm2d(out_planes) 57 | 58 | def forward(self, x): 59 | # Left branch 60 | y1 = self.sep_conv1(x) 61 | y2 = self.sep_conv2(x) 62 | # Right branch 63 | y3 = F.max_pool2d(x, kernel_size=3, stride=self.stride, padding=1) 64 | if self.stride==2: 65 | y3 = self.bn1(self.conv1(y3)) 66 | y4 = self.sep_conv3(x) 67 | # Concat & reduce channels 68 | b1 = F.relu(y1+y2) 69 | b2 = F.relu(y3+y4) 70 | y = torch.cat([b1,b2], 1) 71 | return F.relu(self.bn2(self.conv2(y))) 72 | 73 | class PNASNet(nn.Module): 74 | def __init__(self, cell_type, num_cells, num_planes): 75 | super(PNASNet, self).__init__() 76 | self.in_planes = num_planes 77 | self.cell_type = cell_type 78 | 79 | self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, stride=1, padding=1, bias=False) 80 | self.bn1 = nn.BatchNorm2d(num_planes) 81 | 82 | self.layer1 = self._make_layer(num_planes, num_cells=6) 83 | self.layer2 = self._downsample(num_planes*2) 84 | self.layer3 = self._make_layer(num_planes*2, num_cells=6) 85 | self.layer4 = self._downsample(num_planes*4) 86 | self.layer5 = self._make_layer(num_planes*4, num_cells=6) 87 | 88 | self.linear = nn.Linear(num_planes*4, 10) 89 | 90 | def _make_layer(self, planes, num_cells): 91 | layers = [] 92 | for _ in range(num_cells): 93 | layers.append(self.cell_type(self.in_planes, planes, stride=1)) 94 | self.in_planes = planes 95 | return nn.Sequential(*layers) 96 | 97 | def _downsample(self, planes): 98 | layer = self.cell_type(self.in_planes, planes, stride=2) 99 | self.in_planes = planes 100 | return layer 101 | 102 | def forward(self, x): 103 | out = F.relu(self.bn1(self.conv1(x))) 104 | out = self.layer1(out) 105 | out = self.layer2(out) 106 | out = self.layer3(out) 107 | out = self.layer4(out) 108 | out = self.layer5(out) 109 | out = F.avg_pool2d(out, 8) 110 | out = self.linear(out.view(out.size(0), -1)) 111 | return out 112 | 113 | 114 | def PNASNetA(): 115 | return PNASNet(CellA, num_cells=6, num_planes=44) 116 | 117 | def PNASNetB(): 118 | return PNASNet(CellB, num_cells=6, num_planes=32) 119 | 120 | 121 | def test(): 122 | net = PNASNetB() 123 | print(net) 124 | x = Variable(torch.randn(1,3,32,32)) 125 | y = net(x) 126 | print(y) 127 | 128 | # test() 129 | -------------------------------------------------------------------------------- /models/preact_resnet.py: -------------------------------------------------------------------------------- 1 | '''Pre-activation ResNet in PyTorch. 2 | 3 | Reference: 4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 5 | Identity Mappings in Deep Residual Networks. arXiv:1603.05027 6 | ''' 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from torch.autograd import Variable 12 | 13 | 14 | class PreActBlock(nn.Module): 15 | '''Pre-activation version of the BasicBlock.''' 16 | expansion = 1 17 | 18 | def __init__(self, in_planes, planes, stride=1): 19 | super(PreActBlock, self).__init__() 20 | self.bn1 = nn.BatchNorm2d(in_planes) 21 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 24 | 25 | if stride != 1 or in_planes != self.expansion*planes: 26 | self.shortcut = nn.Sequential( 27 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 28 | ) 29 | 30 | def forward(self, x): 31 | out = F.relu(self.bn1(x)) 32 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 33 | out = self.conv1(out) 34 | out = self.conv2(F.relu(self.bn2(out))) 35 | out += shortcut 36 | return out 37 | 38 | 39 | class PreActBottleneck(nn.Module): 40 | '''Pre-activation version of the original Bottleneck module.''' 41 | expansion = 4 42 | 43 | def __init__(self, in_planes, planes, stride=1): 44 | super(PreActBottleneck, self).__init__() 45 | self.bn1 = nn.BatchNorm2d(in_planes) 46 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 47 | self.bn2 = nn.BatchNorm2d(planes) 48 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 49 | self.bn3 = nn.BatchNorm2d(planes) 50 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 51 | 52 | if stride != 1 or in_planes != self.expansion*planes: 53 | self.shortcut = nn.Sequential( 54 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 55 | ) 56 | 57 | def forward(self, x): 58 | out = F.relu(self.bn1(x)) 59 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 60 | out = self.conv1(out) 61 | out = self.conv2(F.relu(self.bn2(out))) 62 | out = self.conv3(F.relu(self.bn3(out))) 63 | out += shortcut 64 | return out 65 | 66 | 67 | class PreActResNet(nn.Module): 68 | def __init__(self, block, num_blocks, num_classes=10): 69 | super(PreActResNet, self).__init__() 70 | self.in_planes = 64 71 | 72 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 73 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 74 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 75 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 76 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 77 | self.linear = nn.Linear(512*block.expansion, num_classes) 78 | 79 | def _make_layer(self, block, planes, num_blocks, stride): 80 | strides = [stride] + [1]*(num_blocks-1) 81 | layers = [] 82 | for stride in strides: 83 | layers.append(block(self.in_planes, planes, stride)) 84 | self.in_planes = planes * block.expansion 85 | return nn.Sequential(*layers) 86 | 87 | def forward(self, x): 88 | out = self.conv1(x) 89 | out = self.layer1(out) 90 | out = self.layer2(out) 91 | out = self.layer3(out) 92 | out = self.layer4(out) 93 | out = F.avg_pool2d(out, 4) 94 | out = out.view(out.size(0), -1) 95 | out = self.linear(out) 96 | return out 97 | 98 | 99 | def PreActResNet18(): 100 | return PreActResNet(PreActBlock, [2,2,2,2]) 101 | 102 | def PreActResNet34(): 103 | return PreActResNet(PreActBlock, [3,4,6,3]) 104 | 105 | def PreActResNet50(): 106 | return PreActResNet(PreActBottleneck, [3,4,6,3]) 107 | 108 | def PreActResNet101(): 109 | return PreActResNet(PreActBottleneck, [3,4,23,3]) 110 | 111 | def PreActResNet152(): 112 | return PreActResNet(PreActBottleneck, [3,8,36,3]) 113 | 114 | 115 | def test(): 116 | net = PreActResNet18() 117 | y = net(Variable(torch.randn(1,3,32,32))) 118 | print(y.size()) 119 | 120 | # test() 121 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | 3 | For Pre-activation ResNet, see 'preact_resnet.py'. 4 | 5 | Reference: 6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 8 | ''' 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from torch.autograd import Variable 14 | 15 | 16 | class BasicBlock(nn.Module): 17 | expansion = 1 18 | 19 | def __init__(self, in_planes, planes, stride=1): 20 | super(BasicBlock, self).__init__() 21 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 22 | self.bn1 = nn.BatchNorm2d(planes) 23 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 24 | self.bn2 = nn.BatchNorm2d(planes) 25 | 26 | self.shortcut = nn.Sequential() 27 | if stride != 1 or in_planes != self.expansion*planes: 28 | self.shortcut = nn.Sequential( 29 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 30 | nn.BatchNorm2d(self.expansion*planes) 31 | ) 32 | 33 | def forward(self, x): 34 | out = F.relu(self.bn1(self.conv1(x))) 35 | out = self.bn2(self.conv2(out)) 36 | out += self.shortcut(x) 37 | out = F.relu(out) 38 | return out 39 | 40 | 41 | class Bottleneck(nn.Module): 42 | expansion = 4 43 | 44 | def __init__(self, in_planes, planes, stride=1): 45 | super(Bottleneck, self).__init__() 46 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 47 | self.bn1 = nn.BatchNorm2d(planes) 48 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 49 | self.bn2 = nn.BatchNorm2d(planes) 50 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 51 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 52 | 53 | self.shortcut = nn.Sequential() 54 | if stride != 1 or in_planes != self.expansion*planes: 55 | self.shortcut = nn.Sequential( 56 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 57 | nn.BatchNorm2d(self.expansion*planes) 58 | ) 59 | 60 | def forward(self, x): 61 | out = F.relu(self.bn1(self.conv1(x))) 62 | out = F.relu(self.bn2(self.conv2(out))) 63 | out = self.bn3(self.conv3(out)) 64 | out += self.shortcut(x) 65 | out = F.relu(out) 66 | return out 67 | 68 | 69 | class ResNet(nn.Module): 70 | def __init__(self, block, num_blocks, num_classes=10): 71 | super(ResNet, self).__init__() 72 | self.in_planes = 64 73 | 74 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 75 | self.bn1 = nn.BatchNorm2d(64) 76 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 77 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 78 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 79 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 80 | self.linear = nn.Linear(512*block.expansion, num_classes) 81 | 82 | def _make_layer(self, block, planes, num_blocks, stride): 83 | strides = [stride] + [1]*(num_blocks-1) 84 | layers = [] 85 | for stride in strides: 86 | layers.append(block(self.in_planes, planes, stride)) 87 | self.in_planes = planes * block.expansion 88 | return nn.Sequential(*layers) 89 | 90 | def forward(self, x): 91 | out = F.relu(self.bn1(self.conv1(x))) 92 | out = self.layer1(out) 93 | out = self.layer2(out) 94 | out = self.layer3(out) 95 | out = self.layer4(out) 96 | out = F.avg_pool2d(out, 4) 97 | out = out.view(out.size(0), -1) 98 | out = self.linear(out) 99 | return out 100 | 101 | 102 | def ResNet18(): 103 | return ResNet(BasicBlock, [2,2,2,2]) 104 | 105 | def ResNet34(): 106 | return ResNet(BasicBlock, [3,4,6,3]) 107 | 108 | def ResNet50(): 109 | return ResNet(Bottleneck, [3,4,6,3]) 110 | 111 | def ResNet101(): 112 | return ResNet(Bottleneck, [3,4,23,3]) 113 | 114 | def ResNet152(): 115 | return ResNet(Bottleneck, [3,8,36,3]) 116 | 117 | 118 | def test(): 119 | net = ResNet18() 120 | y = net(Variable(torch.randn(1,3,32,32))) 121 | print(y.size()) 122 | 123 | # test() 124 | -------------------------------------------------------------------------------- /models/resnext.py: -------------------------------------------------------------------------------- 1 | '''ResNeXt in PyTorch. 2 | 3 | See the paper "Aggregated Residual Transformations for Deep Neural Networks" for more details. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from torch.autograd import Variable 10 | 11 | 12 | class Block(nn.Module): 13 | '''Grouped convolution block.''' 14 | expansion = 2 15 | 16 | def __init__(self, in_planes, cardinality=32, bottleneck_width=4, stride=1): 17 | super(Block, self).__init__() 18 | group_width = cardinality * bottleneck_width 19 | self.conv1 = nn.Conv2d(in_planes, group_width, kernel_size=1, bias=False) 20 | self.bn1 = nn.BatchNorm2d(group_width) 21 | self.conv2 = nn.Conv2d(group_width, group_width, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False) 22 | self.bn2 = nn.BatchNorm2d(group_width) 23 | self.conv3 = nn.Conv2d(group_width, self.expansion*group_width, kernel_size=1, bias=False) 24 | self.bn3 = nn.BatchNorm2d(self.expansion*group_width) 25 | 26 | self.shortcut = nn.Sequential() 27 | if stride != 1 or in_planes != self.expansion*group_width: 28 | self.shortcut = nn.Sequential( 29 | nn.Conv2d(in_planes, self.expansion*group_width, kernel_size=1, stride=stride, bias=False), 30 | nn.BatchNorm2d(self.expansion*group_width) 31 | ) 32 | 33 | def forward(self, x): 34 | out = F.relu(self.bn1(self.conv1(x))) 35 | out = F.relu(self.bn2(self.conv2(out))) 36 | out = self.bn3(self.conv3(out)) 37 | out += self.shortcut(x) 38 | out = F.relu(out) 39 | return out 40 | 41 | 42 | class ResNeXt(nn.Module): 43 | def __init__(self, num_blocks, cardinality, bottleneck_width, num_classes=10): 44 | super(ResNeXt, self).__init__() 45 | self.cardinality = cardinality 46 | self.bottleneck_width = bottleneck_width 47 | self.in_planes = 64 48 | 49 | self.conv1 = nn.Conv2d(3, 64, kernel_size=1, bias=False) 50 | self.bn1 = nn.BatchNorm2d(64) 51 | self.layer1 = self._make_layer(num_blocks[0], 1) 52 | self.layer2 = self._make_layer(num_blocks[1], 2) 53 | self.layer3 = self._make_layer(num_blocks[2], 2) 54 | # self.layer4 = self._make_layer(num_blocks[3], 2) 55 | self.linear = nn.Linear(cardinality*bottleneck_width*8, num_classes) 56 | 57 | def _make_layer(self, num_blocks, stride): 58 | strides = [stride] + [1]*(num_blocks-1) 59 | layers = [] 60 | for stride in strides: 61 | layers.append(Block(self.in_planes, self.cardinality, self.bottleneck_width, stride)) 62 | self.in_planes = Block.expansion * self.cardinality * self.bottleneck_width 63 | # Increase bottleneck_width by 2 after each stage. 64 | self.bottleneck_width *= 2 65 | return nn.Sequential(*layers) 66 | 67 | def forward(self, x): 68 | out = F.relu(self.bn1(self.conv1(x))) 69 | out = self.layer1(out) 70 | out = self.layer2(out) 71 | out = self.layer3(out) 72 | # out = self.layer4(out) 73 | out = F.avg_pool2d(out, 8) 74 | out = out.view(out.size(0), -1) 75 | out = self.linear(out) 76 | return out 77 | 78 | 79 | def ResNeXt29_2x64d(): 80 | return ResNeXt(num_blocks=[3,3,3], cardinality=2, bottleneck_width=64) 81 | 82 | def ResNeXt29_4x64d(): 83 | return ResNeXt(num_blocks=[3,3,3], cardinality=4, bottleneck_width=64) 84 | 85 | def ResNeXt29_8x64d(): 86 | return ResNeXt(num_blocks=[3,3,3], cardinality=8, bottleneck_width=64) 87 | 88 | def ResNeXt29_32x4d(): 89 | return ResNeXt(num_blocks=[3,3,3], cardinality=32, bottleneck_width=4) 90 | 91 | def test_resnext(): 92 | net = ResNeXt29_2x64d() 93 | x = torch.randn(1,3,32,32) 94 | y = net(Variable(x)) 95 | print(y.size()) 96 | 97 | # test_resnext() 98 | -------------------------------------------------------------------------------- /models/senet.py: -------------------------------------------------------------------------------- 1 | '''SENet in PyTorch. 2 | 3 | SENet is the winner of ImageNet-2017. The paper is not released yet. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from torch.autograd import Variable 10 | 11 | 12 | class BasicBlock(nn.Module): 13 | def __init__(self, in_planes, planes, stride=1): 14 | super(BasicBlock, self).__init__() 15 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 16 | self.bn1 = nn.BatchNorm2d(planes) 17 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 18 | self.bn2 = nn.BatchNorm2d(planes) 19 | 20 | self.shortcut = nn.Sequential() 21 | if stride != 1 or in_planes != planes: 22 | self.shortcut = nn.Sequential( 23 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False), 24 | nn.BatchNorm2d(planes) 25 | ) 26 | 27 | # SE layers 28 | self.fc1 = nn.Conv2d(planes, planes//16, kernel_size=1) # Use nn.Conv2d instead of nn.Linear 29 | self.fc2 = nn.Conv2d(planes//16, planes, kernel_size=1) 30 | 31 | def forward(self, x): 32 | out = F.relu(self.bn1(self.conv1(x))) 33 | out = self.bn2(self.conv2(out)) 34 | 35 | # Squeeze 36 | w = F.avg_pool2d(out, out.size(2)) 37 | w = F.relu(self.fc1(w)) 38 | w = F.sigmoid(self.fc2(w)) 39 | # Excitation 40 | out = out * w # New broadcasting feature from v0.2! 41 | 42 | out += self.shortcut(x) 43 | out = F.relu(out) 44 | return out 45 | 46 | 47 | class PreActBlock(nn.Module): 48 | def __init__(self, in_planes, planes, stride=1): 49 | super(PreActBlock, self).__init__() 50 | self.bn1 = nn.BatchNorm2d(in_planes) 51 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 52 | self.bn2 = nn.BatchNorm2d(planes) 53 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 54 | 55 | if stride != 1 or in_planes != planes: 56 | self.shortcut = nn.Sequential( 57 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False) 58 | ) 59 | 60 | # SE layers 61 | self.fc1 = nn.Conv2d(planes, planes//16, kernel_size=1) 62 | self.fc2 = nn.Conv2d(planes//16, planes, kernel_size=1) 63 | 64 | def forward(self, x): 65 | out = F.relu(self.bn1(x)) 66 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 67 | out = self.conv1(out) 68 | out = self.conv2(F.relu(self.bn2(out))) 69 | 70 | # Squeeze 71 | w = F.avg_pool2d(out, out.size(2)) 72 | w = F.relu(self.fc1(w)) 73 | w = F.sigmoid(self.fc2(w)) 74 | # Excitation 75 | out = out * w 76 | 77 | out += shortcut 78 | return out 79 | 80 | 81 | class SENet(nn.Module): 82 | def __init__(self, block, num_blocks, num_classes=10): 83 | super(SENet, self).__init__() 84 | self.in_planes = 64 85 | 86 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 87 | self.bn1 = nn.BatchNorm2d(64) 88 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 89 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 90 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 91 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 92 | self.linear = nn.Linear(512, num_classes) 93 | 94 | def _make_layer(self, block, planes, num_blocks, stride): 95 | strides = [stride] + [1]*(num_blocks-1) 96 | layers = [] 97 | for stride in strides: 98 | layers.append(block(self.in_planes, planes, stride)) 99 | self.in_planes = planes 100 | return nn.Sequential(*layers) 101 | 102 | def forward(self, x): 103 | out = F.relu(self.bn1(self.conv1(x))) 104 | out = self.layer1(out) 105 | out = self.layer2(out) 106 | out = self.layer3(out) 107 | out = self.layer4(out) 108 | out = F.avg_pool2d(out, 4) 109 | out = out.view(out.size(0), -1) 110 | out = self.linear(out) 111 | return out 112 | 113 | 114 | def SENet18(): 115 | return SENet(PreActBlock, [2,2,2,2]) 116 | 117 | 118 | def test(): 119 | net = SENet18() 120 | y = net(Variable(torch.randn(1,3,32,32))) 121 | print(y.size()) 122 | 123 | # test() 124 | -------------------------------------------------------------------------------- /models/shufflenet.py: -------------------------------------------------------------------------------- 1 | '''ShuffleNet in PyTorch. 2 | 3 | See the paper "ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices" for more details. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from torch.autograd import Variable 10 | 11 | 12 | class ShuffleBlock(nn.Module): 13 | def __init__(self, groups): 14 | super(ShuffleBlock, self).__init__() 15 | self.groups = groups 16 | 17 | def forward(self, x): 18 | '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]''' 19 | N,C,H,W = x.size() 20 | g = self.groups 21 | return x.view(N,g,C/g,H,W).permute(0,2,1,3,4).contiguous().view(N,C,H,W) 22 | 23 | 24 | class Bottleneck(nn.Module): 25 | def __init__(self, in_planes, out_planes, stride, groups): 26 | super(Bottleneck, self).__init__() 27 | self.stride = stride 28 | 29 | mid_planes = out_planes/4 30 | g = 1 if in_planes==24 else groups 31 | self.conv1 = nn.Conv2d(in_planes, mid_planes, kernel_size=1, groups=g, bias=False) 32 | self.bn1 = nn.BatchNorm2d(mid_planes) 33 | self.shuffle1 = ShuffleBlock(groups=g) 34 | self.conv2 = nn.Conv2d(mid_planes, mid_planes, kernel_size=3, stride=stride, padding=1, groups=mid_planes, bias=False) 35 | self.bn2 = nn.BatchNorm2d(mid_planes) 36 | self.conv3 = nn.Conv2d(mid_planes, out_planes, kernel_size=1, groups=groups, bias=False) 37 | self.bn3 = nn.BatchNorm2d(out_planes) 38 | 39 | self.shortcut = nn.Sequential() 40 | if stride == 2: 41 | self.shortcut = nn.Sequential(nn.AvgPool2d(3, stride=2, padding=1)) 42 | 43 | def forward(self, x): 44 | out = F.relu(self.bn1(self.conv1(x))) 45 | out = self.shuffle1(out) 46 | out = F.relu(self.bn2(self.conv2(out))) 47 | out = self.bn3(self.conv3(out)) 48 | res = self.shortcut(x) 49 | out = F.relu(torch.cat([out,res], 1)) if self.stride==2 else F.relu(out+res) 50 | return out 51 | 52 | 53 | class ShuffleNet(nn.Module): 54 | def __init__(self, cfg): 55 | super(ShuffleNet, self).__init__() 56 | out_planes = cfg['out_planes'] 57 | num_blocks = cfg['num_blocks'] 58 | groups = cfg['groups'] 59 | 60 | self.conv1 = nn.Conv2d(3, 24, kernel_size=1, bias=False) 61 | self.bn1 = nn.BatchNorm2d(24) 62 | self.in_planes = 24 63 | self.layer1 = self._make_layer(out_planes[0], num_blocks[0], groups) 64 | self.layer2 = self._make_layer(out_planes[1], num_blocks[1], groups) 65 | self.layer3 = self._make_layer(out_planes[2], num_blocks[2], groups) 66 | self.linear = nn.Linear(out_planes[2], 10) 67 | 68 | def _make_layer(self, out_planes, num_blocks, groups): 69 | layers = [] 70 | for i in range(num_blocks): 71 | stride = 2 if i == 0 else 1 72 | cat_planes = self.in_planes if i == 0 else 0 73 | layers.append(Bottleneck(self.in_planes, out_planes-cat_planes, stride=stride, groups=groups)) 74 | self.in_planes = out_planes 75 | return nn.Sequential(*layers) 76 | 77 | def forward(self, x): 78 | out = F.relu(self.bn1(self.conv1(x))) 79 | out = self.layer1(out) 80 | out = self.layer2(out) 81 | out = self.layer3(out) 82 | out = F.avg_pool2d(out, 4) 83 | out = out.view(out.size(0), -1) 84 | out = self.linear(out) 85 | return out 86 | 87 | 88 | def ShuffleNetG2(): 89 | cfg = { 90 | 'out_planes': [200,400,800], 91 | 'num_blocks': [4,8,4], 92 | 'groups': 2 93 | } 94 | return ShuffleNet(cfg) 95 | 96 | def ShuffleNetG3(): 97 | cfg = { 98 | 'out_planes': [240,480,960], 99 | 'num_blocks': [4,8,4], 100 | 'groups': 3 101 | } 102 | return ShuffleNet(cfg) 103 | 104 | 105 | def test(): 106 | net = ShuffleNetG2() 107 | x = Variable(torch.randn(1,3,32,32)) 108 | y = net(x) 109 | print(y) 110 | 111 | # test() 112 | -------------------------------------------------------------------------------- /models/vgg.py: -------------------------------------------------------------------------------- 1 | '''VGG11/13/16/19 in Pytorch.''' 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | 6 | 7 | cfg = { 8 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 9 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 10 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 11 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 12 | } 13 | 14 | 15 | class VGG(nn.Module): 16 | def __init__(self, vgg_name): 17 | super(VGG, self).__init__() 18 | self.features = self._make_layers(cfg[vgg_name]) 19 | self.classifier = nn.Linear(512, 10) 20 | 21 | def forward(self, x): 22 | out = self.features(x) 23 | out = out.view(out.size(0), -1) 24 | out = self.classifier(out) 25 | return out 26 | 27 | def _make_layers(self, cfg): 28 | layers = [] 29 | in_channels = 3 30 | for x in cfg: 31 | if x == 'M': 32 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 33 | else: 34 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1), 35 | nn.BatchNorm2d(x), 36 | nn.ReLU(inplace=True)] 37 | in_channels = x 38 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)] 39 | return nn.Sequential(*layers) 40 | 41 | # net = VGG('VGG11') 42 | # x = torch.randn(2,3,32,32) 43 | # print(net(Variable(x)).size()) 44 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | - msr_init: net parameter initialization. 4 | - progress_bar: progress bar mimic xlua.progress. 5 | ''' 6 | import os 7 | import sys 8 | import time 9 | import math 10 | 11 | import torch.nn as nn 12 | import torch.nn.init as init 13 | 14 | 15 | def get_mean_and_std(dataset): 16 | '''Compute the mean and std value of dataset.''' 17 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 18 | mean = torch.zeros(3) 19 | std = torch.zeros(3) 20 | print('==> Computing mean and std..') 21 | for inputs, targets in dataloader: 22 | for i in range(3): 23 | mean[i] += inputs[:,i,:,:].mean() 24 | std[i] += inputs[:,i,:,:].std() 25 | mean.div_(len(dataset)) 26 | std.div_(len(dataset)) 27 | return mean, std 28 | 29 | def init_params(net): 30 | '''Init layer parameters.''' 31 | for m in net.modules(): 32 | if isinstance(m, nn.Conv2d): 33 | init.kaiming_normal(m.weight, mode='fan_out') 34 | if m.bias: 35 | init.constant(m.bias, 0) 36 | elif isinstance(m, nn.BatchNorm2d): 37 | init.constant(m.weight, 1) 38 | init.constant(m.bias, 0) 39 | elif isinstance(m, nn.Linear): 40 | init.normal(m.weight, std=1e-3) 41 | if m.bias: 42 | init.constant(m.bias, 0) 43 | 44 | 45 | _, term_width = os.popen('stty size', 'r').read().split() 46 | term_width = int(term_width) 47 | 48 | TOTAL_BAR_LENGTH = 65. 49 | last_time = time.time() 50 | begin_time = last_time 51 | def progress_bar(current, total, msg=None): 52 | global last_time, begin_time 53 | if current == 0: 54 | begin_time = time.time() # Reset for new bar. 55 | 56 | cur_len = int(TOTAL_BAR_LENGTH*current/total) 57 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 58 | 59 | sys.stdout.write(' [') 60 | for i in range(cur_len): 61 | sys.stdout.write('=') 62 | sys.stdout.write('>') 63 | for i in range(rest_len): 64 | sys.stdout.write('.') 65 | sys.stdout.write(']') 66 | 67 | cur_time = time.time() 68 | step_time = cur_time - last_time 69 | last_time = cur_time 70 | tot_time = cur_time - begin_time 71 | 72 | L = [] 73 | L.append(' Step: %s' % format_time(step_time)) 74 | L.append(' | Tot: %s' % format_time(tot_time)) 75 | if msg: 76 | L.append(' | ' + msg) 77 | 78 | msg = ''.join(L) 79 | sys.stdout.write(msg) 80 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 81 | sys.stdout.write(' ') 82 | 83 | # Go back to the center of the bar. 84 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): 85 | sys.stdout.write('\b') 86 | sys.stdout.write(' %d/%d ' % (current+1, total)) 87 | 88 | if current < total-1: 89 | sys.stdout.write('\r') 90 | else: 91 | sys.stdout.write('\n') 92 | sys.stdout.flush() 93 | 94 | def format_time(seconds): 95 | days = int(seconds / 3600/24) 96 | seconds = seconds - days*3600*24 97 | hours = int(seconds / 3600) 98 | seconds = seconds - hours*3600 99 | minutes = int(seconds / 60) 100 | seconds = seconds - minutes*60 101 | secondsf = int(seconds) 102 | seconds = seconds - secondsf 103 | millis = int(seconds*1000) 104 | 105 | f = '' 106 | i = 1 107 | if days > 0: 108 | f += str(days) + 'D' 109 | i += 1 110 | if hours > 0 and i <= 2: 111 | f += str(hours) + 'h' 112 | i += 1 113 | if minutes > 0 and i <= 2: 114 | f += str(minutes) + 'm' 115 | i += 1 116 | if secondsf > 0 and i <= 2: 117 | f += str(secondsf) + 's' 118 | i += 1 119 | if millis > 0 and i <= 2: 120 | f += str(millis) + 'ms' 121 | i += 1 122 | if f == '': 123 | f = '0ms' 124 | return f 125 | --------------------------------------------------------------------------------