├── inclearn
├── __init__.py
├── convnet
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── utils.cpython-38.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── network.cpython-38.pyc
│ │ ├── resnet.cpython-38.pyc
│ │ ├── classifier.cpython-38.pyc
│ │ ├── imbalance.cpython-38.pyc
│ │ ├── cifar_resnet.cpython-38.pyc
│ │ ├── preact_resnet.cpython-38.pyc
│ │ └── modified_resnet_cifar.cpython-38.pyc
│ ├── classifier.py
│ ├── imbalance.py
│ ├── modified_resnet_cifar.py
│ ├── preact_resnet.py
│ ├── cifar_resnet.py
│ ├── utils.py
│ ├── network.py
│ └── resnet.py
├── learn
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ └── pretrain.cpython-38.pyc
│ └── pretrain.py
├── loss
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ ├── focalloss.cpython-38.pyc
│ │ └── tripeloss.cpython-38.pyc
│ ├── focalloss.py
│ └── tripeloss.py
├── tools
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── utils.cpython-38.pyc
│ │ ├── cutout.cpython-38.pyc
│ │ ├── factory.cpython-38.pyc
│ │ ├── memory.cpython-38.pyc
│ │ ├── metrics.cpython-38.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── data_utils.cpython-38.pyc
│ │ ├── scheduler.cpython-38.pyc
│ │ ├── results_utils.cpython-38.pyc
│ │ └── autoaugment_extra.cpython-38.pyc
│ ├── data_utils.py
│ ├── cutout.py
│ ├── results_utils.py
│ ├── factory.py
│ ├── scheduler.py
│ ├── memory.py
│ ├── metrics.py
│ └── utils.py
├── datasets
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── data.cpython-38.pyc
│ │ ├── dataset.cpython-38.pyc
│ │ └── __init__.cpython-38.pyc
│ └── dataset.py
├── models
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── base.cpython-38.pyc
│ │ ├── align.cpython-38.pyc
│ │ ├── prune.cpython-38.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ └── incmodel.cpython-38.pyc
│ ├── base.py
│ └── prune.py
├── .DS_Store
├── prune
│ ├── prune
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── structured.cpython-38.pyc
│ │ │ └── unstructured.cpython-38.pyc
│ │ ├── unstructured.py
│ │ └── structured.py
│ ├── __pycache__
│ │ ├── utils.cpython-38.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── autoslim.cpython-38.pyc
│ │ ├── dependency.cpython-38.pyc
│ │ └── flops_counter.cpython-38.pyc
│ ├── __init__.py
│ ├── utils.py
│ ├── resnet_small.py
│ ├── autoslim_test.py
│ ├── autoslim.py
│ └── sensitivity_analysis.py
└── __pycache__
│ └── __init__.cpython-38.pyc
├── pictures
├── TCIL.png
├── .DS_Store
├── cifar_mem.png
├── non_mem.png
└── imagenet_mem.png
├── requirements.txt
├── scripts
├── run.sh
├── prune.sh
└── inference.sh
├── configs
├── cifar_b0_10s.yaml
├── cifar_b0_20s.yaml
├── cifar_b0_5s.yaml
├── cifar_b50_10s.yaml
├── cifar_b50_2s.yaml
├── cifar_b50_5s.yaml
└── imagenet_b0_10s.yaml
├── README.md
└── main.py
/inclearn/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/inclearn/convnet/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/inclearn/learn/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/inclearn/loss/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/inclearn/tools/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/inclearn/datasets/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/inclearn/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .incmodel import IncModel
2 |
--------------------------------------------------------------------------------
/pictures/TCIL.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/pictures/TCIL.png
--------------------------------------------------------------------------------
/inclearn/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/.DS_Store
--------------------------------------------------------------------------------
/inclearn/prune/prune/__init__.py:
--------------------------------------------------------------------------------
1 | from .structured import *
2 | from .unstructured import *
--------------------------------------------------------------------------------
/pictures/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/pictures/.DS_Store
--------------------------------------------------------------------------------
/pictures/cifar_mem.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/pictures/cifar_mem.png
--------------------------------------------------------------------------------
/pictures/non_mem.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/pictures/non_mem.png
--------------------------------------------------------------------------------
/pictures/imagenet_mem.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/pictures/imagenet_mem.png
--------------------------------------------------------------------------------
/inclearn/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/inclearn/models/__pycache__/base.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/models/__pycache__/base.cpython-38.pyc
--------------------------------------------------------------------------------
/inclearn/prune/__pycache__/utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/prune/__pycache__/utils.cpython-38.pyc
--------------------------------------------------------------------------------
/inclearn/tools/__pycache__/utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/tools/__pycache__/utils.cpython-38.pyc
--------------------------------------------------------------------------------
/inclearn/convnet/__pycache__/utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/convnet/__pycache__/utils.cpython-38.pyc
--------------------------------------------------------------------------------
/inclearn/datasets/__pycache__/data.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/datasets/__pycache__/data.cpython-38.pyc
--------------------------------------------------------------------------------
/inclearn/loss/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/loss/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/inclearn/models/__pycache__/align.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/models/__pycache__/align.cpython-38.pyc
--------------------------------------------------------------------------------
/inclearn/models/__pycache__/prune.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/models/__pycache__/prune.cpython-38.pyc
--------------------------------------------------------------------------------
/inclearn/tools/__pycache__/cutout.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/tools/__pycache__/cutout.cpython-38.pyc
--------------------------------------------------------------------------------
/inclearn/tools/__pycache__/factory.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/tools/__pycache__/factory.cpython-38.pyc
--------------------------------------------------------------------------------
/inclearn/tools/__pycache__/memory.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/tools/__pycache__/memory.cpython-38.pyc
--------------------------------------------------------------------------------
/inclearn/tools/__pycache__/metrics.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/tools/__pycache__/metrics.cpython-38.pyc
--------------------------------------------------------------------------------
/inclearn/convnet/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/convnet/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/inclearn/convnet/__pycache__/network.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/convnet/__pycache__/network.cpython-38.pyc
--------------------------------------------------------------------------------
/inclearn/convnet/__pycache__/resnet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/convnet/__pycache__/resnet.cpython-38.pyc
--------------------------------------------------------------------------------
/inclearn/datasets/__pycache__/dataset.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/datasets/__pycache__/dataset.cpython-38.pyc
--------------------------------------------------------------------------------
/inclearn/learn/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/learn/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/inclearn/learn/__pycache__/pretrain.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/learn/__pycache__/pretrain.cpython-38.pyc
--------------------------------------------------------------------------------
/inclearn/loss/__pycache__/focalloss.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/loss/__pycache__/focalloss.cpython-38.pyc
--------------------------------------------------------------------------------
/inclearn/loss/__pycache__/tripeloss.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/loss/__pycache__/tripeloss.cpython-38.pyc
--------------------------------------------------------------------------------
/inclearn/models/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/models/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/inclearn/models/__pycache__/incmodel.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/models/__pycache__/incmodel.cpython-38.pyc
--------------------------------------------------------------------------------
/inclearn/prune/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/prune/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/inclearn/prune/__pycache__/autoslim.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/prune/__pycache__/autoslim.cpython-38.pyc
--------------------------------------------------------------------------------
/inclearn/prune/__pycache__/dependency.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/prune/__pycache__/dependency.cpython-38.pyc
--------------------------------------------------------------------------------
/inclearn/tools/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/tools/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/inclearn/tools/__pycache__/data_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/tools/__pycache__/data_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/inclearn/tools/__pycache__/scheduler.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/tools/__pycache__/scheduler.cpython-38.pyc
--------------------------------------------------------------------------------
/inclearn/convnet/__pycache__/classifier.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/convnet/__pycache__/classifier.cpython-38.pyc
--------------------------------------------------------------------------------
/inclearn/convnet/__pycache__/imbalance.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/convnet/__pycache__/imbalance.cpython-38.pyc
--------------------------------------------------------------------------------
/inclearn/datasets/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/datasets/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/inclearn/convnet/__pycache__/cifar_resnet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/convnet/__pycache__/cifar_resnet.cpython-38.pyc
--------------------------------------------------------------------------------
/inclearn/convnet/__pycache__/preact_resnet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/convnet/__pycache__/preact_resnet.cpython-38.pyc
--------------------------------------------------------------------------------
/inclearn/prune/__pycache__/flops_counter.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/prune/__pycache__/flops_counter.cpython-38.pyc
--------------------------------------------------------------------------------
/inclearn/prune/prune/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/prune/prune/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/inclearn/tools/__pycache__/results_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/tools/__pycache__/results_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/inclearn/prune/prune/__pycache__/structured.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/prune/prune/__pycache__/structured.cpython-38.pyc
--------------------------------------------------------------------------------
/inclearn/tools/__pycache__/autoaugment_extra.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/tools/__pycache__/autoaugment_extra.cpython-38.pyc
--------------------------------------------------------------------------------
/inclearn/prune/prune/__pycache__/unstructured.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/prune/prune/__pycache__/unstructured.cpython-38.pyc
--------------------------------------------------------------------------------
/inclearn/convnet/__pycache__/modified_resnet_cifar.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/convnet/__pycache__/modified_resnet_cifar.cpython-38.pyc
--------------------------------------------------------------------------------
/inclearn/prune/__init__.py:
--------------------------------------------------------------------------------
1 | from .dependency import *
2 | from .prune import *
3 | from .autoslim import *
4 | from . import utils
5 | from .flops_counter import get_model_complexity_info
6 | import warnings
7 |
8 | warnings.filterwarnings('ignore')
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | albumentations==1.1.0
2 | easydict==1.9
3 | matplotlib==3.5.1
4 | nni==2.10
5 | numpy==1.22.4
6 | opencv_python==4.5.5.62
7 | Pillow==9.3.0
8 | sacred==0.8.2
9 | scikit_learn==1.1.3
10 | scipy==1.9.3
11 | tensorboardX==2.5.1
12 | thop==0.0.31.post2005241907
13 | torch==1.8.1+cu111
14 | torchvision==0.9.1+cu111
15 |
--------------------------------------------------------------------------------
/scripts/run.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | name='cifar_b50_2s'
3 | expid='cifar_b50_2s'
4 |
5 |
6 | python -m main train with "./configs/${expid}.yaml" \
7 | exp.name="${name}" \
8 | exp.savedir="./logs/" \
9 | exp.saveckpt="./ckpts_${expid}/" \
10 | exp.ckptdir="./logs/" \
11 | exp.tensorboard_dir="./tensorboard/" \
12 | exp.debug=True \
13 | --name="${name}" \
14 | -D \
15 | -p \
16 | --force \
17 |
--------------------------------------------------------------------------------
/scripts/prune.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | name='cifar_b0_10s'
3 | expid='cifar_b0_10s'
4 |
5 |
6 | python -m main prune with "./configs/${expid}.yaml" \
7 | exp.name="${name}" \
8 | exp.savedir="./logs/" \
9 | exp.ckptdir="./logs/" \
10 | exp.saveckpt="./ckpts_${expid}/" \
11 | exp.tensorboard_dir="./tensorboard/" \
12 | exp.debug=True \
13 | load_mem=True \
14 | --name="${name}" \
15 | -D \
16 | -p \
17 | --force \
18 |
--------------------------------------------------------------------------------
/scripts/inference.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | name='Inference_cifar_b0_10s'
3 | expid='cifar_b0_10s'
4 |
5 |
6 | python -m main test with "./configs/${expid}.yaml" \
7 | exp.name="${name}" \
8 | exp.savedir="./logs/" \
9 | exp.ckptdir="./logs/" \
10 | exp.saveckpt="./ckpts_${expid}/" \
11 | exp.tensorboard_dir="./tensorboard/" \
12 | exp.debug=True \
13 | load_mem=True \
14 | --name="${name}" \
15 | -D \
16 | -p \
17 | --force \
--------------------------------------------------------------------------------
/inclearn/tools/data_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def construct_balanced_subset(x, y):
5 | xdata, ydata = [], []
6 | minsize = np.inf
7 | for cls_ in np.unique(y):
8 | xdata.append(x[y == cls_])
9 | ydata.append(y[y == cls_])
10 | if ydata[-1].shape[0] < minsize:
11 | minsize = ydata[-1].shape[0]
12 | for i in range(len(xdata)):
13 | if xdata[i].shape[0] < minsize:
14 | import pdb
15 | pdb.set_trace()
16 | idx = np.arange(xdata[i].shape[0])
17 | np.random.shuffle(idx)
18 | xdata[i] = xdata[i][idx][:minsize]
19 | ydata[i] = ydata[i][idx][:minsize]
20 | # !list
21 | return np.concatenate(xdata, 0), np.concatenate(ydata, 0)
--------------------------------------------------------------------------------
/inclearn/convnet/classifier.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | from torch.nn.parameter import Parameter
5 | from torch.nn import functional as F
6 | from torch.nn import Module
7 |
8 |
9 | class CosineClassifier(Module):
10 | def __init__(self, in_features, n_classes, sigma=True):
11 | super(CosineClassifier, self).__init__()
12 | self.in_features = in_features
13 | self.out_features = n_classes
14 | self.weight = Parameter(torch.Tensor(n_classes, in_features))
15 | if sigma:
16 | self.sigma = Parameter(torch.Tensor(1))
17 | else:
18 | self.register_parameter('sigma', None)
19 | self.reset_parameters()
20 |
21 | def reset_parameters(self):
22 | stdv = 1. / math.sqrt(self.weight.size(1))
23 | self.weight.data.uniform_(-stdv, stdv)
24 | if self.sigma is not None:
25 | self.sigma.data.fill_(1) #for initializaiton of sigma
26 |
27 | def forward(self, input):
28 | out = F.linear(F.normalize(input, p=2, dim=1), F.normalize(self.weight, p=2, dim=1))
29 | if self.sigma is not None:
30 | out = self.sigma * out
31 | return out
32 |
--------------------------------------------------------------------------------
/inclearn/tools/cutout.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | class Cutout(object):
5 | """Randomly mask out one or more patches from an image.
6 | Args:
7 | n_holes (int): Number of patches to cut out of each image.
8 | length (int): The length (in pixels) of each square patch.
9 | """
10 | def __init__(self, n_holes, length):
11 | self.n_holes = n_holes
12 | self.length = length
13 |
14 | def __call__(self, img):
15 | """
16 | Args:
17 | img (Tensor): Tensor image of size (C, H, W).
18 | Returns:
19 | Tensor: Image with n_holes of dimension length x length cut out of it.
20 | """
21 | h = img.size(1)
22 | w = img.size(2)
23 |
24 | mask = np.ones((h, w), np.float32)
25 |
26 | for n in range(self.n_holes):
27 | y = np.random.randint(h)
28 | x = np.random.randint(w)
29 |
30 | y1 = np.clip(y - self.length // 2, 0, h)
31 | y2 = np.clip(y + self.length // 2, 0, h)
32 | x1 = np.clip(x - self.length // 2, 0, w)
33 | x2 = np.clip(x + self.length // 2, 0, w)
34 |
35 | mask[y1: y2, x1: x2] = 0.
36 |
37 | mask = torch.from_numpy(mask)
38 | mask = mask.expand_as(img)
39 | img = img * mask
40 |
41 | return img
--------------------------------------------------------------------------------
/inclearn/prune/utils.py:
--------------------------------------------------------------------------------
1 | from .dependency import TORCH_CONV, TORCH_BATCHNORM, TORCH_PRELU, TORCH_LINEAR
2 |
3 | def count_prunable_params(module):
4 | if isinstance( module, ( TORCH_CONV, TORCH_LINEAR) ):
5 | num_params = module.weight.numel()
6 | if module.bias is not None:
7 | num_params += module.bias.numel()
8 | return num_params
9 | elif isinstance( module, TORCH_BATCHNORM ):
10 | num_params = module.running_mean.numel() + module.running_var.numel()
11 | if module.affine:
12 | num_params+= module.weight.numel() + module.bias.numel()
13 | return num_params
14 | elif isinstance( module, TORCH_PRELU ):
15 | if len( module.weight )==1:
16 | return 0
17 | else:
18 | return module.weight.numel
19 | else:
20 | return 0
21 |
22 | def count_prunable_channels(module):
23 | if isinstance( module, TORCH_CONV ):
24 | return module.weight.shape[0]
25 | elif isinstance( module, TORCH_LINEAR ):
26 | return module.out_features
27 | elif isinstance( module, TORCH_BATCHNORM ):
28 | return module.num_features
29 | elif isinstance( module, TORCH_PRELU ):
30 | if len( module.weight )==1:
31 | return 0
32 | else:
33 | return len(module.weight)
34 | else:
35 | return 0
36 |
37 | def count_params(module):
38 | return sum([ p.numel() for p in module.parameters() ])
39 |
--------------------------------------------------------------------------------
/inclearn/tools/results_utils.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import json
3 | import math
4 | import os
5 | import numpy as np
6 | import matplotlib.pyplot as plt
7 | from copy import deepcopy
8 |
9 | from . import utils
10 |
11 |
12 | def get_template_results(cfg):
13 | return {"config": cfg, "results": []}
14 |
15 |
16 | def save_results(results, label):
17 | del results["config"]["device"]
18 |
19 | folder_path = os.path.join("results", "{}_{}".format(utils.get_date(), label))
20 | if not os.path.exists(folder_path):
21 | os.makedirs(folder_path)
22 |
23 | file_path = "{}_{}_.json".format(utils.get_date(), results["config"]["seed"])
24 | with open(os.path.join(folder_path, file_path), "w+") as f:
25 | json.dump(results, f, indent=2)
26 |
27 |
28 | def compute_avg_inc_acc(results):
29 | """Computes the average incremental accuracy as defined in iCaRL.
30 |
31 | The average incremental accuracies at task X are the average of accuracies
32 | at task 0, 1, ..., and X.
33 |
34 | :param accs: A list of dict for per-class accuracy at each step.
35 | :return: A float.
36 | """
37 | top1_tasks_accuracy = [r['top1']["total"] for r in results]
38 | top1acc = sum(top1_tasks_accuracy) / len(top1_tasks_accuracy)
39 | if "top5" in results[0].keys():
40 | top5_tasks_accuracy = [r['top5']["total"] for r in results]
41 | top5acc = sum(top5_tasks_accuracy) / len(top5_tasks_accuracy)
42 | else:
43 | top5acc = None
44 | return top1acc, top5acc
--------------------------------------------------------------------------------
/inclearn/convnet/imbalance.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 | import numpy as np
5 | from torch.optim.lr_scheduler import CosineAnnealingLR
6 |
7 |
8 |
9 | class CR(object):
10 | def __init__(self):
11 | self.gamma = None
12 |
13 | @torch.no_grad()
14 | def update(self, classifier, task_size):
15 | old_weight_norm = torch.norm(classifier.weight[:-task_size], p=2, dim=1)
16 | new_weight_norm = torch.norm(classifier.weight[-task_size:], p=2, dim=1)
17 | self.gamma = old_weight_norm.mean() / new_weight_norm.mean()
18 | # print(self.gamma.cpu().item())
19 |
20 | @torch.no_grad()
21 | def post_process(self, logits, task_size):
22 | logits[:, -task_size:] = logits[:, -task_size:] * self.gamma
23 | return logits
24 |
25 |
26 | class All_av(object):
27 | def __init__(self):
28 | self.gamma = []
29 |
30 | @torch.no_grad()
31 | def update(self, classifier, task_size, classnum_list, taski):
32 | self.gamma = []
33 | for i in range(taski+1):
34 | old_weight_norm = torch.norm(classifier.weight[:-task_size], p=2, dim=1)
35 | new_weight_norm = torch.norm(classifier.weight[sum(classnum_list[:i]):sum(classnum_list[:i+1])], p=2, dim=1)
36 | self.gamma.append(old_weight_norm.mean() / new_weight_norm.mean())
37 | # print(self.gamma)
38 |
39 | @torch.no_grad()
40 | def post_process(self, logits, task_size, classnum_list, taski):
41 | for i in range(taski+1):
42 | logits[:, sum(classnum_list[:i]):sum(classnum_list[:i+1])] = logits[:, sum(classnum_list[:i]):sum(classnum_list[:i+1])] * self.gamma[i]
43 | return logits
44 |
--------------------------------------------------------------------------------
/inclearn/prune/prune/unstructured.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from copy import deepcopy
4 |
5 | __all__=['mask_weight', 'mask_bias']
6 |
7 | def _mask_weight_hook(module, input):
8 | if hasattr(module, 'weight_mask'):
9 | module.weight.data *= module.weight_mask
10 |
11 | def _mask_bias_hook(module, input):
12 | if module.bias is not None and hasattr(module, 'bias_mask'):
13 | module.bias.data *= module.bias_mask
14 |
15 | def mask_weight(layer, mask, inplace=True):
16 | """Unstructed pruning for convolution layer
17 |
18 | Args:
19 | layer: a convolution layer.
20 | mask: 0-1 mask.
21 | """
22 | if not inplace:
23 | layer = deepcopy(layer)
24 | if mask.shape != layer.weight.shape:
25 | return layer
26 | mask = torch.tensor( mask, dtype=layer.weight.dtype, device=layer.weight.device, requires_grad=False )
27 | if hasattr(layer, 'weight_mask'):
28 | mask = mask + layer.weight_mask
29 | mask[mask>0]=1
30 | layer.weight_mask = mask
31 | else:
32 | layer.register_buffer( 'weight_mask', mask )
33 |
34 | layer.register_forward_pre_hook( _mask_weight_hook )
35 | return layer
36 |
37 | def mask_bias(layer, mask, inplace=True):
38 | """Unstructed pruning for convolution layer
39 |
40 | Args:
41 | layer: a convolution layer.
42 | mask: 0-1 mask.
43 | """
44 | if not inplace:
45 | layer = deepcopy(layer)
46 | if layer.bias is None or mask.shape != layer.bias.shape:
47 | return layer
48 |
49 | mask = torch.tensor( mask, dtype=layer.weight.dtype, device=layer.weight.device, requires_grad=False )
50 | if hasattr(layer, 'bias_mask'):
51 | mask = mask + layer.bias_mask
52 | mask[mask>0]=1
53 | layer.bias_mask = mask
54 | else:
55 | layer.register_buffer( 'bias_mask', mask )
56 | layer.register_forward_pre_hook( _mask_bias_hook )
57 | return layer
58 |
--------------------------------------------------------------------------------
/inclearn/loss/focalloss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.autograd import Variable
4 | from torch.nn import functional as F
5 |
6 | class BCEFocalLoss(torch.nn.Module):
7 | def __init__(self, gamma=2, alpha=0.25, reduction='mean'):
8 | super(BCEFocalLoss, self).__init__()
9 | self.gamma = gamma
10 | self.alpha = alpha
11 | self.reduction = reduction
12 |
13 | def forward(self, predict, target):
14 | pt = torch.sigmoid(predict) # sigmoide获取概率
15 | #在原始ce上增加动态权重因子,注意alpha的写法,下面多类时不能这样使用
16 | loss = - self.alpha * (1 - pt) ** self.gamma * target * torch.log(pt) - (1 - self.alpha) * pt ** self.gamma * (1 - target) * torch.log(1 - pt)
17 | if self.reduction == 'mean':
18 | loss = torch.mean(loss)
19 | elif self.reduction == 'sum':
20 | loss = torch.sum(loss)
21 | return loss
22 |
23 | class MultiCEFocalLoss(torch.nn.Module):
24 | def __init__(self, class_num, gamma=2, alpha=None, reduction='mean'):
25 | super(MultiCEFocalLoss, self).__init__()
26 | if alpha is None:
27 | self.alpha = Variable(torch.ones(class_num, 1))
28 | else:
29 | self.alpha = alpha
30 | self.gamma = gamma
31 | self.reduction = reduction
32 | self.class_num = class_num
33 |
34 | def forward(self, predict, target):
35 | pt = F.softmax(predict, dim=1) # softmmax获取预测概率
36 | class_mask = F.one_hot(target, self.class_num) #获取target的one hot编码
37 | ids = target.view(-1, 1)
38 | # print(ids.shape, self.alpha.shape)
39 | alpha = self.alpha[ids.data.view(-1)].cuda() # 注意,这里的alpha是给定的一个list(tensor
40 | #),里面的元素分别是每一个类的权重因子
41 | probs = (pt * class_mask).sum(1).view(-1, 1) # 利用onehot作为mask,提取对应的pt
42 | log_p = probs.log()
43 | # 同样,原始ce上增加一个动态权重衰减因子
44 | # print(probs.shape, alpha.shape)
45 | loss = -alpha * (torch.pow((1 - probs), self.gamma)) * log_p
46 |
47 | if self.reduction == 'mean':
48 | loss = loss.mean()
49 | elif self.reduction == 'sum':
50 | loss = loss.sum()
51 | return loss
--------------------------------------------------------------------------------
/configs/cifar_b0_10s.yaml:
--------------------------------------------------------------------------------
1 | exp:
2 | name: "CIFAR100_B0_10S_TCIL"
3 | savedir: "./logs"
4 | tensorboard_dir: "./tensorboard"
5 | debug: False
6 |
7 |
8 | #Model Cfg
9 | model: "incmodel"
10 | convnet: 'resnet18'
11 | train_head: 'softmax'
12 | infer_head: 'softmax'
13 | channel: 64
14 | use_bias: False
15 | last_relu: False
16 |
17 | dea: True
18 | use_div_cls: True
19 | div_type: "n+1" # n+t, 1+1
20 | distillation: True
21 | disttype: "MTKD"
22 | temperature: 2
23 | distlamb: 1
24 | feature_type: "ffm" # se
25 | attention_use_residual: True
26 | ignore_new: True
27 |
28 | prune: False
29 |
30 | attention:
31 | add_kl: True
32 | kd_warm_up: 50
33 | kd_loss_weight: 0.5
34 | kl_loss_weight: 0.5
35 |
36 | reuse_oldfc: False
37 | weight_normalization: False
38 | val_per_n_epoch: -1 # Validation Per N epoch. -1 means the function is off.
39 | save_ckpt: True
40 | save_mem: True
41 | load_mem: False
42 |
43 | #Optimization;Training related
44 | task_max: 10
45 | lr_min: 0.00005
46 | lr: 0.1
47 | weight_decay: 0.0005
48 | dynamic_weight_decay: False
49 | scheduler: 'multistep'
50 | scheduling:
51 | - 100
52 | - 120
53 | lr_decay: 0.1
54 | optimizer: "sgd"
55 | epochs: 170
56 | resampling: False
57 | warmup: True
58 | warmup_epochs: 10
59 |
60 | postprocessor:
61 | enable: True
62 | type: 'cr'
63 |
64 | pretrain:
65 | epochs: 200
66 | lr: 0.1
67 | scheduling:
68 | - 60
69 | - 120
70 | - 160
71 | lr_decay: 0.1
72 | weight_decay: 0.0005
73 |
74 |
75 | # Dataset Cfg
76 | dataset: "cifar100" #'imagenet100', 'cifar100'
77 | trial: 2
78 | increment: 10
79 | batch_size: 128
80 | workers: 4
81 | validation: 0 # Validation split (0. <= x <= 1.)
82 | random_classes: False #Randomize classes order of increment
83 | start_class: 0 # number of tasks for the first step, start from 0.
84 | start_task: 0
85 | max_task: # Cap the number of task
86 |
87 | #Memory
88 | coreset_strategy: "iCaRL" # iCaRL, random
89 | mem_size_mode: "uniform_fixed_total_mem" #uniform_fixed_per_cls, uniform_fixed_total_mem
90 | memory_size: 2000 # Max number of storable examplars
91 | fixed_memory_per_cls: 20 # the fixed number of exemplars per cls
92 |
93 | # Misc
94 | device: 0 #GPU index to use, for cpu use -1
95 | seed: 1993
96 |
--------------------------------------------------------------------------------
/configs/cifar_b0_20s.yaml:
--------------------------------------------------------------------------------
1 | exp:
2 | name: "CIFAR100_B0_20S_TCIL"
3 | savedir: "./logs"
4 | tensorboard_dir: "./tensorboard"
5 | debug: False
6 |
7 |
8 | #Model Cfg
9 | model: "incmodel"
10 | convnet: 'resnet18'
11 | train_head: 'softmax'
12 | infer_head: 'softmax'
13 | channel: 64
14 | use_bias: False
15 | last_relu: False
16 |
17 | dea: True
18 | use_div_cls: True
19 | div_type: "n+1" # n+t, 1+1
20 | distillation: True
21 | disttype: "MTKD"
22 | temperature: 2
23 | distlamb: 1
24 | feature_type: "ffm" # se
25 | attention_use_residual: True
26 | ignore_new: True
27 |
28 | prune: False
29 |
30 | attention:
31 | add_kl: True
32 | kd_warm_up: 50
33 | kd_loss_weight: 0.5
34 | kl_loss_weight: 0.5
35 |
36 | reuse_oldfc: False
37 | weight_normalization: False
38 | val_per_n_epoch: -1 # Validation Per N epoch. -1 means the function is off.
39 | save_ckpt: True
40 | save_mem: True
41 | load_mem: False
42 |
43 | #Optimization;Training related
44 | task_max: 10
45 | lr_min: 0.00005
46 | lr: 0.1
47 | weight_decay: 0.0005
48 | dynamic_weight_decay: False
49 | scheduler: 'multistep'
50 | scheduling:
51 | - 100
52 | - 120
53 | lr_decay: 0.1
54 | optimizer: "sgd"
55 | epochs: 170
56 | resampling: False
57 | warmup: True
58 | warmup_epochs: 10
59 |
60 | postprocessor:
61 | enable: True
62 | type: 'cr'
63 |
64 | pretrain:
65 | epochs: 200
66 | lr: 0.1
67 | scheduling:
68 | - 60
69 | - 120
70 | - 160
71 | lr_decay: 0.1
72 | weight_decay: 0.0005
73 |
74 |
75 | # Dataset Cfg
76 | dataset: "cifar100" #'imagenet100', 'cifar100'
77 | trial: 2
78 | increment: 5
79 | batch_size: 128
80 | workers: 4
81 | validation: 0 # Validation split (0. <= x <= 1.)
82 | random_classes: False #Randomize classes order of increment
83 | start_class: 0 # number of tasks for the first step, start from 0.
84 | start_task: 0
85 | max_task: # Cap the number of task
86 |
87 | #Memory
88 | coreset_strategy: "iCaRL" # iCaRL, random
89 | mem_size_mode: "uniform_fixed_total_mem" #uniform_fixed_per_cls, uniform_fixed_total_mem
90 | memory_size: 2000 # Max number of storable examplars
91 | fixed_memory_per_cls: 20 # the fixed number of exemplars per cls
92 |
93 | # Misc
94 | device: 0 #GPU index to use, for cpu use -1
95 | seed: 1993
96 |
--------------------------------------------------------------------------------
/configs/cifar_b0_5s.yaml:
--------------------------------------------------------------------------------
1 | exp:
2 | name: "CIFAR100_B0_5S_TCIL"
3 | savedir: "./logs"
4 | tensorboard_dir: "./tensorboard"
5 | debug: False
6 |
7 |
8 | #Model Cfg
9 | model: "incmodel"
10 | convnet: 'resnet18'
11 | train_head: 'softmax'
12 | infer_head: 'softmax'
13 | channel: 64
14 | use_bias: False
15 | last_relu: False
16 |
17 | dea: True
18 | use_div_cls: True
19 | div_type: "n+1" # n+t, 1+1
20 | distillation: True
21 | disttype: "MTKD"
22 | temperature: 2
23 | distlamb: 1
24 | feature_type: "ffm" # se
25 | attention_use_residual: True
26 | ignore_new: True
27 |
28 | prune: False
29 |
30 | attention:
31 | add_kl: True
32 | kd_warm_up: 50
33 | kd_loss_weight: 0.5
34 | kl_loss_weight: 0.5
35 |
36 | reuse_oldfc: False
37 | weight_normalization: False
38 | val_per_n_epoch: -1 # Validation Per N epoch. -1 means the function is off.
39 | save_ckpt: True
40 | save_mem: True
41 | load_mem: False
42 |
43 | #Optimization;Training related
44 | task_max: 10
45 | lr_min: 0.00005
46 | lr: 0.1
47 | weight_decay: 0.0005
48 | dynamic_weight_decay: False
49 | scheduler: 'multistep'
50 | scheduling:
51 | - 100
52 | - 120
53 | lr_decay: 0.1
54 | optimizer: "sgd"
55 | epochs: 170
56 | resampling: False
57 | warmup: True
58 | warmup_epochs: 10
59 |
60 | postprocessor:
61 | enable: True
62 | type: 'cr'
63 |
64 | pretrain:
65 | epochs: 200
66 | lr: 0.1
67 | scheduling:
68 | - 60
69 | - 120
70 | - 160
71 | lr_decay: 0.1
72 | weight_decay: 0.0005
73 |
74 |
75 | # Dataset Cfg
76 | dataset: "cifar100" #'imagenet100', 'cifar100'
77 | trial: 2
78 | increment: 20
79 | batch_size: 128
80 | workers: 4
81 | validation: 0 # Validation split (0. <= x <= 1.)
82 | random_classes: False #Randomize classes order of increment
83 | start_class: 0 # number of tasks for the first step, start from 0.
84 | start_task: 0
85 | max_task: # Cap the number of task
86 |
87 | #Memory
88 | coreset_strategy: "iCaRL" # iCaRL, random
89 | mem_size_mode: "uniform_fixed_total_mem" #uniform_fixed_per_cls, uniform_fixed_total_mem
90 | memory_size: 2000 # Max number of storable examplars
91 | fixed_memory_per_cls: 20 # the fixed number of exemplars per cls
92 |
93 | # Misc
94 | device: 0 #GPU index to use, for cpu use -1
95 | seed: 1993
96 |
--------------------------------------------------------------------------------
/configs/cifar_b50_10s.yaml:
--------------------------------------------------------------------------------
1 | exp:
2 | name: "CIFAR100_B50_10S_TCIL"
3 | savedir: "./logs"
4 | tensorboard_dir: "./tensorboard"
5 | debug: False
6 |
7 |
8 | #Model Cfg
9 | model: "incmodel"
10 | convnet: 'resnet18'
11 | train_head: 'softmax'
12 | infer_head: 'softmax'
13 | channel: 64
14 | use_bias: False
15 | last_relu: False
16 |
17 | dea: True
18 | use_div_cls: True
19 | div_type: "n+1" # n+t, 1+1
20 | distillation: True
21 | disttype: "MTKD"
22 | temperature: 2
23 | distlamb: 1
24 | feature_type: "ffm" # se
25 | attention_use_residual: True
26 | ignore_new: True
27 |
28 | prune: False
29 |
30 | attention:
31 | add_kl: True
32 | kd_warm_up: 50
33 | kd_loss_weight: 0.5
34 | kl_loss_weight: 0.5
35 |
36 | reuse_oldfc: False
37 | weight_normalization: False
38 | val_per_n_epoch: -1 # Validation Per N epoch. -1 means the function is off.
39 | save_ckpt: True
40 | save_mem: True
41 | load_mem: False
42 |
43 | #Optimization;Training related
44 | task_max: 10
45 | lr_min: 0.00005
46 | lr: 0.1
47 | weight_decay: 0.0005
48 | dynamic_weight_decay: False
49 | scheduler: 'multistep'
50 | scheduling:
51 | - 100
52 | - 120
53 | lr_decay: 0.1
54 | optimizer: "sgd"
55 | epochs: 170
56 | resampling: False
57 | warmup: True
58 | warmup_epochs: 10
59 |
60 | postprocessor:
61 | enable: True
62 | type: 'cr'
63 |
64 | pretrain:
65 | epochs: 200
66 | lr: 0.1
67 | scheduling:
68 | - 60
69 | - 120
70 | - 160
71 | lr_decay: 0.1
72 | weight_decay: 0.0005
73 |
74 |
75 | # Dataset Cfg
76 | dataset: "cifar100" #'imagenet100', 'cifar100'
77 | trial: 2
78 | increment: 5
79 | batch_size: 128
80 | workers: 4
81 | validation: 0 # Validation split (0. <= x <= 1.)
82 | random_classes: False #Randomize classes order of increment
83 | start_class: 50 # number of tasks for the first step, start from 0.
84 | start_task: 0
85 | max_task: # Cap the number of task
86 |
87 | #Memory
88 | coreset_strategy: "iCaRL" # iCaRL, random
89 | mem_size_mode: "uniform_fixed_total_mem" #uniform_fixed_per_cls, uniform_fixed_total_mem
90 | memory_size: 2000 # Max number of storable examplars
91 | fixed_memory_per_cls: 20 # the fixed number of exemplars per cls
92 |
93 | # Misc
94 | device: 0 #GPU index to use, for cpu use -1
95 | seed: 1993
96 |
--------------------------------------------------------------------------------
/configs/cifar_b50_2s.yaml:
--------------------------------------------------------------------------------
1 | exp:
2 | name: "CIFAR100_B50_2S_TCIL"
3 | savedir: "./logs"
4 | tensorboard_dir: "./tensorboard"
5 | debug: False
6 |
7 |
8 | #Model Cfg
9 | model: "incmodel"
10 | convnet: 'resnet18'
11 | train_head: 'softmax'
12 | infer_head: 'softmax'
13 | channel: 64
14 | use_bias: False
15 | last_relu: False
16 |
17 | dea: True
18 | use_div_cls: True
19 | div_type: "n+1" # n+t, 1+1
20 | distillation: True
21 | disttype: "MTKD"
22 | temperature: 2
23 | distlamb: 1
24 | feature_type: "ffm" # se
25 | attention_use_residual: True
26 | ignore_new: True
27 |
28 | prune: False
29 |
30 | attention:
31 | add_kl: True
32 | kd_warm_up: 50
33 | kd_loss_weight: 0.5
34 | kl_loss_weight: 0.5
35 |
36 | reuse_oldfc: False
37 | weight_normalization: False
38 | val_per_n_epoch: -1 # Validation Per N epoch. -1 means the function is off.
39 | save_ckpt: True
40 | save_mem: True
41 | load_mem: False
42 |
43 | #Optimization;Training related
44 | task_max: 10
45 | lr_min: 0.00005
46 | lr: 0.1
47 | weight_decay: 0.0005
48 | dynamic_weight_decay: False
49 | scheduler: 'multistep'
50 | scheduling:
51 | - 100
52 | - 120
53 | lr_decay: 0.1
54 | optimizer: "sgd"
55 | epochs: 170
56 | resampling: False
57 | warmup: True
58 | warmup_epochs: 10
59 |
60 | postprocessor:
61 | enable: True
62 | type: 'cr'
63 |
64 | pretrain:
65 | epochs: 200
66 | lr: 0.1
67 | scheduling:
68 | - 60
69 | - 120
70 | - 160
71 | lr_decay: 0.1
72 | weight_decay: 0.0005
73 |
74 |
75 | # Dataset Cfg
76 | dataset: "cifar100" #'imagenet100', 'cifar100'
77 | trial: 2
78 | increment: 25
79 | batch_size: 128
80 | workers: 4
81 | validation: 0 # Validation split (0. <= x <= 1.)
82 | random_classes: False #Randomize classes order of increment
83 | start_class: 50 # number of tasks for the first step, start from 0.
84 | start_task: 0
85 | max_task: # Cap the number of task
86 |
87 | #Memory
88 | coreset_strategy: "iCaRL" # iCaRL, random
89 | mem_size_mode: "uniform_fixed_total_mem" #uniform_fixed_per_cls, uniform_fixed_total_mem
90 | memory_size: 2000 # Max number of storable examplars
91 | fixed_memory_per_cls: 20 # the fixed number of exemplars per cls
92 |
93 | # Misc
94 | device: 0 #GPU index to use, for cpu use -1
95 | seed: 1993
96 |
--------------------------------------------------------------------------------
/configs/cifar_b50_5s.yaml:
--------------------------------------------------------------------------------
1 | exp:
2 | name: "CIFAR100_B50_5S_TCIL"
3 | savedir: "./logs"
4 | tensorboard_dir: "./tensorboard"
5 | debug: False
6 |
7 |
8 | #Model Cfg
9 | model: "incmodel"
10 | convnet: 'resnet18'
11 | train_head: 'softmax'
12 | infer_head: 'softmax'
13 | channel: 64
14 | use_bias: False
15 | last_relu: False
16 |
17 | dea: True
18 | use_div_cls: True
19 | div_type: "n+1" # n+t, 1+1
20 | distillation: True
21 | disttype: "MTKD"
22 | temperature: 2
23 | distlamb: 1
24 | feature_type: "ffm" # se
25 | attention_use_residual: True
26 | ignore_new: True
27 |
28 | prune: False
29 |
30 | attention:
31 | add_kl: True
32 | kd_warm_up: 50
33 | kd_loss_weight: 0.5
34 | kl_loss_weight: 0.5
35 |
36 | reuse_oldfc: False
37 | weight_normalization: False
38 | val_per_n_epoch: -1 # Validation Per N epoch. -1 means the function is off.
39 | save_ckpt: True
40 | save_mem: True
41 | load_mem: False
42 |
43 | #Optimization;Training related
44 | task_max: 10
45 | lr_min: 0.00005
46 | lr: 0.1
47 | weight_decay: 0.0005
48 | dynamic_weight_decay: False
49 | scheduler: 'multistep'
50 | scheduling:
51 | - 100
52 | - 120
53 | lr_decay: 0.1
54 | optimizer: "sgd"
55 | epochs: 170
56 | resampling: False
57 | warmup: True
58 | warmup_epochs: 10
59 |
60 | postprocessor:
61 | enable: True
62 | type: 'cr'
63 |
64 | pretrain:
65 | epochs: 200
66 | lr: 0.1
67 | scheduling:
68 | - 60
69 | - 120
70 | - 160
71 | lr_decay: 0.1
72 | weight_decay: 0.0005
73 |
74 |
75 | # Dataset Cfg
76 | dataset: "cifar100" #'imagenet100', 'cifar100'
77 | trial: 2
78 | increment: 10
79 | batch_size: 128
80 | workers: 4
81 | validation: 0 # Validation split (0. <= x <= 1.)
82 | random_classes: False #Randomize classes order of increment
83 | start_class: 50 # number of tasks for the first step, start from 0.
84 | start_task: 0
85 | max_task: # Cap the number of task
86 |
87 | #Memory
88 | coreset_strategy: "iCaRL" # iCaRL, random
89 | mem_size_mode: "uniform_fixed_total_mem" #uniform_fixed_per_cls, uniform_fixed_total_mem
90 | memory_size: 2000 # Max number of storable examplars
91 | fixed_memory_per_cls: 20 # the fixed number of exemplars per cls
92 |
93 | # Misc
94 | device: 0 #GPU index to use, for cpu use -1
95 | seed: 1993
96 |
--------------------------------------------------------------------------------
/configs/imagenet_b0_10s.yaml:
--------------------------------------------------------------------------------
1 | exp:
2 | name: "ImageNet100_B0_10S_TCIL"
3 | savedir: "./logs"
4 | tensorboard_dir: "./tensorboard"
5 | debug: False
6 |
7 |
8 | #Model Cfg
9 | model: "incmodel"
10 | convnet: 'resnet18'
11 | train_head: 'softmax'
12 | infer_head: 'softmax'
13 | channel: 64
14 | use_bias: False
15 | last_relu: False
16 |
17 | dea: True
18 | use_div_cls: True
19 | div_type: "n+1" # n+t, 1+1
20 | distillation: True
21 | disttype: "MTKD"
22 | temperature: 2
23 | distlamb: 1
24 | feature_type: "ffm" # se
25 | attention_use_residual: True
26 | ignore_new: True
27 |
28 | prune: False
29 |
30 | attention:
31 | add_kl: True
32 | kd_warm_up: 50
33 | kd_loss_weight: 0.5
34 | kl_loss_weight: 0.5
35 |
36 | reuse_oldfc: False
37 | weight_normalization: False
38 | val_per_n_epoch: -1 # Validation Per N epoch. -1 means the function is off.
39 | save_ckpt: True
40 | save_mem: True
41 | load_mem: False
42 |
43 | #Optimization;Training related
44 | task_max: 10
45 | lr_min: 0.00005
46 | lr: 0.1
47 | weight_decay: 0.0005
48 | dynamic_weight_decay: False
49 | scheduler: 'multistep'
50 | scheduling:
51 | - 60
52 | - 120
53 | - 180
54 | lr_decay: 0.1
55 | optimizer: "sgd"
56 | epochs: 200
57 | resampling: False
58 | warmup: True
59 | warmup_epochs: 20
60 |
61 | postprocessor:
62 | enable: True
63 | type: 'cr'
64 |
65 | pretrain:
66 | epochs: 200
67 | lr: 0.1
68 | scheduling:
69 | - 60
70 | - 120
71 | - 160
72 | lr_decay: 0.1
73 | weight_decay: 0.0005
74 |
75 |
76 | # Dataset Cfg
77 | dataset: "imagenet100" #'imagenet100', 'cifar100'
78 | trial: 2
79 | increment: 10
80 | batch_size: 128
81 | workers: 4
82 | validation: 0 # Validation split (0. <= x <= 1.)
83 | random_classes: False #Randomize classes order of increment
84 | start_class: 0 # number of tasks for the first step, start from 0.
85 | start_task: 0
86 | max_task: # Cap the number of task
87 |
88 | #Memory
89 | coreset_strategy: "iCaRL" # iCaRL, random
90 | mem_size_mode: "uniform_fixed_total_mem" #uniform_fixed_per_cls, uniform_fixed_total_mem
91 | memory_size: 2000 # Max number of storable examplars
92 | fixed_memory_per_cls: 20 # the fixed number of exemplars per cls
93 |
94 | # Misc
95 | device: 0 #GPU index to use, for cpu use -1
96 | seed: 1993
97 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # TCIL
4 |
5 | ## Resolving Task Confusion in Dynamic Expansion Architectures for Class Incremental Learning
6 |
7 | [](https://arxiv.org/abs/2212.14284)
8 | 
9 | [](https://www.bilibili.com/video/BV1L14y1u7XV/)
10 |
11 |
12 | 
13 |
14 |
15 |
16 |
17 |
18 |
19 | # Datasets
20 |
21 | - Training datasets
22 | 1. CIFAR100:
23 | CIFAR100 dataset will be auto-downloaded.
24 | 2. ImageNet100:
25 | ImageNet100 is a subset of ImageNet. You need to download ImageNet first, and split the dataset refer to [ImageNet100_Split](https://github.com/arthurdouillard/incremental_learning.pytorch).
26 |
27 | - Class ordering
28 | - We use the class ordering proposed by [DER](https://github.com/Rhyssiyan/DER-ClassIL.pytorch).
29 |
30 | - Structure of `data` directory
31 | ```
32 | data
33 | ├── cifar100
34 | │ └── cifar-100-python
35 | │ ├── train
36 | │ ├── test
37 | │ ├── meta
38 | │ └── file.txt~
39 | │
40 | ├── imagenet100
41 | │ ├── train
42 | │ └── val
43 | ```
44 |
45 | # Environment
46 | You can find all the libraries in the `requirements.txt`, and configure the experimental environment with the following commands.
47 |
48 | ```
49 | conda create -n TCIL python=3.8
50 | conda install pytorch==1.8.1 torchvision==0.9.1 cudatoolkit=11.1 -c pytorch
51 | pip install -r requirements.txt
52 | ```
53 | Thanks for the great code base from [DER](https://github.com/Rhyssiyan/DER-ClassIL.pytorch).
54 | # Launching an experiment
55 | ## Train
56 | `sh scripts/run.sh`
57 | ## Eval
58 | `sh scripts/inference.sh`
59 | ## Prune
60 | `sh scripts/prune.sh`
61 |
62 |
63 | # Results
64 |
65 | ## Rehearsal Setting
66 |
67 | 
68 | 
69 |
70 | ## Non-Rehearsal Setting
71 |
72 | 
73 |
74 | ## Checkpoints
75 |
76 | Get the trained models from [BaiduNetdisk(passwd:q3eh)](https://pan.baidu.com/s/1G0XVZCaaZ2LmM_eppr3cXA).
77 | (We both offer the training logs in the same file)
78 |
79 |
--------------------------------------------------------------------------------
/inclearn/tools/factory.py:
--------------------------------------------------------------------------------
1 | from matplotlib.transforms import Transform
2 | import torch
3 | from torch import nn
4 | from torch import optim
5 |
6 | from inclearn import models
7 | from inclearn.convnet import resnet, cifar_resnet, modified_resnet_cifar, preact_resnet
8 | from inclearn.datasets import data
9 | from inclearn.convnet.resnet import SEFeatureAt
10 |
11 | def get_optimizer(params, optimizer, lr, weight_decay=0.0):
12 | if optimizer == "adam":
13 | return optim.Adam(params, lr=lr, weight_decay=weight_decay, betas=(0.9, 0.999))
14 | elif optimizer == "sgd":
15 | return optim.SGD(params, lr=lr, weight_decay=weight_decay, momentum=0.9)
16 | else:
17 | raise NotImplementedError
18 |
19 | def get_attention(inplane, type, at_res):
20 | return SEFeatureAt(inplane, type, at_res)
21 |
22 | def get_convnet(convnet_type, **kwargs):
23 | if convnet_type == "resnet18":
24 | return resnet.resnet18(**kwargs)
25 | elif convnet_type == "resnet32":
26 | return cifar_resnet.resnet32()
27 | elif convnet_type == "modified_resnet32":
28 | return modified_resnet_cifar.resnet32(**kwargs)
29 | elif convnet_type == "preact_resnet18":
30 | return preact_resnet.PreActResNet18(**kwargs)
31 | else:
32 | raise NotImplementedError("Unknwon convnet type {}.".format(convnet_type))
33 |
34 |
35 | def get_model(cfg, trial_i, _run, ex, tensorboard, inc_dataset):
36 | if cfg["model"] == "incmodel":
37 | return models.IncModel(cfg, trial_i, _run, ex, tensorboard, inc_dataset)
38 | else:
39 | raise NotImplementedError(cfg["model"])
40 |
41 |
42 | def get_data(cfg, trial_i):
43 | return data.IncrementalDataset(
44 | trial_i=trial_i,
45 | dataset_name=cfg["dataset"],
46 | random_order=cfg["random_classes"],
47 | shuffle=True,
48 | batch_size=cfg["batch_size"],
49 | workers=cfg["workers"],
50 | validation_split=cfg["validation"],
51 | resampling=cfg["resampling"],
52 | increment=cfg["increment"],
53 | data_folder=cfg["data_folder"],
54 | start_class=cfg["start_class"],
55 | # transform=cfg.get("transform","normal")
56 | )
57 |
58 |
59 | def set_device(cfg):
60 | device_type = cfg["device"]
61 |
62 | if device_type == -1:
63 | device = torch.device("cpu")
64 | else:
65 | device = torch.device("cuda:{}".format(device_type))
66 |
67 | cfg["device"] = device
68 | return device
69 |
--------------------------------------------------------------------------------
/inclearn/loss/tripeloss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.autograd import Variable
4 | import pdb
5 |
6 |
7 | class TripletLossNoHardMining(nn.Module):
8 | def __init__(self, margin=0, num_instances=8):
9 | super(TripletLossNoHardMining, self).__init__()
10 | self.margin = margin
11 | self.num_instances = num_instances
12 | self.ranking_loss = nn.MarginRankingLoss(margin=margin)
13 |
14 | def forward(self, inputs, targets):
15 | n = inputs.size(0)
16 | # Compute pairwise distance, replace by the official when merged
17 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n)
18 | dist = dist + dist.t()
19 | dist.addmm_(1, -2, inputs, inputs.t())
20 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
21 | # For each anchor, find the hardest positive and negative
22 | mask = targets.expand(n, n).eq(targets.expand(n, n).t())
23 | dist_ap, dist_an = [], []
24 | for i in range(n):
25 | for j in range(self.num_instances-1):
26 | tmp = dist[i][mask[i]]
27 | dist_ap.append(tmp[j+1])
28 | tmp = dist[i][mask[i] == 0]
29 | dist_an.append(tmp[j+1])
30 | dist_ap = torch.stack(dist_ap)
31 | dist_an = torch.stack(dist_an)
32 | # Compute ranking hinge loss
33 | y = dist_an.data.new()
34 | y.resize_as_(dist_an.data)
35 | y.fill_(1)
36 | y = Variable(y)
37 | loss = self.ranking_loss(dist_an, dist_ap, y)
38 | prec = (dist_an.data > dist_ap.data).sum() * 1. / y.size(0)
39 | dist_p = torch.mean(dist_ap).data[0]
40 | dist_n = torch.mean(dist_an).data[0]
41 | return loss, prec, dist_p, dist_n
42 |
43 | class TripletLoss(nn.Module):
44 | def __init__(self, margin=0, num_instances=None):
45 | super(TripletLoss, self).__init__()
46 | self.margin = margin
47 | self.ranking_loss = nn.MarginRankingLoss(margin=margin)
48 |
49 | def forward(self, inputs, targets):
50 | n = inputs.size(0)
51 | # Compute pairwise distance, replace by the official when merged
52 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n)
53 | dist = dist + dist.t()
54 | dist.addmm_(1, -2, inputs, inputs.t())
55 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
56 | # For each anchor, find the hardest positive and negative
57 | mask = targets.expand(n, n).eq(targets.expand(n, n).t())
58 | dist_ap, dist_an = [], []
59 | for i in range(n):
60 | dist_ap.append(dist[i][mask[i]].max())
61 | dist_an.append(dist[i][mask[i] == 0].min())
62 | dist_ap = torch.stack(dist_ap)
63 | dist_an = torch.stack(dist_an)
64 | # Compute ranking hinge loss
65 | y = dist_an.data.new()
66 | y.resize_as_(dist_an.data)
67 | y.fill_(1)
68 | y = Variable(y)
69 | loss = self.ranking_loss(dist_an, dist_ap, y)
70 | prec = (dist_an.data > dist_ap.data).sum() * 1. / y.size(0)
71 | dist_p = torch.mean(dist_ap).item()
72 | dist_n = torch.mean(dist_an).item()
73 | return loss, prec, dist_p, dist_n
--------------------------------------------------------------------------------
/inclearn/tools/scheduler.py:
--------------------------------------------------------------------------------
1 | import math
2 | from torch.optim.lr_scheduler import _LRScheduler
3 | from torch.optim.lr_scheduler import ReduceLROnPlateau
4 |
5 |
6 | class ConstantTaskLR:
7 | def __init__(self, lr):
8 | self._lr = lr
9 |
10 | def get_lr(self, task_i):
11 | return self._lr
12 |
13 |
14 | class CosineAnnealTaskLR:
15 | def __init__(self, lr_max, lr_min, task_max):
16 | self._lr_max = lr_max
17 | self._lr_min = lr_min
18 | self._task_max = task_max
19 |
20 | def get_lr(self, task_i):
21 | return self._lr_min + (self._lr_max - self._lr_min) * (1 + math.cos(math.pi * task_i / self._task_max)) / 2
22 |
23 |
24 | class GradualWarmupScheduler(_LRScheduler):
25 | """ Gradually warm-up(increasing) learning rate in optimizer.
26 | https://github.com/ildoonet/pytorch-gradual-warmup-lr
27 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
28 | Args:
29 | optimizer (Optimizer): Wrapped optimizer.
30 | multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
31 | total_epoch: target learning rate is reached at total_epoch, gradually
32 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
33 | """
34 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
35 | self.multiplier = multiplier
36 | if self.multiplier < 1.:
37 | raise ValueError('multiplier should be greater thant or equal to 1.')
38 | self.total_epoch = total_epoch
39 | self.after_scheduler = after_scheduler
40 | self.finished = False
41 | super(GradualWarmupScheduler, self).__init__(optimizer)
42 |
43 | def get_lr(self):
44 | if self.last_epoch > self.total_epoch:
45 | if self.after_scheduler:
46 | if not self.finished:
47 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
48 | self.finished = True
49 | return self.after_scheduler.get_last_lr()
50 | return [base_lr * self.multiplier for base_lr in self.base_lrs]
51 |
52 | if self.multiplier == 1.0:
53 | return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
54 | else:
55 | return [
56 | base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.)
57 | for base_lr in self.base_lrs
58 | ]
59 |
60 | def step_ReduceLROnPlateau(self, metrics, epoch=None):
61 | if epoch is None:
62 | epoch = self.last_epoch + 1
63 | self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
64 | if self.last_epoch <= self.total_epoch:
65 | warmup_lr = [
66 | base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.)
67 | for base_lr in self.base_lrs
68 | ]
69 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
70 | param_group['lr'] = lr
71 | else:
72 | if epoch is None:
73 | self.after_scheduler.step(metrics, None)
74 | else:
75 | self.after_scheduler.step(metrics, epoch - self.total_epoch)
76 |
77 | def step(self, epoch=None, metrics=None):
78 | if type(self.after_scheduler) != ReduceLROnPlateau:
79 | if self.finished and self.after_scheduler:
80 | if epoch is None:
81 | self.after_scheduler.step(None)
82 | else:
83 | self.after_scheduler.step(epoch - self.total_epoch)
84 | self._last_lr = self.after_scheduler.get_last_lr()
85 | else:
86 | return super(GradualWarmupScheduler, self).step(epoch)
87 | else:
88 | self.step_ReduceLROnPlateau(metrics, epoch)
89 |
--------------------------------------------------------------------------------
/inclearn/learn/pretrain.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | from inclearn.tools import factory, utils
6 | from inclearn.tools.metrics import ClassErrorMeter, AverageValueMeter
7 |
8 | # import line_profiler
9 | # import atexit
10 | # profile = line_profiler.LineProfiler()
11 | # atexit.register(profile.print_stats)
12 |
13 |
14 | def _compute_loss(cfg, logits, targets, device):
15 |
16 | if cfg["train_head"] == "sigmoid":
17 | n_classes = cfg["start_class"]
18 | onehot_targets = utils.to_onehot(targets, n_classes).to(device)
19 | loss = F.binary_cross_entropy_with_logits(logits, onehot_targets)
20 | elif cfg["train_head"] == "softmax":
21 | loss = F.cross_entropy(logits, targets)
22 | else:
23 | raise ValueError()
24 |
25 | return loss
26 |
27 |
28 | def train(cfg, model, optimizer, device, train_loader):
29 | _loss = 0.0
30 | accu = ClassErrorMeter(accuracy=True)
31 | accu.reset()
32 |
33 | model.train()
34 | for i, (inputs, targets) in enumerate(train_loader, start=1):
35 | # assert torch.isnan(inputs).sum().item() == 0
36 | optimizer.zero_grad()
37 | inputs, targets = inputs.to(device), targets.to(device)
38 | logits = model._parallel_network(inputs)['logit']
39 | if accu is not None:
40 | accu.add(logits.detach(), targets)
41 |
42 | loss = _compute_loss(cfg, logits, targets, device)
43 | if torch.isnan(loss):
44 | import pdb
45 |
46 | pdb.set_trace()
47 |
48 | loss.backward()
49 | optimizer.step()
50 | _loss += loss
51 |
52 | return (
53 | round(_loss.item() / i, 3),
54 | round(accu.value()[0], 3),
55 | )
56 |
57 |
58 | def test(cfg, model, device, test_loader):
59 | _loss = 0.0
60 | accu = ClassErrorMeter(accuracy=True)
61 | accu.reset()
62 |
63 | model.eval()
64 | with torch.no_grad():
65 | for i, (inputs, targets) in enumerate(test_loader, start=1):
66 | # assert torch.isnan(inputs).sum().item() == 0
67 | inputs, targets = inputs.to(device), targets.to(device)
68 | logits = model._parallel_network(inputs)['logit']
69 | if accu is not None:
70 | accu.add(logits.detach(), targets)
71 | loss = _compute_loss(cfg, logits, targets, device)
72 | if torch.isnan(loss):
73 | import pdb
74 | pdb.set_trace()
75 |
76 | _loss = _loss + loss
77 | return round(_loss.item() / i, 3), round(accu.value()[0], 3)
78 |
79 |
80 | def pretrain(cfg, ex, model, device, train_loader, test_loader, model_path):
81 | ex.logger.info(f"nb Train {len(train_loader.dataset)} Eval {len(test_loader.dataset)}")
82 | optimizer = torch.optim.SGD(model._network.parameters(),
83 | lr=cfg["pretrain"]["lr"],
84 | momentum=0.9,
85 | weight_decay=cfg["pretrain"]["weight_decay"])
86 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
87 | cfg["pretrain"]["scheduling"],
88 | gamma=cfg["pretrain"]["lr_decay"])
89 | test_loss, test_acc = float("nan"), float("nan")
90 | for e in range(cfg["pretrain"]["epochs"]):
91 | train_loss, train_acc = train(cfg, model, optimizer, device, train_loader)
92 | if e % 5 == 0:
93 | test_loss, test_acc = test(cfg, model, device, test_loader)
94 | ex.logger.info(
95 | "Pretrain Class {}, Epoch {}/{} => Clf Train loss: {}, Accu {} | Eval loss: {}, Accu {}".format(
96 | cfg["start_class"], e + 1, cfg["pretrain"]["epochs"], train_loss, train_acc, test_loss, test_acc))
97 | else:
98 | ex.logger.info("Pretrain Class {}, Epoch {}/{} => Clf Train loss: {}, Accu {} ".format(
99 | cfg["start_class"], e + 1, cfg["pretrain"]["epochs"], train_loss, train_acc))
100 | scheduler.step()
101 | if hasattr(model._network, "module"):
102 | torch.save(model._network.module.state_dict(), model_path)
103 | else:
104 | torch.save(model._network.state_dict(), model_path)
105 |
--------------------------------------------------------------------------------
/inclearn/convnet/modified_resnet_cifar.py:
--------------------------------------------------------------------------------
1 | """Taken & slightly modified from:
2 | * https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
3 | """
4 | import torch.nn as nn
5 | import torch.utils.model_zoo as model_zoo
6 | from torch.nn import functional as F
7 |
8 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152']
9 |
10 | model_urls = {
11 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
12 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
13 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
14 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
15 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
16 | }
17 |
18 |
19 | def conv3x3(in_planes, out_planes, stride=1):
20 | """3x3 convolution with padding"""
21 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
22 |
23 |
24 | def conv1x1(in_planes, out_planes, stride=1):
25 | """1x1 convolution"""
26 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
27 |
28 |
29 | class BasicBlock(nn.Module):
30 | expansion = 1
31 |
32 | def __init__(self, inplanes, planes, stride=1, downsample=None, remove_last_relu=False):
33 | super(BasicBlock, self).__init__()
34 | self.conv1 = conv3x3(inplanes, planes, stride)
35 | self.bn1 = nn.BatchNorm2d(planes)
36 | self.relu = nn.ReLU(inplace=True)
37 | self.conv2 = conv3x3(planes, planes)
38 | self.bn2 = nn.BatchNorm2d(planes)
39 | self.downsample = downsample
40 | self.stride = stride
41 | self.remove_last_relu = remove_last_relu
42 |
43 | def forward(self, x):
44 | identity = x
45 |
46 | out = self.conv1(x)
47 | out = self.bn1(out)
48 | out = self.relu(out)
49 |
50 | out = self.conv2(out)
51 | out = self.bn2(out)
52 |
53 | if self.downsample is not None:
54 | identity = self.downsample(x)
55 |
56 | out += identity
57 | if not self.remove_last_relu:
58 | out = self.relu(out)
59 | return out
60 |
61 |
62 | class ResNet(nn.Module):
63 | def __init__(self, block, layers, nf=16, dataset='cifar', start_class=0, remove_last_relu=False):
64 | super(ResNet, self).__init__()
65 | self.inplanes = nf
66 | self.conv1 = nn.Conv2d(3, nf, kernel_size=3, stride=1, padding=1, bias=False)
67 | self.bn1 = nn.BatchNorm2d(nf)
68 | self.relu = nn.ReLU(inplace=True)
69 |
70 | self.layer1 = self._make_layer(block, 1 * nf, layers[0])
71 | self.layer2 = self._make_layer(block, 2 * nf, layers[1], stride=2)
72 | self.layer3 = self._make_layer(block, 4 * nf, layers[2], stride=2)
73 | self.avgpool = nn.AvgPool2d(8, stride=1)
74 |
75 | self.out_dim = 4 * nf * block.expansion
76 |
77 | for m in self.modules():
78 | if isinstance(m, nn.Conv2d):
79 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
80 | elif isinstance(m, nn.BatchNorm2d):
81 | nn.init.constant_(m.weight, 1)
82 | nn.init.constant_(m.bias, 0)
83 |
84 | def _make_layer(self, block, planes, blocks, stride=1, remove_last_relu=False):
85 | downsample = None
86 | if stride != 1 or self.inplanes != planes * block.expansion:
87 | downsample = nn.Sequential(
88 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
89 | nn.BatchNorm2d(planes * block.expansion),
90 | )
91 |
92 | layers = []
93 | layers.append(block(self.inplanes, planes, stride, downsample))
94 | self.inplanes = planes * block.expansion
95 | if remove_last_relu:
96 | for i in range(1, blocks - 1):
97 | layers.append(block(self.inplanes, planes))
98 | layers.append(block(self.inplanes, planes, remove_last_relu=True))
99 | else:
100 | for _ in range(1, blocks):
101 | layers.append(block(self.inplanes, planes))
102 |
103 | return nn.Sequential(*layers)
104 |
105 | def reset_bn(self):
106 | for m in self.modules():
107 | if isinstance(m, nn.BatchNorm2d):
108 | m.reset_running_stats()
109 |
110 | def forward(self, x, pool=True):
111 | x = self.conv1(x)
112 | x = self.bn1(x)
113 | x = self.relu(x)
114 |
115 | x = self.layer1(x)
116 | x = self.layer2(x)
117 | x = self.layer3(x)
118 |
119 | x = self.avgpool(x)
120 | x = x.view(x.size(0), -1)
121 | return x
122 |
123 |
124 | def resnet20(pretrained=False, **kwargs):
125 | n = 3
126 | model = ResNet(BasicBlock, [n, n, n], **kwargs)
127 | return model
128 |
129 |
130 | def resnet32(pretrained=False, **kwargs):
131 | n = 5
132 | model = ResNet(BasicBlock, [n, n, n], **kwargs)
133 | return model
134 |
--------------------------------------------------------------------------------
/inclearn/convnet/preact_resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class PreActBlock(nn.Module):
7 | '''Pre-activation version of the BasicBlock.'''
8 | expansion = 1
9 |
10 | def __init__(self, in_planes, planes, stride=1, remove_last_relu=False):
11 | super(PreActBlock, self).__init__()
12 | self.bn1 = nn.BatchNorm2d(in_planes)
13 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
14 | self.bn2 = nn.BatchNorm2d(planes)
15 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
16 | self.bn3 = nn.BatchNorm2d(planes)
17 | self.remove_last_relu = remove_last_relu
18 |
19 | if stride != 1 or in_planes != self.expansion * planes:
20 | self.shortcut = nn.Sequential(
21 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False))
22 |
23 | def forward(self, x):
24 | out = F.relu(self.bn1(x))
25 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
26 | out = self.conv1(out)
27 | out = self.conv2(F.relu(self.bn2(out)))
28 | out += shortcut
29 | out = self.bn3(out)
30 | if not self.remove_last_relu:
31 | out = F.relu(out)
32 | return out
33 |
34 |
35 | class PreActBottleneck(nn.Module):
36 | '''Pre-activation version of the original Bottleneck module.'''
37 | expansion = 4
38 |
39 | def __init__(self, in_planes, planes, stride=1):
40 | super(PreActBottleneck, self).__init__()
41 | self.bn1 = nn.BatchNorm2d(in_planes)
42 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
43 | self.bn2 = nn.BatchNorm2d(planes)
44 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
45 | self.bn3 = nn.BatchNorm2d(planes)
46 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
47 |
48 | if stride != 1 or in_planes != self.expansion * planes:
49 | self.shortcut = nn.Sequential(
50 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False))
51 |
52 | def forward(self, x):
53 | out = F.relu(self.bn1(x))
54 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
55 | out = self.conv1(out)
56 | out = self.conv2(F.relu(self.bn2(out)))
57 | out = self.conv3(F.relu(self.bn3(out)))
58 | out += shortcut
59 | return out
60 |
61 |
62 | class PreActResNet(nn.Module):
63 | def __init__(self,
64 | block,
65 | num_blocks,
66 | nf=64,
67 | zero_init_residual=True,
68 | dataset="cifar",
69 | start_class=0,
70 | remove_last_relu=False):
71 | super(PreActResNet, self).__init__()
72 | self.in_planes = nf
73 | self.dataset = dataset
74 | self.remove_last_relu = remove_last_relu
75 |
76 | if 'cifar' in dataset:
77 | self.conv1 = nn.Conv2d(3, nf, kernel_size=3, stride=1, padding=1, bias=False)
78 | else:
79 | self.conv1 = nn.Sequential(nn.Conv2d(3, nf, kernel_size=7, stride=2, padding=3, bias=False),
80 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
81 | self.layer1 = self._make_layer(block, 1 * nf, num_blocks[0], stride=1)
82 | self.layer2 = self._make_layer(block, 2 * nf, num_blocks[1], stride=2)
83 | self.layer3 = self._make_layer(block, 4 * nf, num_blocks[2], stride=2)
84 | self.layer4 = self._make_layer(block, 8 * nf, num_blocks[3], stride=2, remove_last_relu=remove_last_relu)
85 | self.out_dim = 8 * nf
86 |
87 | if 'cifar' in dataset:
88 | self.avgpool = nn.AvgPool2d(4)
89 | elif 'imagenet' in dataset:
90 | self.avgpool = nn.AvgPool2d(7)
91 |
92 | for m in self.modules():
93 | if isinstance(m, nn.Conv2d):
94 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
95 | elif isinstance(m, nn.BatchNorm2d):
96 | nn.init.constant_(m.weight, 1)
97 | nn.init.constant_(m.bias, 0)
98 |
99 | # ---------------------------------------------
100 | # if zero_init_residual:
101 | # for m in self.modules():
102 | # if isinstance(m, PreActBlock):
103 | # nn.init.constant_(m.bn2.weight, 0)
104 | # elif isinstance(m, PreActBottleneck):
105 | # nn.init.constant_(m.bn3.weight, 0)
106 | # ---------------------------------------------
107 |
108 | def _make_layer(self, block, planes, num_blocks, stride, remove_last_relu=False):
109 | strides = [stride] + [1] * (num_blocks - 1)
110 | layers = []
111 | if remove_last_relu:
112 | for i in range(len(strides) - 1):
113 | layers.append(block(self.in_planes, planes, strides[i]))
114 | self.in_planes = planes * block.expansion
115 | layers.append(block(self.in_planes, planes, strides[-1], remove_last_relu=True))
116 | self.in_planes = planes * block.expansion
117 | else:
118 | for stride in strides:
119 | layers.append(block(self.in_planes, planes, stride))
120 | self.in_planes = planes * block.expansion
121 | return nn.Sequential(*layers)
122 |
123 | def forward(self, x):
124 | out = self.conv1(x)
125 | out = self.layer1(out)
126 | out = self.layer2(out)
127 | out = self.layer3(out)
128 | out = self.layer4(out)
129 | out = self.avgpool(out)
130 | out = out.view(out.size(0), -1)
131 | return out
132 |
133 |
134 | def PreActResNet18(**kwargs):
135 | return PreActResNet(PreActBlock, [2, 2, 2, 2], **kwargs)
136 |
137 |
138 | def PreActResNet34(**kwargs):
139 | return PreActResNet(PreActBlock, [3, 4, 6, 3], **kwargs)
140 |
141 |
142 | def PreActResNet50(**kwargs):
143 | return PreActResNet(PreActBottleneck, [3, 4, 6, 3], **kwargs)
144 |
145 |
146 | def PreActResNet101(**kwargs):
147 | return PreActResNet(PreActBottleneck, [3, 4, 23, 3], **kwargs)
148 |
149 |
150 | def PreActResNet152(**kwargs):
151 | return PreActResNet(PreActBottleneck, [3, 8, 36, 3], **kwargs)
152 |
--------------------------------------------------------------------------------
/inclearn/convnet/cifar_resnet.py:
--------------------------------------------------------------------------------
1 | ''' Incremental-Classifier Learning
2 | Authors : Khurram Javed, Muhammad Talha Paracha
3 | Maintainer : Khurram Javed
4 | Lab : TUKL-SEECS R&D Lab
5 | Email : 14besekjaved@seecs.edu.pk '''
6 |
7 | import math
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from torch.nn import init
13 |
14 |
15 | class DownsampleA(nn.Module):
16 | def __init__(self, nIn, nOut, stride):
17 | super(DownsampleA, self).__init__()
18 | assert stride == 2
19 | self.avg = nn.AvgPool2d(kernel_size=1, stride=stride)
20 |
21 | def forward(self, x):
22 | x = self.avg(x)
23 | return torch.cat((x, x.mul(0)), 1)
24 |
25 |
26 | class ResNetBasicblock(nn.Module):
27 | expansion = 1
28 | """
29 | RexNet basicblock (https://github.com/facebook/fb.resnet.torch/blob/master/models/resnet.lua)
30 | """
31 | def __init__(self, inplanes, planes, stride=1, downsample=None):
32 | super(ResNetBasicblock, self).__init__()
33 |
34 | self.conv_a = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
35 | self.bn_a = nn.BatchNorm2d(planes)
36 |
37 | self.conv_b = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
38 | self.bn_b = nn.BatchNorm2d(planes)
39 |
40 | self.downsample = downsample
41 | self.featureSize = 64
42 |
43 | def forward(self, x):
44 | residual = x
45 |
46 | basicblock = self.conv_a(x)
47 | basicblock = self.bn_a(basicblock)
48 | basicblock = F.relu(basicblock, inplace=True)
49 |
50 | basicblock = self.conv_b(basicblock)
51 | basicblock = self.bn_b(basicblock)
52 |
53 | if self.downsample is not None:
54 | residual = self.downsample(x)
55 |
56 | return F.relu(residual + basicblock, inplace=True)
57 |
58 |
59 | class CifarResNet(nn.Module):
60 | """
61 | ResNet optimized for the Cifar Dataset, as specified in
62 | https://arxiv.org/abs/1512.03385.pdf
63 | """
64 | def __init__(self, block, depth, num_classes, channels=3):
65 | """ Constructor
66 | Args:
67 | depth: number of layers.
68 | num_classes: number of classes
69 | base_width: base width
70 | """
71 | super(CifarResNet, self).__init__()
72 |
73 | self.featureSize = 64
74 | # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
75 | assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110'
76 | layer_blocks = (depth - 2) // 6
77 |
78 | self.num_classes = num_classes
79 |
80 | self.conv_1_3x3 = nn.Conv2d(channels, 16, kernel_size=3, stride=1, padding=1, bias=False)
81 | self.bn_1 = nn.BatchNorm2d(16)
82 |
83 | self.inplanes = 16
84 | self.stage_1 = self._make_layer(block, 16, layer_blocks, 1)
85 | self.stage_2 = self._make_layer(block, 32, layer_blocks, 2)
86 | self.stage_3 = self._make_layer(block, 64, layer_blocks, 2)
87 | self.avgpool = nn.AvgPool2d(8)
88 | self.out_dim = 64 * block.expansion
89 |
90 | for m in self.modules():
91 | if isinstance(m, nn.Conv2d):
92 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
93 | m.weight.data.normal_(0, math.sqrt(2. / n))
94 | # m.bias.data.zero_()
95 | elif isinstance(m, nn.BatchNorm2d):
96 | m.weight.data.fill_(1)
97 | m.bias.data.zero_()
98 | elif isinstance(m, nn.Linear):
99 | init.kaiming_normal(m.weight)
100 | m.bias.data.zero_()
101 |
102 | def _make_layer(self, block, planes, blocks, stride=1):
103 | downsample = None
104 | if stride != 1 or self.inplanes != planes * block.expansion:
105 | downsample = DownsampleA(self.inplanes, planes * block.expansion, stride)
106 |
107 | layers = []
108 | layers.append(block(self.inplanes, planes, stride, downsample))
109 | self.inplanes = planes * block.expansion
110 | for i in range(1, blocks):
111 | layers.append(block(self.inplanes, planes))
112 |
113 | return nn.Sequential(*layers)
114 |
115 | def forward(self, x, feature=False, T=1, labels=False, scale=None, keep=None):
116 |
117 | x = self.conv_1_3x3(x)
118 | x = F.relu(self.bn_1(x), inplace=True)
119 | x = self.stage_1(x)
120 | x = self.stage_2(x)
121 | x = self.stage_3(x)
122 | x = self.avgpool(x)
123 | x = x.view(x.size(0), -1)
124 | return x
125 |
126 | def forwardFeature(self, x):
127 | pass
128 |
129 |
130 | def resnet20(num_classes=10):
131 | """Constructs a ResNet-20 model for CIFAR-10 (by default)
132 | Args:
133 | num_classes (uint): number of classes
134 | """
135 | model = CifarResNet(ResNetBasicblock, 20, num_classes)
136 | return model
137 |
138 |
139 | def resnet10mnist(num_classes=10):
140 | """Constructs a ResNet-20 model for CIFAR-10 (by default)
141 | Args:
142 | num_classes (uint): number of classes
143 | """
144 | model = CifarResNet(ResNetBasicblock, 10, num_classes, 1)
145 | return model
146 |
147 |
148 | def resnet20mnist(num_classes=10):
149 | """Constructs a ResNet-20 model for CIFAR-10 (by default)
150 | Args:
151 | num_classes (uint): number of classes
152 | """
153 | model = CifarResNet(ResNetBasicblock, 20, num_classes, 1)
154 | return model
155 |
156 |
157 | def resnet32mnist(num_classes=10, channels=1):
158 | model = CifarResNet(ResNetBasicblock, 32, num_classes, channels)
159 | return model
160 |
161 |
162 | def resnet32(num_classes=10):
163 | """Constructs a ResNet-32 model for CIFAR-10 (by default)
164 | Args:
165 | num_classes (uint): number of classes
166 | """
167 | model = CifarResNet(ResNetBasicblock, 32, num_classes)
168 | return model
169 |
170 |
171 | def resnet44(num_classes=10):
172 | """Constructs a ResNet-44 model for CIFAR-10 (by default)
173 | Args:
174 | num_classes (uint): number of classes
175 | """
176 | model = CifarResNet(ResNetBasicblock, 44, num_classes)
177 | return model
178 |
179 |
180 | def resnet56(num_classes=10):
181 | """Constructs a ResNet-56 model for CIFAR-10 (by default)
182 | Args:
183 | num_classes (uint): number of classes
184 | """
185 | model = CifarResNet(ResNetBasicblock, 56, num_classes)
186 | return model
187 |
188 |
189 | def resnet110(num_classes=10):
190 | """Constructs a ResNet-110 model for CIFAR-10 (by default)
191 | Args:
192 | num_classes (uint): number of classes
193 | """
194 | model = CifarResNet(ResNetBasicblock, 110, num_classes)
195 | return model
196 |
--------------------------------------------------------------------------------
/inclearn/models/base.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import logging
3 | import torch
4 | import torch.nn.functional as F
5 | import numpy as np
6 | from inclearn.tools.metrics import ClassErrorMeter
7 |
8 | LOGGER = logging.Logger("IncLearn", level="INFO")
9 |
10 |
11 | class IncrementalLearner(abc.ABC):
12 | """Base incremental learner.
13 |
14 | Methods are called in this order (& repeated for each new task):
15 |
16 | 1. set_task_info
17 | 2. before_task
18 | 3. train_task
19 | 4. after_task
20 | 5. eval_task
21 | """
22 | def __init__(self, *args, **kwargs):
23 | self._increments = []
24 | self._seen_classes = []
25 |
26 | def set_task_info(self, task, total_n_classes, increment, n_train_data, n_test_data, n_tasks):
27 | self._task = task
28 | self._task_size = increment
29 | self._increments.append(self._task_size)
30 | self._total_n_classes = total_n_classes
31 | self._n_train_data = n_train_data
32 | self._n_test_data = n_test_data
33 | self._n_tasks = n_tasks
34 |
35 | def before_task(self, taski, inc_dataset):
36 | LOGGER.info("Before task")
37 | self.eval()
38 | self._before_task(taski, inc_dataset)
39 |
40 | def train_task(self, train_loader, val_loader):
41 | LOGGER.info("train task")
42 | self.train()
43 | self._train_task(train_loader, val_loader)
44 |
45 | def after_task(self, taski, inc_dataset):
46 | LOGGER.info("after task")
47 | self.eval()
48 | self._after_task(taski, inc_dataset)
49 |
50 | def eval_task(self, data_loader):
51 | LOGGER.info("eval task")
52 | self.eval()
53 | return self._eval_task(data_loader)
54 |
55 | def get_memory(self):
56 | return None
57 |
58 | def eval(self):
59 | raise NotImplementedError
60 |
61 | def train(self):
62 | raise NotImplementedError
63 |
64 | def _before_task(self, data_loader):
65 | pass
66 |
67 | def _train_task(self, train_loader, val_loader):
68 | raise NotImplementedError
69 |
70 | def _after_task(self, data_loader):
71 | pass
72 |
73 | def _eval_task(self, data_loader):
74 | raise NotImplementedError
75 |
76 | @property
77 | def _new_task_index(self):
78 | return self._task * self._task_size
79 |
80 | @property
81 | def _memory_per_class(self):
82 | """Returns the number of examplars per class."""
83 | return self._memory_size.mem_per_cls
84 |
85 | def _after_epoch(self, epoch, avg_loss, train_new_accu, train_old_accu, accu):
86 | self._run.log_scalar(f"train_loss_trial{self._trial_i}_task{self._task}", avg_loss, epoch + 1)
87 | self._tensorboard.add_scalar(f"trial{self._trial_i}_task{self._task}/train_loss", avg_loss, epoch + 1)
88 |
89 | # self._run.log_scalar(f"train_new_accu_trial{self._trial_i}_task{self._task}",
90 | # train_new_accu.value()[0], epoch + 1)
91 | # self._tensorboard.add_scalar(f"trial{self._trial_i}_task{self._task}/train_new_accu",
92 | # train_new_accu.value()[0], epoch + 1)
93 |
94 | # if self._task != 0:
95 | # self._run.log_scalar(f"train_old_accu_trial{self._trial_i}_task{self._task}",
96 | # train_old_accu.value()[0], epoch + 1)
97 | # self._tensorboard.add_scalar(f"trial{self._trial_i}_task{self._task}/train_old_accu",
98 | # train_old_accu.value()[0], epoch + 1)
99 |
100 | self._run.log_scalar(f"train_accu_trial{self._trial_i}_task{self._task}", accu.value()[0], epoch + 1)
101 | self._tensorboard.add_scalar(f"trial{self._trial_i}_task{self._task}/train_accu", accu.value()[0], epoch + 1)
102 | # self._tensorboard.close()
103 | self._tensorboard.flush()
104 |
105 | def _validation(self, val_loader, epoch):
106 | topk = 5 if self._n_classes >= 5 else self._n_classes
107 | if self._val_per_n_epoch != -1 and epoch % self._val_per_n_epoch == 0:
108 | _val_loss = 0
109 | _val_accu = ClassErrorMeter(accuracy=True, topk=[1, topk])
110 | _val_new_accu = ClassErrorMeter(accuracy=True)
111 | _val_old_accu = ClassErrorMeter(accuracy=True)
112 | self._parallel_network.eval()
113 | with torch.no_grad():
114 | for i, (inputs, targets) in enumerate(val_loader, 1):
115 | old_classes = targets < (self._n_classes - self._task_size)
116 | new_classes = targets >= (self._n_classes - self._task_size)
117 | val_loss, _ = self._forward_loss(
118 | inputs,
119 | targets,
120 | old_classes,
121 | new_classes,
122 | accu=_val_accu,
123 | old_accu=_val_old_accu,
124 | new_accu=_val_new_accu,
125 | )
126 | _val_loss += val_loss.item()
127 | self._ex.logger.info(
128 | f"epoch{epoch} val acc:{_val_accu.value()[0]:.2f}, val top5acc:{_val_accu.value()[1]:.2f}")
129 | # Test accu
130 | self._run.log_scalar(f"test_accu_trial{self._trial_i}_task{self._task}", _val_accu.value()[0], epoch + 1)
131 | self._run.log_scalar(f"test_5accu_trial{self._trial_i}_task{self._task}", _val_accu.value()[1], epoch + 1)
132 | self._tensorboard.add_scalar(f"trial{self._trial_i}_task{self._task}/test_accu",
133 | _val_accu.value()[0], epoch + 1)
134 | self._tensorboard.add_scalar(f"trial{self._trial_i}_task{self._task}/test_5accu",
135 | _val_accu.value()[1], epoch + 1)
136 |
137 | # Test new accu
138 | self._run.log_scalar(f"test_new_accu_trial{self._trial_i}_task{self._task}",
139 | _val_new_accu.value()[0], epoch + 1)
140 | self._tensorboard.add_scalar(f"trial{self._trial_i}_task{self._task}/test_new_accu",
141 | _val_new_accu.value()[0], epoch + 1)
142 |
143 | # Test old accu
144 | if self._task != 0:
145 | self._run.log_scalar(f"test_old_accu_trial{self._trial_i}_task{self._task}",
146 | _val_old_accu.value()[0], epoch + 1)
147 | self._tensorboard.add_scalar(f"trial{self._trial_i}_task{self._task}/test_old_accu",
148 | _val_old_accu.value()[0], epoch + 1)
149 |
150 | # Test loss
151 | self._run.log_scalar(f"test_loss_trial{self._trial_i}_task{self._task}", round(_val_loss / i, 3), epoch + 1)
152 | self._tensorboard.add_scalar(f"trial{self._trial_i}_task{self._task}/test_loss", round(_val_loss / i, 3),
153 | epoch + 1)
154 | self._tensorboard.close()
--------------------------------------------------------------------------------
/inclearn/tools/memory.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from copy import deepcopy
3 | import torch
4 | from torch.nn import functional as F
5 |
6 | from inclearn.tools.utils import get_class_loss
7 | from inclearn.convnet.utils import extract_features
8 |
9 |
10 | class MemorySize:
11 | def __init__(self, mode, inc_dataset, total_memory=None, fixed_memory_per_cls=None):
12 | self.mode = mode
13 | assert mode.lower() in ["uniform_fixed_per_cls", "uniform_fixed_total_mem", "dynamic_fixed_per_cls"]
14 | self.total_memory = total_memory
15 | self.fixed_memory_per_cls = fixed_memory_per_cls
16 | self._n_classes = 0
17 | self.mem_per_cls = []
18 | self._inc_dataset = inc_dataset
19 |
20 | def update_n_classes(self, n_classes):
21 | self._n_classes = n_classes
22 |
23 | def update_memory_per_cls_uniform(self, n_classes):
24 | if "fixed_per_cls" in self.mode:
25 | self.mem_per_cls = [self.fixed_memory_per_cls for i in range(n_classes)]
26 | elif "fixed_total_mem" in self.mode:
27 | self.mem_per_cls = [self.total_memory // n_classes for i in range(n_classes)]
28 | return self.mem_per_cls
29 |
30 | def update_memory_per_cls(self, network, n_classes, task_size):
31 | if "uniform" in self.mode:
32 | self.update_memory_per_cls_uniform(n_classes)
33 | else:
34 | if n_classes == task_size:
35 | self.update_memory_per_cls_uniform(n_classes)
36 |
37 | @property
38 | def memsize(self):
39 | if self.mode == "fixed_total_mem":
40 | return self.total_memory
41 | elif self.mode == "fixed_per_cls":
42 | return self.fixed_memory_per_cls * self._n_classes
43 |
44 |
45 | def compute_examplar_mean(feat_norm, feat_flip, herding_mat, nb_max):
46 | EPSILON = 1e-8
47 | D = feat_norm.T
48 | D = D / (np.linalg.norm(D, axis=0) + EPSILON)
49 |
50 | D2 = feat_flip.T
51 | D2 = D2 / (np.linalg.norm(D2, axis=0) + EPSILON)
52 |
53 | alph = herding_mat
54 | alph = (alph > 0) * (alph < nb_max + 1) * 1.0
55 |
56 | alph_mean = alph / np.sum(alph)
57 |
58 | mean = (np.dot(D, alph_mean) + np.dot(D2, alph_mean)) / 2
59 | # mean = np.dot(D, alph_mean)
60 | mean /= np.linalg.norm(mean) + EPSILON
61 |
62 | return mean, alph
63 |
64 |
65 | def select_examplars(features, nb_max):
66 | EPSILON = 1e-8
67 | D = features.T
68 | D = D / (np.linalg.norm(D, axis=0) + EPSILON)
69 | mu = np.mean(D, axis=1)
70 | herding_matrix = np.zeros((features.shape[0], ))
71 | idxes = []
72 | w_t = mu
73 |
74 | iter_herding, iter_herding_eff = 0, 0
75 |
76 | while not (np.sum(herding_matrix != 0) == min(nb_max, features.shape[0])) and iter_herding_eff < 1000:
77 | tmp_t = np.dot(w_t, D)
78 | # tmp_t = -np.linalg.norm(w_t[:,np.newaxis]-D, axis=0)
79 | # tmp_t = np.linalg.norm(w_t[:,np.newaxis]-D, axis=0)
80 | ind_max = np.argmax(tmp_t)
81 | iter_herding_eff += 1
82 | if herding_matrix[ind_max] == 0:
83 | herding_matrix[ind_max] = 1 + iter_herding
84 | idxes.append(ind_max)
85 | iter_herding += 1
86 |
87 | w_t = w_t + mu - D[:, ind_max]
88 |
89 | return herding_matrix, idxes
90 |
91 |
92 | def random_selection(n_classes, task_size, network, logger, inc_dataset, memory_per_class: list):
93 | # TODO: Move data_memroy,targets_memory into IncDataset
94 | logger.info("Building & updating memory.(Random Selection)")
95 | tmp_data_memory, tmp_targets_memory = [], []
96 | assert len(memory_per_class) == n_classes
97 | for class_idx in range(n_classes):
98 | # 旧类数据从get_custom_loader_from_memory中读取,新类数据从get_custom_loader中读取
99 | if class_idx < n_classes - task_size:
100 | inputs, targets, loader = inc_dataset.get_custom_loader_from_memory([class_idx])
101 | else:
102 | inputs, targets, loader = inc_dataset.get_custom_loader(class_idx, mode="test")
103 | memory_this_cls = min(memory_per_class[class_idx], inputs.shape[0])
104 | idxs = np.random.choice(inputs.shape[0], memory_this_cls, replace=False)
105 | tmp_data_memory.append(inputs[idxs])
106 | tmp_targets_memory.append(targets[idxs])
107 | tmp_data_memory = np.concatenate(tmp_data_memory)
108 | tmp_targets_memory = np.concatenate(tmp_targets_memory)
109 | return tmp_data_memory, tmp_targets_memory
110 |
111 |
112 | def herding(n_classes, task_size, network, herding_matrix, inc_dataset, shared_data_inc, memory_per_class: list,
113 | logger):
114 | """Herding matrix: list
115 | """
116 | logger.info("Building & updating memory.(iCaRL)")
117 | tmp_data_memory, tmp_targets_memory = [], []
118 |
119 | for class_idx in range(n_classes):
120 | inputs = inc_dataset.data_train[inc_dataset.targets_train == class_idx]
121 | targets = inc_dataset.targets_train[inc_dataset.targets_train == class_idx]
122 | # zi = inc_dataset.zimages[inc_dataset.zlabels == class_idx]
123 | # zt = inc_dataset.zlabels[inc_dataset.zlabels == class_idx]
124 | # inputs = np.concatenate((inputs, zi))
125 | # targets = np.concatenate((targets, zt))
126 |
127 |
128 | if class_idx >= n_classes - task_size:
129 | if len(shared_data_inc) > len(inc_dataset.targets_inc):
130 | share_memory = [shared_data_inc[i] for i in np.where(inc_dataset.targets_inc == class_idx)[0].tolist()]
131 | else:
132 | share_memory = []
133 | for i in np.where(inc_dataset.targets_inc == class_idx)[0].tolist():
134 | if i < len(shared_data_inc):
135 | share_memory.append(shared_data_inc[i])
136 |
137 | # share_memory = [shared_data_inc[i] for i in np.where(inc_dataset.targets_inc == class_idx)[0].tolist()]
138 | loader = inc_dataset._get_loader(inc_dataset.data_inc[inc_dataset.targets_inc == class_idx],
139 | inc_dataset.targets_inc[inc_dataset.targets_inc == class_idx],
140 | share_memory=share_memory,
141 | batch_size=128,
142 | shuffle=False,
143 | mode="test")
144 | features, _ = extract_features(network, loader)
145 | # features_flipped, _ = extract_features(network, inc_dataset.get_custom_loader(class_idx, mode="flip")[-1])
146 | herding_matrix.append(select_examplars(features, memory_per_class[class_idx])[0])
147 | alph = herding_matrix[class_idx]
148 | alph = (alph > 0) * (alph < memory_per_class[class_idx] + 1) * 1.0
149 | # examplar_mean, alph = compute_examplar_mean(features, features_flipped, herding_matrix[class_idx],
150 | # memory_per_class[class_idx])
151 | tmp_data_memory.append(inputs[np.where(alph == 1)[0]])
152 | tmp_targets_memory.append(targets[np.where(alph == 1)[0]])
153 | tmp_data_memory = np.concatenate(tmp_data_memory)
154 | tmp_targets_memory = np.concatenate(tmp_targets_memory)
155 | return tmp_data_memory, tmp_targets_memory, herding_matrix
156 |
--------------------------------------------------------------------------------
/inclearn/prune/prune/structured.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from copy import deepcopy
4 | from functools import reduce
5 | from operator import mul
6 |
7 | __all__=['prune_conv', 'prune_related_conv', 'prune_linear', 'prune_related_linear', 'prune_batchnorm', 'prune_prelu', 'prune_group_conv']
8 |
9 | def prune_group_conv(layer: nn.modules.conv._ConvNd, idxs: list, inplace: bool=True, dry_run: bool=False):
10 | """Prune `filters` for the convolutional layer, e.g. [256 x 128 x 3 x 3] => [192 x 128 x 3 x 3]
11 |
12 | Args:
13 | - layer: a convolution layer.
14 | - idxs: pruning index.
15 | """
16 |
17 | if layer.groups>1:
18 | assert layer.groups==layer.in_channels and layer.groups==layer.out_channels, "only group conv with in_channel==groups==out_channels is supported"
19 |
20 |
21 | idxs = list(set(idxs))
22 | num_pruned = len(idxs) * reduce(mul, layer.weight.shape[1:]) + (len(idxs) if layer.bias is not None else 0)
23 | if dry_run:
24 | return layer, num_pruned
25 | if not inplace:
26 | layer = deepcopy(layer)
27 | keep_idxs = [idx for idx in range(layer.out_channels) if idx not in idxs]
28 | layer.out_channels = layer.out_channels-len(idxs)
29 | layer.in_channels = layer.in_channels-len(idxs)
30 |
31 | layer.groups = layer.groups-len(idxs)
32 | layer.weight = torch.nn.Parameter(layer.weight.data.clone()[keep_idxs])
33 | if layer.bias is not None:
34 | layer.bias = torch.nn.Parameter(layer.bias.data.clone()[keep_idxs])
35 |
36 | return layer, num_pruned
37 |
38 | def prune_conv(layer: nn.modules.conv._ConvNd, idxs: list, inplace: bool=True, dry_run: bool=False):
39 | """Prune `filters` for the convolutional layer, e.g. [256 x 128 x 3 x 3] => [192 x 128 x 3 x 3]
40 |
41 | Args:
42 | - layer: a convolution layer.
43 | - idxs: pruning index.
44 | """
45 | idxs = list(set(idxs))
46 | num_pruned = len(idxs) * reduce(mul, layer.weight.shape[1:]) + (len(idxs) if layer.bias is not None else 0)
47 | if dry_run:
48 | return layer, num_pruned
49 |
50 | if not inplace:
51 | layer = deepcopy(layer)
52 |
53 | keep_idxs = [idx for idx in range(layer.out_channels) if idx not in idxs]
54 | layer.out_channels = layer.out_channels-len(idxs)
55 | if isinstance(layer,(nn.ConvTranspose2d,nn.ConvTranspose3d)):
56 | layer.weight = torch.nn.Parameter(layer.weight.data.clone()[:, keep_idxs])
57 | else:
58 | layer.weight = torch.nn.Parameter(layer.weight.data.clone()[keep_idxs])
59 | if layer.bias is not None:
60 | layer.bias = torch.nn.Parameter(layer.bias.data.clone()[keep_idxs])
61 | return layer, num_pruned
62 |
63 | def prune_related_conv(layer: nn.modules.conv._ConvNd, idxs: list, inplace: bool=True, dry_run: bool=False):
64 | """Prune `kernels` for the related (affected) convolutional layer, e.g. [256 x 128 x 3 x 3] => [256 x 96 x 3 x 3]
65 |
66 | Args:
67 | layer: a convolutional layer.
68 | idxs: pruning index.
69 | """
70 | idxs = list(set(idxs))
71 | num_pruned = len(idxs) * layer.weight.shape[0] * reduce(mul ,layer.weight.shape[2:])
72 | if dry_run:
73 | return layer, num_pruned
74 | if not inplace:
75 | layer = deepcopy(layer)
76 |
77 |
78 | keep_idxs = [i for i in range(layer.in_channels) if i not in idxs]
79 |
80 | layer.in_channels = layer.in_channels - len(idxs)
81 |
82 | if isinstance(layer,(nn.ConvTranspose2d,nn.ConvTranspose3d)):
83 | layer.weight = torch.nn.Parameter(layer.weight.data.clone()[keep_idxs,:])
84 | else:
85 | layer.weight = torch.nn.Parameter(layer.weight.data.clone()[:, keep_idxs])
86 | # no bias pruning because it does not change the output size
87 | return layer, num_pruned
88 |
89 | def prune_linear(layer: nn.modules.linear.Linear, idxs: list, inplace: list=True, dry_run: list=False):
90 | """Prune neurons for the fully-connected layer, e.g. [256 x 128] => [192 x 128]
91 |
92 | Args:
93 | layer: a fully-connected layer.
94 | idxs: pruning index.
95 | """
96 | num_pruned = len(idxs)*layer.weight.shape[1] + (len(idxs) if layer.bias is not None else 0)
97 | if dry_run:
98 | return layer, num_pruned
99 |
100 | if not inplace:
101 | layer = deepcopy(layer)
102 | keep_idxs = [i for i in range(layer.out_features) if i not in idxs]
103 | layer.out_features = layer.out_features-len(idxs)
104 | layer.weight = torch.nn.Parameter(layer.weight.data.clone()[keep_idxs])
105 | if layer.bias is not None:
106 | layer.bias = torch.nn.Parameter(layer.bias.data.clone()[keep_idxs])
107 | return layer, num_pruned
108 |
109 | def prune_related_linear(layer: nn.modules.linear.Linear, idxs: list, inplace: list=True, dry_run: list=False):
110 | """Prune weights for the related (affected) fully-connected layer, e.g. [256 x 128] => [256 x 96]
111 |
112 | Args:
113 | layer: a fully-connected layer.
114 | idxs: pruning index.
115 | """
116 | num_pruned = len(idxs) * layer.weight.shape[0]
117 | if dry_run:
118 | return layer, num_pruned
119 |
120 | if not inplace:
121 | layer = deepcopy(layer)
122 | keep_idxs = [i for i in range(layer.in_features) if i not in idxs]
123 | layer.in_features = layer.in_features-len(idxs)
124 | layer.weight = torch.nn.Parameter(layer.weight.data.clone()[:, keep_idxs])
125 | return layer, num_pruned
126 |
127 | def prune_batchnorm(layer: nn.modules.batchnorm._BatchNorm, idxs: list, inplace: bool=True, dry_run: bool=False ):
128 | """Prune batch normalization layers, e.g. [128] => [64]
129 |
130 | Args:
131 | layer: a batch normalization layer.
132 | idxs: pruning index.
133 | """
134 |
135 | num_pruned = len(idxs)* ( 2 if layer.affine else 1)
136 | if dry_run:
137 | return layer, num_pruned
138 |
139 | if not inplace:
140 | layer = deepcopy(layer)
141 |
142 | keep_idxs = [i for i in range(layer.num_features) if i not in idxs]
143 | layer.num_features = layer.num_features-len(idxs)
144 | layer.running_mean = layer.running_mean.data.clone()[keep_idxs]
145 | layer.running_var = layer.running_var.data.clone()[keep_idxs]
146 | if layer.affine:
147 | layer.weight = torch.nn.Parameter(layer.weight.data.clone()[keep_idxs])
148 | layer.bias = torch.nn.Parameter(layer.bias.data.clone()[keep_idxs])
149 | return layer, num_pruned
150 |
151 | def prune_prelu(layer: nn.PReLU, idxs: list, inplace: bool=True, dry_run: bool=False):
152 | """Prune PReLU layers, e.g. [128] => [64] or [1] => [1] (no pruning if prelu has only 1 parameter)
153 |
154 | Args:
155 | layer: a PReLU layer.
156 | idxs: pruning index.
157 | """
158 | num_pruned = 0 if layer.num_parameters==1 else len(idxs)
159 | if dry_run:
160 | return layer, num_pruned
161 | if not inplace:
162 | layer = deepcopy(layer)
163 | if layer.num_parameters==1: return layer, num_pruned
164 | keep_idxs = [i for i in range(layer.num_parameters) if i not in idxs]
165 | layer.num_parameters = layer.num_parameters-len(idxs)
166 | layer.weight = torch.nn.Parameter(layer.weight.data.clone()[keep_idxs])
167 | return layer, num_pruned
168 |
169 |
170 |
--------------------------------------------------------------------------------
/inclearn/convnet/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import nn
4 | from torch.optim import SGD
5 | import torch.nn.functional as F
6 | from inclearn.tools.metrics import ClassErrorMeter, AverageValueMeter
7 | from inclearn.loss.focalloss import BCEFocalLoss, MultiCEFocalLoss
8 |
9 |
10 | def finetune_last_layer(
11 | logger,
12 | network,
13 | loader,
14 | n_class,
15 | nepoch=30,
16 | lr=0.1,
17 | scheduling=[15, 35],
18 | lr_decay=0.1,
19 | weight_decay=5e-4,
20 | loss_type="ce",
21 | temperature=5.0,
22 | test_loader=None,
23 | samples_per_cls = []
24 | ):
25 | network.eval()
26 | #if hasattr(network.module, "convnets"):
27 | # for net in network.module.convnets:
28 | # net.eval()
29 | #else:
30 | # network.module.convnet.eval()
31 | optim = SGD(network.module.classifier.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)
32 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optim, scheduling, gamma=lr_decay)
33 |
34 | if loss_type == "ce":
35 | criterion = nn.CrossEntropyLoss()
36 | else:
37 | criterion = nn.BCEWithLogitsLoss()
38 |
39 | logger.info("Begin finetuning last layer")
40 |
41 | for i in range(nepoch):
42 | total_loss = 0.0
43 | total_correct = 0.0
44 | total_count = 0
45 | # print(f"dataset loader length {len(loader.dataset)}")
46 | for inputs, targets in loader:
47 | inputs, targets = inputs.cuda(), targets.cuda()
48 | if loss_type == "bce":
49 | targets = to_onehot(targets, n_class)
50 | outputs = network(inputs)['logit']
51 | _, preds = outputs.max(1)
52 | optim.zero_grad()
53 | if loss_type == "cb":
54 | loss = CB_loss(targets, outputs / temperature, samples_per_cls, n_class, "focal")
55 | else:
56 | loss = criterion(outputs / temperature, targets)
57 | loss.backward()
58 | optim.step()
59 | total_loss += loss * inputs.size(0)
60 | total_correct += (preds == targets).sum()
61 | total_count += inputs.size(0)
62 |
63 | if test_loader is not None:
64 | test_correct = 0.0
65 | test_count = 0.0
66 | with torch.no_grad():
67 | for inputs, targets in test_loader:
68 | outputs = network(inputs.cuda())['logit']
69 | _, preds = outputs.max(1)
70 | test_correct += (preds.cpu() == targets).sum().item()
71 | test_count += inputs.size(0)
72 |
73 | scheduler.step()
74 | if test_loader is not None:
75 | logger.info(
76 | "Epoch %d finetuning loss %.3f acc %.3f Eval %.3f" %
77 | (i, total_loss.item() / total_count, total_correct.item() / total_count, test_correct / test_count))
78 | else:
79 | logger.info("Epoch %d finetuning loss %.3f acc %.3f" %
80 | (i, total_loss.item() / total_count, total_correct.item() / total_count))
81 | return network
82 |
83 | def CB_loss(labels, logits, samples_per_cls, no_of_classes, loss_type, beta=0.99, gamma=2.0):
84 | """Compute the Class Balanced Loss between `logits` and the ground truth `labels`.
85 | Class Balanced Loss: ((1-beta)/(1-beta^n))*Loss(labels, logits)
86 | where Loss is one of the standard losses used for Neural Networks.
87 | Args:
88 | labels: A int tensor of size [batch].
89 | logits: A float tensor of size [batch, no_of_classes].
90 | samples_per_cls: A python list of size [no_of_classes].
91 | no_of_classes: total number of classes. int
92 | loss_type: string. One of "sigmoid", "focal", "softmax".
93 | beta: float. Hyperparameter for Class balanced loss.
94 | gamma: float. Hyperparameter for Focal loss.
95 | Returns:
96 | cb_loss: A float tensor representing class balanced loss
97 | """
98 | effective_num = 1.0 - np.power(beta, samples_per_cls)
99 | # print(labels.device, logits.device)
100 | weights = (1.0 - beta) / np.array(effective_num)
101 | weights = weights / np.sum(weights) * no_of_classes
102 |
103 | labels_one_hot = F.one_hot(labels, no_of_classes).float().cpu()
104 |
105 | weights = torch.tensor(weights).float()
106 | # weights = weights.unsqueeze(0)
107 | # print(labels_one_hot.device, weights.device)
108 | # weights = weights.repeat(labels_one_hot.shape[0],1) * labels_one_hot
109 | # weights = weights.sum(1)
110 | # weights = weights.unsqueeze(1)
111 | # weights = weights.repeat(1,no_of_classes)
112 |
113 | if loss_type == "focal":
114 | # print(effective_num.shape, weights.shape, len(samples_per_cls))
115 | focalloss = MultiCEFocalLoss(class_num=no_of_classes, gamma=gamma, alpha=weights).cuda()
116 | cb_loss = focalloss(logits, labels)
117 | # cb_loss = self.focal_loss(labels_one_hot, logits, weights, gamma)
118 | elif loss_type == "sigmoid":
119 | cb_loss = F.binary_cross_entropy_with_logits(input = logits,target = labels_one_hot, weights = weights)
120 | elif loss_type == "softmax":
121 | pred = logits.softmax(dim = 1)
122 | cb_loss = F.binary_cross_entropy(input = pred, target = labels_one_hot, weight = weights)
123 | return cb_loss
124 |
125 | def extract_features(model, loader):
126 | targets, features = [], []
127 | model.eval()
128 | with torch.no_grad():
129 | for _inputs, _targets in loader:
130 | _inputs = _inputs.cuda()
131 | _targets = _targets.numpy()
132 | _features = model(_inputs)['feature'].detach().cpu().numpy()
133 | features.append(_features)
134 | targets.append(_targets)
135 |
136 | return np.concatenate(features), np.concatenate(targets)
137 |
138 |
139 | def calc_class_mean(network, loader, class_idx, metric):
140 | EPSILON = 1e-8
141 | features, targets = extract_features(network, loader)
142 | # norm_feats = features/(np.linalg.norm(features, axis=1)[:,np.newaxis]+EPSILON)
143 | # examplar_mean = norm_feats.mean(axis=0)
144 | examplar_mean = features.mean(axis=0)
145 | if metric == "cosine" or metric == "weight":
146 | examplar_mean /= (np.linalg.norm(examplar_mean) + EPSILON)
147 | return examplar_mean
148 |
149 |
150 | def update_classes_mean(network, inc_dataset, n_classes, task_size, share_memory=None, metric="cosine", EPSILON=1e-8):
151 | loader = inc_dataset._get_loader(inc_dataset.data_inc,
152 | inc_dataset.targets_inc,
153 | shuffle=False,
154 | share_memory=share_memory,
155 | mode="test")
156 | class_means = np.zeros((n_classes, network.module.features_dim))
157 | count = np.zeros(n_classes)
158 | network.eval()
159 | with torch.no_grad():
160 | for x, y in loader:
161 | feat = network(x.cuda())['feature']
162 | for lbl in torch.unique(y):
163 | class_means[lbl] += feat[y == lbl].sum(0).cpu().numpy()
164 | count[lbl] += feat[y == lbl].shape[0]
165 | for i in range(n_classes):
166 | class_means[i] /= count[i]
167 | if metric == "cosine" or metric == "weight":
168 | class_means[i] /= (np.linalg.norm(class_means) + EPSILON)
169 | return class_means
170 |
--------------------------------------------------------------------------------
/inclearn/models/prune.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os, sys
3 | import shutil
4 | import time
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.parallel
9 | import torch.backends.cudnn as cudnn
10 | import torch.optim
11 | import torch.utils.data
12 | import torchvision.transforms as transforms
13 | import torchvision.datasets as datasets
14 | import torchvision.models
15 | # from models import print_log
16 | import numpy as np
17 | from collections import OrderedDict
18 |
19 | class AverageMeter(object):
20 | """Computes and stores the average and current value"""
21 |
22 | def __init__(self):
23 | self.reset()
24 |
25 | def reset(self):
26 | self.val = 0
27 | self.avg = 0
28 | self.sum = 0
29 | self.count = 0
30 |
31 | def update(self, val, n=1):
32 | self.val = val
33 | self.sum += val * n
34 | self.count += n
35 | self.avg = self.sum / self.count
36 |
37 | class Mask:
38 | def __init__(self, model, cfg):
39 | self.model_size = {}
40 | self.model_length = {}
41 | self.compress_rate = {}
42 | self.mat = {}
43 | self.model = model
44 | self.mask_index = []
45 | self.cfg = cfg
46 |
47 | def get_codebook(self, weight_torch, compress_rate, length):
48 | weight_vec = weight_torch.view(length)
49 | weight_np = weight_vec.cpu().numpy()
50 |
51 | weight_abs = np.abs(weight_np)
52 | weight_sort = np.sort(weight_abs)
53 |
54 | threshold = weight_sort[int(length * (1 - compress_rate))]
55 | weight_np[weight_np <= -threshold] = 1
56 | weight_np[weight_np >= threshold] = 1
57 | weight_np[weight_np != 1] = 0
58 |
59 | # print("codebook done")
60 | return weight_np
61 |
62 | def get_filter_codebook(self, weight_torch, compress_rate, length):
63 | codebook = np.ones(length)
64 | if len(weight_torch.size()) == 4:
65 | filter_pruned_num = int(weight_torch.size()[0] * (1 - compress_rate))
66 | weight_vec = weight_torch.view(weight_torch.size()[0], -1)
67 | # norm1 = torch.norm(weight_vec, 1, 1)
68 | # norm1_np = norm1.cpu().numpy()
69 | norm2 = torch.norm(weight_vec, 2, 1)
70 | norm2_np = norm2.cpu().numpy()
71 | filter_index = norm2_np.argsort()[:filter_pruned_num]
72 | # norm1_sort = np.sort(norm1_np)
73 | # threshold = norm1_sort[int (weight_torch.size()[0] * (1-compress_rate) )]
74 | kernel_length = weight_torch.size()[1] * weight_torch.size()[2] * weight_torch.size()[3]
75 | for x in range(0, len(filter_index)):
76 | codebook[filter_index[x] * kernel_length: (filter_index[x] + 1) * kernel_length] = 0
77 |
78 | # print("filter codebook done")
79 | else:
80 | pass
81 | return codebook
82 |
83 | def convert2tensor(self, x):
84 | x = torch.FloatTensor(x)
85 | return x
86 |
87 | def init_length(self):
88 | for index, item in enumerate(self.model.parameters()):
89 | self.model_size[index] = item.size()
90 |
91 | for index1 in self.model_size:
92 | for index2 in range(0, len(self.model_size[index1])):
93 | if index2 == 0:
94 | self.model_length[index1] = self.model_size[index1][0]
95 | else:
96 | self.model_length[index1] *= self.model_size[index1][index2]
97 |
98 | def init_rate(self, layer_rate):
99 | if 'vgg' in self.cfg["arch"]:
100 | cfg_5x = [24, 22, 41, 51, 108, 89, 111, 184, 276, 228, 512, 512, 512]
101 | cfg_official = [64, 64, 128, 128, 256, 256, 256, 512, 512, 512, 512, 512, 512]
102 | # cfg = [32, 64, 128, 128, 256, 256, 256, 256, 256, 256, 256, 256, 256]
103 | cfg_index = 0
104 | pre_cfg = True
105 | for index, item in enumerate(self.model.named_parameters()):
106 | self.compress_rate[index] = 1
107 | if len(item[1].size()) == 4:
108 | # print(item[1].size())
109 | if not pre_cfg:
110 | self.compress_rate[index] = layer_rate
111 | self.mask_index.append(index)
112 | # print(item[0], "self.mask_index", self.mask_index)
113 | else:
114 | self.compress_rate[index] = 1 - cfg_5x[cfg_index] / item[1].size()[0]
115 | self.mask_index.append(index)
116 | # print(item[0], "self.mask_index", self.mask_index, cfg_index, cfg_5x[cfg_index], item[1].size()[0],)
117 | cfg_index += 1
118 | elif "resnet" in self.cfg["arch"]:
119 | for index, item in enumerate(self.model.parameters()):
120 | self.compress_rate[index] = 1
121 | for key in range(self.cfg["layer_begin"], self.cfg["layer_end"] + 1, self.cfg["layer_inter"]):
122 | self.compress_rate[key] = layer_rate
123 | if self.cfg["arch"] == 'resnet18':
124 | # last index include last fc layer
125 | last_index = 59
126 | skip_list = [21, 36, 51]
127 | elif self.cfg["arch"] == 'resnet34':
128 | last_index = 108
129 | skip_list = [27, 54, 93]
130 | elif self.cfg["arch"] == 'resnet50':
131 | last_index = 159
132 | skip_list = [12, 42, 81, 138]
133 | elif self.cfg["arch"] == 'resnet101':
134 | last_index = 312
135 | skip_list = [12, 42, 81, 291]
136 | elif self.cfg["arch"] == 'resnet152':
137 | last_index = 465
138 | skip_list = [12, 42, 117, 444]
139 | self.mask_index = [x for x in range(0, last_index, 3)]
140 | # skip downsample layer
141 | if self.cfg["skip_downsample"] == 1:
142 | for x in skip_list:
143 | self.compress_rate[x] = 1
144 | self.mask_index.remove(x)
145 | # print(self.mask_index)
146 | else:
147 | pass
148 |
149 | def init_mask(self, layer_rate):
150 | self.init_rate(layer_rate)
151 | for index, item in enumerate(self.model.parameters()):
152 | if (index in self.mask_index):
153 | self.mat[index] = self.get_filter_codebook(item.data, self.compress_rate[index],
154 | self.model_length[index])
155 | self.mat[index] = self.convert2tensor(self.mat[index])
156 | if self.cfg["use_cuda"]:
157 | self.mat[index] = self.mat[index].cuda()
158 | print("mask Ready")
159 |
160 | def do_mask(self):
161 | for index, item in enumerate(self.model.parameters()):
162 | if (index in self.mask_index):
163 | a = item.data.view(self.model_length[index])
164 | b = a * self.mat[index]
165 | item.data = b.view(self.model_size[index])
166 | print("mask Done")
167 |
168 | def if_zero(self):
169 | for index, item in enumerate(self.model.parameters()):
170 | # if(index in self.mask_index):
171 | if index in [x for x in range(self.cfg["layer_begin"], self.cfg["layer_end"] + 1, self.cfg["layer_inter"])]:
172 | a = item.data.view(self.model_length[index])
173 | b = a.cpu().numpy()
174 |
175 | # print("layer: %d, number of nonzero weight is %d, zero is %d" % (index, np.count_nonzero(b), len(b) - np.count_nonzero(b)))
176 |
--------------------------------------------------------------------------------
/inclearn/tools/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import numbers
4 | import math
5 |
6 |
7 | class IncConfusionMeter:
8 | """Maintains a confusion matrix for a given calssification problem.
9 | The ConfusionMeter constructs a confusion matrix for a multi-class
10 | classification problems. It does not support multi-label, multi-class problems:
11 | for such problems, please use MultiLabelConfusionMeter.
12 | Args:
13 | k (int): number of classes in the classification problem
14 | normalized (boolean): Determines whether or not the confusion matrix
15 | is normalized or not
16 | """
17 | def __init__(self, k, increments, normalized=False):
18 | self.conf = np.ndarray((k, k), dtype=np.int32)
19 | self.normalized = normalized
20 | self.increments = increments
21 | self.cum_increments = [0] + [sum(increments[:i + 1]) for i in range(len(increments))]
22 | self.k = k
23 | self.reset()
24 |
25 | def reset(self):
26 | self.conf.fill(0)
27 |
28 | def add(self, predicted, target):
29 | """Computes the confusion matrix of K x K size where K is no of classes
30 | Args:
31 | predicted (tensor): Can be an N x K tensor of predicted scores obtained from
32 | the model for N examples and K classes or an N-tensor of
33 | integer values between 0 and K-1.
34 | target (tensor): Can be a N-tensor of integer values assumed to be integer
35 | values between 0 and K-1 or N x K tensor, where targets are
36 | assumed to be provided as one-hot vectors
37 | """
38 | if isinstance(predicted, torch.Tensor):
39 | predicted = predicted.cpu().numpy()
40 | if isinstance(target, torch.Tensor):
41 | target = target.cpu().numpy()
42 |
43 | assert predicted.shape[0] == target.shape[0], \
44 | 'number of targets and predicted outputs do not match'
45 |
46 | if np.ndim(predicted) != 1:
47 | assert predicted.shape[1] == self.k, \
48 | 'number of predictions does not match size of confusion matrix'
49 | predicted = np.argmax(predicted, 1)
50 | else:
51 | assert (predicted.max() < self.k) and (predicted.min() >= 0), \
52 | 'predicted values are not between 1 and k'
53 |
54 | onehot_target = np.ndim(target) != 1
55 | if onehot_target:
56 | assert target.shape[1] == self.k, \
57 | 'Onehot target does not match size of confusion matrix'
58 | assert (target >= 0).all() and (target <= 1).all(), \
59 | 'in one-hot encoding, target values should be 0 or 1'
60 | assert (target.sum(1) == 1).all(), \
61 | 'multi-label setting is not supported'
62 | target = np.argmax(target, 1)
63 | else:
64 | assert (predicted.max() < self.k) and (predicted.min() >= 0), \
65 | 'predicted values are not between 0 and k-1'
66 |
67 | # hack for bincounting 2 arrays together
68 | x = predicted + self.k * target
69 | bincount_2d = np.bincount(x.astype(np.int32), minlength=self.k**2)
70 | assert bincount_2d.size == self.k**2
71 | conf = bincount_2d.reshape((self.k, self.k))
72 |
73 | self.conf += conf
74 |
75 | def value(self):
76 | """
77 | Returns:
78 | Confustion matrix of K rows and K columns, where rows corresponds
79 | to ground-truth targets and columns corresponds to predicted
80 | targets.
81 | """
82 | conf = self.conf.astype(np.float32)
83 | new_conf = np.zeros([len(self.increments), len(self.increments) + 2])
84 | for i in range(len(self.increments)):
85 | idxs = range(self.cum_increments[i], self.cum_increments[i + 1])
86 | new_conf[i, 0] = conf[idxs, idxs].sum()
87 | new_conf[i, 1] = conf[self.cum_increments[i]:self.cum_increments[i + 1],
88 | self.cum_increments[i]:self.cum_increments[i + 1]].sum() - new_conf[i, 0]
89 | for j in range(len(self.increments)):
90 | new_conf[i, j + 2] = conf[self.cum_increments[i]:self.cum_increments[i + 1],
91 | self.cum_increments[j]:self.cum_increments[j + 1]].sum()
92 | conf = new_conf
93 | if self.normalized:
94 | return conf / conf[:, 2:].sum(1).clip(min=1e-12)[:, None]
95 | else:
96 | return conf
97 |
98 |
99 | class ClassErrorMeter:
100 | def __init__(self, topk=[1], accuracy=False):
101 | super(ClassErrorMeter, self).__init__()
102 | self.topk = np.sort(topk)
103 | self.accuracy = accuracy
104 | self.reset()
105 |
106 | def reset(self):
107 | self.sum = {v: 0 for v in self.topk}
108 | self.n = 0
109 |
110 | def add(self, output, target):
111 | if isinstance(output, np.ndarray):
112 | output = torch.Tensor(output)
113 | if isinstance(target, np.ndarray):
114 | target = torch.Tensor(target)
115 | # if torch.is_tensor(output):
116 | # output = output.cpu().squeeze().numpy()
117 | # if torch.is_tensor(target):
118 | # target = target.cpu().squeeze().numpy()
119 | # elif isinstance(target, numbers.Number):
120 | # target = np.asarray([target])
121 | # if np.ndim(output) == 1:
122 | # output = output[np.newaxis]
123 | # else:
124 | # assert np.ndim(output) == 2, \
125 | # 'wrong output size (1D or 2D expected)'
126 | # assert np.ndim(target) == 1, \
127 | # 'target and output do not match'
128 | # assert target.shape[0] == output.shape[0], \
129 | # 'target and output do not match'
130 | topk = self.topk
131 | maxk = int(topk[-1]) # seems like Python3 wants int and not np.int64
132 | no = output.shape[0]
133 |
134 | pred = output.topk(maxk, 1, True, True)[1]
135 | correct = pred == target.unsqueeze(1).repeat(1, pred.shape[1])
136 | # pred = torch.from_numpy(output).topk(maxk, 1, True, True)[1].numpy()
137 | # correct = pred == target[:, np.newaxis].repeat(pred.shape[1], 1)
138 |
139 | for k in topk:
140 | self.sum[k] += no - correct[:, 0:k].sum()
141 | self.n += no
142 |
143 | def value(self, k=-1):
144 | if k != -1:
145 | assert k in self.sum.keys(), \
146 | 'invalid k (this k was not provided at construction time)'
147 | if self.n == 0:
148 | return float('nan')
149 | if self.accuracy:
150 | return (1. - float(self.sum[k]) / self.n) * 100.0
151 | else:
152 | return float(self.sum[k]) / self.n * 100.0
153 | else:
154 | return [self.value(k_) for k_ in self.topk]
155 |
156 |
157 | class AverageValueMeter:
158 | def __init__(self):
159 | super(AverageValueMeter, self).__init__()
160 | self.reset()
161 | self.val = 0
162 |
163 | def add(self, value, n=1):
164 | self.val = value
165 | self.sum += value
166 | self.var += value * value
167 | self.n += n
168 |
169 | if self.n == 0:
170 | self.mean, self.std = np.nan, np.nan
171 | elif self.n == 1:
172 | self.mean, self.std = self.sum, np.inf
173 | self.mean_old = self.mean
174 | self.m_s = 0.0
175 | else:
176 | self.mean = self.mean_old + (value - n * self.mean_old) / float(self.n)
177 | self.m_s += (value - self.mean_old) * (value - self.mean)
178 | self.mean_old = self.mean
179 | self.std = math.sqrt(self.m_s / (self.n - 1.0))
180 |
181 | def value(self):
182 | return self.mean, self.std
183 |
184 | def reset(self):
185 | self.n = 0
186 | self.sum = 0.0
187 | self.var = 0.0
188 | self.val = 0.0
189 | self.mean = np.nan
190 | self.mean_old = 0.0
191 | self.m_s = 0.0
192 | self.std = np.nan
--------------------------------------------------------------------------------
/inclearn/convnet/network.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import pdb
3 |
4 | import torch
5 | from torch import nn
6 | import torch.nn.functional as F
7 |
8 | from inclearn.tools import factory
9 | from inclearn.convnet.imbalance import CR, All_av
10 | from inclearn.convnet.classifier import CosineClassifier
11 |
12 |
13 | class BasicNet(nn.Module):
14 | def __init__(
15 | self,
16 | convnet_type,
17 | cfg,
18 | nf=64,
19 | use_bias=False,
20 | init="kaiming",
21 | device=None,
22 | dataset="cifar100",
23 | ):
24 | super(BasicNet, self).__init__()
25 | self.nf = nf
26 | self.init = init
27 | self.convnet_type = convnet_type
28 | self.dataset = dataset
29 | self.start_class = cfg['start_class']
30 | self.weight_normalization = cfg['weight_normalization']
31 | self.remove_last_relu = True if self.weight_normalization else False
32 | self.use_bias = use_bias if not self.weight_normalization else False
33 | self.dea = cfg['dea']
34 | self.ft_type = cfg.get('feature_type', 'normal')
35 | self.at_res = cfg.get('attention_use_residual', False)
36 | self.div_type = cfg['div_type']
37 | self.reuse_oldfc = cfg['reuse_oldfc']
38 | self.prune = cfg.get('prune', False)
39 | self.reset = cfg.get('reset_se', True)
40 |
41 |
42 | if self.dea:
43 | print("Enable dynamical reprensetation expansion!")
44 | self.convnets = nn.ModuleList()
45 | self.convnets.append(
46 | factory.get_convnet(convnet_type,
47 | nf=nf,
48 | dataset=dataset,
49 | start_class=self.start_class,
50 | remove_last_relu=self.remove_last_relu))
51 | self.out_dim = self.convnets[0].out_dim
52 | else:
53 | self.convnet = factory.get_convnet(convnet_type,
54 | nf=nf,
55 | dataset=dataset,
56 | remove_last_relu=self.remove_last_relu)
57 | self.out_dim = self.convnet.out_dim
58 | self.classifier = None
59 | self.se = None
60 | self.aux_classifier = None
61 |
62 | self.n_classes = 0
63 | self.ntask = 0
64 | self.device = device
65 |
66 | if cfg['postprocessor']['enable']:
67 | if cfg['postprocessor']['type'].lower() == "cr":
68 | self.postprocessor = CR()
69 | elif cfg['postprocessor']['type'].lower() == "aver":
70 | self.postprocessor = All_av()
71 | else:
72 | self.postprocessor = None
73 |
74 | self.to(self.device)
75 |
76 | def forward(self, x):
77 | if self.classifier is None:
78 | raise Exception("Add some classes before training.")
79 |
80 | if self.dea:
81 | feature = [convnet(x) for convnet in self.convnets]
82 | features = torch.cat(feature, 1)
83 | last_dim = feature[-1].size(1)
84 | width = features.size(1)
85 |
86 | if self.reset:
87 | se = factory.get_attention(width, self.ft_type, self.at_res).to(self.device)
88 | features = se(features)
89 | else:
90 | features = self.se(features)
91 |
92 | else:
93 | features = self.convnet(x)
94 |
95 | logits = self.classifier(features)
96 |
97 | div_logits = self.aux_classifier(features[:, -last_dim:]) if self.ntask > 1 else None
98 |
99 | return {'feature': features, 'logit': logits, 'div_logit': div_logits, 'features': feature}
100 |
101 | def caculate_dim(self, x):
102 | feature = [convnet(x) for convnet in self.convnets]
103 | features = torch.cat(feature, 1)
104 |
105 | width = features.size(1)
106 |
107 | # se = factory.get_attention(width, self.ft_type, self.at_res).to(self.device)
108 | se = factory.get_attention(width, "ce", self.at_res).cuda()
109 | features = se(features)
110 |
111 | # import pdb
112 | # pdb.set_trace()
113 | return features.size(1), feature[-1].size(1)
114 |
115 | @property
116 | def features_dim(self):
117 | if self.dea:
118 | return self.out_dim * len(self.convnets)
119 | else:
120 | return self.out_dim
121 |
122 | def freeze(self):
123 | for param in self.parameters():
124 | param.requires_grad = False
125 | self.eval()
126 | return self
127 |
128 | def copy(self):
129 | return copy.deepcopy(self)
130 |
131 | def add_classes(self, n_classes):
132 | self.ntask += 1
133 |
134 | if self.dea:
135 | self._add_classes_multi_fc(n_classes)
136 | else:
137 | self._add_classes_single_fc(n_classes)
138 |
139 | self.n_classes += n_classes
140 |
141 | def _add_classes_multi_fc(self, n_classes):
142 | if self.ntask > 1:
143 | new_clf = factory.get_convnet(self.convnet_type,
144 | nf=self.nf,
145 | dataset=self.dataset,
146 | start_class=self.start_class,
147 | remove_last_relu=self.remove_last_relu).to(self.device)
148 | if self.prune:
149 | pass
150 | else:
151 | new_clf.load_state_dict(self.convnets[-1].state_dict())
152 | self.convnets.append(new_clf)
153 |
154 | if not self.reset:
155 | self.se = factory.get_attention(512*len(self.convnets), self.ft_type, self.at_res)
156 | self.se.to(self.device)
157 |
158 | if self.classifier is not None:
159 | weight = copy.deepcopy(self.classifier.weight.data)
160 |
161 | fc = self._gen_classifier(self.out_dim * len(self.convnets), self.n_classes + n_classes)
162 |
163 | if self.classifier is not None and self.reuse_oldfc:
164 | fc.weight.data[:self.n_classes, :self.out_dim * (len(self.convnets) - 1)] = weight
165 | del self.classifier
166 | self.classifier = fc
167 |
168 | if self.div_type == "n+1":
169 | div_fc = self._gen_classifier(self.out_dim, n_classes + 1)
170 | elif self.div_type == "1+1":
171 | div_fc = self._gen_classifier(self.out_dim, 2)
172 | elif self.div_type == "n+t":
173 | div_fc = self._gen_classifier(self.out_dim, self.ntask + n_classes)
174 | else:
175 | div_fc = self._gen_classifier(self.out_dim, self.n_classes + n_classes)
176 | del self.aux_classifier
177 | self.aux_classifier = div_fc
178 |
179 | def _add_classes_single_fc(self, n_classes):
180 | if self.classifier is not None:
181 | weight = copy.deepcopy(self.classifier.weight.data)
182 | if self.use_bias:
183 | bias = copy.deepcopy(self.classifier.bias.data)
184 |
185 | classifier = self._gen_classifier(self.features_dim, self.n_classes + n_classes)
186 |
187 | if self.classifier is not None and self.reuse_oldfc:
188 | classifier.weight.data[:self.n_classes] = weight
189 | if self.use_bias:
190 | classifier.bias.data[:self.n_classes] = bias
191 |
192 | del self.classifier
193 | self.classifier = classifier
194 |
195 | def _gen_classifier(self, in_features, n_classes):
196 | if self.weight_normalization:
197 | classifier = CosineClassifier(in_features, n_classes).to(self.device)
198 | # classifier = CosineClassifier(in_features, n_classes).cuda()
199 | else:
200 | classifier = nn.Linear(in_features, n_classes, bias=self.use_bias).to(self.device)
201 | # classifier = nn.Linear(in_features, n_classes, bias=self.use_bias).cuda()
202 | if self.init == "kaiming":
203 | nn.init.kaiming_normal_(classifier.weight, nonlinearity="linear")
204 | if self.use_bias:
205 | nn.init.constant_(classifier.bias, 0.0)
206 |
207 | return classifier
208 |
--------------------------------------------------------------------------------
/inclearn/prune/resnet_small.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
6 | """3x3 convolution with padding"""
7 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
8 | padding=dilation, groups=groups, bias=False, dilation=dilation)
9 |
10 |
11 | def conv1x1(in_planes, out_planes, stride=1):
12 | """1x1 convolution"""
13 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
14 |
15 |
16 | class BasicBlock(nn.Module):
17 | expansion = 1
18 |
19 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
20 | base_width=64, dilation=1, norm_layer=None):
21 | super(BasicBlock, self).__init__()
22 | if norm_layer is None:
23 | norm_layer = nn.BatchNorm2d
24 | if groups != 1 or base_width != 64:
25 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
26 | if dilation > 1:
27 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
28 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
29 | self.conv1 = conv3x3(inplanes, planes, stride)
30 | self.bn1 = norm_layer(planes)
31 | self.relu = nn.ReLU(inplace=True)
32 | self.conv2 = conv3x3(planes, planes)
33 | self.bn2 = norm_layer(planes)
34 | self.downsample = downsample
35 | self.stride = stride
36 |
37 | def forward(self, x):
38 | identity = x
39 |
40 | out = self.conv1(x)
41 | out = self.bn1(out)
42 | out = self.relu(out)
43 |
44 | out = self.conv2(out)
45 | out = self.bn2(out)
46 |
47 | if self.downsample is not None:
48 | identity = self.downsample(x)
49 |
50 | out += identity
51 | out = self.relu(out)
52 |
53 | return out
54 |
55 |
56 | class Bottleneck(nn.Module):
57 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
58 | # while original implementation places the stride at the first 1x1 convolution(self.conv1)
59 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
60 | # This variant is also known as ResNet V1.5 and improves accuracy according to
61 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
62 |
63 | expansion = 4
64 |
65 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
66 | base_width=64, dilation=1, norm_layer=None):
67 | super(Bottleneck, self).__init__()
68 | if norm_layer is None:
69 | norm_layer = nn.BatchNorm2d
70 | width = int(planes * (base_width / 64.)) * groups
71 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
72 | self.conv1 = conv1x1(inplanes, width)
73 | self.bn1 = norm_layer(width)
74 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
75 | self.bn2 = norm_layer(width)
76 | self.conv3 = conv1x1(width, planes * self.expansion)
77 | self.bn3 = norm_layer(planes * self.expansion)
78 | self.relu = nn.ReLU(inplace=True)
79 | self.downsample = downsample
80 | self.stride = stride
81 |
82 | def forward(self, x):
83 | identity = x
84 |
85 | out = self.conv1(x)
86 | out = self.bn1(out)
87 | out = self.relu(out)
88 |
89 | out = self.conv2(out)
90 | out = self.bn2(out)
91 | out = self.relu(out)
92 |
93 | out = self.conv3(out)
94 | out = self.bn3(out)
95 |
96 | if self.downsample is not None:
97 | identity = self.downsample(x)
98 |
99 | out += identity
100 | out = self.relu(out)
101 |
102 | return out
103 |
104 |
105 | class ResNet(nn.Module):
106 |
107 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
108 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
109 | norm_layer=None):
110 | super(ResNet, self).__init__()
111 | if norm_layer is None:
112 | norm_layer = nn.BatchNorm2d
113 | self._norm_layer = norm_layer
114 |
115 | self.inplanes = 64
116 | self.dilation = 1
117 | if replace_stride_with_dilation is None:
118 | # each element in the tuple indicates if we should replace
119 | # the 2x2 stride with a dilated convolution instead
120 | replace_stride_with_dilation = [False, False, False]
121 | if len(replace_stride_with_dilation) != 3:
122 | raise ValueError("replace_stride_with_dilation should be None "
123 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
124 | self.groups = groups
125 | self.base_width = width_per_group
126 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
127 | bias=False)
128 | self.bn1 = norm_layer(self.inplanes)
129 | self.relu = nn.ReLU(inplace=True)
130 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
131 | self.layer1 = self._make_layer(block, 64, layers[0])
132 | # self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
133 | # self.fc = nn.Linear(64 * block.expansion, num_classes)
134 | self.final_conv=nn.Conv2d(64, 1024, kernel_size=3, stride=1, padding=1,
135 | bias=False)
136 |
137 | for m in self.modules():
138 | if isinstance(m, nn.Conv2d):
139 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
140 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
141 | nn.init.constant_(m.weight, 1)
142 | nn.init.constant_(m.bias, 0)
143 |
144 | # Zero-initialize the last BN in each residual branch,
145 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
146 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
147 | if zero_init_residual:
148 | for m in self.modules():
149 | if isinstance(m, Bottleneck):
150 | nn.init.constant_(m.bn3.weight, 0)
151 | elif isinstance(m, BasicBlock):
152 | nn.init.constant_(m.bn2.weight, 0)
153 |
154 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
155 | norm_layer = self._norm_layer
156 | downsample = None
157 | previous_dilation = self.dilation
158 | if dilate:
159 | self.dilation *= stride
160 | stride = 1
161 | if stride != 1 or self.inplanes != planes * block.expansion:
162 | downsample = nn.Sequential(
163 | conv1x1(self.inplanes, planes * block.expansion, stride),
164 | norm_layer(planes * block.expansion),
165 | )
166 |
167 | layers = []
168 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
169 | self.base_width, previous_dilation, norm_layer))
170 | self.inplanes = planes * block.expansion
171 | for _ in range(1, blocks):
172 | layers.append(block(self.inplanes, planes, groups=self.groups,
173 | base_width=self.base_width, dilation=self.dilation,
174 | norm_layer=norm_layer))
175 |
176 | return nn.Sequential(*layers)
177 |
178 | def _forward_impl(self, x):
179 | # See note [TorchScript super()]
180 | x = self.conv1(x)
181 | x = self.bn1(x)
182 | x = self.relu(x)
183 | x = self.maxpool(x)
184 |
185 | x = self.layer1(x)
186 | # x = self.avgpool(x)
187 | # x = torch.flatten(x, 1)
188 | # x = self.fc(x)
189 | x=self.final_conv(x)
190 | return x
191 |
192 | def forward(self, x):
193 | return self._forward_impl(x)
194 |
195 | def _resnet(arch, block, layers, pretrained, progress, **kwargs):
196 | model = ResNet(block, layers, **kwargs)
197 | return model
198 |
199 | def resnet_small(pretrained=False, progress=True, **kwargs):
200 | r"""ResNet-18 model from
201 | `"Deep Residual Learning for Image Recognition" `_
202 |
203 | Args:
204 | pretrained (bool): If True, returns a model pre-trained on ImageNet
205 | progress (bool): If True, displays a progress bar of the download to stderr
206 | """
207 | return _resnet('resnet_small', BasicBlock, [1], pretrained, progress,
208 | **kwargs)
209 |
--------------------------------------------------------------------------------
/inclearn/datasets/dataset.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import numpy as np
3 | import glob
4 |
5 | from albumentations.pytorch import ToTensorV2
6 |
7 | from torchvision import datasets, transforms
8 | import torch
9 | from inclearn.tools.cutout import Cutout
10 | from inclearn.tools.autoaugment_extra import ImageNetPolicy
11 |
12 |
13 | def get_datasets(dataset_names):
14 | return [get_dataset(dataset_name) for dataset_name in dataset_names.split("-")]
15 |
16 |
17 | def get_dataset(dataset_name):
18 | if dataset_name == "cifar10":
19 | return iCIFAR10
20 | elif dataset_name == "cifar100":
21 | return iCIFAR100
22 | elif "imagenet100" in dataset_name:
23 | return iImageNet100
24 | else:
25 | raise NotImplementedError("Unknown dataset {}.".format(dataset_name))
26 |
27 |
28 | class DataHandler:
29 | base_dataset = None
30 | train_transforms = []
31 | common_transforms = [ToTensorV2()]
32 | class_order = None
33 |
34 |
35 | class iCIFAR10(DataHandler):
36 | base_dataset_cls = datasets.cifar.CIFAR10
37 | transform_type = 'torchvision'
38 | train_transforms = transforms.Compose([
39 | transforms.RandomCrop(32, padding=4),
40 | transforms.RandomHorizontalFlip(),
41 | # transforms.ColorJitter(brightness=63 / 255),
42 | transforms.ToTensor(),
43 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
44 | ])
45 | test_transforms = transforms.Compose([
46 | transforms.ToTensor(),
47 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
48 | ])
49 |
50 | def __init__(self, data_folder, train, is_fine_label=False):
51 | self.base_dataset = self.base_dataset_cls(data_folder, train=train, download=True)
52 | self.data = self.base_dataset.data
53 | self.targets = self.base_dataset.targets
54 | self.n_cls = 10
55 |
56 | @property
57 | def is_proc_inc_data(self):
58 | return False
59 |
60 | @classmethod
61 | def class_order(cls, trial_i):
62 | return [4, 0, 2, 5, 8, 3, 1, 6, 9, 7]
63 |
64 |
65 | class iCIFAR100(iCIFAR10):
66 | label_list = [
67 | 'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle', 'bowl', 'boy',
68 | 'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock',
69 | 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', 'dolphin', 'elephant', 'flatfish',
70 | 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion',
71 | 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', 'mushroom', 'oak_tree', 'orange',
72 | 'orchid', 'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine',
73 | 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk',
74 | 'skyscraper', 'snail', 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank',
75 | 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale',
76 | 'willow_tree', 'wolf', 'woman', 'worm'
77 | ]
78 | base_dataset_cls = datasets.cifar.CIFAR100
79 | transform_type = 'torchvision'
80 | train_transforms = transforms.Compose([
81 | transforms.RandomCrop(32, padding=4),
82 | transforms.RandomHorizontalFlip(),
83 | transforms.ColorJitter(brightness=63 / 255),
84 | transforms.ToTensor(),
85 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
86 | ])
87 | test_transforms = transforms.Compose([
88 | transforms.ToTensor(),
89 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
90 | ])
91 |
92 | def __init__(self, data_folder, train, is_fine_label=False):
93 | self.base_dataset = self.base_dataset_cls(data_folder, train=train, download=True)
94 | self.data = self.base_dataset.data
95 | self.targets = self.base_dataset.targets
96 | self.n_cls = 100
97 | self.transform_type = 'torchvision'
98 |
99 | @property
100 | def is_proc_inc_data(self):
101 | return False
102 |
103 | @classmethod
104 | def class_order(cls, trial_i):
105 | if trial_i == 0:
106 | return [
107 | 62, 54, 84, 20, 94, 22, 40, 29, 78, 27, 26, 79, 17, 76, 68, 88, 3, 19, 31, 21, 33, 60, 24, 14, 6, 10,
108 | 16, 82, 70, 92, 25, 5, 28, 9, 61, 36, 50, 90, 8, 48, 47, 56, 11, 98, 35, 93, 44, 64, 75, 66, 15, 38, 97,
109 | 42, 43, 12, 37, 55, 72, 95, 18, 7, 23, 71, 49, 53, 57, 86, 39, 87, 34, 63, 81, 89, 69, 46, 2, 1, 73, 32,
110 | 67, 91, 0, 51, 83, 13, 58, 80, 74, 65, 4, 30, 45, 77, 99, 85, 41, 96, 59, 52
111 | ]
112 | elif trial_i == 1:
113 | return [
114 | 68, 56, 78, 8, 23, 84, 90, 65, 74, 76, 40, 89, 3, 92, 55, 9, 26, 80, 43, 38, 58, 70, 77, 1, 85, 19, 17,
115 | 50, 28, 53, 13, 81, 45, 82, 6, 59, 83, 16, 15, 44, 91, 41, 72, 60, 79, 52, 20, 10, 31, 54, 37, 95, 14,
116 | 71, 96, 98, 97, 2, 64, 66, 42, 22, 35, 86, 24, 34, 87, 21, 99, 0, 88, 27, 18, 94, 11, 12, 47, 25, 30,
117 | 46, 62, 69, 36, 61, 7, 63, 75, 5, 32, 4, 51, 48, 73, 93, 39, 67, 29, 49, 57, 33
118 | ]
119 | elif trial_i == 2: #PODNet
120 | return [
121 | 87, 0, 52, 58, 44, 91, 68, 97, 51, 15, 94, 92, 10, 72, 49, 78, 61, 14, 8, 86, 84, 96, 18, 24, 32, 45,
122 | 88, 11, 4, 67, 69, 66, 77, 47, 79, 93, 29, 50, 57, 83, 17, 81, 41, 12, 37, 59, 25, 20, 80, 73, 1, 28, 6,
123 | 46, 62, 82, 53, 9, 31, 75, 38, 63, 33, 74, 27, 22, 36, 3, 16, 21, 60, 19, 70, 90, 89, 43, 5, 42, 65, 76,
124 | 40, 30, 23, 85, 2, 95, 56, 48, 71, 64, 98, 13, 99, 7, 34, 55, 54, 26, 35, 39
125 | ]
126 | elif trial_i == 3: #PODNet
127 | return [
128 | 58, 30, 93, 69, 21, 77, 3, 78, 12, 71, 65, 40, 16, 49, 89, 46, 24, 66, 19, 41, 5, 29, 15, 73, 11, 70,
129 | 90, 63, 67, 25, 59, 72, 80, 94, 54, 33, 18, 96, 2, 10, 43, 9, 57, 81, 76, 50, 32, 6, 37, 7, 68, 91, 88,
130 | 95, 85, 4, 60, 36, 22, 27, 39, 42, 34, 51, 55, 28, 53, 48, 38, 17, 83, 86, 56, 35, 45, 79, 99, 84, 97,
131 | 82, 98, 26, 47, 44, 62, 13, 31, 0, 75, 14, 52, 74, 8, 20, 1, 92, 87, 23, 64, 61
132 | ]
133 | elif trial_i == 4: #PODNet
134 | return [
135 | 71, 54, 45, 32, 4, 8, 48, 66, 1, 91, 28, 82, 29, 22, 80, 27, 86, 23, 37, 47, 55, 9, 14, 68, 25, 96, 36,
136 | 90, 58, 21, 57, 81, 12, 26, 16, 89, 79, 49, 31, 38, 46, 20, 92, 88, 40, 39, 98, 94, 19, 95, 72, 24, 64,
137 | 18, 60, 50, 63, 61, 83, 76, 69, 35, 0, 52, 7, 65, 42, 73, 74, 30, 41, 3, 6, 53, 13, 56, 70, 77, 34, 97,
138 | 75, 2, 17, 93, 33, 84, 99, 51, 62, 87, 5, 15, 10, 78, 67, 44, 59, 85, 43, 11
139 | ]
140 |
141 |
142 | class DataHandler:
143 | base_dataset = None
144 | train_transforms = []
145 | common_transforms = [ToTensorV2()]
146 | class_order = None
147 |
148 |
149 | class iImageNet100(DataHandler):
150 |
151 | base_dataset_cls = datasets.ImageFolder
152 | transform_type = 'torchvision'
153 | train_transforms = transforms.Compose([
154 | transforms.ToPILImage(),
155 | transforms.ToTensor(),
156 | Cutout(n_holes=1, length=16),
157 | transforms.ToPILImage(),
158 | transforms.RandomResizedCrop(224),
159 | transforms.RandomHorizontalFlip(),
160 | ImageNetPolicy(),
161 | transforms.ToTensor(),
162 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
163 |
164 | ])
165 | test_transforms = transforms.Compose([
166 | transforms.ToPILImage(),
167 | transforms.ToTensor(),
168 | transforms.ToPILImage(),
169 | transforms.Resize(256),
170 | transforms.CenterCrop(224),
171 | transforms.ToTensor(),
172 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
173 | ])
174 |
175 | def __init__(self, data_folder, train, is_fine_label=False):
176 | if train is True:
177 | self.base_dataset = self.base_dataset_cls(osp.join(data_folder, "train"))
178 | else:
179 | self.base_dataset = self.base_dataset_cls(osp.join(data_folder, "val"))
180 |
181 | self.data, self.targets = zip(*self.base_dataset.samples)
182 | self.data = np.array(self.data)
183 | self.targets = np.array(self.targets)
184 | self.n_cls = 100
185 |
186 | @property
187 | def is_proc_inc_data(self):
188 | return False
189 |
190 | @classmethod
191 | def class_order(cls, trial_i):
192 | return [
193 | 68, 56, 78, 8, 23, 84, 90, 65, 74, 76, 40, 89, 3, 92, 55, 9, 26, 80, 43, 38, 58, 70, 77, 1, 85, 19, 17, 50,
194 | 28, 53, 13, 81, 45, 82, 6, 59, 83, 16, 15, 44, 91, 41, 72, 60, 79, 52, 20, 10, 31, 54, 37, 95, 14, 71, 96,
195 | 98, 97, 2, 64, 66, 42, 22, 35, 86, 24, 34, 87, 21, 99, 0, 88, 27, 18, 94, 11, 12, 47, 25, 30, 46, 62, 69,
196 | 36, 61, 7, 63, 75, 5, 32, 4, 51, 48, 73, 93, 39, 67, 29, 49, 57, 33
197 | ]
198 |
--------------------------------------------------------------------------------
/inclearn/prune/autoslim_test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from itertools import chain
5 | from .dependency import *
6 | from . import prune
7 | import math
8 | from scipy.spatial import distance
9 |
10 | __all__ = ['Autoslim']
11 |
12 |
13 | class Autoslim(object):
14 | def __init__(self, model, inputs, compression_ratio):
15 | self.model = model # torchvision.models模型
16 | self.inputs = inputs # 输入大小,torch.randn(1,3,224,224)
17 | self.compression_ratio = compression_ratio # 期望压缩率
18 | self.DG = DependencyGraph()
19 | # 构建节点依赖关系
20 | self.DG.build_dependency(model, example_inputs=inputs)
21 | self.model_modules = list(model.modules())
22 | self.pruning_func = {
23 | 'l1': self._base_l1_pruning,
24 | 'fpgm': self._base_fpgm_pruning
25 | }
26 |
27 | def index_of_layer(self):
28 | dicts = {}
29 | for i, m in enumerate(self.model_modules):
30 | if isinstance(m, nn.modules.conv._ConvNd):
31 | dicts[i] = m
32 | return dicts
33 |
34 | def base_prunging(self, config):
35 | if not config['pruning_func'] in self.pruning_func:
36 | raise KeyError(
37 | "-[ERROR] {} not supported.".format((config['pruning_func'])))
38 |
39 | ori_output = {}
40 | for i, m in enumerate(self.model_modules):
41 | if isinstance(m, nn.modules.conv._ConvNd):
42 | ori_output[i] = m.out_channels
43 |
44 | if config['layer_compression_ratio'] is None and config['prune_shortcut'] == 1:
45 | config['layer_compression_ratio'] = self._compute_auto_ratios()
46 |
47 | prune_indexes = self.pruning_func[config['pruning_func']](config)
48 |
49 | for i, m in enumerate(self.model_modules):
50 | if i in prune_indexes and m.out_channels == ori_output[i]:
51 | pruning_plan = self.DG.get_pruning_plan(
52 | m, prune.prune_conv, idxs=prune_indexes[i])
53 | if pruning_plan and config['prune_shortcut'] == 1:
54 | pruning_plan.exec()
55 | elif not pruning_plan.is_in_shortcut:
56 | pruning_plan.exec()
57 |
58 | def _base_fpgm_pruning(self, config):
59 | prune_indexes = {}
60 | for i, m in enumerate(self.model_modules):
61 | # _ConvNd包含卷积和反卷积
62 | if isinstance(m, nn.modules.conv._ConvNd):
63 | weight_torch = m.weight.detach().cuda()
64 | if isinstance(m, nn.modules.conv._ConvTransposeMixin):
65 | weight_vec = weight_torch.view(weight_torch.size()[1], -1)
66 | out_channels = weight_torch.size()[1]
67 | else:
68 | weight_vec = weight_torch.view(
69 | weight_torch.size()[0], -1) # 权重[512,64,3,3] -> [512, 64*3*3]
70 | out_channels = weight_torch.size()[0]
71 |
72 | if config['layer_compression_ratio'] and i in config['layer_compression_ratio']:
73 | similar_pruned_num = int(
74 | out_channels * config['layer_compression_ratio'][i])
75 | # 全自动化压缩时,不剪跳连层
76 | else:
77 | similar_pruned_num = int(
78 | out_channels * self.compression_ratio)
79 |
80 | filter_pruned_num = int(
81 | out_channels * (1 - config['norm_rate']))
82 |
83 | if config['dist_type'] == "l2" or "cos":
84 | norm = torch.norm(weight_vec, 2, 1)
85 | norm_np = norm.cpu().numpy()
86 | elif config['dist_type'] == "l1":
87 | norm = torch.norm(weight_vec, 1, 1)
88 | norm_np = norm.cpu().numpy()
89 |
90 | filter_large_index = []
91 | filter_large_index = norm_np.argsort()[filter_pruned_num:]
92 |
93 | indices = torch.LongTensor(filter_large_index).cuda()
94 | # weight_vec_after_norm.size=15
95 | weight_vec_after_norm = torch.index_select(
96 | weight_vec, 0, indices).cpu().numpy()
97 |
98 | # for euclidean distance
99 | if config['dist_type'] == "l2" or "l1":
100 | similar_matrix = distance.cdist(
101 | weight_vec_after_norm, weight_vec_after_norm, 'euclidean')
102 | elif config['dist_type'] == "cos": # for cos similarity
103 | similar_matrix = 1 - \
104 | distance.cdist(weight_vec_after_norm,
105 | weight_vec_after_norm, 'cosine')
106 |
107 | # 将任意一个点与其他点的距离算出来,最后将距离相加,一共得到15组数据
108 | similar_sum = np.sum(np.abs(similar_matrix), axis=0)
109 |
110 | # for distance similar: get the filter index with largest similarity == small distance
111 | similar_large_index = similar_sum.argsort()[
112 | similar_pruned_num:]
113 | similar_small_index = similar_sum.argsort()[
114 | :similar_pruned_num]
115 | prune_index = [filter_large_index[i]
116 | for i in similar_small_index]
117 | prune_indexes[i] = prune_index
118 | return prune_indexes
119 |
120 | def _base_l1_pruning(self, config):
121 | prune_indexes = {}
122 | # 全局阈值剪枝法(最好别用,效果不佳)
123 | if config['global_pruning']:
124 | filter_record = []
125 | for i, m in enumerate(self.model_modules):
126 | if isinstance(m, nn.modules.conv._ConvNd):
127 | weight = m.weight.detach().cpu().numpy()
128 | if isinstance(m, nn.modules.conv._ConvTransposeMixin):
129 | L1_norm = np.sum(np.abs(weight), axis=(
130 | 0, 2, 3)) # 注:反卷积维数1对应输出维度
131 | else:
132 | L1_norm = np.sum(np.abs(weight), axis=(1, 2, 3))
133 | filter_record.append(L1_norm.tolist()) # 记录每层卷积的l1_norm参数
134 |
135 | filter_record = list(chain.from_iterable(filter_record))
136 | total = len(filter_record)
137 | filter_record.sort() # 全局排序
138 | thre_index = int(total * self.compression_ratio)
139 | thre = filter_record[thre_index] # 根据裁剪率确定阈值
140 | for i, m in enumerate(self.model_modules):
141 | if isinstance(m, nn.modules.conv._ConvNd):
142 | weight = m.weight.detach().cpu().numpy()
143 | # _ConvTransposeMixin只包含反卷积
144 | if isinstance(m, nn.modules.conv._ConvTransposeMixin):
145 | L1_norm = np.sum(np.abs(weight), axis=(
146 | 0, 2, 3)) # 注:反卷积维数1对应输出维度
147 | else:
148 | L1_norm = np.sum(np.abs(weight), axis=(1, 2, 3))
149 | num_pruned = min(int(max_ratio*len(L1_norm)),
150 | len(L1_norm[L1_norm < thre])) # 不能全部减去
151 | # 删除低于阈值的卷积核
152 | prune_index = np.argsort(L1_norm)[:num_pruned].tolist()
153 | prune_indexes.append(prune_index)
154 |
155 | # 局部阈值加指定层
156 | else:
157 | if config['layer_compression_ratio'] is None and config['prune_shortcut'] == 1:
158 | # 需要剪跳连层,并且未指定每一层的裁剪率
159 | config['layer_compression_ratio'] = self._compute_auto_ratios()
160 |
161 | for i, m in enumerate(self.model_modules):
162 | # 逐层裁剪
163 | # _ConvNd包含卷积和反卷积
164 | if isinstance(m, nn.modules.conv._ConvNd):
165 | weight = m.weight.detach().cpu().numpy()
166 | # _ConvTransposeMixin只包含反卷积
167 | if isinstance(m, nn.modules.conv._ConvTransposeMixin):
168 | out_channels = weight.shape[1]
169 | L1_norm = np.sum(np.abs(weight), axis=(0, 2, 3))
170 | else:
171 | out_channels = weight.shape[0]
172 | L1_norm = np.sum(
173 | np.abs(weight), axis=(1, 2, 3)) # 计算卷积核的L1范式
174 |
175 | # 自定义压缩或全自动化压缩时剪跳连层
176 | if config['layer_compression_ratio'] and i in config['layer_compression_ratio']:
177 | num_pruned = int(
178 | out_channels * config['layer_compression_ratio'][i])
179 | # 全自动化压缩时,不剪跳连层
180 | else:
181 | num_pruned = int(out_channels * self.compression_ratio)
182 |
183 | # remove filters with small L1-Norm
184 | prune_index = np.argsort(L1_norm)[:num_pruned].tolist()
185 | prune_indexes.append(prune_index)
186 | return prune_indexes
187 |
188 | def _compute_auto_ratios(self):
189 | layer_compression_ratio = {}
190 | mid_value = self.compression_ratio
191 |
192 | one_value = (1-mid_value)/4 if mid_value >= 0.43 else mid_value/4
193 | values = [mid_value-one_value*3, mid_value-one_value*2, mid_value-one_value,
194 | mid_value, mid_value+one_value, mid_value+one_value*2, mid_value+one_value*3]
195 | layer_cnt = 0
196 | for i, m in enumerate(self.model_modules):
197 | if isinstance(m, nn.modules.conv._ConvNd):
198 | layer_compression_ratio[i] = 0
199 | layer_cnt += 1
200 | layers_of_class = layer_cnt/7
201 | conv_cnt = 0
202 | for i, m in enumerate(self.model_modules):
203 | if isinstance(m, nn.modules.conv._ConvNd):
204 | layer_compression_ratio[i] = values[math.floor(
205 | conv_cnt/layers_of_class)]
206 | conv_cnt += 1
207 | return layer_compression_ratio
208 |
209 |
210 | if __name__ == "__main__":
211 | from resnet_small import resnet_small
212 | model = resnet_small()
213 | slim = Autoslim(model, inputs=torch.randn(
214 | 1, 3, 224, 224), compression_ratio=0.5)
215 | slim.l1_norm_pruning()
216 | print(model)
217 |
--------------------------------------------------------------------------------
/inclearn/prune/autoslim.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from itertools import chain
5 | from .dependency import *
6 | from . import prune
7 | import math
8 | from scipy.spatial import distance
9 |
10 | __all__ = ['Autoslim']
11 |
12 |
13 | class Autoslim(object):
14 | def __init__(self, model, inputs, compression_ratio):
15 | self.model = model # torchvision.models模型
16 | self.inputs = inputs # 输入大小,torch.randn(1,3,224,224)
17 | self.compression_ratio = compression_ratio # 期望压缩率
18 | self.DG = DependencyGraph()
19 | # 构建节点依赖关系
20 | self.DG.build_dependency(model, example_inputs=inputs)
21 | self.model_modules = list(model.modules())
22 | self.pruning_func = {
23 | 'l1': self._base_l1_pruning,
24 | 'fpgm': self._base_fpgm_pruning
25 | }
26 |
27 | def index_of_layer(self):
28 | dicts = {}
29 | for i, m in enumerate(self.model_modules):
30 | if isinstance(m, nn.modules.conv._ConvNd):
31 | dicts[i] = m
32 | return dicts
33 |
34 | def base_prunging(self, config):
35 | if not config['pruning_func'] in self.pruning_func:
36 | raise KeyError(
37 | "-[ERROR] {} pruning not supported.".format(config['pruning_func']))
38 |
39 | ori_output = {}
40 | for i, m in enumerate(self.model_modules):
41 | if isinstance(m, nn.modules.conv._ConvNd):
42 | ori_output[i] = m.out_channels
43 |
44 | if config['layer_compression_ratio'] is None and config['prune_shortcut'] == 1:
45 | config['layer_compression_ratio'] = self._compute_auto_ratios()
46 |
47 | prune_indexes = self.pruning_func[config['pruning_func']](config)
48 |
49 | for i, m in enumerate(self.model_modules):
50 | if i in prune_indexes and m.out_channels == ori_output[i]:
51 | pruning_plan = self.DG.get_pruning_plan(
52 | m, prune.prune_conv, idxs=prune_indexes[i])
53 | if pruning_plan and config['prune_shortcut'] == 1:
54 | pruning_plan.exec()
55 | elif not pruning_plan:
56 | continue
57 | elif not pruning_plan.is_in_shortcut:
58 | pruning_plan.exec()
59 |
60 | def _base_fpgm_pruning(self, config):
61 | prune_indexes = {}
62 | for i, m in enumerate(self.model_modules):
63 | # _ConvNd包含卷积和反卷积
64 | if isinstance(m, nn.modules.conv._ConvNd):
65 | weight_torch = m.weight.detach().cuda()
66 | if isinstance(m, nn.modules.conv._ConvTransposeMixin):
67 | weight_vec = weight_torch.view(weight_torch.size()[1], -1)
68 | out_channels = weight_torch.size()[1]
69 | else:
70 | weight_vec = weight_torch.view(
71 | weight_torch.size()[0], -1) # 权重[512,64,3,3] -> [512, 64*3*3]
72 | out_channels = weight_torch.size()[0]
73 |
74 | if config['layer_compression_ratio'] and i in config['layer_compression_ratio']:
75 | similar_pruned_num = int(
76 | out_channels * config['layer_compression_ratio'][i])
77 | # 全自动化压缩时,不剪跳连层
78 | else:
79 | similar_pruned_num = int(
80 | out_channels * self.compression_ratio)
81 |
82 | filter_pruned_num = int(
83 | out_channels * (1 - config['norm_rate']))
84 |
85 | if config['dist_type'] == "l2" or "cos":
86 | norm = torch.norm(weight_vec, 2, 1)
87 | norm_np = norm.cpu().numpy()
88 | elif config['dist_type'] == "l1":
89 | norm = torch.norm(weight_vec, 1, 1)
90 | norm_np = norm.cpu().numpy()
91 |
92 | filter_large_index = []
93 | filter_large_index = norm_np.argsort()[filter_pruned_num:]
94 |
95 | indices = torch.LongTensor(filter_large_index).cuda()
96 | # weight_vec_after_norm.size=15
97 | weight_vec_after_norm = torch.index_select(
98 | weight_vec, 0, indices).cpu().numpy()
99 |
100 | # for euclidean distance
101 | if config['dist_type'] == "l2" or "l1":
102 | similar_matrix = distance.cdist(
103 | weight_vec_after_norm, weight_vec_after_norm, 'euclidean')
104 | elif config['dist_type'] == "cos": # for cos similarity
105 | similar_matrix = 1 - \
106 | distance.cdist(weight_vec_after_norm,
107 | weight_vec_after_norm, 'cosine')
108 |
109 | # 将任意一个点与其他点的距离算出来,最后将距离相加,一共得到15组数据
110 | similar_sum = np.sum(np.abs(similar_matrix), axis=0)
111 |
112 | # for distance similar: get the filter index with largest similarity == small distance
113 | similar_large_index = similar_sum.argsort()[
114 | similar_pruned_num:]
115 | similar_small_index = similar_sum.argsort()[
116 | :similar_pruned_num]
117 | prune_index = [filter_large_index[i]
118 | for i in similar_small_index]
119 | prune_indexes[i] = prune_index
120 | return prune_indexes
121 |
122 | def _base_l1_pruning(self, config):
123 | return self.__base_lx_norm_pruning(config, norm='l1')
124 |
125 | def _base_l1_pruning(self, config):
126 | return self.__base_lx_norm_pruning(config, norm='l2')
127 |
128 | def __base_lx_norm_pruning(self, config, norm='l1'):
129 | prune_indexes = {}
130 |
131 | def _compute_lx_norm(m, norm):
132 | weight = m.weight.detach().cpu().numpy()
133 | if isinstance(m, nn.modules.conv._ConvTransposeMixin):
134 | if norm == 'l1':
135 | Lx_norm = np.sum(np.abs(weight), axis=(0, 2, 3))
136 | else:
137 | Lx_norm = np.sum(
138 | np.sqrt(weight ** 2), axis=(0, 2, 3))
139 | else:
140 | if norm == 'l1':
141 | Lx_norm = np.sum(np.abs(weight), axis=(1, 2, 3))
142 | else:
143 | # 注:反卷积维数1对应输出维度
144 | Lx_norm = np.sum(
145 | np.sqrt(weight ** 2), axis=(1, 2, 3))
146 | return Lx_norm
147 |
148 | # 全局阈值剪枝法(最好别用,效果不佳)
149 | if config['global_pruning']:
150 | filter_record = []
151 | for i, m in enumerate(self.model_modules):
152 | if isinstance(m, nn.modules.conv._ConvNd):
153 | Lx_norm = _compute_lx_norm(m, norm)
154 | filter_record.append(Lx_norm.tolist()) # 记录每层卷积的lx_norm参数
155 |
156 | filter_record = list(chain.from_iterable(filter_record))
157 | total = len(filter_record)
158 | filter_record.sort() # 全局排序
159 | thre_index = int(total * self.compression_ratio)
160 | thre = filter_record[thre_index] # 根据裁剪率确定阈值
161 | for i, m in enumerate(self.model_modules):
162 | if isinstance(m, nn.modules.conv._ConvNd):
163 | weight = m.weight.detach().cpu().numpy()
164 | # _ConvTransposeMixin只包含反卷积
165 | if isinstance(m, nn.modules.conv._ConvTransposeMixin):
166 | Lx_norm = np.sum(np.abs(weight), axis=(
167 | 0, 2, 3)) # 注:反卷积维数1对应输出维度
168 | else:
169 | Lx_norm = np.sum(np.abs(weight), axis=(1, 2, 3))
170 | num_pruned = min(int(max_ratio*len(Lx_norm)),
171 | len(Lx_norm[Lx_norm < thre])) # 不能全部减去
172 | # 删除低于阈值的卷积核
173 | prune_index = np.argsort(Lx_norm)[:num_pruned].tolist()
174 | prune_indexes[i] = prune_index
175 |
176 | # 局部阈值加指定层
177 | else:
178 | if config['layer_compression_ratio'] is None and config['prune_shortcut'] == 1:
179 | # 需要剪跳连层,并且未指定每一层的裁剪率
180 | config['layer_compression_ratio'] = self._compute_auto_ratios()
181 |
182 | for i, m in enumerate(self.model_modules):
183 | # 逐层裁剪
184 | # _ConvNd包含卷积和反卷积
185 | if isinstance(m, nn.modules.conv._ConvNd):
186 | Lx_norm = _compute_lx_norm(m, norm)
187 |
188 | # 自定义压缩或全自动化压缩时剪跳连层
189 | if config['layer_compression_ratio'] and i in config['layer_compression_ratio']:
190 | num_pruned = int(
191 | out_channels * config['layer_compression_ratio'][i])
192 | # 全自动化压缩时,不剪跳连层
193 | else:
194 | num_pruned = int(out_channels * self.compression_ratio)
195 |
196 | # remove filters with small L1-Norm
197 | prune_index = np.argsort(Lx_norm)[:num_pruned].tolist()
198 | prune_indexes[i] = prune_index
199 | return prune_indexes
200 |
201 | def _compute_auto_ratios(self):
202 | # 如果未指定每层裁剪率,则自动生成
203 | layer_compression_ratio = {}
204 | mid_value = self.compression_ratio
205 |
206 | one_value = (1-mid_value)/4 if mid_value >= 0.43 else mid_value/4
207 | values = [mid_value-one_value*3, mid_value-one_value*2, mid_value-one_value,
208 | mid_value, mid_value+one_value, mid_value+one_value*2, mid_value+one_value*3]
209 | # 分为七级裁剪率,从浅到深,从小到大
210 | # 均值为期望裁剪率
211 | layer_cnt = 0
212 | for i, m in enumerate(self.model_modules):
213 | if isinstance(m, nn.modules.conv._ConvNd):
214 | layer_compression_ratio[i] = 0
215 | layer_cnt += 1
216 | layers_of_class = layer_cnt/7
217 | conv_cnt = 0
218 | for i, m in enumerate(self.model_modules):
219 | if isinstance(m, nn.modules.conv._ConvNd):
220 | layer_compression_ratio[i] = values[math.floor(
221 | conv_cnt/layers_of_class)]
222 | conv_cnt += 1
223 | return layer_compression_ratio
224 |
225 |
226 | if __name__ == "__main__":
227 | from resnet_small import resnet_small
228 | model = resnet_small()
229 | slim = Autoslim(model, inputs=torch.randn(
230 | 1, 3, 224, 224), compression_ratio=0.5)
231 | slim.l1_norm_pruning()
232 | print(model)
233 |
--------------------------------------------------------------------------------
/inclearn/convnet/resnet.py:
--------------------------------------------------------------------------------
1 | """Taken & slightly modified from:
2 | * https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
3 | """
4 | import torch
5 | import torch.nn as nn
6 | import torch.utils.model_zoo as model_zoo
7 | from torch.nn import functional as F
8 |
9 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152']
10 |
11 | model_urls = {
12 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
13 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
14 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
15 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
16 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
17 | }
18 |
19 |
20 | def conv3x3(in_planes, out_planes, stride=1):
21 | """3x3 convolution with padding"""
22 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
23 |
24 |
25 | def conv1x1(in_planes, out_planes, stride=1):
26 | """1x1 convolution"""
27 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
28 |
29 |
30 | class BasicBlock(nn.Module):
31 | expansion = 1
32 |
33 | def __init__(self, inplanes, planes, stride=1, downsample=None, remove_last_relu=False):
34 | super(BasicBlock, self).__init__()
35 | self.conv1 = conv3x3(inplanes, planes, stride)
36 | self.bn1 = nn.BatchNorm2d(planes)
37 | self.relu = nn.ReLU(inplace=True)
38 | self.conv2 = conv3x3(planes, planes)
39 | self.bn2 = nn.BatchNorm2d(planes)
40 | self.downsample = downsample
41 | self.stride = stride
42 | self.remove_last_relu = remove_last_relu
43 |
44 | def forward(self, x):
45 | identity = x
46 |
47 | out = self.conv1(x)
48 | out = self.bn1(out)
49 | out = self.relu(out)
50 |
51 | out = self.conv2(out)
52 | out = self.bn2(out)
53 |
54 | if self.downsample is not None:
55 | identity = self.downsample(x)
56 |
57 | out += identity
58 | if not self.remove_last_relu:
59 | out = self.relu(out)
60 | return out
61 |
62 |
63 | class Bottleneck(nn.Module):
64 | expansion = 4
65 |
66 | def __init__(self, inplanes, planes, stride=1, downsample=None):
67 | super(Bottleneck, self).__init__()
68 | self.conv1 = conv1x1(inplanes, planes)
69 | self.bn1 = nn.BatchNorm2d(planes)
70 | self.conv2 = conv3x3(planes, planes, stride)
71 | self.bn2 = nn.BatchNorm2d(planes)
72 | self.conv3 = conv1x1(planes, planes * self.expansion)
73 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
74 | self.relu = nn.ReLU(inplace=True)
75 | self.downsample = downsample
76 | self.stride = stride
77 |
78 | def forward(self, x):
79 | identity = x
80 |
81 | out = self.conv1(x)
82 | out = self.bn1(out)
83 | out = self.relu(out)
84 |
85 | out = self.conv2(out)
86 | out = self.bn2(out)
87 | out = self.relu(out)
88 |
89 | out = self.conv3(out)
90 | out = self.bn3(out)
91 |
92 | if self.downsample is not None:
93 | identity = self.downsample(x)
94 |
95 | out += identity
96 | out = self.relu(out)
97 |
98 | return out
99 |
100 |
101 | class ChannelAttention(nn.Module):
102 | def __init__(self, in_planes, ratio=16):
103 | super(ChannelAttention, self).__init__()
104 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
105 | self.max_pool = nn.AdaptiveMaxPool2d(1)
106 | # 共享权重的MLP
107 | self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False)
108 | self.relu1 = nn.ReLU()
109 | self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)
110 | self.sigmoid = nn.Sigmoid()
111 |
112 | def forward(self, x):
113 | avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
114 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
115 | out = avg_out + max_out
116 | return self.sigmoid(out)
117 |
118 |
119 | class SpatialAttention(nn.Module):
120 | def __init__(self, kernel_size=7):
121 | super(SpatialAttention, self).__init__()
122 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
123 | padding = 3 if kernel_size == 7 else 1
124 | self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
125 | self.sigmoid = nn.Sigmoid()
126 |
127 | def forward(self, x):
128 | avg_out = torch.mean(x, dim=1, keepdim=True)
129 | max_out, _ = torch.max(x, dim=1, keepdim=True)
130 | x = torch.cat([avg_out, max_out], dim=1)
131 | x = self.conv1(x)
132 | return self.sigmoid(x)
133 |
134 | class SEFeatureAt(nn.Module):
135 | def __init__(self, inplanes, type, at_res):
136 | super(SEFeatureAt, self).__init__()
137 | self.se = nn.Sequential(
138 | nn.AdaptiveAvgPool2d((1,1)),
139 | nn.Conv2d(inplanes,inplanes//16,kernel_size=1),
140 | nn.ReLU(),
141 | nn.Conv2d(inplanes//16,inplanes,kernel_size=1),
142 | nn.Sigmoid()
143 | )
144 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
145 | self.type = type
146 | self.at_res = at_res
147 | self.ca = ChannelAttention(inplanes)
148 | self.sa = SpatialAttention()
149 |
150 | def forward(self, x):
151 | residual = x
152 | if self.type == "se":
153 | attention = self.se(x)
154 | x = x * attention
155 | elif self.type == "ffm":
156 | x = self.ca(x) * x
157 | x = self.sa(x) * x
158 | if self.at_res:
159 | x += residual
160 | x = self.avgpool(x)
161 | x = x.view(x.size(0), -1)
162 |
163 | return x
164 |
165 |
166 | class ResNet(nn.Module):
167 | def __init__(self,
168 | block,
169 | layers,
170 | nf=64,
171 | zero_init_residual=True,
172 | dataset='cifar',
173 | start_class=0,
174 | remove_last_relu=False):
175 | super(ResNet, self).__init__()
176 | self.remove_last_relu = remove_last_relu
177 | self.inplanes = nf
178 | if 'cifar' in dataset:
179 | self.conv1 = nn.Sequential(nn.Conv2d(3, nf, kernel_size=3, stride=1, padding=1, bias=False),
180 | nn.BatchNorm2d(nf), nn.ReLU(inplace=True))
181 | elif 'imagenet' in dataset:
182 | if start_class == 0:
183 | self.conv1 = nn.Sequential(
184 | nn.Conv2d(3, nf, kernel_size=7, stride=2, padding=3, bias=False),
185 | nn.BatchNorm2d(nf),
186 | nn.ReLU(inplace=True),
187 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
188 | )
189 | else:
190 | # Following PODNET implmentation
191 | self.conv1 = nn.Sequential(
192 | nn.Conv2d(3, nf, kernel_size=3, stride=1, padding=1, bias=False),
193 | nn.BatchNorm2d(nf),
194 | nn.ReLU(inplace=True),
195 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
196 | )
197 |
198 | self.layer1 = self._make_layer(block, 1 * nf, layers[0])
199 | self.layer2 = self._make_layer(block, 2 * nf, layers[1], stride=2)
200 | self.layer3 = self._make_layer(block, 4 * nf, layers[2], stride=2)
201 | self.layer4 = self._make_layer(block, 8 * nf, layers[3], stride=2, remove_last_relu=remove_last_relu)
202 |
203 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
204 |
205 | self.out_dim = 8 * nf * block.expansion
206 |
207 | for m in self.modules():
208 | if isinstance(m, nn.Conv2d):
209 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
210 | elif isinstance(m, nn.BatchNorm2d):
211 | nn.init.constant_(m.weight, 1)
212 | nn.init.constant_(m.bias, 0)
213 |
214 | # Zero-initialize the last BN in each residual branch,
215 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
216 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
217 | if zero_init_residual:
218 | for m in self.modules():
219 | if isinstance(m, Bottleneck):
220 | nn.init.constant_(m.bn3.weight, 0)
221 | elif isinstance(m, BasicBlock):
222 | nn.init.constant_(m.bn2.weight, 0)
223 |
224 | def _make_layer(self, block, planes, blocks, remove_last_relu=False, stride=1):
225 | downsample = None
226 | if stride != 1 or self.inplanes != planes * block.expansion:
227 | downsample = nn.Sequential(
228 | conv1x1(self.inplanes, planes * block.expansion, stride),
229 | nn.BatchNorm2d(planes * block.expansion),
230 | )
231 |
232 | layers = []
233 | layers.append(block(self.inplanes, planes, stride, downsample))
234 | self.inplanes = planes * block.expansion
235 | if remove_last_relu:
236 | for i in range(1, blocks - 1):
237 | layers.append(block(self.inplanes, planes))
238 | layers.append(block(self.inplanes, planes, remove_last_relu=True))
239 | else:
240 | for _ in range(1, blocks):
241 | layers.append(block(self.inplanes, planes))
242 |
243 | return nn.Sequential(*layers)
244 |
245 | def reset_bn(self):
246 | for m in self.modules():
247 | if isinstance(m, nn.BatchNorm2d):
248 | m.reset_running_stats()
249 |
250 | def forward(self, x):
251 | x = self.conv1(x)
252 | x = self.layer1(x)
253 | x = self.layer2(x)
254 | x = self.layer3(x)
255 | x = self.layer4(x)
256 | # x = self.avgpool(x)
257 | # x = x.view(x.size(0), -1)
258 | return x
259 |
260 |
261 | def resnet18(pretrained=False, **kwargs):
262 | """Constructs a ResNet-18 model.
263 |
264 | """
265 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
266 | if pretrained:
267 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
268 | return model
269 |
270 |
271 | def resnet34(pretrained=False, **kwargs):
272 | """Constructs a ResNet-34 model.
273 |
274 | """
275 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
276 | if pretrained:
277 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
278 | return model
279 |
280 |
281 | def resnet50(pretrained=False, **kwargs):
282 | """Constructs a ResNet-50 model.
283 |
284 | """
285 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
286 | if pretrained:
287 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
288 | return model
289 |
290 |
291 | def resnet101(pretrained=False, **kwargs):
292 | """Constructs a ResNet-101 model.
293 |
294 | """
295 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
296 | if pretrained:
297 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
298 | return model
299 |
300 |
301 | def resnet152(pretrained=False, **kwargs):
302 | """Constructs a ResNet-152 model.
303 |
304 | """
305 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
306 | if pretrained:
307 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
308 | return model
309 |
--------------------------------------------------------------------------------
/inclearn/prune/sensitivity_analysis.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | import copy
5 | import csv
6 | import logging
7 | from collections import OrderedDict
8 |
9 | import numpy as np
10 | import torch.nn as nn
11 |
12 | # FIXME: I don't know where "utils" should be
13 | SUPPORTED_OP_NAME = ['Conv2d', 'Conv1d']
14 | SUPPORTED_OP_TYPE = [getattr(nn, name) for name in SUPPORTED_OP_NAME]
15 |
16 | logger = logging.getLogger('Sensitivity_Analysis')
17 | logger.setLevel(logging.INFO)
18 |
19 |
20 | class SensitivityAnalysis:
21 | def __init__(self, model, val_func, sparsities=None, prune_type='l1', early_stop_mode=None, early_stop_value=None):
22 | """
23 | Perform sensitivity analysis for this model.
24 | Parameters
25 | ----------
26 | model : torch.nn.Module
27 | the model to perform sensitivity analysis
28 | val_func : function
29 | validation function for the model. Due to
30 | different models may need different dataset/criterion
31 | , therefore the user need to cover this part by themselves.
32 | In the val_func, the model should be tested on the validation dateset,
33 | and the validation accuracy/loss should be returned as the output of val_func.
34 | There are no restrictions on the input parameters of the val_function.
35 | User can use the val_args, val_kwargs parameters in analysis
36 | to pass all the parameters that val_func needed.
37 | sparsities : list
38 | The sparsity list provided by users. This parameter is set when the user
39 | only wants to test some specific sparsities. In the sparsity list, each element
40 | is a sparsity value which means how much weight the pruner should prune. Take
41 | [0.25, 0.5, 0.75] for an example, the SensitivityAnalysis will prune 25% 50% 75%
42 | weights gradually for each layer.
43 | prune_type : str
44 | The pruner type used to prune the conv layers, default is 'l1',
45 | and 'l2', 'fine-grained' is also supported.
46 | early_stop_mode : str
47 | If this flag is set, the sensitivity analysis
48 | for a conv layer will early stop when the validation metric(
49 | for example, accurracy/loss) has alreay meet the threshold. We
50 | support four different early stop modes: minimize, maximize, dropped,
51 | raised. The default value is None, which means the analysis won't stop
52 | until all given sparsities are tested. This option should be used with
53 | early_stop_value together.
54 |
55 | minimize: The analysis stops when the validation metric return by the val_func
56 | lower than early_stop_value.
57 | maximize: The analysis stops when the validation metric return by the val_func
58 | larger than early_stop_value.
59 | dropped: The analysis stops when the validation metric has dropped by early_stop_value.
60 | raised: The analysis stops when the validation metric has raised by early_stop_value.
61 | early_stop_value : float
62 | This value is used as the threshold for different earlystop modes.
63 | This value is effective only when the early_stop_mode is set.
64 |
65 | """
66 | from nni.algorithms.compression.pytorch.pruning.constants_pruner import PRUNER_DICT
67 |
68 | self.model = model
69 | self.val_func = val_func
70 | self.target_layer = OrderedDict()
71 | self.ori_state_dict = copy.deepcopy(self.model.state_dict())
72 | self.target_layer = {}
73 | self.sensitivities = {}
74 | if sparsities is not None:
75 | self.sparsities = sorted(sparsities)
76 | else:
77 | self.sparsities = np.arange(0.1, 1.0, 0.1)
78 | self.sparsities = [np.round(x, 2) for x in self.sparsities]
79 | self.Pruner = PRUNER_DICT[prune_type]
80 | self.early_stop_mode = early_stop_mode
81 | self.early_stop_value = early_stop_value
82 | self.ori_metric = None # original validation metric for the model
83 | # already_pruned is for the iterative sensitivity analysis
84 | # For example, sensitivity_pruner iteratively prune the target
85 | # model according to the sensitivity. After each round of
86 | # pruning, the sensitivity_pruner will test the new sensitivity
87 | # for each layer
88 | self.already_pruned = {}
89 | self.model_parse()
90 |
91 | @property
92 | def layers_count(self):
93 | return len(self.target_layer)
94 |
95 | def model_parse(self):
96 | for name, submodel in self.model.named_modules():
97 | for op_type in SUPPORTED_OP_TYPE:
98 | if isinstance(submodel, op_type):
99 | self.target_layer[name] = submodel
100 | self.already_pruned[name] = 0
101 |
102 | def _need_to_stop(self, ori_metric, cur_metric):
103 | """
104 | Judge if meet the stop conditon(early_stop, min_threshold,
105 | max_threshold).
106 | Parameters
107 | ----------
108 | ori_metric : float
109 | original validation metric
110 | cur_metric : float
111 | current validation metric
112 |
113 | Returns
114 | -------
115 | stop : bool
116 | if stop the sensitivity analysis
117 | """
118 | if self.early_stop_mode is None:
119 | # early stop mode is not enable
120 | return False
121 | assert self.early_stop_value is not None
122 | if self.early_stop_mode == 'minimize':
123 | if cur_metric < self.early_stop_value:
124 | return True
125 | elif self.early_stop_mode == 'maximize':
126 | if cur_metric > self.early_stop_value:
127 | return True
128 | elif self.early_stop_mode == 'dropped':
129 | if cur_metric < ori_metric - self.early_stop_value:
130 | return True
131 | elif self.early_stop_mode == 'raised':
132 | if cur_metric > ori_metric + self.early_stop_value:
133 | return True
134 | return False
135 |
136 | def analysis(self, val_args=None, val_kwargs=None, specified_layers=None):
137 | """
138 | This function analyze the sensitivity to pruning for
139 | each conv layer in the target model.
140 | If start and end are not set, we analyze all the conv
141 | layers by default. Users can specify several layers to
142 | analyze or parallelize the analysis process easily through
143 | the start and end parameter.
144 |
145 | Parameters
146 | ----------
147 | val_args : list
148 | args for the val_function
149 | val_kwargs : dict
150 | kwargs for the val_funtion
151 | specified_layers : list
152 | list of layer names to analyze sensitivity.
153 | If this variable is set, then only analyze
154 | the conv layers that specified in the list.
155 | User can also use this option to parallelize
156 | the sensitivity analysis easily.
157 | Returns
158 | -------
159 | sensitivities : dict
160 | dict object that stores the trajectory of the
161 | accuracy/loss when the prune ratio changes
162 | """
163 | if val_args is None:
164 | val_args = []
165 | if val_kwargs is None:
166 | val_kwargs = {}
167 | # Get the original validation metric(accuracy/loss) before pruning
168 | # Get the accuracy baseline before starting the analysis.
169 | self.ori_metric = self.val_func(*val_args, **val_kwargs)
170 | namelist = list(self.target_layer.keys())
171 | if specified_layers is not None:
172 | # only analyze several specified conv layers
173 | namelist = list(filter(lambda x: x in specified_layers, namelist))
174 | for name in namelist:
175 | self.sensitivities[name] = {}
176 | for sparsity in self.sparsities:
177 | # here the sparsity is the relative sparsity of the
178 | # the remained weights
179 | # Calculate the actual prune ratio based on the already pruned ratio
180 | real_sparsity = (
181 | 1.0 - self.already_pruned[name]) * sparsity + self.already_pruned[name]
182 | # TODO In current L1/L2 Filter Pruner, the 'op_types' is still necessary
183 | # I think the L1/L2 Pruner should specify the op_types automaticlly
184 | # according to the op_names
185 | cfg = [{'sparsity': real_sparsity, 'op_names': [
186 | name], 'op_types': ['Conv2d']}]
187 | pruner = self.Pruner(self.model, cfg)
188 | pruner.compress()
189 | val_metric = self.val_func(*val_args, **val_kwargs)
190 | logger.info('Layer: %s Sparsity: %.2f Validation Metric: %.4f',
191 | name, real_sparsity, val_metric)
192 |
193 | self.sensitivities[name][sparsity] = val_metric
194 | pruner._unwrap_model()
195 | del pruner
196 | # check if the current metric meet the stop condition
197 | if self._need_to_stop(self.ori_metric, val_metric):
198 | break
199 |
200 | # reset the weights pruned by the pruner, because the
201 | # input sparsities is sorted, so we donnot need to reset
202 | # weight of the layer when the sparsity changes, instead,
203 | # we only need reset the weight when the pruning layer changes.
204 | self.model.load_state_dict(self.ori_state_dict)
205 |
206 | return self.sensitivities
207 |
208 | def export(self, filepath):
209 | """
210 | Export the results of the sensitivity analysis
211 | to a csv file. The firstline of the csv file describe the content
212 | structure. The first line is constructed by 'layername' and sparsity
213 | list. Each line below records the validation metric returned by val_func
214 | when this layer is under different sparsities. Note that, due to the early_stop
215 | option, some layers may not have the metrics under all sparsities.
216 |
217 | layername, 0.25, 0.5, 0.75
218 | conv1, 0.6, 0.55
219 | conv2, 0.61, 0.57, 0.56
220 |
221 | Parameters
222 | ----------
223 | filepath : str
224 | Path of the output file
225 | """
226 | str_sparsities = [str(x) for x in self.sparsities]
227 | header = ['layername'] + str_sparsities
228 | with open(filepath, 'w') as csvf:
229 | csv_w = csv.writer(csvf)
230 | csv_w.writerow(header)
231 | for layername in self.sensitivities:
232 | row = []
233 | row.append(layername)
234 | for sparsity in sorted(self.sensitivities[layername].keys()):
235 | row.append(self.sensitivities[layername][sparsity])
236 | csv_w.writerow(row)
237 |
238 | def update_already_pruned(self, layername, ratio):
239 | """
240 | Set the already pruned ratio for the target layer.
241 | """
242 | self.already_pruned[layername] = ratio
243 |
244 | def load_state_dict(self, state_dict):
245 | """
246 | Update the weight of the model
247 | """
248 | self.ori_state_dict = copy.deepcopy(state_dict)
249 | self.model.load_state_dict(self.ori_state_dict)
250 |
--------------------------------------------------------------------------------
/inclearn/tools/utils.py:
--------------------------------------------------------------------------------
1 | import random
2 | from copy import deepcopy
3 | import numpy as np
4 | import datetime
5 |
6 | import torch
7 |
8 | from inclearn.tools.metrics import ClassErrorMeter
9 | from sklearn.metrics import classification_report
10 |
11 |
12 | def get_date():
13 | return datetime.datetime.now().strftime("%Y%m%d")
14 |
15 |
16 | def to_onehot(targets, n_classes):
17 | if not hasattr(targets, "device"):
18 | targets = torch.from_numpy(targets)
19 | onehot = torch.zeros(targets.shape[0], n_classes).to(targets.device)
20 | onehot.scatter_(dim=1, index=targets.long().view(-1, 1), value=1.0)
21 | return onehot
22 |
23 |
24 | def get_class_loss(network, cur_n_cls, loader):
25 | class_loss = torch.zeros(cur_n_cls)
26 | n_cls_data = torch.zeros(cur_n_cls) # the num of imgs for cls i.
27 | EPS = 1e-10
28 | task_size = 10
29 | network.eval()
30 | for x, y in loader:
31 | x, y = x.cuda(), y.cuda()
32 | preds = network(x)['logit'].softmax(dim=1)
33 | # preds[:,-task_size:] = preds[:,-task_size:].softmax(dim=1)
34 | for i, lbl in enumerate(y):
35 | class_loss[lbl] = class_loss[lbl] - (preds[i, lbl] + EPS).detach().log().cpu()
36 | n_cls_data[lbl] += 1
37 | class_loss = class_loss / n_cls_data
38 | return class_loss
39 |
40 |
41 | def get_featnorm_grouped_by_class(network, cur_n_cls, loader):
42 | """
43 | Ret: feat_norms: list of list
44 | feat_norms[idx] is the list of feature norm of the images for class idx.
45 | """
46 | feats = [[] for i in range(cur_n_cls)]
47 | feat_norms = np.zeros(cur_n_cls)
48 | network.eval()
49 | with torch.no_grad():
50 | for x, y in loader:
51 | x = x.cuda()
52 | feat = network(x)['feature'].cpu()
53 | for i, lbl in enumerate(y):
54 | if lbl >= cur_n_cls:
55 | continue
56 | feats[lbl].append(feat[y == lbl])
57 | for i in range(len(feats)):
58 | if len(feats[i]) != 0:
59 | feat_cls = torch.cat((feats[i]))
60 | feat_norms[i] = torch.norm(feat_cls, p=2, dim=1).mean().data.numpy()
61 | return feat_norms
62 |
63 |
64 | def set_seed(seed):
65 | print("Set seed", seed)
66 | random.seed(seed)
67 | np.random.seed(seed)
68 | torch.manual_seed(seed)
69 | torch.cuda.manual_seed_all(seed)
70 | torch.backends.cudnn.deterministic = True # This will slow down training.
71 | torch.backends.cudnn.benchmark = False
72 |
73 |
74 | def display_weight_norm(logger, network, increments, tag):
75 | weight_norms = [[] for _ in range(len(increments))]
76 | increments = np.cumsum(np.array(increments))
77 | for idx in range(network.module.classifier.weight.shape[0]):
78 | norm = torch.norm(network.module.classifier.weight[idx].data, p=2).item()
79 | for i in range(len(weight_norms)):
80 | if idx < increments[i]:
81 | break
82 | weight_norms[i].append(round(norm, 3))
83 | avg_weight_norm = []
84 | # all_weight_norms = []
85 | for idx in range(len(weight_norms)):
86 | # all_weight_norms += weight_norms[idx]
87 | # logger.info("task %s: Weight norm per class %s" % (str(idx), str(weight_norms[idx])))
88 | avg_weight_norm.append(round(np.array(weight_norms[idx]).mean(), 3))
89 |
90 | logger.info("%s: Weight norm per task %s" % (tag, str(avg_weight_norm)))
91 |
92 |
93 | def display_feature_norm(logger, network, loader, n_classes, increments, tag, return_norm=False):
94 | avg_feat_norm_per_cls = get_featnorm_grouped_by_class(network, n_classes, loader)
95 | feature_norms = [[] for _ in range(len(increments))]
96 | increments = np.cumsum(np.array(increments))
97 | for idx in range(len(avg_feat_norm_per_cls)):
98 | for i in range(len(feature_norms)):
99 | if idx < increments[i]: #Find the mapping from class idx to step i.
100 | break
101 | feature_norms[i].append(round(avg_feat_norm_per_cls[idx], 3))
102 | avg_feature_norm = []
103 | for idx in range(len(feature_norms)):
104 | avg_feature_norm.append(round(np.array(feature_norms[idx]).mean(), 3))
105 | logger.info("%s: Feature norm per class %s" % (tag, str(avg_feature_norm)))
106 | if return_norm:
107 | return avg_feature_norm
108 | else:
109 | return
110 |
111 |
112 | def check_loss(loss):
113 | return not bool(torch.isnan(loss).item()) and bool((loss >= 0.0).item())
114 |
115 | def class2task(class_form, classnum):
116 | target_form = deepcopy(class_form)
117 | for i in range(classnum):
118 | mask = (target_form==i)
119 | target_form[mask] = -(i//10)-1
120 | target_form = (target_form+1)*(-1)
121 | return target_form
122 |
123 | def maskclass(pred, target, classnum, type='new'):
124 | # type 为new,遮盖new的class,old遮盖旧的,all遮盖新旧
125 | target_form = deepcopy(target)
126 | pred_form = deepcopy(pred)
127 | if type == 'old':
128 | mask = np.logical_or(pred_form<(classnum-10), target_form<(classnum-10))
129 | pred_form[mask] = 0
130 | target_form[mask] = 0
131 |
132 | if type == 'new':
133 | mask = np.logical_or(pred_form>=(classnum-10), target_form>=(classnum-10))
134 | pred_form[mask] = 1000
135 | target_form[mask] = 1000
136 |
137 | if type == 'all':
138 | mask = (target_form>=(classnum-10))
139 | target_form[mask] = 1000
140 | mask = (pred_form>=(classnum-10))
141 | pred_form[mask] = 1000
142 |
143 | mask = (target_form<(classnum-10))
144 | target_form[mask] = 0
145 | mask = (pred_form<(classnum-10))
146 | pred_form[mask] = 0
147 |
148 | all_err = np.sum(pred_form!=target_form)
149 | pred_form1 = deepcopy(pred_form)
150 | mask = (target_form<(classnum-10))
151 | pred_form[mask] = 0
152 |
153 | new_old_err = np.sum(pred_form!=target_form)
154 |
155 | mask = (target_form>=(classnum-10))
156 | pred_form1[mask] = 1000
157 | old_new_err = np.sum(pred_form1!=target_form)
158 | return all_err, new_old_err, old_new_err
159 |
160 |
161 | return pred_form, target_form
162 |
163 | def compute_old_new_mix(ypred, ytrue, increments, n_classes, task_order):
164 |
165 | task_means = []
166 | for i in range (n_classes//10):
167 | taski_mask = np.logical_and(ytrue>=i*10, ytrue<(i+1)*10)
168 | task_i_mean = ypred[np.arange(ytrue.shape[0]), ytrue][taski_mask].mean().item()
169 | task_means.append(task_i_mean)
170 | task_mean = ypred[np.arange(ytrue.shape[0]), ytrue].mean().item()
171 |
172 | classnum = ypred.shape[1]
173 | ypred = ypred.argmax(1)
174 | all_err = np.sum(ypred!=ytrue)
175 | ypred_task = class2task(ypred, classnum)
176 | ytrue_task = class2task(ytrue, classnum)
177 | err_among_task = np.sum(ypred_task!=ytrue_task)
178 | err_inner_task = all_err - err_among_task
179 | # print("all err : {}\n among task err: {}\n inner task err: {}\n".format(all_err, err_among_task, err_inner_task))
180 |
181 |
182 | ypred_new, ytrue_new = maskclass(ypred, ytrue, n_classes, 'old')
183 | new_err = np.sum(ypred_new!=ytrue_new)
184 |
185 | ypred_old, ytrue_old = maskclass(ypred, ytrue, n_classes, 'new')
186 | old_err = np.sum(ypred_old!=ytrue_old)
187 |
188 | all_err, new_old_err, old_new_err = maskclass(ypred, ytrue, n_classes, 'all')
189 | print("******all_err:****", all_err)
190 |
191 | all_acc = {"task_mean": task_mean, "task_means":task_means, "new_err": new_err, "old_err":old_err, "new_old_err": new_old_err, "old_new_err": old_new_err, "err_among_task": err_among_task, "err_inner_task": err_inner_task}
192 |
193 | return all_acc
194 |
195 |
196 | def compute_task_accuracy(ypred, ytrue, increments, n_classes, task_order):
197 | task_mean = ypred[np.arange(ytrue.shape[0]), ytrue].mean().item()
198 | classnum = ypred.shape[1]
199 | ypred = ypred.argmax(1)
200 | ypred_task = class2task(ypred, classnum)
201 | ytrue_task = class2task(ytrue, classnum)
202 |
203 |
204 | all_acc = {"task_mean": task_mean, "class_info": classification_report(ytrue, ypred), "task_info": classification_report(ytrue_task, ypred_task)}
205 |
206 | return all_acc
207 |
208 | def compute_accuracy(ypred, ytrue, increments, n_classes):
209 | all_acc = {"top1": {}, "top5": {}}
210 | topk = 5 if n_classes >= 5 else n_classes
211 | ncls = np.unique(ytrue).shape[0]
212 | if topk > ncls:
213 | topk = ncls
214 | all_acc_meter = ClassErrorMeter(topk=[1, topk], accuracy=True)
215 | all_acc_meter.add(ypred, ytrue)
216 | all_acc["top1"]["total"] = round(all_acc_meter.value()[0], 3)
217 | all_acc["top5"]["total"] = round(all_acc_meter.value()[1], 3)
218 | # all_acc["total"] = round((ypred == ytrue).sum() / len(ytrue), 3)
219 |
220 | # for class_id in range(0, np.max(ytrue), task_size):
221 | start, end = 0, 0
222 | for i in range(len(increments)):
223 | if increments[i] <= 0:
224 | pass
225 | else:
226 | start = end
227 | end += increments[i]
228 |
229 | idxes = np.where(np.logical_and(ytrue >= start, ytrue < end))[0]
230 | topk_ = 5 if increments[i] >= 5 else increments[i]
231 | ncls = np.unique(ytrue[idxes]).shape[0]
232 | if topk_ > ncls:
233 | topk_ = ncls
234 | cur_acc_meter = ClassErrorMeter(topk=[1, topk_], accuracy=True)
235 | cur_acc_meter.add(ypred[idxes], ytrue[idxes])
236 | top1_acc = (ypred[idxes].argmax(1) == ytrue[idxes]).sum() / idxes.shape[0] * 100
237 | if start < end:
238 | label = "{}-{}".format(str(start).rjust(2, "0"), str(end - 1).rjust(2, "0"))
239 | else:
240 | label = "{}-{}".format(str(start).rjust(2, "0"), str(end).rjust(2, "0"))
241 | all_acc["top1"][label] = round(top1_acc, 3)
242 | all_acc["top5"][label] = round(cur_acc_meter.value()[1], 3)
243 | # all_acc[label] = round((ypred[idxes] == ytrue[idxes]).sum() / len(idxes), 3)
244 |
245 | return all_acc
246 |
247 |
248 | def make_logger(log_name, savedir='.logs/'):
249 | """Set up the logger for saving log file on the disk
250 | Args:
251 | cfg: configuration dict
252 |
253 | Return:
254 | logger: a logger for record essential information
255 | """
256 | import logging
257 | import os
258 | from logging.config import dictConfig
259 | import time
260 |
261 | logging_config = dict(
262 | version=1,
263 | formatters={'f_t': {
264 | 'format': '\n %(asctime)s | %(levelname)s | %(name)s \t %(message)s'
265 | }},
266 | handlers={
267 | 'stream_handler': {
268 | 'class': 'logging.StreamHandler',
269 | 'formatter': 'f_t',
270 | 'level': logging.INFO
271 | },
272 | 'file_handler': {
273 | 'class': 'logging.FileHandler',
274 | 'formatter': 'f_t',
275 | 'level': logging.INFO,
276 | 'filename': None,
277 | }
278 | },
279 | root={
280 | 'handlers': ['stream_handler', 'file_handler'],
281 | 'level': logging.DEBUG,
282 | },
283 | )
284 | # set up logger
285 | log_file = '{}.log'.format(log_name)
286 | # if folder not exist,create it
287 | if not os.path.exists(savedir):
288 | os.makedirs(savedir)
289 | log_file_path = os.path.join(savedir, log_file)
290 |
291 | logging_config['handlers']['file_handler']['filename'] = log_file_path
292 |
293 | open(log_file_path, 'w').close() # Clear the content of logfile
294 | # get logger from dictConfig
295 | dictConfig(logging_config)
296 |
297 | logger = logging.getLogger()
298 |
299 | return logger
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import os.path as osp
4 | import copy
5 | import time
6 | import shutil
7 | import cProfile
8 | import logging
9 | from pathlib import Path
10 | import numpy as np
11 | import random
12 | from easydict import EasyDict as edict
13 | from tensorboardX import SummaryWriter
14 | import os
15 | import inclearn.prune as pruning
16 |
17 | os.environ['CUDA_VISIBLE_DEVICES']='0'
18 |
19 | repo_name = 'TCIL'
20 | base_dir = osp.realpath(".")[:osp.realpath(".").index(repo_name) + len(repo_name)]
21 | sys.path.insert(0, base_dir)
22 |
23 | from sacred import Experiment
24 | ex = Experiment(base_dir=base_dir, save_git_info=False)
25 |
26 |
27 | import torch
28 |
29 | from inclearn.tools import factory, results_utils, utils
30 | from inclearn.learn.pretrain import pretrain
31 | from inclearn.tools.metrics import IncConfusionMeter
32 |
33 | def initialization(config, seed, mode, exp_id):
34 |
35 | torch.backends.cudnn.benchmark = True # This will result in non-deterministic results.
36 | # ex.captured_out_filter = lambda text: 'Output capturing turned off.'
37 | cfg = edict(config)
38 | utils.set_seed(cfg['seed'])
39 | if exp_id is None:
40 | exp_id = -1
41 | cfg.exp.savedir = "./logs"
42 | logger = utils.make_logger(f"exp{exp_id}_{cfg.exp.name}_{mode}", savedir=cfg.exp.savedir)
43 |
44 | # Tensorboard
45 | exp_name = f'{exp_id}_{cfg["exp"]["name"]}' if exp_id is not None else f'../inbox/{cfg["exp"]["name"]}'
46 | tensorboard_dir = cfg["exp"]["tensorboard_dir"] + f"/{exp_name}"
47 |
48 | # If not only save latest tensorboard log.
49 | # if Path(tensorboard_dir).exists():
50 | # shutil.move(tensorboard_dir, cfg["exp"]["tensorboard_dir"] + f"/../inbox/{time.time()}_{exp_name}")
51 |
52 | tensorboard = SummaryWriter(tensorboard_dir)
53 |
54 | return cfg, logger, tensorboard
55 |
56 |
57 | @ex.command
58 | def train(_run, _rnd, _seed):
59 | cfg, ex.logger, tensorboard = initialization(_run.config, _seed, "train", _run._id)
60 | ex.logger.info(cfg)
61 | cfg.data_folder = osp.join(base_dir, "data")
62 |
63 | start_time = time.time()
64 | _train(cfg, _run, ex, tensorboard)
65 | ex.logger.info("Training finished in {}s.".format(int(time.time() - start_time)))
66 |
67 |
68 | def _train(cfg, _run, ex, tensorboard):
69 | device = factory.set_device(cfg)
70 | trial_i = cfg['trial']
71 |
72 | inc_dataset = factory.get_data(cfg, trial_i)
73 | ex.logger.info("classes_order")
74 | ex.logger.info(inc_dataset.class_order)
75 |
76 | model = factory.get_model(cfg, trial_i, _run, ex, tensorboard, inc_dataset)
77 |
78 | if _run.meta_info["options"]["--file_storage"] is not None:
79 | _save_dir = osp.join(_run.meta_info["options"]["--file_storage"], str(_run._id))
80 | else:
81 | _save_dir = cfg["exp"]["ckptdir"]
82 |
83 | results = results_utils.get_template_results(cfg)
84 |
85 | for task_i in range(inc_dataset.n_tasks):
86 | task_info, train_loader, val_loader, test_loader = inc_dataset.new_task()
87 |
88 | model.set_task_info(
89 | task=task_info["task"],
90 | total_n_classes=task_info["max_class"],
91 | increment=task_info["increment"],
92 | n_train_data=task_info["n_train_data"],
93 | n_test_data=task_info["n_test_data"],
94 | n_tasks=inc_dataset.n_tasks,
95 | )
96 |
97 | model.before_task(task_i, inc_dataset)
98 | # TODO: Move to incmodel.py
99 | if 'min_class' in task_info:
100 | ex.logger.info("Train on {}->{}.".format(task_info["min_class"], task_info["max_class"]))
101 |
102 | # Pretraining at step0 if needed
103 | if task_i == 0 and cfg["start_class"] > 0:
104 | do_pretrain(cfg, ex, model, device, train_loader, test_loader)
105 | inc_dataset.shared_data_inc = train_loader.dataset.share_memory
106 | elif task_i < cfg['start_task']:
107 | state_dict = torch.load(f'./{cfg.exp.saveckpt}/step{task_i}.ckpt')
108 | model._parallel_network.load_state_dict(state_dict)
109 | inc_dataset.shared_data_inc = train_loader.dataset.share_memory
110 | else:
111 | model.train_task(train_loader, val_loader)
112 | model.after_task(task_i, inc_dataset)
113 |
114 | ex.logger.info("Eval on {}->{}.".format(0, task_info["max_class"]))
115 |
116 | ypred, ytrue = model.eval_task(test_loader)
117 |
118 |
119 | acc_stats = utils.compute_accuracy(ypred, ytrue, increments=model._increments, n_classes=model._n_classes)
120 |
121 | #Logging
122 | model._tensorboard.add_scalar(f"taskaccu/trial{trial_i}", acc_stats["top1"]["total"], task_i)
123 |
124 | _run.log_scalar(f"trial{trial_i}_taskaccu", acc_stats["top1"]["total"], task_i)
125 | _run.log_scalar(f"trial{trial_i}_task_top5_accu", acc_stats["top5"]["total"], task_i)
126 |
127 | ex.logger.info(f"top1:{acc_stats['top1']}")
128 | ex.logger.info(f"top5:{acc_stats['top5']}")
129 |
130 | results["results"].append(acc_stats)
131 |
132 | top1_avg_acc, top5_avg_acc = results_utils.compute_avg_inc_acc(results["results"])
133 |
134 | _run.info[f"trial{trial_i}"][f"avg_incremental_accu_top1"] = top1_avg_acc
135 | _run.info[f"trial{trial_i}"][f"avg_incremental_accu_top5"] = top5_avg_acc
136 | ex.logger.info("Average Incremental Accuracy Top 1: {} Top 5: {}.".format(
137 | _run.info[f"trial{trial_i}"][f"avg_incremental_accu_top1"],
138 | _run.info[f"trial{trial_i}"][f"avg_incremental_accu_top5"],
139 | ))
140 | if cfg["exp"]["name"]:
141 | results_utils.save_results(results, cfg["exp"]["name"])
142 |
143 |
144 | def do_pretrain(cfg, ex, model, device, train_loader, test_loader):
145 | if not os.path.exists(osp.join(ex.base_dir, 'pretrain/')):
146 | os.makedirs(osp.join(ex.base_dir, 'pretrain/'))
147 | model_path = osp.join(
148 | ex.base_dir,
149 | "pretrain/{}_{}_cosine_{}_dynamic_{}_nplus1_{}_{}_trial_{}_{}_seed_{}_start_{}_epoch_{}.pth".format(
150 | cfg["model"],
151 | cfg["convnet"],
152 | cfg["weight_normalization"],
153 | cfg["dea"],
154 | cfg["div_type"],
155 | cfg["dataset"],
156 | cfg["trial"],
157 | cfg["train_head"],
158 | cfg['seed'],
159 | cfg["start_class"],
160 | cfg["pretrain"]["epochs"],
161 | ),
162 | )
163 | if osp.exists(model_path):
164 | print("Load pretrain model")
165 | if hasattr(model._network, "module"):
166 | model._network.module.load_state_dict(torch.load(model_path))
167 | else:
168 | model._network.load_state_dict(torch.load(model_path))
169 | else:
170 | pretrain(cfg, ex, model, device, train_loader, test_loader, model_path)
171 |
172 | @ex.command
173 | def test(_run, _rnd, _seed):
174 | cfg, ex.logger, tensorboard = initialization(_run.config, _seed, "test", _run._id)
175 | ex.logger.info(cfg)
176 |
177 | trial_i = cfg['trial']
178 | cfg.data_folder = osp.join(base_dir, "data")
179 | inc_dataset = factory.get_data(cfg, trial_i)
180 |
181 | ex.logger.info("classes_order")
182 | ex.logger.info(inc_dataset.class_order)
183 |
184 | # inc_dataset._current_task = taski
185 | # train_loader = inc_dataset._get_loader(inc_dataset.data_cur, inc_dataset.targets_cur)
186 | model = factory.get_model(cfg, trial_i, _run, ex, tensorboard, inc_dataset)
187 | model._network.task_size = cfg.increment
188 |
189 | test_results = results_utils.get_template_results(cfg)
190 |
191 | for taski in range(inc_dataset.n_tasks):
192 | task_info, train_loader, _, test_loader = inc_dataset.new_task()
193 | model.set_task_info(
194 | task=task_info["task"],
195 | total_n_classes=task_info["max_class"],
196 | increment=task_info["increment"],
197 | n_train_data=task_info["n_train_data"],
198 | n_test_data=task_info["n_test_data"],
199 | n_tasks=task_info["max_task"]
200 | )
201 | model.before_task(taski, inc_dataset)
202 | state_dict = torch.load(f'./{cfg.exp.saveckpt}/step{taski}.ckpt')
203 | if cfg.get("caculate_params", False):
204 | model._parallel_network.load_state_dict(state_dict,False)
205 | else:
206 | model._parallel_network.load_state_dict(state_dict)
207 |
208 | model.eval()
209 |
210 | #Build exemplars
211 | model.after_task(taski, inc_dataset)
212 |
213 |
214 | ypred, ytrue = model.eval_task(test_loader)
215 |
216 | test_acc_stats = utils.compute_accuracy(ypred, ytrue, increments=model._increments, n_classes=model._n_classes)
217 |
218 | test_acc_task_stats = utils.compute_old_new_mix(ypred, ytrue, increments=model._increments, n_classes=model._n_classes, task_order=inc_dataset.class_order)
219 |
220 | test_results['results'].append(test_acc_stats)
221 | ex.logger.info(f"task{taski} test acc:{test_acc_stats['top1']}")
222 |
223 | # ex.logger.info(f"task{taski} task mean:{test_acc_task_stats['task_mean']} \ntest class\n:{test_acc_task_stats['class_info']} \ntest task\n:{test_acc_task_stats['task_info']}")
224 | ex.logger.info(f"task{taski} all task mean:{test_acc_task_stats['task_mean']} \n task means: {test_acc_task_stats['task_means']} \n new err:{test_acc_task_stats['new_err']} \nold err:{test_acc_task_stats['old_err']}\nnew_old_err:{test_acc_task_stats['new_old_err']} \nold_new_err:{test_acc_task_stats['old_new_err']} \nerr_among_task:{test_acc_task_stats['err_among_task']} \nerr_inner_task:{test_acc_task_stats['err_inner_task']}")
225 |
226 | avg_test_acc = results_utils.compute_avg_inc_acc(test_results['results'])
227 | ex.logger.info(f"Test Average Incremental Accuracy: {avg_test_acc}")
228 |
229 |
230 | @ex.command
231 | def prune(_run, _rnd, _seed):
232 | from copy import deepcopy
233 |
234 | cfg, ex.logger, tensorboard = initialization(_run.config, _seed, "prune", _run._id)
235 | #ex.logger.info(cfg)
236 |
237 | trial_i = cfg['trial']
238 | cfg.data_folder = osp.join(base_dir, "data")
239 | inc_dataset = factory.get_data(cfg, trial_i)
240 |
241 | #ex.logger.info("classes_order")
242 | #ex.logger.info(inc_dataset.class_order)
243 |
244 | model = factory.get_model(cfg, trial_i, _run, ex, tensorboard, inc_dataset)
245 | tmodel = factory.get_model(cfg, trial_i, _run, ex, tensorboard, inc_dataset)
246 |
247 | model._network.task_size = cfg.increment
248 | tmodel._network.task_size = cfg.increment
249 |
250 | test_results = results_utils.get_template_results(cfg)
251 |
252 |
253 | for taski in range(inc_dataset.n_tasks):
254 |
255 | print(f"--------------对step{taski}进行剪枝--------------")
256 |
257 | task_info, train_loader, val_loader, test_loader = inc_dataset.new_task()
258 |
259 |
260 | model.set_task_info(
261 | task=task_info["task"],
262 | total_n_classes=task_info["max_class"],
263 | increment=task_info["increment"],
264 | n_train_data=task_info["n_train_data"],
265 | n_test_data=task_info["n_test_data"],
266 | n_tasks=inc_dataset.n_tasks,
267 | )
268 |
269 | model.before_task(taski, inc_dataset)
270 |
271 | tmodel.set_task_info(
272 | task=task_info["task"],
273 | total_n_classes=task_info["max_class"],
274 | increment=task_info["increment"],
275 | n_train_data=task_info["n_train_data"],
276 | n_test_data=task_info["n_test_data"],
277 | n_tasks=inc_dataset.n_tasks,
278 | )
279 |
280 | tmodel.before_task(taski, inc_dataset)
281 |
282 | state_dict = torch.load(f'./{cfg.exp.saveckpt}/step{taski}.ckpt')
283 | tmodel._parallel_network.load_state_dict(state_dict)
284 | model._parallel_network.module.convnets[-1].load_state_dict(tmodel._parallel_network.module.convnets[-1].state_dict())
285 |
286 | net = model._parallel_network.module.convnets[-1]
287 |
288 | flops_raw, params_raw = pruning.get_model_complexity_info(
289 | net, (3, 32, 32), as_strings=True, print_per_layer_stat=False)
290 | print(f'-pruning step{taski} with net{taski}')
291 | print('-[INFO] before pruning flops: ' + flops_raw)
292 | print('-[INFO] before pruning params: ' + params_raw)
293 | # 选择裁剪方式
294 | mod = 'fpgm'
295 |
296 | # 剪枝引擎建立
297 | slim = pruning.Autoslim(net, inputs=torch.randn(
298 | 1, 3, 32, 32), compression_ratio=0.4)
299 |
300 | if mod == 'fpgm':
301 | config = {
302 | 'layer_compression_ratio': None,
303 | 'norm_rate': 1.0, 'prune_shortcut': 1,
304 | 'dist_type': 'l1', 'pruning_func': 'fpgm'
305 | }
306 | elif mod == 'l1':
307 | config = {
308 | 'layer_compression_ratio': None,
309 | 'norm_rate': 1.0, 'prune_shortcut': 1,
310 | 'global_pruning': False, 'pruning_func': 'l1'
311 | }
312 | slim.base_prunging(config)
313 | flops_new, params_new = pruning.get_model_complexity_info(
314 | net, (3, 32, 32), as_strings=True, print_per_layer_stat=False)
315 | print('-[INFO] after pruning flops: ' + flops_new)
316 | print('-[INFO] after pruning params: ' + params_new)
317 |
318 | model.after_prune(taski, inc_dataset)
319 |
320 | model.set_optimizer()
321 |
322 | model.train_task(train_loader, val_loader)
323 |
324 | model.eval()
325 |
326 | model.after_task(taski, inc_dataset)
327 |
328 | model._parallel_network = model._parallel_network.cuda()
329 | ypred, ytrue = model.eval_task(test_loader)
330 |
331 | test_acc_stats = utils.compute_accuracy(ypred, ytrue, increments=model._increments, n_classes=model._n_classes)
332 |
333 | test_results['results'].append(test_acc_stats)
334 | ex.logger.info(f"task{taski} test acc:{test_acc_stats['top1']}")
335 | ex.logger.info(f"top1:{test_acc_stats['top1']}")
336 | ex.logger.info(f"top5:{test_acc_stats['top5']}")
337 |
338 | save_path = os.path.join(os.getcwd(), f"{cfg.exp.saveckpt}")
339 | torch.save(model._parallel_network.cpu(), "{}/prune_step{}.ckpt".format(save_path, taski)) # 保存整个神经网络的模型结构以及参数
340 |
341 |
342 | top1_avg_acc, top5_avg_acc = results_utils.compute_avg_inc_acc(test_results["results"])
343 |
344 | _run.info[f"trial{trial_i}"][f"avg_incremental_accu_top1"] = top1_avg_acc
345 | _run.info[f"trial{trial_i}"][f"avg_incremental_accu_top5"] = top5_avg_acc
346 | ex.logger.info("Average Incremental Accuracy Top 1: {} Top 5: {}.".format(
347 | top1_avg_acc,
348 | top5_avg_acc,
349 | ))
350 |
351 |
352 | if __name__ == "__main__":
353 | ex.add_config("./configs/cifar_b0_10s.yaml")
354 | ex.run_commandline()
355 |
--------------------------------------------------------------------------------