├── README.md ├── main.py └── models ├── alexnet.py ├── resnet.py ├── utils.py └── vgg.py /README.md: -------------------------------------------------------------------------------- 1 | # Hierarchical Group Sparse Regularization for Deep Convolutional Neural Networks 2 | Implementation with PyTorch. 3 | 4 | [Accepted to The 2020 International Joint Conference on Neural Networks (IJCNN)](https://arxiv.org/abs/2004.04394) 5 | 6 | ## Requirements 7 | - Python 3.7.5 8 | - Pytorch 1.3.1 9 | - TorchVision 0.4.2 10 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | from torchvision import datasets, transforms 9 | from models import alexnet, resnet, vgg 10 | 11 | # Training settings 12 | parser = argparse.ArgumentParser(description='Train AlexNet on CIFAR or STL dataset') 13 | parser.add_argument('data_path', type=str, 14 | help='path to dataset') 15 | parser.add_argument('--dataset', type=str, choices = ['cifar10','cifar100','stl10'], default='cifar10', 16 | help='choose dataset from cifar10/100 or stl10') 17 | parser.add_argument('--trained_model', action='store_true', default=False, 18 | help='load trained model') 19 | parser.add_argument('--trained_model_path', type=str, 20 | help='path to trained model') 21 | parser.add_argument('--batch_size', type=int, default=128, 22 | help='batch size (default: 128)') 23 | parser.add_argument('--workers', type=int, default=2, 24 | help='number of data loading workers (default: 2)') 25 | parser.add_argument('--epochs', type=int, default=100, 26 | help='number of epochs to train (default: 100)') 27 | parser.add_argument('--lr', type=float, default=0.01, 28 | help='learning rate (default: 0.01)') 29 | parser.add_argument('--momentum', type=float, default=0.9, 30 | help='momentum (default: 0.9)') 31 | parser.add_argument('--decay', type=float, default=0.0005, 32 | help='Weight decay (L2) (default: 0.0005)') 33 | parser.add_argument('--gamma', type=float, default=0.2, 34 | help='Learning rate step gamma (default: 0.2)') 35 | parser.add_argument('--save-model', action='store_true', default=True, 36 | help='saving model (default: True)') 37 | parser.add_argument('--no-cuda', action='store_true', default=False, 38 | help='disables CUDA training') 39 | 40 | parser.add_argument('--zero_threshold', type=float, default=0.001, 41 | help='threshold to define zero weight (default: 0.001)') 42 | parser.add_argument('--_lambda', type=float, default=0.001, 43 | help='hypaerparameter for regularization tearm (default: 0.001)') 44 | parser.add_argument('--_lambda2', type=float, default=0.5, 45 | help='balancing parameter between regularization tearm (default: 0.5)') 46 | 47 | 48 | args = parser.parse_args() 49 | 50 | def main(): 51 | 52 | use_cuda = not args.no_cuda and torch.cuda.is_available() 53 | device = torch.device("cuda" if use_cuda else "cpu") 54 | 55 | transform_test = transforms.Compose([ 56 | transforms.ToTensor(), 57 | transforms.Normalize((0.485, 0.456, 0.406), 58 | (0.229, 0.224, 0.225)), 59 | ]) 60 | 61 | if args.dataset == 'cifar10': 62 | transform_train = transforms.Compose([ 63 | transforms.RandomCrop(32,padding = 4), 64 | transforms.RandomHorizontalFlip(), 65 | transforms.ToTensor(), 66 | transforms.Normalize((0.485, 0.456, 0.406), 67 | (0.229, 0.224, 0.225)), 68 | ]) 69 | trainset = datasets.CIFAR10(root=args.data_path,train=True,download=False,transform=transform_train) 70 | testset = datasets.CIFAR10(root=args.data_path,train=False,download=False,transform=transform_test) 71 | num_classes = 10 72 | elif args.dataset == 'cifar100': 73 | transform_train = transforms.Compose([ 74 | transforms.RandomCrop(32,padding = 4), 75 | transforms.RandomHorizontalFlip(), 76 | transforms.ToTensor(), 77 | transforms.Normalize((0.485, 0.456, 0.406), 78 | (0.229, 0.224, 0.225)), 79 | ]) 80 | trainset = datasets.CIFAR100(root=args.data_path,train=True,download=False,transform=transform_train) 81 | testset = datasets.CIFAR100(root=args.data_path,train=False,download=False,transform=transform_test) 82 | num_classes = 100 83 | elif args.dataset == 'stl10': 84 | transform_train = transforms.Compose([ 85 | transforms.RandomCrop(96,padding = 4), 86 | transforms.RandomHorizontalFlip(), 87 | transforms.ToTensor(), 88 | transforms.Normalize((0.485, 0.456, 0.406), 89 | (0.229, 0.224, 0.225)), 90 | ]) 91 | trainset = datasets.STL10(root=args.data_path,train=True,download=False,transform=transform_train) 92 | testset = datasets.STL10(root=args.data_path,train=False,download=False,transform=transform_test) 93 | num_classes = 10 94 | 95 | trainloader = torch.utils.data.DataLoader(trainset,batch_size=args.batch_size,shuffle=True,num_workers=args.workers) 96 | testloader = torch.utils.data.DataLoader(testset,batch_size=args.batch_size,shuffle=False, num_workers=args.workers) 97 | 98 | net = alexnet.alexnet(num_classes = num_classes).to(device) 99 | if args.trained_model: 100 | ckpt = torch.load(args.trained_model_path, map_location= device) 101 | net.load_state_dict(ckpt) 102 | 103 | criterion = nn.CrossEntropyLoss() 104 | 105 | optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.decay) 106 | 107 | #scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60,120,160], gamma=args.gamma) 108 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(args.epochs/3), gamma=args.gamma) 109 | 110 | regularization = sparse_regularization(net,device) 111 | 112 | for epoch in range(1, args.epochs + 1): 113 | train(args, net, device, trainloader, optimizer, criterion, epoch, regularization) 114 | test(args, net, device, testloader, criterion) 115 | scheduler.step() 116 | 117 | if args.save_model: 118 | torch.save(net.state_dict(), str(args.dataset)+"_alexnet.pt") 119 | 120 | 121 | def train(args, model, device, trainloader, optimizer, criterion, epoch, regularization): 122 | 123 | sum_loss = 0.0 124 | sum_correct = 0 125 | sum_total = 0 126 | model.train() 127 | for batch_idx, (data, target) in enumerate(trainloader): 128 | data, target = data.to(device), target.to(device) 129 | optimizer.zero_grad() 130 | output = model(data) 131 | loss = criterion(output, target) 132 | sum_loss += loss.item() 133 | loss += args._lambda2*regularization.hierarchical_squared_group_l12_regularization(args._lambda) 134 | loss += (1-args._lambda2)*regularization.l1_regularization(args._lambda) 135 | loss.backward() 136 | optimizer.step() 137 | _, predicted = output.max(1) 138 | sum_total += target.size(0) 139 | sum_correct += (predicted == target).sum().item() 140 | print("train mean loss={}, accuracy={}, sparsity={}" 141 | .format(sum_loss*args.batch_size/len(trainloader.dataset), float(sum_correct/sum_total), sparsity(model))) 142 | 143 | 144 | def test(args, model, device, testloader, criterion): 145 | model.eval() 146 | sum_loss = 0.0 147 | sum_correct = 0 148 | sum_total = 0 149 | with torch.no_grad(): 150 | for data, target in testloader: 151 | data, target = data.to(device), target.to(device) 152 | output = model(data) 153 | loss = criterion(output, target) 154 | sum_loss += loss.item() 155 | _, predicted = output.max(1) 156 | sum_total += target.size(0) 157 | sum_correct += (predicted == target).sum().item() 158 | print("test mean loss={}, accuracy={}" 159 | .format(sum_loss*args.batch_size/len(testloader.dataset), float(sum_correct/sum_total))) 160 | 161 | def sparsity(model): 162 | number_of_conv_weight = 0 163 | number_of_zero_conv_weight = 0 164 | for n, _module in model.named_modules(): 165 | if isinstance(_module, nn.Conv2d) and (not 'downsample' in n): 166 | p = torch.flatten(_module.weight.data) 167 | number_of_conv_weight += len(p) 168 | number_of_zero_conv_weight += len(p[torch.abs(p)`_ paper. 54 | Args: 55 | pretrained (bool): If True, returns a model pre-trained on ImageNet 56 | progress (bool): If True, displays a progress bar of the download to stderr 57 | """ 58 | model = AlexNet(**kwargs) 59 | if pretrained: 60 | state_dict = load_state_dict_from_url(model_urls['alexnet'], 61 | progress=progress) 62 | model.load_state_dict(state_dict) 63 | return model -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.utils import load_state_dict_from_url 4 | 5 | 6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 7 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 8 | 'wide_resnet50_2', 'wide_resnet101_2'] 9 | 10 | 11 | model_urls = { 12 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 13 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 14 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 15 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 16 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 17 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 18 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 19 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 20 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 21 | } 22 | 23 | 24 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 25 | """3x3 convolution with padding""" 26 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 27 | padding=dilation, groups=groups, bias=False, dilation=dilation) 28 | 29 | 30 | def conv1x1(in_planes, out_planes, stride=1): 31 | """1x1 convolution""" 32 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 33 | 34 | 35 | class BasicBlock(nn.Module): 36 | expansion = 1 37 | __constants__ = ['downsample'] 38 | 39 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 40 | base_width=64, dilation=1, norm_layer=None): 41 | super(BasicBlock, self).__init__() 42 | if norm_layer is None: 43 | norm_layer = nn.BatchNorm2d 44 | if groups != 1 or base_width != 64: 45 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 46 | if dilation > 1: 47 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 48 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 49 | self.conv1 = conv3x3(inplanes, planes, stride) 50 | self.bn1 = norm_layer(planes) 51 | self.relu = nn.ReLU(inplace=True) 52 | self.conv2 = conv3x3(planes, planes) 53 | self.bn2 = norm_layer(planes) 54 | self.downsample = downsample 55 | self.stride = stride 56 | 57 | def forward(self, x): 58 | identity = x 59 | 60 | out = self.conv1(x) 61 | out = self.bn1(out) 62 | out = self.relu(out) 63 | 64 | out = self.conv2(out) 65 | out = self.bn2(out) 66 | 67 | if self.downsample is not None: 68 | identity = self.downsample(x) 69 | 70 | out += identity 71 | out = self.relu(out) 72 | 73 | return out 74 | 75 | 76 | class Bottleneck(nn.Module): 77 | expansion = 4 78 | __constants__ = ['downsample'] 79 | 80 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 81 | base_width=64, dilation=1, norm_layer=None): 82 | super(Bottleneck, self).__init__() 83 | if norm_layer is None: 84 | norm_layer = nn.BatchNorm2d 85 | width = int(planes * (base_width / 64.)) * groups 86 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 87 | self.conv1 = conv1x1(inplanes, width) 88 | self.bn1 = norm_layer(width) 89 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 90 | self.bn2 = norm_layer(width) 91 | self.conv3 = conv1x1(width, planes * self.expansion) 92 | self.bn3 = norm_layer(planes * self.expansion) 93 | self.relu = nn.ReLU(inplace=True) 94 | self.downsample = downsample 95 | self.stride = stride 96 | 97 | def forward(self, x): 98 | identity = x 99 | 100 | out = self.conv1(x) 101 | out = self.bn1(out) 102 | out = self.relu(out) 103 | 104 | out = self.conv2(out) 105 | out = self.bn2(out) 106 | out = self.relu(out) 107 | 108 | out = self.conv3(out) 109 | out = self.bn3(out) 110 | 111 | if self.downsample is not None: 112 | identity = self.downsample(x) 113 | 114 | out += identity 115 | out = self.relu(out) 116 | 117 | return out 118 | 119 | 120 | class ResNet(nn.Module): 121 | 122 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 123 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 124 | norm_layer=None): 125 | super(ResNet, self).__init__() 126 | if norm_layer is None: 127 | norm_layer = nn.BatchNorm2d 128 | self._norm_layer = norm_layer 129 | 130 | self.inplanes = 8 131 | self.dilation = 1 132 | if replace_stride_with_dilation is None: 133 | # each element in the tuple indicates if we should replace 134 | # the 2x2 stride with a dilated convolution instead 135 | replace_stride_with_dilation = [False, False, False] 136 | if len(replace_stride_with_dilation) != 3: 137 | raise ValueError("replace_stride_with_dilation should be None " 138 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 139 | self.groups = groups 140 | self.base_width = width_per_group 141 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 142 | bias=False) 143 | self.bn1 = norm_layer(self.inplanes) 144 | self.relu = nn.ReLU(inplace=True) 145 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 146 | self.layer1 = self._make_layer(block, 8, layers[0]) 147 | self.layer2 = self._make_layer(block, 16, layers[1], stride=2, 148 | dilate=replace_stride_with_dilation[0]) 149 | self.layer3 = self._make_layer(block, 32, layers[2], stride=2, 150 | dilate=replace_stride_with_dilation[1]) 151 | self.layer4 = self._make_layer(block, 64, layers[3], stride=2, 152 | dilate=replace_stride_with_dilation[2]) 153 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 154 | self.fc = nn.Linear(64 * block.expansion, num_classes) 155 | 156 | for m in self.modules(): 157 | if isinstance(m, nn.Conv2d): 158 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 159 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 160 | nn.init.constant_(m.weight, 1) 161 | nn.init.constant_(m.bias, 0) 162 | 163 | # Zero-initialize the last BN in each residual branch, 164 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 165 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 166 | if zero_init_residual: 167 | for m in self.modules(): 168 | if isinstance(m, Bottleneck): 169 | nn.init.constant_(m.bn3.weight, 0) 170 | elif isinstance(m, BasicBlock): 171 | nn.init.constant_(m.bn2.weight, 0) 172 | 173 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 174 | norm_layer = self._norm_layer 175 | downsample = None 176 | previous_dilation = self.dilation 177 | if dilate: 178 | self.dilation *= stride 179 | stride = 1 180 | if stride != 1 or self.inplanes != planes * block.expansion: 181 | downsample = nn.Sequential( 182 | conv1x1(self.inplanes, planes * block.expansion, stride), 183 | norm_layer(planes * block.expansion), 184 | ) 185 | 186 | layers = [] 187 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 188 | self.base_width, previous_dilation, norm_layer)) 189 | self.inplanes = planes * block.expansion 190 | for _ in range(1, blocks): 191 | layers.append(block(self.inplanes, planes, groups=self.groups, 192 | base_width=self.base_width, dilation=self.dilation, 193 | norm_layer=norm_layer)) 194 | 195 | return nn.Sequential(*layers) 196 | 197 | def _forward(self, x): 198 | x = self.conv1(x) 199 | x = self.bn1(x) 200 | x = self.relu(x) 201 | x = self.maxpool(x) 202 | 203 | x = self.layer1(x) 204 | x = self.layer2(x) 205 | x = self.layer3(x) 206 | x = self.layer4(x) 207 | x = self.avgpool(x) 208 | x = torch.flatten(x, 1) 209 | x = self.fc(x) 210 | 211 | return x 212 | 213 | # Allow for accessing forward method in a inherited class 214 | forward = _forward 215 | 216 | 217 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 218 | model = ResNet(block, layers, **kwargs) 219 | if pretrained: 220 | state_dict = load_state_dict_from_url(model_urls[arch], 221 | progress=progress) 222 | model.load_state_dict(state_dict) 223 | return model 224 | 225 | 226 | def resnet18(pretrained=False, progress=True, **kwargs): 227 | r"""ResNet-18 model from 228 | `"Deep Residual Learning for Image Recognition" `_ 229 | Args: 230 | pretrained (bool): If True, returns a model pre-trained on ImageNet 231 | progress (bool): If True, displays a progress bar of the download to stderr 232 | """ 233 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 234 | **kwargs) 235 | 236 | 237 | def resnet34(pretrained=False, progress=True, **kwargs): 238 | r"""ResNet-34 model from 239 | `"Deep Residual Learning for Image Recognition" `_ 240 | Args: 241 | pretrained (bool): If True, returns a model pre-trained on ImageNet 242 | progress (bool): If True, displays a progress bar of the download to stderr 243 | """ 244 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 245 | **kwargs) 246 | 247 | 248 | def resnet50(pretrained=False, progress=True, **kwargs): 249 | r"""ResNet-50 model from 250 | `"Deep Residual Learning for Image Recognition" `_ 251 | Args: 252 | pretrained (bool): If True, returns a model pre-trained on ImageNet 253 | progress (bool): If True, displays a progress bar of the download to stderr 254 | """ 255 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 256 | **kwargs) 257 | 258 | 259 | def resnet101(pretrained=False, progress=True, **kwargs): 260 | r"""ResNet-101 model from 261 | `"Deep Residual Learning for Image Recognition" `_ 262 | Args: 263 | pretrained (bool): If True, returns a model pre-trained on ImageNet 264 | progress (bool): If True, displays a progress bar of the download to stderr 265 | """ 266 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 267 | **kwargs) 268 | 269 | 270 | def resnet152(pretrained=False, progress=True, **kwargs): 271 | r"""ResNet-152 model from 272 | `"Deep Residual Learning for Image Recognition" `_ 273 | Args: 274 | pretrained (bool): If True, returns a model pre-trained on ImageNet 275 | progress (bool): If True, displays a progress bar of the download to stderr 276 | """ 277 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 278 | **kwargs) 279 | 280 | 281 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 282 | r"""ResNeXt-50 32x4d model from 283 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 284 | Args: 285 | pretrained (bool): If True, returns a model pre-trained on ImageNet 286 | progress (bool): If True, displays a progress bar of the download to stderr 287 | """ 288 | kwargs['groups'] = 32 289 | kwargs['width_per_group'] = 4 290 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 291 | pretrained, progress, **kwargs) 292 | 293 | 294 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 295 | r"""ResNeXt-101 32x8d model from 296 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 297 | Args: 298 | pretrained (bool): If True, returns a model pre-trained on ImageNet 299 | progress (bool): If True, displays a progress bar of the download to stderr 300 | """ 301 | kwargs['groups'] = 32 302 | kwargs['width_per_group'] = 8 303 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 304 | pretrained, progress, **kwargs) 305 | 306 | 307 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 308 | r"""Wide ResNet-50-2 model from 309 | `"Wide Residual Networks" `_ 310 | The model is the same as ResNet except for the bottleneck number of channels 311 | which is twice larger in every block. The number of channels in outer 1x1 312 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 313 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 314 | Args: 315 | pretrained (bool): If True, returns a model pre-trained on ImageNet 316 | progress (bool): If True, displays a progress bar of the download to stderr 317 | """ 318 | kwargs['width_per_group'] = 64 * 2 319 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 320 | pretrained, progress, **kwargs) 321 | 322 | 323 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 324 | r"""Wide ResNet-101-2 model from 325 | `"Wide Residual Networks" `_ 326 | The model is the same as ResNet except for the bottleneck number of channels 327 | which is twice larger in every block. The number of channels in outer 1x1 328 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 329 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 330 | Args: 331 | pretrained (bool): If True, returns a model pre-trained on ImageNet 332 | progress (bool): If True, displays a progress bar of the download to stderr 333 | """ 334 | kwargs['width_per_group'] = 64 * 2 335 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 336 | pretrained, progress, **kwargs) -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | try: 2 | from torch.hub import load_state_dict_from_url 3 | except ImportError: 4 | from torch.utils.model_zoo import load_url as load_state_dict_from_url -------------------------------------------------------------------------------- /models/vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.utils import load_state_dict_from_url 4 | 5 | 6 | __all__ = [ 7 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 8 | 'vgg19_bn', 'vgg19', 9 | ] 10 | 11 | 12 | model_urls = { 13 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 14 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 15 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 16 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 17 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', 18 | 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', 19 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', 20 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', 21 | } 22 | 23 | 24 | class VGG(nn.Module): 25 | 26 | def __init__(self, features, num_classes=1000, init_weights=True): 27 | super(VGG, self).__init__() 28 | self.features = features 29 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 30 | self.classifier = nn.Sequential( 31 | nn.Linear(128 * 1 * 1, 128), 32 | nn.ReLU(True), 33 | nn.Dropout(), 34 | nn.Linear(128, 128), 35 | nn.ReLU(True), 36 | nn.Dropout(), 37 | nn.Linear(128, num_classes), 38 | ) 39 | if init_weights: 40 | self._initialize_weights() 41 | 42 | def forward(self, x): 43 | x = self.features(x) 44 | x = self.avgpool(x) 45 | x = torch.flatten(x, 1) 46 | x = self.classifier(x) 47 | return x 48 | 49 | def _initialize_weights(self): 50 | for m in self.modules(): 51 | if isinstance(m, nn.Conv2d): 52 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 53 | if m.bias is not None: 54 | nn.init.constant_(m.bias, 0) 55 | elif isinstance(m, nn.BatchNorm2d): 56 | nn.init.constant_(m.weight, 1) 57 | nn.init.constant_(m.bias, 0) 58 | elif isinstance(m, nn.Linear): 59 | nn.init.normal_(m.weight, 0, 0.01) 60 | nn.init.constant_(m.bias, 0) 61 | 62 | 63 | def make_layers(cfg, batch_norm=False): 64 | layers = [] 65 | in_channels = 3 66 | for v in cfg: 67 | if v == 'M': 68 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 69 | else: 70 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 71 | if batch_norm: 72 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 73 | else: 74 | layers += [conv2d, nn.ReLU(inplace=True)] 75 | in_channels = v 76 | return nn.Sequential(*layers) 77 | 78 | 79 | cfgs = { 80 | 'A': [16, 'M', 32, 'M', 64, 64, 'M', 128, 128, 'M', 128, 128, 'M'], 81 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 82 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 83 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 84 | } 85 | 86 | 87 | def _vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs): 88 | if pretrained: 89 | kwargs['init_weights'] = False 90 | model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) 91 | if pretrained: 92 | state_dict = load_state_dict_from_url(model_urls[arch], 93 | progress=progress) 94 | model.load_state_dict(state_dict) 95 | return model 96 | 97 | 98 | def vgg11(pretrained=False, progress=True, **kwargs): 99 | r"""VGG 11-layer model (configuration "A") from 100 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 101 | Args: 102 | pretrained (bool): If True, returns a model pre-trained on ImageNet 103 | progress (bool): If True, displays a progress bar of the download to stderr 104 | """ 105 | return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs) 106 | 107 | 108 | def vgg11_bn(pretrained=False, progress=True, **kwargs): 109 | r"""VGG 11-layer model (configuration "A") with batch normalization 110 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 111 | Args: 112 | pretrained (bool): If True, returns a model pre-trained on ImageNet 113 | progress (bool): If True, displays a progress bar of the download to stderr 114 | """ 115 | return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs) 116 | 117 | 118 | def vgg13(pretrained=False, progress=True, **kwargs): 119 | r"""VGG 13-layer model (configuration "B") 120 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 121 | Args: 122 | pretrained (bool): If True, returns a model pre-trained on ImageNet 123 | progress (bool): If True, displays a progress bar of the download to stderr 124 | """ 125 | return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs) 126 | 127 | 128 | def vgg13_bn(pretrained=False, progress=True, **kwargs): 129 | r"""VGG 13-layer model (configuration "B") with batch normalization 130 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 131 | Args: 132 | pretrained (bool): If True, returns a model pre-trained on ImageNet 133 | progress (bool): If True, displays a progress bar of the download to stderr 134 | """ 135 | return _vgg('vgg13_bn', 'B', True, pretrained, progress, **kwargs) 136 | 137 | 138 | def vgg16(pretrained=False, progress=True, **kwargs): 139 | r"""VGG 16-layer model (configuration "D") 140 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 141 | Args: 142 | pretrained (bool): If True, returns a model pre-trained on ImageNet 143 | progress (bool): If True, displays a progress bar of the download to stderr 144 | """ 145 | return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs) 146 | 147 | 148 | def vgg16_bn(pretrained=False, progress=True, **kwargs): 149 | r"""VGG 16-layer model (configuration "D") with batch normalization 150 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 151 | Args: 152 | pretrained (bool): If True, returns a model pre-trained on ImageNet 153 | progress (bool): If True, displays a progress bar of the download to stderr 154 | """ 155 | return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs) 156 | 157 | 158 | def vgg19(pretrained=False, progress=True, **kwargs): 159 | r"""VGG 19-layer model (configuration "E") 160 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 161 | Args: 162 | pretrained (bool): If True, returns a model pre-trained on ImageNet 163 | progress (bool): If True, displays a progress bar of the download to stderr 164 | """ 165 | return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs) 166 | 167 | 168 | def vgg19_bn(pretrained=False, progress=True, **kwargs): 169 | r"""VGG 19-layer model (configuration 'E') with batch normalization 170 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 171 | Args: 172 | pretrained (bool): If True, returns a model pre-trained on ImageNet 173 | progress (bool): If True, displays a progress bar of the download to stderr 174 | """ 175 | return _vgg('vgg19_bn', 'E', True, pretrained, progress, **kwargs) --------------------------------------------------------------------------------