├── .gitignore ├── cam.jpg ├── sample.jpg ├── data.py ├── README.md ├── update.py ├── train.py ├── main.py └── inception.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .idea 3 | kaggle/* 4 | result*/ 5 | result* 6 | -------------------------------------------------------------------------------- /cam.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaeyoung-lee/pytorch-CAM/HEAD/cam.jpg -------------------------------------------------------------------------------- /sample.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaeyoung-lee/pytorch-CAM/HEAD/sample.jpg -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import io 3 | import requests 4 | from PIL import Image 5 | 6 | def read_data(img_num, txt, idx): 7 | f = open(txt, 'r') 8 | IMG_URLs = [] 9 | for i in range(img_num): 10 | line = f.readline() 11 | url = line.split()[idx] 12 | IMG_URLs = np.append(IMG_URLs,url) 13 | f.close() 14 | return IMG_URLs 15 | 16 | def get_img(num, IMG_URL, root): 17 | response = requests.get(IMG_URL) 18 | img_pil = Image.open(io.BytesIO(response.content)) 19 | #img_pil.save(str(idx) + '.jpg') 20 | img_pil.save(root + str(num) + '.jpg') 21 | return img_pil 22 | 23 | """ 24 | # 이미지넷 이미지 추출 25 | IMG_URLs = read_data(1071, 'data/ear.txt', 0) 26 | 27 | for i in range(len(IMG_URLs)): 28 | try: 29 | print(i) 30 | IMG_URL = IMG_URLs[i] 31 | img = get_img(i, IMG_URL, root='image/train/ear/') 32 | if img == 0: continue 33 | print() 34 | except: 35 | print('에러') 36 | print() 37 | """ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-CAM 2 | This repository is an unofficial version of Class Activation Mapping written in PyTorch, modified for a simple use case. 3 | 4 | ## Class Activation Mapping (CAM) 5 | Paper and Archiecture: [Learning Deep Features for Discriminative Localization][1] 6 | 7 | Paper Author Implementation: [metalbubble/CAM][2] 8 | 9 | In the paper: 10 | 11 | *We propose a technique for generating class activation maps using the global average pooling (GAP) in CNNs. A class activation map for a particular category indicates the discriminative image regions used by the CNN to identify that category. The procedure for generating these maps is illustrated as follows:* 12 | 13 |
14 |

15 |
16 | 17 | *Class activation maps could be used to intepret the prediction decision made by the CNN. The left image below shows the class activation map of top 5 predictions respectively, you can see that the CNN is triggered by different semantic regions of the image for different predictions. The right image below shows the CNN learns to localize the common visual patterns for the same object class.* 18 | 19 |
20 |

21 |
22 | 23 | 24 | ## Code Description 25 | **Usage**: `python3 main.py` 26 | 27 | **Network**: Inception V3 28 | 29 | **Data**: [Kaggle dogs vs. cats][3] 30 | - Download the 'test1.zip' and 'train.zip' files and upzip them. 31 | - Divde the total dataset into train group and test group. As you do that, [images must be arranged in this way][4]: 32 | ``` 33 | kaggle/train/cat/*.jpg 34 | kaggle/test/cat/*.jpg 35 | ``` 36 | 37 | **Checkpoint** 38 | - Checkpoint will be created in the checkpoint folder every ten epoch. 39 | - By setting `RESUME = #`, you can resume from `checkpoint/#.pt`. 40 | 41 | [1]: https://arxiv.org/abs/1512.04150 42 | [2]: https://github.com/metalbubble/CAM 43 | [3]: https://www.kaggle.com/c/dogs-vs-cats/data 44 | [4]: http://pytorch.org/docs/master/torchvision/datasets.html#imagefolder 45 | -------------------------------------------------------------------------------- /update.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from torch.autograd import Variable 3 | from torch.nn import functional as F 4 | import numpy as np 5 | import cv2, torch 6 | 7 | # generate class activation mapping for the top1 prediction 8 | def returnCAM(feature_conv, weight_softmax, class_idx): 9 | # generate the class activation maps upsample to 256x256 10 | size_upsample = (256, 256) 11 | bz, nc, h, w = feature_conv.shape 12 | output_cam = [] 13 | for idx in class_idx: 14 | cam = weight_softmax[class_idx].dot(feature_conv.reshape((nc, h*w))) 15 | cam = cam.reshape(h, w) 16 | cam = cam - np.min(cam) 17 | cam_img = cam / np.max(cam) 18 | cam_img = np.uint8(255 * cam_img) 19 | output_cam.append(cv2.resize(cam_img, size_upsample)) 20 | return output_cam 21 | 22 | def get_cam(net, features_blobs, img_pil, classes, root_img): 23 | params = list(net.parameters()) 24 | weight_softmax = np.squeeze(params[-2].data.cpu().numpy()) 25 | 26 | normalize = transforms.Normalize( 27 | mean=[0.485, 0.456, 0.406], 28 | std=[0.229, 0.224, 0.225] 29 | ) 30 | preprocess = transforms.Compose([ 31 | transforms.Resize((224, 224)), 32 | transforms.ToTensor(), 33 | normalize 34 | ]) 35 | 36 | img_tensor = preprocess(img_pil) 37 | img_variable = Variable(img_tensor.unsqueeze(0)).cuda() 38 | logit = net(img_variable) 39 | 40 | h_x = F.softmax(logit, dim=1).data.squeeze() 41 | probs, idx = h_x.sort(0, True) 42 | 43 | # output: the prediction 44 | for i in range(0, 2): 45 | line = '{:.3f} -> {}'.format(probs[i], classes[idx[i].item()]) 46 | print(line) 47 | 48 | CAMs = returnCAM(features_blobs[0], weight_softmax, [idx[0].item()]) 49 | 50 | # render the CAM and output 51 | print('output CAM.jpg for the top1 prediction: %s' % classes[idx[0].item()]) 52 | img = cv2.imread(root_img) 53 | height, width, _ = img.shape 54 | CAM = cv2.resize(CAMs[0], (width, height)) 55 | heatmap = cv2.applyColorMap(CAM, cv2.COLORMAP_JET) 56 | result = heatmap * 0.3 + img * 0.5 57 | cv2.imwrite('cam.jpg', result) 58 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | 4 | def retrain(trainloader, model, use_cuda, epoch, criterion, optimizer): 5 | model.train() 6 | correct, total = 0, 0 7 | acc_sum, loss_sum = 0, 0 8 | i = 0 9 | for batch_idx, (data, target) in enumerate(trainloader): 10 | if use_cuda: 11 | data, target = data.cuda(), target.cuda() 12 | data, target = Variable(data), Variable(target) 13 | optimizer.zero_grad() 14 | output = model(data) 15 | 16 | # calculate accuracy 17 | correct += (torch.max(output, 1)[1].view(target.size()).data == target.data).sum() 18 | total += trainloader.batch_size 19 | train_acc = 100. * correct / total 20 | acc_sum += train_acc 21 | i += 1 22 | 23 | loss = criterion(output, target) 24 | loss.backward() 25 | optimizer.step() 26 | loss_sum += loss.item() 27 | 28 | if batch_idx % 10 == 0: 29 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.3f}\tTraining Accuracy: {:.3f}%'.format( 30 | epoch, batch_idx * len(data), len(trainloader.dataset), 31 | 100. * batch_idx / len(trainloader), loss.item(), train_acc)) 32 | 33 | acc_avg = acc_sum / i 34 | loss_avg = loss_sum / len(trainloader.dataset) 35 | print() 36 | print('Train Epoch: {}\tAverage Loss: {:.3f}\tAverage Accuracy: {:.3f}%'.format(epoch, loss_avg, acc_avg)) 37 | 38 | with open('result/train_acc.txt', 'a') as f: 39 | f.write(str(acc_avg)) 40 | f.close() 41 | with open('result/train_loss.txt', 'a') as f: 42 | f.write(str(loss_avg)) 43 | f.close() 44 | 45 | def retest(testloader, model, use_cuda, criterion, epoch, RESUME): 46 | model.eval() 47 | test_loss = 0 48 | correct = 0 49 | for data, target in testloader: 50 | if use_cuda: 51 | data, target = data.cuda(), target.cuda() 52 | data, target = Variable(data, volatile=True), Variable(target) 53 | output = model(data) 54 | # sum up batch loss 55 | test_loss += criterion(output, target).item() 56 | # get the index of the max log-probability 57 | pred = output.data.max(1, keepdim=True)[1] 58 | correct += pred.eq(target.data.view_as(pred)).cpu().sum() 59 | 60 | test_loss /= len(testloader.dataset) 61 | test_acc = 100. * correct / len(testloader.dataset) 62 | result = '\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.3f}%)\n'.format( 63 | test_loss, correct, len(testloader.dataset), test_acc) 64 | print(result) 65 | 66 | # Save checkpoint. 67 | if epoch % 10 == 0: 68 | torch.save(model.state_dict(), 'checkpoint/' + str(RESUME + int(epoch / 10)) + '.pt') 69 | with open('result/result.txt', 'a') as f: 70 | f.write(result) 71 | f.close() 72 | 73 | with open('result/test_acc.txt', 'a') as f: 74 | f.write(str(test_acc)) 75 | f.close() 76 | with open('result/test_loss.txt', 'a') as f: 77 | f.write(str(test_loss)) 78 | f.close() 79 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Class Activation Mapping 3 | Googlenet, Kaggle data 4 | """ 5 | 6 | from update import * 7 | from data import * 8 | from train import * 9 | import torch, os 10 | from torch.utils.data import DataLoader 11 | from torchvision import datasets, transforms 12 | from inception import inception_v3 13 | 14 | 15 | # functions 16 | CAM = 1 17 | USE_CUDA = 1 18 | RESUME = 0 19 | PRETRAINED = 0 20 | 21 | 22 | # hyperparameters 23 | BATCH_SIZE = 32 24 | IMG_SIZE = 224 25 | LEARNING_RATE = 0.01 26 | EPOCH = 0 27 | 28 | 29 | # prepare data 30 | normalize = transforms.Normalize( 31 | mean=[0.485, 0.456, 0.406], 32 | std=[0.229, 0.224, 0.225] 33 | ) 34 | 35 | transform_train = transforms.Compose([ 36 | transforms.RandomResizedCrop(224), 37 | transforms.RandomHorizontalFlip(), 38 | transforms.ToTensor(), 39 | normalize 40 | ]) 41 | 42 | transform_test = transforms.Compose([ 43 | transforms.Resize(256), 44 | transforms.CenterCrop(224), 45 | transforms.ToTensor(), 46 | normalize 47 | ]) 48 | 49 | train_data = datasets.ImageFolder('kaggle/train/', transform=transform_train) 50 | trainloader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=2) 51 | 52 | test_data = datasets.ImageFolder('kaggle/test/', transform=transform_test) 53 | testloader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=2) 54 | 55 | 56 | # class 57 | classes = {0: 'cat', 1: 'dog'} 58 | 59 | 60 | # fine tuning 61 | if PRETRAINED: 62 | net = inception_v3(pretrained=PRETRAINED) 63 | for param in net.parameters(): 64 | param.requires_grad = False 65 | net.fc = torch.nn.Linear(2048, 2) 66 | else: 67 | net = inception_v3(pretrained=PRETRAINED, num_classes=len(classes)) 68 | final_conv = 'Mixed_7c' 69 | 70 | net.cuda() 71 | 72 | 73 | # load checkpoint 74 | if RESUME != 0: 75 | print("===> Resuming from checkpoint.") 76 | assert os.path.isfile('checkpoint/'+ str(RESUME) + '.pt'), 'Error: no checkpoint found!' 77 | net.load_state_dict(torch.load('checkpoint/' + str(RESUME) + '.pt')) 78 | 79 | 80 | # retrain 81 | criterion = torch.nn.CrossEntropyLoss() 82 | 83 | if PRETRAINED: 84 | optimizer = torch.optim.SGD(net.fc.parameters(), lr=LEARNING_RATE, momentum=0.9, weight_decay=5e-4) 85 | else: 86 | optimizer = torch.optim.SGD(net.parameters(), lr=LEARNING_RATE, momentum=0.9, weight_decay=5e-4) 87 | 88 | for epoch in range (1, EPOCH + 1): 89 | retrain(trainloader, net, USE_CUDA, epoch, criterion, optimizer) 90 | retest(testloader, net, USE_CUDA, criterion, epoch, RESUME) 91 | 92 | 93 | # hook the feature extractor 94 | features_blobs = [] 95 | 96 | def hook_feature(module, input, output): 97 | features_blobs.append(output.data.cpu().numpy()) 98 | 99 | net._modules.get(final_conv).register_forward_hook(hook_feature) 100 | 101 | 102 | # CAM 103 | if CAM: 104 | root = 'sample.jpg' 105 | img = Image.open(root) 106 | get_cam(net, features_blobs, img, classes, root) 107 | -------------------------------------------------------------------------------- /inception.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.utils.model_zoo as model_zoo 5 | 6 | 7 | __all__ = ['Inception3', 'inception_v3'] 8 | 9 | 10 | model_urls = { 11 | # Inception v3 ported from TensorFlow 12 | 'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth', 13 | } 14 | 15 | 16 | def inception_v3(pretrained=False, **kwargs): 17 | r"""Inception v3 model architecture from 18 | `"Rethinking the Inception Architecture for Computer Vision" `_. 19 | 20 | Args: 21 | pretrained (bool): If True, returns a model pre-trained on ImageNet 22 | """ 23 | if pretrained: 24 | if 'transform_input' not in kwargs: 25 | kwargs['transform_input'] = True 26 | model = Inception3(**kwargs) 27 | model.load_state_dict(model_zoo.load_url(model_urls['inception_v3_google'])) 28 | return model 29 | 30 | return Inception3(**kwargs) 31 | 32 | 33 | class Inception3(nn.Module): 34 | 35 | def __init__(self, num_classes=1000, aux_logits=True, transform_input=False): 36 | super(Inception3, self).__init__() 37 | self.aux_logits = aux_logits 38 | self.transform_input = transform_input 39 | self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2) 40 | self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3) 41 | self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1) 42 | self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1) 43 | self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3) 44 | self.Mixed_5b = InceptionA(192, pool_features=32) 45 | self.Mixed_5c = InceptionA(256, pool_features=64) 46 | self.Mixed_5d = InceptionA(288, pool_features=64) 47 | self.Mixed_6a = InceptionB(288) 48 | self.Mixed_6b = InceptionC(768, channels_7x7=128) 49 | self.Mixed_6c = InceptionC(768, channels_7x7=160) 50 | self.Mixed_6d = InceptionC(768, channels_7x7=160) 51 | self.Mixed_6e = InceptionC(768, channels_7x7=192) 52 | if aux_logits: 53 | self.AuxLogits = InceptionAux(768, num_classes) 54 | self.Mixed_7a = InceptionD(768) 55 | self.Mixed_7b = InceptionE(1280) 56 | self.Mixed_7c = InceptionE(2048) 57 | self.fc = nn.Linear(2048, num_classes) 58 | 59 | for m in self.modules(): 60 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 61 | import scipy.stats as stats 62 | stddev = m.stddev if hasattr(m, 'stddev') else 0.1 63 | X = stats.truncnorm(-2, 2, scale=stddev) 64 | values = torch.Tensor(X.rvs(m.weight.data.numel())) 65 | values = values.view(m.weight.data.size()) 66 | m.weight.data.copy_(values) 67 | elif isinstance(m, nn.BatchNorm2d): 68 | m.weight.data.fill_(1) 69 | m.bias.data.zero_() 70 | 71 | def forward(self, x): 72 | if self.transform_input: 73 | x = x.clone() 74 | x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 75 | x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 76 | x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 77 | # 299 x 299 x 3 78 | x = self.Conv2d_1a_3x3(x) 79 | # 149 x 149 x 32 80 | x = self.Conv2d_2a_3x3(x) 81 | # 147 x 147 x 32 82 | x = self.Conv2d_2b_3x3(x) 83 | # 147 x 147 x 64 84 | x = F.max_pool2d(x, kernel_size=3, stride=2) 85 | # 73 x 73 x 64 86 | x = self.Conv2d_3b_1x1(x) 87 | # 73 x 73 x 80 88 | x = self.Conv2d_4a_3x3(x) 89 | # 71 x 71 x 192 90 | x = F.max_pool2d(x, kernel_size=3, stride=2) 91 | # 35 x 35 x 192 92 | x = self.Mixed_5b(x) 93 | # 35 x 35 x 256 94 | x = self.Mixed_5c(x) 95 | # 35 x 35 x 288 96 | x = self.Mixed_5d(x) 97 | # 35 x 35 x 288 98 | x = self.Mixed_6a(x) 99 | # 17 x 17 x 768 100 | x = self.Mixed_6b(x) 101 | # 17 x 17 x 768 102 | x = self.Mixed_6c(x) 103 | # 17 x 17 x 768 104 | x = self.Mixed_6d(x) 105 | # 17 x 17 x 768 106 | x = self.Mixed_6e(x) 107 | # 17 x 17 x 768 108 | #if self.training and self.aux_logits: 109 | # aux = self.AuxLogits(x) # Error 12 -> 5 110 | # 17 x 17 x 768 111 | x = self.Mixed_7a(x) 112 | # 8 x 8 x 1280 113 | x = self.Mixed_7b(x) 114 | # 8 x 8 x 2048 115 | x = self.Mixed_7c(x) 116 | # 8 x 8 x 2048 117 | x = F.avg_pool2d(x, kernel_size=5) 118 | # 1 x 1 x 2048 119 | x = F.dropout(x, training=self.training) 120 | # 1 x 1 x 2048 121 | x = x.view(x.size(0), -1) 122 | # 2048 123 | x = self.fc(x) 124 | # 1000 (num_classes) 125 | #if self.training and self.aux_logits: 126 | # return x, aux 127 | return x 128 | 129 | 130 | class InceptionA(nn.Module): 131 | 132 | def __init__(self, in_channels, pool_features): 133 | super(InceptionA, self).__init__() 134 | self.branch1x1 = BasicConv2d(in_channels, 64, kernel_size=1) 135 | 136 | self.branch5x5_1 = BasicConv2d(in_channels, 48, kernel_size=1) 137 | self.branch5x5_2 = BasicConv2d(48, 64, kernel_size=5, padding=2) 138 | 139 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1) 140 | self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1) 141 | self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, padding=1) 142 | 143 | self.branch_pool = BasicConv2d(in_channels, pool_features, kernel_size=1) 144 | 145 | def forward(self, x): 146 | branch1x1 = self.branch1x1(x) 147 | 148 | branch5x5 = self.branch5x5_1(x) 149 | branch5x5 = self.branch5x5_2(branch5x5) 150 | 151 | branch3x3dbl = self.branch3x3dbl_1(x) 152 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 153 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 154 | 155 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) 156 | branch_pool = self.branch_pool(branch_pool) 157 | 158 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 159 | return torch.cat(outputs, 1) 160 | 161 | 162 | class InceptionB(nn.Module): 163 | 164 | def __init__(self, in_channels): 165 | super(InceptionB, self).__init__() 166 | self.branch3x3 = BasicConv2d(in_channels, 384, kernel_size=3, stride=2) 167 | 168 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1) 169 | self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1) 170 | self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, stride=2) 171 | 172 | def forward(self, x): 173 | branch3x3 = self.branch3x3(x) 174 | 175 | branch3x3dbl = self.branch3x3dbl_1(x) 176 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 177 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 178 | 179 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) 180 | 181 | outputs = [branch3x3, branch3x3dbl, branch_pool] 182 | return torch.cat(outputs, 1) 183 | 184 | 185 | class InceptionC(nn.Module): 186 | 187 | def __init__(self, in_channels, channels_7x7): 188 | super(InceptionC, self).__init__() 189 | self.branch1x1 = BasicConv2d(in_channels, 192, kernel_size=1) 190 | 191 | c7 = channels_7x7 192 | self.branch7x7_1 = BasicConv2d(in_channels, c7, kernel_size=1) 193 | self.branch7x7_2 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3)) 194 | self.branch7x7_3 = BasicConv2d(c7, 192, kernel_size=(7, 1), padding=(3, 0)) 195 | 196 | self.branch7x7dbl_1 = BasicConv2d(in_channels, c7, kernel_size=1) 197 | self.branch7x7dbl_2 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)) 198 | self.branch7x7dbl_3 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3)) 199 | self.branch7x7dbl_4 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)) 200 | self.branch7x7dbl_5 = BasicConv2d(c7, 192, kernel_size=(1, 7), padding=(0, 3)) 201 | 202 | self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1) 203 | 204 | def forward(self, x): 205 | branch1x1 = self.branch1x1(x) 206 | 207 | branch7x7 = self.branch7x7_1(x) 208 | branch7x7 = self.branch7x7_2(branch7x7) 209 | branch7x7 = self.branch7x7_3(branch7x7) 210 | 211 | branch7x7dbl = self.branch7x7dbl_1(x) 212 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 213 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 214 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 215 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 216 | 217 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) 218 | branch_pool = self.branch_pool(branch_pool) 219 | 220 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 221 | return torch.cat(outputs, 1) 222 | 223 | 224 | class InceptionD(nn.Module): 225 | 226 | def __init__(self, in_channels): 227 | super(InceptionD, self).__init__() 228 | self.branch3x3_1 = BasicConv2d(in_channels, 192, kernel_size=1) 229 | self.branch3x3_2 = BasicConv2d(192, 320, kernel_size=3, stride=2) 230 | 231 | self.branch7x7x3_1 = BasicConv2d(in_channels, 192, kernel_size=1) 232 | self.branch7x7x3_2 = BasicConv2d(192, 192, kernel_size=(1, 7), padding=(0, 3)) 233 | self.branch7x7x3_3 = BasicConv2d(192, 192, kernel_size=(7, 1), padding=(3, 0)) 234 | self.branch7x7x3_4 = BasicConv2d(192, 192, kernel_size=3, stride=2) 235 | 236 | def forward(self, x): 237 | branch3x3 = self.branch3x3_1(x) 238 | branch3x3 = self.branch3x3_2(branch3x3) 239 | 240 | branch7x7x3 = self.branch7x7x3_1(x) 241 | branch7x7x3 = self.branch7x7x3_2(branch7x7x3) 242 | branch7x7x3 = self.branch7x7x3_3(branch7x7x3) 243 | branch7x7x3 = self.branch7x7x3_4(branch7x7x3) 244 | 245 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) 246 | outputs = [branch3x3, branch7x7x3, branch_pool] 247 | return torch.cat(outputs, 1) 248 | 249 | 250 | class InceptionE(nn.Module): 251 | 252 | def __init__(self, in_channels): 253 | super(InceptionE, self).__init__() 254 | self.branch1x1 = BasicConv2d(in_channels, 320, kernel_size=1) 255 | 256 | self.branch3x3_1 = BasicConv2d(in_channels, 384, kernel_size=1) 257 | self.branch3x3_2a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1)) 258 | self.branch3x3_2b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0)) 259 | 260 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 448, kernel_size=1) 261 | self.branch3x3dbl_2 = BasicConv2d(448, 384, kernel_size=3, padding=1) 262 | self.branch3x3dbl_3a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1)) 263 | self.branch3x3dbl_3b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0)) 264 | 265 | self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1) 266 | 267 | def forward(self, x): 268 | branch1x1 = self.branch1x1(x) 269 | 270 | branch3x3 = self.branch3x3_1(x) 271 | branch3x3 = [ 272 | self.branch3x3_2a(branch3x3), 273 | self.branch3x3_2b(branch3x3), 274 | ] 275 | branch3x3 = torch.cat(branch3x3, 1) 276 | 277 | branch3x3dbl = self.branch3x3dbl_1(x) 278 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 279 | branch3x3dbl = [ 280 | self.branch3x3dbl_3a(branch3x3dbl), 281 | self.branch3x3dbl_3b(branch3x3dbl), 282 | ] 283 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 284 | 285 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) 286 | branch_pool = self.branch_pool(branch_pool) 287 | 288 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 289 | return torch.cat(outputs, 1) 290 | 291 | 292 | class InceptionAux(nn.Module): 293 | 294 | def __init__(self, in_channels, num_classes): 295 | super(InceptionAux, self).__init__() 296 | self.conv0 = BasicConv2d(in_channels, 128, kernel_size=1) 297 | self.conv1 = BasicConv2d(128, 768, kernel_size=5) 298 | self.conv1.stddev = 0.01 299 | self.fc = nn.Linear(768, num_classes) 300 | self.fc.stddev = 0.001 301 | 302 | def forward(self, x): 303 | # 17 x 17 x 768 304 | x = F.avg_pool2d(x, kernel_size=5, stride=3) 305 | # 5 x 5 x 768 306 | x = self.conv0(x) 307 | # 5 x 5 x 128 308 | x = self.conv1(x) 309 | # 1 x 1 x 768 310 | x = x.view(x.size(0), -1) 311 | # 768 312 | x = self.fc(x) 313 | # 1000 314 | return x 315 | 316 | 317 | class BasicConv2d(nn.Module): 318 | 319 | def __init__(self, in_channels, out_channels, **kwargs): 320 | super(BasicConv2d, self).__init__() 321 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) 322 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001) 323 | 324 | def forward(self, x): 325 | x = self.conv(x) 326 | x = self.bn(x) 327 | return F.relu(x, inplace=True) --------------------------------------------------------------------------------