├── .gitignore
├── cam.jpg
├── sample.jpg
├── data.py
├── README.md
├── update.py
├── train.py
├── main.py
└── inception.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | .idea
3 | kaggle/*
4 | result*/
5 | result*
6 |
--------------------------------------------------------------------------------
/cam.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaeyoung-lee/pytorch-CAM/HEAD/cam.jpg
--------------------------------------------------------------------------------
/sample.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaeyoung-lee/pytorch-CAM/HEAD/sample.jpg
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import io
3 | import requests
4 | from PIL import Image
5 |
6 | def read_data(img_num, txt, idx):
7 | f = open(txt, 'r')
8 | IMG_URLs = []
9 | for i in range(img_num):
10 | line = f.readline()
11 | url = line.split()[idx]
12 | IMG_URLs = np.append(IMG_URLs,url)
13 | f.close()
14 | return IMG_URLs
15 |
16 | def get_img(num, IMG_URL, root):
17 | response = requests.get(IMG_URL)
18 | img_pil = Image.open(io.BytesIO(response.content))
19 | #img_pil.save(str(idx) + '.jpg')
20 | img_pil.save(root + str(num) + '.jpg')
21 | return img_pil
22 |
23 | """
24 | # 이미지넷 이미지 추출
25 | IMG_URLs = read_data(1071, 'data/ear.txt', 0)
26 |
27 | for i in range(len(IMG_URLs)):
28 | try:
29 | print(i)
30 | IMG_URL = IMG_URLs[i]
31 | img = get_img(i, IMG_URL, root='image/train/ear/')
32 | if img == 0: continue
33 | print()
34 | except:
35 | print('에러')
36 | print()
37 | """
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # pytorch-CAM
2 | This repository is an unofficial version of Class Activation Mapping written in PyTorch, modified for a simple use case.
3 |
4 | ## Class Activation Mapping (CAM)
5 | Paper and Archiecture: [Learning Deep Features for Discriminative Localization][1]
6 |
7 | Paper Author Implementation: [metalbubble/CAM][2]
8 |
9 | In the paper:
10 |
11 | *We propose a technique for generating class activation maps using the global average pooling (GAP) in CNNs. A class activation map for a particular category indicates the discriminative image regions used by the CNN to identify that category. The procedure for generating these maps is illustrated as follows:*
12 |
13 |
14 |

15 |
16 |
17 | *Class activation maps could be used to intepret the prediction decision made by the CNN. The left image below shows the class activation map of top 5 predictions respectively, you can see that the CNN is triggered by different semantic regions of the image for different predictions. The right image below shows the CNN learns to localize the common visual patterns for the same object class.*
18 |
19 |
20 |

21 |
22 |
23 |
24 | ## Code Description
25 | **Usage**: `python3 main.py`
26 |
27 | **Network**: Inception V3
28 |
29 | **Data**: [Kaggle dogs vs. cats][3]
30 | - Download the 'test1.zip' and 'train.zip' files and upzip them.
31 | - Divde the total dataset into train group and test group. As you do that, [images must be arranged in this way][4]:
32 | ```
33 | kaggle/train/cat/*.jpg
34 | kaggle/test/cat/*.jpg
35 | ```
36 |
37 | **Checkpoint**
38 | - Checkpoint will be created in the checkpoint folder every ten epoch.
39 | - By setting `RESUME = #`, you can resume from `checkpoint/#.pt`.
40 |
41 | [1]: https://arxiv.org/abs/1512.04150
42 | [2]: https://github.com/metalbubble/CAM
43 | [3]: https://www.kaggle.com/c/dogs-vs-cats/data
44 | [4]: http://pytorch.org/docs/master/torchvision/datasets.html#imagefolder
45 |
--------------------------------------------------------------------------------
/update.py:
--------------------------------------------------------------------------------
1 | from torchvision import transforms
2 | from torch.autograd import Variable
3 | from torch.nn import functional as F
4 | import numpy as np
5 | import cv2, torch
6 |
7 | # generate class activation mapping for the top1 prediction
8 | def returnCAM(feature_conv, weight_softmax, class_idx):
9 | # generate the class activation maps upsample to 256x256
10 | size_upsample = (256, 256)
11 | bz, nc, h, w = feature_conv.shape
12 | output_cam = []
13 | for idx in class_idx:
14 | cam = weight_softmax[class_idx].dot(feature_conv.reshape((nc, h*w)))
15 | cam = cam.reshape(h, w)
16 | cam = cam - np.min(cam)
17 | cam_img = cam / np.max(cam)
18 | cam_img = np.uint8(255 * cam_img)
19 | output_cam.append(cv2.resize(cam_img, size_upsample))
20 | return output_cam
21 |
22 | def get_cam(net, features_blobs, img_pil, classes, root_img):
23 | params = list(net.parameters())
24 | weight_softmax = np.squeeze(params[-2].data.cpu().numpy())
25 |
26 | normalize = transforms.Normalize(
27 | mean=[0.485, 0.456, 0.406],
28 | std=[0.229, 0.224, 0.225]
29 | )
30 | preprocess = transforms.Compose([
31 | transforms.Resize((224, 224)),
32 | transforms.ToTensor(),
33 | normalize
34 | ])
35 |
36 | img_tensor = preprocess(img_pil)
37 | img_variable = Variable(img_tensor.unsqueeze(0)).cuda()
38 | logit = net(img_variable)
39 |
40 | h_x = F.softmax(logit, dim=1).data.squeeze()
41 | probs, idx = h_x.sort(0, True)
42 |
43 | # output: the prediction
44 | for i in range(0, 2):
45 | line = '{:.3f} -> {}'.format(probs[i], classes[idx[i].item()])
46 | print(line)
47 |
48 | CAMs = returnCAM(features_blobs[0], weight_softmax, [idx[0].item()])
49 |
50 | # render the CAM and output
51 | print('output CAM.jpg for the top1 prediction: %s' % classes[idx[0].item()])
52 | img = cv2.imread(root_img)
53 | height, width, _ = img.shape
54 | CAM = cv2.resize(CAMs[0], (width, height))
55 | heatmap = cv2.applyColorMap(CAM, cv2.COLORMAP_JET)
56 | result = heatmap * 0.3 + img * 0.5
57 | cv2.imwrite('cam.jpg', result)
58 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Variable
3 |
4 | def retrain(trainloader, model, use_cuda, epoch, criterion, optimizer):
5 | model.train()
6 | correct, total = 0, 0
7 | acc_sum, loss_sum = 0, 0
8 | i = 0
9 | for batch_idx, (data, target) in enumerate(trainloader):
10 | if use_cuda:
11 | data, target = data.cuda(), target.cuda()
12 | data, target = Variable(data), Variable(target)
13 | optimizer.zero_grad()
14 | output = model(data)
15 |
16 | # calculate accuracy
17 | correct += (torch.max(output, 1)[1].view(target.size()).data == target.data).sum()
18 | total += trainloader.batch_size
19 | train_acc = 100. * correct / total
20 | acc_sum += train_acc
21 | i += 1
22 |
23 | loss = criterion(output, target)
24 | loss.backward()
25 | optimizer.step()
26 | loss_sum += loss.item()
27 |
28 | if batch_idx % 10 == 0:
29 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.3f}\tTraining Accuracy: {:.3f}%'.format(
30 | epoch, batch_idx * len(data), len(trainloader.dataset),
31 | 100. * batch_idx / len(trainloader), loss.item(), train_acc))
32 |
33 | acc_avg = acc_sum / i
34 | loss_avg = loss_sum / len(trainloader.dataset)
35 | print()
36 | print('Train Epoch: {}\tAverage Loss: {:.3f}\tAverage Accuracy: {:.3f}%'.format(epoch, loss_avg, acc_avg))
37 |
38 | with open('result/train_acc.txt', 'a') as f:
39 | f.write(str(acc_avg))
40 | f.close()
41 | with open('result/train_loss.txt', 'a') as f:
42 | f.write(str(loss_avg))
43 | f.close()
44 |
45 | def retest(testloader, model, use_cuda, criterion, epoch, RESUME):
46 | model.eval()
47 | test_loss = 0
48 | correct = 0
49 | for data, target in testloader:
50 | if use_cuda:
51 | data, target = data.cuda(), target.cuda()
52 | data, target = Variable(data, volatile=True), Variable(target)
53 | output = model(data)
54 | # sum up batch loss
55 | test_loss += criterion(output, target).item()
56 | # get the index of the max log-probability
57 | pred = output.data.max(1, keepdim=True)[1]
58 | correct += pred.eq(target.data.view_as(pred)).cpu().sum()
59 |
60 | test_loss /= len(testloader.dataset)
61 | test_acc = 100. * correct / len(testloader.dataset)
62 | result = '\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.3f}%)\n'.format(
63 | test_loss, correct, len(testloader.dataset), test_acc)
64 | print(result)
65 |
66 | # Save checkpoint.
67 | if epoch % 10 == 0:
68 | torch.save(model.state_dict(), 'checkpoint/' + str(RESUME + int(epoch / 10)) + '.pt')
69 | with open('result/result.txt', 'a') as f:
70 | f.write(result)
71 | f.close()
72 |
73 | with open('result/test_acc.txt', 'a') as f:
74 | f.write(str(test_acc))
75 | f.close()
76 | with open('result/test_loss.txt', 'a') as f:
77 | f.write(str(test_loss))
78 | f.close()
79 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | """
2 | Class Activation Mapping
3 | Googlenet, Kaggle data
4 | """
5 |
6 | from update import *
7 | from data import *
8 | from train import *
9 | import torch, os
10 | from torch.utils.data import DataLoader
11 | from torchvision import datasets, transforms
12 | from inception import inception_v3
13 |
14 |
15 | # functions
16 | CAM = 1
17 | USE_CUDA = 1
18 | RESUME = 0
19 | PRETRAINED = 0
20 |
21 |
22 | # hyperparameters
23 | BATCH_SIZE = 32
24 | IMG_SIZE = 224
25 | LEARNING_RATE = 0.01
26 | EPOCH = 0
27 |
28 |
29 | # prepare data
30 | normalize = transforms.Normalize(
31 | mean=[0.485, 0.456, 0.406],
32 | std=[0.229, 0.224, 0.225]
33 | )
34 |
35 | transform_train = transforms.Compose([
36 | transforms.RandomResizedCrop(224),
37 | transforms.RandomHorizontalFlip(),
38 | transforms.ToTensor(),
39 | normalize
40 | ])
41 |
42 | transform_test = transforms.Compose([
43 | transforms.Resize(256),
44 | transforms.CenterCrop(224),
45 | transforms.ToTensor(),
46 | normalize
47 | ])
48 |
49 | train_data = datasets.ImageFolder('kaggle/train/', transform=transform_train)
50 | trainloader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
51 |
52 | test_data = datasets.ImageFolder('kaggle/test/', transform=transform_test)
53 | testloader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
54 |
55 |
56 | # class
57 | classes = {0: 'cat', 1: 'dog'}
58 |
59 |
60 | # fine tuning
61 | if PRETRAINED:
62 | net = inception_v3(pretrained=PRETRAINED)
63 | for param in net.parameters():
64 | param.requires_grad = False
65 | net.fc = torch.nn.Linear(2048, 2)
66 | else:
67 | net = inception_v3(pretrained=PRETRAINED, num_classes=len(classes))
68 | final_conv = 'Mixed_7c'
69 |
70 | net.cuda()
71 |
72 |
73 | # load checkpoint
74 | if RESUME != 0:
75 | print("===> Resuming from checkpoint.")
76 | assert os.path.isfile('checkpoint/'+ str(RESUME) + '.pt'), 'Error: no checkpoint found!'
77 | net.load_state_dict(torch.load('checkpoint/' + str(RESUME) + '.pt'))
78 |
79 |
80 | # retrain
81 | criterion = torch.nn.CrossEntropyLoss()
82 |
83 | if PRETRAINED:
84 | optimizer = torch.optim.SGD(net.fc.parameters(), lr=LEARNING_RATE, momentum=0.9, weight_decay=5e-4)
85 | else:
86 | optimizer = torch.optim.SGD(net.parameters(), lr=LEARNING_RATE, momentum=0.9, weight_decay=5e-4)
87 |
88 | for epoch in range (1, EPOCH + 1):
89 | retrain(trainloader, net, USE_CUDA, epoch, criterion, optimizer)
90 | retest(testloader, net, USE_CUDA, criterion, epoch, RESUME)
91 |
92 |
93 | # hook the feature extractor
94 | features_blobs = []
95 |
96 | def hook_feature(module, input, output):
97 | features_blobs.append(output.data.cpu().numpy())
98 |
99 | net._modules.get(final_conv).register_forward_hook(hook_feature)
100 |
101 |
102 | # CAM
103 | if CAM:
104 | root = 'sample.jpg'
105 | img = Image.open(root)
106 | get_cam(net, features_blobs, img, classes, root)
107 |
--------------------------------------------------------------------------------
/inception.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch.utils.model_zoo as model_zoo
5 |
6 |
7 | __all__ = ['Inception3', 'inception_v3']
8 |
9 |
10 | model_urls = {
11 | # Inception v3 ported from TensorFlow
12 | 'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth',
13 | }
14 |
15 |
16 | def inception_v3(pretrained=False, **kwargs):
17 | r"""Inception v3 model architecture from
18 | `"Rethinking the Inception Architecture for Computer Vision" `_.
19 |
20 | Args:
21 | pretrained (bool): If True, returns a model pre-trained on ImageNet
22 | """
23 | if pretrained:
24 | if 'transform_input' not in kwargs:
25 | kwargs['transform_input'] = True
26 | model = Inception3(**kwargs)
27 | model.load_state_dict(model_zoo.load_url(model_urls['inception_v3_google']))
28 | return model
29 |
30 | return Inception3(**kwargs)
31 |
32 |
33 | class Inception3(nn.Module):
34 |
35 | def __init__(self, num_classes=1000, aux_logits=True, transform_input=False):
36 | super(Inception3, self).__init__()
37 | self.aux_logits = aux_logits
38 | self.transform_input = transform_input
39 | self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2)
40 | self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3)
41 | self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1)
42 | self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1)
43 | self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3)
44 | self.Mixed_5b = InceptionA(192, pool_features=32)
45 | self.Mixed_5c = InceptionA(256, pool_features=64)
46 | self.Mixed_5d = InceptionA(288, pool_features=64)
47 | self.Mixed_6a = InceptionB(288)
48 | self.Mixed_6b = InceptionC(768, channels_7x7=128)
49 | self.Mixed_6c = InceptionC(768, channels_7x7=160)
50 | self.Mixed_6d = InceptionC(768, channels_7x7=160)
51 | self.Mixed_6e = InceptionC(768, channels_7x7=192)
52 | if aux_logits:
53 | self.AuxLogits = InceptionAux(768, num_classes)
54 | self.Mixed_7a = InceptionD(768)
55 | self.Mixed_7b = InceptionE(1280)
56 | self.Mixed_7c = InceptionE(2048)
57 | self.fc = nn.Linear(2048, num_classes)
58 |
59 | for m in self.modules():
60 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
61 | import scipy.stats as stats
62 | stddev = m.stddev if hasattr(m, 'stddev') else 0.1
63 | X = stats.truncnorm(-2, 2, scale=stddev)
64 | values = torch.Tensor(X.rvs(m.weight.data.numel()))
65 | values = values.view(m.weight.data.size())
66 | m.weight.data.copy_(values)
67 | elif isinstance(m, nn.BatchNorm2d):
68 | m.weight.data.fill_(1)
69 | m.bias.data.zero_()
70 |
71 | def forward(self, x):
72 | if self.transform_input:
73 | x = x.clone()
74 | x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
75 | x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
76 | x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
77 | # 299 x 299 x 3
78 | x = self.Conv2d_1a_3x3(x)
79 | # 149 x 149 x 32
80 | x = self.Conv2d_2a_3x3(x)
81 | # 147 x 147 x 32
82 | x = self.Conv2d_2b_3x3(x)
83 | # 147 x 147 x 64
84 | x = F.max_pool2d(x, kernel_size=3, stride=2)
85 | # 73 x 73 x 64
86 | x = self.Conv2d_3b_1x1(x)
87 | # 73 x 73 x 80
88 | x = self.Conv2d_4a_3x3(x)
89 | # 71 x 71 x 192
90 | x = F.max_pool2d(x, kernel_size=3, stride=2)
91 | # 35 x 35 x 192
92 | x = self.Mixed_5b(x)
93 | # 35 x 35 x 256
94 | x = self.Mixed_5c(x)
95 | # 35 x 35 x 288
96 | x = self.Mixed_5d(x)
97 | # 35 x 35 x 288
98 | x = self.Mixed_6a(x)
99 | # 17 x 17 x 768
100 | x = self.Mixed_6b(x)
101 | # 17 x 17 x 768
102 | x = self.Mixed_6c(x)
103 | # 17 x 17 x 768
104 | x = self.Mixed_6d(x)
105 | # 17 x 17 x 768
106 | x = self.Mixed_6e(x)
107 | # 17 x 17 x 768
108 | #if self.training and self.aux_logits:
109 | # aux = self.AuxLogits(x) # Error 12 -> 5
110 | # 17 x 17 x 768
111 | x = self.Mixed_7a(x)
112 | # 8 x 8 x 1280
113 | x = self.Mixed_7b(x)
114 | # 8 x 8 x 2048
115 | x = self.Mixed_7c(x)
116 | # 8 x 8 x 2048
117 | x = F.avg_pool2d(x, kernel_size=5)
118 | # 1 x 1 x 2048
119 | x = F.dropout(x, training=self.training)
120 | # 1 x 1 x 2048
121 | x = x.view(x.size(0), -1)
122 | # 2048
123 | x = self.fc(x)
124 | # 1000 (num_classes)
125 | #if self.training and self.aux_logits:
126 | # return x, aux
127 | return x
128 |
129 |
130 | class InceptionA(nn.Module):
131 |
132 | def __init__(self, in_channels, pool_features):
133 | super(InceptionA, self).__init__()
134 | self.branch1x1 = BasicConv2d(in_channels, 64, kernel_size=1)
135 |
136 | self.branch5x5_1 = BasicConv2d(in_channels, 48, kernel_size=1)
137 | self.branch5x5_2 = BasicConv2d(48, 64, kernel_size=5, padding=2)
138 |
139 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1)
140 | self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1)
141 | self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, padding=1)
142 |
143 | self.branch_pool = BasicConv2d(in_channels, pool_features, kernel_size=1)
144 |
145 | def forward(self, x):
146 | branch1x1 = self.branch1x1(x)
147 |
148 | branch5x5 = self.branch5x5_1(x)
149 | branch5x5 = self.branch5x5_2(branch5x5)
150 |
151 | branch3x3dbl = self.branch3x3dbl_1(x)
152 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
153 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
154 |
155 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
156 | branch_pool = self.branch_pool(branch_pool)
157 |
158 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
159 | return torch.cat(outputs, 1)
160 |
161 |
162 | class InceptionB(nn.Module):
163 |
164 | def __init__(self, in_channels):
165 | super(InceptionB, self).__init__()
166 | self.branch3x3 = BasicConv2d(in_channels, 384, kernel_size=3, stride=2)
167 |
168 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1)
169 | self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1)
170 | self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, stride=2)
171 |
172 | def forward(self, x):
173 | branch3x3 = self.branch3x3(x)
174 |
175 | branch3x3dbl = self.branch3x3dbl_1(x)
176 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
177 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
178 |
179 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
180 |
181 | outputs = [branch3x3, branch3x3dbl, branch_pool]
182 | return torch.cat(outputs, 1)
183 |
184 |
185 | class InceptionC(nn.Module):
186 |
187 | def __init__(self, in_channels, channels_7x7):
188 | super(InceptionC, self).__init__()
189 | self.branch1x1 = BasicConv2d(in_channels, 192, kernel_size=1)
190 |
191 | c7 = channels_7x7
192 | self.branch7x7_1 = BasicConv2d(in_channels, c7, kernel_size=1)
193 | self.branch7x7_2 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3))
194 | self.branch7x7_3 = BasicConv2d(c7, 192, kernel_size=(7, 1), padding=(3, 0))
195 |
196 | self.branch7x7dbl_1 = BasicConv2d(in_channels, c7, kernel_size=1)
197 | self.branch7x7dbl_2 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0))
198 | self.branch7x7dbl_3 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3))
199 | self.branch7x7dbl_4 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0))
200 | self.branch7x7dbl_5 = BasicConv2d(c7, 192, kernel_size=(1, 7), padding=(0, 3))
201 |
202 | self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1)
203 |
204 | def forward(self, x):
205 | branch1x1 = self.branch1x1(x)
206 |
207 | branch7x7 = self.branch7x7_1(x)
208 | branch7x7 = self.branch7x7_2(branch7x7)
209 | branch7x7 = self.branch7x7_3(branch7x7)
210 |
211 | branch7x7dbl = self.branch7x7dbl_1(x)
212 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
213 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
214 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
215 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
216 |
217 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
218 | branch_pool = self.branch_pool(branch_pool)
219 |
220 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
221 | return torch.cat(outputs, 1)
222 |
223 |
224 | class InceptionD(nn.Module):
225 |
226 | def __init__(self, in_channels):
227 | super(InceptionD, self).__init__()
228 | self.branch3x3_1 = BasicConv2d(in_channels, 192, kernel_size=1)
229 | self.branch3x3_2 = BasicConv2d(192, 320, kernel_size=3, stride=2)
230 |
231 | self.branch7x7x3_1 = BasicConv2d(in_channels, 192, kernel_size=1)
232 | self.branch7x7x3_2 = BasicConv2d(192, 192, kernel_size=(1, 7), padding=(0, 3))
233 | self.branch7x7x3_3 = BasicConv2d(192, 192, kernel_size=(7, 1), padding=(3, 0))
234 | self.branch7x7x3_4 = BasicConv2d(192, 192, kernel_size=3, stride=2)
235 |
236 | def forward(self, x):
237 | branch3x3 = self.branch3x3_1(x)
238 | branch3x3 = self.branch3x3_2(branch3x3)
239 |
240 | branch7x7x3 = self.branch7x7x3_1(x)
241 | branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
242 | branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
243 | branch7x7x3 = self.branch7x7x3_4(branch7x7x3)
244 |
245 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
246 | outputs = [branch3x3, branch7x7x3, branch_pool]
247 | return torch.cat(outputs, 1)
248 |
249 |
250 | class InceptionE(nn.Module):
251 |
252 | def __init__(self, in_channels):
253 | super(InceptionE, self).__init__()
254 | self.branch1x1 = BasicConv2d(in_channels, 320, kernel_size=1)
255 |
256 | self.branch3x3_1 = BasicConv2d(in_channels, 384, kernel_size=1)
257 | self.branch3x3_2a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1))
258 | self.branch3x3_2b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0))
259 |
260 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 448, kernel_size=1)
261 | self.branch3x3dbl_2 = BasicConv2d(448, 384, kernel_size=3, padding=1)
262 | self.branch3x3dbl_3a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1))
263 | self.branch3x3dbl_3b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0))
264 |
265 | self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1)
266 |
267 | def forward(self, x):
268 | branch1x1 = self.branch1x1(x)
269 |
270 | branch3x3 = self.branch3x3_1(x)
271 | branch3x3 = [
272 | self.branch3x3_2a(branch3x3),
273 | self.branch3x3_2b(branch3x3),
274 | ]
275 | branch3x3 = torch.cat(branch3x3, 1)
276 |
277 | branch3x3dbl = self.branch3x3dbl_1(x)
278 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
279 | branch3x3dbl = [
280 | self.branch3x3dbl_3a(branch3x3dbl),
281 | self.branch3x3dbl_3b(branch3x3dbl),
282 | ]
283 | branch3x3dbl = torch.cat(branch3x3dbl, 1)
284 |
285 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
286 | branch_pool = self.branch_pool(branch_pool)
287 |
288 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
289 | return torch.cat(outputs, 1)
290 |
291 |
292 | class InceptionAux(nn.Module):
293 |
294 | def __init__(self, in_channels, num_classes):
295 | super(InceptionAux, self).__init__()
296 | self.conv0 = BasicConv2d(in_channels, 128, kernel_size=1)
297 | self.conv1 = BasicConv2d(128, 768, kernel_size=5)
298 | self.conv1.stddev = 0.01
299 | self.fc = nn.Linear(768, num_classes)
300 | self.fc.stddev = 0.001
301 |
302 | def forward(self, x):
303 | # 17 x 17 x 768
304 | x = F.avg_pool2d(x, kernel_size=5, stride=3)
305 | # 5 x 5 x 768
306 | x = self.conv0(x)
307 | # 5 x 5 x 128
308 | x = self.conv1(x)
309 | # 1 x 1 x 768
310 | x = x.view(x.size(0), -1)
311 | # 768
312 | x = self.fc(x)
313 | # 1000
314 | return x
315 |
316 |
317 | class BasicConv2d(nn.Module):
318 |
319 | def __init__(self, in_channels, out_channels, **kwargs):
320 | super(BasicConv2d, self).__init__()
321 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
322 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
323 |
324 | def forward(self, x):
325 | x = self.conv(x)
326 | x = self.bn(x)
327 | return F.relu(x, inplace=True)
--------------------------------------------------------------------------------