├── LICENSE ├── README.md └── inception_v1.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 antspy 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # inception_v1.pytorch 2 | An implementation of inception_v1 on pytorch with pretrained weights. 3 | 4 | This code is a pytorch translation of the soumith torch repo: https://github.com/soumith/inception.torch 5 | It implements the original version of the inception architechture; what is known has GoogLeNet. 6 | 7 | Pretrained weights can be found at https://mega.nz/#!4RJE1SSY!kcCDyhkum6EQqVtqTc-deHnQuckM3zYSYq16bADbfww 8 | 9 | # Disclaimer 10 | Test accuracy of the pretrained model on imagenet is only 26.38%. If I am not mistaken, this is an issue of the original torch repo - the data loading is done correctly. If you train this model to better accuracy, I would love to get the new set of weights! 11 | 12 | # License 13 | The code is licensed under the MIT Licence. See the [LICENSE](LICENSE) file for detail. 14 | -------------------------------------------------------------------------------- /inception_v1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import torch.nn.functional as F 5 | import os 6 | 7 | # Data handling: 8 | # normalize = transforms.Normalize(mean=[0.4588, 0.4588, 0.4588], 9 | # std=[1, 1, 1]) 10 | # ... 11 | # val_loader = torch.utils.data.DataLoader( 12 | # datasets.ImageFolder(valdir, transforms.Compose([ 13 | # transforms.Scale(256), 14 | # transforms.CenterCrop(224), 15 | # transforms.ToTensor(), 16 | # normalize 17 | # ])), 18 | # batch_size=args.batch_size, shuffle=False, 19 | # num_workers=args.workers, pin_memory=True) 20 | 21 | def layer_init(m): 22 | classname = m.__class__.__name__ 23 | classname = classname.lower() 24 | if classname.find('conv') != -1 or classname.find('linear') != -1: 25 | gain = nn.init.calculate_gain(classname) 26 | nn.init.xavier_uniform(m.weight, gain=gain) 27 | if m.bias is not None: 28 | nn.init.constant(m.bias, 0) 29 | elif classname.find('batchnorm') != -1: 30 | nn.init.constant(m.weight, 1) 31 | if m.bias is not None: 32 | nn.init.constant(m.bias, 0) 33 | elif classname.find('embedding') != -1: 34 | # The default initializer in the TensorFlow embedding layer is a truncated normal with mean 0 and 35 | # standard deviation 1/sqrt(sparse_id_column.length). Here we use a normal truncated with 3 std dev 36 | num_columns = m.weight.size(1) 37 | sigma = 1/(num_columns**0.5) 38 | m.weight.data.normal_(0, sigma).clamp_(-3*sigma, 3*sigma) 39 | 40 | class LRN(nn.Module): 41 | 42 | ''' 43 | Implementing Local Response Normalization layer. Implemention adapted 44 | from https://github.com/jiecaoyu/pytorch_imagenet/blob/master/networks/model_list/alexnet.py 45 | ''' 46 | 47 | def __init__(self, local_size=1, alpha=1.0, beta=0.75, k=1, ACROSS_CHANNELS=False): 48 | super(LRN, self).__init__() 49 | self.ACROSS_CHANNELS = ACROSS_CHANNELS 50 | if ACROSS_CHANNELS: 51 | self.average=nn.AvgPool3d(kernel_size=(local_size, 1, 1), 52 | stride=1, 53 | padding=(int((local_size-1.0)/2), 0, 0)) 54 | else: 55 | self.average=nn.AvgPool2d(kernel_size=local_size, 56 | stride=1, 57 | padding=int((local_size-1.0)/2)) 58 | self.alpha = alpha 59 | self.beta = beta 60 | self.k = k 61 | 62 | def forward(self, x): 63 | if self.ACROSS_CHANNELS: 64 | div = x.pow(2).unsqueeze(1) 65 | div = self.average(div).squeeze(1) 66 | div = div.mul(self.alpha).add(self.k).pow(self.beta) 67 | else: 68 | div = x.pow(2) 69 | div = self.average(div) 70 | div = div.mul(self.alpha).add(self.k).pow(self.beta) 71 | x = x.div(div) 72 | return x 73 | 74 | class Inception_base(nn.Module): 75 | def __init__(self, depth_dim, input_size, config): 76 | super(Inception_base, self).__init__() 77 | 78 | self.depth_dim = depth_dim 79 | 80 | #mixed 'name'_1x1 81 | self.conv1 = nn.Conv2d(input_size, out_channels=config[0][0], kernel_size=1, stride=1, padding=0) 82 | 83 | #mixed 'name'_3x3_bottleneck 84 | self.conv3_1 = nn.Conv2d(input_size, out_channels=config[1][0], kernel_size=1, stride=1, padding=0) 85 | #mixed 'name'_3x3 86 | self.conv3_3 = nn.Conv2d(config[1][0], config[1][1], kernel_size=3, stride=1, padding=1) 87 | 88 | # mixed 'name'_5x5_bottleneck 89 | self.conv5_1 = nn.Conv2d(input_size, out_channels=config[2][0], kernel_size=1, stride=1, padding=0) 90 | # mixed 'name'_5x5 91 | self.conv5_5 = nn.Conv2d(config[2][0], config[2][1], kernel_size=5, stride=1, padding=2) 92 | 93 | self.max_pool_1 = nn.MaxPool2d(kernel_size=config[3][0], stride=1, padding=1) 94 | #mixed 'name'_pool_reduce 95 | self.conv_max_1 = nn.Conv2d(input_size, out_channels=config[3][1], kernel_size=1, stride=1, padding=0) 96 | 97 | self.apply(layer_init) 98 | 99 | def forward(self, input): 100 | 101 | output1 = F.relu(self.conv1(input)) 102 | 103 | output2 = F.relu(self.conv3_1(input)) 104 | output2 = F.relu(self.conv3_3(output2)) 105 | 106 | output3 = F.relu(self.conv5_1(input)) 107 | output3 = F.relu(self.conv5_5(output3)) 108 | 109 | output4 = F.relu(self.conv_max_1(self.max_pool_1(input))) 110 | 111 | return torch.cat([output1, output2, output3, output4], dim=self.depth_dim) 112 | 113 | # weights available at t https://github.com/antspy/inception_v1.pytorch 114 | class Inception_v1(nn.Module): 115 | def __init__(self, num_classes=1000): 116 | super(Inception_v1, self).__init__() 117 | 118 | #conv2d0 119 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) 120 | self.max_pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 121 | self.lrn1 = LRN(local_size=11, alpha=0.00109999999404, beta=0.5, k=2) 122 | 123 | #conv2d1 124 | self.conv2 = nn.Conv2d(64, 64, kernel_size=1, stride=1, padding=0) 125 | 126 | #conv2d2 127 | self.conv3 = nn.Conv2d(64, 192, kernel_size=3, stride=1, padding=1) 128 | self.lrn3 = LRN(local_size=11, alpha=0.00109999999404, beta=0.5, k=2) 129 | self.max_pool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 130 | 131 | self.inception_3a = Inception_base(1, 192, [[64], [96,128], [16, 32], [3, 32]]) #3a 132 | self.inception_3b = Inception_base(1, 256, [[128], [128,192], [32, 96], [3, 64]]) #3b 133 | self.max_pool_inc3= nn.MaxPool2d(kernel_size=3, stride=2, padding=0) 134 | 135 | self.inception_4a = Inception_base(1, 480, [[192], [ 96,204], [16, 48], [3, 64]]) #4a 136 | self.inception_4b = Inception_base(1, 508, [[160], [112,224], [24, 64], [3, 64]]) #4b 137 | self.inception_4c = Inception_base(1, 512, [[128], [128,256], [24, 64], [3, 64]]) #4c 138 | self.inception_4d = Inception_base(1, 512, [[112], [144,288], [32, 64], [3, 64]]) #4d 139 | self.inception_4e = Inception_base(1, 528, [[256], [160,320], [32,128], [3,128]]) #4e 140 | self.max_pool_inc4 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 141 | 142 | self.inception_5a = Inception_base(1, 832, [[256], [160,320], [48,128], [3,128]]) #5a 143 | self.inception_5b = Inception_base(1, 832, [[384], [192,384], [48,128], [3,128]]) #5b 144 | self.avg_pool5 = nn.AvgPool2d(kernel_size=7, stride=1, padding=0) 145 | 146 | self.dropout_layer = nn.Dropout(0.4) 147 | self.fc = nn.Linear(1024, num_classes) 148 | 149 | self.apply(layer_init) 150 | 151 | def forward(self, input): 152 | 153 | output = self.max_pool1(F.relu(self.conv1(input))) 154 | output = self.lrn1(output) 155 | 156 | output = F.relu(self.conv2(output)) 157 | output = F.relu(self.conv3(output)) 158 | output = self.max_pool3(self.lrn3(output)) 159 | 160 | output = self.inception_3a(output) 161 | output = self.inception_3b(output) 162 | output = self.max_pool_inc3(output) 163 | 164 | output = self.inception_4a(output) 165 | output = self.inception_4b(output) 166 | output = self.inception_4c(output) 167 | output = self.inception_4d(output) 168 | output = self.inception_4e(output) 169 | output = self.max_pool_inc4(output) 170 | 171 | output = self.inception_5a(output) 172 | output = self.inception_5b(output) 173 | output = self.avg_pool5(output) 174 | 175 | output = output.view(-1, 1024) 176 | 177 | if self.fc is not None: 178 | output = self.dropout_layer(output) 179 | output = self.fc(output) 180 | 181 | return output 182 | 183 | 184 | def inception_v1_pretrained(path_to_weights='default'): 185 | model = Inception_v1(num_classes=1000) 186 | if path_to_weights.lower() == 'default': 187 | _currDir = os.path.dirname(os.path.abspath(__file__)) 188 | path_to_weights = os.path.join(_currDir, 'inception_v1_weights.pth') 189 | model.load_state_dict(torch.load(path_to_weights)) 190 | return model 191 | 192 | 193 | 194 | 195 | # ==================== Code used to load the weights from torch dump ========================== 196 | ind = {0: 278, 1: 212, 2: 250, 3: 193, 4: 217, 5: 147, 6: 387, 7: 285, 8: 350, 9: 283, 10: 286, 11: 353, 197 | 12: 334, 13: 150, 14: 249, 15: 362, 16: 246, 17: 166, 18: 218, 19: 172, 20: 177, 21: 148, 22: 357, 198 | 23: 386, 24: 178, 25: 202, 26: 194, 27: 271, 28: 229, 29: 290, 30: 175, 31: 163, 32: 191, 33: 276, 199 | 34: 299, 35: 197, 36: 380, 37: 364, 38: 339, 39: 359, 40: 251, 41: 165, 42: 157, 43: 361, 44: 179, 200 | 45: 268, 46: 233, 47: 356, 48: 266, 49: 264, 50: 225, 51: 349, 52: 335, 53: 375, 54: 282, 55: 204, 201 | 56: 352, 57: 272, 58: 187, 59: 256, 60: 294, 61: 277, 62: 174, 63: 234, 64: 351, 65: 176, 66: 280, 202 | 67: 223, 68: 154, 69: 262, 70: 203, 71: 190, 72: 370, 73: 298, 74: 384, 75: 292, 76: 170, 77: 342, 203 | 78: 241, 79: 340, 80: 348, 81: 245, 82: 365, 83: 253, 84: 288, 85: 239, 86: 153, 87: 185, 88: 158, 204 | 89: 211, 90: 192, 91: 382, 92: 224, 93: 216, 94: 284, 95: 367, 96: 228, 97: 160, 98: 152, 99: 376, 205 | 100: 338, 101: 270, 102: 296, 103: 366, 104: 169, 105: 265, 106: 183, 107: 345, 108: 199, 109: 244, 206 | 110: 381, 111: 236, 112: 195, 113: 238, 114: 240, 115: 155, 116: 221, 117: 259, 118: 181, 119: 343, 207 | 120: 354, 121: 369, 122: 196, 123: 231, 124: 207, 125: 184, 126: 252, 127: 232, 128: 331, 129: 242, 208 | 130: 201, 131: 162, 132: 255, 133: 210, 134: 371, 135: 274, 136: 372, 137: 373, 138: 209, 139: 243, 209 | 140: 222, 141: 378, 142: 254, 143: 206, 144: 186, 145: 205, 146: 341, 147: 261, 148: 248, 149: 215, 210 | 150: 267, 151: 189, 152: 289, 153: 214, 154: 273, 155: 198, 156: 333, 157: 200, 158: 279, 159: 188, 211 | 160: 161, 161: 346, 162: 295, 163: 332, 164: 347, 165: 379, 166: 344, 167: 260, 168: 388, 169: 180, 212 | 170: 230, 171: 257, 172: 151, 173: 281, 174: 377, 175: 208, 176: 247, 177: 363, 178: 258, 179: 164, 213 | 180: 168, 181: 358, 182: 336, 183: 227, 184: 368, 185: 355, 186: 237, 187: 330, 188: 171, 189: 291, 214 | 190: 219, 191: 213, 192: 149, 193: 385, 194: 337, 195: 220, 196: 263, 197: 156, 198: 383, 199: 159, 215 | 200: 287, 201: 275, 202: 374, 203: 173, 204: 269, 205: 293, 206: 167, 207: 226, 208: 297, 209: 182, 216 | 210: 235, 211: 360, 212: 105, 213: 101, 214: 102, 215: 104, 216: 103, 217: 106, 218: 763, 219: 879, 217 | 220: 780, 221: 805, 222: 401, 223: 310, 224: 327, 225: 117, 226: 579, 227: 620, 228: 949, 229: 404, 218 | 230: 895, 231: 405, 232: 417, 233: 812, 234: 554, 235: 576, 236: 814, 237: 625, 238: 472, 239: 914, 219 | 240: 484, 241: 871, 242: 510, 243: 628, 244: 724, 245: 403, 246: 833, 247: 913, 248: 586, 249: 847, 220 | 250: 657, 251: 450, 252: 537, 253: 444, 254: 671, 255: 565, 256: 705, 257: 428, 258: 791, 259: 670, 221 | 260: 561, 261: 547, 262: 820, 263: 408, 264: 407, 265: 436, 266: 468, 267: 511, 268: 609, 269: 627, 222 | 270: 656, 271: 661, 272: 751, 273: 817, 274: 573, 275: 575, 276: 665, 277: 803, 278: 555, 279: 569, 223 | 280: 717, 281: 864, 282: 867, 283: 675, 284: 734, 285: 757, 286: 829, 287: 802, 288: 866, 289: 660, 224 | 290: 870, 291: 880, 292: 603, 293: 612, 294: 690, 295: 431, 296: 516, 297: 520, 298: 564, 299: 453, 225 | 300: 495, 301: 648, 302: 493, 303: 846, 304: 553, 305: 703, 306: 423, 307: 857, 308: 559, 309: 765, 226 | 310: 831, 311: 861, 312: 526, 313: 736, 314: 532, 315: 548, 316: 894, 317: 948, 318: 950, 319: 951, 227 | 320: 952, 321: 953, 322: 954, 323: 955, 324: 956, 325: 957, 326: 988, 327: 989, 328: 998, 329: 984, 228 | 330: 987, 331: 990, 332: 687, 333: 881, 334: 494, 335: 541, 336: 577, 337: 641, 338: 642, 339: 822, 229 | 340: 420, 341: 486, 342: 889, 343: 594, 344: 402, 345: 546, 346: 513, 347: 566, 348: 875, 349: 593, 230 | 350: 684, 351: 699, 352: 432, 353: 683, 354: 776, 355: 558, 356: 985, 357: 986, 358: 972, 359: 979, 231 | 360: 970, 361: 980, 362: 976, 363: 977, 364: 973, 365: 975, 366: 978, 367: 974, 368: 596, 369: 499, 232 | 370: 623, 371: 726, 372: 740, 373: 621, 374: 587, 375: 512, 376: 473, 377: 731, 378: 784, 379: 792, 233 | 380: 730, 381: 491, 382: 7, 383: 8, 384: 9, 385: 10, 386: 11, 387: 12, 388: 13, 389: 14, 390: 15, 234 | 391: 16, 392: 17, 393: 18, 394: 19, 395: 20, 396: 21, 397: 22, 398: 23, 399: 24, 400: 80, 401: 81, 235 | 402: 82, 403: 83, 404: 84, 405: 85, 406: 86, 407: 87, 408: 88, 409: 89, 410: 90, 411: 91, 412: 92, 236 | 413: 93, 414: 94, 415: 95, 416: 96, 417: 97, 418: 98, 419: 99, 420: 100, 421: 127, 422: 128, 237 | 423: 129, 424: 130, 425: 132, 426: 131, 427: 133, 428: 134, 429: 135, 430: 137, 431: 138, 432: 139, 238 | 433: 140, 434: 141, 435: 142, 436: 143, 437: 136, 438: 144, 439: 145, 440: 146, 441: 2, 442: 3, 239 | 443: 4, 444: 5, 445: 6, 446: 389, 447: 391, 448: 0, 449: 1, 450: 390, 451: 392, 452: 393, 453: 396, 240 | 454: 397, 455: 394, 456: 395, 457: 33, 458: 34, 459: 35, 460: 36, 461: 37, 462: 38, 463: 39, 464: 40, 241 | 465: 41, 466: 42, 467: 43, 468: 44, 469: 45, 470: 46, 471: 47, 472: 48, 473: 51, 474: 49, 475: 50, 242 | 476: 52, 477: 53, 478: 54, 479: 55, 480: 56, 481: 57, 482: 58, 483: 59, 484: 60, 485: 61, 486: 62, 243 | 487: 63, 488: 64, 489: 65, 490: 66, 491: 67, 492: 68, 493: 25, 494: 26, 495: 27, 496: 28, 497: 29, 244 | 498: 30, 499: 31, 500: 32, 501: 902, 502: 908, 503: 696, 504: 589, 505: 691, 506: 801, 507: 632, 245 | 508: 650, 509: 782, 510: 673, 511: 545, 512: 686, 513: 828, 514: 811, 515: 827, 516: 583, 517: 426, 246 | 518: 769, 519: 685, 520: 778, 521: 409, 522: 530, 523: 892, 524: 604, 525: 835, 526: 704, 527: 826, 247 | 528: 531, 529: 823, 530: 845, 531: 635, 532: 447, 533: 745, 534: 837, 535: 633, 536: 755, 537: 456, 248 | 538: 471, 539: 413, 540: 764, 541: 744, 542: 508, 543: 878, 544: 517, 545: 626, 546: 398, 547: 480, 249 | 548: 798, 549: 527, 550: 590, 551: 681, 552: 916, 553: 595, 554: 856, 555: 742, 556: 800, 557: 886, 250 | 558: 786, 559: 613, 560: 844, 561: 600, 562: 479, 563: 694, 564: 723, 565: 739, 566: 571, 567: 476, 251 | 568: 843, 569: 758, 570: 753, 571: 746, 572: 592, 573: 836, 574: 714, 575: 475, 576: 807, 577: 761, 252 | 578: 535, 579: 464, 580: 584, 581: 616, 582: 507, 583: 695, 584: 677, 585: 772, 586: 783, 587: 676, 253 | 588: 785, 589: 795, 590: 470, 591: 607, 592: 818, 593: 862, 594: 678, 595: 718, 596: 872, 597: 645, 254 | 598: 674, 599: 815, 600: 69, 601: 70, 602: 71, 603: 72, 604: 73, 605: 74, 606: 75, 607: 76, 608: 77, 255 | 609: 78, 610: 79, 611: 126, 612: 118, 613: 119, 614: 120, 615: 121, 616: 122, 617: 123, 618: 124, 256 | 619: 125, 620: 300, 621: 301, 622: 302, 623: 303, 624: 304, 625: 305, 626: 306, 627: 307, 628: 308, 257 | 629: 309, 630: 311, 631: 312, 632: 313, 633: 314, 634: 315, 635: 316, 636: 317, 637: 318, 638: 319, 258 | 639: 320, 640: 321, 641: 322, 642: 323, 643: 324, 644: 325, 645: 326, 646: 107, 647: 108, 648: 109, 259 | 649: 110, 650: 111, 651: 112, 652: 113, 653: 114, 654: 115, 655: 116, 656: 328, 657: 329, 658: 606, 260 | 659: 550, 660: 651, 661: 544, 662: 766, 663: 859, 664: 891, 665: 882, 666: 534, 667: 760, 668: 897, 261 | 669: 521, 670: 567, 671: 909, 672: 469, 673: 505, 674: 849, 675: 813, 676: 406, 677: 873, 678: 706, 262 | 679: 821, 680: 839, 681: 888, 682: 425, 683: 580, 684: 698, 685: 663, 686: 624, 687: 410, 688: 449, 263 | 689: 497, 690: 668, 691: 832, 692: 727, 693: 762, 694: 498, 695: 598, 696: 634, 697: 506, 698: 682, 264 | 699: 863, 700: 483, 701: 743, 702: 582, 703: 415, 704: 424, 705: 454, 706: 467, 707: 509, 708: 788, 265 | 709: 860, 710: 865, 711: 562, 712: 500, 713: 915, 714: 536, 715: 458, 716: 649, 717: 421, 718: 460, 266 | 719: 525, 720: 489, 721: 716, 722: 912, 723: 825, 724: 581, 725: 799, 726: 877, 727: 672, 728: 781, 267 | 729: 599, 730: 729, 731: 708, 732: 437, 733: 935, 734: 945, 735: 936, 736: 937, 737: 938, 738: 939, 268 | 739: 940, 740: 941, 741: 942, 742: 943, 743: 944, 744: 946, 745: 947, 746: 794, 747: 608, 748: 478, 269 | 749: 591, 750: 774, 751: 412, 752: 771, 753: 923, 754: 679, 755: 522, 756: 568, 757: 855, 758: 697, 270 | 759: 770, 760: 503, 761: 492, 762: 640, 763: 662, 764: 876, 765: 868, 766: 416, 767: 931, 768: 741, 271 | 769: 614, 770: 926, 771: 901, 772: 615, 773: 921, 774: 816, 775: 796, 776: 440, 777: 518, 778: 455, 272 | 779: 858, 780: 643, 781: 638, 782: 712, 783: 560, 784: 433, 785: 850, 786: 597, 787: 737, 788: 713, 273 | 789: 887, 790: 918, 791: 574, 792: 927, 793: 834, 794: 900, 795: 552, 796: 501, 797: 966, 798: 542, 274 | 799: 787, 800: 496, 801: 601, 802: 922, 803: 819, 804: 452, 805: 962, 806: 429, 807: 551, 808: 777, 275 | 809: 838, 810: 441, 811: 996, 812: 924, 813: 619, 814: 911, 815: 958, 816: 457, 817: 636, 818: 899, 276 | 819: 463, 820: 533, 821: 809, 822: 969, 823: 666, 824: 869, 825: 693, 826: 488, 827: 840, 828: 659, 277 | 829: 964, 830: 907, 831: 789, 832: 465, 833: 540, 834: 446, 835: 474, 836: 841, 837: 738, 838: 448, 278 | 839: 588, 840: 722, 841: 709, 842: 707, 843: 925, 844: 411, 845: 747, 846: 414, 847: 982, 848: 439, 279 | 849: 710, 850: 462, 851: 669, 852: 399, 853: 667, 854: 735, 855: 523, 856: 732, 857: 810, 858: 968, 280 | 859: 752, 860: 920, 861: 749, 862: 754, 863: 961, 864: 524, 865: 652, 866: 629, 867: 793, 868: 664, 281 | 869: 688, 870: 658, 871: 459, 872: 930, 873: 883, 874: 653, 875: 768, 876: 700, 877: 995, 878: 549, 282 | 879: 655, 880: 515, 881: 874, 882: 711, 883: 435, 884: 934, 885: 991, 886: 466, 887: 721, 888: 999, 283 | 889: 481, 890: 477, 891: 618, 892: 994, 893: 631, 894: 585, 895: 400, 896: 538, 897: 519, 898: 903, 284 | 899: 965, 900: 720, 901: 490, 902: 854, 903: 905, 904: 427, 905: 896, 906: 418, 907: 430, 908: 434, 285 | 909: 514, 910: 578, 911: 904, 912: 992, 913: 487, 914: 680, 915: 422, 916: 637, 917: 617, 918: 556, 286 | 919: 654, 920: 692, 921: 646, 922: 733, 923: 602, 924: 808, 925: 715, 926: 756, 927: 893, 928: 482, 287 | 929: 917, 930: 719, 931: 919, 932: 442, 933: 563, 934: 906, 935: 890, 936: 689, 937: 775, 938: 748, 288 | 939: 451, 940: 443, 941: 701, 942: 797, 943: 851, 944: 842, 945: 647, 946: 967, 947: 963, 948: 461, 289 | 949: 790, 950: 910, 951: 773, 952: 960, 953: 981, 954: 572, 955: 993, 956: 830, 957: 898, 958: 528, 290 | 959: 804, 960: 610, 961: 779, 962: 611, 963: 728, 964: 759, 965: 529, 966: 419, 967: 929, 968: 885, 291 | 969: 852, 970: 570, 971: 539, 972: 630, 973: 928, 974: 932, 975: 750, 976: 639, 977: 848, 978: 502, 292 | 979: 605, 980: 997, 981: 983, 982: 725, 983: 644, 984: 445, 985: 806, 986: 485, 987: 622, 988: 853, 293 | 989: 884, 990: 438, 991: 971, 992: 933, 993: 702, 994: 557, 995: 504, 996: 767, 997: 824, 998: 959, 294 | 999: 543} 295 | ind = {val:key for key, val in ind.items()} #actually need the inverse indices 296 | 297 | dictionary_default_pytorch_names_to_correct_names_full = { 298 | 'conv1':'conv2d0', 299 | 'conv2':'conv2d1', 300 | 'conv3':'conv2d2', 301 | 'fc':'softmax2' 302 | } 303 | 304 | dictionary_default_pytorch_names_to_correct_names_base = { 305 | 'conv1': 'mixed{}_1x1', 306 | 'conv3_1': 'mixed{}_3x3_bottleneck', 307 | 'conv3_3': 'mixed{}_3x3', 308 | 'conv5_1': 'mixed{}_5x5_bottleneck', 309 | 'conv5_5': 'mixed{}_5x5', 310 | 'conv_max_1': 'mixed{}_pool_reduce' 311 | } 312 | 313 | def load_weights_from_dump(model, dump_folder): 314 | # For this to work we need the h5py package 315 | import h5py 316 | import numpy as np 317 | 318 | 'Loads the weights saved as h5py files in the soumith repo linked above. Just here for completeness' 319 | 320 | dump_folder = os.path.abspath(dump_folder) 321 | files_list = [os.path.join(dump_folder, x) for x in os.listdir(dump_folder)] 322 | 323 | for name, layer in model.named_parameters(): 324 | # get path from name 325 | if 'inception' in name: 326 | first_dot = name.find('.') 327 | name_inception = name[:first_dot].replace('inception_', '') 328 | name_layer = name[first_dot + 1:name.find('.', first_dot + 1)] 329 | name_layer = dictionary_default_pytorch_names_to_correct_names_base[name_layer].format(name_inception) 330 | else: 331 | name_layer = name[:name.find('.')] 332 | name_layer = dictionary_default_pytorch_names_to_correct_names_full[name_layer] 333 | if 'weight' in name: 334 | filename = name_layer + '_w.h5' 335 | else: 336 | filename = name_layer + '_b.h5' 337 | 338 | filename = os.path.join(dump_folder, filename) 339 | if filename in files_list: 340 | files_list.remove(filename) 341 | else: 342 | print('file {} not found in files list'.format(filename)) 343 | 344 | # print(filename, 'exists', os.path.isfile(filename)) 345 | 346 | f = h5py.File(filename, 'r') 347 | a_group_key = list(f.keys())[0] 348 | w = np.asarray(list(f[a_group_key])) 349 | f.close() 350 | 351 | w = torch.from_numpy(w) 352 | if 'weight' in name: 353 | w = w.transpose(1, 3).transpose(2, 3).clone() 354 | w = w.type_as(layer.data) 355 | if name_layer == 'softmax2': 356 | #Adjust the size - because google has 1008 classes, class 1 - 1000 are valid 357 | if 'weight' in name: 358 | w = w[1:1001, :] 359 | else: 360 | w = w[1:1001] 361 | 362 | #and re-arrange the indices - the torch repo had another order for the indices 363 | ind_list = [ind[idx] for idx in range(1000)] 364 | idx_t = torch.FloatTensor(ind_list).long() 365 | w = w.squeeze() 366 | w = torch.index_select(w, dim=0, index=idx_t) 367 | 368 | if layer.data.size() != w.size(): 369 | raise ValueError('Incompatible sizes') 370 | 371 | layer.data = w 372 | 373 | print('Number of unused files: {}'.format(len(files_list))) 374 | return model 375 | --------------------------------------------------------------------------------