├── core ├── __init__.py ├── model_assemble.py ├── parameter_scaling.py ├── model_disassemble.py ├── sample_select.py ├── relevant_feature_identifying.py └── model_decision_route_visualizing.py ├── utils ├── __init__.py ├── image_util.py ├── file_util.py └── train_util.py ├── engines ├── __init__.py ├── test.py └── train.py ├── framework.jpg ├── loaders ├── datasets │ ├── __init__.py │ └── image_dataset.py ├── __init__.py └── image_loader.py ├── model_assembling.jpg ├── model_disassembling.jpg ├── metrics ├── __init__.py └── accuracy.py ├── scripts ├── model_decision_route_visualizing.sh ├── model_assemble.sh ├── sample_select.sh ├── relevant_feature_identifying.sh ├── model_disassemble.sh ├── test.sh ├── parameter_scaling.sh └── train.sh ├── models ├── simnet.py ├── lenet.py ├── alexnet.py ├── __init__.py ├── vgg.py ├── simplenetv1.py ├── googlenet.py ├── resnet.py ├── mobilenet.py └── inceptionv3.py ├── .gitignore └── README.md /core/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /engines/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiaconghu/Model-LEGO/HEAD/framework.jpg -------------------------------------------------------------------------------- /loaders/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from loaders.datasets.image_dataset import ImageDataset 2 | -------------------------------------------------------------------------------- /model_assembling.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiaconghu/Model-LEGO/HEAD/model_assembling.jpg -------------------------------------------------------------------------------- /model_disassembling.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiaconghu/Model-LEGO/HEAD/model_disassembling.jpg -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from metrics.accuracy import accuracy 2 | from metrics.accuracy import ClassAccuracy 3 | -------------------------------------------------------------------------------- /loaders/__init__.py: -------------------------------------------------------------------------------- 1 | from loaders.image_loader import load_images 2 | 3 | 4 | def load_data(data_dir, data_name, data_type): 5 | print('-' * 50) 6 | print('DATA PATH:', data_dir) 7 | print('DATA NAME:', data_name, '\t|\tDATA TYPE:', data_type) 8 | print('-' * 50) 9 | 10 | return load_images(data_dir, data_name, data_type) 11 | -------------------------------------------------------------------------------- /scripts/model_decision_route_visualizing.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export PYTHONPATH=/nfs3/hjc/projects/cnnlego/code 3 | export CUDA_VISIBLE_DEVICES=0 4 | #---------------------------------------- 5 | mask_dir='/nfs3/hjc/projects/cnnlego/output/lenet_cifar10_base/contributions/masks' 6 | layers='-1' 7 | labels='3 4' 8 | #---------------------------------------- 9 | 10 | python core/model_decision_route_visualizing.py \ 11 | --mask_dir ${mask_dir} \ 12 | --layers ${layers} \ 13 | --labels ${labels} -------------------------------------------------------------------------------- /utils/image_util.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import seaborn as sns 3 | import matplotlib 4 | 5 | matplotlib.use('AGG') 6 | 7 | 8 | def heatmap(vals, fig_path, fig_w=None, fig_h=None, annot=False): 9 | if fig_w is None: 10 | fig_w = vals.shape[1] 11 | if fig_h is None: 12 | fig_h = vals.shape[0] 13 | 14 | f, ax = plt.subplots(figsize=(fig_w, fig_h), ncols=1) 15 | sns.heatmap(vals, ax=ax, annot=annot) 16 | plt.savefig(fig_path, bbox_inches='tight') 17 | plt.clf() 18 | -------------------------------------------------------------------------------- /scripts/model_assemble.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export PYTHONPATH=/nfs3/hjc/projects/cnnlego/code 3 | export CUDA_VISIBLE_DEVICES=0 4 | result_path='/nfs3/hjc/projects/cnnlego/output' 5 | #---------------------------------------- 6 | exp_name='lenet_cifar10_base' 7 | #---------------------------------------- 8 | model1_path=${result_path}'/'${exp_name}'/models/model_disa1.pth' 9 | model2_path=${result_path}'/'${exp_name}'/models/model_disa2.pth' 10 | asse_path=${result_path}'/'${exp_name}'/models/model_asse.pth' 11 | 12 | python core/model_assemble.py \ 13 | --model1_path ${model1_path} \ 14 | --model2_path ${model2_path} \ 15 | --asse_path ${asse_path} 16 | -------------------------------------------------------------------------------- /utils/file_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | 5 | def walk_file(path): 6 | count = 0 7 | for root, dirs, files in os.walk(path): 8 | print(root) 9 | 10 | for f in files: 11 | count += 1 12 | # print(os.path.join(root, f)) 13 | 14 | for d in dirs: 15 | print(os.path.join(root, d)) 16 | print(count) 17 | 18 | 19 | def count_files(path): 20 | for root, dirs, files in os.walk(path): 21 | print(root, len(files)) 22 | 23 | 24 | def copy_file(src, dst): 25 | path, name = os.path.split(dst) 26 | if not os.path.exists(path): 27 | os.makedirs(path) 28 | shutil.copyfile(src, dst) 29 | 30 | 31 | if __name__ == '__main__': 32 | count_files('path') 33 | -------------------------------------------------------------------------------- /scripts/sample_select.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export PYTHONPATH=/nfs3/hjc/projects/cnnlego/code 3 | export CUDA_VISIBLE_DEVICES=0 4 | result_path='/nfs3/hjc/projects/cnnlego/output' 5 | #---------------------------------------- 6 | exp_name='lenet_cifar10_base' 7 | #---------------------------------------- 8 | model_name='lenet' 9 | #---------------------------------------- 10 | data_name='cifar10' 11 | num_classes=10 12 | #---------------------------------------- 13 | model_path=${result_path}'/'${exp_name}'/models/model_ori.pth' 14 | data_dir='/nfs3-p1/hjc/datasets/cifar10/train' 15 | save_dir=${result_path}'/'${exp_name}'/images/htrain' 16 | num_samples=50 17 | 18 | python core/sample_select.py \ 19 | --model_name ${model_name} \ 20 | --data_name ${data_name} \ 21 | --num_classes ${num_classes} \ 22 | --model_path ${model_path} \ 23 | --data_dir ${data_dir} \ 24 | --save_dir ${save_dir} \ 25 | --num_samples ${num_samples} 26 | -------------------------------------------------------------------------------- /scripts/relevant_feature_identifying.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export PYTHONPATH=/nfs3/hjc/projects/cnnlego/code 3 | export CUDA_VISIBLE_DEVICES=0 4 | result_path='/nfs3/hjc/projects/cnnlego/output' 5 | #---------------------------------------- 6 | exp_name='lenet_cifar10_base' 7 | #---------------------------------------- 8 | model_name='lenet' 9 | #---------------------------------------- 10 | export data_name='cifar10' 11 | export num_classes=10 12 | #---------------------------------------- 13 | export model_path=${result_path}'/'${exp_name}'/models/model_ori.pth' 14 | export data_dir=${result_path}'/'${exp_name}'/images/htrain' 15 | export save_dir=${result_path}'/'${exp_name}'/contributions' 16 | 17 | python core/relevant_feature_identifying.py \ 18 | --model_name ${model_name} \ 19 | --data_name ${data_name} \ 20 | --num_classes ${num_classes} \ 21 | --model_path ${model_path} \ 22 | --data_dir ${data_dir} \ 23 | --save_dir ${save_dir} 24 | -------------------------------------------------------------------------------- /scripts/model_disassemble.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export PYTHONPATH=/nfs3/hjc/projects/cnnlego/code 3 | export CUDA_VISIBLE_DEVICES=0 4 | result_path='/nfs3/hjc/projects/cnnlego/output' 5 | #---------------------------------------- 6 | exp_name='lenet_cifar10_base' 7 | #---------------------------------------- 8 | model_name='lenet' 9 | #---------------------------------------- 10 | num_classes=10 11 | #---------------------------------------- 12 | model_path=${result_path}'/'${exp_name}'/models/model_ori.pth' 13 | mask_dir=${result_path}'/'${exp_name}'/contributions/masks' 14 | save_dir=${result_path}'/'${exp_name}'/models' 15 | #---------------------------------------- 16 | disa_layers='-1' 17 | disa_labels='3 4' 18 | 19 | python core/model_disassemble.py \ 20 | --model_name ${model_name} \ 21 | --num_classes ${num_classes} \ 22 | --model_path ${model_path} \ 23 | --mask_dir ${mask_dir} \ 24 | --save_dir ${save_dir} \ 25 | --disa_layers ${disa_layers} \ 26 | --disa_labels ${disa_labels} -------------------------------------------------------------------------------- /scripts/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export PYTHONPATH=/nfs3/hjc/projects/cnnlego/code 3 | export CUDA_VISIBLE_DEVICES=0 4 | result_path='/nfs3/hjc/projects/cnnlego/output' 5 | #---------------------------------------- 6 | exp_name='lenet_cifar10_base' 7 | #---------------------------------------- 8 | #model_name='vgg16' 9 | #model_name='resnet50' 10 | model_name='lenet' 11 | #---------------------------------------- 12 | data_name='cifar10' 13 | num_classes=10 14 | #data_name='cifar100' 15 | #num_classes=100 16 | #---------------------------------------- 17 | #model_path=${result_path}'/'${exp_name}'/models/model_ori.pth' 18 | model_path=${result_path}'/'${exp_name}'/models/model_disa.pth' 19 | #---------------------------------------- 20 | data_dir='/nfs3-p1/hjc/datasets/'${data_name}'/test' 21 | #---------------------------------------- 22 | 23 | python engines/test.py \ 24 | --model_name ${model_name} \ 25 | --data_name ${data_name} \ 26 | --num_classes ${num_classes} \ 27 | --model_path ${model_path} \ 28 | --data_dir ${data_dir} 29 | -------------------------------------------------------------------------------- /scripts/parameter_scaling.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export PYTHONPATH=/nfs3/hjc/projects/cnnlego/code 3 | export CUDA_VISIBLE_DEVICES=0 4 | result_path='/nfs3/hjc/projects/cnnlego/output' 5 | #---------------------------------------- 6 | exp_name='lenet_cifar10_base' 7 | #---------------------------------------- 8 | #model_name='vgg16' 9 | #model_name='resnet50' 10 | model_name='lenet' 11 | #---------------------------------------- 12 | data_name='cifar10' 13 | num_classes=10 14 | #data_name='cifar100' 15 | #num_classes=100 16 | #---------------------------------------- 17 | #model_path=${result_path}'/'${exp_name}'/models/model_ori.pth' 18 | model_path=${result_path}'/'${exp_name}'/models/model_disa.pth' 19 | #---------------------------------------- 20 | data_dir='/nfs3-p1/hjc/datasets/'${data_name}'/test' 21 | #---------------------------------------- 22 | 23 | python core/parameter_scaling.py \ 24 | --model_name ${model_name} \ 25 | --data_name ${data_name} \ 26 | --num_classes ${num_classes} \ 27 | --model_path ${model_path} \ 28 | --data_dir ${data_dir} 29 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export PYTHONPATH=/nfs3/hjc/projects/cnnlego/code 3 | export CUDA_VISIBLE_DEVICES=0 4 | result_path='/nfs3/hjc/projects/cnnlego/output' 5 | #---------------------------------------- 6 | exp_name='lenet_cifar10_base' 7 | #---------------------------------------- 8 | #model_name='vgg16' 9 | #model_name='resnet50' 10 | model_name='lenet' 11 | #---------------------------------------- 12 | data_name='cifar10' 13 | num_classes=10 14 | #data_name='cifar100' 15 | #num_classes=100 16 | #---------------------------------------- 17 | num_epochs=200 18 | model_dir=${result_path}'/'${exp_name}'/models' 19 | #---------------------------------------- 20 | data_train_dir='/nfs3-p1/hjc/datasets/'${data_name}'/train' 21 | data_test_dir='/nfs3-p1/hjc/datasets/'${data_name}'/test' 22 | #---------------------------------------- 23 | log_dir=${result_path}'/runs/'${exp_name} 24 | 25 | python engines/train.py \ 26 | --model_name ${model_name} \ 27 | --data_name ${data_name} \ 28 | --num_classes ${num_classes} \ 29 | --num_epochs ${num_epochs} \ 30 | --model_dir ${model_dir} \ 31 | --data_train_dir ${data_train_dir} \ 32 | --data_test_dir ${data_test_dir} \ 33 | --log_dir ${log_dir} 34 | -------------------------------------------------------------------------------- /models/simnet.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from collections import OrderedDict 3 | 4 | 5 | class SimNet(nn.Module): 6 | def __init__(self, in_channels, num_classes): 7 | super(SimNet, self).__init__() 8 | self.features = nn.Sequential( 9 | OrderedDict([ 10 | ('c1', nn.Conv2d(in_channels, 9, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))), 11 | ('relu1', nn.ReLU()), 12 | ('s1', nn.MaxPool2d(kernel_size=2, stride=2)), 13 | ('c2', nn.Conv2d(9, 27, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))), 14 | ('relu2', nn.ReLU()), 15 | ('s2', nn.MaxPool2d(kernel_size=2, stride=2)), 16 | ('c3', nn.Conv2d(27, 81, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))), 17 | ('relu3', nn.ReLU()) 18 | ]) 19 | ) 20 | self.classifier = nn.Sequential( 21 | OrderedDict([ 22 | # ('f4', nn.Linear(254016, num_classes)) 23 | ('f4', nn.Linear(5184, 2048)), 24 | ('f5', nn.Linear(2048, num_classes)) 25 | ]) 26 | ) 27 | 28 | def forward(self, x): 29 | x = self.features(x) 30 | x = x.view(x.size(0), -1) 31 | x = self.classifier(x) 32 | return x 33 | 34 | 35 | def simnet(in_channels, num_classes): 36 | return SimNet(in_channels, num_classes) 37 | -------------------------------------------------------------------------------- /utils/train_util.py: -------------------------------------------------------------------------------- 1 | class AverageMeter(object): 2 | def __init__(self, name, fmt=':f'): 3 | self.name = name 4 | self.fmt = fmt 5 | self.reset() 6 | 7 | def reset(self): 8 | self.val = 0 9 | self.avg = 0 10 | self.sum = 0 11 | self.count = 0 12 | 13 | def update(self, val, n=1): 14 | self.val = val 15 | self.sum += val * n 16 | self.count += n 17 | self.avg = self.sum / self.count 18 | 19 | def __str__(self): 20 | # fmtstr = '{name}[VAL:{val' + self.fmt + '} AVG:{avg' + self.fmt + '}]' 21 | fmtstr = '{name}[{avg' + self.fmt + '}]' 22 | return fmtstr.format(**self.__dict__) 23 | 24 | 25 | class ProgressMeter(object): 26 | def __init__(self, total, step, prefix, meters): 27 | self._fmtstr = self._get_fmtstr(total) 28 | self.meters = meters 29 | self.prefix = prefix 30 | 31 | self.step = step 32 | 33 | def display(self, running): 34 | if (running + 1) % self.step == 0: 35 | entries = [self.prefix + self._fmtstr.format(running)] # [prefix xx.xx/xx.xx] 36 | entries += [str(meter) for meter in self.meters] 37 | print(' '.join(entries)) 38 | 39 | def _get_fmtstr(self, total): 40 | num_digits = len(str(total // 1)) 41 | fmt = '{:' + str(num_digits) + 'd}' 42 | return '[' + fmt + '/' + fmt.format(total) + ']' # [prefix xx.xx/xx.xx] 43 | -------------------------------------------------------------------------------- /models/lenet.py: -------------------------------------------------------------------------------- 1 | from pyexpat import model 2 | from torch import nn 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | class LeNet(nn.Module): 8 | """LeNet-like network for tests with MNIST (28x28).""" 9 | 10 | def __init__(self, in_channels=1, num_classes=10, **kwargs): 11 | super().__init__() 12 | # main part of the network 13 | self.conv1 = nn.Conv2d(in_channels, 6, 5) 14 | self.conv2 = nn.Conv2d(6, 16, 5) 15 | self.fc1 = nn.Linear(400, 120) 16 | self.fc2 = nn.Linear(120, 84) 17 | 18 | # last classifier layer (head) with as many outputs as classes 19 | self.fc = nn.Linear(84, num_classes) 20 | # and `head_var` with the name of the head, so it can be removed when doing incremental learning experiments 21 | self.head_var = 'fc' 22 | 23 | def forward(self, x): 24 | out = F.relu(self.conv1(x)) 25 | out = F.max_pool2d(out, 2) 26 | out = F.relu(self.conv2(out)) 27 | out = F.max_pool2d(out, 2) 28 | out = out.view(out.size(0), -1) 29 | out = F.relu(self.fc1(out)) 30 | out = F.relu(self.fc2(out)) 31 | out = self.fc(out) 32 | return out 33 | 34 | 35 | def lenet(in_channels=3, num_classes=10): 36 | return LeNet(in_channels=in_channels, num_classes=num_classes) 37 | 38 | 39 | if __name__ == '__main__': 40 | model = LeNet(1, 10) 41 | y = model(torch.randn(1, 1, 32, 32)) 42 | print(y) 43 | -------------------------------------------------------------------------------- /models/alexnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class AlexNet(nn.Module): 5 | def __init__(self, in_channels=3, num_classes=10): 6 | super(AlexNet, self).__init__() 7 | self.features = nn.Sequential( 8 | nn.Conv2d(in_channels, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), 9 | nn.ReLU(inplace=True), 10 | nn.MaxPool2d(kernel_size=2), 11 | nn.Conv2d(64, 192, kernel_size=(3, 3), padding=(1, 1)), 12 | nn.ReLU(inplace=True), 13 | nn.MaxPool2d(kernel_size=2), 14 | nn.Conv2d(192, 384, kernel_size=(3, 3), padding=(1, 1)), 15 | nn.ReLU(inplace=True), 16 | nn.Conv2d(384, 256, kernel_size=(3, 3), padding=(1, 1)), 17 | nn.ReLU(inplace=True), 18 | nn.Conv2d(256, 256, kernel_size=(3, 3), padding=(1, 1)), 19 | nn.ReLU(inplace=True), 20 | nn.MaxPool2d(kernel_size=2), 21 | ) 22 | self.classifier = nn.Sequential( 23 | nn.Dropout(), 24 | nn.Linear(256 * 2 * 2, 4096), 25 | nn.ReLU(inplace=True), 26 | nn.Dropout(), 27 | nn.Linear(4096, 4096), 28 | nn.ReLU(inplace=True), 29 | nn.Linear(4096, num_classes), 30 | ) 31 | 32 | def forward(self, x): 33 | x = self.features(x) 34 | x = x.view(x.size(0), 256 * 2 * 2) 35 | x = self.classifier(x) 36 | return x 37 | 38 | 39 | def alexnet(in_channels=3, num_classes=10): 40 | return AlexNet(in_channels=in_channels, num_classes=num_classes) 41 | -------------------------------------------------------------------------------- /metrics/accuracy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def accuracy(outputs, labels, topk=(1,)): 5 | with torch.no_grad(): 6 | maxk = max(topk) 7 | batch_size = labels.size(0) 8 | 9 | _, pred = outputs.topk(maxk, 1, True, True) # [batch_size, topk] 10 | pred = pred.t() # [topk, batch_size] 11 | correct = pred.eq(labels.view(1, -1).expand_as(pred)) # [topk, batch_size] 12 | 13 | res = [] 14 | for k in topk: 15 | correct_k = correct[:k].float().sum() 16 | res.append(correct_k.mul_(100.0 / batch_size)) 17 | return res 18 | 19 | 20 | class ClassAccuracy: 21 | def __init__(self): 22 | self.sum = {} 23 | self.count = {} 24 | 25 | def update(self, outputs, labels): 26 | _, pred = outputs.max(dim=1) 27 | correct = pred.eq(labels) 28 | 29 | for b, label in enumerate(labels): 30 | label = label.item() 31 | if label not in self.sum.keys(): 32 | self.sum[label] = 0 33 | self.count[label] = 0 34 | self.sum[label] += correct[b].item() 35 | self.count[label] += 1 36 | 37 | def __call__(self): 38 | self.sum = dict(sorted(self.sum.items())) 39 | self.count = dict(sorted(self.count.items())) 40 | return [s / c * 100 for s, c in zip(self.sum.values(), self.count.values())] 41 | 42 | def __getitem__(self, item): 43 | return self.__call__()[item] 44 | 45 | def list(self): 46 | return self.__call__() 47 | 48 | def __str__(self): 49 | fmtstr = '{}:{:6.2f}' 50 | result = '\n'.join([fmtstr.format(l, a) for l, a in enumerate(self.__call__())]) 51 | return result 52 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models import simnet, alexnet, vgg, resnet, simplenetv1, googlenet, lenet 3 | 4 | 5 | def load_model(model_name, in_channels=3, num_classes=10): 6 | print('-' * 50) 7 | print('LOAD MODEL:', model_name) 8 | print('NUM CLASSES:', num_classes) 9 | print('-' * 50) 10 | 11 | model = None 12 | if model_name == 'simnet': 13 | model = simnet.simnet(in_channels, num_classes) 14 | if model_name == 'alexnet': 15 | model = alexnet.alexnet(in_channels, num_classes) 16 | if model_name == 'vgg16': 17 | model = vgg.vgg16_bn(in_channels, num_classes) 18 | if model_name == 'resnet50': 19 | model = resnet.resnet50(in_channels, num_classes) 20 | if model_name == 'simplenetv1': 21 | model = simplenetv1.simplenet(in_channels, num_classes) 22 | if model_name == 'googlenet': 23 | model = googlenet.googlenet(in_channels, num_classes) 24 | if model_name == 'lenet': 25 | model = lenet.lenet(in_channels, num_classes) 26 | 27 | return model 28 | 29 | 30 | def load_modules(model, model_layers=None): 31 | assert model_layers is None or type(model_layers) is list 32 | 33 | modules = [] 34 | for module in model.modules(): 35 | if isinstance(module, torch.nn.Conv2d): 36 | modules.append(module) 37 | if isinstance(module, torch.nn.Linear): 38 | modules.append(module) 39 | 40 | modules.reverse() # reverse order 41 | if model_layers is None: 42 | model_modules = modules 43 | else: 44 | model_modules = [] 45 | for layer in model_layers: 46 | model_modules.append(modules[layer]) 47 | 48 | print('-' * 50) 49 | print('Model Layers:', model_layers) 50 | print('Model Modules:', model_modules) 51 | print('Model Modules Length:', len(model_modules)) 52 | print('-' * 50) 53 | 54 | return model_modules 55 | -------------------------------------------------------------------------------- /loaders/datasets/image_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import PIL.Image as Image 3 | from torch.utils.data import Dataset 4 | 5 | 6 | def _img_loader(path, mode='RGB'): 7 | assert mode in ['RGB', 'L'] 8 | with open(path, 'rb') as f: 9 | img = Image.open(f) 10 | return img.convert(mode) 11 | 12 | 13 | def _find_classes(root): 14 | class_names = [d.name for d in os.scandir(root) if d.is_dir()] 15 | class_names.sort() 16 | classes_indices = {class_names[i]: i for i in range(len(class_names))} 17 | # print(classes_indices) 18 | return class_names, classes_indices # 'class_name':index 19 | 20 | 21 | def _make_dataset(image_dir): 22 | samples = [] # image_path, class_idx 23 | 24 | class_names, class_indices = _find_classes(image_dir) 25 | 26 | for class_name in sorted(class_names): 27 | class_idx = class_indices[class_name] 28 | target_dir = os.path.join(image_dir, class_name) 29 | 30 | if not os.path.isdir(target_dir): 31 | continue 32 | 33 | for root, _, files in sorted(os.walk(target_dir)): 34 | for file in sorted(files): 35 | image_path = os.path.join(root, file) 36 | item = image_path, class_idx 37 | samples.append(item) 38 | return samples 39 | 40 | 41 | class ImageDataset(Dataset): 42 | def __init__(self, image_dir, transform=None): 43 | self.image_dir = image_dir 44 | self.transform = transform 45 | self.samples = _make_dataset(self.image_dir) 46 | self.targets = [s[1] for s in self.samples] 47 | 48 | def __getitem__(self, index): 49 | image_path, target = self.samples[index] 50 | image = _img_loader(image_path, mode='RGB') 51 | name = os.path.split(image_path)[1] 52 | 53 | if self.transform is not None: 54 | image = self.transform(image) 55 | 56 | return image, target, name 57 | 58 | def __len__(self): 59 | return len(self.samples) 60 | -------------------------------------------------------------------------------- /models/vgg.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | cfg = { 4 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 5 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 6 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 7 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'] 8 | } 9 | 10 | 11 | class VGG(nn.Module): 12 | 13 | def __init__(self, features, num_classes=10): 14 | super().__init__() 15 | self.features = features 16 | 17 | self.classifier = nn.Sequential( 18 | nn.Linear(512, 4096), 19 | nn.ReLU(inplace=True), 20 | nn.Dropout(), 21 | nn.Linear(4096, 4096), 22 | nn.ReLU(inplace=True), 23 | nn.Dropout(), 24 | nn.Linear(4096, num_classes) 25 | ) 26 | 27 | def forward(self, x): 28 | output = self.features(x) 29 | output = output.view(output.size()[0], -1) 30 | output = self.classifier(output) 31 | 32 | return output 33 | 34 | 35 | def make_layers(cfg, in_channels=3, batch_norm=False): 36 | layers = [] 37 | 38 | input_channel = in_channels 39 | for l in cfg: 40 | if l == 'M': 41 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 42 | continue 43 | 44 | layers += [nn.Conv2d(input_channel, l, kernel_size=3, padding=1)] 45 | 46 | if batch_norm: 47 | layers += [nn.BatchNorm2d(l)] 48 | 49 | layers += [nn.ReLU(inplace=True)] 50 | input_channel = l 51 | 52 | return nn.Sequential(*layers) 53 | 54 | 55 | def vgg11_bn(num_classes=100): 56 | return VGG(make_layers(cfg['A'], batch_norm=True)) 57 | 58 | 59 | def vgg13_bn(): 60 | return VGG(make_layers(cfg['B'], batch_norm=True)) 61 | 62 | 63 | def vgg16_bn(in_channels=3, num_classes=10): 64 | return VGG(make_layers(cfg['D'], batch_norm=True, in_channels=in_channels), num_classes=num_classes) 65 | 66 | 67 | def vgg19_bn(): 68 | return VGG(make_layers(cfg['E'], batch_norm=True)) -------------------------------------------------------------------------------- /core/model_assemble.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | 4 | 5 | def main(): 6 | parser = argparse.ArgumentParser(description='') 7 | parser.add_argument('--model1_path', default='', type=str, help='model path') 8 | parser.add_argument('--model2_path', default='', type=str, help='model path') 9 | parser.add_argument('--asse_path', default='', type=str, help='asse path') 10 | args = parser.parse_args() 11 | 12 | model1 = torch.load(args.model1_path).cuda() 13 | model2 = torch.load(args.model2_path).cuda() 14 | 15 | # architecture 16 | print('=================> Architecture Assembling') 17 | layer = 0 18 | for module1, module2 in zip(model1.modules(), model2.modules()): 19 | if isinstance(module1, torch.nn.Conv2d): 20 | if layer == 0: 21 | module1.out_channels += module2.out_channels 22 | else: 23 | module1.in_channels += module2.in_channels 24 | module1.out_channels += module2.out_channels 25 | layer += 1 26 | if isinstance(module1, torch.nn.Linear): 27 | module1.in_features += module2.in_features 28 | module1.out_features += module2.out_features 29 | if isinstance(module1, torch.nn.BatchNorm2d): 30 | module1.num_features += module2.num_features 31 | module1.running_mean.data = torch.cat([module1.running_mean.data, module2.running_mean.data], dim=0) 32 | module1.running_var.data = torch.cat([module1.running_var.data, module2.running_var.data], dim=0) 33 | print(model1) 34 | 35 | # parameter 36 | print('=================> Parameter Assembling') 37 | layer = 0 38 | for p1, p2 in zip(model1.parameters(), model2.parameters()): 39 | if len(p1.shape) > 2: 40 | if layer == 0: 41 | p1.data = torch.cat([p1, p2], dim=0) 42 | else: 43 | p1b = torch.zeros(p1.shape[0], p2.shape[1], p1.shape[2], p1.shape[2]).cuda() 44 | p2b = torch.zeros(p2.shape[0], p1.shape[1], p2.shape[2], p2.shape[2]).cuda() 45 | p1.data = torch.cat([p1, p1b], dim=1) 46 | p2.data = torch.cat([p2b, p2], dim=1) 47 | p1.data = torch.cat([p1, p2], dim=0) 48 | layer += 1 49 | elif len(p1.shape) > 1: 50 | p1b = torch.zeros(p1.shape[0], p2.shape[1]).cuda() 51 | p2b = torch.zeros(p2.shape[0], p1.shape[1]).cuda() 52 | p1.data = torch.cat([p1, p1b], dim=1) 53 | p2.data = torch.cat([p2b, p2], dim=1) 54 | p1.data = torch.cat([p1, p2], dim=0) 55 | else: 56 | p1.data = torch.cat([p1, p2], dim=0) 57 | print('=', p1.shape) 58 | 59 | # save model 60 | torch.save(model1, args.asse_path) 61 | 62 | 63 | if __name__ == '__main__': 64 | main() 65 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ### Example user template template 2 | ### Example user template 3 | 4 | # IntelliJ project files 5 | .idea 6 | *.iml 7 | out 8 | gen 9 | ### Python template 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | share/python-wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .nox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | *.py,cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | cover/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | db.sqlite3-journal 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | .pybuilder/ 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | 90 | # IPython 91 | profile_default/ 92 | ipython_config.py 93 | 94 | # pyenv 95 | # For a library or package, you might want to ignore these files since the code is 96 | # intended to run in multiple environments; otherwise, check them in: 97 | # .python-version 98 | 99 | # pipenv 100 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 101 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 102 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 103 | # install all needed dependencies. 104 | #Pipfile.lock 105 | 106 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 107 | __pypackages__/ 108 | 109 | # Celery stuff 110 | celerybeat-schedule 111 | celerybeat.pid 112 | 113 | # SageMath parsed files 114 | *.sage.py 115 | 116 | # Environments 117 | .env 118 | .venv 119 | env/ 120 | venv/ 121 | ENV/ 122 | env.bak/ 123 | venv.bak/ 124 | 125 | # Spyder project settings 126 | .spyderproject 127 | .spyproject 128 | 129 | # Rope project settings 130 | .ropeproject 131 | 132 | # mkdocs documentation 133 | /site 134 | 135 | # mypy 136 | .mypy_cache/ 137 | .dmypy.json 138 | dmypy.json 139 | 140 | # Pyre type checker 141 | .pyre/ 142 | 143 | # pytype static type analyzer 144 | .pytype/ 145 | 146 | # Cython debug symbols 147 | cython_debug/ 148 | 149 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Model LEGO: Creating Models Like Disassembling and Assembling Building Blocks 2 | 3 | 4 | 5 |
6 | 7 | For more information, 8 | please visit https://model-lego.github.io/. 9 | 10 | ## Requirements 11 | 12 | + Python Version: 3.9 13 | + PyTorch Version: 2.0.1 14 | + GPU: NVIDIA RTX A6000 / NVIDIA A40 15 | 16 | ## Quick Start 17 | 18 | ### Prepare the Source Models 19 | 20 | * Train a Pre-trained Model: 21 | 22 | ```bash 23 | python engines/train.py \ 24 | --model_name 'vgg16' \ 25 | --data_name 'cifar10' \ 26 | --num_classes 10 \ 27 | --num_epochs 200 \ 28 | --model_dir ${model_dir} \ 29 | --data_train_dir ${data_train_dir} \ 30 | --data_test_dir ${data_test_dir} \ 31 | --log_dir ${log_dir} 32 | ``` 33 | 34 | ### Model Disassembling 35 | 36 |
37 | 38 | * Select the Top 1% of Samples with High Confidence: 39 | 40 | ```bash 41 | python core/sample_select.py \ 42 | --model_name 'vgg16' \ 43 | --data_name 'cifar10' \ 44 | --num_classes 10 \ 45 | --model_path ${model_path} \ 46 | --data_dir ${data_dir} \ 47 | --save_dir ${save_dir} \ 48 | --num_samples 50 49 | ``` 50 | 51 | * Relevant Features Identifying (\alpha and \beta can be configured in core/relevant_feature_identifying.py): 52 | 53 | ```bash 54 | python core/relevant_feature_identifying.py \ 55 | --model_name 'vgg16' \ 56 | --data_name cifar10 \ 57 | --num_classes 10 \ 58 | --model_path ${model_path} \ 59 | --data_dir ${data_dir} \ 60 | --save_dir ${save_dir} 61 | ``` 62 | 63 | * Parameter Linking and Model Assembling (output the disassembled task-aware component): 64 | 65 | ```bash 66 | python core/model_disassemble.py \ 67 | --model_name 'vgg16' \ 68 | --num_classes 10 \ 69 | --model_path ${model_path} \ 70 | --mask_dir ${mask_dir} \ 71 | --save_dir ${save_dir} \ 72 | --disa_layers ${disa_layers} \ 73 | --disa_labels ${disa_labels} 74 | ``` 75 | 76 | ### Model Assembling 77 | 78 |
79 | 80 | * Parameter Scaling (optional): 81 | 82 | ```bash 83 | python core/parameter_scaling.py \ 84 | --model_name 'vgg16' \ 85 | --data_name 'cifar10' \ 86 | --num_classes 10 \ 87 | --model_path ${model_path} \ 88 | --data_dir ${data_dir} 89 | ``` 90 | 91 | * Alignment Padding and Model Assembling (output the assembled model): 92 | 93 | ```bash 94 | python core/model_assemble.py \ 95 | --model1_path ${model1_path} \ 96 | --model2_path ${model2_path} \ 97 | --asse_path ${asse_path} 98 | ``` 99 | 100 | ### Others 101 | 102 | * Evaluate the Accuracy of the Model or Task-aware Component: 103 | 104 | ```bash 105 | python engines/test.py \ 106 | --model_name 'vgg16' \ 107 | --data_name cifar10 \ 108 | --num_classes 10 \ 109 | --model_path ${model_path} \ 110 | --data_dir ${data_dir} 111 | ``` 112 | 113 | * Visualize Model Decision Routes: 114 | 115 | ```bash 116 | python core/model_decision_route_visualizing.py \ 117 | --mask_dir ${mask_dir} \ 118 | --layers ${layers} \ 119 | --labels ${labels} 120 | ``` 121 | -------------------------------------------------------------------------------- /engines/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import time 4 | 5 | import torch 6 | from torch import nn 7 | import collections 8 | 9 | import loaders 10 | import models 11 | import metrics 12 | from utils.train_util import AverageMeter, ProgressMeter 13 | 14 | from thop import profile 15 | 16 | 17 | def main(): 18 | parser = argparse.ArgumentParser(description='') 19 | parser.add_argument('--model_name', default='', type=str, help='model name') 20 | parser.add_argument('--data_name', default='', type=str, help='data name') 21 | parser.add_argument('--num_classes', default='', type=int, help='num classes') 22 | parser.add_argument('--model_path', default='', type=str, help='model path') 23 | parser.add_argument('--data_dir', default='', type=str, help='data directory') 24 | args = parser.parse_args() 25 | 26 | # ---------------------------------------- 27 | # basic configuration 28 | # ---------------------------------------- 29 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 30 | 31 | print('-' * 50) 32 | print('TEST ON:', device) 33 | print('MODEL PATH:', args.model_path) 34 | print('DATA PATH:', args.data_dir) 35 | print('-' * 50) 36 | 37 | # ---------------------------------------- 38 | # trainer configuration 39 | # ---------------------------------------- 40 | state = torch.load(args.model_path) 41 | if isinstance(state, collections.OrderedDict): 42 | model = models.load_model(args.model_name, num_classes=args.num_classes) 43 | model.load_state_dict(state) 44 | else: 45 | model = state 46 | model.to(device) 47 | 48 | test_loader = loaders.load_data(args.data_dir, args.data_name, data_type='test') 49 | 50 | criterion = nn.CrossEntropyLoss() 51 | 52 | # ---------------------------------------- 53 | # speed 54 | # ---------------------------------------- 55 | speed(model, device) 56 | 57 | # ---------------------------------------- 58 | # each epoch 59 | # ---------------------------------------- 60 | # since = time.time() 61 | 62 | loss, acc1, acc5, class_acc = test(test_loader, model, criterion, device) 63 | 64 | print('-' * 50) 65 | print(class_acc) 66 | print('AVG:', acc1.avg) 67 | # print('TIME CONSUMED', time.time() - since) 68 | 69 | 70 | def test(test_loader, model, criterion, device): 71 | loss_meter = AverageMeter('Loss', ':.4e') 72 | acc1_meter = AverageMeter('Acc@1', ':6.2f') 73 | acc5_meter = AverageMeter('Acc@5', ':6.2f') 74 | progress = ProgressMeter(total=len(test_loader), step=20, prefix='Test', 75 | meters=[loss_meter, acc1_meter, acc5_meter]) 76 | class_acc = metrics.ClassAccuracy() 77 | model.eval() 78 | 79 | for i, samples in enumerate(test_loader): 80 | inputs, labels, _ = samples 81 | inputs = inputs.to(device) 82 | labels = labels.to(device) 83 | 84 | with torch.set_grad_enabled(False): 85 | outputs = model(inputs) 86 | loss = criterion(outputs, labels) 87 | acc1, acc5 = metrics.accuracy(outputs, labels, topk=(1, 1)) 88 | class_acc.update(outputs, labels) 89 | 90 | loss_meter.update(loss.item(), inputs.size(0)) 91 | acc1_meter.update(acc1.item(), inputs.size(0)) 92 | acc5_meter.update(acc5.item(), inputs.size(0)) 93 | 94 | progress.display(i) 95 | 96 | return loss_meter, acc1_meter, acc5_meter, class_acc 97 | 98 | 99 | def speed(model, device): 100 | # model.eval() 101 | 102 | flops, params = profile(model, inputs=(torch.randn(1, 3, 32, 32).to(device),)) 103 | print('FLOPs = ' + str(flops / 1000 ** 3) + 'G') 104 | print('Params = ' + str(params / 1000 ** 2) + 'M') 105 | 106 | 107 | if __name__ == '__main__': 108 | main() 109 | -------------------------------------------------------------------------------- /core/parameter_scaling.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | import torch 4 | import argparse 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | import loaders 9 | import models 10 | 11 | 12 | class ScoreStatistic: 13 | def __init__(self, num_classes): 14 | self.scores = [[] for i in range(num_classes)] 15 | self.nums = torch.zeros(num_classes, dtype=torch.long) 16 | 17 | def __call__(self, outputs, labels): 18 | scores, predicts = torch.max(outputs.detach(), dim=1) 19 | 20 | for i, label in enumerate(labels): 21 | if label == predicts[i]: 22 | self.scores[label].append(scores[i].detach().cpu().numpy()) 23 | self.nums[label] += 1 24 | 25 | def display_score(self, save_path): 26 | max_num = self.nums.max() 27 | for i in range(len(self.scores)): 28 | if len(self.scores[i]) != max_num: 29 | self.scores[i] = self.scores[i] + [0 for _ in range(max_num - len(self.scores[i]))] 30 | scores = torch.from_numpy(np.asarray(self.scores)) 31 | scores_class = torch.sum(scores, dim=1) / self.nums 32 | fc_ratio = self.nums / torch.sum(scores, dim=1) 33 | np.save(save_path, fc_ratio.numpy()) 34 | 35 | print('AVG SCORE RATIO: ', scores_class) 36 | print('Reciprocal AVG SCORE RATIO: ', fc_ratio) 37 | print('PICTURE NUM: ', self.nums) 38 | return fc_ratio 39 | 40 | 41 | def main(): 42 | parser = argparse.ArgumentParser(description='') 43 | parser.add_argument('--model_name', default='', type=str, help='model name') 44 | parser.add_argument('--data_name', default='', type=str, help='data name') 45 | parser.add_argument('--num_classes', default='', type=int, help='num classes') 46 | parser.add_argument('--model_path', default='', type=str, help='model path') 47 | parser.add_argument('--data_dir', default='', type=str, help='data dir') 48 | args = parser.parse_args() 49 | 50 | # ---------------------------------------- 51 | # basic configuration 52 | # ---------------------------------------- 53 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 54 | 55 | print('-' * 100) 56 | print('SCALE ON:', device) 57 | print('MODEL PATH:', args.model_path) 58 | print('DATA DIR:', args.data_dir) 59 | 60 | # ---------------------------------------- 61 | # model/data configuration 62 | # ---------------------------------------- 63 | state = torch.load(args.model_path) 64 | if isinstance(state, collections.OrderedDict): 65 | model = models.load_model(args.model_name) 66 | model.load_state_dict(state) 67 | else: 68 | model = state 69 | model.to(device) 70 | model.eval() 71 | 72 | data_loader = loaders.load_data(args.data_dir, args.data_name, data_type='test') 73 | 74 | score_statistic = ScoreStatistic(num_classes=args.num_classes) 75 | 76 | # ---------------------------------------- 77 | # forward 78 | # ---------------------------------------- 79 | for samples in tqdm(data_loader): 80 | inputs, labels, _ = samples 81 | inputs = inputs.to(device) 82 | labels = labels.to(device) 83 | outputs = model(inputs) 84 | 85 | score_statistic(outputs=outputs, labels=labels) 86 | 87 | score_ratio = score_statistic.display_score( 88 | save_path=args.model_path.split('.')[0] + '.npy') 89 | 90 | # ---------------------------------------- 91 | # parameter scaling 92 | # ---------------------------------------- 93 | layer = 0 94 | last_layer = len(models.load_modules(model=model)) 95 | for para in model.parameters(): 96 | if len(para.shape) > 2: # conv 97 | layer += 1 98 | elif len(para.shape) > 1: # linear 99 | if layer == last_layer - 1: 100 | para.data = score_ratio.view(-1, 1).float().cuda() * para.data 101 | layer += 1 102 | else: # bias 103 | if layer == last_layer: 104 | para.data = score_ratio.view(-1).float().cuda() * para.data 105 | 106 | scale_model_path = args.model_path.split('.')[0] + '_scale.pth' 107 | torch.save(model, scale_model_path) 108 | 109 | print('RESCALE MODEL PATH:', scale_model_path) 110 | print('-' * 50) 111 | 112 | 113 | if __name__ == '__main__': 114 | main() 115 | -------------------------------------------------------------------------------- /core/model_disassemble.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import torch 5 | import torch_pruning as tp 6 | 7 | import models 8 | 9 | 10 | def disassemble(): 11 | parser = argparse.ArgumentParser(description='') 12 | parser.add_argument('--model_name', default='', type=str, help='model name') 13 | parser.add_argument('--num_classes', default='', type=int, help='num classes') 14 | parser.add_argument('--model_path', default='', type=str, help='model path') 15 | parser.add_argument('--save_dir', default='', type=str, help='save dir') 16 | parser.add_argument('--mask_dir', default='', type=str, help='mask dir') 17 | parser.add_argument('--disa_layers', default='', nargs='+', type=int, help='disa layers') 18 | parser.add_argument('--disa_labels', default='', nargs='+', type=int, help='disa labels') 19 | args = parser.parse_args() 20 | 21 | # ---------------------------------------- 22 | # basic configuration 23 | # ---------------------------------------- 24 | print('-' * 50) 25 | print('SAVE DIR:', args.save_dir) 26 | print('-' * 50) 27 | 28 | # ---------------------------------------- 29 | # model configuration 30 | # ---------------------------------------- 31 | model = models.load_model(args.model_name, num_classes=args.num_classes) 32 | model.load_state_dict(torch.load(args.model_path, map_location=torch.device('cpu'))) 33 | # model = torch.load(args.model_path).cpu() 34 | 35 | modules = models.load_modules(model=model, model_layers=None) 36 | 37 | # ---------------------------------------- 38 | # disa configuration 39 | # ---------------------------------------- 40 | mask_path = os.path.join(args.mask_dir, 'mask_layer{}.pt') 41 | 42 | if args.disa_layers[0] == -1: 43 | args.disa_layers = [i for i in range(len(modules) - 1)] 44 | 45 | print('disassembling layers:', args.disa_layers) 46 | print('disassembling labels:', args.disa_labels) 47 | 48 | # ---------------------------------------- 49 | # model disassemble 50 | # ---------------------------------------- 51 | DG = tp.DependencyGraph().build_dependency(model, example_inputs=torch.randn(1, 3, 32, 32)) 52 | 53 | ############################### 54 | # layers 1-N: input channels 55 | ############################### 56 | for layer in args.disa_layers: 57 | print('===> LAYER', layer) 58 | print('--->', modules[layer]) 59 | 60 | # idxs 61 | mask_total_i = None 62 | mask_i = torch.load(mask_path.format(layer)) 63 | for label in args.disa_labels: 64 | if mask_total_i is None: 65 | mask_total_i = mask_i[label] 66 | else: 67 | mask_total_i = torch.bitwise_or(mask_i[label], mask_total_i) 68 | idxs = torch.where(mask_total_i == 0)[0].tolist() 69 | 70 | # structure pruning 71 | prune_fn = None 72 | if isinstance(modules[layer], torch.nn.Conv2d): 73 | prune_fn = tp.prune_conv_in_channels 74 | if isinstance(modules[layer], torch.nn.Linear): 75 | prune_fn = tp.prune_linear_in_channels 76 | group = DG.get_pruning_group(modules[layer], prune_fn, idxs=idxs) 77 | if DG.check_pruning_group(group): 78 | group.prune() 79 | print('--->', modules[layer]) 80 | 81 | ############################### 82 | # layer N: output channels 83 | ############################### 84 | # layer = 0 85 | # print('--->', modules[layer]) 86 | # 87 | # # idxs 88 | # mask_i = torch.load(mask_path.format(-1)) 89 | # mask_total_i = None 90 | # for label in args.disa_labels: 91 | # if mask_total_i is None: 92 | # mask_total_i = mask_i[label] 93 | # else: 94 | # mask_total_i = torch.bitwise_or(mask_i[label], mask_total_i) 95 | # idxs = np.where(mask_total_i == 0)[0].tolist() 96 | # 97 | # # structure pruning 98 | # prune_fn = tp.prune_linear_out_channels 99 | # group = DG.get_pruning_group(modules[layer], prune_fn, idxs=idxs) 100 | # if DG.check_pruning_group(group): 101 | # group.prune() 102 | # print('--->', modules[layer]) 103 | 104 | ############################### 105 | # save model 106 | ############################### 107 | model.zero_grad() 108 | result_path = os.path.join(args.save_dir, 'model_disa.pth') 109 | torch.save(model, result_path) 110 | print(model) 111 | 112 | 113 | if __name__ == '__main__': 114 | disassemble() 115 | -------------------------------------------------------------------------------- /loaders/image_loader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | from torchvision import transforms 3 | from loaders.datasets import ImageDataset 4 | 5 | mnist_train_transform = transforms.Compose([ 6 | transforms.Resize((32, 32)), 7 | transforms.ToTensor(), 8 | transforms.Normalize((0.5, 0.5, 0.5), 9 | (0.5, 0.5, 0.5)), 10 | ]) 11 | 12 | mnist_test_transform = transforms.Compose([ 13 | transforms.Resize((32, 32)), 14 | transforms.ToTensor(), 15 | transforms.Normalize((0.5, 0.5, 0.5), 16 | (0.5, 0.5, 0.5)), 17 | ]) 18 | 19 | cifar10_train_transform = transforms.Compose([ 20 | transforms.RandomCrop(32, padding=4), 21 | transforms.Resize((32, 32)), 22 | # transforms.Resize((256, 256)), 23 | transforms.RandomHorizontalFlip(), 24 | transforms.ToTensor(), 25 | # transforms.Normalize((0.4914, 0.4822, 0.4465), 26 | # (0.2023, 0.1994, 0.2010)), 27 | transforms.Normalize((0.49139968, 0.48215827, 0.44653124), 28 | (0.24703233, 0.24348505, 0.26158768)), 29 | ]) 30 | 31 | cifar10_test_transform = transforms.Compose([ 32 | transforms.Resize((32, 32)), 33 | # transforms.Resize((256, 256)), 34 | transforms.ToTensor(), 35 | # transforms.Normalize((0.4914, 0.4822, 0.4465), 36 | # (0.2023, 0.1994, 0.2010)), 37 | transforms.Normalize((0.49139968, 0.48215827, 0.44653124), 38 | (0.24703233, 0.24348505, 0.26158768)), 39 | ]) 40 | 41 | tiny_imagenet_train_transform = transforms.Compose([ 42 | transforms.RandomResizedCrop((64, 64)), 43 | transforms.RandomHorizontalFlip(), 44 | transforms.ToTensor(), 45 | transforms.Normalize((0.4802, 0.4481, 0.3975), 46 | (0.2770, 0.2691, 0.2821)) 47 | ]) 48 | 49 | tiny_imagenet_test_transform = transforms.Compose([ 50 | transforms.Resize((64, 64)), 51 | transforms.ToTensor(), 52 | transforms.Normalize((0.4802, 0.4481, 0.3975), 53 | (0.2770, 0.2691, 0.2821)) 54 | ]) 55 | 56 | imagenet_train_transform = transforms.Compose([ 57 | transforms.RandomResizedCrop((224, 224)), 58 | transforms.RandomHorizontalFlip(), 59 | transforms.ToTensor(), 60 | transforms.Normalize((0.485, 0.456, 0.406), 61 | (0.229, 0.224, 0.225)) 62 | ]) 63 | 64 | imagenet_test_transform = transforms.Compose([ 65 | transforms.Resize((224, 224)), 66 | transforms.ToTensor(), 67 | transforms.Normalize((0.485, 0.456, 0.406), 68 | (0.229, 0.224, 0.225)) 69 | ]) 70 | 71 | 72 | def _get_set(data_path, transform): 73 | return ImageDataset(image_dir=data_path, 74 | transform=transform) 75 | 76 | 77 | def load_images(data_dir, data_name, data_type=None): 78 | assert data_name in ['mnist', 'fashion-mnist', 'cifar10', 'cifar100', 'tiny-imagenet', 'imagenet'] 79 | assert data_type is None or data_type in ['train', 'test'] 80 | 81 | data_transform = None 82 | if data_name == 'mnist' and data_type == 'train': 83 | data_transform = mnist_train_transform 84 | elif data_name == 'mnist' and data_type == 'test': 85 | data_transform = mnist_test_transform 86 | elif data_name == 'cifar10' and data_type == 'train': 87 | data_transform = cifar10_train_transform 88 | elif data_name == 'cifar10' and data_type == 'test': 89 | data_transform = cifar10_test_transform 90 | elif data_name == 'cifar100' and data_type == 'train': 91 | data_transform = cifar10_train_transform 92 | elif data_name == 'cifar100' and data_type == 'test': 93 | data_transform = cifar10_test_transform 94 | elif data_name == 'tiny-imagenet' and data_type == 'train': 95 | data_transform = tiny_imagenet_train_transform 96 | elif data_name == 'tiny-imagenet' and data_type == 'test': 97 | data_transform = tiny_imagenet_test_transform 98 | elif data_name == 'imagenet' and data_type == 'train': 99 | data_transform = imagenet_train_transform 100 | elif data_name == 'imagenet' and data_type == 'test': 101 | data_transform = imagenet_test_transform 102 | assert data_transform is not None 103 | 104 | data_set = _get_set(data_dir, transform=data_transform) 105 | data_loader = DataLoader(dataset=data_set, 106 | batch_size=256, 107 | num_workers=4, 108 | shuffle=True) 109 | # ImageNet+VGG16: bs128->gpu26311->40days 110 | return data_loader 111 | -------------------------------------------------------------------------------- /models/simplenetv1.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class SimpleNet(nn.Module): 6 | def __init__(self, in_channels=3, num_classes=10): 7 | super(SimpleNet, self).__init__() 8 | # print(simpnet_name) 9 | self.features = self._make_layers(in_channels) # self._make_layers(cfg[simpnet_name]) 10 | self.classifier = nn.Linear(256, num_classes) 11 | self.drp = nn.Dropout(0.1) 12 | 13 | def forward(self, x): 14 | out = self.features(x) 15 | 16 | # Global Max Pooling 17 | out = F.max_pool2d(out, kernel_size=out.size()[2:]) 18 | # out = F.dropout2d(out, 0.1, training=True) 19 | out = self.drp(out) 20 | 21 | out = out.view(out.size(0), -1) 22 | out = self.classifier(out) 23 | return out 24 | 25 | def _make_layers(self, in_channels): 26 | 27 | model = nn.Sequential( 28 | nn.Conv2d(in_channels, 64, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)), 29 | nn.BatchNorm2d(64, eps=1e-05, momentum=0.05, affine=True), 30 | nn.ReLU(inplace=True), 31 | 32 | nn.Conv2d(64, 128, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)), 33 | nn.BatchNorm2d(128, eps=1e-05, momentum=0.05, affine=True), 34 | nn.ReLU(inplace=True), 35 | 36 | nn.Conv2d(128, 128, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)), 37 | nn.BatchNorm2d(128, eps=1e-05, momentum=0.05, affine=True), 38 | nn.ReLU(inplace=True), 39 | 40 | nn.Conv2d(128, 128, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)), 41 | nn.BatchNorm2d(128, eps=1e-05, momentum=0.05, affine=True), 42 | nn.ReLU(inplace=True), 43 | 44 | nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False), 45 | nn.Dropout2d(p=0.1), 46 | 47 | nn.Conv2d(128, 128, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)), 48 | nn.BatchNorm2d(128, eps=1e-05, momentum=0.05, affine=True), 49 | nn.ReLU(inplace=True), 50 | 51 | nn.Conv2d(128, 128, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)), 52 | nn.BatchNorm2d(128, eps=1e-05, momentum=0.05, affine=True), 53 | nn.ReLU(inplace=True), 54 | 55 | nn.Conv2d(128, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)), 56 | nn.BatchNorm2d(256, eps=1e-05, momentum=0.05, affine=True), 57 | nn.ReLU(inplace=True), 58 | 59 | nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False), 60 | nn.Dropout2d(p=0.1), 61 | 62 | nn.Conv2d(256, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)), 63 | nn.BatchNorm2d(256, eps=1e-05, momentum=0.05, affine=True), 64 | nn.ReLU(inplace=True), 65 | 66 | nn.Conv2d(256, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)), 67 | nn.BatchNorm2d(256, eps=1e-05, momentum=0.05, affine=True), 68 | nn.ReLU(inplace=True), 69 | 70 | nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False), 71 | nn.Dropout2d(p=0.1), 72 | 73 | nn.Conv2d(256, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)), 74 | nn.BatchNorm2d(512, eps=1e-05, momentum=0.05, affine=True), 75 | nn.ReLU(inplace=True), 76 | 77 | nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False), 78 | nn.Dropout2d(p=0.1), 79 | 80 | nn.Conv2d(512, 2048, kernel_size=[1, 1], stride=(1, 1), padding=(0, 0)), 81 | nn.BatchNorm2d(2048, eps=1e-05, momentum=0.05, affine=True), 82 | nn.ReLU(inplace=True), 83 | 84 | nn.Conv2d(2048, 256, kernel_size=[1, 1], stride=(1, 1), padding=(0, 0)), 85 | nn.BatchNorm2d(256, eps=1e-05, momentum=0.05, affine=True), 86 | nn.ReLU(inplace=True), 87 | 88 | nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False), 89 | nn.Dropout2d(p=0.1), 90 | 91 | nn.Conv2d(256, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)), 92 | nn.BatchNorm2d(256, eps=1e-05, momentum=0.05, affine=True), 93 | nn.ReLU(inplace=True), 94 | 95 | ) 96 | 97 | for m in model.modules(): 98 | if isinstance(m, nn.Conv2d): 99 | nn.init.xavier_uniform_(m.weight.data, gain=nn.init.calculate_gain('relu')) 100 | 101 | return model 102 | 103 | 104 | def simplenet(in_channels=3, num_classes=10): 105 | return SimpleNet(in_channels=in_channels, num_classes=num_classes) -------------------------------------------------------------------------------- /models/googlenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Inception(nn.Module): 6 | def __init__(self, input_channels, n1x1, n3x3_reduce, n3x3, n5x5_reduce, n5x5, pool_proj): 7 | super().__init__() 8 | 9 | # 1x1conv branch 10 | self.b1 = nn.Sequential( 11 | nn.Conv2d(input_channels, n1x1, kernel_size=1), 12 | nn.BatchNorm2d(n1x1), 13 | nn.ReLU(inplace=True) 14 | ) 15 | 16 | # 1x1conv -> 3x3conv branch 17 | self.b2 = nn.Sequential( 18 | nn.Conv2d(input_channels, n3x3_reduce, kernel_size=1), 19 | nn.BatchNorm2d(n3x3_reduce), 20 | nn.ReLU(inplace=True), 21 | nn.Conv2d(n3x3_reduce, n3x3, kernel_size=3, padding=1), 22 | nn.BatchNorm2d(n3x3), 23 | nn.ReLU(inplace=True) 24 | ) 25 | 26 | # 1x1conv -> 5x5conv branch 27 | # we use 2 3x3 conv filters stacked instead 28 | # of 1 5x5 filters to obtain the same receptive 29 | # field with fewer parameters 30 | self.b3 = nn.Sequential( 31 | nn.Conv2d(input_channels, n5x5_reduce, kernel_size=1), 32 | nn.BatchNorm2d(n5x5_reduce), 33 | nn.ReLU(inplace=True), 34 | nn.Conv2d(n5x5_reduce, n5x5, kernel_size=3, padding=1), 35 | nn.BatchNorm2d(n5x5, n5x5), 36 | nn.ReLU(inplace=True), 37 | nn.Conv2d(n5x5, n5x5, kernel_size=3, padding=1), 38 | nn.BatchNorm2d(n5x5), 39 | nn.ReLU(inplace=True) 40 | ) 41 | 42 | # 3x3pooling -> 1x1conv 43 | # same conv 44 | self.b4 = nn.Sequential( 45 | nn.MaxPool2d(3, stride=1, padding=1), 46 | nn.Conv2d(input_channels, pool_proj, kernel_size=1), 47 | nn.BatchNorm2d(pool_proj), 48 | nn.ReLU(inplace=True) 49 | ) 50 | 51 | def forward(self, x): 52 | return torch.cat([self.b1(x), self.b2(x), self.b3(x), self.b4(x)], dim=1) 53 | 54 | 55 | class GoogleNet(nn.Module): 56 | 57 | def __init__(self, in_channels=3, num_classes=10): 58 | super().__init__() 59 | self.prelayer = nn.Sequential( 60 | nn.Conv2d(in_channels, 64, kernel_size=3, padding=1, bias=False), 61 | nn.BatchNorm2d(64), 62 | nn.ReLU(inplace=True), 63 | nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=False), 64 | nn.BatchNorm2d(64), 65 | nn.ReLU(inplace=True), 66 | nn.Conv2d(64, 192, kernel_size=3, padding=1, bias=False), 67 | nn.BatchNorm2d(192), 68 | nn.ReLU(inplace=True), 69 | ) 70 | 71 | # although we only use 1 conv layer as prelayer, 72 | # we still use name a3, b3....... 73 | self.a3 = Inception(192, 64, 96, 128, 16, 32, 32) 74 | self.b3 = Inception(256, 128, 128, 192, 32, 96, 64) 75 | 76 | ##"""In general, an Inception network is a network consisting of 77 | ##modules of the above type stacked upon each other, with occasional 78 | ##max-pooling layers with stride 2 to halve the resolution of the 79 | ##grid""" 80 | self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) 81 | 82 | self.a4 = Inception(480, 192, 96, 208, 16, 48, 64) 83 | self.b4 = Inception(512, 160, 112, 224, 24, 64, 64) 84 | self.c4 = Inception(512, 128, 128, 256, 24, 64, 64) 85 | self.d4 = Inception(512, 112, 144, 288, 32, 64, 64) 86 | self.e4 = Inception(528, 256, 160, 320, 32, 128, 128) 87 | 88 | self.a5 = Inception(832, 256, 160, 320, 32, 128, 128) 89 | self.b5 = Inception(832, 384, 192, 384, 48, 128, 128) 90 | 91 | # input feature size: 8*8*1024 92 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 93 | self.dropout = nn.Dropout2d(p=0.4) 94 | self.linear = nn.Linear(1024, num_classes) 95 | 96 | def forward(self, x): 97 | x = self.prelayer(x) 98 | x = self.maxpool(x) 99 | x = self.a3(x) 100 | x = self.b3(x) 101 | 102 | x = self.maxpool(x) 103 | 104 | x = self.a4(x) 105 | x = self.b4(x) 106 | x = self.c4(x) 107 | x = self.d4(x) 108 | x = self.e4(x) 109 | 110 | x = self.maxpool(x) 111 | 112 | x = self.a5(x) 113 | x = self.b5(x) 114 | 115 | # """It was found that a move from fully connected layers to 116 | # average pooling improved the top-1 accuracy by about 0.6%, 117 | # however the use of dropout remained essential even after 118 | # removing the fully connected layers.""" 119 | x = self.avgpool(x) 120 | x = self.dropout(x) 121 | x = x.view(x.size()[0], -1) 122 | x = self.linear(x) 123 | 124 | return x 125 | 126 | 127 | def googlenet(in_channels=3, num_classes=10): 128 | return GoogleNet(in_channels=in_channels, num_classes=num_classes) 129 | -------------------------------------------------------------------------------- /core/sample_select.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | from tqdm import tqdm 5 | 6 | import loaders 7 | import models 8 | from utils import file_util 9 | 10 | 11 | class SampleSift: 12 | def __init__(self, num_classes, num_samples, is_high_confidence=True): 13 | self.names = [[None for j in range(num_samples)] for i in range(num_classes)] 14 | self.scores = torch.zeros((num_classes, num_samples)) 15 | self.nums = torch.zeros(num_classes, dtype=torch.long) 16 | self.num_classes = num_classes 17 | self.num_samples = num_samples 18 | self.is_high_confidence = is_high_confidence 19 | 20 | def __call__(self, outputs, labels, names): 21 | softmaxs = torch.nn.Softmax(dim=1)(outputs.detach()) 22 | # print(scores) 23 | 24 | for i, label in enumerate(labels): # each datas 25 | score = softmaxs[i][label] 26 | 27 | if self.is_high_confidence: # sift high confidence 28 | if self.nums[label] == self.num_samples: 29 | score_min, index = torch.min(self.scores[label], dim=0) 30 | if score > score_min: 31 | self.names[label][index] = names[i] 32 | self.scores[label][index] = score 33 | else: 34 | self.names[label][self.nums[label]] = names[i] 35 | self.scores[label][self.nums[label]] = score 36 | self.nums[label] += 1 37 | else: # sift low confidence 38 | if self.nums[label] == self.num_samples: 39 | score_max, index = torch.max(self.scores[label], dim=0) 40 | if score < score_max: 41 | self.names[label][index] = names[i] 42 | self.scores[label][index] = score 43 | else: 44 | self.names[label][self.nums[label]] = names[i] 45 | self.scores[label][self.nums[label]] = score 46 | self.nums[label] += 1 47 | 48 | def save_image(self, input_path, output_path): 49 | print(self.scores) 50 | print(self.nums) 51 | 52 | class_names = sorted([d.name for d in os.scandir(input_path) if d.is_dir()]) 53 | print(class_names) 54 | 55 | for label, image_list in enumerate(self.names): 56 | for image in tqdm(image_list): 57 | class_name = class_names[label] 58 | 59 | src_path = os.path.join(input_path, class_name, str(image)) 60 | dst_path = os.path.join(output_path, class_name, str(image)) 61 | file_util.copy_file(src_path, dst_path) 62 | 63 | 64 | def main(): 65 | parser = argparse.ArgumentParser(description='') 66 | parser.add_argument('--model_name', default='', type=str, help='model name') 67 | parser.add_argument('--data_name', default='', type=str, help='data name') 68 | parser.add_argument('--num_classes', default='', type=int, help='num classes') 69 | parser.add_argument('--model_path', default='', type=str, help='model path') 70 | parser.add_argument('--data_dir', default='', type=str, help='data dir') 71 | parser.add_argument('--save_dir', default='', type=str, help='sift dir') 72 | parser.add_argument('--num_samples', default=10, type=int, help='num samples') 73 | args = parser.parse_args() 74 | 75 | # ---------------------------------------- 76 | # basic configuration 77 | # ---------------------------------------- 78 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 79 | 80 | if not os.path.exists(args.save_dir): 81 | os.makedirs(args.save_dir) 82 | 83 | print('-' * 50) 84 | print('TRAIN ON:', device) 85 | print('MODEL PATH:', args.model_path) 86 | print('DATA PATH:', args.data_dir) 87 | print('RESULT PATH:', args.save_dir) 88 | print('-' * 50) 89 | 90 | # ---------------------------------------- 91 | # model/data configuration 92 | # ---------------------------------------- 93 | model = models.load_model(model_name=args.model_name, num_classes=args.num_classes) 94 | model.load_state_dict(torch.load(args.model_path)) 95 | # model = torch.load(args.model_path) 96 | model.to(device) 97 | model.eval() 98 | 99 | data_loader = loaders.load_data(data_dir=args.data_dir, data_name=args.data_name, data_type='test') 100 | 101 | sample_sift = SampleSift(num_classes=args.num_classes, num_samples=args.num_samples, is_high_confidence=True) 102 | 103 | # ---------------------------------------- 104 | # forward 105 | # ---------------------------------------- 106 | for samples in tqdm(data_loader): 107 | inputs, labels, names = samples 108 | inputs = inputs.to(device) 109 | labels = labels.to(device) 110 | with torch.no_grad(): 111 | outputs = model(inputs) 112 | sample_sift(outputs=outputs, labels=labels, names=names) 113 | 114 | sample_sift.save_image(args.data_dir, args.save_dir) 115 | 116 | 117 | if __name__ == '__main__': 118 | main() 119 | -------------------------------------------------------------------------------- /engines/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import shutil 4 | import time 5 | from tqdm import tqdm 6 | 7 | import torch 8 | from torch import nn 9 | from torch import optim 10 | from torch.utils.tensorboard import SummaryWriter 11 | 12 | import loaders 13 | import models 14 | import metrics 15 | from utils.train_util import AverageMeter, ProgressMeter 16 | 17 | 18 | def main(): 19 | parser = argparse.ArgumentParser(description='') 20 | parser.add_argument('--model_name', default='', type=str, help='model name') 21 | parser.add_argument('--data_name', default='', type=str, help='data name') 22 | parser.add_argument('--num_classes', default='', type=int, help='num classes') 23 | parser.add_argument('--num_epochs', default=200, type=int, help='num epochs') 24 | parser.add_argument('--model_dir', default='', type=str, help='model dir') 25 | parser.add_argument('--data_train_dir', default='', type=str, help='data dir') 26 | parser.add_argument('--data_test_dir', default='', type=str, help='data dir') 27 | parser.add_argument('--log_dir', default='', type=str, help='log dir') 28 | args = parser.parse_args() 29 | 30 | # ---------------------------------------- 31 | # basic configuration 32 | # ---------------------------------------- 33 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 34 | 35 | if not os.path.exists(args.model_dir): 36 | os.makedirs(args.model_dir) 37 | if os.path.exists(args.log_dir): 38 | shutil.rmtree(args.log_dir) 39 | 40 | print('-' * 50) 41 | print('TRAIN ON:', device) 42 | print('MODEL DIR:', args.model_dir) 43 | # print('LOG DIR:', args.log_dir) 44 | print('-' * 50) 45 | 46 | # ---------------------------------------- 47 | # trainer configuration 48 | # ---------------------------------------- 49 | model = models.load_model(args.model_name, num_classes=args.num_classes) 50 | model.to(device) 51 | 52 | train_loader = loaders.load_data(args.data_train_dir, args.data_name, data_type='train') 53 | test_loader = loaders.load_data(args.data_test_dir, args.data_name, data_type='test') 54 | 55 | criterion = nn.CrossEntropyLoss() 56 | optimizer = optim.SGD(params=model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) 57 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=args.num_epochs) 58 | 59 | writer = SummaryWriter(args.log_dir) 60 | 61 | # ---------------------------------------- 62 | # each epoch 63 | # ---------------------------------------- 64 | since = time.time() 65 | 66 | best_acc = None 67 | best_epoch = None 68 | 69 | for epoch in tqdm(range(args.num_epochs)): 70 | print('\n') 71 | loss, acc1, acc5 = train(train_loader, model, criterion, optimizer, device) 72 | writer.add_scalar(tag='training loss', scalar_value=loss.avg, global_step=epoch) 73 | writer.add_scalar(tag='training acc1', scalar_value=acc1.avg, global_step=epoch) 74 | loss, acc1, acc5 = test(test_loader, model, criterion, device) 75 | writer.add_scalar(tag='test loss', scalar_value=loss.avg, global_step=epoch) 76 | writer.add_scalar(tag='test acc1', scalar_value=acc1.avg, global_step=epoch) 77 | 78 | # ---------------------------------------- 79 | # save best model 80 | # ---------------------------------------- 81 | if best_acc is None or best_acc < acc1.avg: 82 | best_acc = acc1.avg 83 | best_epoch = epoch 84 | torch.save(model.state_dict(), os.path.join(args.model_dir, 'model_ori.pth')) 85 | 86 | scheduler.step() 87 | 88 | print('BEST ACC', best_acc) 89 | print('BEST EPOCH', best_epoch) 90 | print('TIME CONSUMED', time.time() - since) 91 | print('MODEL DIR', args.model_dir) 92 | 93 | 94 | def train(train_loader, model, criterion, optimizer, device): 95 | loss_meter = AverageMeter('Loss', ':.4e') 96 | acc1_meter = AverageMeter('Acc@1', ':6.2f') 97 | acc5_meter = AverageMeter('Acc@5', ':6.2f') 98 | progress = ProgressMeter(total=len(train_loader), step=20, prefix='Training', 99 | meters=[loss_meter, acc1_meter, acc5_meter]) 100 | 101 | model.train() 102 | 103 | for i, samples in enumerate(train_loader): 104 | inputs, labels, _ = samples 105 | inputs = inputs.to(device) 106 | labels = labels.to(device) 107 | 108 | outputs = model(inputs) 109 | loss = criterion(outputs, labels) 110 | acc1, acc5 = metrics.accuracy(outputs, labels, topk=(1, 5)) 111 | 112 | loss_meter.update(loss.item(), inputs.size(0)) 113 | acc1_meter.update(acc1.item(), inputs.size(0)) 114 | acc5_meter.update(acc5.item(), inputs.size(0)) 115 | 116 | optimizer.zero_grad() # 1 117 | loss.backward() # 2 118 | optimizer.step() # 3 119 | 120 | progress.display(i) 121 | 122 | return loss_meter, acc1_meter, acc5_meter 123 | 124 | 125 | def test(test_loader, model, criterion, device): 126 | loss_meter = AverageMeter('Loss', ':.4e') 127 | acc1_meter = AverageMeter('Acc@1', ':6.2f') 128 | acc5_meter = AverageMeter('Acc@5', ':6.2f') 129 | progress = ProgressMeter(total=len(test_loader), step=20, prefix='Test', 130 | meters=[loss_meter, acc1_meter, acc5_meter]) 131 | model.eval() 132 | 133 | for i, samples in enumerate(test_loader): 134 | inputs, labels, _ = samples 135 | inputs = inputs.to(device) 136 | labels = labels.to(device) 137 | 138 | with torch.set_grad_enabled(False): 139 | outputs = model(inputs) 140 | loss = criterion(outputs, labels) 141 | acc1, acc5 = metrics.accuracy(outputs, labels, topk=(1, 5)) 142 | 143 | loss_meter.update(loss.item(), inputs.size(0)) 144 | acc1_meter.update(acc1.item(), inputs.size(0)) 145 | acc5_meter.update(acc5.item(), inputs.size(0)) 146 | 147 | progress.display(i) 148 | 149 | return loss_meter, acc1_meter, acc5_meter 150 | 151 | 152 | if __name__ == '__main__': 153 | main() 154 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class BasicBlock(nn.Module): 5 | """Basic Block for resnet 18 and resnet 34 6 | """ 7 | 8 | # BasicBlock and BottleNeck block 9 | # have different output size 10 | # we use class attribute expansion 11 | # to distinct 12 | expansion = 1 13 | 14 | def __init__(self, in_channels, out_channels, stride=1): 15 | super().__init__() 16 | 17 | # residual function 18 | self.residual_function = nn.Sequential( 19 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False), 20 | nn.BatchNorm2d(out_channels), 21 | nn.ReLU(inplace=True), 22 | nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False), 23 | nn.BatchNorm2d(out_channels * BasicBlock.expansion) 24 | ) 25 | 26 | # shortcut 27 | self.shortcut = nn.Sequential() 28 | 29 | # the shortcut output dimension is not the same with residual function 30 | # use 1*1 convolution to match the dimension 31 | if stride != 1 or in_channels != BasicBlock.expansion * out_channels: 32 | self.shortcut = nn.Sequential( 33 | nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False), 34 | nn.BatchNorm2d(out_channels * BasicBlock.expansion) 35 | ) 36 | 37 | def forward(self, x): 38 | return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x)) 39 | 40 | 41 | class BottleNeck(nn.Module): 42 | """Residual block for resnet over 50 layers 43 | """ 44 | expansion = 4 45 | 46 | def __init__(self, in_channels, out_channels, stride=1): 47 | super().__init__() 48 | self.residual_function = nn.Sequential( 49 | nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), 50 | nn.BatchNorm2d(out_channels), 51 | nn.ReLU(inplace=True), 52 | nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False), 53 | nn.BatchNorm2d(out_channels), 54 | nn.ReLU(inplace=True), 55 | nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False), 56 | nn.BatchNorm2d(out_channels * BottleNeck.expansion), 57 | ) 58 | 59 | self.shortcut = nn.Sequential() 60 | 61 | if stride != 1 or in_channels != out_channels * BottleNeck.expansion: 62 | self.shortcut = nn.Sequential( 63 | nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False), 64 | nn.BatchNorm2d(out_channels * BottleNeck.expansion) 65 | ) 66 | 67 | def forward(self, x): 68 | return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x)) 69 | 70 | 71 | class ResNet(nn.Module): 72 | 73 | def __init__(self, block, num_block, in_channels=3, num_classes=10): 74 | super().__init__() 75 | 76 | self.in_channels = 64 77 | 78 | self.conv1 = nn.Sequential( 79 | nn.Conv2d(in_channels, 64, kernel_size=3, padding=1, bias=False), 80 | nn.BatchNorm2d(64), 81 | nn.ReLU(inplace=True)) 82 | # we use a different inputsize than the original paper 83 | # so conv2_x's stride is 1 84 | self.conv2_x = self._make_layer(block, 64, num_block[0], 1) 85 | self.conv3_x = self._make_layer(block, 128, num_block[1], 2) 86 | self.conv4_x = self._make_layer(block, 256, num_block[2], 2) 87 | self.conv5_x = self._make_layer(block, 512, num_block[3], 2) 88 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 89 | self.fc = nn.Linear(512 * block.expansion, num_classes) 90 | 91 | def _make_layer(self, block, out_channels, num_blocks, stride): 92 | """make resnet layers(by layer i didnt mean this 'layer' was the 93 | same as a neuron netowork layer, ex. conv layer), one layer may 94 | contain more than one residual block 95 | Args: 96 | block: block type, basic block or bottle neck block 97 | out_channels: output depth channel number of this layer 98 | num_blocks: how many blocks per layer 99 | stride: the stride of the first block of this layer 100 | Return: 101 | return a resnet layer 102 | """ 103 | 104 | # we have num_block blocks per layer, the first block 105 | # could be 1 or 2, other blocks would always be 1 106 | strides = [stride] + [1] * (num_blocks - 1) 107 | layers = [] 108 | for stride in strides: 109 | layers.append(block(self.in_channels, out_channels, stride)) 110 | self.in_channels = out_channels * block.expansion 111 | 112 | return nn.Sequential(*layers) 113 | 114 | def forward(self, x): 115 | output = self.conv1(x) 116 | output = self.conv2_x(output) 117 | output = self.conv3_x(output) 118 | output = self.conv4_x(output) 119 | output = self.conv5_x(output) 120 | output = self.avg_pool(output) 121 | output = output.view(output.size(0), -1) 122 | output = self.fc(output) 123 | 124 | return output 125 | 126 | 127 | def resnet18(): 128 | """ return a ResNet 18 object 129 | """ 130 | return ResNet(BasicBlock, [2, 2, 2, 2]) 131 | 132 | 133 | def resnet34(in_channels=3, num_classes=10): 134 | """ return a ResNet 34 object 135 | """ 136 | return ResNet(BasicBlock, [3, 4, 6, 3], in_channels=in_channels, num_classes=num_classes) 137 | 138 | 139 | def resnet50(in_channels=3, num_classes=10): 140 | """ return a ResNet 50 object 141 | """ 142 | return ResNet(BottleNeck, [3, 4, 6, 3], in_channels=in_channels, num_classes=num_classes) 143 | 144 | 145 | def resnet101(): 146 | """ return a ResNet 101 object 147 | """ 148 | return ResNet(BottleNeck, [3, 4, 23, 3]) 149 | 150 | 151 | def resnet152(): 152 | """ return a ResNet 152 object 153 | """ 154 | return ResNet(BottleNeck, [3, 8, 36, 3]) 155 | -------------------------------------------------------------------------------- /models/mobilenet.py: -------------------------------------------------------------------------------- 1 | """mobilenet in pytorch 2 | 3 | 4 | 5 | [1] Andrew G. Howard, Menglong Zhu, Bo Chen, Dmitry Kalenichenko, Weijun Wang, Tobias Weyand, Marco Andreetto, Hartwig Adam 6 | 7 | MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications 8 | https://arxiv.org/abs/1704.04861 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | 15 | class DepthSeperabelConv2d(nn.Module): 16 | 17 | def __init__(self, input_channels, output_channels, kernel_size, **kwargs): 18 | super().__init__() 19 | self.depthwise = nn.Sequential( 20 | nn.Conv2d( 21 | input_channels, 22 | input_channels, 23 | kernel_size, 24 | groups=input_channels, 25 | **kwargs), 26 | nn.BatchNorm2d(input_channels), 27 | nn.ReLU(inplace=True) 28 | ) 29 | 30 | self.pointwise = nn.Sequential( 31 | nn.Conv2d(input_channels, output_channels, 1), 32 | nn.BatchNorm2d(output_channels), 33 | nn.ReLU(inplace=True) 34 | ) 35 | 36 | def forward(self, x): 37 | x = self.depthwise(x) 38 | x = self.pointwise(x) 39 | 40 | return x 41 | 42 | 43 | class BasicConv2d(nn.Module): 44 | 45 | def __init__(self, input_channels, output_channels, kernel_size, **kwargs): 46 | super().__init__() 47 | self.conv = nn.Conv2d( 48 | input_channels, output_channels, kernel_size, **kwargs) 49 | self.bn = nn.BatchNorm2d(output_channels) 50 | self.relu = nn.ReLU(inplace=True) 51 | 52 | def forward(self, x): 53 | x = self.conv(x) 54 | x = self.bn(x) 55 | x = self.relu(x) 56 | 57 | return x 58 | 59 | 60 | class MobileNet(nn.Module): 61 | """ 62 | Args: 63 | width multipler: The role of the width multiplier α is to thin 64 | a network uniformly at each layer. For a given 65 | layer and width multiplier α, the number of 66 | input channels M becomes αM and the number of 67 | output channels N becomes αN. 68 | """ 69 | 70 | def __init__(self, width_multiplier=1, class_num=100): 71 | super().__init__() 72 | 73 | alpha = width_multiplier 74 | self.stem = nn.Sequential( 75 | BasicConv2d(3, int(32 * alpha), 3, padding=1, bias=False), 76 | DepthSeperabelConv2d( 77 | int(32 * alpha), 78 | int(64 * alpha), 79 | 3, 80 | padding=1, 81 | bias=False 82 | ) 83 | ) 84 | 85 | # downsample 86 | self.conv1 = nn.Sequential( 87 | DepthSeperabelConv2d( 88 | int(64 * alpha), 89 | int(128 * alpha), 90 | 3, 91 | stride=2, 92 | padding=1, 93 | bias=False 94 | ), 95 | DepthSeperabelConv2d( 96 | int(128 * alpha), 97 | int(128 * alpha), 98 | 3, 99 | padding=1, 100 | bias=False 101 | ) 102 | ) 103 | 104 | # downsample 105 | self.conv2 = nn.Sequential( 106 | DepthSeperabelConv2d( 107 | int(128 * alpha), 108 | int(256 * alpha), 109 | 3, 110 | stride=2, 111 | padding=1, 112 | bias=False 113 | ), 114 | DepthSeperabelConv2d( 115 | int(256 * alpha), 116 | int(256 * alpha), 117 | 3, 118 | padding=1, 119 | bias=False 120 | ) 121 | ) 122 | 123 | # downsample 124 | self.conv3 = nn.Sequential( 125 | DepthSeperabelConv2d( 126 | int(256 * alpha), 127 | int(512 * alpha), 128 | 3, 129 | stride=2, 130 | padding=1, 131 | bias=False 132 | ), 133 | 134 | DepthSeperabelConv2d( 135 | int(512 * alpha), 136 | int(512 * alpha), 137 | 3, 138 | padding=1, 139 | bias=False 140 | ), 141 | DepthSeperabelConv2d( 142 | int(512 * alpha), 143 | int(512 * alpha), 144 | 3, 145 | padding=1, 146 | bias=False 147 | ), 148 | DepthSeperabelConv2d( 149 | int(512 * alpha), 150 | int(512 * alpha), 151 | 3, 152 | padding=1, 153 | bias=False 154 | ), 155 | DepthSeperabelConv2d( 156 | int(512 * alpha), 157 | int(512 * alpha), 158 | 3, 159 | padding=1, 160 | bias=False 161 | ), 162 | DepthSeperabelConv2d( 163 | int(512 * alpha), 164 | int(512 * alpha), 165 | 3, 166 | padding=1, 167 | bias=False 168 | ) 169 | ) 170 | 171 | # downsample 172 | self.conv4 = nn.Sequential( 173 | DepthSeperabelConv2d( 174 | int(512 * alpha), 175 | int(1024 * alpha), 176 | 3, 177 | stride=2, 178 | padding=1, 179 | bias=False 180 | ), 181 | DepthSeperabelConv2d( 182 | int(1024 * alpha), 183 | int(1024 * alpha), 184 | 3, 185 | padding=1, 186 | bias=False 187 | ) 188 | ) 189 | 190 | self.fc = nn.Linear(int(1024 * alpha), class_num) 191 | self.avg = nn.AdaptiveAvgPool2d(1) 192 | 193 | def forward(self, x): 194 | x = self.stem(x) 195 | 196 | x = self.conv1(x) 197 | x = self.conv2(x) 198 | x = self.conv3(x) 199 | x = self.conv4(x) 200 | 201 | x = self.avg(x) 202 | x = x.view(x.size(0), -1) 203 | x = self.fc(x) 204 | return x 205 | 206 | 207 | def mobilenet(alpha=1, class_num=100): 208 | return MobileNet(alpha, class_num) -------------------------------------------------------------------------------- /core/relevant_feature_identifying.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | from tqdm import tqdm 5 | import torch 6 | import torch.nn as nn 7 | 8 | import models 9 | import loaders 10 | 11 | 12 | def partial_conv(conv: nn.Conv2d, inp: torch.Tensor, o_h=None, o_w=None): 13 | kernel_size = conv.kernel_size 14 | dilation = conv.dilation 15 | padding = conv.padding 16 | stride = conv.stride 17 | weight = conv.weight.to(inp.device) # O I K K 18 | # bias = conv.bias.to(inp.device) # O 19 | 20 | wei_res = weight.view(weight.size(0), weight.size(1), -1).permute((1, 2, 0)) # I K*K O 21 | inp_unf = nn.Unfold(kernel_size=kernel_size, dilation=dilation, padding=padding, stride=stride)(inp) # B K*K N 22 | inp_unf = inp_unf.view(inp.size(0), inp.size(1), wei_res.size(1), o_h, o_w) # B I K*K H_O W_O 23 | out = torch.einsum('ijkmn,jkl->iljmn', inp_unf, wei_res) # B O I H W 24 | 25 | # out = out.sum(2) 26 | # bias = bias.unsqueeze(1).unsqueeze(2).expand((out.size(1), out.size(2), out.size(3))) # O H W 27 | # out = out + bias 28 | 29 | return out 30 | 31 | 32 | def partial_linear(linear: nn.Linear, inp: torch.Tensor): 33 | weight = linear.weight.to(inp.device) # (o, i) 34 | # bias = linear.bias.to(inp.device) # (o) 35 | 36 | out = torch.einsum('bi,oi->boi', inp, weight) # (b, o, i) 37 | 38 | # out = torch.sum(out, dim=-1) 39 | # out = out + bias 40 | 41 | return out 42 | 43 | 44 | def mm_norm(a, dim=-1, zero=False): 45 | if zero: 46 | a_min = torch.zeros(a.size()) 47 | else: 48 | a_min, _ = torch.min(a, dim=dim, keepdim=True) 49 | a_max, _ = torch.max(a, dim=dim, keepdim=True) 50 | a_normalized = (a - a_min) / (a_max - a_min + 1e-5) 51 | 52 | return a_normalized 53 | 54 | 55 | class HookModule: 56 | def __init__(self, module): 57 | self.module = module 58 | self.inputs = None 59 | self.outputs = None 60 | module.register_forward_hook(self._hook) 61 | 62 | def _hook(self, module, inputs, outputs): 63 | self.inputs = inputs[0] 64 | self.outputs = outputs 65 | 66 | 67 | class RelevantFeatureIdentifying: 68 | def __init__(self, modules, num_classes, save_dir): 69 | self.modules = [HookModule(module) for module in modules] 70 | # self.values = [[[] for _ in range(num_classes)] for _ in range(len(modules))] # [l, c, n, channels] 71 | self.values = [[0 for _ in range(num_classes)] for _ in range(len(modules))] # [l, c, channels] 72 | self.num_classes = num_classes 73 | self.save_dir = save_dir 74 | 75 | def __call__(self, outputs, labels): 76 | for layer, module in enumerate(self.modules): 77 | torch.cuda.empty_cache() 78 | # print(layer, '==>', layer) 79 | values = None 80 | if isinstance(module.module, nn.Conv2d): 81 | # [b, o, i, h, w] 82 | values = partial_conv(module.module, 83 | module.inputs, 84 | module.outputs.size(2), 85 | module.outputs.size(3)) 86 | values = torch.sum(values, dim=(3, 4)) 87 | elif isinstance(module.module, nn.Linear): 88 | # [b, o, i) 89 | values = partial_linear(module.module, 90 | module.inputs) 91 | values = torch.relu(values) 92 | values = values.cpu() 93 | values = values.numpy() 94 | 95 | for b in range(len(labels)): 96 | # self.values[layer][labels[b]].append(values[b]) # (l, c, n, o, i) 97 | self.values[layer][labels[b]] += values[b] # (l, c, o, i) 98 | 99 | def identify(self): 100 | # parameter configuration 101 | alpha_c = 0.3 102 | beta_c = 0.2 103 | alpha_f = 0.4 104 | beta_f = 0.3 105 | 106 | # layer -1 107 | mask = torch.eye(self.num_classes, dtype=torch.long) # (c, o) 108 | mask_path = os.path.join(self.save_dir, 'masks', 'mask_layer{}.pt'.format('-1')) 109 | torch.save(mask, mask_path) 110 | 111 | # layer 0~n 112 | for layer, values in enumerate(self.values): # (l, c, n, o, i) 113 | values = torch.from_numpy(np.asarray(self.values[layer])) # (c, n, o, i) 114 | # values = torch.sum(values, axis=1) # (c, o, i) 115 | print('-' * 20) 116 | print(mask.shape) 117 | print(values.shape) 118 | print('-' * 20) 119 | 120 | if values.shape[1] != mask.shape[1]: 121 | mask = torch.ones((values.shape[0], values.shape[1]), dtype=torch.long) 122 | 123 | values = mm_norm(values) # (c, o, i) 124 | if isinstance(self.modules[layer].module, nn.Conv2d): 125 | values = torch.where(values > alpha_c, 1, 0) # (c, o, i) 126 | else: 127 | values = torch.where(values > alpha_f, 1, 0) # (c, o, i) 128 | values = torch.einsum('co,coi->ci', mask, values) # (c, i) 129 | # values = torch.sum(values, dim=1) # (c, i) 130 | values = mm_norm(values) # (c, i) 131 | if isinstance(self.modules[layer].module, nn.Conv2d): 132 | mask = torch.where(values > beta_c, 1, 0) # (c, i) 133 | else: 134 | mask = torch.where(values > beta_f, 1, 0) # (c, i) 135 | 136 | mask_path = os.path.join(self.save_dir, 'masks', 'mask_layer{}.pt'.format(layer)) 137 | torch.save(mask, mask_path) 138 | 139 | 140 | def main(): 141 | parser = argparse.ArgumentParser(description='') 142 | parser.add_argument('--model_name', default='', type=str, help='model name') 143 | parser.add_argument('--data_name', default='', type=str, help='data name') 144 | parser.add_argument('--num_classes', default='', type=int, help='num classes') 145 | parser.add_argument('--model_path', default='', type=str, help='model path') 146 | parser.add_argument('--data_dir', default='', type=str, help='data path') 147 | parser.add_argument('--save_dir', default='', type=str, help='save dir') 148 | args = parser.parse_args() 149 | 150 | # ---------------------------------------- 151 | # basic configuration 152 | # ---------------------------------------- 153 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 154 | 155 | if not os.path.exists(args.save_dir): 156 | os.makedirs(os.path.join(args.save_dir, 'masks')) 157 | os.makedirs(os.path.join(args.save_dir, 'figs')) 158 | 159 | print('-' * 50) 160 | print('TRAIN ON:', device) 161 | print('DATA DIR:', args.data_dir) 162 | print('SAVE DIR:', args.save_dir) 163 | print('-' * 50) 164 | 165 | # ---------------------------------------- 166 | # model/data configuration 167 | # ---------------------------------------- 168 | model = models.load_model(model_name=args.model_name, num_classes=args.num_classes) 169 | model.load_state_dict(torch.load(args.model_path)) 170 | # model = torch.load(args.model_path) 171 | model.to(device) 172 | model.eval() 173 | 174 | data_loader = loaders.load_data(args.data_dir, args.data_name, data_type='test') 175 | 176 | modules = models.load_modules(model=model) 177 | 178 | rfi = RelevantFeatureIdentifying(modules=modules, num_classes=args.num_classes, save_dir=args.save_dir) 179 | 180 | # ---------------------------------------- 181 | # forward 182 | # ---------------------------------------- 183 | for i, samples in enumerate(tqdm(data_loader)): 184 | inputs, labels, _ = samples 185 | inputs = inputs.to(device) 186 | labels = labels.to(device) 187 | with torch.no_grad(): 188 | outputs = model(inputs) 189 | rfi(outputs, labels) 190 | 191 | rfi.identify() 192 | 193 | 194 | if __name__ == '__main__': 195 | main() 196 | -------------------------------------------------------------------------------- /core/model_decision_route_visualizing.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from tkinter import * 3 | import numpy as np 4 | import os 5 | 6 | import torch 7 | 8 | FIG_W = 1536 9 | FIG_H = 1024 10 | 11 | CONV_W = 4 12 | CONV_H = 4 13 | LINEAR_W = 2 14 | LINEAR_H = 2 15 | 16 | INTERVAL_CONV_X = 200 17 | INTERVAL_CONV_Y = 7 18 | INTERVAL_LINEAR_X = 280 19 | INTERVAL_LINEAR_Y = 4.5 20 | 21 | PADDING_X = 10 22 | PADDING_Y = 400 # middle line 23 | 24 | LINE_WIDTH = 1 25 | 26 | # COLOR_PUBLIC = 'orange' 27 | # COLOR_NO_USE = 'gray' 28 | # COLORS = ['purple', 'red'] 29 | # COLOR_PUBLIC = '#feb888' 30 | # COLOR_NO_USE = '#c8c8c8' 31 | # COLORS = ['#b0d994', '#a3cbef', ] 32 | COLOR_PUBLIC = '#F8AC8C' 33 | COLOR_NO_USE = '#c8c8c8' 34 | COLORS = ['#C82423', '#2878B5', ] 35 | 36 | 37 | # COLORS = ['#2878B5', '#C82423', ] 38 | 39 | 40 | def draw_route(masks, layers): 41 | root = Tk() 42 | cv = Canvas(root, background='white', width=FIG_W, height=FIG_H) 43 | cv.pack(fill=BOTH, expand=YES) 44 | 45 | # --------------------------- 46 | # each layer 47 | # --------------------------- 48 | masks = np.asarray(masks) # layers, labels, channels 49 | print(masks.shape) 50 | 51 | x = PADDING_X 52 | line_start_p_preceding = [(PADDING_X, PADDING_Y)] # public 53 | line_start_preceding = [[(PADDING_X, PADDING_Y)] for _ in range(masks.shape[1])] # [labels * [init]] 54 | 55 | for layer in range(masks.shape[0]): 56 | 57 | line_end_p = [] # public 58 | line_start_p = [] # public 59 | line_end = [[] for _ in range(masks.shape[1])] # [labels * []] each class 60 | line_start = [[] for _ in range(masks.shape[1])] 61 | 62 | line_p_num = 0 63 | line_num = 0 64 | 65 | # --------------------------- 66 | # each channel 67 | # --------------------------- 68 | layer_masks = np.asarray(list(masks[layer])) # labels, channels 69 | 70 | # init posi. 71 | if layers[layer] == 'conv': 72 | x += CONV_W + INTERVAL_CONV_X 73 | y = PADDING_Y - (layer_masks.shape[1] / 2) * (CONV_H + INTERVAL_CONV_Y) + INTERVAL_CONV_Y / 2 74 | else: 75 | x += LINEAR_W + INTERVAL_LINEAR_X 76 | y = PADDING_Y - (layer_masks.shape[1] / 2) * (LINEAR_H + INTERVAL_LINEAR_Y) + INTERVAL_LINEAR_Y / 2 77 | 78 | # draw conv/linear 79 | for channel in range(layer_masks.shape[1]): 80 | if layer_masks[:, channel].sum() > 1: 81 | if layers[layer] == 'conv': 82 | line_end_p.append(((x), (y + CONV_H / 2))) 83 | line_start_p.append(((x + CONV_W), (y + CONV_H / 2))) 84 | cv.create_rectangle(x, y, x + CONV_W, y + CONV_H, 85 | outline=COLOR_PUBLIC, 86 | fill=COLOR_PUBLIC, 87 | width=LINE_WIDTH) 88 | else: 89 | line_end_p.append(((x), (y + LINEAR_H / 2))) 90 | line_start_p.append(((x + LINEAR_W), (y + LINEAR_H / 2))) 91 | cv.create_oval(x, y, x + LINEAR_W, y + LINEAR_H, 92 | outline=COLOR_PUBLIC, 93 | fill=COLOR_PUBLIC, 94 | width=LINE_WIDTH) 95 | elif layer_masks[:, channel].sum() < 1: 96 | if layers[layer] == 'conv': 97 | cv.create_rectangle(x, y, x + CONV_W, y + CONV_H, 98 | outline=COLOR_NO_USE, 99 | fill=COLOR_NO_USE, 100 | width=LINE_WIDTH) 101 | else: 102 | cv.create_oval(x, y, x + LINEAR_W, y + LINEAR_H, 103 | outline=COLOR_NO_USE, 104 | fill=COLOR_NO_USE, 105 | width=LINE_WIDTH) 106 | else: 107 | # --------------------------- 108 | # each label 109 | # --------------------------- 110 | for l, mask in enumerate(layer_masks[:, channel]): 111 | if mask: 112 | if layers[layer] == 'conv': 113 | line_end[l].append(((x), (y + CONV_H / 2))) 114 | line_start[l].append(((x + CONV_W), (y + CONV_H / 2))) 115 | cv.create_rectangle(x, y, x + CONV_W, y + CONV_H, 116 | outline=COLORS[l], 117 | fill=COLORS[l], 118 | width=LINE_WIDTH) 119 | else: 120 | line_end[l].append(((x), (y + LINEAR_H / 2))) 121 | line_start[l].append(((x + LINEAR_W), (y + LINEAR_H / 2))) 122 | cv.create_oval(x, y, x + LINEAR_W, y + LINEAR_H, 123 | outline=COLORS[l], 124 | fill=COLORS[l], 125 | width=LINE_WIDTH) 126 | 127 | # next y start posi. 128 | if layers[layer] == 'conv': 129 | y += CONV_H + INTERVAL_CONV_Y 130 | else: 131 | y += LINEAR_H + INTERVAL_LINEAR_Y 132 | 133 | # draw line 134 | for l in range(layer_masks.shape[0]): 135 | # line_num += (len(line_start_preceding[l]) * len(line_end[l])) # each to each 136 | # line_p_num += (len(line_start_preceding[l]) * len(line_end_p)) # each to public 137 | # line_p_num += (len(line_start_p_preceding) * len(line_end[l])) # public to each 138 | line_num += len(line_start[l]) # each 139 | for x0, y0 in line_start_preceding[l]: 140 | # each to each 141 | for x1, y1 in line_end[l]: 142 | cv.create_line(x0, y0, x1, y1, 143 | width=LINE_WIDTH, 144 | fill=COLORS[l], 145 | # arrow=LAST, 146 | arrowshape=(6, 5, 1)) 147 | 148 | # each to public 149 | for x1, y1 in line_end_p: 150 | cv.create_line(x0, y0, x1, y1, 151 | width=LINE_WIDTH, 152 | fill=COLORS[l], 153 | # arrow=LAST, 154 | arrowshape=(6, 5, 1)) 155 | 156 | # public to each 157 | for x0, y0 in line_start_p_preceding: 158 | for x1, y1 in line_end[l]: 159 | cv.create_line(x0, y0, x1, y1, 160 | width=LINE_WIDTH, 161 | fill=COLORS[l], 162 | # arrow=LAST, 163 | arrowshape=(6, 5, 1)) 164 | 165 | # line_p_num += (len(line_start_p_preceding) * len(line_end_p)) # public to public 166 | line_p_num += len(line_start_p) # public 167 | # public to public 168 | for x0, y0 in line_start_p_preceding: 169 | for x1, y1 in line_end_p: 170 | cv.create_line(x0, y0, x1, y1, 171 | width=LINE_WIDTH + 1, 172 | fill=COLOR_PUBLIC, 173 | # arrow=LAST, 174 | arrowshape=(6, 5, 1)) 175 | 176 | line_start_preceding = line_start.copy() 177 | line_start_p_preceding = line_start_p.copy() 178 | 179 | # calculate 180 | print('--->', layer, 181 | '| line--->', line_num, 182 | '| line_p--->', line_p_num, 183 | '| --->', line_p_num / (line_num + line_p_num)) 184 | 185 | root.mainloop() 186 | 187 | 188 | def main(): 189 | parser = argparse.ArgumentParser(description='') 190 | parser.add_argument('--mask_dir', default='', type=str, help='mask dir') 191 | parser.add_argument('--layers', default='', nargs='+', type=int, help='layers') 192 | parser.add_argument('--labels', default='', nargs='+', type=int, help='labels') 193 | # parser.add_argument('--save_dir', default='', type=str, help='save dir') 194 | args = parser.parse_args() 195 | 196 | mask_path = os.path.join(args.mask_dir, 'mask_layer{}.pt') 197 | 198 | if args.layers[0] == -1: 199 | args.layers = [4, 3, 2, 1, 0] # Please set manually 200 | 201 | layers_name = ['conv' for _ in range(2)] + ['linear' for _ in range(3)] # Please set manually 202 | 203 | for label in args.labels: 204 | masks = [] 205 | for layer in args.layers: 206 | mask_o = torch.load(mask_path.format(layer - 1))[label].numpy() 207 | masks.append([mask_o]) 208 | 209 | print(masks) 210 | print(np.asarray(masks).shape) 211 | print(layers_name) 212 | draw_route(masks, layers_name) 213 | 214 | 215 | if __name__ == '__main__': 216 | main() 217 | -------------------------------------------------------------------------------- /models/inceptionv3.py: -------------------------------------------------------------------------------- 1 | """ inceptionv3 in pytorch 2 | 3 | 4 | [1] Christian Szegedy, Vincent Vanhoucke, Sergey Ioffe, Jonathon Shlens, Zbigniew Wojna 5 | 6 | Rethinking the Inception Architecture for Computer Vision 7 | https://arxiv.org/abs/1512.00567v3 8 | """ 9 | 10 | import torch 11 | import torch.nn as nn 12 | 13 | 14 | class BasicConv2d(nn.Module): 15 | 16 | def __init__(self, input_channels, output_channels, **kwargs): 17 | super().__init__() 18 | self.conv = nn.Conv2d(input_channels, output_channels, bias=False, **kwargs) 19 | self.bn = nn.BatchNorm2d(output_channels) 20 | self.relu = nn.ReLU(inplace=True) 21 | 22 | def forward(self, x): 23 | x = self.conv(x) 24 | x = self.bn(x) 25 | x = self.relu(x) 26 | 27 | return x 28 | 29 | 30 | # same naive inception module 31 | class InceptionA(nn.Module): 32 | 33 | def __init__(self, input_channels, pool_features): 34 | super().__init__() 35 | self.branch1x1 = BasicConv2d(input_channels, 64, kernel_size=1) 36 | 37 | self.branch5x5 = nn.Sequential( 38 | BasicConv2d(input_channels, 48, kernel_size=1), 39 | BasicConv2d(48, 64, kernel_size=5, padding=2) 40 | ) 41 | 42 | self.branch3x3 = nn.Sequential( 43 | BasicConv2d(input_channels, 64, kernel_size=1), 44 | BasicConv2d(64, 96, kernel_size=3, padding=1), 45 | BasicConv2d(96, 96, kernel_size=3, padding=1) 46 | ) 47 | 48 | self.branchpool = nn.Sequential( 49 | nn.AvgPool2d(kernel_size=3, stride=1, padding=1), 50 | BasicConv2d(input_channels, pool_features, kernel_size=3, padding=1) 51 | ) 52 | 53 | def forward(self, x): 54 | # x -> 1x1(same) 55 | branch1x1 = self.branch1x1(x) 56 | 57 | # x -> 1x1 -> 5x5(same) 58 | branch5x5 = self.branch5x5(x) 59 | # branch5x5 = self.branch5x5_2(branch5x5) 60 | 61 | # x -> 1x1 -> 3x3 -> 3x3(same) 62 | branch3x3 = self.branch3x3(x) 63 | 64 | # x -> pool -> 1x1(same) 65 | branchpool = self.branchpool(x) 66 | 67 | outputs = [branch1x1, branch5x5, branch3x3, branchpool] 68 | 69 | return torch.cat(outputs, 1) 70 | 71 | 72 | # downsample 73 | # Factorization into smaller convolutions 74 | class InceptionB(nn.Module): 75 | 76 | def __init__(self, input_channels): 77 | super().__init__() 78 | 79 | self.branch3x3 = BasicConv2d(input_channels, 384, kernel_size=3, stride=2) 80 | 81 | self.branch3x3stack = nn.Sequential( 82 | BasicConv2d(input_channels, 64, kernel_size=1), 83 | BasicConv2d(64, 96, kernel_size=3, padding=1), 84 | BasicConv2d(96, 96, kernel_size=3, stride=2) 85 | ) 86 | 87 | self.branchpool = nn.MaxPool2d(kernel_size=3, stride=2) 88 | 89 | def forward(self, x): 90 | # x - > 3x3(downsample) 91 | branch3x3 = self.branch3x3(x) 92 | 93 | # x -> 3x3 -> 3x3(downsample) 94 | branch3x3stack = self.branch3x3stack(x) 95 | 96 | # x -> avgpool(downsample) 97 | branchpool = self.branchpool(x) 98 | 99 | # """We can use two parallel stride 2 blocks: P and C. P is a pooling 100 | # layer (either average or maximum pooling) the activation, both of 101 | # them are stride 2 the filter banks of which are concatenated as in 102 | # figure 10.""" 103 | outputs = [branch3x3, branch3x3stack, branchpool] 104 | 105 | return torch.cat(outputs, 1) 106 | 107 | 108 | # Factorizing Convolutions with Large Filter Size 109 | class InceptionC(nn.Module): 110 | def __init__(self, input_channels, channels_7x7): 111 | super().__init__() 112 | self.branch1x1 = BasicConv2d(input_channels, 192, kernel_size=1) 113 | 114 | c7 = channels_7x7 115 | 116 | # In theory, we could go even further and argue that one can replace any n × n 117 | # convolution by a 1 × n convolution followed by a n × 1 convolution and the 118 | # computational cost saving increases dramatically as n grows (see figure 6). 119 | self.branch7x7 = nn.Sequential( 120 | BasicConv2d(input_channels, c7, kernel_size=1), 121 | BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)), 122 | BasicConv2d(c7, 192, kernel_size=(1, 7), padding=(0, 3)) 123 | ) 124 | 125 | self.branch7x7stack = nn.Sequential( 126 | BasicConv2d(input_channels, c7, kernel_size=1), 127 | BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)), 128 | BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3)), 129 | BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)), 130 | BasicConv2d(c7, 192, kernel_size=(1, 7), padding=(0, 3)) 131 | ) 132 | 133 | self.branch_pool = nn.Sequential( 134 | nn.AvgPool2d(kernel_size=3, stride=1, padding=1), 135 | BasicConv2d(input_channels, 192, kernel_size=1), 136 | ) 137 | 138 | def forward(self, x): 139 | # x -> 1x1(same) 140 | branch1x1 = self.branch1x1(x) 141 | 142 | # x -> 1layer 1*7 and 7*1 (same) 143 | branch7x7 = self.branch7x7(x) 144 | 145 | # x-> 2layer 1*7 and 7*1(same) 146 | branch7x7stack = self.branch7x7stack(x) 147 | 148 | # x-> avgpool (same) 149 | branchpool = self.branch_pool(x) 150 | 151 | outputs = [branch1x1, branch7x7, branch7x7stack, branchpool] 152 | 153 | return torch.cat(outputs, 1) 154 | 155 | 156 | class InceptionD(nn.Module): 157 | 158 | def __init__(self, input_channels): 159 | super().__init__() 160 | 161 | self.branch3x3 = nn.Sequential( 162 | BasicConv2d(input_channels, 192, kernel_size=1), 163 | BasicConv2d(192, 320, kernel_size=3, stride=2) 164 | ) 165 | 166 | self.branch7x7 = nn.Sequential( 167 | BasicConv2d(input_channels, 192, kernel_size=1), 168 | BasicConv2d(192, 192, kernel_size=(1, 7), padding=(0, 3)), 169 | BasicConv2d(192, 192, kernel_size=(7, 1), padding=(3, 0)), 170 | BasicConv2d(192, 192, kernel_size=3, stride=2) 171 | ) 172 | 173 | self.branchpool = nn.AvgPool2d(kernel_size=3, stride=2) 174 | 175 | def forward(self, x): 176 | # x -> 1x1 -> 3x3(downsample) 177 | branch3x3 = self.branch3x3(x) 178 | 179 | # x -> 1x1 -> 1x7 -> 7x1 -> 3x3 (downsample) 180 | branch7x7 = self.branch7x7(x) 181 | 182 | # x -> avgpool (downsample) 183 | branchpool = self.branchpool(x) 184 | 185 | outputs = [branch3x3, branch7x7, branchpool] 186 | 187 | return torch.cat(outputs, 1) 188 | 189 | 190 | # same 191 | class InceptionE(nn.Module): 192 | def __init__(self, input_channels): 193 | super().__init__() 194 | self.branch1x1 = BasicConv2d(input_channels, 320, kernel_size=1) 195 | 196 | self.branch3x3_1 = BasicConv2d(input_channels, 384, kernel_size=1) 197 | self.branch3x3_2a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1)) 198 | self.branch3x3_2b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0)) 199 | 200 | self.branch3x3stack_1 = BasicConv2d(input_channels, 448, kernel_size=1) 201 | self.branch3x3stack_2 = BasicConv2d(448, 384, kernel_size=3, padding=1) 202 | self.branch3x3stack_3a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1)) 203 | self.branch3x3stack_3b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0)) 204 | 205 | self.branch_pool = nn.Sequential( 206 | nn.AvgPool2d(kernel_size=3, stride=1, padding=1), 207 | BasicConv2d(input_channels, 192, kernel_size=1) 208 | ) 209 | 210 | def forward(self, x): 211 | # x -> 1x1 (same) 212 | branch1x1 = self.branch1x1(x) 213 | 214 | # x -> 1x1 -> 3x1 215 | # x -> 1x1 -> 1x3 216 | # concatenate(3x1, 1x3) 217 | # """7. Inception modules with expanded the filter bank outputs. 218 | # This architecture is used on the coarsest (8 × 8) grids to promote 219 | # high dimensional representations, as suggested by principle 220 | # 2 of Section 2.""" 221 | branch3x3 = self.branch3x3_1(x) 222 | branch3x3 = [ 223 | self.branch3x3_2a(branch3x3), 224 | self.branch3x3_2b(branch3x3) 225 | ] 226 | branch3x3 = torch.cat(branch3x3, 1) 227 | 228 | # x -> 1x1 -> 3x3 -> 1x3 229 | # x -> 1x1 -> 3x3 -> 3x1 230 | # concatenate(1x3, 3x1) 231 | branch3x3stack = self.branch3x3stack_1(x) 232 | branch3x3stack = self.branch3x3stack_2(branch3x3stack) 233 | branch3x3stack = [ 234 | self.branch3x3stack_3a(branch3x3stack), 235 | self.branch3x3stack_3b(branch3x3stack) 236 | ] 237 | branch3x3stack = torch.cat(branch3x3stack, 1) 238 | 239 | branchpool = self.branch_pool(x) 240 | 241 | outputs = [branch1x1, branch3x3, branch3x3stack, branchpool] 242 | 243 | return torch.cat(outputs, 1) 244 | 245 | 246 | class InceptionV3(nn.Module): 247 | 248 | def __init__(self, in_channels=3, num_classes=10): 249 | super().__init__() 250 | self.Conv2d_1a_3x3 = BasicConv2d(in_channels, 32, kernel_size=3, padding=1) 251 | self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3, padding=1) 252 | self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1) 253 | self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1) 254 | self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3) 255 | 256 | # naive inception module 257 | self.Mixed_5b = InceptionA(192, pool_features=32) 258 | self.Mixed_5c = InceptionA(256, pool_features=64) 259 | self.Mixed_5d = InceptionA(288, pool_features=64) 260 | 261 | # downsample 262 | self.Mixed_6a = InceptionB(288) 263 | 264 | self.Mixed_6b = InceptionC(768, channels_7x7=128) 265 | self.Mixed_6c = InceptionC(768, channels_7x7=160) 266 | self.Mixed_6d = InceptionC(768, channels_7x7=160) 267 | self.Mixed_6e = InceptionC(768, channels_7x7=192) 268 | 269 | # downsample 270 | self.Mixed_7a = InceptionD(768) 271 | 272 | self.Mixed_7b = InceptionE(1280) 273 | self.Mixed_7c = InceptionE(2048) 274 | 275 | # 6*6 feature size 276 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 277 | self.dropout = nn.Dropout2d() 278 | self.linear = nn.Linear(2048, num_classes) 279 | 280 | def forward(self, x): 281 | # 32 -> 30 282 | x = self.Conv2d_1a_3x3(x) 283 | x = self.Conv2d_2a_3x3(x) 284 | x = self.Conv2d_2b_3x3(x) 285 | x = self.Conv2d_3b_1x1(x) 286 | x = self.Conv2d_4a_3x3(x) 287 | 288 | # 30 -> 30 289 | x = self.Mixed_5b(x) 290 | x = self.Mixed_5c(x) 291 | x = self.Mixed_5d(x) 292 | 293 | # 30 -> 14 294 | # Efficient Grid Size Reduction to avoid representation 295 | # bottleneck 296 | x = self.Mixed_6a(x) 297 | 298 | # 14 -> 14 299 | # """In practice, we have found that employing this factorization does not 300 | # work well on early layers, but it gives very good results on medium 301 | # grid-sizes (On m × m feature maps, where m ranges between 12 and 20). 302 | # On that level, very good results can be achieved by using 1 × 7 convolutions 303 | # followed by 7 × 1 convolutions.""" 304 | x = self.Mixed_6b(x) 305 | x = self.Mixed_6c(x) 306 | x = self.Mixed_6d(x) 307 | x = self.Mixed_6e(x) 308 | 309 | # 14 -> 6 310 | # Efficient Grid Size Reduction 311 | x = self.Mixed_7a(x) 312 | 313 | # 6 -> 6 314 | # We are using this solution only on the coarsest grid, 315 | # since that is the place where producing high dimensional 316 | # sparse representation is the most critical as the ratio of 317 | # local processing (by 1 × 1 convolutions) is increased compared 318 | # to the spatial aggregation.""" 319 | x = self.Mixed_7b(x) 320 | x = self.Mixed_7c(x) 321 | 322 | # 6 -> 1 323 | x = self.avgpool(x) 324 | x = self.dropout(x) 325 | x = x.view(x.size(0), -1) 326 | x = self.linear(x) 327 | return x 328 | 329 | 330 | def inceptionv3(in_channels=3, num_classes=10): 331 | return InceptionV3(in_channels=in_channels, num_classes=num_classes) --------------------------------------------------------------------------------