├── core ├── __init__.py ├── grad_noise.py ├── grad_constraint.py ├── grad_calculate.py └── image_sift.py ├── utils ├── __init__.py ├── file_util.py └── train_util.py ├── preprocessing ├── __init__.py ├── cifar │ ├── __init__.py │ ├── cifar10_gen.py │ └── cifar100_gen.py ├── mnin │ ├── __init__.py │ ├── mask_gen.py │ ├── image_gen.py │ └── image_split.py ├── mnist │ ├── __init__.py │ └── image_gen.py ├── stl10 │ ├── __init__.py │ └── image_gen.py ├── labelme_anno_track.py └── labelme_to_mask.py ├── framework.png ├── metrics ├── __init__.py └── accuracy.py ├── loaders ├── datasets │ ├── __init__.py │ ├── image_dataset.py │ ├── image_mask_dataset.py │ └── image_mask_transforms.py ├── svhn_loader.py ├── mnist_loader.py ├── fashion_mnist_loader.py ├── cifar10_loader.py ├── cifar100_loader.py ├── __init__.py ├── mnin_loader.py └── stl10_loader.py ├── scripts ├── test.sh ├── train.sh ├── grad_calculate.sh ├── image_sift.sh └── train_model_doctor.sh ├── models ├── simnet.py ├── alexnet.py ├── vgg.py ├── squeezenet.py ├── __init__.py ├── wideresnet.py ├── mobilenetv2.py ├── simplenetv1.py ├── resnext.py ├── googlenet.py ├── mnasnet.py ├── shufflenetv2.py ├── densenet.py ├── senet.py ├── resnet.py ├── mobilenet.py ├── xception.py ├── efficientnetv2.py ├── shufflenet.py ├── inceptionv3.py └── inceptionv4.py ├── .gitignore ├── README.md ├── test.py ├── train.py └── train_model_doctor.py /core/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /preprocessing/cifar/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /preprocessing/mnin/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /preprocessing/mnist/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /preprocessing/stl10/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiaconghu/Model-Doctor/HEAD/framework.png -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from metrics.accuracy import accuracy 2 | from metrics.accuracy import ClassAccuracy 3 | -------------------------------------------------------------------------------- /loaders/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from loaders.datasets.image_dataset import ImageDataset 2 | from loaders.datasets.image_mask_dataset import ImageMaskDataset 3 | -------------------------------------------------------------------------------- /core/grad_noise.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class GradNoise: 5 | def __init__(self, module): 6 | self.module = module 7 | self.hook = None 8 | 9 | def add_noise(self): 10 | self.hook = self.module.register_forward_hook(_modify_feature_map) 11 | 12 | def remove_noise(self): 13 | self.hook.remove() 14 | 15 | 16 | # keep forward after modify 17 | def _modify_feature_map(module, inputs, outputs): 18 | noise = torch.randn(outputs.shape).to(outputs.device) 19 | # noise = torch.normal(mean=0, std=3, size=outputs.shape).to(outputs.device) 20 | 21 | outputs += noise 22 | -------------------------------------------------------------------------------- /scripts/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export result_path='/nfs3-p1/hjc/md/output/' 3 | export exp_name='vgg16_cifar10_202206072149' 4 | export model_name='vgg16' 5 | export data_name='cifar10' 6 | export in_channels=3 7 | export num_classes=10 8 | export model_path=${result_path}${exp_name}'/models/model_ori.pth' 9 | export data_path='/nfs3-p1/hjc/datasets/cifar10/test' 10 | export device_index='2' 11 | python test.py \ 12 | --model_name ${model_name} \ 13 | --data_name ${data_name} \ 14 | --in_channels ${in_channels} \ 15 | --num_classes ${num_classes} \ 16 | --model_path ${model_path} \ 17 | --data_path ${data_path} \ 18 | --device_index ${device_index} 19 | -------------------------------------------------------------------------------- /utils/file_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | 5 | def walk_file(path): 6 | count = 0 7 | for root, dirs, files in os.walk(path): 8 | print(root) 9 | 10 | for f in files: 11 | count += 1 12 | # print(os.path.join(root, f)) 13 | 14 | for d in dirs: 15 | print(os.path.join(root, d)) 16 | print(count) 17 | 18 | 19 | def count_files(path): 20 | for root, dirs, files in os.walk(path): 21 | print(root, len(files)) 22 | 23 | 24 | def copy_file(src, dst): 25 | path, name = os.path.split(dst) 26 | if not os.path.exists(path): 27 | os.makedirs(path) 28 | shutil.copyfile(src, dst) 29 | 30 | 31 | if __name__ == '__main__': 32 | count_files('') 33 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export result_path='/nfs3-p1/hjc/md/output/' 3 | export exp_name='vgg16_cifar10_202206072149' 4 | export model_name='vgg16' 5 | export data_name='cifar10' 6 | export in_channels=3 7 | export num_classes=10 8 | export num_epochs=200 9 | export model_dir=${result_path}${exp_name}'/models' 10 | export data_dir='/nfs3-p1/hjc/datasets/cifar10' 11 | export log_dir=${result_path}'/runs/'${exp_name} 12 | export device_index='1' 13 | python train.py \ 14 | --model_name ${model_name} \ 15 | --data_name ${data_name} \ 16 | --in_channels ${in_channels} \ 17 | --num_classes ${num_classes} \ 18 | --num_epochs ${num_epochs} \ 19 | --model_dir ${model_dir} \ 20 | --data_dir ${data_dir} \ 21 | --log_dir ${log_dir} \ 22 | --device_index ${device_index} 23 | -------------------------------------------------------------------------------- /preprocessing/labelme_anno_track.py: -------------------------------------------------------------------------------- 1 | import os 2 | from configs import config 3 | from utils import file_util 4 | 5 | 6 | def track(track_path, result_path): 7 | for root, _, files in os.walk(track_path): 8 | for file in files: 9 | if os.path.splitext(file)[1] == '.json': 10 | class_name = os.path.split(root)[-1] 11 | src = os.path.join(root, file) 12 | dst = os.path.join(result_path, class_name, file) 13 | file_util.copy_file(src, dst) 14 | 15 | 16 | def main(): 17 | track_path = os.path.join(config.output_result, 'vgg16_08241356', 'images') 18 | result_path = os.path.join(config.output_result, 'mnim_lc_images') 19 | track(track_path, result_path) 20 | 21 | 22 | if __name__ == '__main__': 23 | main() 24 | -------------------------------------------------------------------------------- /scripts/grad_calculate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export result_path='/nfs3-p1/hjc/md/output/' 3 | export exp_name='vgg16_cifar10_202206072149' 4 | export model_name='vgg16' 5 | export data_name='cifar10' 6 | export in_channels=3 7 | export num_classes=10 8 | export model_path=${result_path}${exp_name}'/models/model_ori.pth' 9 | export data_path=${result_path}${exp_name}'/images_50' 10 | export grad_path=${result_path}${exp_name}'/grads_50' 11 | export theta=0.2 12 | export device_index='2' 13 | python core/grad_calculate.py \ 14 | --model_name ${model_name} \ 15 | --data_name ${data_name} \ 16 | --in_channels ${in_channels} \ 17 | --num_classes ${num_classes} \ 18 | --model_path ${model_path} \ 19 | --data_path ${data_path} \ 20 | --grad_path ${grad_path} \ 21 | --theta ${theta} \ 22 | --device_index ${device_index} 23 | -------------------------------------------------------------------------------- /scripts/image_sift.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export result_path='/nfs3-p1/hjc/md/output/' 3 | export exp_name='vgg16_cifar10_202206072149' 4 | export model_name='vgg16' 5 | export data_name='cifar10' 6 | export in_channels=3 7 | export num_classes=10 8 | export model_path=${result_path}${exp_name}'/models/model_ori.pth' 9 | export data_path='/nfs3-p1/hjc/datasets/cifar10/train' 10 | export image_path=${result_path}${exp_name}'/images_50' 11 | export num_images=50 12 | export device_index='2' 13 | python core/image_sift.py \ 14 | --model_name ${model_name} \ 15 | --data_name ${data_name} \ 16 | --in_channels ${in_channels} \ 17 | --num_classes ${num_classes} \ 18 | --model_path ${model_path} \ 19 | --data_path ${data_path} \ 20 | --image_path ${image_path} \ 21 | --num_images ${num_images} \ 22 | --device_index ${device_index} 23 | -------------------------------------------------------------------------------- /scripts/train_model_doctor.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export result_path='/nfs3-p1/hjc/md/output/' 3 | export exp_name='vgg16_cifar10_202206072149' 4 | export model_name='vgg16' 5 | export data_name='cifar10' 6 | export in_channels=3 7 | export num_classes=10 8 | export num_epochs=200 9 | export ori_model_path=${result_path}${exp_name}'/models/model_ori.pth' 10 | export res_model_path=${result_path}${exp_name}'/models/model_optim.pth' 11 | export data_dir='/nfs3-p1/hjc/datasets/cifar10' 12 | export log_dir=${result_path}'/runs/'${exp_name} 13 | export mask_dir=${result_path}${exp_name}'/masks' 14 | export grad_dir=${result_path}${exp_name}'/grads_50' 15 | export alpha=10.0 16 | export beta=0.0 17 | export device_index='2' 18 | python train_model_doctor.py \ 19 | --model_name ${model_name} \ 20 | --data_name ${data_name} \ 21 | --in_channels ${in_channels} \ 22 | --num_classes ${num_classes} \ 23 | --num_epochs ${num_epochs} \ 24 | --ori_model_path ${ori_model_path} \ 25 | --res_model_path ${res_model_path} \ 26 | --data_dir ${data_dir} \ 27 | --log_dir ${log_dir} \ 28 | --mask_dir ${mask_dir} \ 29 | --grad_dir ${grad_dir} \ 30 | --alpha ${alpha} \ 31 | --beta ${beta} \ 32 | --device_index ${device_index} 33 | -------------------------------------------------------------------------------- /metrics/accuracy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def accuracy(outputs, labels, topk=(1,)): 6 | with torch.no_grad(): 7 | maxk = max(topk) 8 | batch_size = labels.size(0) 9 | 10 | _, pred = outputs.topk(maxk, 1, True, True) # [batch_size, topk] 11 | pred = pred.t() # [topk, batch_size] 12 | correct = pred.eq(labels.view(1, -1).expand_as(pred)) # [topk, batch_size] 13 | 14 | res = [] 15 | for k in topk: 16 | correct_k = correct[:k].float().sum() 17 | res.append(correct_k.mul_(100.0 / batch_size)) 18 | return res 19 | 20 | 21 | class ClassAccuracy: 22 | def __init__(self, num_classes): 23 | self.sum = np.zeros(num_classes) 24 | self.count = np.zeros(num_classes) 25 | 26 | def accuracy(self, outputs, labels): 27 | _, pred = outputs.max(dim=1) 28 | correct = pred.eq(labels) 29 | 30 | for b, label in enumerate(labels): 31 | self.count[label] += 1 32 | self.sum[label] += correct[b] 33 | 34 | def __str__(self): 35 | fmtstr = '{}:{:6.2f}' 36 | avg = (self.sum / self.count) * 100 37 | result = '\n'.join([fmtstr.format(l, a) for l, a in enumerate(avg)]) 38 | return result 39 | -------------------------------------------------------------------------------- /utils/train_util.py: -------------------------------------------------------------------------------- 1 | class AverageMeter(object): 2 | def __init__(self, name, fmt=':f'): 3 | self.name = name 4 | self.fmt = fmt 5 | self.reset() 6 | 7 | def reset(self): 8 | self.val = 0 9 | self.avg = 0 10 | self.sum = 0 11 | self.count = 0 12 | 13 | def update(self, val, n=1): 14 | self.val = val 15 | self.sum += val * n 16 | self.count += n 17 | self.avg = self.sum / self.count 18 | 19 | def __str__(self): 20 | fmtstr = '{name}[VAL:{val' + self.fmt + '} AVG:{avg' + self.fmt + '}]' 21 | return fmtstr.format(**self.__dict__) 22 | 23 | 24 | class ProgressMeter(object): 25 | def __init__(self, total, step, prefix, meters): 26 | self._fmtstr = self._get_fmtstr(total) 27 | self.meters = meters 28 | self.prefix = prefix 29 | 30 | self.step = step 31 | 32 | def display(self, running): 33 | if running % self.step == 0: 34 | entries = [self.prefix + self._fmtstr.format(running)] # [prefix xx.xx/xx.xx] 35 | entries += [str(meter) for meter in self.meters] 36 | print(' '.join(entries)) 37 | 38 | def _get_fmtstr(self, total): 39 | num_digits = len(str(total // 1)) 40 | fmt = '{:' + str(num_digits) + 'd}' 41 | return '[' + fmt + '/' + fmt.format(total) + ']' # [prefix xx.xx/xx.xx] 42 | -------------------------------------------------------------------------------- /models/simnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from collections import OrderedDict 4 | 5 | 6 | class SimNet(nn.Module): 7 | def __init__(self): 8 | super(SimNet, self).__init__() 9 | self.features = nn.Sequential( 10 | OrderedDict([ 11 | ('c1', nn.Conv2d(3, 9, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))), 12 | ('relu1', nn.ReLU()), 13 | ('s1', nn.MaxPool2d(kernel_size=2, stride=2)), 14 | ('c2', nn.Conv2d(9, 27, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))), 15 | ('relu2', nn.ReLU()), 16 | ('s2', nn.MaxPool2d(kernel_size=2, stride=2)), 17 | ('c3', nn.Conv2d(27, 81, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))), 18 | ('relu3', nn.ReLU()) 19 | ]) 20 | ) 21 | self.classifier = nn.Sequential( 22 | OrderedDict([ 23 | ('f4', nn.Linear(254016, 12)) 24 | ]) 25 | ) 26 | 27 | def forward(self, x): 28 | x = self.features(x) 29 | x = x.view(x.size(0), -1) 30 | x = self.classifier(x) 31 | return x 32 | 33 | 34 | def simnet(): 35 | return SimNet() 36 | 37 | 38 | if __name__ == '__main__': 39 | from torchsummary import summary 40 | 41 | model = simnet() 42 | summary(model, (3, 224, 224)) 43 | print(model) 44 | -------------------------------------------------------------------------------- /loaders/svhn_loader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | from torchvision import transforms 3 | from loaders.datasets import ImageDataset 4 | 5 | 6 | def _get_train_set(data_path): 7 | return ImageDataset(image_dir=data_path, 8 | transform=transforms.Compose([ 9 | transforms.Resize((32, 32)), 10 | transforms.RandomHorizontalFlip(), 11 | transforms.ToTensor(), 12 | transforms.Normalize((0.286,), (0.353,)) 13 | ])) 14 | 15 | 16 | def _get_test_set(data_path): 17 | return ImageDataset(image_dir=data_path, 18 | transform=transforms.Compose([ 19 | transforms.Resize((32, 32)), 20 | transforms.ToTensor(), 21 | transforms.Normalize((0.286,), (0.353,)) 22 | ])) 23 | 24 | 25 | def load_images(data_path, data_type=None): 26 | assert data_type is None or data_type in ['train', 'test'] 27 | if data_type == 'train': 28 | data_set = _get_train_set(data_path) 29 | else: 30 | data_set = _get_test_set(data_path) 31 | 32 | data_loader = DataLoader(dataset=data_set, 33 | batch_size=128, 34 | num_workers=4, 35 | shuffle=True) 36 | 37 | return data_loader 38 | -------------------------------------------------------------------------------- /loaders/mnist_loader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | from torchvision import transforms 3 | from loaders.datasets import ImageDataset 4 | 5 | 6 | def _get_train_set(data_path): 7 | return ImageDataset(image_dir=data_path, 8 | transform=transforms.Compose([ 9 | transforms.Resize((32, 32)), 10 | transforms.RandomCrop(32, padding=2), 11 | transforms.ToTensor(), 12 | transforms.Normalize((0.1307,), (0.3081,)) 13 | ])) 14 | 15 | 16 | def _get_test_set(data_path): 17 | return ImageDataset(image_dir=data_path, 18 | transform=transforms.Compose([ 19 | transforms.Resize((32, 32)), 20 | transforms.ToTensor(), 21 | transforms.Normalize((0.1307,), (0.3081,)) 22 | ])) 23 | 24 | 25 | def load_images(data_path, data_type=None): 26 | assert data_type is None or data_type in ['train', 'test'] 27 | if data_type == 'train': 28 | data_set = _get_train_set(data_path) 29 | else: 30 | data_set = _get_test_set(data_path) 31 | 32 | data_loader = DataLoader(dataset=data_set, 33 | batch_size=128, 34 | num_workers=4, 35 | shuffle=True) 36 | 37 | return data_loader 38 | -------------------------------------------------------------------------------- /loaders/fashion_mnist_loader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | from torchvision import transforms 3 | from loaders.datasets import ImageDataset 4 | 5 | 6 | def _get_train_set(data_path): 7 | return ImageDataset(image_dir=data_path, 8 | transform=transforms.Compose([ 9 | transforms.Resize((32, 32)), 10 | transforms.RandomCrop(32, padding=2), 11 | transforms.RandomHorizontalFlip(), 12 | transforms.ToTensor(), 13 | transforms.Normalize((0.286,), (0.353,)) 14 | ])) 15 | 16 | 17 | def _get_test_set(data_path): 18 | return ImageDataset(image_dir=data_path, 19 | transform=transforms.Compose([ 20 | transforms.Resize((32, 32)), 21 | transforms.ToTensor(), 22 | transforms.Normalize((0.286,), (0.353,)) 23 | ])) 24 | 25 | 26 | def load_images(data_path, data_type=None): 27 | assert data_type is None or data_type in ['train', 'test'] 28 | if data_type == 'train': 29 | data_set = _get_train_set(data_path) 30 | else: 31 | data_set = _get_test_set(data_path) 32 | 33 | data_loader = DataLoader(dataset=data_set, 34 | batch_size=128, 35 | num_workers=4, 36 | shuffle=True) 37 | 38 | return data_loader 39 | -------------------------------------------------------------------------------- /preprocessing/labelme_to_mask.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | import cv2 5 | 6 | from configs import config 7 | from utils import image_util 8 | 9 | 10 | def parse_json(json_path): 11 | data = json.load(open(json_path)) 12 | shapes = data['shapes'][0] 13 | points = shapes['points'] 14 | return points 15 | 16 | 17 | def polygons_to_mask(img_shape, polygons): 18 | """ 19 | 边界点生成mask 20 | :param img_shape: [h,w] 21 | :param polygons: labelme JSON中的边界点格式 [[x1,y1],[x2,y2],[x3,y3],...[xn,yn]] 22 | :return: mask 0-1 23 | """ 24 | mask = np.zeros(img_shape, dtype=np.uint8) 25 | polygons = np.asarray([polygons], np.int32) # 这里必须是int32,其他类型使用fillPoly会报错 26 | cv2.fillPoly(mask, polygons, 1) # 非int32 会报错 27 | return mask 28 | 29 | 30 | def main(): 31 | images_dir = os.path.join(config.output_result, 'vgg16_09012200', 'images') 32 | for root, _, files in os.walk(images_dir): 33 | for file in files: 34 | if os.path.splitext(file)[1] == '.json': 35 | img = cv2.imread(os.path.join(root, file.replace('json', 'png'))) 36 | 37 | json_name = os.path.splitext(file)[0] 38 | json_path = os.path.join(root, file) 39 | print(json_path) 40 | mask = polygons_to_mask([img.shape[0], img.shape[1]], parse_json(json_path)) * 255 41 | mask_path = os.path.join(config.result_masks_stl10, json_name + '.png') 42 | image_util.save_cv(mask, mask_path) 43 | 44 | 45 | if __name__ == '__main__': 46 | main() 47 | -------------------------------------------------------------------------------- /models/alexnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | # no LRN 5 | class AlexNet(nn.Module): 6 | def __init__(self, in_channels=3, num_classes=10): 7 | super(AlexNet, self).__init__() 8 | self.features = nn.Sequential( 9 | nn.Conv2d(in_channels, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), 10 | nn.ReLU(inplace=True), 11 | nn.MaxPool2d(kernel_size=2), 12 | nn.Conv2d(64, 192, kernel_size=(3, 3), padding=(1, 1)), 13 | nn.ReLU(inplace=True), 14 | nn.MaxPool2d(kernel_size=2), 15 | nn.Conv2d(192, 384, kernel_size=(3, 3), padding=(1, 1)), 16 | nn.ReLU(inplace=True), 17 | nn.Conv2d(384, 256, kernel_size=(3, 3), padding=(1, 1)), 18 | nn.ReLU(inplace=True), 19 | nn.Conv2d(256, 256, kernel_size=(3, 3), padding=(1, 1)), 20 | nn.ReLU(inplace=True), 21 | nn.MaxPool2d(kernel_size=2), 22 | ) 23 | self.classifier = nn.Sequential( 24 | nn.Dropout(), 25 | nn.Linear(256 * 2 * 2, 4096), 26 | nn.ReLU(inplace=True), 27 | nn.Dropout(), 28 | nn.Linear(4096, 4096), 29 | nn.ReLU(inplace=True), 30 | nn.Linear(4096, num_classes), 31 | ) 32 | 33 | def forward(self, x): 34 | x = self.features(x) 35 | x = x.view(x.size(0), 256 * 2 * 2) 36 | x = self.classifier(x) 37 | return x 38 | 39 | 40 | def alexnet(in_channels=3, num_classes=10): 41 | return AlexNet(in_channels=in_channels, num_classes=num_classes) 42 | -------------------------------------------------------------------------------- /loaders/cifar10_loader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | from torchvision import transforms 3 | from loaders.datasets import ImageDataset 4 | 5 | 6 | def _get_train_set(data_path): 7 | return ImageDataset(image_dir=data_path, 8 | transform=transforms.Compose([ 9 | transforms.Resize((32, 32)), 10 | transforms.RandomCrop(32, padding=4), 11 | transforms.RandomHorizontalFlip(), 12 | transforms.ToTensor(), 13 | transforms.Normalize((0.4914, 0.4822, 0.4465), 14 | (0.2023, 0.1994, 0.2010)) 15 | ])) 16 | 17 | 18 | def _get_test_set(data_path): 19 | return ImageDataset(image_dir=data_path, 20 | transform=transforms.Compose([ 21 | transforms.ToTensor(), 22 | transforms.Normalize((0.4914, 0.4822, 0.4465), 23 | (0.2023, 0.1994, 0.2010)) 24 | ])) 25 | 26 | 27 | def load_images(data_path, data_type=None): 28 | assert data_type is None or data_type in ['train', 'test'] 29 | if data_type == 'train': 30 | data_set = _get_train_set(data_path) 31 | else: 32 | data_set = _get_test_set(data_path) 33 | 34 | data_loader = DataLoader(dataset=data_set, 35 | batch_size=128, 36 | num_workers=4, 37 | shuffle=True) 38 | 39 | return data_loader 40 | -------------------------------------------------------------------------------- /loaders/cifar100_loader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | from torchvision import transforms 3 | from loaders.datasets import ImageDataset 4 | 5 | 6 | def _get_train_set(data_path): 7 | return ImageDataset(image_dir=data_path, 8 | transform=transforms.Compose([ 9 | transforms.RandomCrop(32, padding=4), 10 | transforms.RandomHorizontalFlip(), 11 | transforms.RandomRotation(15), 12 | transforms.ToTensor(), 13 | transforms.Normalize((0.5070751592371323, 0.48654887331495095, 0.4409178433670343), 14 | (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)) 15 | ])) 16 | 17 | 18 | def _get_test_set(data_path): 19 | return ImageDataset(image_dir=data_path, 20 | transform=transforms.Compose([ 21 | transforms.ToTensor(), 22 | transforms.Normalize((0.5070751592371323, 0.48654887331495095, 0.4409178433670343), 23 | (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)) 24 | ])) 25 | 26 | 27 | def load_images(data_path, data_type=None): 28 | assert data_type is None or data_type in ['train', 'test'] 29 | if data_type == 'train': 30 | data_set = _get_train_set(data_path) 31 | else: 32 | data_set = _get_test_set(data_path) 33 | 34 | data_loader = DataLoader(dataset=data_set, 35 | batch_size=128, 36 | num_workers=4, 37 | shuffle=True) 38 | 39 | return data_loader 40 | -------------------------------------------------------------------------------- /loaders/datasets/image_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import PIL.Image as Image 3 | from torch.utils.data import Dataset 4 | 5 | 6 | def _img_loader(path, mode='RGB'): 7 | assert mode in ['RGB', 'L'] 8 | with open(path, 'rb') as f: 9 | img = Image.open(f) 10 | return img.convert(mode) 11 | 12 | 13 | def _find_classes(root): 14 | class_names = [d.name for d in os.scandir(root) if d.is_dir()] 15 | class_names.sort() 16 | classes_indices = {class_names[i]: i for i in range(len(class_names))} 17 | # print(classes_indices) 18 | return class_names, classes_indices # 'class_name':index 19 | 20 | 21 | def _make_dataset(image_dir): 22 | samples = [] # image_path, class_idx 23 | 24 | class_names, class_indices = _find_classes(image_dir) 25 | 26 | for class_name in sorted(class_names): 27 | class_idx = class_indices[class_name] 28 | target_dir = os.path.join(image_dir, class_name) 29 | 30 | if not os.path.isdir(target_dir): 31 | continue 32 | 33 | for root, _, files in sorted(os.walk(target_dir)): 34 | for file in sorted(files): 35 | image_path = os.path.join(root, file) 36 | item = image_path, class_idx 37 | samples.append(item) 38 | return samples 39 | 40 | 41 | class ImageDataset(Dataset): 42 | def __init__(self, image_dir, transform=None): 43 | self.image_dir = image_dir 44 | self.transform = transform 45 | self.samples = _make_dataset(self.image_dir) 46 | self.targets = [s[1] for s in self.samples] 47 | 48 | def __getitem__(self, index): 49 | image_path, target = self.samples[index] 50 | image = _img_loader(image_path, mode='RGB') 51 | name = os.path.split(image_path)[1] 52 | 53 | if self.transform is not None: 54 | image = self.transform(image) 55 | 56 | return image, target, name 57 | 58 | def __len__(self): 59 | return len(self.samples) 60 | -------------------------------------------------------------------------------- /preprocessing/mnist/image_gen.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append('/workspace/classification/code/') # zjl 4 | 5 | import numpy as np 6 | import cv2 7 | import os 8 | 9 | from configs import config 10 | from utils import image_util 11 | 12 | 13 | def save_mnist_to_jpg(mnist_image_file, mnist_label_file, save_dir): 14 | if 'train' in os.path.basename(mnist_image_file): 15 | num_file = 60000 16 | prefix = 'train' 17 | else: 18 | num_file = 10000 19 | prefix = 'test' 20 | with open(mnist_image_file, 'rb') as f1: 21 | image_file = f1.read() 22 | with open(mnist_label_file, 'rb') as f2: 23 | label_file = f2.read() 24 | image_file = image_file[16:] 25 | label_file = label_file[8:] 26 | for i in range(num_file): 27 | label = int(label_file[i]) 28 | image_list = [int(item) for item in image_file[i * 784:i * 784 + 784]] 29 | image_np = np.array(image_list, dtype=np.uint8).reshape(28, 28) 30 | save_name = os.path.join(save_dir, str(label), '{}_{}.jpg'.format(prefix, i)) 31 | # gray = cv2.cvtColor(image_np, cv2.COLOR_BGR2GRAY) 32 | image_util.save_cv(image_np, save_name) 33 | print('{} ==> {}_{}_{}.jpg'.format(i, prefix, i, label)) 34 | 35 | 36 | if __name__ == '__main__': 37 | datasets_path = config.datasets_FASHION_MNIST 38 | train_image_file = datasets_path + '/train-images-idx3-ubyte' 39 | train_label_file = datasets_path + '/train-labels-idx1-ubyte' 40 | test_image_file = datasets_path + '/t10k-images-idx3-ubyte' 41 | test_label_file = datasets_path + '/t10k-labels-idx1-ubyte' 42 | 43 | save_train_dir = config.data_fashion_mnist + '/train' 44 | save_test_dir = config.data_fashion_mnist + '/test' 45 | 46 | if not os.path.exists(save_train_dir): 47 | os.makedirs(save_train_dir) 48 | if not os.path.exists(save_test_dir): 49 | os.makedirs(save_test_dir) 50 | 51 | save_mnist_to_jpg(train_image_file, train_label_file, save_train_dir) 52 | save_mnist_to_jpg(test_image_file, test_label_file, save_test_dir) 53 | -------------------------------------------------------------------------------- /models/vgg.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | cfg = { 4 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 5 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 6 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 7 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'] 8 | } 9 | 10 | 11 | class VGG(nn.Module): 12 | 13 | def __init__(self, features, num_classes=10): 14 | super().__init__() 15 | self.features = features 16 | 17 | self.classifier = nn.Sequential( 18 | nn.Linear(512, 4096), 19 | nn.ReLU(inplace=True), 20 | nn.Dropout(), 21 | nn.Linear(4096, 4096), 22 | nn.ReLU(inplace=True), 23 | nn.Dropout(), 24 | nn.Linear(4096, num_classes) 25 | ) 26 | 27 | def forward(self, x): 28 | output = self.features(x) 29 | output = output.view(output.size()[0], -1) 30 | output = self.classifier(output) 31 | 32 | return output 33 | 34 | 35 | def make_layers(cfg, in_channels=3, batch_norm=False): 36 | layers = [] 37 | 38 | input_channel = in_channels 39 | for l in cfg: 40 | if l == 'M': 41 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 42 | continue 43 | 44 | layers += [nn.Conv2d(input_channel, l, kernel_size=3, padding=1)] 45 | 46 | if batch_norm: 47 | layers += [nn.BatchNorm2d(l)] 48 | 49 | layers += [nn.ReLU(inplace=True)] 50 | input_channel = l 51 | 52 | return nn.Sequential(*layers) 53 | 54 | 55 | def vgg11_bn(num_classes=100): 56 | return VGG(make_layers(cfg['A'], batch_norm=True)) 57 | 58 | 59 | def vgg13_bn(): 60 | return VGG(make_layers(cfg['B'], batch_norm=True)) 61 | 62 | 63 | def vgg16_bn(in_channels=3, num_classes=10): 64 | return VGG(make_layers(cfg['D'], batch_norm=True, in_channels=in_channels), num_classes=num_classes) 65 | 66 | 67 | def vgg19_bn(): 68 | return VGG(make_layers(cfg['E'], batch_norm=True)) 69 | -------------------------------------------------------------------------------- /preprocessing/mnin/mask_gen.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append('/disk2/hjc/classification/') # 210 4 | import os 5 | import numpy as np 6 | import cv2 7 | 8 | from configs import config 9 | from utils import image_util 10 | 11 | 12 | def iter_img_to_gray(): 13 | for root, _, files in os.walk(config.result_images): 14 | for file in files: 15 | path = os.path.join(root, file) 16 | img = cv2.imread(path, cv2.IMREAD_COLOR) 17 | img = rgb_to_gray(img) 18 | filename = path.replace(config.result_images, config.result_images + '_gray') 19 | image_util.save_cv(img, filename) 20 | 21 | 22 | def iter_img_to_mask(): 23 | for root, _, files in os.walk(config.result_images + '_gray'): 24 | for file in files: 25 | path = os.path.join(root, file) 26 | img = cv2.imread(path, cv2.IMREAD_COLOR) 27 | img = gen_mask(img) 28 | filename = path.replace(config.result_images + '_gray', config.result_masks) 29 | image_util.save_cv(img, filename) 30 | 31 | 32 | def rgb_to_gray(img): 33 | gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 34 | return gray 35 | 36 | 37 | def gen_mask(img): 38 | img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) 39 | 40 | lower_red = np.array([156, 43, 46]) 41 | upper_red = np.array([180, 255, 255]) 42 | mask_red = cv2.inRange(img_hsv, lower_red, upper_red) 43 | 44 | lower_blue = np.array([78, 43, 46]) 45 | upper_blue = np.array([99, 255, 255]) 46 | mask_blue = cv2.inRange(img_hsv, lower_blue, upper_blue) 47 | 48 | lower_yellow = np.array([26, 43, 46]) 49 | upper_yellow = np.array([34, 255, 255]) 50 | mask_yellow = cv2.inRange(img_hsv, lower_yellow, upper_yellow) 51 | 52 | outline = mask_red | mask_blue | mask_yellow 53 | 54 | fill = outline.copy() 55 | h, w = img.shape[:2] 56 | zeros = np.zeros((h + 2, w + 2), np.uint8) 57 | cv2.floodFill(fill, zeros, (0, 0), 255) # flood fill 58 | 59 | fill = cv2.bitwise_not(fill) 60 | mask = outline | fill 61 | return mask 62 | 63 | 64 | if __name__ == '__main__': 65 | iter_img_to_gray() 66 | # iter_img_to_mask() 67 | -------------------------------------------------------------------------------- /preprocessing/mnin/image_gen.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append('/workspace/classification/code/') # zjl 4 | import csv 5 | import os 6 | from PIL import Image 7 | 8 | from configs import config 9 | 10 | data_dir = '/workspace/classification/datasets/mini-imagenet/' 11 | train_csv_path = data_dir + '/train.csv' 12 | val_csv_path = data_dir + '/val.csv' 13 | test_csv_path = data_dir + '/test.csv' 14 | images_path = data_dir + '/images' 15 | 16 | output_images = config.data_mini_imagenet 17 | 18 | train_label = {} 19 | val_label = {} 20 | test_label = {} 21 | with open(train_csv_path) as csv_file: 22 | csv_reader = csv.reader(csv_file) 23 | birth_header = next(csv_reader) 24 | for row in csv_reader: 25 | train_label[row[0]] = row[1] 26 | 27 | with open(val_csv_path) as csv_file: 28 | csv_reader = csv.reader(csv_file) 29 | birth_header = next(csv_reader) 30 | for row in csv_reader: 31 | val_label[row[0]] = row[1] 32 | 33 | with open(test_csv_path) as csv_file: 34 | csv_reader = csv.reader(csv_file) 35 | birth_header = next(csv_reader) 36 | for row in csv_reader: 37 | test_label[row[0]] = row[1] 38 | 39 | for filename in os.listdir(images_path): 40 | if not filename.endswith('jpg'): 41 | continue 42 | 43 | path = images_path + '/' + filename 44 | print(path) 45 | im = Image.open(path) 46 | 47 | if filename in train_label.keys(): 48 | tmp = train_label[filename] 49 | temp_path = output_images + '/train' + '/' + tmp 50 | if not os.path.exists(temp_path): 51 | os.makedirs(temp_path) 52 | t = temp_path + '/' + filename 53 | im.save(t) 54 | 55 | elif filename in val_label.keys(): 56 | tmp = val_label[filename] 57 | temp_path = output_images + '/val' + '/' + tmp 58 | if not os.path.exists(temp_path): 59 | os.makedirs(temp_path) 60 | t = temp_path + '/' + filename 61 | im.save(t) 62 | 63 | elif filename in test_label.keys(): 64 | tmp = test_label[filename] 65 | temp_path = output_images + '/test' + '/' + tmp 66 | if not os.path.exists(temp_path): 67 | os.makedirs(temp_path) 68 | t = temp_path + '/' + filename 69 | im.save(t) 70 | -------------------------------------------------------------------------------- /loaders/__init__.py: -------------------------------------------------------------------------------- 1 | from loaders.cifar10_loader import load_images as load_cifar10 2 | from loaders.cifar100_loader import load_images as load_cifar100 3 | from loaders.mnist_loader import load_images as load_mnist 4 | from loaders.fashion_mnist_loader import load_images as load_fashion_mnist 5 | from loaders.svhn_loader import load_images as load_svhn 6 | from loaders.stl10_loader import load_images as load_stl10 7 | from loaders.stl10_loader import load_images_masks as load_stl10_masks 8 | from loaders.mnin_loader import load_images as load_mnin 9 | from loaders.mnin_loader import load_images_masks as load_mnin_masks 10 | 11 | 12 | def load_data(data_name, data_path, data_type=None): 13 | print('-' * 50) 14 | print('DATA NAME:', data_name) 15 | print('DATA PATH:', data_path) 16 | print('DATA TYPE:', data_type) 17 | print('-' * 50) 18 | 19 | assert data_name in ['cifar10', 'cifar100', 'mnist', 'fashion_mnist', 'svhn', 'stl10', 'mnin'] 20 | 21 | data_loader = None 22 | if data_name == 'cifar10': 23 | data_loader = load_cifar10(data_path, data_type) 24 | elif data_name == 'cifar100': 25 | data_loader = load_cifar100(data_path, data_type) 26 | elif data_name == 'mnist': 27 | data_loader = load_mnist(data_path, data_type) 28 | elif data_name == 'fashion_mnist': 29 | data_loader = load_fashion_mnist(data_path, data_type) 30 | elif data_name == 'svhn': 31 | data_loader = load_svhn(data_path, data_type) 32 | elif data_name == 'stl10': 33 | data_loader = load_stl10(data_path, data_type) 34 | elif data_name == 'mnin': 35 | data_loader = load_mnin(data_path, data_type) 36 | return data_loader 37 | 38 | 39 | def load_data_mask(data_name, data_path, mask_path, data_type=None): 40 | print('-' * 50) 41 | print('DATA NAME:', data_name) 42 | print('DATA PATH:', data_path) 43 | print('DATA TYPE:', data_type) 44 | print('-' * 50) 45 | 46 | assert data_name in ['stl10', 'mnin'] 47 | 48 | data_loader = None 49 | if data_name == 'stl10': 50 | data_loader = load_stl10_masks(data_path, mask_path, data_type) 51 | elif data_name == 'mnin': 52 | data_loader = load_mnin_masks(data_path, mask_path, data_type) 53 | return data_loader 54 | -------------------------------------------------------------------------------- /preprocessing/cifar/cifar10_gen.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | # sys.path.append('/disk1/hjc/classification/') # 205 4 | sys.path.append('/workspace/classification/code/') # zjlab 5 | import numpy as np 6 | from PIL import Image 7 | import pickle 8 | import os 9 | from configs import config 10 | 11 | 12 | def unpickle(paths, types): 13 | CHANNEL = 3 14 | WIDTH = 32 15 | HEIGHT = 32 16 | 17 | data = [] 18 | labels = [] 19 | filenames = [] 20 | classification = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] 21 | 22 | check_path(classification, types) 23 | 24 | for path in paths: 25 | path = '{}/{}'.format(config.datasets_CIFAR_10, path) 26 | with open(path, mode='rb') as file: 27 | # 数据集在当脚本前文件夹下 28 | data_dict = pickle.load(file, encoding='bytes') 29 | data += list(data_dict[b'data']) 30 | labels += list(data_dict[b'labels']) 31 | filenames += list(data_dict[b'filenames']) 32 | 33 | img = np.reshape(data, [-1, CHANNEL, WIDTH, HEIGHT]) 34 | 35 | for i in range(img.shape[0]): 36 | r = img[i][0] 37 | g = img[i][1] 38 | b = img[i][2] 39 | 40 | ir = Image.fromarray(r) 41 | ig = Image.fromarray(g) 42 | ib = Image.fromarray(b) 43 | rgb = Image.merge("RGB", (ir, ig, ib)) 44 | 45 | filename = '{}/{}/{}/{}.png'.format(config.data_cifar10, types, classification[labels[i]], i) 46 | rgb.save(filename, "PNG") 47 | 48 | 49 | def check_path(classification, types): 50 | for cls in classification: 51 | data_path = '{}/{}/{}'.format(config.data_cifar10, types, cls) 52 | if not os.path.exists(data_path): 53 | os.makedirs(data_path) 54 | 55 | 56 | def main(): 57 | train_paths = ['data_batch_1', 'data_batch_2', 'data_batch_3', 'data_batch_4', 'data_batch_5'] 58 | test_paths = ['test_batch'] 59 | unpickle(train_paths, types='train') 60 | unpickle(test_paths, types='test') 61 | 62 | 63 | def _test(): 64 | import pickle 65 | with open("{}/data_batch_1".format(config.datasets_CIFAR_10), 'rb') as fo: 66 | dict = pickle.load(fo, encoding='bytes') 67 | print(dict) 68 | 69 | 70 | if __name__ == '__main__': 71 | main() 72 | # _test() 73 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Model Doctor: A Simple Gradient Aggregation Strategy for Diagnosing and Treating CNN Classifiers 2 | 3 | ![](framework.png) 4 | 5 | This is an official PyTorch implementation of the [Model Doctor](https://arxiv.org/pdf/2112.04934.pdf): 6 | ``` 7 | @article{feng2021model, 8 | title={Model Doctor: A Simple Gradient Aggregation Strategy for Diagnosing and Treating CNN Classifiers}, 9 | author={Feng, Zunlei and Hu, Jiacong and Wu, Sai and Yu, Xiaotian and Song, Jie and Song, Mingli}, 10 | journal={AAAI}, 11 | year={2021} 12 | } 13 | ``` 14 | ### Environment 15 | + python 3.8 16 | 17 | ### Repository structure 18 | ``` 19 | . 20 | ├── core // the core code of the paper 21 | ├── loaders // the data loaders of the CIFAR-10, CIFAR-100, etc 22 | ├── metrics // the metrics method of the model, such as accuracy 23 | ├── models // various CNN models include AlexNet, VGGNet, etc 24 | ├── preprocessing // preprocessing the different datasets and labelme annotations 25 | ├── scripts // the script command to run the code 26 | ├── utils // various tools include file tool, train tool 27 | ├── test.py // test a classification model 28 | ├── train.py // train a pure classification model 29 | └── train_model_doctor.py // train a classification model with the model doctor 30 | ``` 31 | 32 | ### Command 33 | #### 1. Train a pre-trained model 34 | ```shell 35 | bash scripts/train.sh 36 | ``` 37 | #### 2. Prepare for channel constraints 38 | 1. Sift high confidence images 39 | ```shell 40 | bash scripts/image_sift.sh 41 | ``` 42 | 2. Calculate the gradient for each class 43 | ```shell 44 | bash scripts/grad_calcualte.sh 45 | ``` 46 | 47 | #### 3. Prepare for spatial constraints 48 | 1. Sift low confidence images 49 | ```shell 50 | bash scripts/image_sift.sh 51 | ``` 52 | 2. Label the foreground area with [labelme](https://github.com/wkentaro/labelme) 53 | 3. Convert labelme annotation files into masks 54 | ``` 55 | python preprocessing/labelme_to_mask.py 56 | ``` 57 | 58 | #### 4. Train the model with the model doctor: 59 | ```shell 60 | bash scripts/train_model_doctor.sh 61 | ``` 62 | 63 | ### Tips 64 | As we described in the paper, the small size dataset (such as MNIST, Fashion-MNIST, CIFAR-10, CIFAR-100) is not suitable for spatial constraint, so if you want to use a small size dataset or channel constraint only, please set `beta = 0` in `scripts/train_model_doctor.sh`. 65 | 66 | 67 | -------------------------------------------------------------------------------- /preprocessing/mnin/image_split.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append('/workspace/classification/code/') # zjl 4 | # sys.path.append('/nfs3-p1/hjc/classification/code/') # vipa 5 | 6 | import os 7 | import random 8 | 9 | from configs import config 10 | from utils import file_util 11 | 12 | 13 | def split(): 14 | SplitData(input_path=config.data_mini_imagenet_temp, 15 | output_path=config.data_mini_imagenet) 16 | 17 | 18 | class SplitData: 19 | def __init__(self, input_path, output_path, percentage=0.2): 20 | self.input_path = input_path 21 | self.output_path = output_path 22 | self.percentage = percentage # this ratio represents the proportion of the test set 23 | self._split_data(self._data_idx()) 24 | 25 | def _data_idx(self): 26 | data_idx = {} 27 | 28 | for root, _, files in sorted(os.walk(self.input_path)): 29 | if len(files) != 0: 30 | class_name = os.path.split(root)[1] 31 | data_idx[class_name] = self._random_idx(start=0, end=len(files)) 32 | # if len(data_random_idx) == 10: # select 10 classes 33 | # break 34 | return data_idx 35 | 36 | def _split_data(self, data_random_idx): 37 | for root, _, files in sorted(os.walk(self.input_path)): 38 | data_idx = 0 39 | for file in sorted(files): 40 | class_name = os.path.split(root)[1] 41 | if class_name in data_random_idx.keys(): 42 | if data_idx in data_random_idx[class_name]: 43 | data_type = 'test' 44 | else: 45 | data_type = 'train' 46 | 47 | original_path = os.path.join(root, file) 48 | output_path = os.path.join(self.output_path, data_type, class_name, file) 49 | file_util.copy_file(original_path, output_path) 50 | 51 | data_idx += 1 52 | print(root, data_idx, data_type) 53 | 54 | def _random_idx(self, start, end): 55 | random.seed(0) 56 | start, end = (int(start), int(end)) if start <= end else (int(end), int(start)) 57 | length = int((end - start) * self.percentage) 58 | ran_list = [] 59 | while len(ran_list) < length: 60 | x = random.randint(start, end) 61 | if x not in ran_list: 62 | ran_list.append(x) 63 | return ran_list 64 | 65 | 66 | if __name__ == '__main__': 67 | split() 68 | -------------------------------------------------------------------------------- /loaders/datasets/image_mask_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import PIL.Image as Image 3 | from torch.utils.data import Dataset 4 | 5 | 6 | def _img_loader(path, mode='RGB', default_mask_path=None): 7 | assert mode in ['RGB', 'L'] 8 | # -------------------------------------- 9 | if not os.path.exists(path): 10 | path = default_mask_path 11 | # -------------------------------------- 12 | with open(path, 'rb') as f: 13 | img = Image.open(f) 14 | return img.convert(mode) 15 | 16 | 17 | def _find_classes(root): 18 | class_names = [d.name for d in os.scandir(root) if d.is_dir()] 19 | class_names.sort() 20 | classes_indices = {class_names[i]: i for i in range(len(class_names))} 21 | # print(classes_indices) 22 | return class_names, classes_indices # 'class_name':index 23 | 24 | 25 | def _make_dataset(image_dir, mask_dir): 26 | samples = [] # image_path, mask_path, class_idx 27 | 28 | class_names, class_indices = _find_classes(image_dir) 29 | 30 | for class_name in sorted(class_names): 31 | class_idx = class_indices[class_name] 32 | target_dir = os.path.join(image_dir, class_name) 33 | 34 | if not os.path.isdir(target_dir): 35 | continue 36 | 37 | for root, _, files in sorted(os.walk(target_dir)): 38 | for file in sorted(files): 39 | image_path = os.path.join(root, file) 40 | mask_path = os.path.join(mask_dir, file.replace('jpg', 'png')) 41 | item = image_path, mask_path, class_idx 42 | samples.append(item) 43 | return samples 44 | 45 | 46 | class ImageMaskDataset(Dataset): 47 | def __init__(self, image_dir, mask_dir, default_mask_path, transform=None): 48 | self.image_dir = image_dir 49 | self.mask_dir = mask_dir 50 | self.default_mask_path = default_mask_path 51 | self.transform = transform 52 | self.samples = _make_dataset(self.image_dir, self.mask_dir) 53 | self.targets = [s[2] for s in self.samples] 54 | 55 | def __getitem__(self, index): 56 | image_path, mask_path, target = self.samples[index] 57 | image = _img_loader(image_path, mode='RGB') 58 | mask = _img_loader(mask_path, mode='L', default_mask_path=self.default_mask_path) 59 | 60 | images = [image, mask] 61 | if self.transform is not None: 62 | images = self.transform(images) 63 | 64 | return images[0], target, images[1] 65 | 66 | def __len__(self): 67 | return len(self.samples) 68 | -------------------------------------------------------------------------------- /core/grad_constraint.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torchvision import transforms 4 | from core.grad_calculate import HookModule 5 | 6 | 7 | class GradConstraint: 8 | 9 | def __init__(self, module, grad_path, alpha, beta): 10 | self.module = HookModule(module) 11 | self.channels = torch.from_numpy(np.load(grad_path)).cuda() 12 | self.alpha = alpha 13 | self.beta = beta 14 | 15 | def loss_channel(self, outputs, labels): 16 | loss = 0 17 | 18 | # high response channel loss 19 | probs = torch.argsort(-outputs, dim=1) 20 | labels_ = [] 21 | for i in range(len(labels)): 22 | if probs[i][0] == labels[i]: 23 | labels_.append(probs[i][1]) # TP rank2 24 | else: 25 | labels_.append(probs[i][0]) # FP rank1 26 | labels_ = torch.tensor(labels_).cuda() 27 | nll_loss_ = torch.nn.NLLLoss()(outputs, labels_) 28 | loss += _loss_channel(channels=self.channels, 29 | grads=self.module.grads(outputs=-nll_loss_), 30 | labels=labels_, 31 | is_high=True) 32 | 33 | # low response channel loss 34 | nll_loss = torch.nn.NLLLoss()(outputs, labels) 35 | loss += _loss_channel(channels=self.channels, 36 | grads=self.module.grads(outputs=-nll_loss), 37 | labels=labels, 38 | is_high=False) 39 | return loss * self.alpha 40 | 41 | def loss_spatial(self, outputs, labels, masks): 42 | if isinstance(masks, torch.Tensor): # masks is masks 43 | nll_loss = torch.nn.NLLLoss()(outputs, labels) 44 | grads = self.module.grads(outputs=-nll_loss) 45 | masks = transforms.Resize((grads.shape[2], grads.shape[3]))(masks) 46 | masks_bg = 1 - masks 47 | grads_bg = torch.abs(masks_bg * grads) 48 | 49 | loss = grads_bg.sum() 50 | return loss * self.beta 51 | else: 52 | return torch.tensor(0) 53 | 54 | 55 | def _loss_channel(channels, grads, labels, is_high=True): 56 | grads = torch.relu(grads) 57 | channel_grads = torch.sum(grads, dim=(2, 3)) # [batch_size, channels] 58 | 59 | loss = 0 60 | if is_high: 61 | for b, l in enumerate(labels): 62 | loss += (channel_grads[b] * channels[l]).sum() 63 | else: 64 | for b, l in enumerate(labels): 65 | loss += (channel_grads[b] * (1 - channels[l])).sum() 66 | loss = loss / len(labels) 67 | return loss 68 | -------------------------------------------------------------------------------- /loaders/mnin_loader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | from torchvision import transforms 3 | from loaders.datasets import ImageDataset 4 | from loaders.datasets import ImageMaskDataset 5 | 6 | 7 | def _get_train_set(data_path): 8 | return ImageDataset(image_dir=data_path, 9 | transform=transforms.Compose([ 10 | transforms.Resize((224, 224)), 11 | transforms.RandomHorizontalFlip(0.5), 12 | transforms.RandomVerticalFlip(0.5), 13 | transforms.ToTensor(), 14 | transforms.Normalize([0.485, 0.456, 0.406], 15 | [0.229, 0.224, 0.225]) 16 | ])) 17 | 18 | 19 | def _get_test_set(data_path): 20 | return ImageDataset(image_dir=data_path, 21 | transform=transforms.Compose([ 22 | transforms.Resize((224, 224)), 23 | transforms.ToTensor(), 24 | transforms.Normalize([0.485, 0.456, 0.406], 25 | [0.229, 0.224, 0.225]) 26 | ])) 27 | 28 | 29 | def _get_train_set_mask(data_path, mask_path): 30 | return ImageMaskDataset(image_dir=data_path, 31 | mask_dir=mask_path, 32 | default_mask_path='xxx', 33 | transform=transforms.Compose([ 34 | transforms.Resize((224, 224)), 35 | transforms.RandomHorizontalFlip(0.5), 36 | transforms.RandomVerticalFlip(0.5), 37 | transforms.ToTensor(), 38 | transforms.Normalize([0.485, 0.456, 0.406], 39 | [0.229, 0.224, 0.225]) 40 | ])) 41 | 42 | 43 | def load_images(data_path, data_type=None): 44 | assert data_type is None or data_type in ['train', 'test'] 45 | if data_type == 'train': 46 | data_set = _get_train_set(data_path) 47 | else: 48 | data_set = _get_test_set(data_path) 49 | 50 | data_loader = DataLoader(dataset=data_set, 51 | batch_size=32, 52 | num_workers=4, 53 | shuffle=True) 54 | 55 | return data_loader 56 | 57 | 58 | def load_images_masks(data_path, mask_path, data_type=None): 59 | assert data_type is None or data_type in ['train'] 60 | 61 | data_set = _get_train_set_mask(data_path, mask_path) 62 | 63 | data_loader = DataLoader(dataset=data_set, 64 | batch_size=32, 65 | num_workers=4, 66 | shuffle=True) 67 | 68 | return data_loader 69 | -------------------------------------------------------------------------------- /loaders/stl10_loader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | from torchvision import transforms 3 | from loaders.datasets import ImageDataset 4 | from loaders.datasets import ImageMaskDataset 5 | 6 | 7 | def _get_train_set(data_path): 8 | return ImageDataset(image_dir=data_path, 9 | transform=transforms.Compose([ 10 | transforms.Resize((224, 224)), 11 | transforms.RandomHorizontalFlip(0.5), 12 | transforms.RandomVerticalFlip(0.5), 13 | transforms.ToTensor(), 14 | transforms.Normalize([0.485, 0.456, 0.406], 15 | [0.229, 0.224, 0.225]) 16 | ])) 17 | 18 | 19 | def _get_test_set(data_path): 20 | return ImageDataset(image_dir=data_path, 21 | transform=transforms.Compose([ 22 | transforms.Resize((224, 224)), 23 | transforms.ToTensor(), 24 | transforms.Normalize([0.485, 0.456, 0.406], 25 | [0.229, 0.224, 0.225]) 26 | ])) 27 | 28 | 29 | def _get_train_set_mask(data_path, mask_path): 30 | return ImageMaskDataset(image_dir=data_path, 31 | mask_dir=mask_path, 32 | default_mask_path='xxx', 33 | transform=transforms.Compose([ 34 | transforms.Resize((224, 224)), 35 | transforms.RandomHorizontalFlip(0.5), 36 | transforms.RandomVerticalFlip(0.5), 37 | transforms.ToTensor(), 38 | transforms.Normalize([0.485, 0.456, 0.406], 39 | [0.229, 0.224, 0.225]) 40 | ])) 41 | 42 | 43 | def load_images(data_path, data_type=None): 44 | assert data_type is None or data_type in ['train', 'test'] 45 | if data_type == 'train': 46 | data_set = _get_train_set(data_path) 47 | else: 48 | data_set = _get_test_set(data_path) 49 | 50 | data_loader = DataLoader(dataset=data_set, 51 | batch_size=32, 52 | num_workers=4, 53 | shuffle=True) 54 | 55 | return data_loader 56 | 57 | 58 | def load_images_masks(data_path, mask_path, data_type=None): 59 | assert data_type is None or data_type in ['train'] 60 | 61 | data_set = _get_train_set_mask(data_path, mask_path) 62 | 63 | data_loader = DataLoader(dataset=data_set, 64 | batch_size=32, 65 | num_workers=4, 66 | shuffle=True) 67 | 68 | return data_loader 69 | -------------------------------------------------------------------------------- /models/squeezenet.py: -------------------------------------------------------------------------------- 1 | """squeezenet in pytorch 2 | 3 | 4 | 5 | [1] Song Han, Jeff Pool, John Tran, William J. Dally 6 | 7 | squeezenet: Learning both Weights and Connections for Efficient Neural Networks 8 | https://arxiv.org/abs/1506.02626 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | 15 | class Fire(nn.Module): 16 | 17 | def __init__(self, in_channel, out_channel, squzee_channel): 18 | super().__init__() 19 | self.squeeze = nn.Sequential( 20 | nn.Conv2d(in_channel, squzee_channel, 1), 21 | nn.BatchNorm2d(squzee_channel), 22 | nn.ReLU(inplace=True) 23 | ) 24 | 25 | self.expand_1x1 = nn.Sequential( 26 | nn.Conv2d(squzee_channel, int(out_channel / 2), 1), 27 | nn.BatchNorm2d(int(out_channel / 2)), 28 | nn.ReLU(inplace=True) 29 | ) 30 | 31 | self.expand_3x3 = nn.Sequential( 32 | nn.Conv2d(squzee_channel, int(out_channel / 2), 3, padding=1), 33 | nn.BatchNorm2d(int(out_channel / 2)), 34 | nn.ReLU(inplace=True) 35 | ) 36 | 37 | def forward(self, x): 38 | x = self.squeeze(x) 39 | x = torch.cat([ 40 | self.expand_1x1(x), 41 | self.expand_3x3(x) 42 | ], 1) 43 | 44 | return x 45 | 46 | 47 | class SqueezeNet(nn.Module): 48 | """mobile net with simple bypass""" 49 | 50 | def __init__(self, in_channels=3, num_classes=10): 51 | super().__init__() 52 | self.stem = nn.Sequential( 53 | nn.Conv2d(in_channels, 96, 3, padding=1), 54 | nn.BatchNorm2d(96), 55 | nn.ReLU(inplace=True), 56 | nn.MaxPool2d(2, 2) 57 | ) 58 | 59 | self.fire2 = Fire(96, 128, 16) 60 | self.fire3 = Fire(128, 128, 16) 61 | self.fire4 = Fire(128, 256, 32) 62 | self.fire5 = Fire(256, 256, 32) 63 | self.fire6 = Fire(256, 384, 48) 64 | self.fire7 = Fire(384, 384, 48) 65 | self.fire8 = Fire(384, 512, 64) 66 | self.fire9 = Fire(512, 512, 64) 67 | 68 | self.conv10 = nn.Conv2d(512, num_classes, 1) 69 | self.avg = nn.AdaptiveAvgPool2d(1) 70 | self.maxpool = nn.MaxPool2d(2, 2) 71 | 72 | def forward(self, x): 73 | x = self.stem(x) 74 | 75 | f2 = self.fire2(x) 76 | f3 = self.fire3(f2) + f2 77 | f4 = self.fire4(f3) 78 | f4 = self.maxpool(f4) 79 | 80 | f5 = self.fire5(f4) + f4 81 | f6 = self.fire6(f5) 82 | f7 = self.fire7(f6) + f6 83 | f8 = self.fire8(f7) 84 | f8 = self.maxpool(f8) 85 | 86 | f9 = self.fire9(f8) 87 | c10 = self.conv10(f9) 88 | 89 | x = self.avg(c10) 90 | x = x.view(x.size(0), -1) 91 | 92 | return x 93 | 94 | 95 | def squeezenet(in_channels=3, num_classes=10): 96 | return SqueezeNet(in_channels=in_channels, num_classes=num_classes) 97 | 98 | -------------------------------------------------------------------------------- /preprocessing/cifar/cifar100_gen.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append('/workspace/classification/code/') # zjlab 4 | import os 5 | import pickle 6 | import cv2 7 | import numpy as np 8 | 9 | from configs import config 10 | from utils import image_util 11 | from PIL import Image 12 | 13 | # source directory 14 | CIFAR100_DIR = config.datasets_CIFAR_100 15 | 16 | # extract cifar img in here. 17 | CIFAR100_TRAIN_DIR = config.data_cifar100 + '/train' 18 | CIFAR100_TEST_DIR = config.data_cifar100 + '/test' 19 | 20 | dir_list = [CIFAR100_TRAIN_DIR, CIFAR100_TEST_DIR] 21 | 22 | 23 | # extract the binaries, encoding must is 'bytes'! 24 | def unpickle(file): 25 | fo = open(file, 'rb') 26 | data = pickle.load(fo, encoding='bytes') 27 | fo.close() 28 | return data 29 | 30 | 31 | def gen_cifar_100(): 32 | # generate training data sets. 33 | 34 | data_dir = CIFAR100_DIR + '/train' 35 | train_data = unpickle(data_dir) 36 | print(data_dir + " is loading...") 37 | 38 | for i in range(0, 50000): 39 | # binary files are converted to images. 40 | img = np.reshape(train_data[b'data'][i], (3, 32, 32)) 41 | # img = img.transpose(1, 2, 0) 42 | # img_path = CIFAR100_TRAIN_DIR + '/' + str(train_data[b'fine_labels'][i]) + '/' + str(i) + '.jpg' 43 | # print(img_path) 44 | # image_util.save_cv(img, img_path) 45 | 46 | r = img[0] 47 | g = img[1] 48 | b = img[2] 49 | 50 | ir = Image.fromarray(r) 51 | ig = Image.fromarray(g) 52 | ib = Image.fromarray(b) 53 | rgb = Image.merge("RGB", (ir, ig, ib)) 54 | img_path = CIFAR100_TRAIN_DIR + '/' + str(train_data[b'fine_labels'][i]) 55 | if not os.path.exists(img_path): 56 | os.makedirs(img_path) 57 | filename = img_path + '/' + str(i) + '.png' 58 | print(filename) 59 | rgb.save(filename, "PNG") 60 | print(data_dir + " loaded.") 61 | 62 | print("test_batch is loading...") 63 | 64 | # generate the validation data set. 65 | val_data = CIFAR100_DIR + '/test' 66 | val_data = unpickle(val_data) 67 | for i in range(0, 10000): 68 | # binary files are converted to images 69 | img = np.reshape(val_data[b'data'][i], (3, 32, 32)) 70 | # img = img.transpose(1, 2, 0) 71 | # img_path = CIFAR100_TEST_DIR + '/' + str(val_data[b'fine_labels'][i]) + '/' + str(i) + '.jpg' 72 | # print(img_path) 73 | # image_util.save_cv(img, img_path) 74 | 75 | r = img[0] 76 | g = img[1] 77 | b = img[2] 78 | 79 | ir = Image.fromarray(r) 80 | ig = Image.fromarray(g) 81 | ib = Image.fromarray(b) 82 | rgb = Image.merge("RGB", (ir, ig, ib)) 83 | img_path = CIFAR100_TEST_DIR + '/' + str(train_data[b'fine_labels'][i]) 84 | if not os.path.exists(img_path): 85 | os.makedirs(img_path) 86 | filename = img_path + '/' + str(i) + '.png' 87 | print(filename) 88 | rgb.save(filename, "PNG") 89 | print("test_batch loaded.") 90 | return 91 | 92 | 93 | if __name__ == '__main__': 94 | gen_cifar_100() 95 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models import simnet, alexnet, vgg, resnet, \ 3 | senet, resnext, densenet, simplenetv1, \ 4 | efficientnetv2, googlenet, xception, mobilenetv2, \ 5 | inceptionv3, wideresnet, shufflenetv2, squeezenet, mnasnet 6 | 7 | 8 | def load_model(model_name, in_channels=3, num_classes=10): 9 | print('-' * 50) 10 | print('LOAD MODEL:', model_name) 11 | print('-' * 50) 12 | 13 | model = None 14 | if model_name == 'simnet': 15 | model = simnet.simnet() 16 | elif model_name == 'alexnet': 17 | model = alexnet.alexnet(in_channels, num_classes) 18 | elif model_name == 'vgg16': 19 | model = vgg.vgg16_bn(in_channels, num_classes) 20 | elif model_name == 'resnet34': 21 | model = resnet.resnet34(in_channels, num_classes) 22 | elif model_name == 'resnet50': 23 | model = resnet.resnet50(in_channels, num_classes) 24 | elif model_name == 'senet34': 25 | model = senet.seresnet34(in_channels, num_classes) 26 | elif model_name == 'wideresnet28': 27 | model = wideresnet.wide_resnet28_10(in_channels, num_classes) 28 | elif model_name == 'resnext50': 29 | model = resnext.resnext50(in_channels, num_classes) 30 | elif model_name == 'densenet121': 31 | model = densenet.densenet121(in_channels, num_classes) 32 | elif model_name == 'simplenetv1': 33 | model = simplenetv1.simplenet(in_channels, num_classes) 34 | elif model_name == 'efficientnetv2s': 35 | model = efficientnetv2.effnetv2_s(in_channels, num_classes) 36 | elif model_name == 'efficientnetv2l': 37 | model = efficientnetv2.effnetv2_l(in_channels, num_classes) 38 | elif model_name == 'googlenet': 39 | model = googlenet.googlenet(in_channels, num_classes) 40 | elif model_name == 'xception': 41 | model = xception.xception(in_channels, num_classes) 42 | elif model_name == 'mobilenetv2': 43 | model = mobilenetv2.mobilenetv2(in_channels, num_classes) 44 | elif model_name == 'inceptionv3': 45 | model = inceptionv3.inceptionv3(in_channels, num_classes) 46 | elif model_name == 'shufflenetv2': 47 | model = shufflenetv2.shufflenetv2(in_channels, num_classes) 48 | elif model_name == 'squeezenet': 49 | model = squeezenet.squeezenet(in_channels, num_classes) 50 | elif model_name == 'mnasnet': 51 | model = mnasnet.mnasnet(in_channels, num_classes) 52 | return model 53 | 54 | 55 | def load_modules(model, model_layers=None): 56 | assert model_layers is None or type(model_layers) is list 57 | 58 | modules = [] 59 | for module in model.modules(): 60 | if isinstance(module, torch.nn.Conv2d): 61 | modules.append(module) 62 | 63 | modules.reverse() # reverse order 64 | if model_layers is None: 65 | model_modules = modules 66 | else: 67 | model_modules = [] 68 | for layer in model_layers: 69 | model_modules.append(modules[layer]) 70 | 71 | print('-' * 50) 72 | print('Model Layers:', model_layers) 73 | print('Model Modules:', model_modules) 74 | print('Model Modules Length:', len(model_modules)) 75 | print('-' * 50) 76 | 77 | return model_modules 78 | -------------------------------------------------------------------------------- /models/wideresnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.init as init 3 | import torch.nn.functional as F 4 | 5 | import numpy as np 6 | 7 | 8 | def conv3x3(in_planes, out_planes, stride=1): 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True) 10 | 11 | 12 | def conv_init(m): 13 | classname = m.__class__.__name__ 14 | if classname.find('Conv') != -1: 15 | init.xavier_uniform_(m.weight, gain=np.sqrt(2)) 16 | init.constant_(m.bias, 0) 17 | elif classname.find('BatchNorm') != -1: 18 | init.constant_(m.weight, 1) 19 | init.constant_(m.bias, 0) 20 | 21 | 22 | class wide_basic(nn.Module): 23 | def __init__(self, in_planes, planes, dropout_rate, stride=1): 24 | super(wide_basic, self).__init__() 25 | self.bn1 = nn.BatchNorm2d(in_planes) 26 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True) 27 | self.dropout = nn.Dropout(p=dropout_rate) 28 | self.bn2 = nn.BatchNorm2d(planes) 29 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True) 30 | 31 | self.shortcut = nn.Sequential() 32 | if stride != 1 or in_planes != planes: 33 | self.shortcut = nn.Sequential( 34 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True), 35 | ) 36 | 37 | def forward(self, x): 38 | out = self.dropout(self.conv1(F.relu(self.bn1(x)))) 39 | out = self.conv2(F.relu(self.bn2(out))) 40 | out += self.shortcut(x) 41 | 42 | return out 43 | 44 | 45 | class WideResNet(nn.Module): 46 | def __init__(self, depth, widen_factor, dropout_rate, in_channels=3, num_classes=10): 47 | super(WideResNet, self).__init__() 48 | self.in_planes = 16 49 | 50 | assert ((depth - 4) % 6 == 0), 'Wide-resnet depth should be 6n+4' 51 | n = (depth - 4) / 6 52 | k = widen_factor 53 | 54 | nStages = [16, 16 * k, 32 * k, 64 * k] 55 | 56 | self.conv1 = conv3x3(in_channels, nStages[0]) 57 | self.layer1 = self._wide_layer(wide_basic, nStages[1], n, dropout_rate, stride=1) 58 | self.layer2 = self._wide_layer(wide_basic, nStages[2], n, dropout_rate, stride=2) 59 | self.layer3 = self._wide_layer(wide_basic, nStages[3], n, dropout_rate, stride=2) 60 | self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9) 61 | self.linear = nn.Linear(nStages[3] * 49, num_classes) 62 | 63 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride): 64 | strides = [stride] + [1] * (int(num_blocks) - 1) 65 | layers = [] 66 | 67 | for stride in strides: 68 | layers.append(block(self.in_planes, planes, dropout_rate, stride)) 69 | self.in_planes = planes 70 | 71 | return nn.Sequential(*layers) 72 | 73 | def forward(self, x): 74 | out = self.conv1(x) 75 | out = self.layer1(out) 76 | out = self.layer2(out) 77 | out = self.layer3(out) 78 | out = F.relu(self.bn1(out)) 79 | out = F.avg_pool2d(out, 8) 80 | out = out.view(out.size(0), -1) 81 | out = self.linear(out) 82 | 83 | return out 84 | 85 | 86 | def wide_resnet28_10(in_channels=1, num_classes=10): 87 | return WideResNet(28, 10, 0.3, in_channels=in_channels, num_classes=num_classes) 88 | -------------------------------------------------------------------------------- /models/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | """mobilenetv2 in pytorch 2 | 3 | 4 | 5 | [1] Mark Sandler, Andrew Howard, Menglong Zhu, Andrey Zhmoginov, Liang-Chieh Chen 6 | 7 | MobileNetV2: Inverted Residuals and Linear Bottlenecks 8 | https://arxiv.org/abs/1801.04381 9 | """ 10 | 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | 15 | class LinearBottleNeck(nn.Module): 16 | 17 | def __init__(self, in_channels, out_channels, stride, t=6, class_num=100): 18 | super().__init__() 19 | 20 | self.residual = nn.Sequential( 21 | nn.Conv2d(in_channels, in_channels * t, 1), 22 | nn.BatchNorm2d(in_channels * t), 23 | nn.ReLU6(inplace=True), 24 | 25 | nn.Conv2d(in_channels * t, in_channels * t, 3, stride=stride, padding=1, groups=in_channels * t), 26 | nn.BatchNorm2d(in_channels * t), 27 | nn.ReLU6(inplace=True), 28 | 29 | nn.Conv2d(in_channels * t, out_channels, 1), 30 | nn.BatchNorm2d(out_channels) 31 | ) 32 | 33 | self.stride = stride 34 | self.in_channels = in_channels 35 | self.out_channels = out_channels 36 | 37 | def forward(self, x): 38 | residual = self.residual(x) 39 | 40 | if self.stride == 1 and self.in_channels == self.out_channels: 41 | residual += x 42 | 43 | return residual 44 | 45 | 46 | class MobileNetV2(nn.Module): 47 | 48 | def __init__(self, in_channels=3, num_classes=10): 49 | super().__init__() 50 | 51 | self.pre = nn.Sequential( 52 | nn.Conv2d(in_channels, 32, 1, padding=1), 53 | nn.BatchNorm2d(32), 54 | nn.ReLU6(inplace=True) 55 | ) 56 | 57 | self.stage1 = LinearBottleNeck(32, 16, 1, 1) 58 | self.stage2 = self._make_stage(2, 16, 24, 2, 6) 59 | self.stage3 = self._make_stage(3, 24, 32, 2, 6) 60 | self.stage4 = self._make_stage(4, 32, 64, 2, 6) 61 | self.stage5 = self._make_stage(3, 64, 96, 1, 6) 62 | self.stage6 = self._make_stage(3, 96, 160, 1, 6) 63 | self.stage7 = LinearBottleNeck(160, 320, 1, 6) 64 | 65 | self.conv1 = nn.Sequential( 66 | nn.Conv2d(320, 1280, 1), 67 | nn.BatchNorm2d(1280), 68 | nn.ReLU6(inplace=True) 69 | ) 70 | 71 | self.conv2 = nn.Conv2d(1280, num_classes, 1) 72 | 73 | def forward(self, x): 74 | x = self.pre(x) 75 | x = self.stage1(x) 76 | x = self.stage2(x) 77 | x = self.stage3(x) 78 | x = self.stage4(x) 79 | x = self.stage5(x) 80 | x = self.stage6(x) 81 | x = self.stage7(x) 82 | x = self.conv1(x) 83 | x = F.adaptive_avg_pool2d(x, 1) 84 | x = self.conv2(x) 85 | x = x.view(x.size(0), -1) 86 | 87 | return x 88 | 89 | def _make_stage(self, repeat, in_channels, out_channels, stride, t): 90 | layers = [] 91 | layers.append(LinearBottleNeck(in_channels, out_channels, stride, t)) 92 | 93 | while repeat - 1: 94 | layers.append(LinearBottleNeck(out_channels, out_channels, 1, t)) 95 | repeat -= 1 96 | 97 | return nn.Sequential(*layers) 98 | 99 | 100 | def mobilenetv2(in_channels=3, num_classes=10): 101 | return MobileNetV2(in_channels=in_channels, num_classes=num_classes) 102 | 103 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import os 3 | import argparse 4 | import time 5 | 6 | import torch 7 | from torch import nn 8 | 9 | import loaders 10 | import models 11 | import metrics 12 | from utils.train_util import AverageMeter, ProgressMeter 13 | 14 | 15 | def main(): 16 | parser = argparse.ArgumentParser(description='') 17 | parser.add_argument('--model_name', default='', type=str, help='model name') 18 | parser.add_argument('--data_name', default='', type=str, help='data name') 19 | parser.add_argument('--in_channels', default='', type=int, help='in channels') 20 | parser.add_argument('--num_classes', default='', type=int, help='num classes') 21 | parser.add_argument('--model_path', default='', type=str, help='model path') 22 | parser.add_argument('--data_path', default='', type=str, help='data path') 23 | parser.add_argument('--device_index', default='0', type=str, help='device index') 24 | args = parser.parse_args() 25 | 26 | # ---------------------------------------- 27 | # basic configuration 28 | # ---------------------------------------- 29 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device_index 30 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 31 | 32 | print('-' * 50) 33 | print('TEST ON:', device) 34 | print('MODEL PATH:', args.model_path) 35 | print('-' * 50) 36 | 37 | # ---------------------------------------- 38 | # trainer configuration 39 | # ---------------------------------------- 40 | model = models.load_model(model_name=args.model_name, num_classes=args.num_classes) 41 | model.load_state_dict(torch.load(args.model_path)) 42 | model.to(device) 43 | model.eval() 44 | 45 | test_loader = loaders.load_data(args.data_name, args.data_path, data_type='test') 46 | 47 | criterion = nn.CrossEntropyLoss() 48 | 49 | # ---------------------------------------- 50 | # each epoch 51 | # ---------------------------------------- 52 | since = time.time() 53 | 54 | loss, acc1, acc5, class_acc = test(test_loader, model, criterion, device, args) 55 | 56 | print('-' * 50) 57 | print('COMPLETE !!!') 58 | print(class_acc) 59 | print(loss, acc1, acc5) 60 | print('TIME CONSUMED', time.time() - since) 61 | 62 | 63 | def test(test_loader, model, criterion, device, args): 64 | loss_meter = AverageMeter('Loss', ':.4e') 65 | acc1_meter = AverageMeter('Acc@1', ':6.2f') 66 | acc5_meter = AverageMeter('Acc@5', ':6.2f') 67 | progress = ProgressMeter(total=len(test_loader), step=20, prefix='Test', 68 | meters=[loss_meter, acc1_meter, acc5_meter]) 69 | class_acc = metrics.ClassAccuracy(args.num_classes) 70 | model.eval() 71 | 72 | for i, samples in enumerate(test_loader): 73 | inputs, labels, _ = samples 74 | inputs = inputs.to(device) 75 | labels = labels.to(device) 76 | 77 | with torch.set_grad_enabled(False): 78 | outputs = model(inputs) 79 | loss = criterion(outputs, labels) 80 | acc1, acc5 = metrics.accuracy(outputs, labels, topk=(1, 5)) 81 | class_acc.accuracy(outputs, labels) 82 | 83 | loss_meter.update(loss.item(), inputs.size(0)) 84 | acc1_meter.update(acc1.item(), inputs.size(0)) 85 | acc5_meter.update(acc5.item(), inputs.size(0)) 86 | 87 | progress.display(i) 88 | 89 | return loss_meter, acc1_meter, acc5_meter, class_acc 90 | 91 | 92 | if __name__ == '__main__': 93 | main() 94 | -------------------------------------------------------------------------------- /models/simplenetv1.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class SimpleNet(nn.Module): 6 | def __init__(self, in_channels=3, num_classes=10): 7 | super(SimpleNet, self).__init__() 8 | # print(simpnet_name) 9 | self.features = self._make_layers(in_channels) # self._make_layers(cfg[simpnet_name]) 10 | self.classifier = nn.Linear(256, num_classes) 11 | self.drp = nn.Dropout(0.1) 12 | 13 | def forward(self, x): 14 | out = self.features(x) 15 | 16 | # Global Max Pooling 17 | out = F.max_pool2d(out, kernel_size=out.size()[2:]) 18 | # out = F.dropout2d(out, 0.1, training=True) 19 | out = self.drp(out) 20 | 21 | out = out.view(out.size(0), -1) 22 | out = self.classifier(out) 23 | return out 24 | 25 | def _make_layers(self, in_channels): 26 | 27 | model = nn.Sequential( 28 | nn.Conv2d(in_channels, 64, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)), 29 | nn.BatchNorm2d(64, eps=1e-05, momentum=0.05, affine=True), 30 | nn.ReLU(inplace=True), 31 | 32 | nn.Conv2d(64, 128, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)), 33 | nn.BatchNorm2d(128, eps=1e-05, momentum=0.05, affine=True), 34 | nn.ReLU(inplace=True), 35 | 36 | nn.Conv2d(128, 128, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)), 37 | nn.BatchNorm2d(128, eps=1e-05, momentum=0.05, affine=True), 38 | nn.ReLU(inplace=True), 39 | 40 | nn.Conv2d(128, 128, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)), 41 | nn.BatchNorm2d(128, eps=1e-05, momentum=0.05, affine=True), 42 | nn.ReLU(inplace=True), 43 | 44 | nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False), 45 | nn.Dropout2d(p=0.1), 46 | 47 | nn.Conv2d(128, 128, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)), 48 | nn.BatchNorm2d(128, eps=1e-05, momentum=0.05, affine=True), 49 | nn.ReLU(inplace=True), 50 | 51 | nn.Conv2d(128, 128, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)), 52 | nn.BatchNorm2d(128, eps=1e-05, momentum=0.05, affine=True), 53 | nn.ReLU(inplace=True), 54 | 55 | nn.Conv2d(128, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)), 56 | nn.BatchNorm2d(256, eps=1e-05, momentum=0.05, affine=True), 57 | nn.ReLU(inplace=True), 58 | 59 | nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False), 60 | nn.Dropout2d(p=0.1), 61 | 62 | nn.Conv2d(256, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)), 63 | nn.BatchNorm2d(256, eps=1e-05, momentum=0.05, affine=True), 64 | nn.ReLU(inplace=True), 65 | 66 | nn.Conv2d(256, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)), 67 | nn.BatchNorm2d(256, eps=1e-05, momentum=0.05, affine=True), 68 | nn.ReLU(inplace=True), 69 | 70 | nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False), 71 | nn.Dropout2d(p=0.1), 72 | 73 | nn.Conv2d(256, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)), 74 | nn.BatchNorm2d(512, eps=1e-05, momentum=0.05, affine=True), 75 | nn.ReLU(inplace=True), 76 | 77 | nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False), 78 | nn.Dropout2d(p=0.1), 79 | 80 | nn.Conv2d(512, 2048, kernel_size=[1, 1], stride=(1, 1), padding=(0, 0)), 81 | nn.BatchNorm2d(2048, eps=1e-05, momentum=0.05, affine=True), 82 | nn.ReLU(inplace=True), 83 | 84 | nn.Conv2d(2048, 256, kernel_size=[1, 1], stride=(1, 1), padding=(0, 0)), 85 | nn.BatchNorm2d(256, eps=1e-05, momentum=0.05, affine=True), 86 | nn.ReLU(inplace=True), 87 | 88 | nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False), 89 | nn.Dropout2d(p=0.1), 90 | 91 | nn.Conv2d(256, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)), 92 | nn.BatchNorm2d(256, eps=1e-05, momentum=0.05, affine=True), 93 | nn.ReLU(inplace=True), 94 | 95 | ) 96 | 97 | for m in model.modules(): 98 | if isinstance(m, nn.Conv2d): 99 | nn.init.xavier_uniform_(m.weight.data, gain=nn.init.calculate_gain('relu')) 100 | 101 | return model 102 | 103 | 104 | def simplenet(in_channels=3, num_classes=10): 105 | return SimpleNet(in_channels=in_channels, num_classes=num_classes) 106 | -------------------------------------------------------------------------------- /loaders/datasets/image_mask_transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | from collections.abc import Sequence, Iterable 3 | from torchvision import transforms 4 | import torchvision.transforms.functional as F 5 | 6 | 7 | class Compose(object): 8 | 9 | def __init__(self, transforms): 10 | self.transforms = transforms 11 | 12 | def __call__(self, images): 13 | for t in self.transforms: 14 | images = t(images) 15 | return images 16 | 17 | 18 | class ToTensor(object): 19 | 20 | def __call__(self, images): 21 | trans = [] 22 | # TODO mask to tensor? 23 | for img in images: 24 | img = F.to_tensor(img) 25 | trans.append(img) 26 | return trans 27 | 28 | 29 | class Normalize(object): 30 | def __init__(self, mean, std, inplace=False): 31 | self.mean = mean 32 | self.std = std 33 | self.inplace = inplace 34 | 35 | def __call__(self, tensors): 36 | norms = [F.normalize(tensors[0], self.mean, self.std, self.inplace), tensors[1]] 37 | return norms 38 | 39 | def __repr__(self): 40 | return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) 41 | 42 | 43 | class Resize(object): 44 | # def __init__(self, size, interpolation=2): 45 | def __init__(self, size, interpolation=transforms.InterpolationMode.BILINEAR): 46 | assert isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2) 47 | self.size = size 48 | self.interpolation = interpolation 49 | 50 | def __call__(self, images): 51 | trans = [] 52 | for img in images: 53 | img = F.resize(img, self.size, self.interpolation) 54 | trans.append(img) 55 | return trans 56 | 57 | 58 | class RandomHorizontalFlip(object): 59 | def __init__(self, p=0.5): 60 | self.p = p 61 | 62 | def __call__(self, images): 63 | if random.random() < self.p: 64 | trans = [] 65 | for img in images: 66 | img = F.hflip(img) 67 | trans.append(img) 68 | return trans 69 | return images 70 | 71 | 72 | class RandomVerticalFlip(object): 73 | def __init__(self, p=0.5): 74 | self.p = p 75 | 76 | def __call__(self, images): 77 | if random.random() < self.p: 78 | trans = [] 79 | for img in images: 80 | img = F.vflip(img) 81 | trans.append(img) 82 | return trans 83 | return images 84 | 85 | 86 | class RandomRotation(transforms.RandomRotation): 87 | 88 | def __init__(self, degrees): 89 | super(RandomRotation, self).__init__(degrees) 90 | 91 | def __call__(self, images): 92 | angle = self.get_params(self.degrees) 93 | trans = [] 94 | for img in images: 95 | img = F.rotate(img, angle, self.resample, self.expand, self.center, self.fill) 96 | trans.append(img) 97 | return trans 98 | 99 | 100 | class RandomResizedCrop(transforms.RandomResizedCrop): 101 | def __init__(self, size, scale=(0.08, 1.0)): 102 | super(RandomResizedCrop, self).__init__(size, scale) 103 | 104 | def __call__(self, images): 105 | i, j, h, w = self.get_params(images[0], self.scale, self.ratio) 106 | trans = [] 107 | for img in images: 108 | img = F.resized_crop(img, i, j, h, w, self.size, self.interpolation) 109 | trans.append(img) 110 | return trans 111 | 112 | 113 | class RandomCrop(transforms.RandomCrop): 114 | def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant"): 115 | super().__init__(size, padding, pad_if_needed, fill, padding_mode) 116 | 117 | def forward(self, images): 118 | if self.padding is not None: 119 | for i, img in enumerate(images): 120 | images[i] = F.pad(img, self.padding, self.fill, self.padding_mode) 121 | 122 | width, height = F._get_image_size(images[0]) 123 | # pad the width if needed 124 | if self.pad_if_needed and width < self.size[1]: 125 | padding = [self.size[1] - width, 0] 126 | 127 | for i, img in enumerate(images): 128 | images[i] = F.pad(img, padding, self.fill, self.padding_mode) 129 | # pad the height if needed 130 | if self.pad_if_needed and height < self.size[0]: 131 | padding = [0, self.size[0] - height] 132 | for i, img in enumerate(images): 133 | images[i] = F.pad(img, padding, self.fill, self.padding_mode) 134 | 135 | i, j, h, w = self.get_params(images[0], self.size) 136 | 137 | for i, img in enumerate(images): 138 | images[i] = F.crop(img, i, j, h, w) 139 | return images 140 | -------------------------------------------------------------------------------- /models/resnext.py: -------------------------------------------------------------------------------- 1 | """resnext in pytorch 2 | 3 | 4 | 5 | [1] Saining Xie, Ross Girshick, Piotr Dollár, Zhuowen Tu, Kaiming He. 6 | 7 | Aggregated Residual Transformations for Deep Neural Networks 8 | https://arxiv.org/abs/1611.05431 9 | """ 10 | 11 | import math 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | 16 | # only implements ResNext bottleneck c 17 | 18 | 19 | # """This strategy exposes a new dimension, which we call “cardinality” 20 | # (the size of the set of transformations), as an essential factor 21 | # in addition to the dimensions of depth and width.""" 22 | CARDINALITY = 32 23 | DEPTH = 4 24 | BASEWIDTH = 64 25 | 26 | 27 | # """The grouped convolutional layer in Fig. 3(c) performs 32 groups 28 | # of convolutions whose input and output channels are 4-dimensional. 29 | # The grouped convolutional layer concatenates them as the outputs 30 | # of the layer.""" 31 | 32 | class ResNextBottleNeckC(nn.Module): 33 | 34 | def __init__(self, in_channels, out_channels, stride): 35 | super().__init__() 36 | 37 | C = CARDINALITY # How many groups a feature map was splitted into 38 | 39 | # """We note that the input/output width of the template is fixed as 40 | # 256-d (Fig. 3), We note that the input/output width of the template 41 | # is fixed as 256-d (Fig. 3), and all widths are dou- bled each time 42 | # when the feature map is subsampled (see Table 1).""" 43 | D = int(DEPTH * out_channels / BASEWIDTH) # number of channels per group 44 | self.split_transforms = nn.Sequential( 45 | nn.Conv2d(in_channels, C * D, kernel_size=1, groups=C, bias=False), 46 | nn.BatchNorm2d(C * D), 47 | nn.ReLU(inplace=True), 48 | nn.Conv2d(C * D, C * D, kernel_size=3, stride=stride, groups=C, padding=1, bias=False), 49 | nn.BatchNorm2d(C * D), 50 | nn.ReLU(inplace=True), 51 | nn.Conv2d(C * D, out_channels * 4, kernel_size=1, bias=False), 52 | nn.BatchNorm2d(out_channels * 4), 53 | ) 54 | 55 | self.shortcut = nn.Sequential() 56 | 57 | if stride != 1 or in_channels != out_channels * 4: 58 | self.shortcut = nn.Sequential( 59 | nn.Conv2d(in_channels, out_channels * 4, stride=stride, kernel_size=1, bias=False), 60 | nn.BatchNorm2d(out_channels * 4) 61 | ) 62 | 63 | def forward(self, x): 64 | return F.relu(self.split_transforms(x) + self.shortcut(x)) 65 | 66 | 67 | class ResNext(nn.Module): 68 | 69 | def __init__(self, block, num_blocks, in_channels=3, num_classes=100): 70 | super().__init__() 71 | self.in_channels = 64 72 | 73 | self.conv1 = nn.Sequential( 74 | nn.Conv2d(in_channels, 64, 3, stride=1, padding=1, bias=False), 75 | nn.BatchNorm2d(64), 76 | nn.ReLU(inplace=True) 77 | ) 78 | 79 | self.conv2 = self._make_layer(block, num_blocks[0], 64, 1) 80 | self.conv3 = self._make_layer(block, num_blocks[1], 128, 2) 81 | self.conv4 = self._make_layer(block, num_blocks[2], 256, 2) 82 | self.conv5 = self._make_layer(block, num_blocks[3], 512, 2) 83 | self.avg = nn.AdaptiveAvgPool2d((1, 1)) 84 | self.fc = nn.Linear(512 * 4, num_classes) 85 | 86 | def forward(self, x): 87 | x = self.conv1(x) 88 | x = self.conv2(x) 89 | x = self.conv3(x) 90 | x = self.conv4(x) 91 | x = self.conv5(x) 92 | x = self.avg(x) 93 | x = x.view(x.size(0), -1) 94 | x = self.fc(x) 95 | return x 96 | 97 | def _make_layer(self, block, num_block, out_channels, stride): 98 | """Building resnext block 99 | Args: 100 | block: block type(default resnext bottleneck c) 101 | num_block: number of blocks per layer 102 | out_channels: output channels per block 103 | stride: block stride 104 | 105 | Returns: 106 | a resnext layer 107 | """ 108 | strides = [stride] + [1] * (num_block - 1) 109 | layers = [] 110 | for stride in strides: 111 | layers.append(block(self.in_channels, out_channels, stride)) 112 | self.in_channels = out_channels * 4 113 | 114 | return nn.Sequential(*layers) 115 | 116 | 117 | def resnext50(in_channels=3, num_classes=10): 118 | """ return a resnext50(c32x4d) network 119 | """ 120 | return ResNext(ResNextBottleNeckC, [3, 4, 6, 3], in_channels=in_channels, num_classes=num_classes) 121 | 122 | 123 | def resnext101(): 124 | """ return a resnext101(c32x4d) network 125 | """ 126 | return ResNext(ResNextBottleNeckC, [3, 4, 23, 3]) 127 | 128 | 129 | def resnext152(): 130 | """ return a resnext101(c32x4d) network 131 | """ 132 | return ResNext(ResNextBottleNeckC, [3, 4, 36, 3]) 133 | 134 | -------------------------------------------------------------------------------- /models/googlenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Inception(nn.Module): 6 | def __init__(self, input_channels, n1x1, n3x3_reduce, n3x3, n5x5_reduce, n5x5, pool_proj): 7 | super().__init__() 8 | 9 | # 1x1conv branch 10 | self.b1 = nn.Sequential( 11 | nn.Conv2d(input_channels, n1x1, kernel_size=1), 12 | nn.BatchNorm2d(n1x1), 13 | nn.ReLU(inplace=True) 14 | ) 15 | 16 | # 1x1conv -> 3x3conv branch 17 | self.b2 = nn.Sequential( 18 | nn.Conv2d(input_channels, n3x3_reduce, kernel_size=1), 19 | nn.BatchNorm2d(n3x3_reduce), 20 | nn.ReLU(inplace=True), 21 | nn.Conv2d(n3x3_reduce, n3x3, kernel_size=3, padding=1), 22 | nn.BatchNorm2d(n3x3), 23 | nn.ReLU(inplace=True) 24 | ) 25 | 26 | # 1x1conv -> 5x5conv branch 27 | # we use 2 3x3 conv filters stacked instead 28 | # of 1 5x5 filters to obtain the same receptive 29 | # field with fewer parameters 30 | self.b3 = nn.Sequential( 31 | nn.Conv2d(input_channels, n5x5_reduce, kernel_size=1), 32 | nn.BatchNorm2d(n5x5_reduce), 33 | nn.ReLU(inplace=True), 34 | nn.Conv2d(n5x5_reduce, n5x5, kernel_size=3, padding=1), 35 | nn.BatchNorm2d(n5x5, n5x5), 36 | nn.ReLU(inplace=True), 37 | nn.Conv2d(n5x5, n5x5, kernel_size=3, padding=1), 38 | nn.BatchNorm2d(n5x5), 39 | nn.ReLU(inplace=True) 40 | ) 41 | 42 | # 3x3pooling -> 1x1conv 43 | # same conv 44 | self.b4 = nn.Sequential( 45 | nn.MaxPool2d(3, stride=1, padding=1), 46 | nn.Conv2d(input_channels, pool_proj, kernel_size=1), 47 | nn.BatchNorm2d(pool_proj), 48 | nn.ReLU(inplace=True) 49 | ) 50 | 51 | def forward(self, x): 52 | return torch.cat([self.b1(x), self.b2(x), self.b3(x), self.b4(x)], dim=1) 53 | 54 | 55 | class GoogleNet(nn.Module): 56 | 57 | def __init__(self, in_channels=3, num_classes=10): 58 | super().__init__() 59 | self.prelayer = nn.Sequential( 60 | nn.Conv2d(in_channels, 64, kernel_size=3, padding=1, bias=False), 61 | nn.BatchNorm2d(64), 62 | nn.ReLU(inplace=True), 63 | nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=False), 64 | nn.BatchNorm2d(64), 65 | nn.ReLU(inplace=True), 66 | nn.Conv2d(64, 192, kernel_size=3, padding=1, bias=False), 67 | nn.BatchNorm2d(192), 68 | nn.ReLU(inplace=True), 69 | ) 70 | 71 | # although we only use 1 conv layer as prelayer, 72 | # we still use name a3, b3....... 73 | self.a3 = Inception(192, 64, 96, 128, 16, 32, 32) 74 | self.b3 = Inception(256, 128, 128, 192, 32, 96, 64) 75 | 76 | ##"""In general, an Inception network is a network consisting of 77 | ##modules of the above type stacked upon each other, with occasional 78 | ##max-pooling layers with stride 2 to halve the resolution of the 79 | ##grid""" 80 | self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) 81 | 82 | self.a4 = Inception(480, 192, 96, 208, 16, 48, 64) 83 | self.b4 = Inception(512, 160, 112, 224, 24, 64, 64) 84 | self.c4 = Inception(512, 128, 128, 256, 24, 64, 64) 85 | self.d4 = Inception(512, 112, 144, 288, 32, 64, 64) 86 | self.e4 = Inception(528, 256, 160, 320, 32, 128, 128) 87 | 88 | self.a5 = Inception(832, 256, 160, 320, 32, 128, 128) 89 | self.b5 = Inception(832, 384, 192, 384, 48, 128, 128) 90 | 91 | # input feature size: 8*8*1024 92 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 93 | self.dropout = nn.Dropout2d(p=0.4) 94 | self.linear = nn.Linear(1024, num_classes) 95 | 96 | def forward(self, x): 97 | x = self.prelayer(x) 98 | x = self.maxpool(x) 99 | x = self.a3(x) 100 | x = self.b3(x) 101 | 102 | x = self.maxpool(x) 103 | 104 | x = self.a4(x) 105 | x = self.b4(x) 106 | x = self.c4(x) 107 | x = self.d4(x) 108 | x = self.e4(x) 109 | 110 | x = self.maxpool(x) 111 | 112 | x = self.a5(x) 113 | x = self.b5(x) 114 | 115 | # """It was found that a move from fully connected layers to 116 | # average pooling improved the top-1 accuracy by about 0.6%, 117 | # however the use of dropout remained essential even after 118 | # removing the fully connected layers.""" 119 | x = self.avgpool(x) 120 | x = self.dropout(x) 121 | x = x.view(x.size()[0], -1) 122 | x = self.linear(x) 123 | 124 | return x 125 | 126 | 127 | def googlenet(in_channels=3, num_classes=10): 128 | return GoogleNet(in_channels=in_channels, num_classes=num_classes) 129 | 130 | -------------------------------------------------------------------------------- /models/mnasnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | 4 | 5 | def Conv_3x3(inp, oup, stride): 6 | return nn.Sequential( 7 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 8 | nn.BatchNorm2d(oup), 9 | nn.ReLU6(inplace=True) 10 | ) 11 | 12 | 13 | def Conv_1x1(inp, oup): 14 | return nn.Sequential( 15 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 16 | nn.BatchNorm2d(oup), 17 | nn.ReLU6(inplace=True) 18 | ) 19 | 20 | 21 | def SepConv_3x3(inp, oup): # input=32, output=16 22 | return nn.Sequential( 23 | # dw 24 | nn.Conv2d(inp, inp, 3, 1, 1, groups=inp, bias=False), 25 | nn.BatchNorm2d(inp), 26 | nn.ReLU6(inplace=True), 27 | # pw-linear 28 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 29 | nn.BatchNorm2d(oup), 30 | ) 31 | 32 | 33 | class InvertedResidual(nn.Module): 34 | def __init__(self, inp, oup, stride, expand_ratio, kernel): 35 | super(InvertedResidual, self).__init__() 36 | self.stride = stride 37 | assert stride in [1, 2] 38 | 39 | self.use_res_connect = self.stride == 1 and inp == oup 40 | 41 | self.conv = nn.Sequential( 42 | # pw 43 | nn.Conv2d(inp, inp * expand_ratio, 1, 1, 0, bias=False), 44 | nn.BatchNorm2d(inp * expand_ratio), 45 | nn.ReLU6(inplace=True), 46 | # dw 47 | nn.Conv2d(inp * expand_ratio, inp * expand_ratio, kernel, stride, kernel // 2, groups=inp * expand_ratio, 48 | bias=False), 49 | nn.BatchNorm2d(inp * expand_ratio), 50 | nn.ReLU6(inplace=True), 51 | # pw-linear 52 | nn.Conv2d(inp * expand_ratio, oup, 1, 1, 0, bias=False), 53 | nn.BatchNorm2d(oup), 54 | ) 55 | 56 | def forward(self, x): 57 | if self.use_res_connect: 58 | return x + self.conv(x) 59 | else: 60 | return self.conv(x) 61 | 62 | 63 | class MnasNet(nn.Module): 64 | def __init__(self, in_channels=3, num_classes=10, input_size=224, width_mult=1.): 65 | super(MnasNet, self).__init__() 66 | 67 | # setting of inverted residual blocks 68 | self.interverted_residual_setting = [ 69 | # t, c, n, s, k 70 | [3, 24, 3, 2, 3], # -> 56x56 71 | [3, 40, 3, 2, 5], # -> 28x28 72 | [6, 80, 3, 2, 5], # -> 14x14 73 | [6, 96, 2, 1, 3], # -> 14x14 74 | [6, 192, 4, 2, 5], # -> 7x7 75 | [6, 320, 1, 1, 3], # -> 7x7 76 | ] 77 | 78 | assert input_size % 32 == 0 79 | input_channel = int(32 * width_mult) 80 | self.last_channel = int(1280 * width_mult) if width_mult > 1.0 else 1280 81 | 82 | # building first two layer 83 | self.features = [Conv_3x3(in_channels, input_channel, 2), SepConv_3x3(input_channel, 16)] 84 | input_channel = 16 85 | 86 | # building inverted residual blocks (MBConv) 87 | for t, c, n, s, k in self.interverted_residual_setting: 88 | output_channel = int(c * width_mult) 89 | for i in range(n): 90 | if i == 0: 91 | self.features.append(InvertedResidual(input_channel, output_channel, s, t, k)) 92 | else: 93 | self.features.append(InvertedResidual(input_channel, output_channel, 1, t, k)) 94 | input_channel = output_channel 95 | 96 | # building last several layers 97 | self.features.append(Conv_1x1(input_channel, self.last_channel)) 98 | self.features.append(nn.AdaptiveAvgPool2d(1)) 99 | 100 | # make it nn.Sequential 101 | self.features = nn.Sequential(*self.features) 102 | 103 | # building classifier 104 | self.classifier = nn.Sequential( 105 | nn.Dropout(), 106 | nn.Linear(self.last_channel, num_classes), 107 | ) 108 | 109 | self._initialize_weights() 110 | 111 | def forward(self, x): 112 | x = self.features(x) 113 | x = x.view(-1, self.last_channel) 114 | x = self.classifier(x) 115 | return x 116 | 117 | def _initialize_weights(self): 118 | for m in self.modules(): 119 | if isinstance(m, nn.Conv2d): 120 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 121 | m.weight.data.normal_(0, math.sqrt(2. / n)) 122 | if m.bias is not None: 123 | m.bias.data.zero_() 124 | elif isinstance(m, nn.BatchNorm2d): 125 | m.weight.data.fill_(1) 126 | m.bias.data.zero_() 127 | elif isinstance(m, nn.Linear): 128 | n = m.weight.size(1) 129 | m.weight.data.normal_(0, 0.01) 130 | m.bias.data.zero_() 131 | 132 | 133 | def mnasnet(in_channels=3, num_classes=10): 134 | return MnasNet(in_channels=in_channels, num_classes=num_classes) 135 | 136 | -------------------------------------------------------------------------------- /core/grad_calculate.py: -------------------------------------------------------------------------------- 1 | """ 2 | activation: 3 | each layer 4 | 5 | gradient: 6 | output to each layer 7 | """ 8 | import sys 9 | 10 | sys.path.append('.') 11 | 12 | import argparse 13 | import torch 14 | import numpy as np 15 | import os 16 | from tqdm import tqdm 17 | 18 | import models 19 | import loaders 20 | 21 | 22 | class HookModule: 23 | def __init__(self, module): 24 | self.inputs = None 25 | self.outputs = None 26 | module.register_forward_hook(self._hook) 27 | 28 | def grads(self, outputs, inputs=None, retain_graph=True, create_graph=True): 29 | if inputs is None: 30 | inputs = self.outputs # default the output dim 31 | 32 | return torch.autograd.grad(outputs=outputs, 33 | inputs=inputs, 34 | retain_graph=retain_graph, 35 | create_graph=create_graph)[0] 36 | 37 | def _hook(self, module, inputs, outputs): 38 | self.inputs = inputs[0] 39 | self.outputs = outputs 40 | 41 | 42 | def _normalization(data, axis=None, bot=False): 43 | assert axis in [None, 0, 1] 44 | _max = np.max(data, axis=axis) 45 | if bot: 46 | _min = np.zeros(_max.shape) 47 | else: 48 | _min = np.min(data, axis=axis) 49 | _range = _max - _min 50 | if axis == 1: 51 | _norm = ((data.T - _min) / (_range + 1e-5)).T 52 | else: 53 | _norm = (data - _min) / (_range + 1e-5) 54 | return _norm 55 | 56 | 57 | class GradCalculate: 58 | def __init__(self, modules, num_classes): 59 | self.modules = [HookModule(module) for module in modules] 60 | 61 | self.values = [[[] for _ in range(num_classes)] for _ in range(len(modules))] 62 | # [num_modules, num_classes, num_images, channels] 63 | 64 | def __call__(self, outputs, labels): 65 | nll_loss = torch.nn.NLLLoss()(outputs, labels) 66 | for layer, module in enumerate(self.modules): 67 | 68 | values = module.grads(-nll_loss, module.outputs) 69 | values = torch.relu(values) 70 | 71 | values = values.detach().cpu().numpy() 72 | 73 | for b in range(len(labels)): 74 | self.values[layer][labels[b]].append(values[b]) 75 | 76 | def sift(self, result_path, threshold): 77 | for layer, values in enumerate(tqdm(self.values)): 78 | values = np.asarray(values) 79 | if len(values.shape) > 3: 80 | values = np.sum(values, axis=(3, 4)) # [num_classes, num_images, channels] 81 | values = np.sum(values, axis=1) # [num_classes, channels] 82 | 83 | values = _normalization(values, axis=1) 84 | 85 | mask = np.zeros(values.shape) 86 | mask[np.where(values > threshold)] = 1 87 | mask_path = os.path.join(result_path, 'layer_{}.npy'.format(layer)) 88 | np.save(mask_path, mask) 89 | 90 | 91 | def main(): 92 | parser = argparse.ArgumentParser(description='') 93 | parser.add_argument('--model_name', default='', type=str, help='model name') 94 | parser.add_argument('--data_name', default='', type=str, help='data name') 95 | parser.add_argument('--in_channels', default='', type=int, help='in channels') 96 | parser.add_argument('--num_classes', default='', type=int, help='num classes') 97 | parser.add_argument('--model_path', default='', type=str, help='model path') 98 | parser.add_argument('--data_path', default='', type=str, help='data path') 99 | parser.add_argument('--grad_path', default='', type=str, help='grad path') 100 | parser.add_argument('--theta', default='', type=float, help='theta') 101 | parser.add_argument('--device_index', default='0', type=str, help='device index') 102 | args = parser.parse_args() 103 | 104 | # ---------------------------------------- 105 | # basic configuration 106 | # ---------------------------------------- 107 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device_index 108 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 109 | 110 | if not os.path.exists(args.grad_path): 111 | os.makedirs(args.grad_path) 112 | 113 | print('-' * 50) 114 | print('TRAIN ON:', device) 115 | print('DATA PATH:', args.data_path) 116 | print('RESULT PATH:', args.grad_path) 117 | print('-' * 50) 118 | 119 | # ---------------------------------------- 120 | # model/data configuration 121 | # ---------------------------------------- 122 | model = models.load_model(model_name=args.model_name, num_classes=args.num_classes) 123 | model.load_state_dict(torch.load(args.model_path)) 124 | model.to(device) 125 | model.eval() 126 | 127 | data_loader = loaders.load_data(data_name=args.data_name, data_path=args.data_path) 128 | 129 | modules = models.load_modules(model=model) 130 | 131 | grad_calculate = GradCalculate(modules=modules, num_classes=args.num_classes) 132 | 133 | # ---------------------------------------- 134 | # forward 135 | # ---------------------------------------- 136 | for i, samples in enumerate(tqdm(data_loader)): 137 | inputs, labels, _ = samples 138 | inputs = inputs.to(device) 139 | labels = labels.to(device) 140 | outputs = model(inputs) 141 | 142 | grad_calculate(outputs, labels) 143 | 144 | grad_calculate.sift(result_path=args.grad_path, threshold=args.theta) 145 | 146 | 147 | if __name__ == '__main__': 148 | np.set_printoptions(threshold=np.inf) 149 | main() 150 | -------------------------------------------------------------------------------- /models/shufflenetv2.py: -------------------------------------------------------------------------------- 1 | """shufflenetv2 in pytorch 2 | 3 | 4 | 5 | [1] Ningning Ma, Xiangyu Zhang, Hai-Tao Zheng, Jian Sun 6 | 7 | ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design 8 | https://arxiv.org/abs/1807.11164 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | 16 | def channel_split(x, split): 17 | """split a tensor into two pieces along channel dimension 18 | Args: 19 | x: input tensor 20 | split:(int) channel size for each pieces 21 | """ 22 | assert x.size(1) == split * 2 23 | return torch.split(x, split, dim=1) 24 | 25 | 26 | def channel_shuffle(x, groups): 27 | """channel shuffle operation 28 | Args: 29 | x: input tensor 30 | groups: input branch number 31 | """ 32 | 33 | batch_size, channels, height, width = x.size() 34 | channels_per_group = int(channels // groups) 35 | 36 | x = x.view(batch_size, groups, channels_per_group, height, width) 37 | x = x.transpose(1, 2).contiguous() 38 | x = x.view(batch_size, -1, height, width) 39 | 40 | return x 41 | 42 | 43 | class ShuffleUnit(nn.Module): 44 | 45 | def __init__(self, in_channels, out_channels, stride): 46 | super().__init__() 47 | 48 | self.stride = stride 49 | self.in_channels = in_channels 50 | self.out_channels = out_channels 51 | 52 | if stride != 1 or in_channels != out_channels: 53 | self.residual = nn.Sequential( 54 | nn.Conv2d(in_channels, in_channels, 1), 55 | nn.BatchNorm2d(in_channels), 56 | nn.ReLU(inplace=True), 57 | nn.Conv2d(in_channels, in_channels, 3, stride=stride, padding=1, groups=in_channels), 58 | nn.BatchNorm2d(in_channels), 59 | nn.Conv2d(in_channels, int(out_channels / 2), 1), 60 | nn.BatchNorm2d(int(out_channels / 2)), 61 | nn.ReLU(inplace=True) 62 | ) 63 | 64 | self.shortcut = nn.Sequential( 65 | nn.Conv2d(in_channels, in_channels, 3, stride=stride, padding=1, groups=in_channels), 66 | nn.BatchNorm2d(in_channels), 67 | nn.Conv2d(in_channels, int(out_channels / 2), 1), 68 | nn.BatchNorm2d(int(out_channels / 2)), 69 | nn.ReLU(inplace=True) 70 | ) 71 | else: 72 | self.shortcut = nn.Sequential() 73 | 74 | in_channels = int(in_channels / 2) 75 | self.residual = nn.Sequential( 76 | nn.Conv2d(in_channels, in_channels, 1), 77 | nn.BatchNorm2d(in_channels), 78 | nn.ReLU(inplace=True), 79 | nn.Conv2d(in_channels, in_channels, 3, stride=stride, padding=1, groups=in_channels), 80 | nn.BatchNorm2d(in_channels), 81 | nn.Conv2d(in_channels, in_channels, 1), 82 | nn.BatchNorm2d(in_channels), 83 | nn.ReLU(inplace=True) 84 | ) 85 | 86 | def forward(self, x): 87 | 88 | if self.stride == 1 and self.out_channels == self.in_channels: 89 | shortcut, residual = channel_split(x, int(self.in_channels / 2)) 90 | else: 91 | shortcut = x 92 | residual = x 93 | 94 | shortcut = self.shortcut(shortcut) 95 | residual = self.residual(residual) 96 | x = torch.cat([shortcut, residual], dim=1) 97 | x = channel_shuffle(x, 2) 98 | 99 | return x 100 | 101 | 102 | class ShuffleNetV2(nn.Module): 103 | 104 | def __init__(self, ratio=1, in_channels=3, num_classes=10): 105 | super().__init__() 106 | if ratio == 0.5: 107 | out_channels = [48, 96, 192, 1024] 108 | elif ratio == 1: 109 | out_channels = [116, 232, 464, 1024] 110 | elif ratio == 1.5: 111 | out_channels = [176, 352, 704, 1024] 112 | elif ratio == 2: 113 | out_channels = [244, 488, 976, 2048] 114 | else: 115 | ValueError('unsupported ratio number') 116 | 117 | self.pre = nn.Sequential( 118 | nn.Conv2d(in_channels, 24, 3, padding=1), 119 | nn.BatchNorm2d(24) 120 | ) 121 | 122 | self.stage2 = self._make_stage(24, out_channels[0], 3) 123 | self.stage3 = self._make_stage(out_channels[0], out_channels[1], 7) 124 | self.stage4 = self._make_stage(out_channels[1], out_channels[2], 3) 125 | self.conv5 = nn.Sequential( 126 | nn.Conv2d(out_channels[2], out_channels[3], 1), 127 | nn.BatchNorm2d(out_channels[3]), 128 | nn.ReLU(inplace=True) 129 | ) 130 | 131 | self.fc = nn.Linear(out_channels[3], num_classes) 132 | 133 | def forward(self, x): 134 | x = self.pre(x) 135 | x = self.stage2(x) 136 | x = self.stage3(x) 137 | x = self.stage4(x) 138 | x = self.conv5(x) 139 | x = F.adaptive_avg_pool2d(x, 1) 140 | x = x.view(x.size(0), -1) 141 | x = self.fc(x) 142 | 143 | return x 144 | 145 | def _make_stage(self, in_channels, out_channels, repeat): 146 | layers = [] 147 | layers.append(ShuffleUnit(in_channels, out_channels, 2)) 148 | 149 | while repeat: 150 | layers.append(ShuffleUnit(out_channels, out_channels, 1)) 151 | repeat -= 1 152 | 153 | return nn.Sequential(*layers) 154 | 155 | 156 | def shufflenetv2(in_channels=3, num_classes=10): 157 | return ShuffleNetV2(in_channels=in_channels, num_classes=num_classes) 158 | -------------------------------------------------------------------------------- /preprocessing/stl10/image_gen.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append('/workspace/classification/code/') # zjl 4 | # sys.path.append('/nfs3-p1/hjc/classification/code/') # vipa 5 | 6 | import os, sys, tarfile, errno 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | import urllib 10 | from imageio import imsave 11 | 12 | from configs import config 13 | 14 | # image shape 15 | HEIGHT = 96 16 | WIDTH = 96 17 | DEPTH = 3 18 | 19 | # size of a single image in bytes 20 | SIZE = HEIGHT * WIDTH * DEPTH 21 | 22 | # path to the directory with the data 23 | DATA_DIR = config.datasets_STL_10 24 | DATA_TYPE = 'test' 25 | 26 | # path to the binary train file with image data 27 | DATA_PATH = DATA_DIR + '/{}_X.bin'.format(DATA_TYPE) 28 | 29 | # path to the binary train file with labels 30 | LABEL_PATH = DATA_DIR + '/{}_y.bin'.format(DATA_TYPE) 31 | 32 | # url of the binary data 33 | DATA_URL = 'http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz' 34 | 35 | 36 | def read_labels(path_to_labels): 37 | """ 38 | :param path_to_labels: path to the binary file containing labels from the STL-10 dataset 39 | :return: an array containing the labels 40 | """ 41 | with open(path_to_labels, 'rb') as f: 42 | labels = np.fromfile(f, dtype=np.uint8) 43 | return labels 44 | 45 | 46 | def read_all_images(path_to_data): 47 | """ 48 | :param path_to_data: the file containing the binary images from the STL-10 dataset 49 | :return: an array containing all the images 50 | """ 51 | 52 | with open(path_to_data, 'rb') as f: 53 | # read whole file in uint8 chunks 54 | everything = np.fromfile(f, dtype=np.uint8) 55 | 56 | # We force the data into 3x96x96 chunks, since the 57 | # images are stored in "column-major order", meaning 58 | # that "the first 96*96 values are the red channel, 59 | # the next 96*96 are green, and the last are blue." 60 | # The -1 is since the size of the pictures depends 61 | # on the input file, and this way numpy determines 62 | # the size on its own. 63 | 64 | images = np.reshape(everything, (-1, 3, 96, 96)) 65 | 66 | # Now transpose the images into a standard image format 67 | # readable by, for example, matplotlib.imshow 68 | # You might want to comment this line or reverse the shuffle 69 | # if you will use a learning algorithm like CNN, since they like 70 | # their channels separated. 71 | images = np.transpose(images, (0, 3, 2, 1)) 72 | return images 73 | 74 | 75 | def read_single_image(image_file): 76 | """ 77 | CAREFUL! - this method uses a file as input instead of the path - so the 78 | position of the reader will be remembered outside of context of this method. 79 | :param image_file: the open file containing the images 80 | :return: a single image 81 | """ 82 | # read a single image, count determines the number of uint8's to read 83 | image = np.fromfile(image_file, dtype=np.uint8, count=SIZE) 84 | # force into image matrix 85 | image = np.reshape(image, (3, 96, 96)) 86 | # transpose to standard format 87 | # You might want to comment this line or reverse the shuffle 88 | # if you will use a learning algorithm like CNN, since they like 89 | # their channels separated. 90 | image = np.transpose(image, (2, 1, 0)) 91 | return image 92 | 93 | 94 | def plot_image(image): 95 | """ 96 | :param image: the image to be plotted in a 3-D matrix format 97 | :return: None 98 | """ 99 | plt.imshow(image) 100 | plt.show() 101 | 102 | 103 | def save_image(image, name): 104 | imsave("%s.png" % name, image, format="png") 105 | 106 | 107 | def download_and_extract(): 108 | """ 109 | Download and extract the STL-10 dataset 110 | :return: None 111 | """ 112 | dest_directory = DATA_DIR 113 | if not os.path.exists(dest_directory): 114 | os.makedirs(dest_directory) 115 | filename = DATA_URL.split('/')[-1] 116 | filepath = os.path.join(dest_directory, filename) 117 | if not os.path.exists(filepath): 118 | def _progress(count, block_size, total_size): 119 | sys.stdout.write('\rDownloading %s %.2f%%' % (filename, 120 | float(count * block_size) / float(total_size) * 100.0)) 121 | sys.stdout.flush() 122 | 123 | filepath, _ = urllib.urlretrieve(DATA_URL, filepath, reporthook=_progress) 124 | print('Downloaded', filename) 125 | tarfile.open(filepath, 'r:gz').extractall(dest_directory) 126 | 127 | 128 | def save_images(images, labels): 129 | print("Saving images to disk") 130 | i = 0 131 | for image in images: 132 | label = labels[i] 133 | directory = config.data_stl10 + '/' + DATA_TYPE + '/' + str(label) + '/' 134 | try: 135 | os.makedirs(directory, exist_ok=True) 136 | except OSError as exc: 137 | if exc.errno == errno.EEXIST: 138 | pass 139 | filename = directory + str(i) 140 | print(filename) 141 | save_image(image, filename) 142 | i = i + 1 143 | 144 | 145 | if __name__ == "__main__": 146 | # download data if needed 147 | # download_and_extract() 148 | 149 | # # test to check if the image is read correctly 150 | # with open(DATA_PATH) as f: 151 | # image = read_single_image(f) 152 | # plot_image(image) 153 | 154 | # test to check if the whole dataset is read correctly 155 | images = read_all_images(DATA_PATH) 156 | print(images.shape) 157 | 158 | labels = read_labels(LABEL_PATH) 159 | print(labels.shape) 160 | 161 | # save images to disk 162 | save_images(images, labels) 163 | -------------------------------------------------------------------------------- /models/densenet.py: -------------------------------------------------------------------------------- 1 | """dense net in pytorch 2 | 3 | 4 | 5 | [1] Gao Huang, Zhuang Liu, Laurens van der Maaten, Kilian Q. Weinberger. 6 | 7 | Densely Connected Convolutional Networks 8 | https://arxiv.org/abs/1608.06993v5 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | 15 | # """Bottleneck layers. Although each layer only produces k 16 | # output feature-maps, it typically has many more inputs. It 17 | # has been noted in [37, 11] that a 1×1 convolution can be in- 18 | # troduced as bottleneck layer before each 3×3 convolution 19 | # to reduce the number of input feature-maps, and thus to 20 | # improve computational efficiency.""" 21 | class Bottleneck(nn.Module): 22 | def __init__(self, in_channels, growth_rate): 23 | super().__init__() 24 | # """In our experiments, we let each 1×1 convolution 25 | # produce 4k feature-maps.""" 26 | inner_channel = 4 * growth_rate 27 | 28 | # """We find this design especially effective for DenseNet and 29 | # we refer to our network with such a bottleneck layer, i.e., 30 | # to the BN-ReLU-Conv(1×1)-BN-ReLU-Conv(3×3) version of H ` , 31 | # as DenseNet-B.""" 32 | self.bottle_neck = nn.Sequential( 33 | nn.BatchNorm2d(in_channels), 34 | nn.ReLU(inplace=True), 35 | nn.Conv2d(in_channels, inner_channel, kernel_size=1, bias=False), 36 | nn.BatchNorm2d(inner_channel), 37 | nn.ReLU(inplace=True), 38 | nn.Conv2d(inner_channel, growth_rate, kernel_size=3, padding=1, bias=False) 39 | ) 40 | 41 | def forward(self, x): 42 | return torch.cat([x, self.bottle_neck(x)], 1) 43 | 44 | 45 | # """We refer to layers between blocks as transition 46 | # layers, which do convolution and pooling.""" 47 | class Transition(nn.Module): 48 | def __init__(self, in_channels, out_channels): 49 | super().__init__() 50 | # """The transition layers used in our experiments 51 | # consist of a batch normalization layer and an 1×1 52 | # convolutional layer followed by a 2×2 average pooling 53 | # layer""". 54 | self.down_sample = nn.Sequential( 55 | nn.BatchNorm2d(in_channels), 56 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 57 | nn.AvgPool2d(2, stride=2) 58 | ) 59 | 60 | def forward(self, x): 61 | return self.down_sample(x) 62 | 63 | 64 | # DesneNet-BC 65 | # B stands for bottleneck layer(BN-RELU-CONV(1x1)-BN-RELU-CONV(3x3)) 66 | # C stands for compression factor(0<=theta<=1) 67 | class DenseNet(nn.Module): 68 | def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, in_channels=3, num_classes=10): 69 | super().__init__() 70 | self.growth_rate = growth_rate 71 | 72 | # """Before entering the first dense block, a convolution 73 | # with 16 (or twice the growth rate for DenseNet-BC) 74 | # output channels is performed on the input images.""" 75 | inner_channels = 2 * growth_rate 76 | 77 | # For convolutional layers with kernel size 3×3, each 78 | # side of the inputs is zero-padded by one pixel to keep 79 | # the feature-map size fixed. 80 | self.conv1 = nn.Conv2d(in_channels, inner_channels, kernel_size=3, padding=1, bias=False) 81 | 82 | self.features = nn.Sequential() 83 | 84 | for index in range(len(nblocks) - 1): 85 | self.features.add_module("dense_block_layer_{}".format(index), 86 | self._make_dense_layers(block, inner_channels, nblocks[index])) 87 | inner_channels += growth_rate * nblocks[index] 88 | 89 | # """If a dense block contains m feature-maps, we let the 90 | # following transition layer generate θm output feature- 91 | # maps, where 0 < θ ≤ 1 is referred to as the compression 92 | # fac-tor. 93 | out_channels = int(reduction * inner_channels) # int() will automatic floor the value 94 | self.features.add_module("transition_layer_{}".format(index), Transition(inner_channels, out_channels)) 95 | inner_channels = out_channels 96 | 97 | self.features.add_module("dense_block{}".format(len(nblocks) - 1), 98 | self._make_dense_layers(block, inner_channels, nblocks[len(nblocks) - 1])) 99 | inner_channels += growth_rate * nblocks[len(nblocks) - 1] 100 | self.features.add_module('bn', nn.BatchNorm2d(inner_channels)) 101 | self.features.add_module('relu', nn.ReLU(inplace=True)) 102 | 103 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 104 | 105 | self.linear = nn.Linear(inner_channels, num_classes) 106 | 107 | def forward(self, x): 108 | output = self.conv1(x) 109 | output = self.features(output) 110 | output = self.avgpool(output) 111 | output = output.view(output.size()[0], -1) 112 | output = self.linear(output) 113 | return output 114 | 115 | def _make_dense_layers(self, block, in_channels, nblocks): 116 | dense_block = nn.Sequential() 117 | for index in range(nblocks): 118 | dense_block.add_module('bottle_neck_layer_{}'.format(index), block(in_channels, self.growth_rate)) 119 | in_channels += self.growth_rate 120 | return dense_block 121 | 122 | 123 | def densenet121(in_channels=3, num_classes=10): 124 | return DenseNet(Bottleneck, [6, 12, 24, 16], growth_rate=32, in_channels=in_channels, num_classes=num_classes) 125 | 126 | 127 | def densenet169(): 128 | return DenseNet(Bottleneck, [6, 12, 32, 32], growth_rate=32) 129 | 130 | 131 | def densenet201(): 132 | return DenseNet(Bottleneck, [6, 12, 48, 32], growth_rate=32) 133 | 134 | 135 | def densenet161(): 136 | return DenseNet(Bottleneck, [6, 12, 36, 24], growth_rate=48) 137 | 138 | -------------------------------------------------------------------------------- /models/senet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class BasicResidualSEBlock(nn.Module): 7 | expansion = 1 8 | 9 | def __init__(self, in_channels, out_channels, stride, r=16): 10 | super().__init__() 11 | 12 | self.residual = nn.Sequential( 13 | nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1), 14 | nn.BatchNorm2d(out_channels), 15 | nn.ReLU(inplace=True), 16 | 17 | nn.Conv2d(out_channels, out_channels * self.expansion, 3, padding=1), 18 | nn.BatchNorm2d(out_channels * self.expansion), 19 | nn.ReLU(inplace=True) 20 | ) 21 | 22 | self.shortcut = nn.Sequential() 23 | if stride != 1 or in_channels != out_channels * self.expansion: 24 | self.shortcut = nn.Sequential( 25 | nn.Conv2d(in_channels, out_channels * self.expansion, 1, stride=stride), 26 | nn.BatchNorm2d(out_channels * self.expansion) 27 | ) 28 | 29 | self.squeeze = nn.AdaptiveAvgPool2d(1) 30 | self.excitation = nn.Sequential( 31 | nn.Linear(out_channels * self.expansion, out_channels * self.expansion // r), 32 | nn.ReLU(inplace=True), 33 | nn.Linear(out_channels * self.expansion // r, out_channels * self.expansion), 34 | nn.Sigmoid() 35 | ) 36 | 37 | def forward(self, x): 38 | shortcut = self.shortcut(x) 39 | residual = self.residual(x) 40 | 41 | squeeze = self.squeeze(residual) 42 | squeeze = squeeze.view(squeeze.size(0), -1) 43 | excitation = self.excitation(squeeze) 44 | excitation = excitation.view(residual.size(0), residual.size(1), 1, 1) 45 | 46 | x = residual * excitation.expand_as(residual) + shortcut 47 | 48 | return F.relu(x) 49 | 50 | 51 | class BottleneckResidualSEBlock(nn.Module): 52 | expansion = 4 53 | 54 | def __init__(self, in_channels, out_channels, stride, r=16): 55 | super().__init__() 56 | 57 | self.residual = nn.Sequential( 58 | nn.Conv2d(in_channels, out_channels, 1), 59 | nn.BatchNorm2d(out_channels), 60 | nn.ReLU(inplace=True), 61 | 62 | nn.Conv2d(out_channels, out_channels, 3, stride=stride, padding=1), 63 | nn.BatchNorm2d(out_channels), 64 | nn.ReLU(inplace=True), 65 | 66 | nn.Conv2d(out_channels, out_channels * self.expansion, 1), 67 | nn.BatchNorm2d(out_channels * self.expansion), 68 | nn.ReLU(inplace=True) 69 | ) 70 | 71 | self.squeeze = nn.AdaptiveAvgPool2d(1) 72 | self.excitation = nn.Sequential( 73 | nn.Linear(out_channels * self.expansion, out_channels * self.expansion // r), 74 | nn.ReLU(inplace=True), 75 | nn.Linear(out_channels * self.expansion // r, out_channels * self.expansion), 76 | nn.Sigmoid() 77 | ) 78 | 79 | self.shortcut = nn.Sequential() 80 | if stride != 1 or in_channels != out_channels * self.expansion: 81 | self.shortcut = nn.Sequential( 82 | nn.Conv2d(in_channels, out_channels * self.expansion, 1, stride=stride), 83 | nn.BatchNorm2d(out_channels * self.expansion) 84 | ) 85 | 86 | def forward(self, x): 87 | shortcut = self.shortcut(x) 88 | 89 | residual = self.residual(x) 90 | squeeze = self.squeeze(residual) 91 | squeeze = squeeze.view(squeeze.size(0), -1) 92 | excitation = self.excitation(squeeze) 93 | excitation = excitation.view(residual.size(0), residual.size(1), 1, 1) 94 | 95 | x = residual * excitation.expand_as(residual) + shortcut 96 | 97 | return F.relu(x) 98 | 99 | 100 | class SEResNet(nn.Module): 101 | 102 | def __init__(self, block, block_num, in_channels=3, num_classes=10): 103 | super().__init__() 104 | 105 | self.in_channels = 64 106 | 107 | self.pre = nn.Sequential( 108 | nn.Conv2d(in_channels, 64, 3, padding=1), 109 | nn.BatchNorm2d(64), 110 | nn.ReLU(inplace=True) 111 | ) 112 | 113 | self.stage1 = self._make_stage(block, block_num[0], 64, 1) 114 | self.stage2 = self._make_stage(block, block_num[1], 128, 2) 115 | self.stage3 = self._make_stage(block, block_num[2], 256, 2) 116 | self.stage4 = self._make_stage(block, block_num[3], 512, 2) 117 | 118 | self.linear = nn.Linear(self.in_channels, num_classes) 119 | 120 | def forward(self, x): 121 | x = self.pre(x) 122 | 123 | x = self.stage1(x) 124 | x = self.stage2(x) 125 | x = self.stage3(x) 126 | x = self.stage4(x) 127 | 128 | x = F.adaptive_avg_pool2d(x, 1) 129 | x = x.view(x.size(0), -1) 130 | 131 | x = self.linear(x) 132 | 133 | return x 134 | 135 | def _make_stage(self, block, num, out_channels, stride): 136 | layers = [] 137 | layers.append(block(self.in_channels, out_channels, stride)) 138 | self.in_channels = out_channels * block.expansion 139 | 140 | while num - 1: 141 | layers.append(block(self.in_channels, out_channels, 1)) 142 | num -= 1 143 | 144 | return nn.Sequential(*layers) 145 | 146 | 147 | def seresnet18(): 148 | return SEResNet(BasicResidualSEBlock, [2, 2, 2, 2]) 149 | 150 | 151 | def seresnet34(in_channels=3, num_classes=10): 152 | return SEResNet(BasicResidualSEBlock, [3, 4, 6, 3], in_channels=in_channels, num_classes=num_classes) 153 | 154 | 155 | def seresnet50(): 156 | return SEResNet(BottleneckResidualSEBlock, [3, 4, 6, 3]) 157 | 158 | 159 | def seresnet101(): 160 | return SEResNet(BottleneckResidualSEBlock, [3, 4, 23, 3]) 161 | 162 | 163 | def seresnet152(): 164 | return SEResNet(BottleneckResidualSEBlock, [3, 8, 36, 3]) 165 | 166 | -------------------------------------------------------------------------------- /core/image_sift.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append('.') 4 | 5 | import os 6 | import torch 7 | import argparse 8 | from tqdm import tqdm 9 | 10 | import loaders 11 | import models 12 | from utils import file_util 13 | 14 | 15 | class ImageSift: 16 | def __init__(self, num_classes, num_images, is_high_confidence=True): 17 | self.names = [[None for j in range(num_images)] for i in range(num_classes)] 18 | self.scores = torch.zeros((num_classes, num_images)) 19 | self.nums = torch.zeros(num_classes, dtype=torch.long) 20 | self.is_high_confidence = is_high_confidence 21 | 22 | def __call__(self, outputs, labels, names): 23 | softmax = torch.nn.Softmax(dim=1)(outputs.detach()) 24 | scores, predicts = torch.max(softmax, dim=1) 25 | # print(scores) 26 | 27 | if self.is_high_confidence: 28 | for i, label in enumerate(labels): 29 | if label == predicts[i]: 30 | if self.nums[label] == self.scores.shape[1]: 31 | score_min, index = torch.min(self.scores[label], dim=0) 32 | if scores[i] > score_min: 33 | self.scores[label][index] = scores[i] 34 | self.names[label.item()][index.item()] = names[i] 35 | else: 36 | self.scores[label][self.nums[label]] = scores[i] 37 | self.names[label.item()][self.nums[label].item()] = names[i] 38 | self.nums[label] += 1 39 | else: 40 | for i, label in enumerate(labels): 41 | if self.nums[label] == self.scores.shape[1]: 42 | score_max, index = torch.max(self.scores[label], dim=0) 43 | if label == predicts[i]: # TP-LS 44 | if scores[i] < score_max: 45 | self.scores[label][index] = scores[i] 46 | self.names[label.item()][index.item()] = names[i] 47 | else: # TN-HS 48 | if -scores[i] < score_max: 49 | self.scores[label][index] = -scores[i] 50 | self.names[label.item()][index.item()] = names[i] 51 | else: 52 | if label == predicts[i]: # TP-LS 53 | self.scores[label][self.nums[label]] = scores[i] 54 | self.names[label.item()][self.nums[label].item()] = names[i] 55 | self.nums[label] += 1 56 | else: # TN-HS 57 | self.scores[label][self.nums[label]] = -scores[i] 58 | self.names[label.item()][self.nums[label].item()] = names[i] 59 | self.nums[label] += 1 60 | 61 | def save_image(self, input_path, output_path): 62 | print(self.scores) 63 | print(self.nums) 64 | 65 | class_names = sorted([d.name for d in os.scandir(input_path) if d.is_dir()]) 66 | 67 | for label, image_list in enumerate(self.names): 68 | for image in tqdm(image_list): 69 | class_name = class_names[label] 70 | 71 | src_path = os.path.join(input_path, class_name, str(image)) 72 | dst_path = os.path.join(output_path, class_name, str(image)) 73 | file_util.copy_file(src_path, dst_path) 74 | 75 | 76 | def main(): 77 | parser = argparse.ArgumentParser(description='') 78 | parser.add_argument('--model_name', default='', type=str, help='model name') 79 | parser.add_argument('--data_name', default='', type=str, help='data name') 80 | parser.add_argument('--in_channels', default='', type=int, help='in channels') 81 | parser.add_argument('--num_classes', default='', type=int, help='num classes') 82 | parser.add_argument('--model_path', default='', type=str, help='model path') 83 | parser.add_argument('--data_path', default='', type=str, help='data path') 84 | parser.add_argument('--image_path', default='', type=str, help='image path') 85 | parser.add_argument('--num_images', default=10, type=int, help='num images') 86 | parser.add_argument('--device_index', default='0', type=str, help='device index') 87 | args = parser.parse_args() 88 | 89 | # ---------------------------------------- 90 | # basic configuration 91 | # ---------------------------------------- 92 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device_index 93 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 94 | 95 | if not os.path.exists(args.image_path): 96 | os.makedirs(args.image_path) 97 | 98 | print('-' * 50) 99 | print('TRAIN ON:', device) 100 | print('MODEL PATH:', args.model_path) 101 | print('DATA PATH:', args.data_path) 102 | print('RESULT PATH:', args.image_path) 103 | print('-' * 50) 104 | 105 | # ---------------------------------------- 106 | # model/data configuration 107 | # ---------------------------------------- 108 | model = models.load_model(model_name=args.model_name, num_classes=args.num_classes) 109 | model.load_state_dict(torch.load(args.model_path)) 110 | model.to(device) 111 | model.eval() 112 | 113 | data_loader = loaders.load_data(args.data_name, args.data_path, data_type='test') 114 | 115 | image_sift = ImageSift(num_classes=args.num_classes, num_images=args.num_images, is_high_confidence=True) 116 | 117 | # ---------------------------------------- 118 | # forward 119 | # ---------------------------------------- 120 | for samples in tqdm(data_loader): 121 | inputs, labels, names = samples 122 | inputs = inputs.to(device) 123 | labels = labels.to(device) 124 | outputs = model(inputs) 125 | 126 | image_sift(outputs=outputs, labels=labels, names=names) 127 | 128 | image_sift.save_image(args.data_path, args.image_path) 129 | 130 | 131 | if __name__ == '__main__': 132 | main() 133 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class BasicBlock(nn.Module): 5 | """Basic Block for resnet 18 and resnet 34 6 | """ 7 | 8 | # BasicBlock and BottleNeck block 9 | # have different output size 10 | # we use class attribute expansion 11 | # to distinct 12 | expansion = 1 13 | 14 | def __init__(self, in_channels, out_channels, stride=1): 15 | super().__init__() 16 | 17 | # residual function 18 | self.residual_function = nn.Sequential( 19 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False), 20 | nn.BatchNorm2d(out_channels), 21 | nn.ReLU(inplace=True), 22 | nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False), 23 | nn.BatchNorm2d(out_channels * BasicBlock.expansion) 24 | ) 25 | 26 | # shortcut 27 | self.shortcut = nn.Sequential() 28 | 29 | # the shortcut output dimension is not the same with residual function 30 | # use 1*1 convolution to match the dimension 31 | if stride != 1 or in_channels != BasicBlock.expansion * out_channels: 32 | self.shortcut = nn.Sequential( 33 | nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False), 34 | nn.BatchNorm2d(out_channels * BasicBlock.expansion) 35 | ) 36 | 37 | def forward(self, x): 38 | return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x)) 39 | 40 | 41 | class BottleNeck(nn.Module): 42 | """Residual block for resnet over 50 layers 43 | """ 44 | expansion = 4 45 | 46 | def __init__(self, in_channels, out_channels, stride=1): 47 | super().__init__() 48 | self.residual_function = nn.Sequential( 49 | nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), 50 | nn.BatchNorm2d(out_channels), 51 | nn.ReLU(inplace=True), 52 | nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False), 53 | nn.BatchNorm2d(out_channels), 54 | nn.ReLU(inplace=True), 55 | nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False), 56 | nn.BatchNorm2d(out_channels * BottleNeck.expansion), 57 | ) 58 | 59 | self.shortcut = nn.Sequential() 60 | 61 | if stride != 1 or in_channels != out_channels * BottleNeck.expansion: 62 | self.shortcut = nn.Sequential( 63 | nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False), 64 | nn.BatchNorm2d(out_channels * BottleNeck.expansion) 65 | ) 66 | 67 | def forward(self, x): 68 | return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x)) 69 | 70 | 71 | class ResNet(nn.Module): 72 | 73 | def __init__(self, block, num_block, in_channels=3, num_classes=10): 74 | super().__init__() 75 | 76 | self.in_channels = 64 77 | 78 | self.conv1 = nn.Sequential( 79 | nn.Conv2d(in_channels, 64, kernel_size=3, padding=1, bias=False), 80 | nn.BatchNorm2d(64), 81 | nn.ReLU(inplace=True)) 82 | # we use a different inputsize than the original paper 83 | # so conv2_x's stride is 1 84 | self.conv2_x = self._make_layer(block, 64, num_block[0], 1) 85 | self.conv3_x = self._make_layer(block, 128, num_block[1], 2) 86 | self.conv4_x = self._make_layer(block, 256, num_block[2], 2) 87 | self.conv5_x = self._make_layer(block, 512, num_block[3], 2) 88 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 89 | self.fc = nn.Linear(512 * block.expansion, num_classes) 90 | 91 | def _make_layer(self, block, out_channels, num_blocks, stride): 92 | """make resnet layers(by layer i didnt mean this 'layer' was the 93 | same as a neuron netowork layer, ex. conv layer), one layer may 94 | contain more than one residual block 95 | Args: 96 | block: block type, basic block or bottle neck block 97 | out_channels: output depth channel number of this layer 98 | num_blocks: how many blocks per layer 99 | stride: the stride of the first block of this layer 100 | Return: 101 | return a resnet layer 102 | """ 103 | 104 | # we have num_block blocks per layer, the first block 105 | # could be 1 or 2, other blocks would always be 1 106 | strides = [stride] + [1] * (num_blocks - 1) 107 | layers = [] 108 | for stride in strides: 109 | layers.append(block(self.in_channels, out_channels, stride)) 110 | self.in_channels = out_channels * block.expansion 111 | 112 | return nn.Sequential(*layers) 113 | 114 | def forward(self, x): 115 | output = self.conv1(x) 116 | output = self.conv2_x(output) 117 | output = self.conv3_x(output) 118 | output = self.conv4_x(output) 119 | output = self.conv5_x(output) 120 | output = self.avg_pool(output) 121 | output = output.view(output.size(0), -1) 122 | output = self.fc(output) 123 | 124 | return output 125 | 126 | 127 | def resnet18(): 128 | """ return a ResNet 18 object 129 | """ 130 | return ResNet(BasicBlock, [2, 2, 2, 2]) 131 | 132 | 133 | def resnet34(in_channels=3, num_classes=10): 134 | """ return a ResNet 34 object 135 | """ 136 | return ResNet(BasicBlock, [3, 4, 6, 3], in_channels=in_channels, num_classes=num_classes) 137 | 138 | 139 | def resnet50(in_channels=3, num_classes=10): 140 | """ return a ResNet 50 object 141 | """ 142 | return ResNet(BottleNeck, [3, 4, 6, 3], in_channels=in_channels, num_classes=num_classes) 143 | 144 | 145 | def resnet101(): 146 | """ return a ResNet 101 object 147 | """ 148 | return ResNet(BottleNeck, [3, 4, 23, 3]) 149 | 150 | 151 | def resnet152(): 152 | """ return a ResNet 152 object 153 | """ 154 | return ResNet(BottleNeck, [3, 8, 36, 3]) 155 | 156 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import time 4 | import shutil 5 | from tqdm import tqdm 6 | 7 | import torch 8 | from torch import nn 9 | from torch import optim 10 | from torch.utils.tensorboard import SummaryWriter 11 | 12 | import loaders 13 | import models 14 | import metrics 15 | from utils.train_util import AverageMeter, ProgressMeter 16 | 17 | 18 | def main(): 19 | parser = argparse.ArgumentParser(description='') 20 | parser.add_argument('--model_name', default='', type=str, help='model name') 21 | parser.add_argument('--data_name', default='', type=str, help='data name') 22 | parser.add_argument('--in_channels', default='', type=int, help='in channels') 23 | parser.add_argument('--num_classes', default='', type=int, help='num classes') 24 | parser.add_argument('--num_epochs', default=200, type=int, help='num epochs') 25 | parser.add_argument('--model_dir', default='', type=str, help='model dir') 26 | parser.add_argument('--data_dir', default='', type=str, help='data dir') 27 | parser.add_argument('--log_dir', default='', type=str, help='log dir') 28 | parser.add_argument('--device_index', default='0', type=str, help='device index') 29 | args = parser.parse_args() 30 | 31 | # ---------------------------------------- 32 | # basic configuration 33 | # ---------------------------------------- 34 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device_index 35 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 36 | 37 | train_path = os.path.join(args.data_dir, 'train') 38 | test_path = os.path.join(args.data_dir, 'test') 39 | 40 | if not os.path.exists(args.model_dir): 41 | os.makedirs(args.model_dir) 42 | if os.path.exists(args.log_dir): 43 | shutil.rmtree(args.log_dir) 44 | 45 | print('-' * 50) 46 | print('TRAIN ON:', device) 47 | print('MODEL DIR:', args.model_dir) 48 | print('LOG DIR:', args.log_dir) 49 | print('-' * 50) 50 | 51 | # ---------------------------------------- 52 | # trainer configuration 53 | # ---------------------------------------- 54 | model = models.load_model(args.model_name, num_classes=args.num_classes) 55 | model.to(device) 56 | 57 | train_loader = loaders.load_data(args.data_name, train_path, data_type='train') 58 | test_loader = loaders.load_data(args.data_name, test_path, data_type='test') 59 | 60 | criterion = nn.CrossEntropyLoss() 61 | optimizer = optim.SGD(params=model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) 62 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=args.num_epochs) 63 | 64 | writer = SummaryWriter(args.log_dir) 65 | 66 | # ---------------------------------------- 67 | # each epoch 68 | # ---------------------------------------- 69 | since = time.time() 70 | 71 | best_acc = None 72 | best_epoch = None 73 | 74 | for epoch in tqdm(range(args.num_epochs)): 75 | loss, acc1, acc5 = train(train_loader, model, criterion, optimizer, device) 76 | writer.add_scalar(tag='training loss', scalar_value=loss.avg, global_step=epoch) 77 | writer.add_scalar(tag='training acc1', scalar_value=acc1.avg, global_step=epoch) 78 | loss, acc1, acc5 = test(test_loader, model, criterion, device) 79 | writer.add_scalar(tag='test loss', scalar_value=loss.avg, global_step=epoch) 80 | writer.add_scalar(tag='test acc1', scalar_value=acc1.avg, global_step=epoch) 81 | 82 | # ---------------------------------------- 83 | # save best model 84 | # ---------------------------------------- 85 | if best_acc is None or best_acc < acc1.avg: 86 | best_acc = acc1.avg 87 | best_epoch = epoch 88 | torch.save(model.state_dict(), os.path.join(args.model_dir, 'model_ori.pth')) 89 | 90 | scheduler.step() 91 | 92 | print('COMPLETE !!!') 93 | print('BEST ACC', best_acc) 94 | print('BEST EPOCH', best_epoch) 95 | print('TIME CONSUMED', time.time() - since) 96 | 97 | 98 | def train(train_loader, model, criterion, optimizer, device): 99 | loss_meter = AverageMeter('Loss', ':.4e') 100 | acc1_meter = AverageMeter('Acc@1', ':6.2f') 101 | acc5_meter = AverageMeter('Acc@5', ':6.2f') 102 | progress = ProgressMeter(total=len(train_loader), step=20, prefix='Training', 103 | meters=[loss_meter, acc1_meter, acc5_meter]) 104 | 105 | model.train() 106 | 107 | for i, samples in enumerate(train_loader): 108 | inputs, labels, _ = samples 109 | inputs = inputs.to(device) 110 | labels = labels.to(device) 111 | 112 | outputs = model(inputs) 113 | loss = criterion(outputs, labels) 114 | acc1, acc5 = metrics.accuracy(outputs, labels, topk=(1, 5)) 115 | 116 | loss_meter.update(loss.item(), inputs.size(0)) 117 | acc1_meter.update(acc1.item(), inputs.size(0)) 118 | acc5_meter.update(acc5.item(), inputs.size(0)) 119 | 120 | optimizer.zero_grad() # 1 121 | loss.backward() # 2 122 | optimizer.step() # 3 123 | 124 | progress.display(i) 125 | 126 | return loss_meter, acc1_meter, acc5_meter 127 | 128 | 129 | def test(test_loader, model, criterion, device): 130 | loss_meter = AverageMeter('Loss', ':.4e') 131 | acc1_meter = AverageMeter('Acc@1', ':6.2f') 132 | acc5_meter = AverageMeter('Acc@5', ':6.2f') 133 | progress = ProgressMeter(total=len(test_loader), step=20, prefix='Test', 134 | meters=[loss_meter, acc1_meter, acc5_meter]) 135 | model.eval() 136 | 137 | for i, samples in enumerate(test_loader): 138 | inputs, labels, _ = samples 139 | inputs = inputs.to(device) 140 | labels = labels.to(device) 141 | 142 | with torch.set_grad_enabled(False): 143 | outputs = model(inputs) 144 | loss = criterion(outputs, labels) 145 | acc1, acc5 = metrics.accuracy(outputs, labels, topk=(1, 5)) 146 | 147 | loss_meter.update(loss.item(), inputs.size(0)) 148 | acc1_meter.update(acc1.item(), inputs.size(0)) 149 | acc5_meter.update(acc5.item(), inputs.size(0)) 150 | 151 | progress.display(i) 152 | 153 | return loss_meter, acc1_meter, acc5_meter 154 | 155 | 156 | if __name__ == '__main__': 157 | main() 158 | -------------------------------------------------------------------------------- /models/mobilenet.py: -------------------------------------------------------------------------------- 1 | """mobilenet in pytorch 2 | 3 | 4 | 5 | [1] Andrew G. Howard, Menglong Zhu, Bo Chen, Dmitry Kalenichenko, Weijun Wang, Tobias Weyand, Marco Andreetto, Hartwig Adam 6 | 7 | MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications 8 | https://arxiv.org/abs/1704.04861 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | 15 | class DepthSeperabelConv2d(nn.Module): 16 | 17 | def __init__(self, input_channels, output_channels, kernel_size, **kwargs): 18 | super().__init__() 19 | self.depthwise = nn.Sequential( 20 | nn.Conv2d( 21 | input_channels, 22 | input_channels, 23 | kernel_size, 24 | groups=input_channels, 25 | **kwargs), 26 | nn.BatchNorm2d(input_channels), 27 | nn.ReLU(inplace=True) 28 | ) 29 | 30 | self.pointwise = nn.Sequential( 31 | nn.Conv2d(input_channels, output_channels, 1), 32 | nn.BatchNorm2d(output_channels), 33 | nn.ReLU(inplace=True) 34 | ) 35 | 36 | def forward(self, x): 37 | x = self.depthwise(x) 38 | x = self.pointwise(x) 39 | 40 | return x 41 | 42 | 43 | class BasicConv2d(nn.Module): 44 | 45 | def __init__(self, input_channels, output_channels, kernel_size, **kwargs): 46 | super().__init__() 47 | self.conv = nn.Conv2d( 48 | input_channels, output_channels, kernel_size, **kwargs) 49 | self.bn = nn.BatchNorm2d(output_channels) 50 | self.relu = nn.ReLU(inplace=True) 51 | 52 | def forward(self, x): 53 | x = self.conv(x) 54 | x = self.bn(x) 55 | x = self.relu(x) 56 | 57 | return x 58 | 59 | 60 | class MobileNet(nn.Module): 61 | """ 62 | Args: 63 | width multipler: The role of the width multiplier α is to thin 64 | a network uniformly at each layer. For a given 65 | layer and width multiplier α, the number of 66 | input channels M becomes αM and the number of 67 | output channels N becomes αN. 68 | """ 69 | 70 | def __init__(self, width_multiplier=1, class_num=100): 71 | super().__init__() 72 | 73 | alpha = width_multiplier 74 | self.stem = nn.Sequential( 75 | BasicConv2d(3, int(32 * alpha), 3, padding=1, bias=False), 76 | DepthSeperabelConv2d( 77 | int(32 * alpha), 78 | int(64 * alpha), 79 | 3, 80 | padding=1, 81 | bias=False 82 | ) 83 | ) 84 | 85 | # downsample 86 | self.conv1 = nn.Sequential( 87 | DepthSeperabelConv2d( 88 | int(64 * alpha), 89 | int(128 * alpha), 90 | 3, 91 | stride=2, 92 | padding=1, 93 | bias=False 94 | ), 95 | DepthSeperabelConv2d( 96 | int(128 * alpha), 97 | int(128 * alpha), 98 | 3, 99 | padding=1, 100 | bias=False 101 | ) 102 | ) 103 | 104 | # downsample 105 | self.conv2 = nn.Sequential( 106 | DepthSeperabelConv2d( 107 | int(128 * alpha), 108 | int(256 * alpha), 109 | 3, 110 | stride=2, 111 | padding=1, 112 | bias=False 113 | ), 114 | DepthSeperabelConv2d( 115 | int(256 * alpha), 116 | int(256 * alpha), 117 | 3, 118 | padding=1, 119 | bias=False 120 | ) 121 | ) 122 | 123 | # downsample 124 | self.conv3 = nn.Sequential( 125 | DepthSeperabelConv2d( 126 | int(256 * alpha), 127 | int(512 * alpha), 128 | 3, 129 | stride=2, 130 | padding=1, 131 | bias=False 132 | ), 133 | 134 | DepthSeperabelConv2d( 135 | int(512 * alpha), 136 | int(512 * alpha), 137 | 3, 138 | padding=1, 139 | bias=False 140 | ), 141 | DepthSeperabelConv2d( 142 | int(512 * alpha), 143 | int(512 * alpha), 144 | 3, 145 | padding=1, 146 | bias=False 147 | ), 148 | DepthSeperabelConv2d( 149 | int(512 * alpha), 150 | int(512 * alpha), 151 | 3, 152 | padding=1, 153 | bias=False 154 | ), 155 | DepthSeperabelConv2d( 156 | int(512 * alpha), 157 | int(512 * alpha), 158 | 3, 159 | padding=1, 160 | bias=False 161 | ), 162 | DepthSeperabelConv2d( 163 | int(512 * alpha), 164 | int(512 * alpha), 165 | 3, 166 | padding=1, 167 | bias=False 168 | ) 169 | ) 170 | 171 | # downsample 172 | self.conv4 = nn.Sequential( 173 | DepthSeperabelConv2d( 174 | int(512 * alpha), 175 | int(1024 * alpha), 176 | 3, 177 | stride=2, 178 | padding=1, 179 | bias=False 180 | ), 181 | DepthSeperabelConv2d( 182 | int(1024 * alpha), 183 | int(1024 * alpha), 184 | 3, 185 | padding=1, 186 | bias=False 187 | ) 188 | ) 189 | 190 | self.fc = nn.Linear(int(1024 * alpha), class_num) 191 | self.avg = nn.AdaptiveAvgPool2d(1) 192 | 193 | def forward(self, x): 194 | x = self.stem(x) 195 | 196 | x = self.conv1(x) 197 | x = self.conv2(x) 198 | x = self.conv3(x) 199 | x = self.conv4(x) 200 | 201 | x = self.avg(x) 202 | x = x.view(x.size(0), -1) 203 | x = self.fc(x) 204 | return x 205 | 206 | 207 | def mobilenet(alpha=1, class_num=100): 208 | return MobileNet(alpha, class_num) 209 | -------------------------------------------------------------------------------- /models/xception.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SeperableConv2d(nn.Module): 6 | 7 | # ***Figure 4. An “extreme” version of our Inception module, 8 | # with one spatial convolution per output channel of the 1x1 9 | # convolution.""" 10 | def __init__(self, input_channels, output_channels, kernel_size, **kwargs): 11 | super().__init__() 12 | self.depthwise = nn.Conv2d( 13 | input_channels, 14 | input_channels, 15 | kernel_size, 16 | groups=input_channels, 17 | bias=False, 18 | **kwargs 19 | ) 20 | 21 | self.pointwise = nn.Conv2d(input_channels, output_channels, 1, bias=False) 22 | 23 | def forward(self, x): 24 | x = self.depthwise(x) 25 | x = self.pointwise(x) 26 | 27 | return x 28 | 29 | 30 | class EntryFlow(nn.Module): 31 | 32 | def __init__(self, in_channels=3): 33 | super().__init__() 34 | self.conv1 = nn.Sequential( 35 | nn.Conv2d(in_channels, 32, 3, padding=1, bias=False), 36 | nn.BatchNorm2d(32), 37 | nn.ReLU(inplace=True) 38 | ) 39 | 40 | self.conv2 = nn.Sequential( 41 | nn.Conv2d(32, 64, 3, padding=1, bias=False), 42 | nn.BatchNorm2d(64), 43 | nn.ReLU(inplace=True) 44 | ) 45 | 46 | self.conv3_residual = nn.Sequential( 47 | SeperableConv2d(64, 128, 3, padding=1), 48 | nn.BatchNorm2d(128), 49 | nn.ReLU(inplace=True), 50 | SeperableConv2d(128, 128, 3, padding=1), 51 | nn.BatchNorm2d(128), 52 | nn.MaxPool2d(3, stride=2, padding=1), 53 | ) 54 | 55 | self.conv3_shortcut = nn.Sequential( 56 | nn.Conv2d(64, 128, 1, stride=2), 57 | nn.BatchNorm2d(128), 58 | ) 59 | 60 | self.conv4_residual = nn.Sequential( 61 | nn.ReLU(inplace=True), 62 | SeperableConv2d(128, 256, 3, padding=1), 63 | nn.BatchNorm2d(256), 64 | nn.ReLU(inplace=True), 65 | SeperableConv2d(256, 256, 3, padding=1), 66 | nn.BatchNorm2d(256), 67 | nn.MaxPool2d(3, stride=2, padding=1) 68 | ) 69 | 70 | self.conv4_shortcut = nn.Sequential( 71 | nn.Conv2d(128, 256, 1, stride=2), 72 | nn.BatchNorm2d(256), 73 | ) 74 | 75 | # no downsampling 76 | self.conv5_residual = nn.Sequential( 77 | nn.ReLU(inplace=True), 78 | SeperableConv2d(256, 728, 3, padding=1), 79 | nn.BatchNorm2d(728), 80 | nn.ReLU(inplace=True), 81 | SeperableConv2d(728, 728, 3, padding=1), 82 | nn.BatchNorm2d(728), 83 | nn.MaxPool2d(3, 1, padding=1) 84 | ) 85 | 86 | # no downsampling 87 | self.conv5_shortcut = nn.Sequential( 88 | nn.Conv2d(256, 728, 1), 89 | nn.BatchNorm2d(728) 90 | ) 91 | 92 | def forward(self, x): 93 | x = self.conv1(x) 94 | x = self.conv2(x) 95 | residual = self.conv3_residual(x) 96 | shortcut = self.conv3_shortcut(x) 97 | x = residual + shortcut 98 | residual = self.conv4_residual(x) 99 | shortcut = self.conv4_shortcut(x) 100 | x = residual + shortcut 101 | residual = self.conv5_residual(x) 102 | shortcut = self.conv5_shortcut(x) 103 | x = residual + shortcut 104 | 105 | return x 106 | 107 | 108 | class MiddleFLowBlock(nn.Module): 109 | 110 | def __init__(self): 111 | super().__init__() 112 | 113 | self.shortcut = nn.Sequential() 114 | self.conv1 = nn.Sequential( 115 | nn.ReLU(inplace=True), 116 | SeperableConv2d(728, 728, 3, padding=1), 117 | nn.BatchNorm2d(728) 118 | ) 119 | self.conv2 = nn.Sequential( 120 | nn.ReLU(inplace=True), 121 | SeperableConv2d(728, 728, 3, padding=1), 122 | nn.BatchNorm2d(728) 123 | ) 124 | self.conv3 = nn.Sequential( 125 | nn.ReLU(inplace=True), 126 | SeperableConv2d(728, 728, 3, padding=1), 127 | nn.BatchNorm2d(728) 128 | ) 129 | 130 | def forward(self, x): 131 | residual = self.conv1(x) 132 | residual = self.conv2(residual) 133 | residual = self.conv3(residual) 134 | 135 | shortcut = self.shortcut(x) 136 | 137 | return shortcut + residual 138 | 139 | 140 | class MiddleFlow(nn.Module): 141 | def __init__(self, block): 142 | super().__init__() 143 | 144 | # """then through the middle flow which is repeated eight times""" 145 | self.middel_block = self._make_flow(block, 8) 146 | 147 | def forward(self, x): 148 | x = self.middel_block(x) 149 | return x 150 | 151 | def _make_flow(self, block, times): 152 | flows = [] 153 | for i in range(times): 154 | flows.append(block()) 155 | 156 | return nn.Sequential(*flows) 157 | 158 | 159 | class ExitFLow(nn.Module): 160 | 161 | def __init__(self): 162 | super().__init__() 163 | self.residual = nn.Sequential( 164 | nn.ReLU(), 165 | SeperableConv2d(728, 728, 3, padding=1), 166 | nn.BatchNorm2d(728), 167 | nn.ReLU(), 168 | SeperableConv2d(728, 1024, 3, padding=1), 169 | nn.BatchNorm2d(1024), 170 | nn.MaxPool2d(3, stride=2, padding=1) 171 | ) 172 | 173 | self.shortcut = nn.Sequential( 174 | nn.Conv2d(728, 1024, 1, stride=2), 175 | nn.BatchNorm2d(1024) 176 | ) 177 | 178 | self.conv = nn.Sequential( 179 | SeperableConv2d(1024, 1536, 3, padding=1), 180 | nn.BatchNorm2d(1536), 181 | nn.ReLU(inplace=True), 182 | SeperableConv2d(1536, 2048, 3, padding=1), 183 | nn.BatchNorm2d(2048), 184 | nn.ReLU(inplace=True) 185 | ) 186 | 187 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 188 | 189 | def forward(self, x): 190 | shortcut = self.shortcut(x) 191 | residual = self.residual(x) 192 | output = shortcut + residual 193 | output = self.conv(output) 194 | output = self.avgpool(output) 195 | 196 | return output 197 | 198 | 199 | class Xception(nn.Module): 200 | 201 | def __init__(self, block, in_channels=3, num_classes=10): 202 | super().__init__() 203 | self.entry_flow = EntryFlow(in_channels) 204 | self.middel_flow = MiddleFlow(block) 205 | self.exit_flow = ExitFLow() 206 | 207 | self.fc = nn.Linear(2048, num_classes) 208 | 209 | def forward(self, x): 210 | x = self.entry_flow(x) 211 | x = self.middel_flow(x) 212 | x = self.exit_flow(x) 213 | x = x.view(x.size(0), -1) 214 | x = self.fc(x) 215 | 216 | return x 217 | 218 | 219 | def xception(in_channels=3, num_classes=10): 220 | return Xception(MiddleFLowBlock, in_channels=in_channels, num_classes=num_classes) 221 | -------------------------------------------------------------------------------- /models/efficientnetv2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | __all__ = ['effnetv2_s', 'effnetv2_m', 'effnetv2_l', 'effnetv2_xl'] 6 | 7 | 8 | def _make_divisible(v, divisor, min_value=None): 9 | """ 10 | This function is taken from the original tf repo. 11 | It ensures that all layers have a channel number that is divisible by 8 12 | It can be seen here: 13 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 14 | :param v: 15 | :param divisor: 16 | :param min_value: 17 | :return: 18 | """ 19 | if min_value is None: 20 | min_value = divisor 21 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 22 | # Make sure that round down does not go down by more than 10%. 23 | if new_v < 0.9 * v: 24 | new_v += divisor 25 | return new_v 26 | 27 | 28 | # SiLU (Swish) activation function 29 | if hasattr(nn, 'SiLU'): 30 | SiLU = nn.SiLU 31 | else: 32 | # For compatibility with old PyTorch versions 33 | class SiLU(nn.Module): 34 | def forward(self, x): 35 | return x * torch.sigmoid(x) 36 | 37 | 38 | class SELayer(nn.Module): 39 | def __init__(self, inp, oup, reduction=4): 40 | super(SELayer, self).__init__() 41 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 42 | self.fc = nn.Sequential( 43 | nn.Linear(oup, _make_divisible(inp // reduction, 8)), 44 | SiLU(), 45 | nn.Linear(_make_divisible(inp // reduction, 8), oup), 46 | nn.Sigmoid() 47 | ) 48 | 49 | def forward(self, x): 50 | b, c, _, _ = x.size() 51 | y = self.avg_pool(x).view(b, c) 52 | y = self.fc(y).view(b, c, 1, 1) 53 | return x * y 54 | 55 | 56 | def conv_3x3_bn(inp, oup, stride): 57 | return nn.Sequential( 58 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 59 | nn.BatchNorm2d(oup), 60 | SiLU() 61 | ) 62 | 63 | 64 | def conv_1x1_bn(inp, oup): 65 | return nn.Sequential( 66 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 67 | nn.BatchNorm2d(oup), 68 | SiLU() 69 | ) 70 | 71 | 72 | class MBConv(nn.Module): 73 | def __init__(self, inp, oup, stride, expand_ratio, use_se): 74 | super(MBConv, self).__init__() 75 | assert stride in [1, 2] 76 | 77 | hidden_dim = round(inp * expand_ratio) 78 | self.identity = stride == 1 and inp == oup 79 | if use_se: 80 | self.conv = nn.Sequential( 81 | # pw 82 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 83 | nn.BatchNorm2d(hidden_dim), 84 | SiLU(), 85 | # dw 86 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 87 | nn.BatchNorm2d(hidden_dim), 88 | SiLU(), 89 | SELayer(inp, hidden_dim), 90 | # pw-linear 91 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 92 | nn.BatchNorm2d(oup), 93 | ) 94 | else: 95 | self.conv = nn.Sequential( 96 | # fused 97 | nn.Conv2d(inp, hidden_dim, 3, stride, 1, bias=False), 98 | nn.BatchNorm2d(hidden_dim), 99 | SiLU(), 100 | # pw-linear 101 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 102 | nn.BatchNorm2d(oup), 103 | ) 104 | 105 | def forward(self, x): 106 | if self.identity: 107 | return x + self.conv(x) 108 | else: 109 | return self.conv(x) 110 | 111 | 112 | class EffNetV2(nn.Module): 113 | def __init__(self, cfgs, in_channels=3, num_classes=1000, width_mult=1.): 114 | super(EffNetV2, self).__init__() 115 | self.cfgs = cfgs 116 | 117 | # building first layer 118 | input_channel = _make_divisible(24 * width_mult, 8) 119 | layers = [conv_3x3_bn(in_channels, input_channel, 2)] 120 | # building inverted residual blocks 121 | block = MBConv 122 | for t, c, n, s, use_se in self.cfgs: 123 | output_channel = _make_divisible(c * width_mult, 8) 124 | for i in range(n): 125 | layers.append(block(input_channel, output_channel, s if i == 0 else 1, t, use_se)) 126 | input_channel = output_channel 127 | self.features = nn.Sequential(*layers) 128 | # building last several layers 129 | output_channel = _make_divisible(1792 * width_mult, 8) if width_mult > 1.0 else 1792 130 | self.conv = conv_1x1_bn(input_channel, output_channel) 131 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 132 | self.classifier = nn.Linear(output_channel, num_classes) 133 | 134 | self._initialize_weights() 135 | 136 | def forward(self, x): 137 | x = self.features(x) 138 | x = self.conv(x) 139 | x = self.avgpool(x) 140 | x = x.view(x.size(0), -1) 141 | x = self.classifier(x) 142 | return x 143 | 144 | def _initialize_weights(self): 145 | for m in self.modules(): 146 | if isinstance(m, nn.Conv2d): 147 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 148 | m.weight.data.normal_(0, math.sqrt(2. / n)) 149 | if m.bias is not None: 150 | m.bias.data.zero_() 151 | elif isinstance(m, nn.BatchNorm2d): 152 | m.weight.data.fill_(1) 153 | m.bias.data.zero_() 154 | elif isinstance(m, nn.Linear): 155 | m.weight.data.normal_(0, 0.001) 156 | m.bias.data.zero_() 157 | 158 | 159 | def effnetv2_s(in_channels=3, num_classes=10): 160 | """ 161 | Constructs a EfficientNetV2-S model 162 | """ 163 | cfgs = [ 164 | # t, c, n, s, SE 165 | [1, 24, 2, 1, 0], 166 | [4, 48, 4, 2, 0], 167 | [4, 64, 4, 2, 0], 168 | [4, 128, 6, 2, 1], 169 | [6, 160, 9, 1, 1], 170 | [6, 256, 15, 2, 1], 171 | ] 172 | return EffNetV2(cfgs, in_channels=in_channels, num_classes=num_classes) 173 | 174 | 175 | def effnetv2_m(**kwargs): 176 | """ 177 | Constructs a EfficientNetV2-M model 178 | """ 179 | cfgs = [ 180 | # t, c, n, s, SE 181 | [1, 24, 3, 1, 0], 182 | [4, 48, 5, 2, 0], 183 | [4, 80, 5, 2, 0], 184 | [4, 160, 7, 2, 1], 185 | [6, 176, 14, 1, 1], 186 | [6, 304, 18, 2, 1], 187 | [6, 512, 5, 1, 1], 188 | ] 189 | return EffNetV2(cfgs, **kwargs) 190 | 191 | 192 | def effnetv2_l(in_channels=3, num_classes=10): 193 | """ 194 | Constructs a EfficientNetV2-L model 195 | """ 196 | cfgs = [ 197 | # t, c, n, s, SE 198 | [1, 32, 4, 1, 0], 199 | [4, 64, 7, 2, 0], 200 | [4, 96, 7, 2, 0], 201 | [4, 192, 10, 2, 1], 202 | [6, 224, 19, 1, 1], 203 | [6, 384, 25, 2, 1], 204 | [6, 640, 7, 1, 1], 205 | ] 206 | return EffNetV2(cfgs, in_channels=in_channels, num_classes=num_classes) 207 | 208 | 209 | def effnetv2_xl(**kwargs): 210 | """ 211 | Constructs a EfficientNetV2-XL model 212 | """ 213 | cfgs = [ 214 | # t, c, n, s, SE 215 | [1, 32, 4, 1, 0], 216 | [4, 64, 8, 2, 0], 217 | [4, 96, 8, 2, 0], 218 | [4, 192, 16, 2, 1], 219 | [6, 256, 24, 1, 1], 220 | [6, 512, 32, 2, 1], 221 | [6, 640, 8, 1, 1], 222 | ] 223 | return EffNetV2(cfgs, **kwargs) 224 | -------------------------------------------------------------------------------- /models/shufflenet.py: -------------------------------------------------------------------------------- 1 | """shufflenet in pytorch 2 | 3 | 4 | 5 | [1] Xiangyu Zhang, Xinyu Zhou, Mengxiao Lin, Jian Sun. 6 | 7 | ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices 8 | https://arxiv.org/abs/1707.01083v2 9 | """ 10 | 11 | from functools import partial 12 | 13 | import torch 14 | import torch.nn as nn 15 | 16 | 17 | class BasicConv2d(nn.Module): 18 | 19 | def __init__(self, input_channels, output_channels, kernel_size, **kwargs): 20 | super().__init__() 21 | self.conv = nn.Conv2d(input_channels, output_channels, kernel_size, **kwargs) 22 | self.bn = nn.BatchNorm2d(output_channels) 23 | self.relu = nn.ReLU(inplace=True) 24 | 25 | def forward(self, x): 26 | x = self.conv(x) 27 | x = self.bn(x) 28 | x = self.relu(x) 29 | return x 30 | 31 | 32 | class ChannelShuffle(nn.Module): 33 | 34 | def __init__(self, groups): 35 | super().__init__() 36 | self.groups = groups 37 | 38 | def forward(self, x): 39 | batchsize, channels, height, width = x.data.size() 40 | channels_per_group = int(channels / self.groups) 41 | 42 | # """suppose a convolutional layer with g groups whose output has 43 | # g x n channels; we first reshape the output channel dimension 44 | # into (g, n)""" 45 | x = x.view(batchsize, self.groups, channels_per_group, height, width) 46 | 47 | # """transposing and then flattening it back as the input of next layer.""" 48 | x = x.transpose(1, 2).contiguous() 49 | x = x.view(batchsize, -1, height, width) 50 | 51 | return x 52 | 53 | 54 | class DepthwiseConv2d(nn.Module): 55 | 56 | def __init__(self, input_channels, output_channels, kernel_size, **kwargs): 57 | super().__init__() 58 | self.depthwise = nn.Sequential( 59 | nn.Conv2d(input_channels, output_channels, kernel_size, **kwargs), 60 | nn.BatchNorm2d(output_channels) 61 | ) 62 | 63 | def forward(self, x): 64 | return self.depthwise(x) 65 | 66 | 67 | class PointwiseConv2d(nn.Module): 68 | def __init__(self, input_channels, output_channels, **kwargs): 69 | super().__init__() 70 | self.pointwise = nn.Sequential( 71 | nn.Conv2d(input_channels, output_channels, 1, **kwargs), 72 | nn.BatchNorm2d(output_channels) 73 | ) 74 | 75 | def forward(self, x): 76 | return self.pointwise(x) 77 | 78 | 79 | class ShuffleNetUnit(nn.Module): 80 | 81 | def __init__(self, input_channels, output_channels, stage, stride, groups): 82 | super().__init__() 83 | 84 | # """Similar to [9], we set the number of bottleneck channels to 1/4 85 | # of the output channels for each ShuffleNet unit.""" 86 | self.bottlneck = nn.Sequential( 87 | PointwiseConv2d( 88 | input_channels, 89 | int(output_channels / 4), 90 | groups=groups 91 | ), 92 | nn.ReLU(inplace=True) 93 | ) 94 | 95 | # """Note that for Stage 2, we do not apply group convolution on the first pointwise 96 | # layer because the number of input channels is relatively small.""" 97 | if stage == 2: 98 | self.bottlneck = nn.Sequential( 99 | PointwiseConv2d( 100 | input_channels, 101 | int(output_channels / 4), 102 | groups=groups 103 | ), 104 | nn.ReLU(inplace=True) 105 | ) 106 | 107 | self.channel_shuffle = ChannelShuffle(groups) 108 | 109 | self.depthwise = DepthwiseConv2d( 110 | int(output_channels / 4), 111 | int(output_channels / 4), 112 | 3, 113 | groups=int(output_channels / 4), 114 | stride=stride, 115 | padding=1 116 | ) 117 | 118 | self.expand = PointwiseConv2d( 119 | int(output_channels / 4), 120 | output_channels, 121 | groups=groups 122 | ) 123 | 124 | self.relu = nn.ReLU(inplace=True) 125 | self.fusion = self._add 126 | self.shortcut = nn.Sequential() 127 | 128 | # """As for the case where ShuffleNet is applied with stride, 129 | # we simply make two modifications (see Fig 2 (c)): 130 | # (i) add a 3 × 3 average pooling on the shortcut path; 131 | # (ii) replace the element-wise addition with channel concatenation, 132 | # which makes it easy to enlarge channel dimension with little extra 133 | # computation cost. 134 | if stride != 1 or input_channels != output_channels: 135 | self.shortcut = nn.AvgPool2d(3, stride=2, padding=1) 136 | 137 | self.expand = PointwiseConv2d( 138 | int(output_channels / 4), 139 | output_channels - input_channels, 140 | groups=groups 141 | ) 142 | 143 | self.fusion = self._cat 144 | 145 | def _add(self, x, y): 146 | return torch.add(x, y) 147 | 148 | def _cat(self, x, y): 149 | return torch.cat([x, y], dim=1) 150 | 151 | def forward(self, x): 152 | shortcut = self.shortcut(x) 153 | 154 | shuffled = self.bottlneck(x) 155 | shuffled = self.channel_shuffle(shuffled) 156 | shuffled = self.depthwise(shuffled) 157 | shuffled = self.expand(shuffled) 158 | 159 | output = self.fusion(shortcut, shuffled) 160 | output = self.relu(output) 161 | 162 | return output 163 | 164 | 165 | class ShuffleNet(nn.Module): 166 | 167 | def __init__(self, num_blocks, num_classes=10, groups=3): 168 | super().__init__() 169 | 170 | if groups == 1: 171 | out_channels = [24, 144, 288, 567] 172 | elif groups == 2: 173 | out_channels = [24, 200, 400, 800] 174 | elif groups == 3: 175 | out_channels = [24, 240, 480, 960] 176 | elif groups == 4: 177 | out_channels = [24, 272, 544, 1088] 178 | elif groups == 8: 179 | out_channels = [24, 384, 768, 1536] 180 | 181 | self.conv1 = BasicConv2d(3, out_channels[0], 3, padding=1, stride=1) 182 | self.input_channels = out_channels[0] 183 | 184 | self.stage2 = self._make_stage( 185 | ShuffleNetUnit, 186 | num_blocks[0], 187 | out_channels[1], 188 | stride=2, 189 | stage=2, 190 | groups=groups 191 | ) 192 | 193 | self.stage3 = self._make_stage( 194 | ShuffleNetUnit, 195 | num_blocks[1], 196 | out_channels[2], 197 | stride=2, 198 | stage=3, 199 | groups=groups 200 | ) 201 | 202 | self.stage4 = self._make_stage( 203 | ShuffleNetUnit, 204 | num_blocks[2], 205 | out_channels[3], 206 | stride=2, 207 | stage=4, 208 | groups=groups 209 | ) 210 | 211 | self.avg = nn.AdaptiveAvgPool2d((1, 1)) 212 | self.fc = nn.Linear(out_channels[3], num_classes) 213 | 214 | def forward(self, x): 215 | x = self.conv1(x) 216 | x = self.stage2(x) 217 | x = self.stage3(x) 218 | x = self.stage4(x) 219 | x = self.avg(x) 220 | x = x.view(x.size(0), -1) 221 | x = self.fc(x) 222 | 223 | return x 224 | 225 | def _make_stage(self, block, num_blocks, output_channels, stride, stage, groups): 226 | """make shufflenet stage 227 | 228 | Args: 229 | block: block type, shuffle unit 230 | out_channels: output depth channel number of this stage 231 | num_blocks: how many blocks per stage 232 | stride: the stride of the first block of this stage 233 | stage: stage index 234 | groups: group number of group convolution 235 | Return: 236 | return a shuffle net stage 237 | """ 238 | strides = [stride] + [1] * (num_blocks - 1) 239 | 240 | stage = [] 241 | 242 | for stride in strides: 243 | stage.append( 244 | block( 245 | self.input_channels, 246 | output_channels, 247 | stride=stride, 248 | stage=stage, 249 | groups=groups 250 | ) 251 | ) 252 | self.input_channels = output_channels 253 | 254 | return nn.Sequential(*stage) 255 | 256 | 257 | def shufflenet(): 258 | return ShuffleNet([4, 8, 4]) 259 | -------------------------------------------------------------------------------- /train_model_doctor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import time 4 | from tqdm import tqdm 5 | 6 | import torch 7 | from torch import nn 8 | from torch import optim 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | import loaders 12 | import models 13 | import metrics 14 | from utils.train_util import AverageMeter, ProgressMeter 15 | from core.grad_constraint import GradConstraint 16 | from core.grad_noise import GradNoise 17 | 18 | 19 | def main(): 20 | parser = argparse.ArgumentParser(description='') 21 | parser.add_argument('--model_name', default='', type=str, help='model name') 22 | parser.add_argument('--data_name', default='', type=str, help='data name') 23 | parser.add_argument('--in_channels', default='', type=int, help='in channels') 24 | parser.add_argument('--num_classes', default='', type=int, help='num classes') 25 | parser.add_argument('--num_epochs', default=200, type=int, help='num epochs') 26 | parser.add_argument('--ori_model_path', default='', type=str, help='original model path') 27 | parser.add_argument('--res_model_path', default='', type=str, help='result model path') 28 | parser.add_argument('--data_dir', default='', type=str, help='data dir') 29 | parser.add_argument('--log_dir', default='', type=str, help='log dir') 30 | parser.add_argument('--mask_dir', default=None, type=str, help='mask dir') 31 | parser.add_argument('--grad_dir', default='', type=str, help='grad dir') 32 | parser.add_argument('--alpha', default=0, type=float, help='weight coefficient for channel loss') 33 | parser.add_argument('--beta', default=0, type=float, help='weight coefficient for spatial loss') 34 | parser.add_argument('--device_index', default='0', type=str, help='device index') 35 | args = parser.parse_args() 36 | 37 | # ---------------------------------------- 38 | # basic configuration 39 | # ---------------------------------------- 40 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device_index 41 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 42 | 43 | train_path = os.path.join(args.data_dir, 'train') 44 | test_path = os.path.join(args.data_dir, 'test') 45 | mask_path = args.mask_dir # for train set 46 | grad_path = os.path.join(args.grad_dir, 'layer_0.npy') 47 | 48 | print('-' * 50) 49 | print('TRAIN ON:', device) 50 | print('ORI PATH:', args.ori_model_path) 51 | print('RES PATH:', args.res_model_path) 52 | print('LOG DIR:', args.log_dir) 53 | print('-' * 50) 54 | 55 | # ---------------------------------------- 56 | # trainer configuration 57 | # ---------------------------------------- 58 | model = models.load_model(model_name=args.model_name, in_channels=args.in_channels, num_classes=args.num_classes) 59 | model.to(device) 60 | module = models.load_modules(model=model)[0] 61 | 62 | if args.beta == 0: 63 | train_loader = loaders.load_data(args.data_name, train_path, data_type='train') 64 | test_loader = loaders.load_data(args.data_name, test_path, data_type='test') 65 | else: # load training set with mask 66 | train_loader = loaders.load_data_mask(args.data_name, train_path, mask_path, data_type='train') 67 | test_loader = loaders.load_data(args.data_name, test_path, data_type='test') 68 | 69 | criterion = nn.CrossEntropyLoss() 70 | optimizer = optim.SGD(params=model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) 71 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=args.num_epochs) 72 | 73 | writer = SummaryWriter(args.log_dir) 74 | 75 | # ---------------------------------------- 76 | # model doctor configuration 77 | # ---------------------------------------- 78 | constraint = GradConstraint(module=module, grad_path=grad_path, alpha=args.alpha, beta=args.beta) 79 | noise = GradNoise(module=module) 80 | 81 | # ---------------------------------------- 82 | # each epoch 83 | # ---------------------------------------- 84 | since = time.time() 85 | 86 | best_acc = None 87 | best_epoch = None 88 | 89 | for epoch in tqdm(range(args.num_epochs)): 90 | # noise.add_noise() 91 | loss_cls, loss_c, loss_s, acc1, acc5 = train(train_loader, model, criterion, constraint, optimizer, device) 92 | writer.add_scalar(tag='training loss cls', scalar_value=loss_cls.avg, global_step=epoch) 93 | writer.add_scalar(tag='training loss c', scalar_value=loss_c.avg, global_step=epoch) 94 | writer.add_scalar(tag='training loss s', scalar_value=loss_s.avg, global_step=epoch) 95 | writer.add_scalar(tag='training acc1', scalar_value=acc1.avg, global_step=epoch) 96 | # noise.remove_noise() 97 | loss_cls, loss_c, loss_s, acc1, acc5 = test(test_loader, model, criterion, constraint, device) 98 | writer.add_scalar(tag='test loss cls', scalar_value=loss_cls.avg, global_step=epoch) 99 | writer.add_scalar(tag='test loss c', scalar_value=loss_c.avg, global_step=epoch) 100 | writer.add_scalar(tag='test loss s', scalar_value=loss_s.avg, global_step=epoch) 101 | writer.add_scalar(tag='test acc1', scalar_value=acc1.avg, global_step=epoch) 102 | 103 | # ---------------------------------------- 104 | # save best model 105 | # ---------------------------------------- 106 | if best_acc is None or best_acc < acc1.avg: 107 | best_acc = acc1.avg 108 | best_epoch = epoch 109 | torch.save(model.state_dict(), args.res_model_path) 110 | 111 | scheduler.step() 112 | 113 | print('COMPLETE !!!') 114 | print('BEST ACC', best_acc) 115 | print('BEST EPOCH', best_epoch) 116 | print('TIME CONSUMED', time.time() - since) 117 | 118 | 119 | def train(train_loader, model, criterion, constraint, optimizer, device): 120 | loss_cls_meter = AverageMeter('Loss CLS', ':.4e') 121 | loss_c_meter = AverageMeter('Loss C', ':.4e') # channel 122 | loss_s_meter = AverageMeter('Loss S', ':.4e') # spatial 123 | acc1_meter = AverageMeter('Acc@1', ':6.2f') 124 | acc5_meter = AverageMeter('Acc@5', ':6.2f') 125 | progress = ProgressMeter(total=len(train_loader), step=20, prefix='Training', 126 | meters=[loss_cls_meter, loss_c_meter, loss_s_meter, acc1_meter, acc5_meter]) 127 | model.train() 128 | 129 | for i, samples in enumerate(train_loader): 130 | inputs, labels, xxx = samples 131 | inputs = inputs.to(device) 132 | labels = labels.to(device) 133 | 134 | outputs = model(inputs) 135 | loss_cls = criterion(outputs, labels) 136 | loss_c = constraint.loss_channel(outputs, labels) 137 | loss_s = constraint.loss_spatial(outputs, labels, xxx) 138 | acc1, acc5 = metrics.accuracy(outputs, labels, topk=(1, 5)) 139 | 140 | loss_cls_meter.update(loss_cls.item(), inputs.size(0)) 141 | loss_c_meter.update(loss_c.item(), inputs.size(0)) 142 | loss_s_meter.update(loss_s.item(), inputs.size(0)) 143 | acc1_meter.update(acc1.item(), inputs.size(0)) 144 | acc5_meter.update(acc5.item(), inputs.size(0)) 145 | 146 | optimizer.zero_grad() # 1 147 | loss = loss_cls + loss_c + loss_s 148 | loss.backward() # 2 149 | optimizer.step() # 3 150 | 151 | progress.display(i) 152 | 153 | return loss_cls_meter, loss_c_meter, loss_s_meter, acc1_meter, acc5_meter 154 | 155 | 156 | def test(test_loader, model, criterion, constraint, device): 157 | loss_cls_meter = AverageMeter('Loss CLS', ':.4e') 158 | loss_c_meter = AverageMeter('Loss C', ':.4e') # channel 159 | loss_s_meter = AverageMeter('Loss S', ':.4e') # spatial 160 | acc1_meter = AverageMeter('Acc@1', ':6.2f') 161 | acc5_meter = AverageMeter('Acc@5', ':6.2f') 162 | progress = ProgressMeter(total=len(test_loader), step=20, prefix='Training', 163 | meters=[loss_cls_meter, loss_c_meter, loss_s_meter, acc1_meter, acc5_meter]) 164 | model.eval() 165 | 166 | for i, samples in enumerate(test_loader): 167 | inputs, labels, xxx = samples 168 | inputs = inputs.to(device) 169 | labels = labels.to(device) 170 | 171 | # with torch.set_grad_enabled(False): 172 | outputs = model(inputs) 173 | loss_cls = criterion(outputs, labels) 174 | loss_c = constraint.loss_channel(outputs, labels) 175 | loss_s = constraint.loss_spatial(outputs, labels, xxx) 176 | acc1, acc5 = metrics.accuracy(outputs, labels, topk=(1, 5)) 177 | 178 | loss_cls_meter.update(loss_cls.item(), inputs.size(0)) 179 | loss_c_meter.update(loss_c.item(), inputs.size(0)) 180 | loss_s_meter.update(loss_s.item(), inputs.size(0)) 181 | acc1_meter.update(acc1.item(), inputs.size(0)) 182 | acc5_meter.update(acc5.item(), inputs.size(0)) 183 | 184 | progress.display(i) 185 | 186 | return loss_cls_meter, loss_c_meter, loss_s_meter, acc1_meter, acc5_meter 187 | 188 | 189 | if __name__ == '__main__': 190 | main() 191 | -------------------------------------------------------------------------------- /models/inceptionv3.py: -------------------------------------------------------------------------------- 1 | """ inceptionv3 in pytorch 2 | 3 | 4 | [1] Christian Szegedy, Vincent Vanhoucke, Sergey Ioffe, Jonathon Shlens, Zbigniew Wojna 5 | 6 | Rethinking the Inception Architecture for Computer Vision 7 | https://arxiv.org/abs/1512.00567v3 8 | """ 9 | 10 | import torch 11 | import torch.nn as nn 12 | 13 | 14 | class BasicConv2d(nn.Module): 15 | 16 | def __init__(self, input_channels, output_channels, **kwargs): 17 | super().__init__() 18 | self.conv = nn.Conv2d(input_channels, output_channels, bias=False, **kwargs) 19 | self.bn = nn.BatchNorm2d(output_channels) 20 | self.relu = nn.ReLU(inplace=True) 21 | 22 | def forward(self, x): 23 | x = self.conv(x) 24 | x = self.bn(x) 25 | x = self.relu(x) 26 | 27 | return x 28 | 29 | 30 | # same naive inception module 31 | class InceptionA(nn.Module): 32 | 33 | def __init__(self, input_channels, pool_features): 34 | super().__init__() 35 | self.branch1x1 = BasicConv2d(input_channels, 64, kernel_size=1) 36 | 37 | self.branch5x5 = nn.Sequential( 38 | BasicConv2d(input_channels, 48, kernel_size=1), 39 | BasicConv2d(48, 64, kernel_size=5, padding=2) 40 | ) 41 | 42 | self.branch3x3 = nn.Sequential( 43 | BasicConv2d(input_channels, 64, kernel_size=1), 44 | BasicConv2d(64, 96, kernel_size=3, padding=1), 45 | BasicConv2d(96, 96, kernel_size=3, padding=1) 46 | ) 47 | 48 | self.branchpool = nn.Sequential( 49 | nn.AvgPool2d(kernel_size=3, stride=1, padding=1), 50 | BasicConv2d(input_channels, pool_features, kernel_size=3, padding=1) 51 | ) 52 | 53 | def forward(self, x): 54 | # x -> 1x1(same) 55 | branch1x1 = self.branch1x1(x) 56 | 57 | # x -> 1x1 -> 5x5(same) 58 | branch5x5 = self.branch5x5(x) 59 | # branch5x5 = self.branch5x5_2(branch5x5) 60 | 61 | # x -> 1x1 -> 3x3 -> 3x3(same) 62 | branch3x3 = self.branch3x3(x) 63 | 64 | # x -> pool -> 1x1(same) 65 | branchpool = self.branchpool(x) 66 | 67 | outputs = [branch1x1, branch5x5, branch3x3, branchpool] 68 | 69 | return torch.cat(outputs, 1) 70 | 71 | 72 | # downsample 73 | # Factorization into smaller convolutions 74 | class InceptionB(nn.Module): 75 | 76 | def __init__(self, input_channels): 77 | super().__init__() 78 | 79 | self.branch3x3 = BasicConv2d(input_channels, 384, kernel_size=3, stride=2) 80 | 81 | self.branch3x3stack = nn.Sequential( 82 | BasicConv2d(input_channels, 64, kernel_size=1), 83 | BasicConv2d(64, 96, kernel_size=3, padding=1), 84 | BasicConv2d(96, 96, kernel_size=3, stride=2) 85 | ) 86 | 87 | self.branchpool = nn.MaxPool2d(kernel_size=3, stride=2) 88 | 89 | def forward(self, x): 90 | # x - > 3x3(downsample) 91 | branch3x3 = self.branch3x3(x) 92 | 93 | # x -> 3x3 -> 3x3(downsample) 94 | branch3x3stack = self.branch3x3stack(x) 95 | 96 | # x -> avgpool(downsample) 97 | branchpool = self.branchpool(x) 98 | 99 | # """We can use two parallel stride 2 blocks: P and C. P is a pooling 100 | # layer (either average or maximum pooling) the activation, both of 101 | # them are stride 2 the filter banks of which are concatenated as in 102 | # figure 10.""" 103 | outputs = [branch3x3, branch3x3stack, branchpool] 104 | 105 | return torch.cat(outputs, 1) 106 | 107 | 108 | # Factorizing Convolutions with Large Filter Size 109 | class InceptionC(nn.Module): 110 | def __init__(self, input_channels, channels_7x7): 111 | super().__init__() 112 | self.branch1x1 = BasicConv2d(input_channels, 192, kernel_size=1) 113 | 114 | c7 = channels_7x7 115 | 116 | # In theory, we could go even further and argue that one can replace any n × n 117 | # convolution by a 1 × n convolution followed by a n × 1 convolution and the 118 | # computational cost saving increases dramatically as n grows (see figure 6). 119 | self.branch7x7 = nn.Sequential( 120 | BasicConv2d(input_channels, c7, kernel_size=1), 121 | BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)), 122 | BasicConv2d(c7, 192, kernel_size=(1, 7), padding=(0, 3)) 123 | ) 124 | 125 | self.branch7x7stack = nn.Sequential( 126 | BasicConv2d(input_channels, c7, kernel_size=1), 127 | BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)), 128 | BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3)), 129 | BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)), 130 | BasicConv2d(c7, 192, kernel_size=(1, 7), padding=(0, 3)) 131 | ) 132 | 133 | self.branch_pool = nn.Sequential( 134 | nn.AvgPool2d(kernel_size=3, stride=1, padding=1), 135 | BasicConv2d(input_channels, 192, kernel_size=1), 136 | ) 137 | 138 | def forward(self, x): 139 | # x -> 1x1(same) 140 | branch1x1 = self.branch1x1(x) 141 | 142 | # x -> 1layer 1*7 and 7*1 (same) 143 | branch7x7 = self.branch7x7(x) 144 | 145 | # x-> 2layer 1*7 and 7*1(same) 146 | branch7x7stack = self.branch7x7stack(x) 147 | 148 | # x-> avgpool (same) 149 | branchpool = self.branch_pool(x) 150 | 151 | outputs = [branch1x1, branch7x7, branch7x7stack, branchpool] 152 | 153 | return torch.cat(outputs, 1) 154 | 155 | 156 | class InceptionD(nn.Module): 157 | 158 | def __init__(self, input_channels): 159 | super().__init__() 160 | 161 | self.branch3x3 = nn.Sequential( 162 | BasicConv2d(input_channels, 192, kernel_size=1), 163 | BasicConv2d(192, 320, kernel_size=3, stride=2) 164 | ) 165 | 166 | self.branch7x7 = nn.Sequential( 167 | BasicConv2d(input_channels, 192, kernel_size=1), 168 | BasicConv2d(192, 192, kernel_size=(1, 7), padding=(0, 3)), 169 | BasicConv2d(192, 192, kernel_size=(7, 1), padding=(3, 0)), 170 | BasicConv2d(192, 192, kernel_size=3, stride=2) 171 | ) 172 | 173 | self.branchpool = nn.AvgPool2d(kernel_size=3, stride=2) 174 | 175 | def forward(self, x): 176 | # x -> 1x1 -> 3x3(downsample) 177 | branch3x3 = self.branch3x3(x) 178 | 179 | # x -> 1x1 -> 1x7 -> 7x1 -> 3x3 (downsample) 180 | branch7x7 = self.branch7x7(x) 181 | 182 | # x -> avgpool (downsample) 183 | branchpool = self.branchpool(x) 184 | 185 | outputs = [branch3x3, branch7x7, branchpool] 186 | 187 | return torch.cat(outputs, 1) 188 | 189 | 190 | # same 191 | class InceptionE(nn.Module): 192 | def __init__(self, input_channels): 193 | super().__init__() 194 | self.branch1x1 = BasicConv2d(input_channels, 320, kernel_size=1) 195 | 196 | self.branch3x3_1 = BasicConv2d(input_channels, 384, kernel_size=1) 197 | self.branch3x3_2a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1)) 198 | self.branch3x3_2b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0)) 199 | 200 | self.branch3x3stack_1 = BasicConv2d(input_channels, 448, kernel_size=1) 201 | self.branch3x3stack_2 = BasicConv2d(448, 384, kernel_size=3, padding=1) 202 | self.branch3x3stack_3a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1)) 203 | self.branch3x3stack_3b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0)) 204 | 205 | self.branch_pool = nn.Sequential( 206 | nn.AvgPool2d(kernel_size=3, stride=1, padding=1), 207 | BasicConv2d(input_channels, 192, kernel_size=1) 208 | ) 209 | 210 | def forward(self, x): 211 | # x -> 1x1 (same) 212 | branch1x1 = self.branch1x1(x) 213 | 214 | # x -> 1x1 -> 3x1 215 | # x -> 1x1 -> 1x3 216 | # concatenate(3x1, 1x3) 217 | # """7. Inception modules with expanded the filter bank outputs. 218 | # This architecture is used on the coarsest (8 × 8) grids to promote 219 | # high dimensional representations, as suggested by principle 220 | # 2 of Section 2.""" 221 | branch3x3 = self.branch3x3_1(x) 222 | branch3x3 = [ 223 | self.branch3x3_2a(branch3x3), 224 | self.branch3x3_2b(branch3x3) 225 | ] 226 | branch3x3 = torch.cat(branch3x3, 1) 227 | 228 | # x -> 1x1 -> 3x3 -> 1x3 229 | # x -> 1x1 -> 3x3 -> 3x1 230 | # concatenate(1x3, 3x1) 231 | branch3x3stack = self.branch3x3stack_1(x) 232 | branch3x3stack = self.branch3x3stack_2(branch3x3stack) 233 | branch3x3stack = [ 234 | self.branch3x3stack_3a(branch3x3stack), 235 | self.branch3x3stack_3b(branch3x3stack) 236 | ] 237 | branch3x3stack = torch.cat(branch3x3stack, 1) 238 | 239 | branchpool = self.branch_pool(x) 240 | 241 | outputs = [branch1x1, branch3x3, branch3x3stack, branchpool] 242 | 243 | return torch.cat(outputs, 1) 244 | 245 | 246 | class InceptionV3(nn.Module): 247 | 248 | def __init__(self, in_channels=3, num_classes=10): 249 | super().__init__() 250 | self.Conv2d_1a_3x3 = BasicConv2d(in_channels, 32, kernel_size=3, padding=1) 251 | self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3, padding=1) 252 | self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1) 253 | self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1) 254 | self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3) 255 | 256 | # naive inception module 257 | self.Mixed_5b = InceptionA(192, pool_features=32) 258 | self.Mixed_5c = InceptionA(256, pool_features=64) 259 | self.Mixed_5d = InceptionA(288, pool_features=64) 260 | 261 | # downsample 262 | self.Mixed_6a = InceptionB(288) 263 | 264 | self.Mixed_6b = InceptionC(768, channels_7x7=128) 265 | self.Mixed_6c = InceptionC(768, channels_7x7=160) 266 | self.Mixed_6d = InceptionC(768, channels_7x7=160) 267 | self.Mixed_6e = InceptionC(768, channels_7x7=192) 268 | 269 | # downsample 270 | self.Mixed_7a = InceptionD(768) 271 | 272 | self.Mixed_7b = InceptionE(1280) 273 | self.Mixed_7c = InceptionE(2048) 274 | 275 | # 6*6 feature size 276 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 277 | self.dropout = nn.Dropout2d() 278 | self.linear = nn.Linear(2048, num_classes) 279 | 280 | def forward(self, x): 281 | # 32 -> 30 282 | x = self.Conv2d_1a_3x3(x) 283 | x = self.Conv2d_2a_3x3(x) 284 | x = self.Conv2d_2b_3x3(x) 285 | x = self.Conv2d_3b_1x1(x) 286 | x = self.Conv2d_4a_3x3(x) 287 | 288 | # 30 -> 30 289 | x = self.Mixed_5b(x) 290 | x = self.Mixed_5c(x) 291 | x = self.Mixed_5d(x) 292 | 293 | # 30 -> 14 294 | # Efficient Grid Size Reduction to avoid representation 295 | # bottleneck 296 | x = self.Mixed_6a(x) 297 | 298 | # 14 -> 14 299 | # """In practice, we have found that employing this factorization does not 300 | # work well on early layers, but it gives very good results on medium 301 | # grid-sizes (On m × m feature maps, where m ranges between 12 and 20). 302 | # On that level, very good results can be achieved by using 1 × 7 convolutions 303 | # followed by 7 × 1 convolutions.""" 304 | x = self.Mixed_6b(x) 305 | x = self.Mixed_6c(x) 306 | x = self.Mixed_6d(x) 307 | x = self.Mixed_6e(x) 308 | 309 | # 14 -> 6 310 | # Efficient Grid Size Reduction 311 | x = self.Mixed_7a(x) 312 | 313 | # 6 -> 6 314 | # We are using this solution only on the coarsest grid, 315 | # since that is the place where producing high dimensional 316 | # sparse representation is the most critical as the ratio of 317 | # local processing (by 1 × 1 convolutions) is increased compared 318 | # to the spatial aggregation.""" 319 | x = self.Mixed_7b(x) 320 | x = self.Mixed_7c(x) 321 | 322 | # 6 -> 1 323 | x = self.avgpool(x) 324 | x = self.dropout(x) 325 | x = x.view(x.size(0), -1) 326 | x = self.linear(x) 327 | return x 328 | 329 | 330 | def inceptionv3(in_channels=3, num_classes=10): 331 | return InceptionV3(in_channels=in_channels, num_classes=num_classes) 332 | -------------------------------------------------------------------------------- /models/inceptionv4.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ inceptionv4 in pytorch 3 | 4 | 5 | [1] Christian Szegedy, Sergey Ioffe, Vincent Vanhoucke, Alex Alemi 6 | 7 | Inception-v4, Inception-ResNet and the Impact of Residual Connections on Learning 8 | https://arxiv.org/abs/1602.07261 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | 15 | class BasicConv2d(nn.Module): 16 | 17 | def __init__(self, input_channels, output_channels, **kwargs): 18 | super().__init__() 19 | self.conv = nn.Conv2d(input_channels, output_channels, bias=False, **kwargs) 20 | self.bn = nn.BatchNorm2d(output_channels) 21 | self.relu = nn.ReLU(inplace=True) 22 | 23 | def forward(self, x): 24 | x = self.conv(x) 25 | x = self.bn(x) 26 | x = self.relu(x) 27 | 28 | return x 29 | 30 | 31 | class Inception_Stem(nn.Module): 32 | 33 | # """Figure 3. The schema for stem of the pure Inception-v4 and 34 | # Inception-ResNet-v2 networks. This is the input part of those 35 | # networks.""" 36 | def __init__(self, input_channels): 37 | super().__init__() 38 | self.conv1 = nn.Sequential( 39 | BasicConv2d(input_channels, 32, kernel_size=3), 40 | BasicConv2d(32, 32, kernel_size=3, padding=1), 41 | BasicConv2d(32, 64, kernel_size=3, padding=1) 42 | ) 43 | 44 | self.branch3x3_conv = BasicConv2d(64, 96, kernel_size=3, padding=1) 45 | self.branch3x3_pool = nn.MaxPool2d(3, stride=1, padding=1) 46 | 47 | self.branch7x7a = nn.Sequential( 48 | BasicConv2d(160, 64, kernel_size=1), 49 | BasicConv2d(64, 64, kernel_size=(7, 1), padding=(3, 0)), 50 | BasicConv2d(64, 64, kernel_size=(1, 7), padding=(0, 3)), 51 | BasicConv2d(64, 96, kernel_size=3, padding=1) 52 | ) 53 | 54 | self.branch7x7b = nn.Sequential( 55 | BasicConv2d(160, 64, kernel_size=1), 56 | BasicConv2d(64, 96, kernel_size=3, padding=1) 57 | ) 58 | 59 | self.branchpoola = nn.MaxPool2d(kernel_size=3, stride=1, padding=1) 60 | self.branchpoolb = BasicConv2d(192, 192, kernel_size=3, stride=1, padding=1) 61 | 62 | def forward(self, x): 63 | x = self.conv1(x) 64 | 65 | x = [ 66 | self.branch3x3_conv(x), 67 | self.branch3x3_pool(x) 68 | ] 69 | x = torch.cat(x, 1) 70 | 71 | x = [ 72 | self.branch7x7a(x), 73 | self.branch7x7b(x) 74 | ] 75 | x = torch.cat(x, 1) 76 | 77 | x = [ 78 | self.branchpoola(x), 79 | self.branchpoolb(x) 80 | ] 81 | 82 | x = torch.cat(x, 1) 83 | 84 | return x 85 | 86 | 87 | class InceptionA(nn.Module): 88 | 89 | # """Figure 4. The schema for 35 × 35 grid modules of the pure 90 | # Inception-v4 network. This is the Inception-A block of Figure 9.""" 91 | def __init__(self, input_channels): 92 | super().__init__() 93 | 94 | self.branch3x3stack = nn.Sequential( 95 | BasicConv2d(input_channels, 64, kernel_size=1), 96 | BasicConv2d(64, 96, kernel_size=3, padding=1), 97 | BasicConv2d(96, 96, kernel_size=3, padding=1) 98 | ) 99 | 100 | self.branch3x3 = nn.Sequential( 101 | BasicConv2d(input_channels, 64, kernel_size=1), 102 | BasicConv2d(64, 96, kernel_size=3, padding=1) 103 | ) 104 | 105 | self.branch1x1 = BasicConv2d(input_channels, 96, kernel_size=1) 106 | 107 | self.branchpool = nn.Sequential( 108 | nn.AvgPool2d(kernel_size=3, stride=1, padding=1), 109 | BasicConv2d(input_channels, 96, kernel_size=1) 110 | ) 111 | 112 | def forward(self, x): 113 | x = [ 114 | self.branch3x3stack(x), 115 | self.branch3x3(x), 116 | self.branch1x1(x), 117 | self.branchpool(x) 118 | ] 119 | 120 | return torch.cat(x, 1) 121 | 122 | 123 | class ReductionA(nn.Module): 124 | 125 | # """Figure 7. The schema for 35 × 35 to 17 × 17 reduction module. 126 | # Different variants of this blocks (with various number of filters) 127 | # are used in Figure 9, and 15 in each of the new Inception(-v4, - ResNet-v1, 128 | # -ResNet-v2) variants presented in this paper. The k, l, m, n numbers 129 | # represent filter bank sizes which can be looked up in Table 1. 130 | def __init__(self, input_channels, k, l, m, n): 131 | super().__init__() 132 | self.branch3x3stack = nn.Sequential( 133 | BasicConv2d(input_channels, k, kernel_size=1), 134 | BasicConv2d(k, l, kernel_size=3, padding=1), 135 | BasicConv2d(l, m, kernel_size=3, stride=2) 136 | ) 137 | 138 | self.branch3x3 = BasicConv2d(input_channels, n, kernel_size=3, stride=2) 139 | self.branchpool = nn.MaxPool2d(kernel_size=3, stride=2) 140 | self.output_channels = input_channels + n + m 141 | 142 | def forward(self, x): 143 | x = [ 144 | self.branch3x3stack(x), 145 | self.branch3x3(x), 146 | self.branchpool(x) 147 | ] 148 | 149 | return torch.cat(x, 1) 150 | 151 | 152 | class InceptionB(nn.Module): 153 | 154 | # """Figure 5. The schema for 17 × 17 grid modules of the pure Inception-v4 network. 155 | # This is the Inception-B block of Figure 9.""" 156 | def __init__(self, input_channels): 157 | super().__init__() 158 | 159 | self.branch7x7stack = nn.Sequential( 160 | BasicConv2d(input_channels, 192, kernel_size=1), 161 | BasicConv2d(192, 192, kernel_size=(1, 7), padding=(0, 3)), 162 | BasicConv2d(192, 224, kernel_size=(7, 1), padding=(3, 0)), 163 | BasicConv2d(224, 224, kernel_size=(1, 7), padding=(0, 3)), 164 | BasicConv2d(224, 256, kernel_size=(7, 1), padding=(3, 0)) 165 | ) 166 | 167 | self.branch7x7 = nn.Sequential( 168 | BasicConv2d(input_channels, 192, kernel_size=1), 169 | BasicConv2d(192, 224, kernel_size=(1, 7), padding=(0, 3)), 170 | BasicConv2d(224, 256, kernel_size=(7, 1), padding=(3, 0)) 171 | ) 172 | 173 | self.branch1x1 = BasicConv2d(input_channels, 384, kernel_size=1) 174 | 175 | self.branchpool = nn.Sequential( 176 | nn.AvgPool2d(3, stride=1, padding=1), 177 | BasicConv2d(input_channels, 128, kernel_size=1) 178 | ) 179 | 180 | def forward(self, x): 181 | x = [ 182 | self.branch1x1(x), 183 | self.branch7x7(x), 184 | self.branch7x7stack(x), 185 | self.branchpool(x) 186 | ] 187 | 188 | return torch.cat(x, 1) 189 | 190 | 191 | class ReductionB(nn.Module): 192 | 193 | # """Figure 8. The schema for 17 × 17 to 8 × 8 grid-reduction mod- ule. 194 | # This is the reduction module used by the pure Inception-v4 network in 195 | # Figure 9.""" 196 | def __init__(self, input_channels): 197 | super().__init__() 198 | self.branch7x7 = nn.Sequential( 199 | BasicConv2d(input_channels, 256, kernel_size=1), 200 | BasicConv2d(256, 256, kernel_size=(1, 7), padding=(0, 3)), 201 | BasicConv2d(256, 320, kernel_size=(7, 1), padding=(3, 0)), 202 | BasicConv2d(320, 320, kernel_size=3, stride=2, padding=1) 203 | ) 204 | 205 | self.branch3x3 = nn.Sequential( 206 | BasicConv2d(input_channels, 192, kernel_size=1), 207 | BasicConv2d(192, 192, kernel_size=3, stride=2, padding=1) 208 | ) 209 | 210 | self.branchpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 211 | 212 | def forward(self, x): 213 | x = [ 214 | self.branch3x3(x), 215 | self.branch7x7(x), 216 | self.branchpool(x) 217 | ] 218 | 219 | return torch.cat(x, 1) 220 | 221 | 222 | class InceptionC(nn.Module): 223 | 224 | def __init__(self, input_channels): 225 | # """Figure 6. The schema for 8×8 grid modules of the pure 226 | # Inceptionv4 network. This is the Inception-C block of Figure 9.""" 227 | 228 | super().__init__() 229 | 230 | self.branch3x3stack = nn.Sequential( 231 | BasicConv2d(input_channels, 384, kernel_size=1), 232 | BasicConv2d(384, 448, kernel_size=(1, 3), padding=(0, 1)), 233 | BasicConv2d(448, 512, kernel_size=(3, 1), padding=(1, 0)), 234 | ) 235 | self.branch3x3stacka = BasicConv2d(512, 256, kernel_size=(1, 3), padding=(0, 1)) 236 | self.branch3x3stackb = BasicConv2d(512, 256, kernel_size=(3, 1), padding=(1, 0)) 237 | 238 | self.branch3x3 = BasicConv2d(input_channels, 384, kernel_size=1) 239 | self.branch3x3a = BasicConv2d(384, 256, kernel_size=(3, 1), padding=(1, 0)) 240 | self.branch3x3b = BasicConv2d(384, 256, kernel_size=(1, 3), padding=(0, 1)) 241 | 242 | self.branch1x1 = BasicConv2d(input_channels, 256, kernel_size=1) 243 | 244 | self.branchpool = nn.Sequential( 245 | nn.AvgPool2d(kernel_size=3, stride=1, padding=1), 246 | BasicConv2d(input_channels, 256, kernel_size=1) 247 | ) 248 | 249 | def forward(self, x): 250 | branch3x3stack_output = self.branch3x3stack(x) 251 | branch3x3stack_output = [ 252 | self.branch3x3stacka(branch3x3stack_output), 253 | self.branch3x3stackb(branch3x3stack_output) 254 | ] 255 | branch3x3stack_output = torch.cat(branch3x3stack_output, 1) 256 | 257 | branch3x3_output = self.branch3x3(x) 258 | branch3x3_output = [ 259 | self.branch3x3a(branch3x3_output), 260 | self.branch3x3b(branch3x3_output) 261 | ] 262 | branch3x3_output = torch.cat(branch3x3_output, 1) 263 | 264 | branch1x1_output = self.branch1x1(x) 265 | 266 | branchpool = self.branchpool(x) 267 | 268 | output = [ 269 | branch1x1_output, 270 | branch3x3_output, 271 | branch3x3stack_output, 272 | branchpool 273 | ] 274 | 275 | return torch.cat(output, 1) 276 | 277 | 278 | class InceptionV4(nn.Module): 279 | 280 | def __init__(self, A, B, C, k=192, l=224, m=256, n=384, class_nums=10): 281 | super().__init__() 282 | self.stem = Inception_Stem(3) 283 | self.inception_a = self._generate_inception_module(384, 384, A, InceptionA) 284 | self.reduction_a = ReductionA(384, k, l, m, n) 285 | output_channels = self.reduction_a.output_channels 286 | self.inception_b = self._generate_inception_module(output_channels, 1024, B, InceptionB) 287 | self.reduction_b = ReductionB(1024) 288 | self.inception_c = self._generate_inception_module(1536, 1536, C, InceptionC) 289 | self.avgpool = nn.AvgPool2d(7) 290 | 291 | # """Dropout (keep 0.8)""" 292 | self.dropout = nn.Dropout2d(1 - 0.8) 293 | self.linear = nn.Linear(1536, class_nums) 294 | 295 | def forward(self, x): 296 | x = self.stem(x) 297 | x = self.inception_a(x) 298 | x = self.reduction_a(x) 299 | x = self.inception_b(x) 300 | x = self.reduction_b(x) 301 | x = self.inception_c(x) 302 | x = self.avgpool(x) 303 | x = self.dropout(x) 304 | x = x.view(-1, 1536) 305 | x = self.linear(x) 306 | 307 | return x 308 | 309 | @staticmethod 310 | def _generate_inception_module(input_channels, output_channels, block_num, block): 311 | layers = nn.Sequential() 312 | for l in range(block_num): 313 | layers.add_module("{}_{}".format(block.__name__, l), block(input_channels)) 314 | input_channels = output_channels 315 | 316 | return layers 317 | 318 | 319 | class InceptionResNetA(nn.Module): 320 | 321 | # """Figure 16. The schema for 35 × 35 grid (Inception-ResNet-A) 322 | # module of the Inception-ResNet-v2 network.""" 323 | def __init__(self, input_channels): 324 | super().__init__() 325 | self.branch3x3stack = nn.Sequential( 326 | BasicConv2d(input_channels, 32, kernel_size=1), 327 | BasicConv2d(32, 48, kernel_size=3, padding=1), 328 | BasicConv2d(48, 64, kernel_size=3, padding=1) 329 | ) 330 | 331 | self.branch3x3 = nn.Sequential( 332 | BasicConv2d(input_channels, 32, kernel_size=1), 333 | BasicConv2d(32, 32, kernel_size=3, padding=1) 334 | ) 335 | 336 | self.branch1x1 = BasicConv2d(input_channels, 32, kernel_size=1) 337 | 338 | self.reduction1x1 = nn.Conv2d(128, 384, kernel_size=1) 339 | self.shortcut = nn.Conv2d(input_channels, 384, kernel_size=1) 340 | self.bn = nn.BatchNorm2d(384) 341 | self.relu = nn.ReLU(inplace=True) 342 | 343 | def forward(self, x): 344 | residual = [ 345 | self.branch1x1(x), 346 | self.branch3x3(x), 347 | self.branch3x3stack(x) 348 | ] 349 | 350 | residual = torch.cat(residual, 1) 351 | residual = self.reduction1x1(residual) 352 | shortcut = self.shortcut(x) 353 | 354 | output = self.bn(shortcut + residual) 355 | output = self.relu(output) 356 | 357 | return output 358 | 359 | 360 | class InceptionResNetB(nn.Module): 361 | 362 | # """Figure 17. The schema for 17 × 17 grid (Inception-ResNet-B) module of 363 | # the Inception-ResNet-v2 network.""" 364 | def __init__(self, input_channels): 365 | super().__init__() 366 | self.branch7x7 = nn.Sequential( 367 | BasicConv2d(input_channels, 128, kernel_size=1), 368 | BasicConv2d(128, 160, kernel_size=(1, 7), padding=(0, 3)), 369 | BasicConv2d(160, 192, kernel_size=(7, 1), padding=(3, 0)) 370 | ) 371 | 372 | self.branch1x1 = BasicConv2d(input_channels, 192, kernel_size=1) 373 | 374 | self.reduction1x1 = nn.Conv2d(384, 1154, kernel_size=1) 375 | self.shortcut = nn.Conv2d(input_channels, 1154, kernel_size=1) 376 | 377 | self.bn = nn.BatchNorm2d(1154) 378 | self.relu = nn.ReLU(inplace=True) 379 | 380 | def forward(self, x): 381 | residual = [ 382 | self.branch1x1(x), 383 | self.branch7x7(x) 384 | ] 385 | 386 | residual = torch.cat(residual, 1) 387 | 388 | # """In general we picked some scaling factors between 0.1 and 0.3 to scale the residuals 389 | # before their being added to the accumulated layer activations (cf. Figure 20).""" 390 | residual = self.reduction1x1(residual) * 0.1 391 | 392 | shortcut = self.shortcut(x) 393 | 394 | output = self.bn(residual + shortcut) 395 | output = self.relu(output) 396 | 397 | return output 398 | 399 | 400 | class InceptionResNetC(nn.Module): 401 | 402 | def __init__(self, input_channels): 403 | # Figure 19. The schema for 8×8 grid (Inception-ResNet-C) 404 | # module of the Inception-ResNet-v2 network.""" 405 | super().__init__() 406 | self.branch3x3 = nn.Sequential( 407 | BasicConv2d(input_channels, 192, kernel_size=1), 408 | BasicConv2d(192, 224, kernel_size=(1, 3), padding=(0, 1)), 409 | BasicConv2d(224, 256, kernel_size=(3, 1), padding=(1, 0)) 410 | ) 411 | 412 | self.branch1x1 = BasicConv2d(input_channels, 192, kernel_size=1) 413 | self.reduction1x1 = nn.Conv2d(448, 2048, kernel_size=1) 414 | self.shorcut = nn.Conv2d(input_channels, 2048, kernel_size=1) 415 | self.bn = nn.BatchNorm2d(2048) 416 | self.relu = nn.ReLU(inplace=True) 417 | 418 | def forward(self, x): 419 | residual = [ 420 | self.branch1x1(x), 421 | self.branch3x3(x) 422 | ] 423 | 424 | residual = torch.cat(residual, 1) 425 | residual = self.reduction1x1(residual) * 0.1 426 | 427 | shorcut = self.shorcut(x) 428 | 429 | output = self.bn(shorcut + residual) 430 | output = self.relu(output) 431 | 432 | return output 433 | 434 | 435 | class InceptionResNetReductionA(nn.Module): 436 | 437 | # """Figure 7. The schema for 35 × 35 to 17 × 17 reduction module. 438 | # Different variants of this blocks (with various number of filters) 439 | # are used in Figure 9, and 15 in each of the new Inception(-v4, - ResNet-v1, 440 | # -ResNet-v2) variants presented in this paper. The k, l, m, n numbers 441 | # represent filter bank sizes which can be looked up in Table 1. 442 | def __init__(self, input_channels, k, l, m, n): 443 | super().__init__() 444 | self.branch3x3stack = nn.Sequential( 445 | BasicConv2d(input_channels, k, kernel_size=1), 446 | BasicConv2d(k, l, kernel_size=3, padding=1), 447 | BasicConv2d(l, m, kernel_size=3, stride=2) 448 | ) 449 | 450 | self.branch3x3 = BasicConv2d(input_channels, n, kernel_size=3, stride=2) 451 | self.branchpool = nn.MaxPool2d(kernel_size=3, stride=2) 452 | self.output_channels = input_channels + n + m 453 | 454 | def forward(self, x): 455 | x = [ 456 | self.branch3x3stack(x), 457 | self.branch3x3(x), 458 | self.branchpool(x) 459 | ] 460 | 461 | return torch.cat(x, 1) 462 | 463 | 464 | class InceptionResNetReductionB(nn.Module): 465 | 466 | # """Figure 18. The schema for 17 × 17 to 8 × 8 grid-reduction module. 467 | # Reduction-B module used by the wider Inception-ResNet-v1 network in 468 | # Figure 15.""" 469 | # I believe it was a typo(Inception-ResNet-v1 should be Inception-ResNet-v2) 470 | def __init__(self, input_channels): 471 | super().__init__() 472 | self.branchpool = nn.MaxPool2d(3, stride=2) 473 | 474 | self.branch3x3a = nn.Sequential( 475 | BasicConv2d(input_channels, 256, kernel_size=1), 476 | BasicConv2d(256, 384, kernel_size=3, stride=2) 477 | ) 478 | 479 | self.branch3x3b = nn.Sequential( 480 | BasicConv2d(input_channels, 256, kernel_size=1), 481 | BasicConv2d(256, 288, kernel_size=3, stride=2) 482 | ) 483 | 484 | self.branch3x3stack = nn.Sequential( 485 | BasicConv2d(input_channels, 256, kernel_size=1), 486 | BasicConv2d(256, 288, kernel_size=3, padding=1), 487 | BasicConv2d(288, 320, kernel_size=3, stride=2) 488 | ) 489 | 490 | def forward(self, x): 491 | x = [ 492 | self.branch3x3a(x), 493 | self.branch3x3b(x), 494 | self.branch3x3stack(x), 495 | self.branchpool(x) 496 | ] 497 | 498 | x = torch.cat(x, 1) 499 | return x 500 | 501 | 502 | class InceptionResNetV2(nn.Module): 503 | 504 | def __init__(self, A, B, C, k=256, l=256, m=384, n=384, class_nums=100): 505 | super().__init__() 506 | self.stem = Inception_Stem(3) 507 | self.inception_resnet_a = self._generate_inception_module(384, 384, A, InceptionResNetA) 508 | self.reduction_a = InceptionResNetReductionA(384, k, l, m, n) 509 | output_channels = self.reduction_a.output_channels 510 | self.inception_resnet_b = self._generate_inception_module(output_channels, 1154, B, InceptionResNetB) 511 | self.reduction_b = InceptionResNetReductionB(1154) 512 | self.inception_resnet_c = self._generate_inception_module(2146, 2048, C, InceptionResNetC) 513 | 514 | # 6x6 featuresize 515 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 516 | # """Dropout (keep 0.8)""" 517 | self.dropout = nn.Dropout2d(1 - 0.8) 518 | self.linear = nn.Linear(2048, class_nums) 519 | 520 | def forward(self, x): 521 | x = self.stem(x) 522 | x = self.inception_resnet_a(x) 523 | x = self.reduction_a(x) 524 | x = self.inception_resnet_b(x) 525 | x = self.reduction_b(x) 526 | x = self.inception_resnet_c(x) 527 | x = self.avgpool(x) 528 | x = self.dropout(x) 529 | x = x.view(-1, 2048) 530 | x = self.linear(x) 531 | 532 | return x 533 | 534 | @staticmethod 535 | def _generate_inception_module(input_channels, output_channels, block_num, block): 536 | layers = nn.Sequential() 537 | for l in range(block_num): 538 | layers.add_module("{}_{}".format(block.__name__, l), block(input_channels)) 539 | input_channels = output_channels 540 | 541 | return layers 542 | 543 | 544 | def inceptionv4(): 545 | return InceptionV4(4, 7, 3) 546 | 547 | 548 | def inception_resnet_v2(): 549 | return InceptionResNetV2(5, 10, 5) 550 | --------------------------------------------------------------------------------