├── core
├── __init__.py
├── model_assemble.py
├── parameter_scaling.py
├── model_disassemble.py
├── sample_select.py
├── relevant_feature_identifying.py
└── model_decision_route_visualizing.py
├── utils
├── __init__.py
├── image_util.py
├── file_util.py
└── train_util.py
├── engines
├── __init__.py
├── test.py
└── train.py
├── framework.jpg
├── loaders
├── datasets
│ ├── __init__.py
│ └── image_dataset.py
├── __init__.py
└── image_loader.py
├── model_assembling.jpg
├── model_disassembling.jpg
├── metrics
├── __init__.py
└── accuracy.py
├── scripts
├── model_decision_route_visualizing.sh
├── model_assemble.sh
├── sample_select.sh
├── relevant_feature_identifying.sh
├── model_disassemble.sh
├── test.sh
├── parameter_scaling.sh
└── train.sh
├── models
├── simnet.py
├── lenet.py
├── alexnet.py
├── __init__.py
├── vgg.py
├── simplenetv1.py
├── googlenet.py
├── resnet.py
├── mobilenet.py
└── inceptionv3.py
├── .gitignore
└── README.md
/core/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/engines/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/framework.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiaconghu/Model-LEGO/HEAD/framework.jpg
--------------------------------------------------------------------------------
/loaders/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from loaders.datasets.image_dataset import ImageDataset
2 |
--------------------------------------------------------------------------------
/model_assembling.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiaconghu/Model-LEGO/HEAD/model_assembling.jpg
--------------------------------------------------------------------------------
/model_disassembling.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiaconghu/Model-LEGO/HEAD/model_disassembling.jpg
--------------------------------------------------------------------------------
/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from metrics.accuracy import accuracy
2 | from metrics.accuracy import ClassAccuracy
3 |
--------------------------------------------------------------------------------
/loaders/__init__.py:
--------------------------------------------------------------------------------
1 | from loaders.image_loader import load_images
2 |
3 |
4 | def load_data(data_dir, data_name, data_type):
5 | print('-' * 50)
6 | print('DATA PATH:', data_dir)
7 | print('DATA NAME:', data_name, '\t|\tDATA TYPE:', data_type)
8 | print('-' * 50)
9 |
10 | return load_images(data_dir, data_name, data_type)
11 |
--------------------------------------------------------------------------------
/scripts/model_decision_route_visualizing.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export PYTHONPATH=/nfs3/hjc/projects/cnnlego/code
3 | export CUDA_VISIBLE_DEVICES=0
4 | #----------------------------------------
5 | mask_dir='/nfs3/hjc/projects/cnnlego/output/lenet_cifar10_base/contributions/masks'
6 | layers='-1'
7 | labels='3 4'
8 | #----------------------------------------
9 |
10 | python core/model_decision_route_visualizing.py \
11 | --mask_dir ${mask_dir} \
12 | --layers ${layers} \
13 | --labels ${labels}
--------------------------------------------------------------------------------
/utils/image_util.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import seaborn as sns
3 | import matplotlib
4 |
5 | matplotlib.use('AGG')
6 |
7 |
8 | def heatmap(vals, fig_path, fig_w=None, fig_h=None, annot=False):
9 | if fig_w is None:
10 | fig_w = vals.shape[1]
11 | if fig_h is None:
12 | fig_h = vals.shape[0]
13 |
14 | f, ax = plt.subplots(figsize=(fig_w, fig_h), ncols=1)
15 | sns.heatmap(vals, ax=ax, annot=annot)
16 | plt.savefig(fig_path, bbox_inches='tight')
17 | plt.clf()
18 |
--------------------------------------------------------------------------------
/scripts/model_assemble.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export PYTHONPATH=/nfs3/hjc/projects/cnnlego/code
3 | export CUDA_VISIBLE_DEVICES=0
4 | result_path='/nfs3/hjc/projects/cnnlego/output'
5 | #----------------------------------------
6 | exp_name='lenet_cifar10_base'
7 | #----------------------------------------
8 | model1_path=${result_path}'/'${exp_name}'/models/model_disa1.pth'
9 | model2_path=${result_path}'/'${exp_name}'/models/model_disa2.pth'
10 | asse_path=${result_path}'/'${exp_name}'/models/model_asse.pth'
11 |
12 | python core/model_assemble.py \
13 | --model1_path ${model1_path} \
14 | --model2_path ${model2_path} \
15 | --asse_path ${asse_path}
16 |
--------------------------------------------------------------------------------
/utils/file_util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 |
4 |
5 | def walk_file(path):
6 | count = 0
7 | for root, dirs, files in os.walk(path):
8 | print(root)
9 |
10 | for f in files:
11 | count += 1
12 | # print(os.path.join(root, f))
13 |
14 | for d in dirs:
15 | print(os.path.join(root, d))
16 | print(count)
17 |
18 |
19 | def count_files(path):
20 | for root, dirs, files in os.walk(path):
21 | print(root, len(files))
22 |
23 |
24 | def copy_file(src, dst):
25 | path, name = os.path.split(dst)
26 | if not os.path.exists(path):
27 | os.makedirs(path)
28 | shutil.copyfile(src, dst)
29 |
30 |
31 | if __name__ == '__main__':
32 | count_files('path')
33 |
--------------------------------------------------------------------------------
/scripts/sample_select.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export PYTHONPATH=/nfs3/hjc/projects/cnnlego/code
3 | export CUDA_VISIBLE_DEVICES=0
4 | result_path='/nfs3/hjc/projects/cnnlego/output'
5 | #----------------------------------------
6 | exp_name='lenet_cifar10_base'
7 | #----------------------------------------
8 | model_name='lenet'
9 | #----------------------------------------
10 | data_name='cifar10'
11 | num_classes=10
12 | #----------------------------------------
13 | model_path=${result_path}'/'${exp_name}'/models/model_ori.pth'
14 | data_dir='/nfs3-p1/hjc/datasets/cifar10/train'
15 | save_dir=${result_path}'/'${exp_name}'/images/htrain'
16 | num_samples=50
17 |
18 | python core/sample_select.py \
19 | --model_name ${model_name} \
20 | --data_name ${data_name} \
21 | --num_classes ${num_classes} \
22 | --model_path ${model_path} \
23 | --data_dir ${data_dir} \
24 | --save_dir ${save_dir} \
25 | --num_samples ${num_samples}
26 |
--------------------------------------------------------------------------------
/scripts/relevant_feature_identifying.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export PYTHONPATH=/nfs3/hjc/projects/cnnlego/code
3 | export CUDA_VISIBLE_DEVICES=0
4 | result_path='/nfs3/hjc/projects/cnnlego/output'
5 | #----------------------------------------
6 | exp_name='lenet_cifar10_base'
7 | #----------------------------------------
8 | model_name='lenet'
9 | #----------------------------------------
10 | export data_name='cifar10'
11 | export num_classes=10
12 | #----------------------------------------
13 | export model_path=${result_path}'/'${exp_name}'/models/model_ori.pth'
14 | export data_dir=${result_path}'/'${exp_name}'/images/htrain'
15 | export save_dir=${result_path}'/'${exp_name}'/contributions'
16 |
17 | python core/relevant_feature_identifying.py \
18 | --model_name ${model_name} \
19 | --data_name ${data_name} \
20 | --num_classes ${num_classes} \
21 | --model_path ${model_path} \
22 | --data_dir ${data_dir} \
23 | --save_dir ${save_dir}
24 |
--------------------------------------------------------------------------------
/scripts/model_disassemble.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export PYTHONPATH=/nfs3/hjc/projects/cnnlego/code
3 | export CUDA_VISIBLE_DEVICES=0
4 | result_path='/nfs3/hjc/projects/cnnlego/output'
5 | #----------------------------------------
6 | exp_name='lenet_cifar10_base'
7 | #----------------------------------------
8 | model_name='lenet'
9 | #----------------------------------------
10 | num_classes=10
11 | #----------------------------------------
12 | model_path=${result_path}'/'${exp_name}'/models/model_ori.pth'
13 | mask_dir=${result_path}'/'${exp_name}'/contributions/masks'
14 | save_dir=${result_path}'/'${exp_name}'/models'
15 | #----------------------------------------
16 | disa_layers='-1'
17 | disa_labels='3 4'
18 |
19 | python core/model_disassemble.py \
20 | --model_name ${model_name} \
21 | --num_classes ${num_classes} \
22 | --model_path ${model_path} \
23 | --mask_dir ${mask_dir} \
24 | --save_dir ${save_dir} \
25 | --disa_layers ${disa_layers} \
26 | --disa_labels ${disa_labels}
--------------------------------------------------------------------------------
/scripts/test.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export PYTHONPATH=/nfs3/hjc/projects/cnnlego/code
3 | export CUDA_VISIBLE_DEVICES=0
4 | result_path='/nfs3/hjc/projects/cnnlego/output'
5 | #----------------------------------------
6 | exp_name='lenet_cifar10_base'
7 | #----------------------------------------
8 | #model_name='vgg16'
9 | #model_name='resnet50'
10 | model_name='lenet'
11 | #----------------------------------------
12 | data_name='cifar10'
13 | num_classes=10
14 | #data_name='cifar100'
15 | #num_classes=100
16 | #----------------------------------------
17 | #model_path=${result_path}'/'${exp_name}'/models/model_ori.pth'
18 | model_path=${result_path}'/'${exp_name}'/models/model_disa.pth'
19 | #----------------------------------------
20 | data_dir='/nfs3-p1/hjc/datasets/'${data_name}'/test'
21 | #----------------------------------------
22 |
23 | python engines/test.py \
24 | --model_name ${model_name} \
25 | --data_name ${data_name} \
26 | --num_classes ${num_classes} \
27 | --model_path ${model_path} \
28 | --data_dir ${data_dir}
29 |
--------------------------------------------------------------------------------
/scripts/parameter_scaling.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export PYTHONPATH=/nfs3/hjc/projects/cnnlego/code
3 | export CUDA_VISIBLE_DEVICES=0
4 | result_path='/nfs3/hjc/projects/cnnlego/output'
5 | #----------------------------------------
6 | exp_name='lenet_cifar10_base'
7 | #----------------------------------------
8 | #model_name='vgg16'
9 | #model_name='resnet50'
10 | model_name='lenet'
11 | #----------------------------------------
12 | data_name='cifar10'
13 | num_classes=10
14 | #data_name='cifar100'
15 | #num_classes=100
16 | #----------------------------------------
17 | #model_path=${result_path}'/'${exp_name}'/models/model_ori.pth'
18 | model_path=${result_path}'/'${exp_name}'/models/model_disa.pth'
19 | #----------------------------------------
20 | data_dir='/nfs3-p1/hjc/datasets/'${data_name}'/test'
21 | #----------------------------------------
22 |
23 | python core/parameter_scaling.py \
24 | --model_name ${model_name} \
25 | --data_name ${data_name} \
26 | --num_classes ${num_classes} \
27 | --model_path ${model_path} \
28 | --data_dir ${data_dir}
29 |
--------------------------------------------------------------------------------
/scripts/train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export PYTHONPATH=/nfs3/hjc/projects/cnnlego/code
3 | export CUDA_VISIBLE_DEVICES=0
4 | result_path='/nfs3/hjc/projects/cnnlego/output'
5 | #----------------------------------------
6 | exp_name='lenet_cifar10_base'
7 | #----------------------------------------
8 | #model_name='vgg16'
9 | #model_name='resnet50'
10 | model_name='lenet'
11 | #----------------------------------------
12 | data_name='cifar10'
13 | num_classes=10
14 | #data_name='cifar100'
15 | #num_classes=100
16 | #----------------------------------------
17 | num_epochs=200
18 | model_dir=${result_path}'/'${exp_name}'/models'
19 | #----------------------------------------
20 | data_train_dir='/nfs3-p1/hjc/datasets/'${data_name}'/train'
21 | data_test_dir='/nfs3-p1/hjc/datasets/'${data_name}'/test'
22 | #----------------------------------------
23 | log_dir=${result_path}'/runs/'${exp_name}
24 |
25 | python engines/train.py \
26 | --model_name ${model_name} \
27 | --data_name ${data_name} \
28 | --num_classes ${num_classes} \
29 | --num_epochs ${num_epochs} \
30 | --model_dir ${model_dir} \
31 | --data_train_dir ${data_train_dir} \
32 | --data_test_dir ${data_test_dir} \
33 | --log_dir ${log_dir}
34 |
--------------------------------------------------------------------------------
/models/simnet.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from collections import OrderedDict
3 |
4 |
5 | class SimNet(nn.Module):
6 | def __init__(self, in_channels, num_classes):
7 | super(SimNet, self).__init__()
8 | self.features = nn.Sequential(
9 | OrderedDict([
10 | ('c1', nn.Conv2d(in_channels, 9, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))),
11 | ('relu1', nn.ReLU()),
12 | ('s1', nn.MaxPool2d(kernel_size=2, stride=2)),
13 | ('c2', nn.Conv2d(9, 27, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))),
14 | ('relu2', nn.ReLU()),
15 | ('s2', nn.MaxPool2d(kernel_size=2, stride=2)),
16 | ('c3', nn.Conv2d(27, 81, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))),
17 | ('relu3', nn.ReLU())
18 | ])
19 | )
20 | self.classifier = nn.Sequential(
21 | OrderedDict([
22 | # ('f4', nn.Linear(254016, num_classes))
23 | ('f4', nn.Linear(5184, 2048)),
24 | ('f5', nn.Linear(2048, num_classes))
25 | ])
26 | )
27 |
28 | def forward(self, x):
29 | x = self.features(x)
30 | x = x.view(x.size(0), -1)
31 | x = self.classifier(x)
32 | return x
33 |
34 |
35 | def simnet(in_channels, num_classes):
36 | return SimNet(in_channels, num_classes)
37 |
--------------------------------------------------------------------------------
/utils/train_util.py:
--------------------------------------------------------------------------------
1 | class AverageMeter(object):
2 | def __init__(self, name, fmt=':f'):
3 | self.name = name
4 | self.fmt = fmt
5 | self.reset()
6 |
7 | def reset(self):
8 | self.val = 0
9 | self.avg = 0
10 | self.sum = 0
11 | self.count = 0
12 |
13 | def update(self, val, n=1):
14 | self.val = val
15 | self.sum += val * n
16 | self.count += n
17 | self.avg = self.sum / self.count
18 |
19 | def __str__(self):
20 | # fmtstr = '{name}[VAL:{val' + self.fmt + '} AVG:{avg' + self.fmt + '}]'
21 | fmtstr = '{name}[{avg' + self.fmt + '}]'
22 | return fmtstr.format(**self.__dict__)
23 |
24 |
25 | class ProgressMeter(object):
26 | def __init__(self, total, step, prefix, meters):
27 | self._fmtstr = self._get_fmtstr(total)
28 | self.meters = meters
29 | self.prefix = prefix
30 |
31 | self.step = step
32 |
33 | def display(self, running):
34 | if (running + 1) % self.step == 0:
35 | entries = [self.prefix + self._fmtstr.format(running)] # [prefix xx.xx/xx.xx]
36 | entries += [str(meter) for meter in self.meters]
37 | print(' '.join(entries))
38 |
39 | def _get_fmtstr(self, total):
40 | num_digits = len(str(total // 1))
41 | fmt = '{:' + str(num_digits) + 'd}'
42 | return '[' + fmt + '/' + fmt.format(total) + ']' # [prefix xx.xx/xx.xx]
43 |
--------------------------------------------------------------------------------
/models/lenet.py:
--------------------------------------------------------------------------------
1 | from pyexpat import model
2 | from torch import nn
3 | import torch
4 | import torch.nn.functional as F
5 |
6 |
7 | class LeNet(nn.Module):
8 | """LeNet-like network for tests with MNIST (28x28)."""
9 |
10 | def __init__(self, in_channels=1, num_classes=10, **kwargs):
11 | super().__init__()
12 | # main part of the network
13 | self.conv1 = nn.Conv2d(in_channels, 6, 5)
14 | self.conv2 = nn.Conv2d(6, 16, 5)
15 | self.fc1 = nn.Linear(400, 120)
16 | self.fc2 = nn.Linear(120, 84)
17 |
18 | # last classifier layer (head) with as many outputs as classes
19 | self.fc = nn.Linear(84, num_classes)
20 | # and `head_var` with the name of the head, so it can be removed when doing incremental learning experiments
21 | self.head_var = 'fc'
22 |
23 | def forward(self, x):
24 | out = F.relu(self.conv1(x))
25 | out = F.max_pool2d(out, 2)
26 | out = F.relu(self.conv2(out))
27 | out = F.max_pool2d(out, 2)
28 | out = out.view(out.size(0), -1)
29 | out = F.relu(self.fc1(out))
30 | out = F.relu(self.fc2(out))
31 | out = self.fc(out)
32 | return out
33 |
34 |
35 | def lenet(in_channels=3, num_classes=10):
36 | return LeNet(in_channels=in_channels, num_classes=num_classes)
37 |
38 |
39 | if __name__ == '__main__':
40 | model = LeNet(1, 10)
41 | y = model(torch.randn(1, 1, 32, 32))
42 | print(y)
43 |
--------------------------------------------------------------------------------
/models/alexnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class AlexNet(nn.Module):
5 | def __init__(self, in_channels=3, num_classes=10):
6 | super(AlexNet, self).__init__()
7 | self.features = nn.Sequential(
8 | nn.Conv2d(in_channels, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
9 | nn.ReLU(inplace=True),
10 | nn.MaxPool2d(kernel_size=2),
11 | nn.Conv2d(64, 192, kernel_size=(3, 3), padding=(1, 1)),
12 | nn.ReLU(inplace=True),
13 | nn.MaxPool2d(kernel_size=2),
14 | nn.Conv2d(192, 384, kernel_size=(3, 3), padding=(1, 1)),
15 | nn.ReLU(inplace=True),
16 | nn.Conv2d(384, 256, kernel_size=(3, 3), padding=(1, 1)),
17 | nn.ReLU(inplace=True),
18 | nn.Conv2d(256, 256, kernel_size=(3, 3), padding=(1, 1)),
19 | nn.ReLU(inplace=True),
20 | nn.MaxPool2d(kernel_size=2),
21 | )
22 | self.classifier = nn.Sequential(
23 | nn.Dropout(),
24 | nn.Linear(256 * 2 * 2, 4096),
25 | nn.ReLU(inplace=True),
26 | nn.Dropout(),
27 | nn.Linear(4096, 4096),
28 | nn.ReLU(inplace=True),
29 | nn.Linear(4096, num_classes),
30 | )
31 |
32 | def forward(self, x):
33 | x = self.features(x)
34 | x = x.view(x.size(0), 256 * 2 * 2)
35 | x = self.classifier(x)
36 | return x
37 |
38 |
39 | def alexnet(in_channels=3, num_classes=10):
40 | return AlexNet(in_channels=in_channels, num_classes=num_classes)
41 |
--------------------------------------------------------------------------------
/metrics/accuracy.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def accuracy(outputs, labels, topk=(1,)):
5 | with torch.no_grad():
6 | maxk = max(topk)
7 | batch_size = labels.size(0)
8 |
9 | _, pred = outputs.topk(maxk, 1, True, True) # [batch_size, topk]
10 | pred = pred.t() # [topk, batch_size]
11 | correct = pred.eq(labels.view(1, -1).expand_as(pred)) # [topk, batch_size]
12 |
13 | res = []
14 | for k in topk:
15 | correct_k = correct[:k].float().sum()
16 | res.append(correct_k.mul_(100.0 / batch_size))
17 | return res
18 |
19 |
20 | class ClassAccuracy:
21 | def __init__(self):
22 | self.sum = {}
23 | self.count = {}
24 |
25 | def update(self, outputs, labels):
26 | _, pred = outputs.max(dim=1)
27 | correct = pred.eq(labels)
28 |
29 | for b, label in enumerate(labels):
30 | label = label.item()
31 | if label not in self.sum.keys():
32 | self.sum[label] = 0
33 | self.count[label] = 0
34 | self.sum[label] += correct[b].item()
35 | self.count[label] += 1
36 |
37 | def __call__(self):
38 | self.sum = dict(sorted(self.sum.items()))
39 | self.count = dict(sorted(self.count.items()))
40 | return [s / c * 100 for s, c in zip(self.sum.values(), self.count.values())]
41 |
42 | def __getitem__(self, item):
43 | return self.__call__()[item]
44 |
45 | def list(self):
46 | return self.__call__()
47 |
48 | def __str__(self):
49 | fmtstr = '{}:{:6.2f}'
50 | result = '\n'.join([fmtstr.format(l, a) for l, a in enumerate(self.__call__())])
51 | return result
52 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from models import simnet, alexnet, vgg, resnet, simplenetv1, googlenet, lenet
3 |
4 |
5 | def load_model(model_name, in_channels=3, num_classes=10):
6 | print('-' * 50)
7 | print('LOAD MODEL:', model_name)
8 | print('NUM CLASSES:', num_classes)
9 | print('-' * 50)
10 |
11 | model = None
12 | if model_name == 'simnet':
13 | model = simnet.simnet(in_channels, num_classes)
14 | if model_name == 'alexnet':
15 | model = alexnet.alexnet(in_channels, num_classes)
16 | if model_name == 'vgg16':
17 | model = vgg.vgg16_bn(in_channels, num_classes)
18 | if model_name == 'resnet50':
19 | model = resnet.resnet50(in_channels, num_classes)
20 | if model_name == 'simplenetv1':
21 | model = simplenetv1.simplenet(in_channels, num_classes)
22 | if model_name == 'googlenet':
23 | model = googlenet.googlenet(in_channels, num_classes)
24 | if model_name == 'lenet':
25 | model = lenet.lenet(in_channels, num_classes)
26 |
27 | return model
28 |
29 |
30 | def load_modules(model, model_layers=None):
31 | assert model_layers is None or type(model_layers) is list
32 |
33 | modules = []
34 | for module in model.modules():
35 | if isinstance(module, torch.nn.Conv2d):
36 | modules.append(module)
37 | if isinstance(module, torch.nn.Linear):
38 | modules.append(module)
39 |
40 | modules.reverse() # reverse order
41 | if model_layers is None:
42 | model_modules = modules
43 | else:
44 | model_modules = []
45 | for layer in model_layers:
46 | model_modules.append(modules[layer])
47 |
48 | print('-' * 50)
49 | print('Model Layers:', model_layers)
50 | print('Model Modules:', model_modules)
51 | print('Model Modules Length:', len(model_modules))
52 | print('-' * 50)
53 |
54 | return model_modules
55 |
--------------------------------------------------------------------------------
/loaders/datasets/image_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import PIL.Image as Image
3 | from torch.utils.data import Dataset
4 |
5 |
6 | def _img_loader(path, mode='RGB'):
7 | assert mode in ['RGB', 'L']
8 | with open(path, 'rb') as f:
9 | img = Image.open(f)
10 | return img.convert(mode)
11 |
12 |
13 | def _find_classes(root):
14 | class_names = [d.name for d in os.scandir(root) if d.is_dir()]
15 | class_names.sort()
16 | classes_indices = {class_names[i]: i for i in range(len(class_names))}
17 | # print(classes_indices)
18 | return class_names, classes_indices # 'class_name':index
19 |
20 |
21 | def _make_dataset(image_dir):
22 | samples = [] # image_path, class_idx
23 |
24 | class_names, class_indices = _find_classes(image_dir)
25 |
26 | for class_name in sorted(class_names):
27 | class_idx = class_indices[class_name]
28 | target_dir = os.path.join(image_dir, class_name)
29 |
30 | if not os.path.isdir(target_dir):
31 | continue
32 |
33 | for root, _, files in sorted(os.walk(target_dir)):
34 | for file in sorted(files):
35 | image_path = os.path.join(root, file)
36 | item = image_path, class_idx
37 | samples.append(item)
38 | return samples
39 |
40 |
41 | class ImageDataset(Dataset):
42 | def __init__(self, image_dir, transform=None):
43 | self.image_dir = image_dir
44 | self.transform = transform
45 | self.samples = _make_dataset(self.image_dir)
46 | self.targets = [s[1] for s in self.samples]
47 |
48 | def __getitem__(self, index):
49 | image_path, target = self.samples[index]
50 | image = _img_loader(image_path, mode='RGB')
51 | name = os.path.split(image_path)[1]
52 |
53 | if self.transform is not None:
54 | image = self.transform(image)
55 |
56 | return image, target, name
57 |
58 | def __len__(self):
59 | return len(self.samples)
60 |
--------------------------------------------------------------------------------
/models/vgg.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | cfg = {
4 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
5 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
6 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
7 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M']
8 | }
9 |
10 |
11 | class VGG(nn.Module):
12 |
13 | def __init__(self, features, num_classes=10):
14 | super().__init__()
15 | self.features = features
16 |
17 | self.classifier = nn.Sequential(
18 | nn.Linear(512, 4096),
19 | nn.ReLU(inplace=True),
20 | nn.Dropout(),
21 | nn.Linear(4096, 4096),
22 | nn.ReLU(inplace=True),
23 | nn.Dropout(),
24 | nn.Linear(4096, num_classes)
25 | )
26 |
27 | def forward(self, x):
28 | output = self.features(x)
29 | output = output.view(output.size()[0], -1)
30 | output = self.classifier(output)
31 |
32 | return output
33 |
34 |
35 | def make_layers(cfg, in_channels=3, batch_norm=False):
36 | layers = []
37 |
38 | input_channel = in_channels
39 | for l in cfg:
40 | if l == 'M':
41 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
42 | continue
43 |
44 | layers += [nn.Conv2d(input_channel, l, kernel_size=3, padding=1)]
45 |
46 | if batch_norm:
47 | layers += [nn.BatchNorm2d(l)]
48 |
49 | layers += [nn.ReLU(inplace=True)]
50 | input_channel = l
51 |
52 | return nn.Sequential(*layers)
53 |
54 |
55 | def vgg11_bn(num_classes=100):
56 | return VGG(make_layers(cfg['A'], batch_norm=True))
57 |
58 |
59 | def vgg13_bn():
60 | return VGG(make_layers(cfg['B'], batch_norm=True))
61 |
62 |
63 | def vgg16_bn(in_channels=3, num_classes=10):
64 | return VGG(make_layers(cfg['D'], batch_norm=True, in_channels=in_channels), num_classes=num_classes)
65 |
66 |
67 | def vgg19_bn():
68 | return VGG(make_layers(cfg['E'], batch_norm=True))
--------------------------------------------------------------------------------
/core/model_assemble.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import argparse
3 |
4 |
5 | def main():
6 | parser = argparse.ArgumentParser(description='')
7 | parser.add_argument('--model1_path', default='', type=str, help='model path')
8 | parser.add_argument('--model2_path', default='', type=str, help='model path')
9 | parser.add_argument('--asse_path', default='', type=str, help='asse path')
10 | args = parser.parse_args()
11 |
12 | model1 = torch.load(args.model1_path).cuda()
13 | model2 = torch.load(args.model2_path).cuda()
14 |
15 | # architecture
16 | print('=================> Architecture Assembling')
17 | layer = 0
18 | for module1, module2 in zip(model1.modules(), model2.modules()):
19 | if isinstance(module1, torch.nn.Conv2d):
20 | if layer == 0:
21 | module1.out_channels += module2.out_channels
22 | else:
23 | module1.in_channels += module2.in_channels
24 | module1.out_channels += module2.out_channels
25 | layer += 1
26 | if isinstance(module1, torch.nn.Linear):
27 | module1.in_features += module2.in_features
28 | module1.out_features += module2.out_features
29 | if isinstance(module1, torch.nn.BatchNorm2d):
30 | module1.num_features += module2.num_features
31 | module1.running_mean.data = torch.cat([module1.running_mean.data, module2.running_mean.data], dim=0)
32 | module1.running_var.data = torch.cat([module1.running_var.data, module2.running_var.data], dim=0)
33 | print(model1)
34 |
35 | # parameter
36 | print('=================> Parameter Assembling')
37 | layer = 0
38 | for p1, p2 in zip(model1.parameters(), model2.parameters()):
39 | if len(p1.shape) > 2:
40 | if layer == 0:
41 | p1.data = torch.cat([p1, p2], dim=0)
42 | else:
43 | p1b = torch.zeros(p1.shape[0], p2.shape[1], p1.shape[2], p1.shape[2]).cuda()
44 | p2b = torch.zeros(p2.shape[0], p1.shape[1], p2.shape[2], p2.shape[2]).cuda()
45 | p1.data = torch.cat([p1, p1b], dim=1)
46 | p2.data = torch.cat([p2b, p2], dim=1)
47 | p1.data = torch.cat([p1, p2], dim=0)
48 | layer += 1
49 | elif len(p1.shape) > 1:
50 | p1b = torch.zeros(p1.shape[0], p2.shape[1]).cuda()
51 | p2b = torch.zeros(p2.shape[0], p1.shape[1]).cuda()
52 | p1.data = torch.cat([p1, p1b], dim=1)
53 | p2.data = torch.cat([p2b, p2], dim=1)
54 | p1.data = torch.cat([p1, p2], dim=0)
55 | else:
56 | p1.data = torch.cat([p1, p2], dim=0)
57 | print('=', p1.shape)
58 |
59 | # save model
60 | torch.save(model1, args.asse_path)
61 |
62 |
63 | if __name__ == '__main__':
64 | main()
65 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | ### Example user template template
2 | ### Example user template
3 |
4 | # IntelliJ project files
5 | .idea
6 | *.iml
7 | out
8 | gen
9 | ### Python template
10 | # Byte-compiled / optimized / DLL files
11 | __pycache__/
12 | *.py[cod]
13 | *$py.class
14 |
15 | # C extensions
16 | *.so
17 |
18 | # Distribution / packaging
19 | .Python
20 | build/
21 | develop-eggs/
22 | dist/
23 | downloads/
24 | eggs/
25 | .eggs/
26 | lib/
27 | lib64/
28 | parts/
29 | sdist/
30 | var/
31 | wheels/
32 | share/python-wheels/
33 | *.egg-info/
34 | .installed.cfg
35 | *.egg
36 | MANIFEST
37 |
38 | # PyInstaller
39 | # Usually these files are written by a python script from a template
40 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
41 | *.manifest
42 | *.spec
43 |
44 | # Installer logs
45 | pip-log.txt
46 | pip-delete-this-directory.txt
47 |
48 | # Unit test / coverage reports
49 | htmlcov/
50 | .tox/
51 | .nox/
52 | .coverage
53 | .coverage.*
54 | .cache
55 | nosetests.xml
56 | coverage.xml
57 | *.cover
58 | *.py,cover
59 | .hypothesis/
60 | .pytest_cache/
61 | cover/
62 |
63 | # Translations
64 | *.mo
65 | *.pot
66 |
67 | # Django stuff:
68 | *.log
69 | local_settings.py
70 | db.sqlite3
71 | db.sqlite3-journal
72 |
73 | # Flask stuff:
74 | instance/
75 | .webassets-cache
76 |
77 | # Scrapy stuff:
78 | .scrapy
79 |
80 | # Sphinx documentation
81 | docs/_build/
82 |
83 | # PyBuilder
84 | .pybuilder/
85 | target/
86 |
87 | # Jupyter Notebook
88 | .ipynb_checkpoints
89 |
90 | # IPython
91 | profile_default/
92 | ipython_config.py
93 |
94 | # pyenv
95 | # For a library or package, you might want to ignore these files since the code is
96 | # intended to run in multiple environments; otherwise, check them in:
97 | # .python-version
98 |
99 | # pipenv
100 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
101 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
102 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
103 | # install all needed dependencies.
104 | #Pipfile.lock
105 |
106 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
107 | __pypackages__/
108 |
109 | # Celery stuff
110 | celerybeat-schedule
111 | celerybeat.pid
112 |
113 | # SageMath parsed files
114 | *.sage.py
115 |
116 | # Environments
117 | .env
118 | .venv
119 | env/
120 | venv/
121 | ENV/
122 | env.bak/
123 | venv.bak/
124 |
125 | # Spyder project settings
126 | .spyderproject
127 | .spyproject
128 |
129 | # Rope project settings
130 | .ropeproject
131 |
132 | # mkdocs documentation
133 | /site
134 |
135 | # mypy
136 | .mypy_cache/
137 | .dmypy.json
138 | dmypy.json
139 |
140 | # Pyre type checker
141 | .pyre/
142 |
143 | # pytype static type analyzer
144 | .pytype/
145 |
146 | # Cython debug symbols
147 | cython_debug/
148 |
149 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Model LEGO: Creating Models Like Disassembling and Assembling Building Blocks
2 |
3 |
4 |
5 |

6 |
7 | For more information,
8 | please visit https://model-lego.github.io/.
9 |
10 | ## Requirements
11 |
12 | + Python Version: 3.9
13 | + PyTorch Version: 2.0.1
14 | + GPU: NVIDIA RTX A6000 / NVIDIA A40
15 |
16 | ## Quick Start
17 |
18 | ### Prepare the Source Models
19 |
20 | * Train a Pre-trained Model:
21 |
22 | ```bash
23 | python engines/train.py \
24 | --model_name 'vgg16' \
25 | --data_name 'cifar10' \
26 | --num_classes 10 \
27 | --num_epochs 200 \
28 | --model_dir ${model_dir} \
29 | --data_train_dir ${data_train_dir} \
30 | --data_test_dir ${data_test_dir} \
31 | --log_dir ${log_dir}
32 | ```
33 |
34 | ### Model Disassembling
35 |
36 | 
37 |
38 | * Select the Top 1% of Samples with High Confidence:
39 |
40 | ```bash
41 | python core/sample_select.py \
42 | --model_name 'vgg16' \
43 | --data_name 'cifar10' \
44 | --num_classes 10 \
45 | --model_path ${model_path} \
46 | --data_dir ${data_dir} \
47 | --save_dir ${save_dir} \
48 | --num_samples 50
49 | ```
50 |
51 | * Relevant Features Identifying (\alpha and \beta can be configured in core/relevant_feature_identifying.py):
52 |
53 | ```bash
54 | python core/relevant_feature_identifying.py \
55 | --model_name 'vgg16' \
56 | --data_name cifar10 \
57 | --num_classes 10 \
58 | --model_path ${model_path} \
59 | --data_dir ${data_dir} \
60 | --save_dir ${save_dir}
61 | ```
62 |
63 | * Parameter Linking and Model Assembling (output the disassembled task-aware component):
64 |
65 | ```bash
66 | python core/model_disassemble.py \
67 | --model_name 'vgg16' \
68 | --num_classes 10 \
69 | --model_path ${model_path} \
70 | --mask_dir ${mask_dir} \
71 | --save_dir ${save_dir} \
72 | --disa_layers ${disa_layers} \
73 | --disa_labels ${disa_labels}
74 | ```
75 |
76 | ### Model Assembling
77 |
78 | 
79 |
80 | * Parameter Scaling (optional):
81 |
82 | ```bash
83 | python core/parameter_scaling.py \
84 | --model_name 'vgg16' \
85 | --data_name 'cifar10' \
86 | --num_classes 10 \
87 | --model_path ${model_path} \
88 | --data_dir ${data_dir}
89 | ```
90 |
91 | * Alignment Padding and Model Assembling (output the assembled model):
92 |
93 | ```bash
94 | python core/model_assemble.py \
95 | --model1_path ${model1_path} \
96 | --model2_path ${model2_path} \
97 | --asse_path ${asse_path}
98 | ```
99 |
100 | ### Others
101 |
102 | * Evaluate the Accuracy of the Model or Task-aware Component:
103 |
104 | ```bash
105 | python engines/test.py \
106 | --model_name 'vgg16' \
107 | --data_name cifar10 \
108 | --num_classes 10 \
109 | --model_path ${model_path} \
110 | --data_dir ${data_dir}
111 | ```
112 |
113 | * Visualize Model Decision Routes:
114 |
115 | ```bash
116 | python core/model_decision_route_visualizing.py \
117 | --mask_dir ${mask_dir} \
118 | --layers ${layers} \
119 | --labels ${labels}
120 | ```
121 |
--------------------------------------------------------------------------------
/engines/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import time
4 |
5 | import torch
6 | from torch import nn
7 | import collections
8 |
9 | import loaders
10 | import models
11 | import metrics
12 | from utils.train_util import AverageMeter, ProgressMeter
13 |
14 | from thop import profile
15 |
16 |
17 | def main():
18 | parser = argparse.ArgumentParser(description='')
19 | parser.add_argument('--model_name', default='', type=str, help='model name')
20 | parser.add_argument('--data_name', default='', type=str, help='data name')
21 | parser.add_argument('--num_classes', default='', type=int, help='num classes')
22 | parser.add_argument('--model_path', default='', type=str, help='model path')
23 | parser.add_argument('--data_dir', default='', type=str, help='data directory')
24 | args = parser.parse_args()
25 |
26 | # ----------------------------------------
27 | # basic configuration
28 | # ----------------------------------------
29 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
30 |
31 | print('-' * 50)
32 | print('TEST ON:', device)
33 | print('MODEL PATH:', args.model_path)
34 | print('DATA PATH:', args.data_dir)
35 | print('-' * 50)
36 |
37 | # ----------------------------------------
38 | # trainer configuration
39 | # ----------------------------------------
40 | state = torch.load(args.model_path)
41 | if isinstance(state, collections.OrderedDict):
42 | model = models.load_model(args.model_name, num_classes=args.num_classes)
43 | model.load_state_dict(state)
44 | else:
45 | model = state
46 | model.to(device)
47 |
48 | test_loader = loaders.load_data(args.data_dir, args.data_name, data_type='test')
49 |
50 | criterion = nn.CrossEntropyLoss()
51 |
52 | # ----------------------------------------
53 | # speed
54 | # ----------------------------------------
55 | speed(model, device)
56 |
57 | # ----------------------------------------
58 | # each epoch
59 | # ----------------------------------------
60 | # since = time.time()
61 |
62 | loss, acc1, acc5, class_acc = test(test_loader, model, criterion, device)
63 |
64 | print('-' * 50)
65 | print(class_acc)
66 | print('AVG:', acc1.avg)
67 | # print('TIME CONSUMED', time.time() - since)
68 |
69 |
70 | def test(test_loader, model, criterion, device):
71 | loss_meter = AverageMeter('Loss', ':.4e')
72 | acc1_meter = AverageMeter('Acc@1', ':6.2f')
73 | acc5_meter = AverageMeter('Acc@5', ':6.2f')
74 | progress = ProgressMeter(total=len(test_loader), step=20, prefix='Test',
75 | meters=[loss_meter, acc1_meter, acc5_meter])
76 | class_acc = metrics.ClassAccuracy()
77 | model.eval()
78 |
79 | for i, samples in enumerate(test_loader):
80 | inputs, labels, _ = samples
81 | inputs = inputs.to(device)
82 | labels = labels.to(device)
83 |
84 | with torch.set_grad_enabled(False):
85 | outputs = model(inputs)
86 | loss = criterion(outputs, labels)
87 | acc1, acc5 = metrics.accuracy(outputs, labels, topk=(1, 1))
88 | class_acc.update(outputs, labels)
89 |
90 | loss_meter.update(loss.item(), inputs.size(0))
91 | acc1_meter.update(acc1.item(), inputs.size(0))
92 | acc5_meter.update(acc5.item(), inputs.size(0))
93 |
94 | progress.display(i)
95 |
96 | return loss_meter, acc1_meter, acc5_meter, class_acc
97 |
98 |
99 | def speed(model, device):
100 | # model.eval()
101 |
102 | flops, params = profile(model, inputs=(torch.randn(1, 3, 32, 32).to(device),))
103 | print('FLOPs = ' + str(flops / 1000 ** 3) + 'G')
104 | print('Params = ' + str(params / 1000 ** 2) + 'M')
105 |
106 |
107 | if __name__ == '__main__':
108 | main()
109 |
--------------------------------------------------------------------------------
/core/parameter_scaling.py:
--------------------------------------------------------------------------------
1 | import collections
2 |
3 | import torch
4 | import argparse
5 | import numpy as np
6 | from tqdm import tqdm
7 |
8 | import loaders
9 | import models
10 |
11 |
12 | class ScoreStatistic:
13 | def __init__(self, num_classes):
14 | self.scores = [[] for i in range(num_classes)]
15 | self.nums = torch.zeros(num_classes, dtype=torch.long)
16 |
17 | def __call__(self, outputs, labels):
18 | scores, predicts = torch.max(outputs.detach(), dim=1)
19 |
20 | for i, label in enumerate(labels):
21 | if label == predicts[i]:
22 | self.scores[label].append(scores[i].detach().cpu().numpy())
23 | self.nums[label] += 1
24 |
25 | def display_score(self, save_path):
26 | max_num = self.nums.max()
27 | for i in range(len(self.scores)):
28 | if len(self.scores[i]) != max_num:
29 | self.scores[i] = self.scores[i] + [0 for _ in range(max_num - len(self.scores[i]))]
30 | scores = torch.from_numpy(np.asarray(self.scores))
31 | scores_class = torch.sum(scores, dim=1) / self.nums
32 | fc_ratio = self.nums / torch.sum(scores, dim=1)
33 | np.save(save_path, fc_ratio.numpy())
34 |
35 | print('AVG SCORE RATIO: ', scores_class)
36 | print('Reciprocal AVG SCORE RATIO: ', fc_ratio)
37 | print('PICTURE NUM: ', self.nums)
38 | return fc_ratio
39 |
40 |
41 | def main():
42 | parser = argparse.ArgumentParser(description='')
43 | parser.add_argument('--model_name', default='', type=str, help='model name')
44 | parser.add_argument('--data_name', default='', type=str, help='data name')
45 | parser.add_argument('--num_classes', default='', type=int, help='num classes')
46 | parser.add_argument('--model_path', default='', type=str, help='model path')
47 | parser.add_argument('--data_dir', default='', type=str, help='data dir')
48 | args = parser.parse_args()
49 |
50 | # ----------------------------------------
51 | # basic configuration
52 | # ----------------------------------------
53 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
54 |
55 | print('-' * 100)
56 | print('SCALE ON:', device)
57 | print('MODEL PATH:', args.model_path)
58 | print('DATA DIR:', args.data_dir)
59 |
60 | # ----------------------------------------
61 | # model/data configuration
62 | # ----------------------------------------
63 | state = torch.load(args.model_path)
64 | if isinstance(state, collections.OrderedDict):
65 | model = models.load_model(args.model_name)
66 | model.load_state_dict(state)
67 | else:
68 | model = state
69 | model.to(device)
70 | model.eval()
71 |
72 | data_loader = loaders.load_data(args.data_dir, args.data_name, data_type='test')
73 |
74 | score_statistic = ScoreStatistic(num_classes=args.num_classes)
75 |
76 | # ----------------------------------------
77 | # forward
78 | # ----------------------------------------
79 | for samples in tqdm(data_loader):
80 | inputs, labels, _ = samples
81 | inputs = inputs.to(device)
82 | labels = labels.to(device)
83 | outputs = model(inputs)
84 |
85 | score_statistic(outputs=outputs, labels=labels)
86 |
87 | score_ratio = score_statistic.display_score(
88 | save_path=args.model_path.split('.')[0] + '.npy')
89 |
90 | # ----------------------------------------
91 | # parameter scaling
92 | # ----------------------------------------
93 | layer = 0
94 | last_layer = len(models.load_modules(model=model))
95 | for para in model.parameters():
96 | if len(para.shape) > 2: # conv
97 | layer += 1
98 | elif len(para.shape) > 1: # linear
99 | if layer == last_layer - 1:
100 | para.data = score_ratio.view(-1, 1).float().cuda() * para.data
101 | layer += 1
102 | else: # bias
103 | if layer == last_layer:
104 | para.data = score_ratio.view(-1).float().cuda() * para.data
105 |
106 | scale_model_path = args.model_path.split('.')[0] + '_scale.pth'
107 | torch.save(model, scale_model_path)
108 |
109 | print('RESCALE MODEL PATH:', scale_model_path)
110 | print('-' * 50)
111 |
112 |
113 | if __name__ == '__main__':
114 | main()
115 |
--------------------------------------------------------------------------------
/core/model_disassemble.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 |
4 | import torch
5 | import torch_pruning as tp
6 |
7 | import models
8 |
9 |
10 | def disassemble():
11 | parser = argparse.ArgumentParser(description='')
12 | parser.add_argument('--model_name', default='', type=str, help='model name')
13 | parser.add_argument('--num_classes', default='', type=int, help='num classes')
14 | parser.add_argument('--model_path', default='', type=str, help='model path')
15 | parser.add_argument('--save_dir', default='', type=str, help='save dir')
16 | parser.add_argument('--mask_dir', default='', type=str, help='mask dir')
17 | parser.add_argument('--disa_layers', default='', nargs='+', type=int, help='disa layers')
18 | parser.add_argument('--disa_labels', default='', nargs='+', type=int, help='disa labels')
19 | args = parser.parse_args()
20 |
21 | # ----------------------------------------
22 | # basic configuration
23 | # ----------------------------------------
24 | print('-' * 50)
25 | print('SAVE DIR:', args.save_dir)
26 | print('-' * 50)
27 |
28 | # ----------------------------------------
29 | # model configuration
30 | # ----------------------------------------
31 | model = models.load_model(args.model_name, num_classes=args.num_classes)
32 | model.load_state_dict(torch.load(args.model_path, map_location=torch.device('cpu')))
33 | # model = torch.load(args.model_path).cpu()
34 |
35 | modules = models.load_modules(model=model, model_layers=None)
36 |
37 | # ----------------------------------------
38 | # disa configuration
39 | # ----------------------------------------
40 | mask_path = os.path.join(args.mask_dir, 'mask_layer{}.pt')
41 |
42 | if args.disa_layers[0] == -1:
43 | args.disa_layers = [i for i in range(len(modules) - 1)]
44 |
45 | print('disassembling layers:', args.disa_layers)
46 | print('disassembling labels:', args.disa_labels)
47 |
48 | # ----------------------------------------
49 | # model disassemble
50 | # ----------------------------------------
51 | DG = tp.DependencyGraph().build_dependency(model, example_inputs=torch.randn(1, 3, 32, 32))
52 |
53 | ###############################
54 | # layers 1-N: input channels
55 | ###############################
56 | for layer in args.disa_layers:
57 | print('===> LAYER', layer)
58 | print('--->', modules[layer])
59 |
60 | # idxs
61 | mask_total_i = None
62 | mask_i = torch.load(mask_path.format(layer))
63 | for label in args.disa_labels:
64 | if mask_total_i is None:
65 | mask_total_i = mask_i[label]
66 | else:
67 | mask_total_i = torch.bitwise_or(mask_i[label], mask_total_i)
68 | idxs = torch.where(mask_total_i == 0)[0].tolist()
69 |
70 | # structure pruning
71 | prune_fn = None
72 | if isinstance(modules[layer], torch.nn.Conv2d):
73 | prune_fn = tp.prune_conv_in_channels
74 | if isinstance(modules[layer], torch.nn.Linear):
75 | prune_fn = tp.prune_linear_in_channels
76 | group = DG.get_pruning_group(modules[layer], prune_fn, idxs=idxs)
77 | if DG.check_pruning_group(group):
78 | group.prune()
79 | print('--->', modules[layer])
80 |
81 | ###############################
82 | # layer N: output channels
83 | ###############################
84 | # layer = 0
85 | # print('--->', modules[layer])
86 | #
87 | # # idxs
88 | # mask_i = torch.load(mask_path.format(-1))
89 | # mask_total_i = None
90 | # for label in args.disa_labels:
91 | # if mask_total_i is None:
92 | # mask_total_i = mask_i[label]
93 | # else:
94 | # mask_total_i = torch.bitwise_or(mask_i[label], mask_total_i)
95 | # idxs = np.where(mask_total_i == 0)[0].tolist()
96 | #
97 | # # structure pruning
98 | # prune_fn = tp.prune_linear_out_channels
99 | # group = DG.get_pruning_group(modules[layer], prune_fn, idxs=idxs)
100 | # if DG.check_pruning_group(group):
101 | # group.prune()
102 | # print('--->', modules[layer])
103 |
104 | ###############################
105 | # save model
106 | ###############################
107 | model.zero_grad()
108 | result_path = os.path.join(args.save_dir, 'model_disa.pth')
109 | torch.save(model, result_path)
110 | print(model)
111 |
112 |
113 | if __name__ == '__main__':
114 | disassemble()
115 |
--------------------------------------------------------------------------------
/loaders/image_loader.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import DataLoader
2 | from torchvision import transforms
3 | from loaders.datasets import ImageDataset
4 |
5 | mnist_train_transform = transforms.Compose([
6 | transforms.Resize((32, 32)),
7 | transforms.ToTensor(),
8 | transforms.Normalize((0.5, 0.5, 0.5),
9 | (0.5, 0.5, 0.5)),
10 | ])
11 |
12 | mnist_test_transform = transforms.Compose([
13 | transforms.Resize((32, 32)),
14 | transforms.ToTensor(),
15 | transforms.Normalize((0.5, 0.5, 0.5),
16 | (0.5, 0.5, 0.5)),
17 | ])
18 |
19 | cifar10_train_transform = transforms.Compose([
20 | transforms.RandomCrop(32, padding=4),
21 | transforms.Resize((32, 32)),
22 | # transforms.Resize((256, 256)),
23 | transforms.RandomHorizontalFlip(),
24 | transforms.ToTensor(),
25 | # transforms.Normalize((0.4914, 0.4822, 0.4465),
26 | # (0.2023, 0.1994, 0.2010)),
27 | transforms.Normalize((0.49139968, 0.48215827, 0.44653124),
28 | (0.24703233, 0.24348505, 0.26158768)),
29 | ])
30 |
31 | cifar10_test_transform = transforms.Compose([
32 | transforms.Resize((32, 32)),
33 | # transforms.Resize((256, 256)),
34 | transforms.ToTensor(),
35 | # transforms.Normalize((0.4914, 0.4822, 0.4465),
36 | # (0.2023, 0.1994, 0.2010)),
37 | transforms.Normalize((0.49139968, 0.48215827, 0.44653124),
38 | (0.24703233, 0.24348505, 0.26158768)),
39 | ])
40 |
41 | tiny_imagenet_train_transform = transforms.Compose([
42 | transforms.RandomResizedCrop((64, 64)),
43 | transforms.RandomHorizontalFlip(),
44 | transforms.ToTensor(),
45 | transforms.Normalize((0.4802, 0.4481, 0.3975),
46 | (0.2770, 0.2691, 0.2821))
47 | ])
48 |
49 | tiny_imagenet_test_transform = transforms.Compose([
50 | transforms.Resize((64, 64)),
51 | transforms.ToTensor(),
52 | transforms.Normalize((0.4802, 0.4481, 0.3975),
53 | (0.2770, 0.2691, 0.2821))
54 | ])
55 |
56 | imagenet_train_transform = transforms.Compose([
57 | transforms.RandomResizedCrop((224, 224)),
58 | transforms.RandomHorizontalFlip(),
59 | transforms.ToTensor(),
60 | transforms.Normalize((0.485, 0.456, 0.406),
61 | (0.229, 0.224, 0.225))
62 | ])
63 |
64 | imagenet_test_transform = transforms.Compose([
65 | transforms.Resize((224, 224)),
66 | transforms.ToTensor(),
67 | transforms.Normalize((0.485, 0.456, 0.406),
68 | (0.229, 0.224, 0.225))
69 | ])
70 |
71 |
72 | def _get_set(data_path, transform):
73 | return ImageDataset(image_dir=data_path,
74 | transform=transform)
75 |
76 |
77 | def load_images(data_dir, data_name, data_type=None):
78 | assert data_name in ['mnist', 'fashion-mnist', 'cifar10', 'cifar100', 'tiny-imagenet', 'imagenet']
79 | assert data_type is None or data_type in ['train', 'test']
80 |
81 | data_transform = None
82 | if data_name == 'mnist' and data_type == 'train':
83 | data_transform = mnist_train_transform
84 | elif data_name == 'mnist' and data_type == 'test':
85 | data_transform = mnist_test_transform
86 | elif data_name == 'cifar10' and data_type == 'train':
87 | data_transform = cifar10_train_transform
88 | elif data_name == 'cifar10' and data_type == 'test':
89 | data_transform = cifar10_test_transform
90 | elif data_name == 'cifar100' and data_type == 'train':
91 | data_transform = cifar10_train_transform
92 | elif data_name == 'cifar100' and data_type == 'test':
93 | data_transform = cifar10_test_transform
94 | elif data_name == 'tiny-imagenet' and data_type == 'train':
95 | data_transform = tiny_imagenet_train_transform
96 | elif data_name == 'tiny-imagenet' and data_type == 'test':
97 | data_transform = tiny_imagenet_test_transform
98 | elif data_name == 'imagenet' and data_type == 'train':
99 | data_transform = imagenet_train_transform
100 | elif data_name == 'imagenet' and data_type == 'test':
101 | data_transform = imagenet_test_transform
102 | assert data_transform is not None
103 |
104 | data_set = _get_set(data_dir, transform=data_transform)
105 | data_loader = DataLoader(dataset=data_set,
106 | batch_size=256,
107 | num_workers=4,
108 | shuffle=True)
109 | # ImageNet+VGG16: bs128->gpu26311->40days
110 | return data_loader
111 |
--------------------------------------------------------------------------------
/models/simplenetv1.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 |
5 | class SimpleNet(nn.Module):
6 | def __init__(self, in_channels=3, num_classes=10):
7 | super(SimpleNet, self).__init__()
8 | # print(simpnet_name)
9 | self.features = self._make_layers(in_channels) # self._make_layers(cfg[simpnet_name])
10 | self.classifier = nn.Linear(256, num_classes)
11 | self.drp = nn.Dropout(0.1)
12 |
13 | def forward(self, x):
14 | out = self.features(x)
15 |
16 | # Global Max Pooling
17 | out = F.max_pool2d(out, kernel_size=out.size()[2:])
18 | # out = F.dropout2d(out, 0.1, training=True)
19 | out = self.drp(out)
20 |
21 | out = out.view(out.size(0), -1)
22 | out = self.classifier(out)
23 | return out
24 |
25 | def _make_layers(self, in_channels):
26 |
27 | model = nn.Sequential(
28 | nn.Conv2d(in_channels, 64, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)),
29 | nn.BatchNorm2d(64, eps=1e-05, momentum=0.05, affine=True),
30 | nn.ReLU(inplace=True),
31 |
32 | nn.Conv2d(64, 128, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)),
33 | nn.BatchNorm2d(128, eps=1e-05, momentum=0.05, affine=True),
34 | nn.ReLU(inplace=True),
35 |
36 | nn.Conv2d(128, 128, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)),
37 | nn.BatchNorm2d(128, eps=1e-05, momentum=0.05, affine=True),
38 | nn.ReLU(inplace=True),
39 |
40 | nn.Conv2d(128, 128, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)),
41 | nn.BatchNorm2d(128, eps=1e-05, momentum=0.05, affine=True),
42 | nn.ReLU(inplace=True),
43 |
44 | nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False),
45 | nn.Dropout2d(p=0.1),
46 |
47 | nn.Conv2d(128, 128, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)),
48 | nn.BatchNorm2d(128, eps=1e-05, momentum=0.05, affine=True),
49 | nn.ReLU(inplace=True),
50 |
51 | nn.Conv2d(128, 128, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)),
52 | nn.BatchNorm2d(128, eps=1e-05, momentum=0.05, affine=True),
53 | nn.ReLU(inplace=True),
54 |
55 | nn.Conv2d(128, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)),
56 | nn.BatchNorm2d(256, eps=1e-05, momentum=0.05, affine=True),
57 | nn.ReLU(inplace=True),
58 |
59 | nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False),
60 | nn.Dropout2d(p=0.1),
61 |
62 | nn.Conv2d(256, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)),
63 | nn.BatchNorm2d(256, eps=1e-05, momentum=0.05, affine=True),
64 | nn.ReLU(inplace=True),
65 |
66 | nn.Conv2d(256, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)),
67 | nn.BatchNorm2d(256, eps=1e-05, momentum=0.05, affine=True),
68 | nn.ReLU(inplace=True),
69 |
70 | nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False),
71 | nn.Dropout2d(p=0.1),
72 |
73 | nn.Conv2d(256, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)),
74 | nn.BatchNorm2d(512, eps=1e-05, momentum=0.05, affine=True),
75 | nn.ReLU(inplace=True),
76 |
77 | nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False),
78 | nn.Dropout2d(p=0.1),
79 |
80 | nn.Conv2d(512, 2048, kernel_size=[1, 1], stride=(1, 1), padding=(0, 0)),
81 | nn.BatchNorm2d(2048, eps=1e-05, momentum=0.05, affine=True),
82 | nn.ReLU(inplace=True),
83 |
84 | nn.Conv2d(2048, 256, kernel_size=[1, 1], stride=(1, 1), padding=(0, 0)),
85 | nn.BatchNorm2d(256, eps=1e-05, momentum=0.05, affine=True),
86 | nn.ReLU(inplace=True),
87 |
88 | nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False),
89 | nn.Dropout2d(p=0.1),
90 |
91 | nn.Conv2d(256, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)),
92 | nn.BatchNorm2d(256, eps=1e-05, momentum=0.05, affine=True),
93 | nn.ReLU(inplace=True),
94 |
95 | )
96 |
97 | for m in model.modules():
98 | if isinstance(m, nn.Conv2d):
99 | nn.init.xavier_uniform_(m.weight.data, gain=nn.init.calculate_gain('relu'))
100 |
101 | return model
102 |
103 |
104 | def simplenet(in_channels=3, num_classes=10):
105 | return SimpleNet(in_channels=in_channels, num_classes=num_classes)
--------------------------------------------------------------------------------
/models/googlenet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class Inception(nn.Module):
6 | def __init__(self, input_channels, n1x1, n3x3_reduce, n3x3, n5x5_reduce, n5x5, pool_proj):
7 | super().__init__()
8 |
9 | # 1x1conv branch
10 | self.b1 = nn.Sequential(
11 | nn.Conv2d(input_channels, n1x1, kernel_size=1),
12 | nn.BatchNorm2d(n1x1),
13 | nn.ReLU(inplace=True)
14 | )
15 |
16 | # 1x1conv -> 3x3conv branch
17 | self.b2 = nn.Sequential(
18 | nn.Conv2d(input_channels, n3x3_reduce, kernel_size=1),
19 | nn.BatchNorm2d(n3x3_reduce),
20 | nn.ReLU(inplace=True),
21 | nn.Conv2d(n3x3_reduce, n3x3, kernel_size=3, padding=1),
22 | nn.BatchNorm2d(n3x3),
23 | nn.ReLU(inplace=True)
24 | )
25 |
26 | # 1x1conv -> 5x5conv branch
27 | # we use 2 3x3 conv filters stacked instead
28 | # of 1 5x5 filters to obtain the same receptive
29 | # field with fewer parameters
30 | self.b3 = nn.Sequential(
31 | nn.Conv2d(input_channels, n5x5_reduce, kernel_size=1),
32 | nn.BatchNorm2d(n5x5_reduce),
33 | nn.ReLU(inplace=True),
34 | nn.Conv2d(n5x5_reduce, n5x5, kernel_size=3, padding=1),
35 | nn.BatchNorm2d(n5x5, n5x5),
36 | nn.ReLU(inplace=True),
37 | nn.Conv2d(n5x5, n5x5, kernel_size=3, padding=1),
38 | nn.BatchNorm2d(n5x5),
39 | nn.ReLU(inplace=True)
40 | )
41 |
42 | # 3x3pooling -> 1x1conv
43 | # same conv
44 | self.b4 = nn.Sequential(
45 | nn.MaxPool2d(3, stride=1, padding=1),
46 | nn.Conv2d(input_channels, pool_proj, kernel_size=1),
47 | nn.BatchNorm2d(pool_proj),
48 | nn.ReLU(inplace=True)
49 | )
50 |
51 | def forward(self, x):
52 | return torch.cat([self.b1(x), self.b2(x), self.b3(x), self.b4(x)], dim=1)
53 |
54 |
55 | class GoogleNet(nn.Module):
56 |
57 | def __init__(self, in_channels=3, num_classes=10):
58 | super().__init__()
59 | self.prelayer = nn.Sequential(
60 | nn.Conv2d(in_channels, 64, kernel_size=3, padding=1, bias=False),
61 | nn.BatchNorm2d(64),
62 | nn.ReLU(inplace=True),
63 | nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=False),
64 | nn.BatchNorm2d(64),
65 | nn.ReLU(inplace=True),
66 | nn.Conv2d(64, 192, kernel_size=3, padding=1, bias=False),
67 | nn.BatchNorm2d(192),
68 | nn.ReLU(inplace=True),
69 | )
70 |
71 | # although we only use 1 conv layer as prelayer,
72 | # we still use name a3, b3.......
73 | self.a3 = Inception(192, 64, 96, 128, 16, 32, 32)
74 | self.b3 = Inception(256, 128, 128, 192, 32, 96, 64)
75 |
76 | ##"""In general, an Inception network is a network consisting of
77 | ##modules of the above type stacked upon each other, with occasional
78 | ##max-pooling layers with stride 2 to halve the resolution of the
79 | ##grid"""
80 | self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
81 |
82 | self.a4 = Inception(480, 192, 96, 208, 16, 48, 64)
83 | self.b4 = Inception(512, 160, 112, 224, 24, 64, 64)
84 | self.c4 = Inception(512, 128, 128, 256, 24, 64, 64)
85 | self.d4 = Inception(512, 112, 144, 288, 32, 64, 64)
86 | self.e4 = Inception(528, 256, 160, 320, 32, 128, 128)
87 |
88 | self.a5 = Inception(832, 256, 160, 320, 32, 128, 128)
89 | self.b5 = Inception(832, 384, 192, 384, 48, 128, 128)
90 |
91 | # input feature size: 8*8*1024
92 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
93 | self.dropout = nn.Dropout2d(p=0.4)
94 | self.linear = nn.Linear(1024, num_classes)
95 |
96 | def forward(self, x):
97 | x = self.prelayer(x)
98 | x = self.maxpool(x)
99 | x = self.a3(x)
100 | x = self.b3(x)
101 |
102 | x = self.maxpool(x)
103 |
104 | x = self.a4(x)
105 | x = self.b4(x)
106 | x = self.c4(x)
107 | x = self.d4(x)
108 | x = self.e4(x)
109 |
110 | x = self.maxpool(x)
111 |
112 | x = self.a5(x)
113 | x = self.b5(x)
114 |
115 | # """It was found that a move from fully connected layers to
116 | # average pooling improved the top-1 accuracy by about 0.6%,
117 | # however the use of dropout remained essential even after
118 | # removing the fully connected layers."""
119 | x = self.avgpool(x)
120 | x = self.dropout(x)
121 | x = x.view(x.size()[0], -1)
122 | x = self.linear(x)
123 |
124 | return x
125 |
126 |
127 | def googlenet(in_channels=3, num_classes=10):
128 | return GoogleNet(in_channels=in_channels, num_classes=num_classes)
129 |
--------------------------------------------------------------------------------
/core/sample_select.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import argparse
4 | from tqdm import tqdm
5 |
6 | import loaders
7 | import models
8 | from utils import file_util
9 |
10 |
11 | class SampleSift:
12 | def __init__(self, num_classes, num_samples, is_high_confidence=True):
13 | self.names = [[None for j in range(num_samples)] for i in range(num_classes)]
14 | self.scores = torch.zeros((num_classes, num_samples))
15 | self.nums = torch.zeros(num_classes, dtype=torch.long)
16 | self.num_classes = num_classes
17 | self.num_samples = num_samples
18 | self.is_high_confidence = is_high_confidence
19 |
20 | def __call__(self, outputs, labels, names):
21 | softmaxs = torch.nn.Softmax(dim=1)(outputs.detach())
22 | # print(scores)
23 |
24 | for i, label in enumerate(labels): # each datas
25 | score = softmaxs[i][label]
26 |
27 | if self.is_high_confidence: # sift high confidence
28 | if self.nums[label] == self.num_samples:
29 | score_min, index = torch.min(self.scores[label], dim=0)
30 | if score > score_min:
31 | self.names[label][index] = names[i]
32 | self.scores[label][index] = score
33 | else:
34 | self.names[label][self.nums[label]] = names[i]
35 | self.scores[label][self.nums[label]] = score
36 | self.nums[label] += 1
37 | else: # sift low confidence
38 | if self.nums[label] == self.num_samples:
39 | score_max, index = torch.max(self.scores[label], dim=0)
40 | if score < score_max:
41 | self.names[label][index] = names[i]
42 | self.scores[label][index] = score
43 | else:
44 | self.names[label][self.nums[label]] = names[i]
45 | self.scores[label][self.nums[label]] = score
46 | self.nums[label] += 1
47 |
48 | def save_image(self, input_path, output_path):
49 | print(self.scores)
50 | print(self.nums)
51 |
52 | class_names = sorted([d.name for d in os.scandir(input_path) if d.is_dir()])
53 | print(class_names)
54 |
55 | for label, image_list in enumerate(self.names):
56 | for image in tqdm(image_list):
57 | class_name = class_names[label]
58 |
59 | src_path = os.path.join(input_path, class_name, str(image))
60 | dst_path = os.path.join(output_path, class_name, str(image))
61 | file_util.copy_file(src_path, dst_path)
62 |
63 |
64 | def main():
65 | parser = argparse.ArgumentParser(description='')
66 | parser.add_argument('--model_name', default='', type=str, help='model name')
67 | parser.add_argument('--data_name', default='', type=str, help='data name')
68 | parser.add_argument('--num_classes', default='', type=int, help='num classes')
69 | parser.add_argument('--model_path', default='', type=str, help='model path')
70 | parser.add_argument('--data_dir', default='', type=str, help='data dir')
71 | parser.add_argument('--save_dir', default='', type=str, help='sift dir')
72 | parser.add_argument('--num_samples', default=10, type=int, help='num samples')
73 | args = parser.parse_args()
74 |
75 | # ----------------------------------------
76 | # basic configuration
77 | # ----------------------------------------
78 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
79 |
80 | if not os.path.exists(args.save_dir):
81 | os.makedirs(args.save_dir)
82 |
83 | print('-' * 50)
84 | print('TRAIN ON:', device)
85 | print('MODEL PATH:', args.model_path)
86 | print('DATA PATH:', args.data_dir)
87 | print('RESULT PATH:', args.save_dir)
88 | print('-' * 50)
89 |
90 | # ----------------------------------------
91 | # model/data configuration
92 | # ----------------------------------------
93 | model = models.load_model(model_name=args.model_name, num_classes=args.num_classes)
94 | model.load_state_dict(torch.load(args.model_path))
95 | # model = torch.load(args.model_path)
96 | model.to(device)
97 | model.eval()
98 |
99 | data_loader = loaders.load_data(data_dir=args.data_dir, data_name=args.data_name, data_type='test')
100 |
101 | sample_sift = SampleSift(num_classes=args.num_classes, num_samples=args.num_samples, is_high_confidence=True)
102 |
103 | # ----------------------------------------
104 | # forward
105 | # ----------------------------------------
106 | for samples in tqdm(data_loader):
107 | inputs, labels, names = samples
108 | inputs = inputs.to(device)
109 | labels = labels.to(device)
110 | with torch.no_grad():
111 | outputs = model(inputs)
112 | sample_sift(outputs=outputs, labels=labels, names=names)
113 |
114 | sample_sift.save_image(args.data_dir, args.save_dir)
115 |
116 |
117 | if __name__ == '__main__':
118 | main()
119 |
--------------------------------------------------------------------------------
/engines/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import shutil
4 | import time
5 | from tqdm import tqdm
6 |
7 | import torch
8 | from torch import nn
9 | from torch import optim
10 | from torch.utils.tensorboard import SummaryWriter
11 |
12 | import loaders
13 | import models
14 | import metrics
15 | from utils.train_util import AverageMeter, ProgressMeter
16 |
17 |
18 | def main():
19 | parser = argparse.ArgumentParser(description='')
20 | parser.add_argument('--model_name', default='', type=str, help='model name')
21 | parser.add_argument('--data_name', default='', type=str, help='data name')
22 | parser.add_argument('--num_classes', default='', type=int, help='num classes')
23 | parser.add_argument('--num_epochs', default=200, type=int, help='num epochs')
24 | parser.add_argument('--model_dir', default='', type=str, help='model dir')
25 | parser.add_argument('--data_train_dir', default='', type=str, help='data dir')
26 | parser.add_argument('--data_test_dir', default='', type=str, help='data dir')
27 | parser.add_argument('--log_dir', default='', type=str, help='log dir')
28 | args = parser.parse_args()
29 |
30 | # ----------------------------------------
31 | # basic configuration
32 | # ----------------------------------------
33 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
34 |
35 | if not os.path.exists(args.model_dir):
36 | os.makedirs(args.model_dir)
37 | if os.path.exists(args.log_dir):
38 | shutil.rmtree(args.log_dir)
39 |
40 | print('-' * 50)
41 | print('TRAIN ON:', device)
42 | print('MODEL DIR:', args.model_dir)
43 | # print('LOG DIR:', args.log_dir)
44 | print('-' * 50)
45 |
46 | # ----------------------------------------
47 | # trainer configuration
48 | # ----------------------------------------
49 | model = models.load_model(args.model_name, num_classes=args.num_classes)
50 | model.to(device)
51 |
52 | train_loader = loaders.load_data(args.data_train_dir, args.data_name, data_type='train')
53 | test_loader = loaders.load_data(args.data_test_dir, args.data_name, data_type='test')
54 |
55 | criterion = nn.CrossEntropyLoss()
56 | optimizer = optim.SGD(params=model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
57 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=args.num_epochs)
58 |
59 | writer = SummaryWriter(args.log_dir)
60 |
61 | # ----------------------------------------
62 | # each epoch
63 | # ----------------------------------------
64 | since = time.time()
65 |
66 | best_acc = None
67 | best_epoch = None
68 |
69 | for epoch in tqdm(range(args.num_epochs)):
70 | print('\n')
71 | loss, acc1, acc5 = train(train_loader, model, criterion, optimizer, device)
72 | writer.add_scalar(tag='training loss', scalar_value=loss.avg, global_step=epoch)
73 | writer.add_scalar(tag='training acc1', scalar_value=acc1.avg, global_step=epoch)
74 | loss, acc1, acc5 = test(test_loader, model, criterion, device)
75 | writer.add_scalar(tag='test loss', scalar_value=loss.avg, global_step=epoch)
76 | writer.add_scalar(tag='test acc1', scalar_value=acc1.avg, global_step=epoch)
77 |
78 | # ----------------------------------------
79 | # save best model
80 | # ----------------------------------------
81 | if best_acc is None or best_acc < acc1.avg:
82 | best_acc = acc1.avg
83 | best_epoch = epoch
84 | torch.save(model.state_dict(), os.path.join(args.model_dir, 'model_ori.pth'))
85 |
86 | scheduler.step()
87 |
88 | print('BEST ACC', best_acc)
89 | print('BEST EPOCH', best_epoch)
90 | print('TIME CONSUMED', time.time() - since)
91 | print('MODEL DIR', args.model_dir)
92 |
93 |
94 | def train(train_loader, model, criterion, optimizer, device):
95 | loss_meter = AverageMeter('Loss', ':.4e')
96 | acc1_meter = AverageMeter('Acc@1', ':6.2f')
97 | acc5_meter = AverageMeter('Acc@5', ':6.2f')
98 | progress = ProgressMeter(total=len(train_loader), step=20, prefix='Training',
99 | meters=[loss_meter, acc1_meter, acc5_meter])
100 |
101 | model.train()
102 |
103 | for i, samples in enumerate(train_loader):
104 | inputs, labels, _ = samples
105 | inputs = inputs.to(device)
106 | labels = labels.to(device)
107 |
108 | outputs = model(inputs)
109 | loss = criterion(outputs, labels)
110 | acc1, acc5 = metrics.accuracy(outputs, labels, topk=(1, 5))
111 |
112 | loss_meter.update(loss.item(), inputs.size(0))
113 | acc1_meter.update(acc1.item(), inputs.size(0))
114 | acc5_meter.update(acc5.item(), inputs.size(0))
115 |
116 | optimizer.zero_grad() # 1
117 | loss.backward() # 2
118 | optimizer.step() # 3
119 |
120 | progress.display(i)
121 |
122 | return loss_meter, acc1_meter, acc5_meter
123 |
124 |
125 | def test(test_loader, model, criterion, device):
126 | loss_meter = AverageMeter('Loss', ':.4e')
127 | acc1_meter = AverageMeter('Acc@1', ':6.2f')
128 | acc5_meter = AverageMeter('Acc@5', ':6.2f')
129 | progress = ProgressMeter(total=len(test_loader), step=20, prefix='Test',
130 | meters=[loss_meter, acc1_meter, acc5_meter])
131 | model.eval()
132 |
133 | for i, samples in enumerate(test_loader):
134 | inputs, labels, _ = samples
135 | inputs = inputs.to(device)
136 | labels = labels.to(device)
137 |
138 | with torch.set_grad_enabled(False):
139 | outputs = model(inputs)
140 | loss = criterion(outputs, labels)
141 | acc1, acc5 = metrics.accuracy(outputs, labels, topk=(1, 5))
142 |
143 | loss_meter.update(loss.item(), inputs.size(0))
144 | acc1_meter.update(acc1.item(), inputs.size(0))
145 | acc5_meter.update(acc5.item(), inputs.size(0))
146 |
147 | progress.display(i)
148 |
149 | return loss_meter, acc1_meter, acc5_meter
150 |
151 |
152 | if __name__ == '__main__':
153 | main()
154 |
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class BasicBlock(nn.Module):
5 | """Basic Block for resnet 18 and resnet 34
6 | """
7 |
8 | # BasicBlock and BottleNeck block
9 | # have different output size
10 | # we use class attribute expansion
11 | # to distinct
12 | expansion = 1
13 |
14 | def __init__(self, in_channels, out_channels, stride=1):
15 | super().__init__()
16 |
17 | # residual function
18 | self.residual_function = nn.Sequential(
19 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
20 | nn.BatchNorm2d(out_channels),
21 | nn.ReLU(inplace=True),
22 | nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False),
23 | nn.BatchNorm2d(out_channels * BasicBlock.expansion)
24 | )
25 |
26 | # shortcut
27 | self.shortcut = nn.Sequential()
28 |
29 | # the shortcut output dimension is not the same with residual function
30 | # use 1*1 convolution to match the dimension
31 | if stride != 1 or in_channels != BasicBlock.expansion * out_channels:
32 | self.shortcut = nn.Sequential(
33 | nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),
34 | nn.BatchNorm2d(out_channels * BasicBlock.expansion)
35 | )
36 |
37 | def forward(self, x):
38 | return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
39 |
40 |
41 | class BottleNeck(nn.Module):
42 | """Residual block for resnet over 50 layers
43 | """
44 | expansion = 4
45 |
46 | def __init__(self, in_channels, out_channels, stride=1):
47 | super().__init__()
48 | self.residual_function = nn.Sequential(
49 | nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
50 | nn.BatchNorm2d(out_channels),
51 | nn.ReLU(inplace=True),
52 | nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False),
53 | nn.BatchNorm2d(out_channels),
54 | nn.ReLU(inplace=True),
55 | nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False),
56 | nn.BatchNorm2d(out_channels * BottleNeck.expansion),
57 | )
58 |
59 | self.shortcut = nn.Sequential()
60 |
61 | if stride != 1 or in_channels != out_channels * BottleNeck.expansion:
62 | self.shortcut = nn.Sequential(
63 | nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False),
64 | nn.BatchNorm2d(out_channels * BottleNeck.expansion)
65 | )
66 |
67 | def forward(self, x):
68 | return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
69 |
70 |
71 | class ResNet(nn.Module):
72 |
73 | def __init__(self, block, num_block, in_channels=3, num_classes=10):
74 | super().__init__()
75 |
76 | self.in_channels = 64
77 |
78 | self.conv1 = nn.Sequential(
79 | nn.Conv2d(in_channels, 64, kernel_size=3, padding=1, bias=False),
80 | nn.BatchNorm2d(64),
81 | nn.ReLU(inplace=True))
82 | # we use a different inputsize than the original paper
83 | # so conv2_x's stride is 1
84 | self.conv2_x = self._make_layer(block, 64, num_block[0], 1)
85 | self.conv3_x = self._make_layer(block, 128, num_block[1], 2)
86 | self.conv4_x = self._make_layer(block, 256, num_block[2], 2)
87 | self.conv5_x = self._make_layer(block, 512, num_block[3], 2)
88 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
89 | self.fc = nn.Linear(512 * block.expansion, num_classes)
90 |
91 | def _make_layer(self, block, out_channels, num_blocks, stride):
92 | """make resnet layers(by layer i didnt mean this 'layer' was the
93 | same as a neuron netowork layer, ex. conv layer), one layer may
94 | contain more than one residual block
95 | Args:
96 | block: block type, basic block or bottle neck block
97 | out_channels: output depth channel number of this layer
98 | num_blocks: how many blocks per layer
99 | stride: the stride of the first block of this layer
100 | Return:
101 | return a resnet layer
102 | """
103 |
104 | # we have num_block blocks per layer, the first block
105 | # could be 1 or 2, other blocks would always be 1
106 | strides = [stride] + [1] * (num_blocks - 1)
107 | layers = []
108 | for stride in strides:
109 | layers.append(block(self.in_channels, out_channels, stride))
110 | self.in_channels = out_channels * block.expansion
111 |
112 | return nn.Sequential(*layers)
113 |
114 | def forward(self, x):
115 | output = self.conv1(x)
116 | output = self.conv2_x(output)
117 | output = self.conv3_x(output)
118 | output = self.conv4_x(output)
119 | output = self.conv5_x(output)
120 | output = self.avg_pool(output)
121 | output = output.view(output.size(0), -1)
122 | output = self.fc(output)
123 |
124 | return output
125 |
126 |
127 | def resnet18():
128 | """ return a ResNet 18 object
129 | """
130 | return ResNet(BasicBlock, [2, 2, 2, 2])
131 |
132 |
133 | def resnet34(in_channels=3, num_classes=10):
134 | """ return a ResNet 34 object
135 | """
136 | return ResNet(BasicBlock, [3, 4, 6, 3], in_channels=in_channels, num_classes=num_classes)
137 |
138 |
139 | def resnet50(in_channels=3, num_classes=10):
140 | """ return a ResNet 50 object
141 | """
142 | return ResNet(BottleNeck, [3, 4, 6, 3], in_channels=in_channels, num_classes=num_classes)
143 |
144 |
145 | def resnet101():
146 | """ return a ResNet 101 object
147 | """
148 | return ResNet(BottleNeck, [3, 4, 23, 3])
149 |
150 |
151 | def resnet152():
152 | """ return a ResNet 152 object
153 | """
154 | return ResNet(BottleNeck, [3, 8, 36, 3])
155 |
--------------------------------------------------------------------------------
/models/mobilenet.py:
--------------------------------------------------------------------------------
1 | """mobilenet in pytorch
2 |
3 |
4 |
5 | [1] Andrew G. Howard, Menglong Zhu, Bo Chen, Dmitry Kalenichenko, Weijun Wang, Tobias Weyand, Marco Andreetto, Hartwig Adam
6 |
7 | MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications
8 | https://arxiv.org/abs/1704.04861
9 | """
10 |
11 | import torch
12 | import torch.nn as nn
13 |
14 |
15 | class DepthSeperabelConv2d(nn.Module):
16 |
17 | def __init__(self, input_channels, output_channels, kernel_size, **kwargs):
18 | super().__init__()
19 | self.depthwise = nn.Sequential(
20 | nn.Conv2d(
21 | input_channels,
22 | input_channels,
23 | kernel_size,
24 | groups=input_channels,
25 | **kwargs),
26 | nn.BatchNorm2d(input_channels),
27 | nn.ReLU(inplace=True)
28 | )
29 |
30 | self.pointwise = nn.Sequential(
31 | nn.Conv2d(input_channels, output_channels, 1),
32 | nn.BatchNorm2d(output_channels),
33 | nn.ReLU(inplace=True)
34 | )
35 |
36 | def forward(self, x):
37 | x = self.depthwise(x)
38 | x = self.pointwise(x)
39 |
40 | return x
41 |
42 |
43 | class BasicConv2d(nn.Module):
44 |
45 | def __init__(self, input_channels, output_channels, kernel_size, **kwargs):
46 | super().__init__()
47 | self.conv = nn.Conv2d(
48 | input_channels, output_channels, kernel_size, **kwargs)
49 | self.bn = nn.BatchNorm2d(output_channels)
50 | self.relu = nn.ReLU(inplace=True)
51 |
52 | def forward(self, x):
53 | x = self.conv(x)
54 | x = self.bn(x)
55 | x = self.relu(x)
56 |
57 | return x
58 |
59 |
60 | class MobileNet(nn.Module):
61 | """
62 | Args:
63 | width multipler: The role of the width multiplier α is to thin
64 | a network uniformly at each layer. For a given
65 | layer and width multiplier α, the number of
66 | input channels M becomes αM and the number of
67 | output channels N becomes αN.
68 | """
69 |
70 | def __init__(self, width_multiplier=1, class_num=100):
71 | super().__init__()
72 |
73 | alpha = width_multiplier
74 | self.stem = nn.Sequential(
75 | BasicConv2d(3, int(32 * alpha), 3, padding=1, bias=False),
76 | DepthSeperabelConv2d(
77 | int(32 * alpha),
78 | int(64 * alpha),
79 | 3,
80 | padding=1,
81 | bias=False
82 | )
83 | )
84 |
85 | # downsample
86 | self.conv1 = nn.Sequential(
87 | DepthSeperabelConv2d(
88 | int(64 * alpha),
89 | int(128 * alpha),
90 | 3,
91 | stride=2,
92 | padding=1,
93 | bias=False
94 | ),
95 | DepthSeperabelConv2d(
96 | int(128 * alpha),
97 | int(128 * alpha),
98 | 3,
99 | padding=1,
100 | bias=False
101 | )
102 | )
103 |
104 | # downsample
105 | self.conv2 = nn.Sequential(
106 | DepthSeperabelConv2d(
107 | int(128 * alpha),
108 | int(256 * alpha),
109 | 3,
110 | stride=2,
111 | padding=1,
112 | bias=False
113 | ),
114 | DepthSeperabelConv2d(
115 | int(256 * alpha),
116 | int(256 * alpha),
117 | 3,
118 | padding=1,
119 | bias=False
120 | )
121 | )
122 |
123 | # downsample
124 | self.conv3 = nn.Sequential(
125 | DepthSeperabelConv2d(
126 | int(256 * alpha),
127 | int(512 * alpha),
128 | 3,
129 | stride=2,
130 | padding=1,
131 | bias=False
132 | ),
133 |
134 | DepthSeperabelConv2d(
135 | int(512 * alpha),
136 | int(512 * alpha),
137 | 3,
138 | padding=1,
139 | bias=False
140 | ),
141 | DepthSeperabelConv2d(
142 | int(512 * alpha),
143 | int(512 * alpha),
144 | 3,
145 | padding=1,
146 | bias=False
147 | ),
148 | DepthSeperabelConv2d(
149 | int(512 * alpha),
150 | int(512 * alpha),
151 | 3,
152 | padding=1,
153 | bias=False
154 | ),
155 | DepthSeperabelConv2d(
156 | int(512 * alpha),
157 | int(512 * alpha),
158 | 3,
159 | padding=1,
160 | bias=False
161 | ),
162 | DepthSeperabelConv2d(
163 | int(512 * alpha),
164 | int(512 * alpha),
165 | 3,
166 | padding=1,
167 | bias=False
168 | )
169 | )
170 |
171 | # downsample
172 | self.conv4 = nn.Sequential(
173 | DepthSeperabelConv2d(
174 | int(512 * alpha),
175 | int(1024 * alpha),
176 | 3,
177 | stride=2,
178 | padding=1,
179 | bias=False
180 | ),
181 | DepthSeperabelConv2d(
182 | int(1024 * alpha),
183 | int(1024 * alpha),
184 | 3,
185 | padding=1,
186 | bias=False
187 | )
188 | )
189 |
190 | self.fc = nn.Linear(int(1024 * alpha), class_num)
191 | self.avg = nn.AdaptiveAvgPool2d(1)
192 |
193 | def forward(self, x):
194 | x = self.stem(x)
195 |
196 | x = self.conv1(x)
197 | x = self.conv2(x)
198 | x = self.conv3(x)
199 | x = self.conv4(x)
200 |
201 | x = self.avg(x)
202 | x = x.view(x.size(0), -1)
203 | x = self.fc(x)
204 | return x
205 |
206 |
207 | def mobilenet(alpha=1, class_num=100):
208 | return MobileNet(alpha, class_num)
--------------------------------------------------------------------------------
/core/relevant_feature_identifying.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import numpy as np
4 | from tqdm import tqdm
5 | import torch
6 | import torch.nn as nn
7 |
8 | import models
9 | import loaders
10 |
11 |
12 | def partial_conv(conv: nn.Conv2d, inp: torch.Tensor, o_h=None, o_w=None):
13 | kernel_size = conv.kernel_size
14 | dilation = conv.dilation
15 | padding = conv.padding
16 | stride = conv.stride
17 | weight = conv.weight.to(inp.device) # O I K K
18 | # bias = conv.bias.to(inp.device) # O
19 |
20 | wei_res = weight.view(weight.size(0), weight.size(1), -1).permute((1, 2, 0)) # I K*K O
21 | inp_unf = nn.Unfold(kernel_size=kernel_size, dilation=dilation, padding=padding, stride=stride)(inp) # B K*K N
22 | inp_unf = inp_unf.view(inp.size(0), inp.size(1), wei_res.size(1), o_h, o_w) # B I K*K H_O W_O
23 | out = torch.einsum('ijkmn,jkl->iljmn', inp_unf, wei_res) # B O I H W
24 |
25 | # out = out.sum(2)
26 | # bias = bias.unsqueeze(1).unsqueeze(2).expand((out.size(1), out.size(2), out.size(3))) # O H W
27 | # out = out + bias
28 |
29 | return out
30 |
31 |
32 | def partial_linear(linear: nn.Linear, inp: torch.Tensor):
33 | weight = linear.weight.to(inp.device) # (o, i)
34 | # bias = linear.bias.to(inp.device) # (o)
35 |
36 | out = torch.einsum('bi,oi->boi', inp, weight) # (b, o, i)
37 |
38 | # out = torch.sum(out, dim=-1)
39 | # out = out + bias
40 |
41 | return out
42 |
43 |
44 | def mm_norm(a, dim=-1, zero=False):
45 | if zero:
46 | a_min = torch.zeros(a.size())
47 | else:
48 | a_min, _ = torch.min(a, dim=dim, keepdim=True)
49 | a_max, _ = torch.max(a, dim=dim, keepdim=True)
50 | a_normalized = (a - a_min) / (a_max - a_min + 1e-5)
51 |
52 | return a_normalized
53 |
54 |
55 | class HookModule:
56 | def __init__(self, module):
57 | self.module = module
58 | self.inputs = None
59 | self.outputs = None
60 | module.register_forward_hook(self._hook)
61 |
62 | def _hook(self, module, inputs, outputs):
63 | self.inputs = inputs[0]
64 | self.outputs = outputs
65 |
66 |
67 | class RelevantFeatureIdentifying:
68 | def __init__(self, modules, num_classes, save_dir):
69 | self.modules = [HookModule(module) for module in modules]
70 | # self.values = [[[] for _ in range(num_classes)] for _ in range(len(modules))] # [l, c, n, channels]
71 | self.values = [[0 for _ in range(num_classes)] for _ in range(len(modules))] # [l, c, channels]
72 | self.num_classes = num_classes
73 | self.save_dir = save_dir
74 |
75 | def __call__(self, outputs, labels):
76 | for layer, module in enumerate(self.modules):
77 | torch.cuda.empty_cache()
78 | # print(layer, '==>', layer)
79 | values = None
80 | if isinstance(module.module, nn.Conv2d):
81 | # [b, o, i, h, w]
82 | values = partial_conv(module.module,
83 | module.inputs,
84 | module.outputs.size(2),
85 | module.outputs.size(3))
86 | values = torch.sum(values, dim=(3, 4))
87 | elif isinstance(module.module, nn.Linear):
88 | # [b, o, i)
89 | values = partial_linear(module.module,
90 | module.inputs)
91 | values = torch.relu(values)
92 | values = values.cpu()
93 | values = values.numpy()
94 |
95 | for b in range(len(labels)):
96 | # self.values[layer][labels[b]].append(values[b]) # (l, c, n, o, i)
97 | self.values[layer][labels[b]] += values[b] # (l, c, o, i)
98 |
99 | def identify(self):
100 | # parameter configuration
101 | alpha_c = 0.3
102 | beta_c = 0.2
103 | alpha_f = 0.4
104 | beta_f = 0.3
105 |
106 | # layer -1
107 | mask = torch.eye(self.num_classes, dtype=torch.long) # (c, o)
108 | mask_path = os.path.join(self.save_dir, 'masks', 'mask_layer{}.pt'.format('-1'))
109 | torch.save(mask, mask_path)
110 |
111 | # layer 0~n
112 | for layer, values in enumerate(self.values): # (l, c, n, o, i)
113 | values = torch.from_numpy(np.asarray(self.values[layer])) # (c, n, o, i)
114 | # values = torch.sum(values, axis=1) # (c, o, i)
115 | print('-' * 20)
116 | print(mask.shape)
117 | print(values.shape)
118 | print('-' * 20)
119 |
120 | if values.shape[1] != mask.shape[1]:
121 | mask = torch.ones((values.shape[0], values.shape[1]), dtype=torch.long)
122 |
123 | values = mm_norm(values) # (c, o, i)
124 | if isinstance(self.modules[layer].module, nn.Conv2d):
125 | values = torch.where(values > alpha_c, 1, 0) # (c, o, i)
126 | else:
127 | values = torch.where(values > alpha_f, 1, 0) # (c, o, i)
128 | values = torch.einsum('co,coi->ci', mask, values) # (c, i)
129 | # values = torch.sum(values, dim=1) # (c, i)
130 | values = mm_norm(values) # (c, i)
131 | if isinstance(self.modules[layer].module, nn.Conv2d):
132 | mask = torch.where(values > beta_c, 1, 0) # (c, i)
133 | else:
134 | mask = torch.where(values > beta_f, 1, 0) # (c, i)
135 |
136 | mask_path = os.path.join(self.save_dir, 'masks', 'mask_layer{}.pt'.format(layer))
137 | torch.save(mask, mask_path)
138 |
139 |
140 | def main():
141 | parser = argparse.ArgumentParser(description='')
142 | parser.add_argument('--model_name', default='', type=str, help='model name')
143 | parser.add_argument('--data_name', default='', type=str, help='data name')
144 | parser.add_argument('--num_classes', default='', type=int, help='num classes')
145 | parser.add_argument('--model_path', default='', type=str, help='model path')
146 | parser.add_argument('--data_dir', default='', type=str, help='data path')
147 | parser.add_argument('--save_dir', default='', type=str, help='save dir')
148 | args = parser.parse_args()
149 |
150 | # ----------------------------------------
151 | # basic configuration
152 | # ----------------------------------------
153 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
154 |
155 | if not os.path.exists(args.save_dir):
156 | os.makedirs(os.path.join(args.save_dir, 'masks'))
157 | os.makedirs(os.path.join(args.save_dir, 'figs'))
158 |
159 | print('-' * 50)
160 | print('TRAIN ON:', device)
161 | print('DATA DIR:', args.data_dir)
162 | print('SAVE DIR:', args.save_dir)
163 | print('-' * 50)
164 |
165 | # ----------------------------------------
166 | # model/data configuration
167 | # ----------------------------------------
168 | model = models.load_model(model_name=args.model_name, num_classes=args.num_classes)
169 | model.load_state_dict(torch.load(args.model_path))
170 | # model = torch.load(args.model_path)
171 | model.to(device)
172 | model.eval()
173 |
174 | data_loader = loaders.load_data(args.data_dir, args.data_name, data_type='test')
175 |
176 | modules = models.load_modules(model=model)
177 |
178 | rfi = RelevantFeatureIdentifying(modules=modules, num_classes=args.num_classes, save_dir=args.save_dir)
179 |
180 | # ----------------------------------------
181 | # forward
182 | # ----------------------------------------
183 | for i, samples in enumerate(tqdm(data_loader)):
184 | inputs, labels, _ = samples
185 | inputs = inputs.to(device)
186 | labels = labels.to(device)
187 | with torch.no_grad():
188 | outputs = model(inputs)
189 | rfi(outputs, labels)
190 |
191 | rfi.identify()
192 |
193 |
194 | if __name__ == '__main__':
195 | main()
196 |
--------------------------------------------------------------------------------
/core/model_decision_route_visualizing.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from tkinter import *
3 | import numpy as np
4 | import os
5 |
6 | import torch
7 |
8 | FIG_W = 1536
9 | FIG_H = 1024
10 |
11 | CONV_W = 4
12 | CONV_H = 4
13 | LINEAR_W = 2
14 | LINEAR_H = 2
15 |
16 | INTERVAL_CONV_X = 200
17 | INTERVAL_CONV_Y = 7
18 | INTERVAL_LINEAR_X = 280
19 | INTERVAL_LINEAR_Y = 4.5
20 |
21 | PADDING_X = 10
22 | PADDING_Y = 400 # middle line
23 |
24 | LINE_WIDTH = 1
25 |
26 | # COLOR_PUBLIC = 'orange'
27 | # COLOR_NO_USE = 'gray'
28 | # COLORS = ['purple', 'red']
29 | # COLOR_PUBLIC = '#feb888'
30 | # COLOR_NO_USE = '#c8c8c8'
31 | # COLORS = ['#b0d994', '#a3cbef', ]
32 | COLOR_PUBLIC = '#F8AC8C'
33 | COLOR_NO_USE = '#c8c8c8'
34 | COLORS = ['#C82423', '#2878B5', ]
35 |
36 |
37 | # COLORS = ['#2878B5', '#C82423', ]
38 |
39 |
40 | def draw_route(masks, layers):
41 | root = Tk()
42 | cv = Canvas(root, background='white', width=FIG_W, height=FIG_H)
43 | cv.pack(fill=BOTH, expand=YES)
44 |
45 | # ---------------------------
46 | # each layer
47 | # ---------------------------
48 | masks = np.asarray(masks) # layers, labels, channels
49 | print(masks.shape)
50 |
51 | x = PADDING_X
52 | line_start_p_preceding = [(PADDING_X, PADDING_Y)] # public
53 | line_start_preceding = [[(PADDING_X, PADDING_Y)] for _ in range(masks.shape[1])] # [labels * [init]]
54 |
55 | for layer in range(masks.shape[0]):
56 |
57 | line_end_p = [] # public
58 | line_start_p = [] # public
59 | line_end = [[] for _ in range(masks.shape[1])] # [labels * []] each class
60 | line_start = [[] for _ in range(masks.shape[1])]
61 |
62 | line_p_num = 0
63 | line_num = 0
64 |
65 | # ---------------------------
66 | # each channel
67 | # ---------------------------
68 | layer_masks = np.asarray(list(masks[layer])) # labels, channels
69 |
70 | # init posi.
71 | if layers[layer] == 'conv':
72 | x += CONV_W + INTERVAL_CONV_X
73 | y = PADDING_Y - (layer_masks.shape[1] / 2) * (CONV_H + INTERVAL_CONV_Y) + INTERVAL_CONV_Y / 2
74 | else:
75 | x += LINEAR_W + INTERVAL_LINEAR_X
76 | y = PADDING_Y - (layer_masks.shape[1] / 2) * (LINEAR_H + INTERVAL_LINEAR_Y) + INTERVAL_LINEAR_Y / 2
77 |
78 | # draw conv/linear
79 | for channel in range(layer_masks.shape[1]):
80 | if layer_masks[:, channel].sum() > 1:
81 | if layers[layer] == 'conv':
82 | line_end_p.append(((x), (y + CONV_H / 2)))
83 | line_start_p.append(((x + CONV_W), (y + CONV_H / 2)))
84 | cv.create_rectangle(x, y, x + CONV_W, y + CONV_H,
85 | outline=COLOR_PUBLIC,
86 | fill=COLOR_PUBLIC,
87 | width=LINE_WIDTH)
88 | else:
89 | line_end_p.append(((x), (y + LINEAR_H / 2)))
90 | line_start_p.append(((x + LINEAR_W), (y + LINEAR_H / 2)))
91 | cv.create_oval(x, y, x + LINEAR_W, y + LINEAR_H,
92 | outline=COLOR_PUBLIC,
93 | fill=COLOR_PUBLIC,
94 | width=LINE_WIDTH)
95 | elif layer_masks[:, channel].sum() < 1:
96 | if layers[layer] == 'conv':
97 | cv.create_rectangle(x, y, x + CONV_W, y + CONV_H,
98 | outline=COLOR_NO_USE,
99 | fill=COLOR_NO_USE,
100 | width=LINE_WIDTH)
101 | else:
102 | cv.create_oval(x, y, x + LINEAR_W, y + LINEAR_H,
103 | outline=COLOR_NO_USE,
104 | fill=COLOR_NO_USE,
105 | width=LINE_WIDTH)
106 | else:
107 | # ---------------------------
108 | # each label
109 | # ---------------------------
110 | for l, mask in enumerate(layer_masks[:, channel]):
111 | if mask:
112 | if layers[layer] == 'conv':
113 | line_end[l].append(((x), (y + CONV_H / 2)))
114 | line_start[l].append(((x + CONV_W), (y + CONV_H / 2)))
115 | cv.create_rectangle(x, y, x + CONV_W, y + CONV_H,
116 | outline=COLORS[l],
117 | fill=COLORS[l],
118 | width=LINE_WIDTH)
119 | else:
120 | line_end[l].append(((x), (y + LINEAR_H / 2)))
121 | line_start[l].append(((x + LINEAR_W), (y + LINEAR_H / 2)))
122 | cv.create_oval(x, y, x + LINEAR_W, y + LINEAR_H,
123 | outline=COLORS[l],
124 | fill=COLORS[l],
125 | width=LINE_WIDTH)
126 |
127 | # next y start posi.
128 | if layers[layer] == 'conv':
129 | y += CONV_H + INTERVAL_CONV_Y
130 | else:
131 | y += LINEAR_H + INTERVAL_LINEAR_Y
132 |
133 | # draw line
134 | for l in range(layer_masks.shape[0]):
135 | # line_num += (len(line_start_preceding[l]) * len(line_end[l])) # each to each
136 | # line_p_num += (len(line_start_preceding[l]) * len(line_end_p)) # each to public
137 | # line_p_num += (len(line_start_p_preceding) * len(line_end[l])) # public to each
138 | line_num += len(line_start[l]) # each
139 | for x0, y0 in line_start_preceding[l]:
140 | # each to each
141 | for x1, y1 in line_end[l]:
142 | cv.create_line(x0, y0, x1, y1,
143 | width=LINE_WIDTH,
144 | fill=COLORS[l],
145 | # arrow=LAST,
146 | arrowshape=(6, 5, 1))
147 |
148 | # each to public
149 | for x1, y1 in line_end_p:
150 | cv.create_line(x0, y0, x1, y1,
151 | width=LINE_WIDTH,
152 | fill=COLORS[l],
153 | # arrow=LAST,
154 | arrowshape=(6, 5, 1))
155 |
156 | # public to each
157 | for x0, y0 in line_start_p_preceding:
158 | for x1, y1 in line_end[l]:
159 | cv.create_line(x0, y0, x1, y1,
160 | width=LINE_WIDTH,
161 | fill=COLORS[l],
162 | # arrow=LAST,
163 | arrowshape=(6, 5, 1))
164 |
165 | # line_p_num += (len(line_start_p_preceding) * len(line_end_p)) # public to public
166 | line_p_num += len(line_start_p) # public
167 | # public to public
168 | for x0, y0 in line_start_p_preceding:
169 | for x1, y1 in line_end_p:
170 | cv.create_line(x0, y0, x1, y1,
171 | width=LINE_WIDTH + 1,
172 | fill=COLOR_PUBLIC,
173 | # arrow=LAST,
174 | arrowshape=(6, 5, 1))
175 |
176 | line_start_preceding = line_start.copy()
177 | line_start_p_preceding = line_start_p.copy()
178 |
179 | # calculate
180 | print('--->', layer,
181 | '| line--->', line_num,
182 | '| line_p--->', line_p_num,
183 | '| --->', line_p_num / (line_num + line_p_num))
184 |
185 | root.mainloop()
186 |
187 |
188 | def main():
189 | parser = argparse.ArgumentParser(description='')
190 | parser.add_argument('--mask_dir', default='', type=str, help='mask dir')
191 | parser.add_argument('--layers', default='', nargs='+', type=int, help='layers')
192 | parser.add_argument('--labels', default='', nargs='+', type=int, help='labels')
193 | # parser.add_argument('--save_dir', default='', type=str, help='save dir')
194 | args = parser.parse_args()
195 |
196 | mask_path = os.path.join(args.mask_dir, 'mask_layer{}.pt')
197 |
198 | if args.layers[0] == -1:
199 | args.layers = [4, 3, 2, 1, 0] # Please set manually
200 |
201 | layers_name = ['conv' for _ in range(2)] + ['linear' for _ in range(3)] # Please set manually
202 |
203 | for label in args.labels:
204 | masks = []
205 | for layer in args.layers:
206 | mask_o = torch.load(mask_path.format(layer - 1))[label].numpy()
207 | masks.append([mask_o])
208 |
209 | print(masks)
210 | print(np.asarray(masks).shape)
211 | print(layers_name)
212 | draw_route(masks, layers_name)
213 |
214 |
215 | if __name__ == '__main__':
216 | main()
217 |
--------------------------------------------------------------------------------
/models/inceptionv3.py:
--------------------------------------------------------------------------------
1 | """ inceptionv3 in pytorch
2 |
3 |
4 | [1] Christian Szegedy, Vincent Vanhoucke, Sergey Ioffe, Jonathon Shlens, Zbigniew Wojna
5 |
6 | Rethinking the Inception Architecture for Computer Vision
7 | https://arxiv.org/abs/1512.00567v3
8 | """
9 |
10 | import torch
11 | import torch.nn as nn
12 |
13 |
14 | class BasicConv2d(nn.Module):
15 |
16 | def __init__(self, input_channels, output_channels, **kwargs):
17 | super().__init__()
18 | self.conv = nn.Conv2d(input_channels, output_channels, bias=False, **kwargs)
19 | self.bn = nn.BatchNorm2d(output_channels)
20 | self.relu = nn.ReLU(inplace=True)
21 |
22 | def forward(self, x):
23 | x = self.conv(x)
24 | x = self.bn(x)
25 | x = self.relu(x)
26 |
27 | return x
28 |
29 |
30 | # same naive inception module
31 | class InceptionA(nn.Module):
32 |
33 | def __init__(self, input_channels, pool_features):
34 | super().__init__()
35 | self.branch1x1 = BasicConv2d(input_channels, 64, kernel_size=1)
36 |
37 | self.branch5x5 = nn.Sequential(
38 | BasicConv2d(input_channels, 48, kernel_size=1),
39 | BasicConv2d(48, 64, kernel_size=5, padding=2)
40 | )
41 |
42 | self.branch3x3 = nn.Sequential(
43 | BasicConv2d(input_channels, 64, kernel_size=1),
44 | BasicConv2d(64, 96, kernel_size=3, padding=1),
45 | BasicConv2d(96, 96, kernel_size=3, padding=1)
46 | )
47 |
48 | self.branchpool = nn.Sequential(
49 | nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
50 | BasicConv2d(input_channels, pool_features, kernel_size=3, padding=1)
51 | )
52 |
53 | def forward(self, x):
54 | # x -> 1x1(same)
55 | branch1x1 = self.branch1x1(x)
56 |
57 | # x -> 1x1 -> 5x5(same)
58 | branch5x5 = self.branch5x5(x)
59 | # branch5x5 = self.branch5x5_2(branch5x5)
60 |
61 | # x -> 1x1 -> 3x3 -> 3x3(same)
62 | branch3x3 = self.branch3x3(x)
63 |
64 | # x -> pool -> 1x1(same)
65 | branchpool = self.branchpool(x)
66 |
67 | outputs = [branch1x1, branch5x5, branch3x3, branchpool]
68 |
69 | return torch.cat(outputs, 1)
70 |
71 |
72 | # downsample
73 | # Factorization into smaller convolutions
74 | class InceptionB(nn.Module):
75 |
76 | def __init__(self, input_channels):
77 | super().__init__()
78 |
79 | self.branch3x3 = BasicConv2d(input_channels, 384, kernel_size=3, stride=2)
80 |
81 | self.branch3x3stack = nn.Sequential(
82 | BasicConv2d(input_channels, 64, kernel_size=1),
83 | BasicConv2d(64, 96, kernel_size=3, padding=1),
84 | BasicConv2d(96, 96, kernel_size=3, stride=2)
85 | )
86 |
87 | self.branchpool = nn.MaxPool2d(kernel_size=3, stride=2)
88 |
89 | def forward(self, x):
90 | # x - > 3x3(downsample)
91 | branch3x3 = self.branch3x3(x)
92 |
93 | # x -> 3x3 -> 3x3(downsample)
94 | branch3x3stack = self.branch3x3stack(x)
95 |
96 | # x -> avgpool(downsample)
97 | branchpool = self.branchpool(x)
98 |
99 | # """We can use two parallel stride 2 blocks: P and C. P is a pooling
100 | # layer (either average or maximum pooling) the activation, both of
101 | # them are stride 2 the filter banks of which are concatenated as in
102 | # figure 10."""
103 | outputs = [branch3x3, branch3x3stack, branchpool]
104 |
105 | return torch.cat(outputs, 1)
106 |
107 |
108 | # Factorizing Convolutions with Large Filter Size
109 | class InceptionC(nn.Module):
110 | def __init__(self, input_channels, channels_7x7):
111 | super().__init__()
112 | self.branch1x1 = BasicConv2d(input_channels, 192, kernel_size=1)
113 |
114 | c7 = channels_7x7
115 |
116 | # In theory, we could go even further and argue that one can replace any n × n
117 | # convolution by a 1 × n convolution followed by a n × 1 convolution and the
118 | # computational cost saving increases dramatically as n grows (see figure 6).
119 | self.branch7x7 = nn.Sequential(
120 | BasicConv2d(input_channels, c7, kernel_size=1),
121 | BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)),
122 | BasicConv2d(c7, 192, kernel_size=(1, 7), padding=(0, 3))
123 | )
124 |
125 | self.branch7x7stack = nn.Sequential(
126 | BasicConv2d(input_channels, c7, kernel_size=1),
127 | BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)),
128 | BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3)),
129 | BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)),
130 | BasicConv2d(c7, 192, kernel_size=(1, 7), padding=(0, 3))
131 | )
132 |
133 | self.branch_pool = nn.Sequential(
134 | nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
135 | BasicConv2d(input_channels, 192, kernel_size=1),
136 | )
137 |
138 | def forward(self, x):
139 | # x -> 1x1(same)
140 | branch1x1 = self.branch1x1(x)
141 |
142 | # x -> 1layer 1*7 and 7*1 (same)
143 | branch7x7 = self.branch7x7(x)
144 |
145 | # x-> 2layer 1*7 and 7*1(same)
146 | branch7x7stack = self.branch7x7stack(x)
147 |
148 | # x-> avgpool (same)
149 | branchpool = self.branch_pool(x)
150 |
151 | outputs = [branch1x1, branch7x7, branch7x7stack, branchpool]
152 |
153 | return torch.cat(outputs, 1)
154 |
155 |
156 | class InceptionD(nn.Module):
157 |
158 | def __init__(self, input_channels):
159 | super().__init__()
160 |
161 | self.branch3x3 = nn.Sequential(
162 | BasicConv2d(input_channels, 192, kernel_size=1),
163 | BasicConv2d(192, 320, kernel_size=3, stride=2)
164 | )
165 |
166 | self.branch7x7 = nn.Sequential(
167 | BasicConv2d(input_channels, 192, kernel_size=1),
168 | BasicConv2d(192, 192, kernel_size=(1, 7), padding=(0, 3)),
169 | BasicConv2d(192, 192, kernel_size=(7, 1), padding=(3, 0)),
170 | BasicConv2d(192, 192, kernel_size=3, stride=2)
171 | )
172 |
173 | self.branchpool = nn.AvgPool2d(kernel_size=3, stride=2)
174 |
175 | def forward(self, x):
176 | # x -> 1x1 -> 3x3(downsample)
177 | branch3x3 = self.branch3x3(x)
178 |
179 | # x -> 1x1 -> 1x7 -> 7x1 -> 3x3 (downsample)
180 | branch7x7 = self.branch7x7(x)
181 |
182 | # x -> avgpool (downsample)
183 | branchpool = self.branchpool(x)
184 |
185 | outputs = [branch3x3, branch7x7, branchpool]
186 |
187 | return torch.cat(outputs, 1)
188 |
189 |
190 | # same
191 | class InceptionE(nn.Module):
192 | def __init__(self, input_channels):
193 | super().__init__()
194 | self.branch1x1 = BasicConv2d(input_channels, 320, kernel_size=1)
195 |
196 | self.branch3x3_1 = BasicConv2d(input_channels, 384, kernel_size=1)
197 | self.branch3x3_2a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1))
198 | self.branch3x3_2b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0))
199 |
200 | self.branch3x3stack_1 = BasicConv2d(input_channels, 448, kernel_size=1)
201 | self.branch3x3stack_2 = BasicConv2d(448, 384, kernel_size=3, padding=1)
202 | self.branch3x3stack_3a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1))
203 | self.branch3x3stack_3b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0))
204 |
205 | self.branch_pool = nn.Sequential(
206 | nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
207 | BasicConv2d(input_channels, 192, kernel_size=1)
208 | )
209 |
210 | def forward(self, x):
211 | # x -> 1x1 (same)
212 | branch1x1 = self.branch1x1(x)
213 |
214 | # x -> 1x1 -> 3x1
215 | # x -> 1x1 -> 1x3
216 | # concatenate(3x1, 1x3)
217 | # """7. Inception modules with expanded the filter bank outputs.
218 | # This architecture is used on the coarsest (8 × 8) grids to promote
219 | # high dimensional representations, as suggested by principle
220 | # 2 of Section 2."""
221 | branch3x3 = self.branch3x3_1(x)
222 | branch3x3 = [
223 | self.branch3x3_2a(branch3x3),
224 | self.branch3x3_2b(branch3x3)
225 | ]
226 | branch3x3 = torch.cat(branch3x3, 1)
227 |
228 | # x -> 1x1 -> 3x3 -> 1x3
229 | # x -> 1x1 -> 3x3 -> 3x1
230 | # concatenate(1x3, 3x1)
231 | branch3x3stack = self.branch3x3stack_1(x)
232 | branch3x3stack = self.branch3x3stack_2(branch3x3stack)
233 | branch3x3stack = [
234 | self.branch3x3stack_3a(branch3x3stack),
235 | self.branch3x3stack_3b(branch3x3stack)
236 | ]
237 | branch3x3stack = torch.cat(branch3x3stack, 1)
238 |
239 | branchpool = self.branch_pool(x)
240 |
241 | outputs = [branch1x1, branch3x3, branch3x3stack, branchpool]
242 |
243 | return torch.cat(outputs, 1)
244 |
245 |
246 | class InceptionV3(nn.Module):
247 |
248 | def __init__(self, in_channels=3, num_classes=10):
249 | super().__init__()
250 | self.Conv2d_1a_3x3 = BasicConv2d(in_channels, 32, kernel_size=3, padding=1)
251 | self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3, padding=1)
252 | self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1)
253 | self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1)
254 | self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3)
255 |
256 | # naive inception module
257 | self.Mixed_5b = InceptionA(192, pool_features=32)
258 | self.Mixed_5c = InceptionA(256, pool_features=64)
259 | self.Mixed_5d = InceptionA(288, pool_features=64)
260 |
261 | # downsample
262 | self.Mixed_6a = InceptionB(288)
263 |
264 | self.Mixed_6b = InceptionC(768, channels_7x7=128)
265 | self.Mixed_6c = InceptionC(768, channels_7x7=160)
266 | self.Mixed_6d = InceptionC(768, channels_7x7=160)
267 | self.Mixed_6e = InceptionC(768, channels_7x7=192)
268 |
269 | # downsample
270 | self.Mixed_7a = InceptionD(768)
271 |
272 | self.Mixed_7b = InceptionE(1280)
273 | self.Mixed_7c = InceptionE(2048)
274 |
275 | # 6*6 feature size
276 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
277 | self.dropout = nn.Dropout2d()
278 | self.linear = nn.Linear(2048, num_classes)
279 |
280 | def forward(self, x):
281 | # 32 -> 30
282 | x = self.Conv2d_1a_3x3(x)
283 | x = self.Conv2d_2a_3x3(x)
284 | x = self.Conv2d_2b_3x3(x)
285 | x = self.Conv2d_3b_1x1(x)
286 | x = self.Conv2d_4a_3x3(x)
287 |
288 | # 30 -> 30
289 | x = self.Mixed_5b(x)
290 | x = self.Mixed_5c(x)
291 | x = self.Mixed_5d(x)
292 |
293 | # 30 -> 14
294 | # Efficient Grid Size Reduction to avoid representation
295 | # bottleneck
296 | x = self.Mixed_6a(x)
297 |
298 | # 14 -> 14
299 | # """In practice, we have found that employing this factorization does not
300 | # work well on early layers, but it gives very good results on medium
301 | # grid-sizes (On m × m feature maps, where m ranges between 12 and 20).
302 | # On that level, very good results can be achieved by using 1 × 7 convolutions
303 | # followed by 7 × 1 convolutions."""
304 | x = self.Mixed_6b(x)
305 | x = self.Mixed_6c(x)
306 | x = self.Mixed_6d(x)
307 | x = self.Mixed_6e(x)
308 |
309 | # 14 -> 6
310 | # Efficient Grid Size Reduction
311 | x = self.Mixed_7a(x)
312 |
313 | # 6 -> 6
314 | # We are using this solution only on the coarsest grid,
315 | # since that is the place where producing high dimensional
316 | # sparse representation is the most critical as the ratio of
317 | # local processing (by 1 × 1 convolutions) is increased compared
318 | # to the spatial aggregation."""
319 | x = self.Mixed_7b(x)
320 | x = self.Mixed_7c(x)
321 |
322 | # 6 -> 1
323 | x = self.avgpool(x)
324 | x = self.dropout(x)
325 | x = x.view(x.size(0), -1)
326 | x = self.linear(x)
327 | return x
328 |
329 |
330 | def inceptionv3(in_channels=3, num_classes=10):
331 | return InceptionV3(in_channels=in_channels, num_classes=num_classes)
--------------------------------------------------------------------------------