├── inclearn ├── __init__.py ├── convnet │ ├── __init__.py │ ├── __pycache__ │ │ ├── utils.cpython-38.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── network.cpython-38.pyc │ │ ├── resnet.cpython-38.pyc │ │ ├── classifier.cpython-38.pyc │ │ ├── imbalance.cpython-38.pyc │ │ ├── cifar_resnet.cpython-38.pyc │ │ ├── preact_resnet.cpython-38.pyc │ │ └── modified_resnet_cifar.cpython-38.pyc │ ├── classifier.py │ ├── imbalance.py │ ├── modified_resnet_cifar.py │ ├── preact_resnet.py │ ├── cifar_resnet.py │ ├── utils.py │ ├── network.py │ └── resnet.py ├── learn │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ └── pretrain.cpython-38.pyc │ └── pretrain.py ├── loss │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── focalloss.cpython-38.pyc │ │ └── tripeloss.cpython-38.pyc │ ├── focalloss.py │ └── tripeloss.py ├── tools │ ├── __init__.py │ ├── __pycache__ │ │ ├── utils.cpython-38.pyc │ │ ├── cutout.cpython-38.pyc │ │ ├── factory.cpython-38.pyc │ │ ├── memory.cpython-38.pyc │ │ ├── metrics.cpython-38.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── data_utils.cpython-38.pyc │ │ ├── scheduler.cpython-38.pyc │ │ ├── results_utils.cpython-38.pyc │ │ └── autoaugment_extra.cpython-38.pyc │ ├── data_utils.py │ ├── cutout.py │ ├── results_utils.py │ ├── factory.py │ ├── scheduler.py │ ├── memory.py │ ├── metrics.py │ └── utils.py ├── datasets │ ├── __init__.py │ ├── __pycache__ │ │ ├── data.cpython-38.pyc │ │ ├── dataset.cpython-38.pyc │ │ └── __init__.cpython-38.pyc │ └── dataset.py ├── models │ ├── __init__.py │ ├── __pycache__ │ │ ├── base.cpython-38.pyc │ │ ├── align.cpython-38.pyc │ │ ├── prune.cpython-38.pyc │ │ ├── __init__.cpython-38.pyc │ │ └── incmodel.cpython-38.pyc │ ├── base.py │ └── prune.py ├── .DS_Store ├── prune │ ├── prune │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── structured.cpython-38.pyc │ │ │ └── unstructured.cpython-38.pyc │ │ ├── unstructured.py │ │ └── structured.py │ ├── __pycache__ │ │ ├── utils.cpython-38.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── autoslim.cpython-38.pyc │ │ ├── dependency.cpython-38.pyc │ │ └── flops_counter.cpython-38.pyc │ ├── __init__.py │ ├── utils.py │ ├── resnet_small.py │ ├── autoslim_test.py │ ├── autoslim.py │ └── sensitivity_analysis.py └── __pycache__ │ └── __init__.cpython-38.pyc ├── pictures ├── TCIL.png ├── .DS_Store ├── cifar_mem.png ├── non_mem.png └── imagenet_mem.png ├── requirements.txt ├── scripts ├── run.sh ├── prune.sh └── inference.sh ├── configs ├── cifar_b0_10s.yaml ├── cifar_b0_20s.yaml ├── cifar_b0_5s.yaml ├── cifar_b50_10s.yaml ├── cifar_b50_2s.yaml ├── cifar_b50_5s.yaml └── imagenet_b0_10s.yaml ├── README.md └── main.py /inclearn/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /inclearn/convnet/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /inclearn/learn/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /inclearn/loss/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /inclearn/tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /inclearn/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /inclearn/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .incmodel import IncModel 2 | -------------------------------------------------------------------------------- /pictures/TCIL.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/pictures/TCIL.png -------------------------------------------------------------------------------- /inclearn/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/.DS_Store -------------------------------------------------------------------------------- /inclearn/prune/prune/__init__.py: -------------------------------------------------------------------------------- 1 | from .structured import * 2 | from .unstructured import * -------------------------------------------------------------------------------- /pictures/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/pictures/.DS_Store -------------------------------------------------------------------------------- /pictures/cifar_mem.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/pictures/cifar_mem.png -------------------------------------------------------------------------------- /pictures/non_mem.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/pictures/non_mem.png -------------------------------------------------------------------------------- /pictures/imagenet_mem.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/pictures/imagenet_mem.png -------------------------------------------------------------------------------- /inclearn/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /inclearn/models/__pycache__/base.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/models/__pycache__/base.cpython-38.pyc -------------------------------------------------------------------------------- /inclearn/prune/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/prune/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /inclearn/tools/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/tools/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /inclearn/convnet/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/convnet/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /inclearn/datasets/__pycache__/data.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/datasets/__pycache__/data.cpython-38.pyc -------------------------------------------------------------------------------- /inclearn/loss/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/loss/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /inclearn/models/__pycache__/align.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/models/__pycache__/align.cpython-38.pyc -------------------------------------------------------------------------------- /inclearn/models/__pycache__/prune.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/models/__pycache__/prune.cpython-38.pyc -------------------------------------------------------------------------------- /inclearn/tools/__pycache__/cutout.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/tools/__pycache__/cutout.cpython-38.pyc -------------------------------------------------------------------------------- /inclearn/tools/__pycache__/factory.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/tools/__pycache__/factory.cpython-38.pyc -------------------------------------------------------------------------------- /inclearn/tools/__pycache__/memory.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/tools/__pycache__/memory.cpython-38.pyc -------------------------------------------------------------------------------- /inclearn/tools/__pycache__/metrics.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/tools/__pycache__/metrics.cpython-38.pyc -------------------------------------------------------------------------------- /inclearn/convnet/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/convnet/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /inclearn/convnet/__pycache__/network.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/convnet/__pycache__/network.cpython-38.pyc -------------------------------------------------------------------------------- /inclearn/convnet/__pycache__/resnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/convnet/__pycache__/resnet.cpython-38.pyc -------------------------------------------------------------------------------- /inclearn/datasets/__pycache__/dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/datasets/__pycache__/dataset.cpython-38.pyc -------------------------------------------------------------------------------- /inclearn/learn/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/learn/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /inclearn/learn/__pycache__/pretrain.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/learn/__pycache__/pretrain.cpython-38.pyc -------------------------------------------------------------------------------- /inclearn/loss/__pycache__/focalloss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/loss/__pycache__/focalloss.cpython-38.pyc -------------------------------------------------------------------------------- /inclearn/loss/__pycache__/tripeloss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/loss/__pycache__/tripeloss.cpython-38.pyc -------------------------------------------------------------------------------- /inclearn/models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /inclearn/models/__pycache__/incmodel.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/models/__pycache__/incmodel.cpython-38.pyc -------------------------------------------------------------------------------- /inclearn/prune/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/prune/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /inclearn/prune/__pycache__/autoslim.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/prune/__pycache__/autoslim.cpython-38.pyc -------------------------------------------------------------------------------- /inclearn/prune/__pycache__/dependency.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/prune/__pycache__/dependency.cpython-38.pyc -------------------------------------------------------------------------------- /inclearn/tools/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/tools/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /inclearn/tools/__pycache__/data_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/tools/__pycache__/data_utils.cpython-38.pyc -------------------------------------------------------------------------------- /inclearn/tools/__pycache__/scheduler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/tools/__pycache__/scheduler.cpython-38.pyc -------------------------------------------------------------------------------- /inclearn/convnet/__pycache__/classifier.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/convnet/__pycache__/classifier.cpython-38.pyc -------------------------------------------------------------------------------- /inclearn/convnet/__pycache__/imbalance.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/convnet/__pycache__/imbalance.cpython-38.pyc -------------------------------------------------------------------------------- /inclearn/datasets/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/datasets/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /inclearn/convnet/__pycache__/cifar_resnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/convnet/__pycache__/cifar_resnet.cpython-38.pyc -------------------------------------------------------------------------------- /inclearn/convnet/__pycache__/preact_resnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/convnet/__pycache__/preact_resnet.cpython-38.pyc -------------------------------------------------------------------------------- /inclearn/prune/__pycache__/flops_counter.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/prune/__pycache__/flops_counter.cpython-38.pyc -------------------------------------------------------------------------------- /inclearn/prune/prune/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/prune/prune/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /inclearn/tools/__pycache__/results_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/tools/__pycache__/results_utils.cpython-38.pyc -------------------------------------------------------------------------------- /inclearn/prune/prune/__pycache__/structured.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/prune/prune/__pycache__/structured.cpython-38.pyc -------------------------------------------------------------------------------- /inclearn/tools/__pycache__/autoaugment_extra.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/tools/__pycache__/autoaugment_extra.cpython-38.pyc -------------------------------------------------------------------------------- /inclearn/prune/prune/__pycache__/unstructured.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/prune/prune/__pycache__/unstructured.cpython-38.pyc -------------------------------------------------------------------------------- /inclearn/convnet/__pycache__/modified_resnet_cifar.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YellowPancake/TCIL/HEAD/inclearn/convnet/__pycache__/modified_resnet_cifar.cpython-38.pyc -------------------------------------------------------------------------------- /inclearn/prune/__init__.py: -------------------------------------------------------------------------------- 1 | from .dependency import * 2 | from .prune import * 3 | from .autoslim import * 4 | from . import utils 5 | from .flops_counter import get_model_complexity_info 6 | import warnings 7 | 8 | warnings.filterwarnings('ignore') -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | albumentations==1.1.0 2 | easydict==1.9 3 | matplotlib==3.5.1 4 | nni==2.10 5 | numpy==1.22.4 6 | opencv_python==4.5.5.62 7 | Pillow==9.3.0 8 | sacred==0.8.2 9 | scikit_learn==1.1.3 10 | scipy==1.9.3 11 | tensorboardX==2.5.1 12 | thop==0.0.31.post2005241907 13 | torch==1.8.1+cu111 14 | torchvision==0.9.1+cu111 15 | -------------------------------------------------------------------------------- /scripts/run.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | name='cifar_b50_2s' 3 | expid='cifar_b50_2s' 4 | 5 | 6 | python -m main train with "./configs/${expid}.yaml" \ 7 | exp.name="${name}" \ 8 | exp.savedir="./logs/" \ 9 | exp.saveckpt="./ckpts_${expid}/" \ 10 | exp.ckptdir="./logs/" \ 11 | exp.tensorboard_dir="./tensorboard/" \ 12 | exp.debug=True \ 13 | --name="${name}" \ 14 | -D \ 15 | -p \ 16 | --force \ 17 | -------------------------------------------------------------------------------- /scripts/prune.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | name='cifar_b0_10s' 3 | expid='cifar_b0_10s' 4 | 5 | 6 | python -m main prune with "./configs/${expid}.yaml" \ 7 | exp.name="${name}" \ 8 | exp.savedir="./logs/" \ 9 | exp.ckptdir="./logs/" \ 10 | exp.saveckpt="./ckpts_${expid}/" \ 11 | exp.tensorboard_dir="./tensorboard/" \ 12 | exp.debug=True \ 13 | load_mem=True \ 14 | --name="${name}" \ 15 | -D \ 16 | -p \ 17 | --force \ 18 | -------------------------------------------------------------------------------- /scripts/inference.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | name='Inference_cifar_b0_10s' 3 | expid='cifar_b0_10s' 4 | 5 | 6 | python -m main test with "./configs/${expid}.yaml" \ 7 | exp.name="${name}" \ 8 | exp.savedir="./logs/" \ 9 | exp.ckptdir="./logs/" \ 10 | exp.saveckpt="./ckpts_${expid}/" \ 11 | exp.tensorboard_dir="./tensorboard/" \ 12 | exp.debug=True \ 13 | load_mem=True \ 14 | --name="${name}" \ 15 | -D \ 16 | -p \ 17 | --force \ -------------------------------------------------------------------------------- /inclearn/tools/data_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def construct_balanced_subset(x, y): 5 | xdata, ydata = [], [] 6 | minsize = np.inf 7 | for cls_ in np.unique(y): 8 | xdata.append(x[y == cls_]) 9 | ydata.append(y[y == cls_]) 10 | if ydata[-1].shape[0] < minsize: 11 | minsize = ydata[-1].shape[0] 12 | for i in range(len(xdata)): 13 | if xdata[i].shape[0] < minsize: 14 | import pdb 15 | pdb.set_trace() 16 | idx = np.arange(xdata[i].shape[0]) 17 | np.random.shuffle(idx) 18 | xdata[i] = xdata[i][idx][:minsize] 19 | ydata[i] = ydata[i][idx][:minsize] 20 | # !list 21 | return np.concatenate(xdata, 0), np.concatenate(ydata, 0) -------------------------------------------------------------------------------- /inclearn/convnet/classifier.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.nn.parameter import Parameter 5 | from torch.nn import functional as F 6 | from torch.nn import Module 7 | 8 | 9 | class CosineClassifier(Module): 10 | def __init__(self, in_features, n_classes, sigma=True): 11 | super(CosineClassifier, self).__init__() 12 | self.in_features = in_features 13 | self.out_features = n_classes 14 | self.weight = Parameter(torch.Tensor(n_classes, in_features)) 15 | if sigma: 16 | self.sigma = Parameter(torch.Tensor(1)) 17 | else: 18 | self.register_parameter('sigma', None) 19 | self.reset_parameters() 20 | 21 | def reset_parameters(self): 22 | stdv = 1. / math.sqrt(self.weight.size(1)) 23 | self.weight.data.uniform_(-stdv, stdv) 24 | if self.sigma is not None: 25 | self.sigma.data.fill_(1) #for initializaiton of sigma 26 | 27 | def forward(self, input): 28 | out = F.linear(F.normalize(input, p=2, dim=1), F.normalize(self.weight, p=2, dim=1)) 29 | if self.sigma is not None: 30 | out = self.sigma * out 31 | return out 32 | -------------------------------------------------------------------------------- /inclearn/tools/cutout.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | class Cutout(object): 5 | """Randomly mask out one or more patches from an image. 6 | Args: 7 | n_holes (int): Number of patches to cut out of each image. 8 | length (int): The length (in pixels) of each square patch. 9 | """ 10 | def __init__(self, n_holes, length): 11 | self.n_holes = n_holes 12 | self.length = length 13 | 14 | def __call__(self, img): 15 | """ 16 | Args: 17 | img (Tensor): Tensor image of size (C, H, W). 18 | Returns: 19 | Tensor: Image with n_holes of dimension length x length cut out of it. 20 | """ 21 | h = img.size(1) 22 | w = img.size(2) 23 | 24 | mask = np.ones((h, w), np.float32) 25 | 26 | for n in range(self.n_holes): 27 | y = np.random.randint(h) 28 | x = np.random.randint(w) 29 | 30 | y1 = np.clip(y - self.length // 2, 0, h) 31 | y2 = np.clip(y + self.length // 2, 0, h) 32 | x1 = np.clip(x - self.length // 2, 0, w) 33 | x2 = np.clip(x + self.length // 2, 0, w) 34 | 35 | mask[y1: y2, x1: x2] = 0. 36 | 37 | mask = torch.from_numpy(mask) 38 | mask = mask.expand_as(img) 39 | img = img * mask 40 | 41 | return img -------------------------------------------------------------------------------- /inclearn/prune/utils.py: -------------------------------------------------------------------------------- 1 | from .dependency import TORCH_CONV, TORCH_BATCHNORM, TORCH_PRELU, TORCH_LINEAR 2 | 3 | def count_prunable_params(module): 4 | if isinstance( module, ( TORCH_CONV, TORCH_LINEAR) ): 5 | num_params = module.weight.numel() 6 | if module.bias is not None: 7 | num_params += module.bias.numel() 8 | return num_params 9 | elif isinstance( module, TORCH_BATCHNORM ): 10 | num_params = module.running_mean.numel() + module.running_var.numel() 11 | if module.affine: 12 | num_params+= module.weight.numel() + module.bias.numel() 13 | return num_params 14 | elif isinstance( module, TORCH_PRELU ): 15 | if len( module.weight )==1: 16 | return 0 17 | else: 18 | return module.weight.numel 19 | else: 20 | return 0 21 | 22 | def count_prunable_channels(module): 23 | if isinstance( module, TORCH_CONV ): 24 | return module.weight.shape[0] 25 | elif isinstance( module, TORCH_LINEAR ): 26 | return module.out_features 27 | elif isinstance( module, TORCH_BATCHNORM ): 28 | return module.num_features 29 | elif isinstance( module, TORCH_PRELU ): 30 | if len( module.weight )==1: 31 | return 0 32 | else: 33 | return len(module.weight) 34 | else: 35 | return 0 36 | 37 | def count_params(module): 38 | return sum([ p.numel() for p in module.parameters() ]) 39 | -------------------------------------------------------------------------------- /inclearn/tools/results_utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import math 4 | import os 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | from copy import deepcopy 8 | 9 | from . import utils 10 | 11 | 12 | def get_template_results(cfg): 13 | return {"config": cfg, "results": []} 14 | 15 | 16 | def save_results(results, label): 17 | del results["config"]["device"] 18 | 19 | folder_path = os.path.join("results", "{}_{}".format(utils.get_date(), label)) 20 | if not os.path.exists(folder_path): 21 | os.makedirs(folder_path) 22 | 23 | file_path = "{}_{}_.json".format(utils.get_date(), results["config"]["seed"]) 24 | with open(os.path.join(folder_path, file_path), "w+") as f: 25 | json.dump(results, f, indent=2) 26 | 27 | 28 | def compute_avg_inc_acc(results): 29 | """Computes the average incremental accuracy as defined in iCaRL. 30 | 31 | The average incremental accuracies at task X are the average of accuracies 32 | at task 0, 1, ..., and X. 33 | 34 | :param accs: A list of dict for per-class accuracy at each step. 35 | :return: A float. 36 | """ 37 | top1_tasks_accuracy = [r['top1']["total"] for r in results] 38 | top1acc = sum(top1_tasks_accuracy) / len(top1_tasks_accuracy) 39 | if "top5" in results[0].keys(): 40 | top5_tasks_accuracy = [r['top5']["total"] for r in results] 41 | top5acc = sum(top5_tasks_accuracy) / len(top5_tasks_accuracy) 42 | else: 43 | top5acc = None 44 | return top1acc, top5acc -------------------------------------------------------------------------------- /inclearn/convnet/imbalance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | import numpy as np 5 | from torch.optim.lr_scheduler import CosineAnnealingLR 6 | 7 | 8 | 9 | class CR(object): 10 | def __init__(self): 11 | self.gamma = None 12 | 13 | @torch.no_grad() 14 | def update(self, classifier, task_size): 15 | old_weight_norm = torch.norm(classifier.weight[:-task_size], p=2, dim=1) 16 | new_weight_norm = torch.norm(classifier.weight[-task_size:], p=2, dim=1) 17 | self.gamma = old_weight_norm.mean() / new_weight_norm.mean() 18 | # print(self.gamma.cpu().item()) 19 | 20 | @torch.no_grad() 21 | def post_process(self, logits, task_size): 22 | logits[:, -task_size:] = logits[:, -task_size:] * self.gamma 23 | return logits 24 | 25 | 26 | class All_av(object): 27 | def __init__(self): 28 | self.gamma = [] 29 | 30 | @torch.no_grad() 31 | def update(self, classifier, task_size, classnum_list, taski): 32 | self.gamma = [] 33 | for i in range(taski+1): 34 | old_weight_norm = torch.norm(classifier.weight[:-task_size], p=2, dim=1) 35 | new_weight_norm = torch.norm(classifier.weight[sum(classnum_list[:i]):sum(classnum_list[:i+1])], p=2, dim=1) 36 | self.gamma.append(old_weight_norm.mean() / new_weight_norm.mean()) 37 | # print(self.gamma) 38 | 39 | @torch.no_grad() 40 | def post_process(self, logits, task_size, classnum_list, taski): 41 | for i in range(taski+1): 42 | logits[:, sum(classnum_list[:i]):sum(classnum_list[:i+1])] = logits[:, sum(classnum_list[:i]):sum(classnum_list[:i+1])] * self.gamma[i] 43 | return logits 44 | -------------------------------------------------------------------------------- /inclearn/prune/prune/unstructured.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from copy import deepcopy 4 | 5 | __all__=['mask_weight', 'mask_bias'] 6 | 7 | def _mask_weight_hook(module, input): 8 | if hasattr(module, 'weight_mask'): 9 | module.weight.data *= module.weight_mask 10 | 11 | def _mask_bias_hook(module, input): 12 | if module.bias is not None and hasattr(module, 'bias_mask'): 13 | module.bias.data *= module.bias_mask 14 | 15 | def mask_weight(layer, mask, inplace=True): 16 | """Unstructed pruning for convolution layer 17 | 18 | Args: 19 | layer: a convolution layer. 20 | mask: 0-1 mask. 21 | """ 22 | if not inplace: 23 | layer = deepcopy(layer) 24 | if mask.shape != layer.weight.shape: 25 | return layer 26 | mask = torch.tensor( mask, dtype=layer.weight.dtype, device=layer.weight.device, requires_grad=False ) 27 | if hasattr(layer, 'weight_mask'): 28 | mask = mask + layer.weight_mask 29 | mask[mask>0]=1 30 | layer.weight_mask = mask 31 | else: 32 | layer.register_buffer( 'weight_mask', mask ) 33 | 34 | layer.register_forward_pre_hook( _mask_weight_hook ) 35 | return layer 36 | 37 | def mask_bias(layer, mask, inplace=True): 38 | """Unstructed pruning for convolution layer 39 | 40 | Args: 41 | layer: a convolution layer. 42 | mask: 0-1 mask. 43 | """ 44 | if not inplace: 45 | layer = deepcopy(layer) 46 | if layer.bias is None or mask.shape != layer.bias.shape: 47 | return layer 48 | 49 | mask = torch.tensor( mask, dtype=layer.weight.dtype, device=layer.weight.device, requires_grad=False ) 50 | if hasattr(layer, 'bias_mask'): 51 | mask = mask + layer.bias_mask 52 | mask[mask>0]=1 53 | layer.bias_mask = mask 54 | else: 55 | layer.register_buffer( 'bias_mask', mask ) 56 | layer.register_forward_pre_hook( _mask_bias_hook ) 57 | return layer 58 | -------------------------------------------------------------------------------- /inclearn/loss/focalloss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.autograd import Variable 4 | from torch.nn import functional as F 5 | 6 | class BCEFocalLoss(torch.nn.Module): 7 | def __init__(self, gamma=2, alpha=0.25, reduction='mean'): 8 | super(BCEFocalLoss, self).__init__() 9 | self.gamma = gamma 10 | self.alpha = alpha 11 | self.reduction = reduction 12 | 13 | def forward(self, predict, target): 14 | pt = torch.sigmoid(predict) # sigmoide获取概率 15 | #在原始ce上增加动态权重因子,注意alpha的写法,下面多类时不能这样使用 16 | loss = - self.alpha * (1 - pt) ** self.gamma * target * torch.log(pt) - (1 - self.alpha) * pt ** self.gamma * (1 - target) * torch.log(1 - pt) 17 | if self.reduction == 'mean': 18 | loss = torch.mean(loss) 19 | elif self.reduction == 'sum': 20 | loss = torch.sum(loss) 21 | return loss 22 | 23 | class MultiCEFocalLoss(torch.nn.Module): 24 | def __init__(self, class_num, gamma=2, alpha=None, reduction='mean'): 25 | super(MultiCEFocalLoss, self).__init__() 26 | if alpha is None: 27 | self.alpha = Variable(torch.ones(class_num, 1)) 28 | else: 29 | self.alpha = alpha 30 | self.gamma = gamma 31 | self.reduction = reduction 32 | self.class_num = class_num 33 | 34 | def forward(self, predict, target): 35 | pt = F.softmax(predict, dim=1) # softmmax获取预测概率 36 | class_mask = F.one_hot(target, self.class_num) #获取target的one hot编码 37 | ids = target.view(-1, 1) 38 | # print(ids.shape, self.alpha.shape) 39 | alpha = self.alpha[ids.data.view(-1)].cuda() # 注意,这里的alpha是给定的一个list(tensor 40 | #),里面的元素分别是每一个类的权重因子 41 | probs = (pt * class_mask).sum(1).view(-1, 1) # 利用onehot作为mask,提取对应的pt 42 | log_p = probs.log() 43 | # 同样,原始ce上增加一个动态权重衰减因子 44 | # print(probs.shape, alpha.shape) 45 | loss = -alpha * (torch.pow((1 - probs), self.gamma)) * log_p 46 | 47 | if self.reduction == 'mean': 48 | loss = loss.mean() 49 | elif self.reduction == 'sum': 50 | loss = loss.sum() 51 | return loss -------------------------------------------------------------------------------- /configs/cifar_b0_10s.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: "CIFAR100_B0_10S_TCIL" 3 | savedir: "./logs" 4 | tensorboard_dir: "./tensorboard" 5 | debug: False 6 | 7 | 8 | #Model Cfg 9 | model: "incmodel" 10 | convnet: 'resnet18' 11 | train_head: 'softmax' 12 | infer_head: 'softmax' 13 | channel: 64 14 | use_bias: False 15 | last_relu: False 16 | 17 | dea: True 18 | use_div_cls: True 19 | div_type: "n+1" # n+t, 1+1 20 | distillation: True 21 | disttype: "MTKD" 22 | temperature: 2 23 | distlamb: 1 24 | feature_type: "ffm" # se 25 | attention_use_residual: True 26 | ignore_new: True 27 | 28 | prune: False 29 | 30 | attention: 31 | add_kl: True 32 | kd_warm_up: 50 33 | kd_loss_weight: 0.5 34 | kl_loss_weight: 0.5 35 | 36 | reuse_oldfc: False 37 | weight_normalization: False 38 | val_per_n_epoch: -1 # Validation Per N epoch. -1 means the function is off. 39 | save_ckpt: True 40 | save_mem: True 41 | load_mem: False 42 | 43 | #Optimization;Training related 44 | task_max: 10 45 | lr_min: 0.00005 46 | lr: 0.1 47 | weight_decay: 0.0005 48 | dynamic_weight_decay: False 49 | scheduler: 'multistep' 50 | scheduling: 51 | - 100 52 | - 120 53 | lr_decay: 0.1 54 | optimizer: "sgd" 55 | epochs: 170 56 | resampling: False 57 | warmup: True 58 | warmup_epochs: 10 59 | 60 | postprocessor: 61 | enable: True 62 | type: 'cr' 63 | 64 | pretrain: 65 | epochs: 200 66 | lr: 0.1 67 | scheduling: 68 | - 60 69 | - 120 70 | - 160 71 | lr_decay: 0.1 72 | weight_decay: 0.0005 73 | 74 | 75 | # Dataset Cfg 76 | dataset: "cifar100" #'imagenet100', 'cifar100' 77 | trial: 2 78 | increment: 10 79 | batch_size: 128 80 | workers: 4 81 | validation: 0 # Validation split (0. <= x <= 1.) 82 | random_classes: False #Randomize classes order of increment 83 | start_class: 0 # number of tasks for the first step, start from 0. 84 | start_task: 0 85 | max_task: # Cap the number of task 86 | 87 | #Memory 88 | coreset_strategy: "iCaRL" # iCaRL, random 89 | mem_size_mode: "uniform_fixed_total_mem" #uniform_fixed_per_cls, uniform_fixed_total_mem 90 | memory_size: 2000 # Max number of storable examplars 91 | fixed_memory_per_cls: 20 # the fixed number of exemplars per cls 92 | 93 | # Misc 94 | device: 0 #GPU index to use, for cpu use -1 95 | seed: 1993 96 | -------------------------------------------------------------------------------- /configs/cifar_b0_20s.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: "CIFAR100_B0_20S_TCIL" 3 | savedir: "./logs" 4 | tensorboard_dir: "./tensorboard" 5 | debug: False 6 | 7 | 8 | #Model Cfg 9 | model: "incmodel" 10 | convnet: 'resnet18' 11 | train_head: 'softmax' 12 | infer_head: 'softmax' 13 | channel: 64 14 | use_bias: False 15 | last_relu: False 16 | 17 | dea: True 18 | use_div_cls: True 19 | div_type: "n+1" # n+t, 1+1 20 | distillation: True 21 | disttype: "MTKD" 22 | temperature: 2 23 | distlamb: 1 24 | feature_type: "ffm" # se 25 | attention_use_residual: True 26 | ignore_new: True 27 | 28 | prune: False 29 | 30 | attention: 31 | add_kl: True 32 | kd_warm_up: 50 33 | kd_loss_weight: 0.5 34 | kl_loss_weight: 0.5 35 | 36 | reuse_oldfc: False 37 | weight_normalization: False 38 | val_per_n_epoch: -1 # Validation Per N epoch. -1 means the function is off. 39 | save_ckpt: True 40 | save_mem: True 41 | load_mem: False 42 | 43 | #Optimization;Training related 44 | task_max: 10 45 | lr_min: 0.00005 46 | lr: 0.1 47 | weight_decay: 0.0005 48 | dynamic_weight_decay: False 49 | scheduler: 'multistep' 50 | scheduling: 51 | - 100 52 | - 120 53 | lr_decay: 0.1 54 | optimizer: "sgd" 55 | epochs: 170 56 | resampling: False 57 | warmup: True 58 | warmup_epochs: 10 59 | 60 | postprocessor: 61 | enable: True 62 | type: 'cr' 63 | 64 | pretrain: 65 | epochs: 200 66 | lr: 0.1 67 | scheduling: 68 | - 60 69 | - 120 70 | - 160 71 | lr_decay: 0.1 72 | weight_decay: 0.0005 73 | 74 | 75 | # Dataset Cfg 76 | dataset: "cifar100" #'imagenet100', 'cifar100' 77 | trial: 2 78 | increment: 5 79 | batch_size: 128 80 | workers: 4 81 | validation: 0 # Validation split (0. <= x <= 1.) 82 | random_classes: False #Randomize classes order of increment 83 | start_class: 0 # number of tasks for the first step, start from 0. 84 | start_task: 0 85 | max_task: # Cap the number of task 86 | 87 | #Memory 88 | coreset_strategy: "iCaRL" # iCaRL, random 89 | mem_size_mode: "uniform_fixed_total_mem" #uniform_fixed_per_cls, uniform_fixed_total_mem 90 | memory_size: 2000 # Max number of storable examplars 91 | fixed_memory_per_cls: 20 # the fixed number of exemplars per cls 92 | 93 | # Misc 94 | device: 0 #GPU index to use, for cpu use -1 95 | seed: 1993 96 | -------------------------------------------------------------------------------- /configs/cifar_b0_5s.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: "CIFAR100_B0_5S_TCIL" 3 | savedir: "./logs" 4 | tensorboard_dir: "./tensorboard" 5 | debug: False 6 | 7 | 8 | #Model Cfg 9 | model: "incmodel" 10 | convnet: 'resnet18' 11 | train_head: 'softmax' 12 | infer_head: 'softmax' 13 | channel: 64 14 | use_bias: False 15 | last_relu: False 16 | 17 | dea: True 18 | use_div_cls: True 19 | div_type: "n+1" # n+t, 1+1 20 | distillation: True 21 | disttype: "MTKD" 22 | temperature: 2 23 | distlamb: 1 24 | feature_type: "ffm" # se 25 | attention_use_residual: True 26 | ignore_new: True 27 | 28 | prune: False 29 | 30 | attention: 31 | add_kl: True 32 | kd_warm_up: 50 33 | kd_loss_weight: 0.5 34 | kl_loss_weight: 0.5 35 | 36 | reuse_oldfc: False 37 | weight_normalization: False 38 | val_per_n_epoch: -1 # Validation Per N epoch. -1 means the function is off. 39 | save_ckpt: True 40 | save_mem: True 41 | load_mem: False 42 | 43 | #Optimization;Training related 44 | task_max: 10 45 | lr_min: 0.00005 46 | lr: 0.1 47 | weight_decay: 0.0005 48 | dynamic_weight_decay: False 49 | scheduler: 'multistep' 50 | scheduling: 51 | - 100 52 | - 120 53 | lr_decay: 0.1 54 | optimizer: "sgd" 55 | epochs: 170 56 | resampling: False 57 | warmup: True 58 | warmup_epochs: 10 59 | 60 | postprocessor: 61 | enable: True 62 | type: 'cr' 63 | 64 | pretrain: 65 | epochs: 200 66 | lr: 0.1 67 | scheduling: 68 | - 60 69 | - 120 70 | - 160 71 | lr_decay: 0.1 72 | weight_decay: 0.0005 73 | 74 | 75 | # Dataset Cfg 76 | dataset: "cifar100" #'imagenet100', 'cifar100' 77 | trial: 2 78 | increment: 20 79 | batch_size: 128 80 | workers: 4 81 | validation: 0 # Validation split (0. <= x <= 1.) 82 | random_classes: False #Randomize classes order of increment 83 | start_class: 0 # number of tasks for the first step, start from 0. 84 | start_task: 0 85 | max_task: # Cap the number of task 86 | 87 | #Memory 88 | coreset_strategy: "iCaRL" # iCaRL, random 89 | mem_size_mode: "uniform_fixed_total_mem" #uniform_fixed_per_cls, uniform_fixed_total_mem 90 | memory_size: 2000 # Max number of storable examplars 91 | fixed_memory_per_cls: 20 # the fixed number of exemplars per cls 92 | 93 | # Misc 94 | device: 0 #GPU index to use, for cpu use -1 95 | seed: 1993 96 | -------------------------------------------------------------------------------- /configs/cifar_b50_10s.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: "CIFAR100_B50_10S_TCIL" 3 | savedir: "./logs" 4 | tensorboard_dir: "./tensorboard" 5 | debug: False 6 | 7 | 8 | #Model Cfg 9 | model: "incmodel" 10 | convnet: 'resnet18' 11 | train_head: 'softmax' 12 | infer_head: 'softmax' 13 | channel: 64 14 | use_bias: False 15 | last_relu: False 16 | 17 | dea: True 18 | use_div_cls: True 19 | div_type: "n+1" # n+t, 1+1 20 | distillation: True 21 | disttype: "MTKD" 22 | temperature: 2 23 | distlamb: 1 24 | feature_type: "ffm" # se 25 | attention_use_residual: True 26 | ignore_new: True 27 | 28 | prune: False 29 | 30 | attention: 31 | add_kl: True 32 | kd_warm_up: 50 33 | kd_loss_weight: 0.5 34 | kl_loss_weight: 0.5 35 | 36 | reuse_oldfc: False 37 | weight_normalization: False 38 | val_per_n_epoch: -1 # Validation Per N epoch. -1 means the function is off. 39 | save_ckpt: True 40 | save_mem: True 41 | load_mem: False 42 | 43 | #Optimization;Training related 44 | task_max: 10 45 | lr_min: 0.00005 46 | lr: 0.1 47 | weight_decay: 0.0005 48 | dynamic_weight_decay: False 49 | scheduler: 'multistep' 50 | scheduling: 51 | - 100 52 | - 120 53 | lr_decay: 0.1 54 | optimizer: "sgd" 55 | epochs: 170 56 | resampling: False 57 | warmup: True 58 | warmup_epochs: 10 59 | 60 | postprocessor: 61 | enable: True 62 | type: 'cr' 63 | 64 | pretrain: 65 | epochs: 200 66 | lr: 0.1 67 | scheduling: 68 | - 60 69 | - 120 70 | - 160 71 | lr_decay: 0.1 72 | weight_decay: 0.0005 73 | 74 | 75 | # Dataset Cfg 76 | dataset: "cifar100" #'imagenet100', 'cifar100' 77 | trial: 2 78 | increment: 5 79 | batch_size: 128 80 | workers: 4 81 | validation: 0 # Validation split (0. <= x <= 1.) 82 | random_classes: False #Randomize classes order of increment 83 | start_class: 50 # number of tasks for the first step, start from 0. 84 | start_task: 0 85 | max_task: # Cap the number of task 86 | 87 | #Memory 88 | coreset_strategy: "iCaRL" # iCaRL, random 89 | mem_size_mode: "uniform_fixed_total_mem" #uniform_fixed_per_cls, uniform_fixed_total_mem 90 | memory_size: 2000 # Max number of storable examplars 91 | fixed_memory_per_cls: 20 # the fixed number of exemplars per cls 92 | 93 | # Misc 94 | device: 0 #GPU index to use, for cpu use -1 95 | seed: 1993 96 | -------------------------------------------------------------------------------- /configs/cifar_b50_2s.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: "CIFAR100_B50_2S_TCIL" 3 | savedir: "./logs" 4 | tensorboard_dir: "./tensorboard" 5 | debug: False 6 | 7 | 8 | #Model Cfg 9 | model: "incmodel" 10 | convnet: 'resnet18' 11 | train_head: 'softmax' 12 | infer_head: 'softmax' 13 | channel: 64 14 | use_bias: False 15 | last_relu: False 16 | 17 | dea: True 18 | use_div_cls: True 19 | div_type: "n+1" # n+t, 1+1 20 | distillation: True 21 | disttype: "MTKD" 22 | temperature: 2 23 | distlamb: 1 24 | feature_type: "ffm" # se 25 | attention_use_residual: True 26 | ignore_new: True 27 | 28 | prune: False 29 | 30 | attention: 31 | add_kl: True 32 | kd_warm_up: 50 33 | kd_loss_weight: 0.5 34 | kl_loss_weight: 0.5 35 | 36 | reuse_oldfc: False 37 | weight_normalization: False 38 | val_per_n_epoch: -1 # Validation Per N epoch. -1 means the function is off. 39 | save_ckpt: True 40 | save_mem: True 41 | load_mem: False 42 | 43 | #Optimization;Training related 44 | task_max: 10 45 | lr_min: 0.00005 46 | lr: 0.1 47 | weight_decay: 0.0005 48 | dynamic_weight_decay: False 49 | scheduler: 'multistep' 50 | scheduling: 51 | - 100 52 | - 120 53 | lr_decay: 0.1 54 | optimizer: "sgd" 55 | epochs: 170 56 | resampling: False 57 | warmup: True 58 | warmup_epochs: 10 59 | 60 | postprocessor: 61 | enable: True 62 | type: 'cr' 63 | 64 | pretrain: 65 | epochs: 200 66 | lr: 0.1 67 | scheduling: 68 | - 60 69 | - 120 70 | - 160 71 | lr_decay: 0.1 72 | weight_decay: 0.0005 73 | 74 | 75 | # Dataset Cfg 76 | dataset: "cifar100" #'imagenet100', 'cifar100' 77 | trial: 2 78 | increment: 25 79 | batch_size: 128 80 | workers: 4 81 | validation: 0 # Validation split (0. <= x <= 1.) 82 | random_classes: False #Randomize classes order of increment 83 | start_class: 50 # number of tasks for the first step, start from 0. 84 | start_task: 0 85 | max_task: # Cap the number of task 86 | 87 | #Memory 88 | coreset_strategy: "iCaRL" # iCaRL, random 89 | mem_size_mode: "uniform_fixed_total_mem" #uniform_fixed_per_cls, uniform_fixed_total_mem 90 | memory_size: 2000 # Max number of storable examplars 91 | fixed_memory_per_cls: 20 # the fixed number of exemplars per cls 92 | 93 | # Misc 94 | device: 0 #GPU index to use, for cpu use -1 95 | seed: 1993 96 | -------------------------------------------------------------------------------- /configs/cifar_b50_5s.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: "CIFAR100_B50_5S_TCIL" 3 | savedir: "./logs" 4 | tensorboard_dir: "./tensorboard" 5 | debug: False 6 | 7 | 8 | #Model Cfg 9 | model: "incmodel" 10 | convnet: 'resnet18' 11 | train_head: 'softmax' 12 | infer_head: 'softmax' 13 | channel: 64 14 | use_bias: False 15 | last_relu: False 16 | 17 | dea: True 18 | use_div_cls: True 19 | div_type: "n+1" # n+t, 1+1 20 | distillation: True 21 | disttype: "MTKD" 22 | temperature: 2 23 | distlamb: 1 24 | feature_type: "ffm" # se 25 | attention_use_residual: True 26 | ignore_new: True 27 | 28 | prune: False 29 | 30 | attention: 31 | add_kl: True 32 | kd_warm_up: 50 33 | kd_loss_weight: 0.5 34 | kl_loss_weight: 0.5 35 | 36 | reuse_oldfc: False 37 | weight_normalization: False 38 | val_per_n_epoch: -1 # Validation Per N epoch. -1 means the function is off. 39 | save_ckpt: True 40 | save_mem: True 41 | load_mem: False 42 | 43 | #Optimization;Training related 44 | task_max: 10 45 | lr_min: 0.00005 46 | lr: 0.1 47 | weight_decay: 0.0005 48 | dynamic_weight_decay: False 49 | scheduler: 'multistep' 50 | scheduling: 51 | - 100 52 | - 120 53 | lr_decay: 0.1 54 | optimizer: "sgd" 55 | epochs: 170 56 | resampling: False 57 | warmup: True 58 | warmup_epochs: 10 59 | 60 | postprocessor: 61 | enable: True 62 | type: 'cr' 63 | 64 | pretrain: 65 | epochs: 200 66 | lr: 0.1 67 | scheduling: 68 | - 60 69 | - 120 70 | - 160 71 | lr_decay: 0.1 72 | weight_decay: 0.0005 73 | 74 | 75 | # Dataset Cfg 76 | dataset: "cifar100" #'imagenet100', 'cifar100' 77 | trial: 2 78 | increment: 10 79 | batch_size: 128 80 | workers: 4 81 | validation: 0 # Validation split (0. <= x <= 1.) 82 | random_classes: False #Randomize classes order of increment 83 | start_class: 50 # number of tasks for the first step, start from 0. 84 | start_task: 0 85 | max_task: # Cap the number of task 86 | 87 | #Memory 88 | coreset_strategy: "iCaRL" # iCaRL, random 89 | mem_size_mode: "uniform_fixed_total_mem" #uniform_fixed_per_cls, uniform_fixed_total_mem 90 | memory_size: 2000 # Max number of storable examplars 91 | fixed_memory_per_cls: 20 # the fixed number of exemplars per cls 92 | 93 | # Misc 94 | device: 0 #GPU index to use, for cpu use -1 95 | seed: 1993 96 | -------------------------------------------------------------------------------- /configs/imagenet_b0_10s.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: "ImageNet100_B0_10S_TCIL" 3 | savedir: "./logs" 4 | tensorboard_dir: "./tensorboard" 5 | debug: False 6 | 7 | 8 | #Model Cfg 9 | model: "incmodel" 10 | convnet: 'resnet18' 11 | train_head: 'softmax' 12 | infer_head: 'softmax' 13 | channel: 64 14 | use_bias: False 15 | last_relu: False 16 | 17 | dea: True 18 | use_div_cls: True 19 | div_type: "n+1" # n+t, 1+1 20 | distillation: True 21 | disttype: "MTKD" 22 | temperature: 2 23 | distlamb: 1 24 | feature_type: "ffm" # se 25 | attention_use_residual: True 26 | ignore_new: True 27 | 28 | prune: False 29 | 30 | attention: 31 | add_kl: True 32 | kd_warm_up: 50 33 | kd_loss_weight: 0.5 34 | kl_loss_weight: 0.5 35 | 36 | reuse_oldfc: False 37 | weight_normalization: False 38 | val_per_n_epoch: -1 # Validation Per N epoch. -1 means the function is off. 39 | save_ckpt: True 40 | save_mem: True 41 | load_mem: False 42 | 43 | #Optimization;Training related 44 | task_max: 10 45 | lr_min: 0.00005 46 | lr: 0.1 47 | weight_decay: 0.0005 48 | dynamic_weight_decay: False 49 | scheduler: 'multistep' 50 | scheduling: 51 | - 60 52 | - 120 53 | - 180 54 | lr_decay: 0.1 55 | optimizer: "sgd" 56 | epochs: 200 57 | resampling: False 58 | warmup: True 59 | warmup_epochs: 20 60 | 61 | postprocessor: 62 | enable: True 63 | type: 'cr' 64 | 65 | pretrain: 66 | epochs: 200 67 | lr: 0.1 68 | scheduling: 69 | - 60 70 | - 120 71 | - 160 72 | lr_decay: 0.1 73 | weight_decay: 0.0005 74 | 75 | 76 | # Dataset Cfg 77 | dataset: "imagenet100" #'imagenet100', 'cifar100' 78 | trial: 2 79 | increment: 10 80 | batch_size: 128 81 | workers: 4 82 | validation: 0 # Validation split (0. <= x <= 1.) 83 | random_classes: False #Randomize classes order of increment 84 | start_class: 0 # number of tasks for the first step, start from 0. 85 | start_task: 0 86 | max_task: # Cap the number of task 87 | 88 | #Memory 89 | coreset_strategy: "iCaRL" # iCaRL, random 90 | mem_size_mode: "uniform_fixed_total_mem" #uniform_fixed_per_cls, uniform_fixed_total_mem 91 | memory_size: 2000 # Max number of storable examplars 92 | fixed_memory_per_cls: 20 # the fixed number of exemplars per cls 93 | 94 | # Misc 95 | device: 0 #GPU index to use, for cpu use -1 96 | seed: 1993 97 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # TCIL 4 | 5 | ## Resolving Task Confusion in Dynamic Expansion Architectures for Class Incremental Learning 6 | 7 | [![Paper](https://img.shields.io/badge/arXiv-2212.14284-brightgreen)](https://arxiv.org/abs/2212.14284) 8 | ![AAAI](https://img.shields.io/badge/AAAI-2023-%2312457A) 9 | [![bilibili](https://img.shields.io/badge/bilibili-link-%23ff69b4)](https://www.bilibili.com/video/BV1L14y1u7XV/) 10 | 11 | 12 | ![TCIL main figure](pictures/TCIL.png) 13 | 14 | 15 |
16 | 17 | 18 | 19 | # Datasets 20 | 21 | - Training datasets 22 | 1. CIFAR100: 23 | CIFAR100 dataset will be auto-downloaded. 24 | 2. ImageNet100: 25 | ImageNet100 is a subset of ImageNet. You need to download ImageNet first, and split the dataset refer to [ImageNet100_Split](https://github.com/arthurdouillard/incremental_learning.pytorch). 26 | 27 | - Class ordering 28 | - We use the class ordering proposed by [DER](https://github.com/Rhyssiyan/DER-ClassIL.pytorch). 29 | 30 | - Structure of `data` directory 31 | ``` 32 | data 33 | ├── cifar100 34 | │ └── cifar-100-python 35 | │ ├── train 36 | │ ├── test 37 | │ ├── meta 38 | │ └── file.txt~ 39 | │ 40 | ├── imagenet100 41 | │ ├── train 42 | │ └── val 43 | ``` 44 | 45 | # Environment 46 | You can find all the libraries in the `requirements.txt`, and configure the experimental environment with the following commands. 47 | 48 | ``` 49 | conda create -n TCIL python=3.8 50 | conda install pytorch==1.8.1 torchvision==0.9.1 cudatoolkit=11.1 -c pytorch 51 | pip install -r requirements.txt 52 | ``` 53 | Thanks for the great code base from [DER](https://github.com/Rhyssiyan/DER-ClassIL.pytorch). 54 | # Launching an experiment 55 | ## Train 56 | `sh scripts/run.sh` 57 | ## Eval 58 | `sh scripts/inference.sh` 59 | ## Prune 60 | `sh scripts/prune.sh` 61 | 62 | 63 | # Results 64 | 65 | ## Rehearsal Setting 66 | 67 | ![CIFAR figure rehearsal results](pictures/cifar_mem.png) 68 | ![ImageNet figure rehearsal results](pictures/imagenet_mem.png) 69 | 70 | ## Non-Rehearsal Setting 71 | 72 | ![CIFAR and ImageNet figure non-rehearsal results](pictures/non_mem.png) 73 | 74 | ## Checkpoints 75 | 76 | Get the trained models from [BaiduNetdisk(passwd:q3eh)](https://pan.baidu.com/s/1G0XVZCaaZ2LmM_eppr3cXA). 77 | (We both offer the training logs in the same file) 78 | 79 | -------------------------------------------------------------------------------- /inclearn/tools/factory.py: -------------------------------------------------------------------------------- 1 | from matplotlib.transforms import Transform 2 | import torch 3 | from torch import nn 4 | from torch import optim 5 | 6 | from inclearn import models 7 | from inclearn.convnet import resnet, cifar_resnet, modified_resnet_cifar, preact_resnet 8 | from inclearn.datasets import data 9 | from inclearn.convnet.resnet import SEFeatureAt 10 | 11 | def get_optimizer(params, optimizer, lr, weight_decay=0.0): 12 | if optimizer == "adam": 13 | return optim.Adam(params, lr=lr, weight_decay=weight_decay, betas=(0.9, 0.999)) 14 | elif optimizer == "sgd": 15 | return optim.SGD(params, lr=lr, weight_decay=weight_decay, momentum=0.9) 16 | else: 17 | raise NotImplementedError 18 | 19 | def get_attention(inplane, type, at_res): 20 | return SEFeatureAt(inplane, type, at_res) 21 | 22 | def get_convnet(convnet_type, **kwargs): 23 | if convnet_type == "resnet18": 24 | return resnet.resnet18(**kwargs) 25 | elif convnet_type == "resnet32": 26 | return cifar_resnet.resnet32() 27 | elif convnet_type == "modified_resnet32": 28 | return modified_resnet_cifar.resnet32(**kwargs) 29 | elif convnet_type == "preact_resnet18": 30 | return preact_resnet.PreActResNet18(**kwargs) 31 | else: 32 | raise NotImplementedError("Unknwon convnet type {}.".format(convnet_type)) 33 | 34 | 35 | def get_model(cfg, trial_i, _run, ex, tensorboard, inc_dataset): 36 | if cfg["model"] == "incmodel": 37 | return models.IncModel(cfg, trial_i, _run, ex, tensorboard, inc_dataset) 38 | else: 39 | raise NotImplementedError(cfg["model"]) 40 | 41 | 42 | def get_data(cfg, trial_i): 43 | return data.IncrementalDataset( 44 | trial_i=trial_i, 45 | dataset_name=cfg["dataset"], 46 | random_order=cfg["random_classes"], 47 | shuffle=True, 48 | batch_size=cfg["batch_size"], 49 | workers=cfg["workers"], 50 | validation_split=cfg["validation"], 51 | resampling=cfg["resampling"], 52 | increment=cfg["increment"], 53 | data_folder=cfg["data_folder"], 54 | start_class=cfg["start_class"], 55 | # transform=cfg.get("transform","normal") 56 | ) 57 | 58 | 59 | def set_device(cfg): 60 | device_type = cfg["device"] 61 | 62 | if device_type == -1: 63 | device = torch.device("cpu") 64 | else: 65 | device = torch.device("cuda:{}".format(device_type)) 66 | 67 | cfg["device"] = device 68 | return device 69 | -------------------------------------------------------------------------------- /inclearn/loss/tripeloss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.autograd import Variable 4 | import pdb 5 | 6 | 7 | class TripletLossNoHardMining(nn.Module): 8 | def __init__(self, margin=0, num_instances=8): 9 | super(TripletLossNoHardMining, self).__init__() 10 | self.margin = margin 11 | self.num_instances = num_instances 12 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 13 | 14 | def forward(self, inputs, targets): 15 | n = inputs.size(0) 16 | # Compute pairwise distance, replace by the official when merged 17 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) 18 | dist = dist + dist.t() 19 | dist.addmm_(1, -2, inputs, inputs.t()) 20 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 21 | # For each anchor, find the hardest positive and negative 22 | mask = targets.expand(n, n).eq(targets.expand(n, n).t()) 23 | dist_ap, dist_an = [], [] 24 | for i in range(n): 25 | for j in range(self.num_instances-1): 26 | tmp = dist[i][mask[i]] 27 | dist_ap.append(tmp[j+1]) 28 | tmp = dist[i][mask[i] == 0] 29 | dist_an.append(tmp[j+1]) 30 | dist_ap = torch.stack(dist_ap) 31 | dist_an = torch.stack(dist_an) 32 | # Compute ranking hinge loss 33 | y = dist_an.data.new() 34 | y.resize_as_(dist_an.data) 35 | y.fill_(1) 36 | y = Variable(y) 37 | loss = self.ranking_loss(dist_an, dist_ap, y) 38 | prec = (dist_an.data > dist_ap.data).sum() * 1. / y.size(0) 39 | dist_p = torch.mean(dist_ap).data[0] 40 | dist_n = torch.mean(dist_an).data[0] 41 | return loss, prec, dist_p, dist_n 42 | 43 | class TripletLoss(nn.Module): 44 | def __init__(self, margin=0, num_instances=None): 45 | super(TripletLoss, self).__init__() 46 | self.margin = margin 47 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 48 | 49 | def forward(self, inputs, targets): 50 | n = inputs.size(0) 51 | # Compute pairwise distance, replace by the official when merged 52 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) 53 | dist = dist + dist.t() 54 | dist.addmm_(1, -2, inputs, inputs.t()) 55 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 56 | # For each anchor, find the hardest positive and negative 57 | mask = targets.expand(n, n).eq(targets.expand(n, n).t()) 58 | dist_ap, dist_an = [], [] 59 | for i in range(n): 60 | dist_ap.append(dist[i][mask[i]].max()) 61 | dist_an.append(dist[i][mask[i] == 0].min()) 62 | dist_ap = torch.stack(dist_ap) 63 | dist_an = torch.stack(dist_an) 64 | # Compute ranking hinge loss 65 | y = dist_an.data.new() 66 | y.resize_as_(dist_an.data) 67 | y.fill_(1) 68 | y = Variable(y) 69 | loss = self.ranking_loss(dist_an, dist_ap, y) 70 | prec = (dist_an.data > dist_ap.data).sum() * 1. / y.size(0) 71 | dist_p = torch.mean(dist_ap).item() 72 | dist_n = torch.mean(dist_an).item() 73 | return loss, prec, dist_p, dist_n -------------------------------------------------------------------------------- /inclearn/tools/scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch.optim.lr_scheduler import _LRScheduler 3 | from torch.optim.lr_scheduler import ReduceLROnPlateau 4 | 5 | 6 | class ConstantTaskLR: 7 | def __init__(self, lr): 8 | self._lr = lr 9 | 10 | def get_lr(self, task_i): 11 | return self._lr 12 | 13 | 14 | class CosineAnnealTaskLR: 15 | def __init__(self, lr_max, lr_min, task_max): 16 | self._lr_max = lr_max 17 | self._lr_min = lr_min 18 | self._task_max = task_max 19 | 20 | def get_lr(self, task_i): 21 | return self._lr_min + (self._lr_max - self._lr_min) * (1 + math.cos(math.pi * task_i / self._task_max)) / 2 22 | 23 | 24 | class GradualWarmupScheduler(_LRScheduler): 25 | """ Gradually warm-up(increasing) learning rate in optimizer. 26 | https://github.com/ildoonet/pytorch-gradual-warmup-lr 27 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. 28 | Args: 29 | optimizer (Optimizer): Wrapped optimizer. 30 | multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr. 31 | total_epoch: target learning rate is reached at total_epoch, gradually 32 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) 33 | """ 34 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): 35 | self.multiplier = multiplier 36 | if self.multiplier < 1.: 37 | raise ValueError('multiplier should be greater thant or equal to 1.') 38 | self.total_epoch = total_epoch 39 | self.after_scheduler = after_scheduler 40 | self.finished = False 41 | super(GradualWarmupScheduler, self).__init__(optimizer) 42 | 43 | def get_lr(self): 44 | if self.last_epoch > self.total_epoch: 45 | if self.after_scheduler: 46 | if not self.finished: 47 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] 48 | self.finished = True 49 | return self.after_scheduler.get_last_lr() 50 | return [base_lr * self.multiplier for base_lr in self.base_lrs] 51 | 52 | if self.multiplier == 1.0: 53 | return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs] 54 | else: 55 | return [ 56 | base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) 57 | for base_lr in self.base_lrs 58 | ] 59 | 60 | def step_ReduceLROnPlateau(self, metrics, epoch=None): 61 | if epoch is None: 62 | epoch = self.last_epoch + 1 63 | self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning 64 | if self.last_epoch <= self.total_epoch: 65 | warmup_lr = [ 66 | base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) 67 | for base_lr in self.base_lrs 68 | ] 69 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): 70 | param_group['lr'] = lr 71 | else: 72 | if epoch is None: 73 | self.after_scheduler.step(metrics, None) 74 | else: 75 | self.after_scheduler.step(metrics, epoch - self.total_epoch) 76 | 77 | def step(self, epoch=None, metrics=None): 78 | if type(self.after_scheduler) != ReduceLROnPlateau: 79 | if self.finished and self.after_scheduler: 80 | if epoch is None: 81 | self.after_scheduler.step(None) 82 | else: 83 | self.after_scheduler.step(epoch - self.total_epoch) 84 | self._last_lr = self.after_scheduler.get_last_lr() 85 | else: 86 | return super(GradualWarmupScheduler, self).step(epoch) 87 | else: 88 | self.step_ReduceLROnPlateau(metrics, epoch) 89 | -------------------------------------------------------------------------------- /inclearn/learn/pretrain.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from inclearn.tools import factory, utils 6 | from inclearn.tools.metrics import ClassErrorMeter, AverageValueMeter 7 | 8 | # import line_profiler 9 | # import atexit 10 | # profile = line_profiler.LineProfiler() 11 | # atexit.register(profile.print_stats) 12 | 13 | 14 | def _compute_loss(cfg, logits, targets, device): 15 | 16 | if cfg["train_head"] == "sigmoid": 17 | n_classes = cfg["start_class"] 18 | onehot_targets = utils.to_onehot(targets, n_classes).to(device) 19 | loss = F.binary_cross_entropy_with_logits(logits, onehot_targets) 20 | elif cfg["train_head"] == "softmax": 21 | loss = F.cross_entropy(logits, targets) 22 | else: 23 | raise ValueError() 24 | 25 | return loss 26 | 27 | 28 | def train(cfg, model, optimizer, device, train_loader): 29 | _loss = 0.0 30 | accu = ClassErrorMeter(accuracy=True) 31 | accu.reset() 32 | 33 | model.train() 34 | for i, (inputs, targets) in enumerate(train_loader, start=1): 35 | # assert torch.isnan(inputs).sum().item() == 0 36 | optimizer.zero_grad() 37 | inputs, targets = inputs.to(device), targets.to(device) 38 | logits = model._parallel_network(inputs)['logit'] 39 | if accu is not None: 40 | accu.add(logits.detach(), targets) 41 | 42 | loss = _compute_loss(cfg, logits, targets, device) 43 | if torch.isnan(loss): 44 | import pdb 45 | 46 | pdb.set_trace() 47 | 48 | loss.backward() 49 | optimizer.step() 50 | _loss += loss 51 | 52 | return ( 53 | round(_loss.item() / i, 3), 54 | round(accu.value()[0], 3), 55 | ) 56 | 57 | 58 | def test(cfg, model, device, test_loader): 59 | _loss = 0.0 60 | accu = ClassErrorMeter(accuracy=True) 61 | accu.reset() 62 | 63 | model.eval() 64 | with torch.no_grad(): 65 | for i, (inputs, targets) in enumerate(test_loader, start=1): 66 | # assert torch.isnan(inputs).sum().item() == 0 67 | inputs, targets = inputs.to(device), targets.to(device) 68 | logits = model._parallel_network(inputs)['logit'] 69 | if accu is not None: 70 | accu.add(logits.detach(), targets) 71 | loss = _compute_loss(cfg, logits, targets, device) 72 | if torch.isnan(loss): 73 | import pdb 74 | pdb.set_trace() 75 | 76 | _loss = _loss + loss 77 | return round(_loss.item() / i, 3), round(accu.value()[0], 3) 78 | 79 | 80 | def pretrain(cfg, ex, model, device, train_loader, test_loader, model_path): 81 | ex.logger.info(f"nb Train {len(train_loader.dataset)} Eval {len(test_loader.dataset)}") 82 | optimizer = torch.optim.SGD(model._network.parameters(), 83 | lr=cfg["pretrain"]["lr"], 84 | momentum=0.9, 85 | weight_decay=cfg["pretrain"]["weight_decay"]) 86 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, 87 | cfg["pretrain"]["scheduling"], 88 | gamma=cfg["pretrain"]["lr_decay"]) 89 | test_loss, test_acc = float("nan"), float("nan") 90 | for e in range(cfg["pretrain"]["epochs"]): 91 | train_loss, train_acc = train(cfg, model, optimizer, device, train_loader) 92 | if e % 5 == 0: 93 | test_loss, test_acc = test(cfg, model, device, test_loader) 94 | ex.logger.info( 95 | "Pretrain Class {}, Epoch {}/{} => Clf Train loss: {}, Accu {} | Eval loss: {}, Accu {}".format( 96 | cfg["start_class"], e + 1, cfg["pretrain"]["epochs"], train_loss, train_acc, test_loss, test_acc)) 97 | else: 98 | ex.logger.info("Pretrain Class {}, Epoch {}/{} => Clf Train loss: {}, Accu {} ".format( 99 | cfg["start_class"], e + 1, cfg["pretrain"]["epochs"], train_loss, train_acc)) 100 | scheduler.step() 101 | if hasattr(model._network, "module"): 102 | torch.save(model._network.module.state_dict(), model_path) 103 | else: 104 | torch.save(model._network.state_dict(), model_path) 105 | -------------------------------------------------------------------------------- /inclearn/convnet/modified_resnet_cifar.py: -------------------------------------------------------------------------------- 1 | """Taken & slightly modified from: 2 | * https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 3 | """ 4 | import torch.nn as nn 5 | import torch.utils.model_zoo as model_zoo 6 | from torch.nn import functional as F 7 | 8 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'] 9 | 10 | model_urls = { 11 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 12 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 13 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 14 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 15 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 16 | } 17 | 18 | 19 | def conv3x3(in_planes, out_planes, stride=1): 20 | """3x3 convolution with padding""" 21 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 22 | 23 | 24 | def conv1x1(in_planes, out_planes, stride=1): 25 | """1x1 convolution""" 26 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 27 | 28 | 29 | class BasicBlock(nn.Module): 30 | expansion = 1 31 | 32 | def __init__(self, inplanes, planes, stride=1, downsample=None, remove_last_relu=False): 33 | super(BasicBlock, self).__init__() 34 | self.conv1 = conv3x3(inplanes, planes, stride) 35 | self.bn1 = nn.BatchNorm2d(planes) 36 | self.relu = nn.ReLU(inplace=True) 37 | self.conv2 = conv3x3(planes, planes) 38 | self.bn2 = nn.BatchNorm2d(planes) 39 | self.downsample = downsample 40 | self.stride = stride 41 | self.remove_last_relu = remove_last_relu 42 | 43 | def forward(self, x): 44 | identity = x 45 | 46 | out = self.conv1(x) 47 | out = self.bn1(out) 48 | out = self.relu(out) 49 | 50 | out = self.conv2(out) 51 | out = self.bn2(out) 52 | 53 | if self.downsample is not None: 54 | identity = self.downsample(x) 55 | 56 | out += identity 57 | if not self.remove_last_relu: 58 | out = self.relu(out) 59 | return out 60 | 61 | 62 | class ResNet(nn.Module): 63 | def __init__(self, block, layers, nf=16, dataset='cifar', start_class=0, remove_last_relu=False): 64 | super(ResNet, self).__init__() 65 | self.inplanes = nf 66 | self.conv1 = nn.Conv2d(3, nf, kernel_size=3, stride=1, padding=1, bias=False) 67 | self.bn1 = nn.BatchNorm2d(nf) 68 | self.relu = nn.ReLU(inplace=True) 69 | 70 | self.layer1 = self._make_layer(block, 1 * nf, layers[0]) 71 | self.layer2 = self._make_layer(block, 2 * nf, layers[1], stride=2) 72 | self.layer3 = self._make_layer(block, 4 * nf, layers[2], stride=2) 73 | self.avgpool = nn.AvgPool2d(8, stride=1) 74 | 75 | self.out_dim = 4 * nf * block.expansion 76 | 77 | for m in self.modules(): 78 | if isinstance(m, nn.Conv2d): 79 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 80 | elif isinstance(m, nn.BatchNorm2d): 81 | nn.init.constant_(m.weight, 1) 82 | nn.init.constant_(m.bias, 0) 83 | 84 | def _make_layer(self, block, planes, blocks, stride=1, remove_last_relu=False): 85 | downsample = None 86 | if stride != 1 or self.inplanes != planes * block.expansion: 87 | downsample = nn.Sequential( 88 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), 89 | nn.BatchNorm2d(planes * block.expansion), 90 | ) 91 | 92 | layers = [] 93 | layers.append(block(self.inplanes, planes, stride, downsample)) 94 | self.inplanes = planes * block.expansion 95 | if remove_last_relu: 96 | for i in range(1, blocks - 1): 97 | layers.append(block(self.inplanes, planes)) 98 | layers.append(block(self.inplanes, planes, remove_last_relu=True)) 99 | else: 100 | for _ in range(1, blocks): 101 | layers.append(block(self.inplanes, planes)) 102 | 103 | return nn.Sequential(*layers) 104 | 105 | def reset_bn(self): 106 | for m in self.modules(): 107 | if isinstance(m, nn.BatchNorm2d): 108 | m.reset_running_stats() 109 | 110 | def forward(self, x, pool=True): 111 | x = self.conv1(x) 112 | x = self.bn1(x) 113 | x = self.relu(x) 114 | 115 | x = self.layer1(x) 116 | x = self.layer2(x) 117 | x = self.layer3(x) 118 | 119 | x = self.avgpool(x) 120 | x = x.view(x.size(0), -1) 121 | return x 122 | 123 | 124 | def resnet20(pretrained=False, **kwargs): 125 | n = 3 126 | model = ResNet(BasicBlock, [n, n, n], **kwargs) 127 | return model 128 | 129 | 130 | def resnet32(pretrained=False, **kwargs): 131 | n = 5 132 | model = ResNet(BasicBlock, [n, n, n], **kwargs) 133 | return model 134 | -------------------------------------------------------------------------------- /inclearn/convnet/preact_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class PreActBlock(nn.Module): 7 | '''Pre-activation version of the BasicBlock.''' 8 | expansion = 1 9 | 10 | def __init__(self, in_planes, planes, stride=1, remove_last_relu=False): 11 | super(PreActBlock, self).__init__() 12 | self.bn1 = nn.BatchNorm2d(in_planes) 13 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(planes) 15 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 16 | self.bn3 = nn.BatchNorm2d(planes) 17 | self.remove_last_relu = remove_last_relu 18 | 19 | if stride != 1 or in_planes != self.expansion * planes: 20 | self.shortcut = nn.Sequential( 21 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False)) 22 | 23 | def forward(self, x): 24 | out = F.relu(self.bn1(x)) 25 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 26 | out = self.conv1(out) 27 | out = self.conv2(F.relu(self.bn2(out))) 28 | out += shortcut 29 | out = self.bn3(out) 30 | if not self.remove_last_relu: 31 | out = F.relu(out) 32 | return out 33 | 34 | 35 | class PreActBottleneck(nn.Module): 36 | '''Pre-activation version of the original Bottleneck module.''' 37 | expansion = 4 38 | 39 | def __init__(self, in_planes, planes, stride=1): 40 | super(PreActBottleneck, self).__init__() 41 | self.bn1 = nn.BatchNorm2d(in_planes) 42 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 43 | self.bn2 = nn.BatchNorm2d(planes) 44 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 45 | self.bn3 = nn.BatchNorm2d(planes) 46 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 47 | 48 | if stride != 1 or in_planes != self.expansion * planes: 49 | self.shortcut = nn.Sequential( 50 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False)) 51 | 52 | def forward(self, x): 53 | out = F.relu(self.bn1(x)) 54 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 55 | out = self.conv1(out) 56 | out = self.conv2(F.relu(self.bn2(out))) 57 | out = self.conv3(F.relu(self.bn3(out))) 58 | out += shortcut 59 | return out 60 | 61 | 62 | class PreActResNet(nn.Module): 63 | def __init__(self, 64 | block, 65 | num_blocks, 66 | nf=64, 67 | zero_init_residual=True, 68 | dataset="cifar", 69 | start_class=0, 70 | remove_last_relu=False): 71 | super(PreActResNet, self).__init__() 72 | self.in_planes = nf 73 | self.dataset = dataset 74 | self.remove_last_relu = remove_last_relu 75 | 76 | if 'cifar' in dataset: 77 | self.conv1 = nn.Conv2d(3, nf, kernel_size=3, stride=1, padding=1, bias=False) 78 | else: 79 | self.conv1 = nn.Sequential(nn.Conv2d(3, nf, kernel_size=7, stride=2, padding=3, bias=False), 80 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) 81 | self.layer1 = self._make_layer(block, 1 * nf, num_blocks[0], stride=1) 82 | self.layer2 = self._make_layer(block, 2 * nf, num_blocks[1], stride=2) 83 | self.layer3 = self._make_layer(block, 4 * nf, num_blocks[2], stride=2) 84 | self.layer4 = self._make_layer(block, 8 * nf, num_blocks[3], stride=2, remove_last_relu=remove_last_relu) 85 | self.out_dim = 8 * nf 86 | 87 | if 'cifar' in dataset: 88 | self.avgpool = nn.AvgPool2d(4) 89 | elif 'imagenet' in dataset: 90 | self.avgpool = nn.AvgPool2d(7) 91 | 92 | for m in self.modules(): 93 | if isinstance(m, nn.Conv2d): 94 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 95 | elif isinstance(m, nn.BatchNorm2d): 96 | nn.init.constant_(m.weight, 1) 97 | nn.init.constant_(m.bias, 0) 98 | 99 | # --------------------------------------------- 100 | # if zero_init_residual: 101 | # for m in self.modules(): 102 | # if isinstance(m, PreActBlock): 103 | # nn.init.constant_(m.bn2.weight, 0) 104 | # elif isinstance(m, PreActBottleneck): 105 | # nn.init.constant_(m.bn3.weight, 0) 106 | # --------------------------------------------- 107 | 108 | def _make_layer(self, block, planes, num_blocks, stride, remove_last_relu=False): 109 | strides = [stride] + [1] * (num_blocks - 1) 110 | layers = [] 111 | if remove_last_relu: 112 | for i in range(len(strides) - 1): 113 | layers.append(block(self.in_planes, planes, strides[i])) 114 | self.in_planes = planes * block.expansion 115 | layers.append(block(self.in_planes, planes, strides[-1], remove_last_relu=True)) 116 | self.in_planes = planes * block.expansion 117 | else: 118 | for stride in strides: 119 | layers.append(block(self.in_planes, planes, stride)) 120 | self.in_planes = planes * block.expansion 121 | return nn.Sequential(*layers) 122 | 123 | def forward(self, x): 124 | out = self.conv1(x) 125 | out = self.layer1(out) 126 | out = self.layer2(out) 127 | out = self.layer3(out) 128 | out = self.layer4(out) 129 | out = self.avgpool(out) 130 | out = out.view(out.size(0), -1) 131 | return out 132 | 133 | 134 | def PreActResNet18(**kwargs): 135 | return PreActResNet(PreActBlock, [2, 2, 2, 2], **kwargs) 136 | 137 | 138 | def PreActResNet34(**kwargs): 139 | return PreActResNet(PreActBlock, [3, 4, 6, 3], **kwargs) 140 | 141 | 142 | def PreActResNet50(**kwargs): 143 | return PreActResNet(PreActBottleneck, [3, 4, 6, 3], **kwargs) 144 | 145 | 146 | def PreActResNet101(**kwargs): 147 | return PreActResNet(PreActBottleneck, [3, 4, 23, 3], **kwargs) 148 | 149 | 150 | def PreActResNet152(**kwargs): 151 | return PreActResNet(PreActBottleneck, [3, 8, 36, 3], **kwargs) 152 | -------------------------------------------------------------------------------- /inclearn/convnet/cifar_resnet.py: -------------------------------------------------------------------------------- 1 | ''' Incremental-Classifier Learning 2 | Authors : Khurram Javed, Muhammad Talha Paracha 3 | Maintainer : Khurram Javed 4 | Lab : TUKL-SEECS R&D Lab 5 | Email : 14besekjaved@seecs.edu.pk ''' 6 | 7 | import math 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torch.nn import init 13 | 14 | 15 | class DownsampleA(nn.Module): 16 | def __init__(self, nIn, nOut, stride): 17 | super(DownsampleA, self).__init__() 18 | assert stride == 2 19 | self.avg = nn.AvgPool2d(kernel_size=1, stride=stride) 20 | 21 | def forward(self, x): 22 | x = self.avg(x) 23 | return torch.cat((x, x.mul(0)), 1) 24 | 25 | 26 | class ResNetBasicblock(nn.Module): 27 | expansion = 1 28 | """ 29 | RexNet basicblock (https://github.com/facebook/fb.resnet.torch/blob/master/models/resnet.lua) 30 | """ 31 | def __init__(self, inplanes, planes, stride=1, downsample=None): 32 | super(ResNetBasicblock, self).__init__() 33 | 34 | self.conv_a = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 35 | self.bn_a = nn.BatchNorm2d(planes) 36 | 37 | self.conv_b = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 38 | self.bn_b = nn.BatchNorm2d(planes) 39 | 40 | self.downsample = downsample 41 | self.featureSize = 64 42 | 43 | def forward(self, x): 44 | residual = x 45 | 46 | basicblock = self.conv_a(x) 47 | basicblock = self.bn_a(basicblock) 48 | basicblock = F.relu(basicblock, inplace=True) 49 | 50 | basicblock = self.conv_b(basicblock) 51 | basicblock = self.bn_b(basicblock) 52 | 53 | if self.downsample is not None: 54 | residual = self.downsample(x) 55 | 56 | return F.relu(residual + basicblock, inplace=True) 57 | 58 | 59 | class CifarResNet(nn.Module): 60 | """ 61 | ResNet optimized for the Cifar Dataset, as specified in 62 | https://arxiv.org/abs/1512.03385.pdf 63 | """ 64 | def __init__(self, block, depth, num_classes, channels=3): 65 | """ Constructor 66 | Args: 67 | depth: number of layers. 68 | num_classes: number of classes 69 | base_width: base width 70 | """ 71 | super(CifarResNet, self).__init__() 72 | 73 | self.featureSize = 64 74 | # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model 75 | assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110' 76 | layer_blocks = (depth - 2) // 6 77 | 78 | self.num_classes = num_classes 79 | 80 | self.conv_1_3x3 = nn.Conv2d(channels, 16, kernel_size=3, stride=1, padding=1, bias=False) 81 | self.bn_1 = nn.BatchNorm2d(16) 82 | 83 | self.inplanes = 16 84 | self.stage_1 = self._make_layer(block, 16, layer_blocks, 1) 85 | self.stage_2 = self._make_layer(block, 32, layer_blocks, 2) 86 | self.stage_3 = self._make_layer(block, 64, layer_blocks, 2) 87 | self.avgpool = nn.AvgPool2d(8) 88 | self.out_dim = 64 * block.expansion 89 | 90 | for m in self.modules(): 91 | if isinstance(m, nn.Conv2d): 92 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 93 | m.weight.data.normal_(0, math.sqrt(2. / n)) 94 | # m.bias.data.zero_() 95 | elif isinstance(m, nn.BatchNorm2d): 96 | m.weight.data.fill_(1) 97 | m.bias.data.zero_() 98 | elif isinstance(m, nn.Linear): 99 | init.kaiming_normal(m.weight) 100 | m.bias.data.zero_() 101 | 102 | def _make_layer(self, block, planes, blocks, stride=1): 103 | downsample = None 104 | if stride != 1 or self.inplanes != planes * block.expansion: 105 | downsample = DownsampleA(self.inplanes, planes * block.expansion, stride) 106 | 107 | layers = [] 108 | layers.append(block(self.inplanes, planes, stride, downsample)) 109 | self.inplanes = planes * block.expansion 110 | for i in range(1, blocks): 111 | layers.append(block(self.inplanes, planes)) 112 | 113 | return nn.Sequential(*layers) 114 | 115 | def forward(self, x, feature=False, T=1, labels=False, scale=None, keep=None): 116 | 117 | x = self.conv_1_3x3(x) 118 | x = F.relu(self.bn_1(x), inplace=True) 119 | x = self.stage_1(x) 120 | x = self.stage_2(x) 121 | x = self.stage_3(x) 122 | x = self.avgpool(x) 123 | x = x.view(x.size(0), -1) 124 | return x 125 | 126 | def forwardFeature(self, x): 127 | pass 128 | 129 | 130 | def resnet20(num_classes=10): 131 | """Constructs a ResNet-20 model for CIFAR-10 (by default) 132 | Args: 133 | num_classes (uint): number of classes 134 | """ 135 | model = CifarResNet(ResNetBasicblock, 20, num_classes) 136 | return model 137 | 138 | 139 | def resnet10mnist(num_classes=10): 140 | """Constructs a ResNet-20 model for CIFAR-10 (by default) 141 | Args: 142 | num_classes (uint): number of classes 143 | """ 144 | model = CifarResNet(ResNetBasicblock, 10, num_classes, 1) 145 | return model 146 | 147 | 148 | def resnet20mnist(num_classes=10): 149 | """Constructs a ResNet-20 model for CIFAR-10 (by default) 150 | Args: 151 | num_classes (uint): number of classes 152 | """ 153 | model = CifarResNet(ResNetBasicblock, 20, num_classes, 1) 154 | return model 155 | 156 | 157 | def resnet32mnist(num_classes=10, channels=1): 158 | model = CifarResNet(ResNetBasicblock, 32, num_classes, channels) 159 | return model 160 | 161 | 162 | def resnet32(num_classes=10): 163 | """Constructs a ResNet-32 model for CIFAR-10 (by default) 164 | Args: 165 | num_classes (uint): number of classes 166 | """ 167 | model = CifarResNet(ResNetBasicblock, 32, num_classes) 168 | return model 169 | 170 | 171 | def resnet44(num_classes=10): 172 | """Constructs a ResNet-44 model for CIFAR-10 (by default) 173 | Args: 174 | num_classes (uint): number of classes 175 | """ 176 | model = CifarResNet(ResNetBasicblock, 44, num_classes) 177 | return model 178 | 179 | 180 | def resnet56(num_classes=10): 181 | """Constructs a ResNet-56 model for CIFAR-10 (by default) 182 | Args: 183 | num_classes (uint): number of classes 184 | """ 185 | model = CifarResNet(ResNetBasicblock, 56, num_classes) 186 | return model 187 | 188 | 189 | def resnet110(num_classes=10): 190 | """Constructs a ResNet-110 model for CIFAR-10 (by default) 191 | Args: 192 | num_classes (uint): number of classes 193 | """ 194 | model = CifarResNet(ResNetBasicblock, 110, num_classes) 195 | return model 196 | -------------------------------------------------------------------------------- /inclearn/models/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import logging 3 | import torch 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from inclearn.tools.metrics import ClassErrorMeter 7 | 8 | LOGGER = logging.Logger("IncLearn", level="INFO") 9 | 10 | 11 | class IncrementalLearner(abc.ABC): 12 | """Base incremental learner. 13 | 14 | Methods are called in this order (& repeated for each new task): 15 | 16 | 1. set_task_info 17 | 2. before_task 18 | 3. train_task 19 | 4. after_task 20 | 5. eval_task 21 | """ 22 | def __init__(self, *args, **kwargs): 23 | self._increments = [] 24 | self._seen_classes = [] 25 | 26 | def set_task_info(self, task, total_n_classes, increment, n_train_data, n_test_data, n_tasks): 27 | self._task = task 28 | self._task_size = increment 29 | self._increments.append(self._task_size) 30 | self._total_n_classes = total_n_classes 31 | self._n_train_data = n_train_data 32 | self._n_test_data = n_test_data 33 | self._n_tasks = n_tasks 34 | 35 | def before_task(self, taski, inc_dataset): 36 | LOGGER.info("Before task") 37 | self.eval() 38 | self._before_task(taski, inc_dataset) 39 | 40 | def train_task(self, train_loader, val_loader): 41 | LOGGER.info("train task") 42 | self.train() 43 | self._train_task(train_loader, val_loader) 44 | 45 | def after_task(self, taski, inc_dataset): 46 | LOGGER.info("after task") 47 | self.eval() 48 | self._after_task(taski, inc_dataset) 49 | 50 | def eval_task(self, data_loader): 51 | LOGGER.info("eval task") 52 | self.eval() 53 | return self._eval_task(data_loader) 54 | 55 | def get_memory(self): 56 | return None 57 | 58 | def eval(self): 59 | raise NotImplementedError 60 | 61 | def train(self): 62 | raise NotImplementedError 63 | 64 | def _before_task(self, data_loader): 65 | pass 66 | 67 | def _train_task(self, train_loader, val_loader): 68 | raise NotImplementedError 69 | 70 | def _after_task(self, data_loader): 71 | pass 72 | 73 | def _eval_task(self, data_loader): 74 | raise NotImplementedError 75 | 76 | @property 77 | def _new_task_index(self): 78 | return self._task * self._task_size 79 | 80 | @property 81 | def _memory_per_class(self): 82 | """Returns the number of examplars per class.""" 83 | return self._memory_size.mem_per_cls 84 | 85 | def _after_epoch(self, epoch, avg_loss, train_new_accu, train_old_accu, accu): 86 | self._run.log_scalar(f"train_loss_trial{self._trial_i}_task{self._task}", avg_loss, epoch + 1) 87 | self._tensorboard.add_scalar(f"trial{self._trial_i}_task{self._task}/train_loss", avg_loss, epoch + 1) 88 | 89 | # self._run.log_scalar(f"train_new_accu_trial{self._trial_i}_task{self._task}", 90 | # train_new_accu.value()[0], epoch + 1) 91 | # self._tensorboard.add_scalar(f"trial{self._trial_i}_task{self._task}/train_new_accu", 92 | # train_new_accu.value()[0], epoch + 1) 93 | 94 | # if self._task != 0: 95 | # self._run.log_scalar(f"train_old_accu_trial{self._trial_i}_task{self._task}", 96 | # train_old_accu.value()[0], epoch + 1) 97 | # self._tensorboard.add_scalar(f"trial{self._trial_i}_task{self._task}/train_old_accu", 98 | # train_old_accu.value()[0], epoch + 1) 99 | 100 | self._run.log_scalar(f"train_accu_trial{self._trial_i}_task{self._task}", accu.value()[0], epoch + 1) 101 | self._tensorboard.add_scalar(f"trial{self._trial_i}_task{self._task}/train_accu", accu.value()[0], epoch + 1) 102 | # self._tensorboard.close() 103 | self._tensorboard.flush() 104 | 105 | def _validation(self, val_loader, epoch): 106 | topk = 5 if self._n_classes >= 5 else self._n_classes 107 | if self._val_per_n_epoch != -1 and epoch % self._val_per_n_epoch == 0: 108 | _val_loss = 0 109 | _val_accu = ClassErrorMeter(accuracy=True, topk=[1, topk]) 110 | _val_new_accu = ClassErrorMeter(accuracy=True) 111 | _val_old_accu = ClassErrorMeter(accuracy=True) 112 | self._parallel_network.eval() 113 | with torch.no_grad(): 114 | for i, (inputs, targets) in enumerate(val_loader, 1): 115 | old_classes = targets < (self._n_classes - self._task_size) 116 | new_classes = targets >= (self._n_classes - self._task_size) 117 | val_loss, _ = self._forward_loss( 118 | inputs, 119 | targets, 120 | old_classes, 121 | new_classes, 122 | accu=_val_accu, 123 | old_accu=_val_old_accu, 124 | new_accu=_val_new_accu, 125 | ) 126 | _val_loss += val_loss.item() 127 | self._ex.logger.info( 128 | f"epoch{epoch} val acc:{_val_accu.value()[0]:.2f}, val top5acc:{_val_accu.value()[1]:.2f}") 129 | # Test accu 130 | self._run.log_scalar(f"test_accu_trial{self._trial_i}_task{self._task}", _val_accu.value()[0], epoch + 1) 131 | self._run.log_scalar(f"test_5accu_trial{self._trial_i}_task{self._task}", _val_accu.value()[1], epoch + 1) 132 | self._tensorboard.add_scalar(f"trial{self._trial_i}_task{self._task}/test_accu", 133 | _val_accu.value()[0], epoch + 1) 134 | self._tensorboard.add_scalar(f"trial{self._trial_i}_task{self._task}/test_5accu", 135 | _val_accu.value()[1], epoch + 1) 136 | 137 | # Test new accu 138 | self._run.log_scalar(f"test_new_accu_trial{self._trial_i}_task{self._task}", 139 | _val_new_accu.value()[0], epoch + 1) 140 | self._tensorboard.add_scalar(f"trial{self._trial_i}_task{self._task}/test_new_accu", 141 | _val_new_accu.value()[0], epoch + 1) 142 | 143 | # Test old accu 144 | if self._task != 0: 145 | self._run.log_scalar(f"test_old_accu_trial{self._trial_i}_task{self._task}", 146 | _val_old_accu.value()[0], epoch + 1) 147 | self._tensorboard.add_scalar(f"trial{self._trial_i}_task{self._task}/test_old_accu", 148 | _val_old_accu.value()[0], epoch + 1) 149 | 150 | # Test loss 151 | self._run.log_scalar(f"test_loss_trial{self._trial_i}_task{self._task}", round(_val_loss / i, 3), epoch + 1) 152 | self._tensorboard.add_scalar(f"trial{self._trial_i}_task{self._task}/test_loss", round(_val_loss / i, 3), 153 | epoch + 1) 154 | self._tensorboard.close() -------------------------------------------------------------------------------- /inclearn/tools/memory.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from copy import deepcopy 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | from inclearn.tools.utils import get_class_loss 7 | from inclearn.convnet.utils import extract_features 8 | 9 | 10 | class MemorySize: 11 | def __init__(self, mode, inc_dataset, total_memory=None, fixed_memory_per_cls=None): 12 | self.mode = mode 13 | assert mode.lower() in ["uniform_fixed_per_cls", "uniform_fixed_total_mem", "dynamic_fixed_per_cls"] 14 | self.total_memory = total_memory 15 | self.fixed_memory_per_cls = fixed_memory_per_cls 16 | self._n_classes = 0 17 | self.mem_per_cls = [] 18 | self._inc_dataset = inc_dataset 19 | 20 | def update_n_classes(self, n_classes): 21 | self._n_classes = n_classes 22 | 23 | def update_memory_per_cls_uniform(self, n_classes): 24 | if "fixed_per_cls" in self.mode: 25 | self.mem_per_cls = [self.fixed_memory_per_cls for i in range(n_classes)] 26 | elif "fixed_total_mem" in self.mode: 27 | self.mem_per_cls = [self.total_memory // n_classes for i in range(n_classes)] 28 | return self.mem_per_cls 29 | 30 | def update_memory_per_cls(self, network, n_classes, task_size): 31 | if "uniform" in self.mode: 32 | self.update_memory_per_cls_uniform(n_classes) 33 | else: 34 | if n_classes == task_size: 35 | self.update_memory_per_cls_uniform(n_classes) 36 | 37 | @property 38 | def memsize(self): 39 | if self.mode == "fixed_total_mem": 40 | return self.total_memory 41 | elif self.mode == "fixed_per_cls": 42 | return self.fixed_memory_per_cls * self._n_classes 43 | 44 | 45 | def compute_examplar_mean(feat_norm, feat_flip, herding_mat, nb_max): 46 | EPSILON = 1e-8 47 | D = feat_norm.T 48 | D = D / (np.linalg.norm(D, axis=0) + EPSILON) 49 | 50 | D2 = feat_flip.T 51 | D2 = D2 / (np.linalg.norm(D2, axis=0) + EPSILON) 52 | 53 | alph = herding_mat 54 | alph = (alph > 0) * (alph < nb_max + 1) * 1.0 55 | 56 | alph_mean = alph / np.sum(alph) 57 | 58 | mean = (np.dot(D, alph_mean) + np.dot(D2, alph_mean)) / 2 59 | # mean = np.dot(D, alph_mean) 60 | mean /= np.linalg.norm(mean) + EPSILON 61 | 62 | return mean, alph 63 | 64 | 65 | def select_examplars(features, nb_max): 66 | EPSILON = 1e-8 67 | D = features.T 68 | D = D / (np.linalg.norm(D, axis=0) + EPSILON) 69 | mu = np.mean(D, axis=1) 70 | herding_matrix = np.zeros((features.shape[0], )) 71 | idxes = [] 72 | w_t = mu 73 | 74 | iter_herding, iter_herding_eff = 0, 0 75 | 76 | while not (np.sum(herding_matrix != 0) == min(nb_max, features.shape[0])) and iter_herding_eff < 1000: 77 | tmp_t = np.dot(w_t, D) 78 | # tmp_t = -np.linalg.norm(w_t[:,np.newaxis]-D, axis=0) 79 | # tmp_t = np.linalg.norm(w_t[:,np.newaxis]-D, axis=0) 80 | ind_max = np.argmax(tmp_t) 81 | iter_herding_eff += 1 82 | if herding_matrix[ind_max] == 0: 83 | herding_matrix[ind_max] = 1 + iter_herding 84 | idxes.append(ind_max) 85 | iter_herding += 1 86 | 87 | w_t = w_t + mu - D[:, ind_max] 88 | 89 | return herding_matrix, idxes 90 | 91 | 92 | def random_selection(n_classes, task_size, network, logger, inc_dataset, memory_per_class: list): 93 | # TODO: Move data_memroy,targets_memory into IncDataset 94 | logger.info("Building & updating memory.(Random Selection)") 95 | tmp_data_memory, tmp_targets_memory = [], [] 96 | assert len(memory_per_class) == n_classes 97 | for class_idx in range(n_classes): 98 | # 旧类数据从get_custom_loader_from_memory中读取,新类数据从get_custom_loader中读取 99 | if class_idx < n_classes - task_size: 100 | inputs, targets, loader = inc_dataset.get_custom_loader_from_memory([class_idx]) 101 | else: 102 | inputs, targets, loader = inc_dataset.get_custom_loader(class_idx, mode="test") 103 | memory_this_cls = min(memory_per_class[class_idx], inputs.shape[0]) 104 | idxs = np.random.choice(inputs.shape[0], memory_this_cls, replace=False) 105 | tmp_data_memory.append(inputs[idxs]) 106 | tmp_targets_memory.append(targets[idxs]) 107 | tmp_data_memory = np.concatenate(tmp_data_memory) 108 | tmp_targets_memory = np.concatenate(tmp_targets_memory) 109 | return tmp_data_memory, tmp_targets_memory 110 | 111 | 112 | def herding(n_classes, task_size, network, herding_matrix, inc_dataset, shared_data_inc, memory_per_class: list, 113 | logger): 114 | """Herding matrix: list 115 | """ 116 | logger.info("Building & updating memory.(iCaRL)") 117 | tmp_data_memory, tmp_targets_memory = [], [] 118 | 119 | for class_idx in range(n_classes): 120 | inputs = inc_dataset.data_train[inc_dataset.targets_train == class_idx] 121 | targets = inc_dataset.targets_train[inc_dataset.targets_train == class_idx] 122 | # zi = inc_dataset.zimages[inc_dataset.zlabels == class_idx] 123 | # zt = inc_dataset.zlabels[inc_dataset.zlabels == class_idx] 124 | # inputs = np.concatenate((inputs, zi)) 125 | # targets = np.concatenate((targets, zt)) 126 | 127 | 128 | if class_idx >= n_classes - task_size: 129 | if len(shared_data_inc) > len(inc_dataset.targets_inc): 130 | share_memory = [shared_data_inc[i] for i in np.where(inc_dataset.targets_inc == class_idx)[0].tolist()] 131 | else: 132 | share_memory = [] 133 | for i in np.where(inc_dataset.targets_inc == class_idx)[0].tolist(): 134 | if i < len(shared_data_inc): 135 | share_memory.append(shared_data_inc[i]) 136 | 137 | # share_memory = [shared_data_inc[i] for i in np.where(inc_dataset.targets_inc == class_idx)[0].tolist()] 138 | loader = inc_dataset._get_loader(inc_dataset.data_inc[inc_dataset.targets_inc == class_idx], 139 | inc_dataset.targets_inc[inc_dataset.targets_inc == class_idx], 140 | share_memory=share_memory, 141 | batch_size=128, 142 | shuffle=False, 143 | mode="test") 144 | features, _ = extract_features(network, loader) 145 | # features_flipped, _ = extract_features(network, inc_dataset.get_custom_loader(class_idx, mode="flip")[-1]) 146 | herding_matrix.append(select_examplars(features, memory_per_class[class_idx])[0]) 147 | alph = herding_matrix[class_idx] 148 | alph = (alph > 0) * (alph < memory_per_class[class_idx] + 1) * 1.0 149 | # examplar_mean, alph = compute_examplar_mean(features, features_flipped, herding_matrix[class_idx], 150 | # memory_per_class[class_idx]) 151 | tmp_data_memory.append(inputs[np.where(alph == 1)[0]]) 152 | tmp_targets_memory.append(targets[np.where(alph == 1)[0]]) 153 | tmp_data_memory = np.concatenate(tmp_data_memory) 154 | tmp_targets_memory = np.concatenate(tmp_targets_memory) 155 | return tmp_data_memory, tmp_targets_memory, herding_matrix 156 | -------------------------------------------------------------------------------- /inclearn/prune/prune/structured.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from copy import deepcopy 4 | from functools import reduce 5 | from operator import mul 6 | 7 | __all__=['prune_conv', 'prune_related_conv', 'prune_linear', 'prune_related_linear', 'prune_batchnorm', 'prune_prelu', 'prune_group_conv'] 8 | 9 | def prune_group_conv(layer: nn.modules.conv._ConvNd, idxs: list, inplace: bool=True, dry_run: bool=False): 10 | """Prune `filters` for the convolutional layer, e.g. [256 x 128 x 3 x 3] => [192 x 128 x 3 x 3] 11 | 12 | Args: 13 | - layer: a convolution layer. 14 | - idxs: pruning index. 15 | """ 16 | 17 | if layer.groups>1: 18 | assert layer.groups==layer.in_channels and layer.groups==layer.out_channels, "only group conv with in_channel==groups==out_channels is supported" 19 | 20 | 21 | idxs = list(set(idxs)) 22 | num_pruned = len(idxs) * reduce(mul, layer.weight.shape[1:]) + (len(idxs) if layer.bias is not None else 0) 23 | if dry_run: 24 | return layer, num_pruned 25 | if not inplace: 26 | layer = deepcopy(layer) 27 | keep_idxs = [idx for idx in range(layer.out_channels) if idx not in idxs] 28 | layer.out_channels = layer.out_channels-len(idxs) 29 | layer.in_channels = layer.in_channels-len(idxs) 30 | 31 | layer.groups = layer.groups-len(idxs) 32 | layer.weight = torch.nn.Parameter(layer.weight.data.clone()[keep_idxs]) 33 | if layer.bias is not None: 34 | layer.bias = torch.nn.Parameter(layer.bias.data.clone()[keep_idxs]) 35 | 36 | return layer, num_pruned 37 | 38 | def prune_conv(layer: nn.modules.conv._ConvNd, idxs: list, inplace: bool=True, dry_run: bool=False): 39 | """Prune `filters` for the convolutional layer, e.g. [256 x 128 x 3 x 3] => [192 x 128 x 3 x 3] 40 | 41 | Args: 42 | - layer: a convolution layer. 43 | - idxs: pruning index. 44 | """ 45 | idxs = list(set(idxs)) 46 | num_pruned = len(idxs) * reduce(mul, layer.weight.shape[1:]) + (len(idxs) if layer.bias is not None else 0) 47 | if dry_run: 48 | return layer, num_pruned 49 | 50 | if not inplace: 51 | layer = deepcopy(layer) 52 | 53 | keep_idxs = [idx for idx in range(layer.out_channels) if idx not in idxs] 54 | layer.out_channels = layer.out_channels-len(idxs) 55 | if isinstance(layer,(nn.ConvTranspose2d,nn.ConvTranspose3d)): 56 | layer.weight = torch.nn.Parameter(layer.weight.data.clone()[:, keep_idxs]) 57 | else: 58 | layer.weight = torch.nn.Parameter(layer.weight.data.clone()[keep_idxs]) 59 | if layer.bias is not None: 60 | layer.bias = torch.nn.Parameter(layer.bias.data.clone()[keep_idxs]) 61 | return layer, num_pruned 62 | 63 | def prune_related_conv(layer: nn.modules.conv._ConvNd, idxs: list, inplace: bool=True, dry_run: bool=False): 64 | """Prune `kernels` for the related (affected) convolutional layer, e.g. [256 x 128 x 3 x 3] => [256 x 96 x 3 x 3] 65 | 66 | Args: 67 | layer: a convolutional layer. 68 | idxs: pruning index. 69 | """ 70 | idxs = list(set(idxs)) 71 | num_pruned = len(idxs) * layer.weight.shape[0] * reduce(mul ,layer.weight.shape[2:]) 72 | if dry_run: 73 | return layer, num_pruned 74 | if not inplace: 75 | layer = deepcopy(layer) 76 | 77 | 78 | keep_idxs = [i for i in range(layer.in_channels) if i not in idxs] 79 | 80 | layer.in_channels = layer.in_channels - len(idxs) 81 | 82 | if isinstance(layer,(nn.ConvTranspose2d,nn.ConvTranspose3d)): 83 | layer.weight = torch.nn.Parameter(layer.weight.data.clone()[keep_idxs,:]) 84 | else: 85 | layer.weight = torch.nn.Parameter(layer.weight.data.clone()[:, keep_idxs]) 86 | # no bias pruning because it does not change the output size 87 | return layer, num_pruned 88 | 89 | def prune_linear(layer: nn.modules.linear.Linear, idxs: list, inplace: list=True, dry_run: list=False): 90 | """Prune neurons for the fully-connected layer, e.g. [256 x 128] => [192 x 128] 91 | 92 | Args: 93 | layer: a fully-connected layer. 94 | idxs: pruning index. 95 | """ 96 | num_pruned = len(idxs)*layer.weight.shape[1] + (len(idxs) if layer.bias is not None else 0) 97 | if dry_run: 98 | return layer, num_pruned 99 | 100 | if not inplace: 101 | layer = deepcopy(layer) 102 | keep_idxs = [i for i in range(layer.out_features) if i not in idxs] 103 | layer.out_features = layer.out_features-len(idxs) 104 | layer.weight = torch.nn.Parameter(layer.weight.data.clone()[keep_idxs]) 105 | if layer.bias is not None: 106 | layer.bias = torch.nn.Parameter(layer.bias.data.clone()[keep_idxs]) 107 | return layer, num_pruned 108 | 109 | def prune_related_linear(layer: nn.modules.linear.Linear, idxs: list, inplace: list=True, dry_run: list=False): 110 | """Prune weights for the related (affected) fully-connected layer, e.g. [256 x 128] => [256 x 96] 111 | 112 | Args: 113 | layer: a fully-connected layer. 114 | idxs: pruning index. 115 | """ 116 | num_pruned = len(idxs) * layer.weight.shape[0] 117 | if dry_run: 118 | return layer, num_pruned 119 | 120 | if not inplace: 121 | layer = deepcopy(layer) 122 | keep_idxs = [i for i in range(layer.in_features) if i not in idxs] 123 | layer.in_features = layer.in_features-len(idxs) 124 | layer.weight = torch.nn.Parameter(layer.weight.data.clone()[:, keep_idxs]) 125 | return layer, num_pruned 126 | 127 | def prune_batchnorm(layer: nn.modules.batchnorm._BatchNorm, idxs: list, inplace: bool=True, dry_run: bool=False ): 128 | """Prune batch normalization layers, e.g. [128] => [64] 129 | 130 | Args: 131 | layer: a batch normalization layer. 132 | idxs: pruning index. 133 | """ 134 | 135 | num_pruned = len(idxs)* ( 2 if layer.affine else 1) 136 | if dry_run: 137 | return layer, num_pruned 138 | 139 | if not inplace: 140 | layer = deepcopy(layer) 141 | 142 | keep_idxs = [i for i in range(layer.num_features) if i not in idxs] 143 | layer.num_features = layer.num_features-len(idxs) 144 | layer.running_mean = layer.running_mean.data.clone()[keep_idxs] 145 | layer.running_var = layer.running_var.data.clone()[keep_idxs] 146 | if layer.affine: 147 | layer.weight = torch.nn.Parameter(layer.weight.data.clone()[keep_idxs]) 148 | layer.bias = torch.nn.Parameter(layer.bias.data.clone()[keep_idxs]) 149 | return layer, num_pruned 150 | 151 | def prune_prelu(layer: nn.PReLU, idxs: list, inplace: bool=True, dry_run: bool=False): 152 | """Prune PReLU layers, e.g. [128] => [64] or [1] => [1] (no pruning if prelu has only 1 parameter) 153 | 154 | Args: 155 | layer: a PReLU layer. 156 | idxs: pruning index. 157 | """ 158 | num_pruned = 0 if layer.num_parameters==1 else len(idxs) 159 | if dry_run: 160 | return layer, num_pruned 161 | if not inplace: 162 | layer = deepcopy(layer) 163 | if layer.num_parameters==1: return layer, num_pruned 164 | keep_idxs = [i for i in range(layer.num_parameters) if i not in idxs] 165 | layer.num_parameters = layer.num_parameters-len(idxs) 166 | layer.weight = torch.nn.Parameter(layer.weight.data.clone()[keep_idxs]) 167 | return layer, num_pruned 168 | 169 | 170 | -------------------------------------------------------------------------------- /inclearn/convnet/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.optim import SGD 5 | import torch.nn.functional as F 6 | from inclearn.tools.metrics import ClassErrorMeter, AverageValueMeter 7 | from inclearn.loss.focalloss import BCEFocalLoss, MultiCEFocalLoss 8 | 9 | 10 | def finetune_last_layer( 11 | logger, 12 | network, 13 | loader, 14 | n_class, 15 | nepoch=30, 16 | lr=0.1, 17 | scheduling=[15, 35], 18 | lr_decay=0.1, 19 | weight_decay=5e-4, 20 | loss_type="ce", 21 | temperature=5.0, 22 | test_loader=None, 23 | samples_per_cls = [] 24 | ): 25 | network.eval() 26 | #if hasattr(network.module, "convnets"): 27 | # for net in network.module.convnets: 28 | # net.eval() 29 | #else: 30 | # network.module.convnet.eval() 31 | optim = SGD(network.module.classifier.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay) 32 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optim, scheduling, gamma=lr_decay) 33 | 34 | if loss_type == "ce": 35 | criterion = nn.CrossEntropyLoss() 36 | else: 37 | criterion = nn.BCEWithLogitsLoss() 38 | 39 | logger.info("Begin finetuning last layer") 40 | 41 | for i in range(nepoch): 42 | total_loss = 0.0 43 | total_correct = 0.0 44 | total_count = 0 45 | # print(f"dataset loader length {len(loader.dataset)}") 46 | for inputs, targets in loader: 47 | inputs, targets = inputs.cuda(), targets.cuda() 48 | if loss_type == "bce": 49 | targets = to_onehot(targets, n_class) 50 | outputs = network(inputs)['logit'] 51 | _, preds = outputs.max(1) 52 | optim.zero_grad() 53 | if loss_type == "cb": 54 | loss = CB_loss(targets, outputs / temperature, samples_per_cls, n_class, "focal") 55 | else: 56 | loss = criterion(outputs / temperature, targets) 57 | loss.backward() 58 | optim.step() 59 | total_loss += loss * inputs.size(0) 60 | total_correct += (preds == targets).sum() 61 | total_count += inputs.size(0) 62 | 63 | if test_loader is not None: 64 | test_correct = 0.0 65 | test_count = 0.0 66 | with torch.no_grad(): 67 | for inputs, targets in test_loader: 68 | outputs = network(inputs.cuda())['logit'] 69 | _, preds = outputs.max(1) 70 | test_correct += (preds.cpu() == targets).sum().item() 71 | test_count += inputs.size(0) 72 | 73 | scheduler.step() 74 | if test_loader is not None: 75 | logger.info( 76 | "Epoch %d finetuning loss %.3f acc %.3f Eval %.3f" % 77 | (i, total_loss.item() / total_count, total_correct.item() / total_count, test_correct / test_count)) 78 | else: 79 | logger.info("Epoch %d finetuning loss %.3f acc %.3f" % 80 | (i, total_loss.item() / total_count, total_correct.item() / total_count)) 81 | return network 82 | 83 | def CB_loss(labels, logits, samples_per_cls, no_of_classes, loss_type, beta=0.99, gamma=2.0): 84 | """Compute the Class Balanced Loss between `logits` and the ground truth `labels`. 85 | Class Balanced Loss: ((1-beta)/(1-beta^n))*Loss(labels, logits) 86 | where Loss is one of the standard losses used for Neural Networks. 87 | Args: 88 | labels: A int tensor of size [batch]. 89 | logits: A float tensor of size [batch, no_of_classes]. 90 | samples_per_cls: A python list of size [no_of_classes]. 91 | no_of_classes: total number of classes. int 92 | loss_type: string. One of "sigmoid", "focal", "softmax". 93 | beta: float. Hyperparameter for Class balanced loss. 94 | gamma: float. Hyperparameter for Focal loss. 95 | Returns: 96 | cb_loss: A float tensor representing class balanced loss 97 | """ 98 | effective_num = 1.0 - np.power(beta, samples_per_cls) 99 | # print(labels.device, logits.device) 100 | weights = (1.0 - beta) / np.array(effective_num) 101 | weights = weights / np.sum(weights) * no_of_classes 102 | 103 | labels_one_hot = F.one_hot(labels, no_of_classes).float().cpu() 104 | 105 | weights = torch.tensor(weights).float() 106 | # weights = weights.unsqueeze(0) 107 | # print(labels_one_hot.device, weights.device) 108 | # weights = weights.repeat(labels_one_hot.shape[0],1) * labels_one_hot 109 | # weights = weights.sum(1) 110 | # weights = weights.unsqueeze(1) 111 | # weights = weights.repeat(1,no_of_classes) 112 | 113 | if loss_type == "focal": 114 | # print(effective_num.shape, weights.shape, len(samples_per_cls)) 115 | focalloss = MultiCEFocalLoss(class_num=no_of_classes, gamma=gamma, alpha=weights).cuda() 116 | cb_loss = focalloss(logits, labels) 117 | # cb_loss = self.focal_loss(labels_one_hot, logits, weights, gamma) 118 | elif loss_type == "sigmoid": 119 | cb_loss = F.binary_cross_entropy_with_logits(input = logits,target = labels_one_hot, weights = weights) 120 | elif loss_type == "softmax": 121 | pred = logits.softmax(dim = 1) 122 | cb_loss = F.binary_cross_entropy(input = pred, target = labels_one_hot, weight = weights) 123 | return cb_loss 124 | 125 | def extract_features(model, loader): 126 | targets, features = [], [] 127 | model.eval() 128 | with torch.no_grad(): 129 | for _inputs, _targets in loader: 130 | _inputs = _inputs.cuda() 131 | _targets = _targets.numpy() 132 | _features = model(_inputs)['feature'].detach().cpu().numpy() 133 | features.append(_features) 134 | targets.append(_targets) 135 | 136 | return np.concatenate(features), np.concatenate(targets) 137 | 138 | 139 | def calc_class_mean(network, loader, class_idx, metric): 140 | EPSILON = 1e-8 141 | features, targets = extract_features(network, loader) 142 | # norm_feats = features/(np.linalg.norm(features, axis=1)[:,np.newaxis]+EPSILON) 143 | # examplar_mean = norm_feats.mean(axis=0) 144 | examplar_mean = features.mean(axis=0) 145 | if metric == "cosine" or metric == "weight": 146 | examplar_mean /= (np.linalg.norm(examplar_mean) + EPSILON) 147 | return examplar_mean 148 | 149 | 150 | def update_classes_mean(network, inc_dataset, n_classes, task_size, share_memory=None, metric="cosine", EPSILON=1e-8): 151 | loader = inc_dataset._get_loader(inc_dataset.data_inc, 152 | inc_dataset.targets_inc, 153 | shuffle=False, 154 | share_memory=share_memory, 155 | mode="test") 156 | class_means = np.zeros((n_classes, network.module.features_dim)) 157 | count = np.zeros(n_classes) 158 | network.eval() 159 | with torch.no_grad(): 160 | for x, y in loader: 161 | feat = network(x.cuda())['feature'] 162 | for lbl in torch.unique(y): 163 | class_means[lbl] += feat[y == lbl].sum(0).cpu().numpy() 164 | count[lbl] += feat[y == lbl].shape[0] 165 | for i in range(n_classes): 166 | class_means[i] /= count[i] 167 | if metric == "cosine" or metric == "weight": 168 | class_means[i] /= (np.linalg.norm(class_means) + EPSILON) 169 | return class_means 170 | -------------------------------------------------------------------------------- /inclearn/models/prune.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os, sys 3 | import shutil 4 | import time 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.parallel 9 | import torch.backends.cudnn as cudnn 10 | import torch.optim 11 | import torch.utils.data 12 | import torchvision.transforms as transforms 13 | import torchvision.datasets as datasets 14 | import torchvision.models 15 | # from models import print_log 16 | import numpy as np 17 | from collections import OrderedDict 18 | 19 | class AverageMeter(object): 20 | """Computes and stores the average and current value""" 21 | 22 | def __init__(self): 23 | self.reset() 24 | 25 | def reset(self): 26 | self.val = 0 27 | self.avg = 0 28 | self.sum = 0 29 | self.count = 0 30 | 31 | def update(self, val, n=1): 32 | self.val = val 33 | self.sum += val * n 34 | self.count += n 35 | self.avg = self.sum / self.count 36 | 37 | class Mask: 38 | def __init__(self, model, cfg): 39 | self.model_size = {} 40 | self.model_length = {} 41 | self.compress_rate = {} 42 | self.mat = {} 43 | self.model = model 44 | self.mask_index = [] 45 | self.cfg = cfg 46 | 47 | def get_codebook(self, weight_torch, compress_rate, length): 48 | weight_vec = weight_torch.view(length) 49 | weight_np = weight_vec.cpu().numpy() 50 | 51 | weight_abs = np.abs(weight_np) 52 | weight_sort = np.sort(weight_abs) 53 | 54 | threshold = weight_sort[int(length * (1 - compress_rate))] 55 | weight_np[weight_np <= -threshold] = 1 56 | weight_np[weight_np >= threshold] = 1 57 | weight_np[weight_np != 1] = 0 58 | 59 | # print("codebook done") 60 | return weight_np 61 | 62 | def get_filter_codebook(self, weight_torch, compress_rate, length): 63 | codebook = np.ones(length) 64 | if len(weight_torch.size()) == 4: 65 | filter_pruned_num = int(weight_torch.size()[0] * (1 - compress_rate)) 66 | weight_vec = weight_torch.view(weight_torch.size()[0], -1) 67 | # norm1 = torch.norm(weight_vec, 1, 1) 68 | # norm1_np = norm1.cpu().numpy() 69 | norm2 = torch.norm(weight_vec, 2, 1) 70 | norm2_np = norm2.cpu().numpy() 71 | filter_index = norm2_np.argsort()[:filter_pruned_num] 72 | # norm1_sort = np.sort(norm1_np) 73 | # threshold = norm1_sort[int (weight_torch.size()[0] * (1-compress_rate) )] 74 | kernel_length = weight_torch.size()[1] * weight_torch.size()[2] * weight_torch.size()[3] 75 | for x in range(0, len(filter_index)): 76 | codebook[filter_index[x] * kernel_length: (filter_index[x] + 1) * kernel_length] = 0 77 | 78 | # print("filter codebook done") 79 | else: 80 | pass 81 | return codebook 82 | 83 | def convert2tensor(self, x): 84 | x = torch.FloatTensor(x) 85 | return x 86 | 87 | def init_length(self): 88 | for index, item in enumerate(self.model.parameters()): 89 | self.model_size[index] = item.size() 90 | 91 | for index1 in self.model_size: 92 | for index2 in range(0, len(self.model_size[index1])): 93 | if index2 == 0: 94 | self.model_length[index1] = self.model_size[index1][0] 95 | else: 96 | self.model_length[index1] *= self.model_size[index1][index2] 97 | 98 | def init_rate(self, layer_rate): 99 | if 'vgg' in self.cfg["arch"]: 100 | cfg_5x = [24, 22, 41, 51, 108, 89, 111, 184, 276, 228, 512, 512, 512] 101 | cfg_official = [64, 64, 128, 128, 256, 256, 256, 512, 512, 512, 512, 512, 512] 102 | # cfg = [32, 64, 128, 128, 256, 256, 256, 256, 256, 256, 256, 256, 256] 103 | cfg_index = 0 104 | pre_cfg = True 105 | for index, item in enumerate(self.model.named_parameters()): 106 | self.compress_rate[index] = 1 107 | if len(item[1].size()) == 4: 108 | # print(item[1].size()) 109 | if not pre_cfg: 110 | self.compress_rate[index] = layer_rate 111 | self.mask_index.append(index) 112 | # print(item[0], "self.mask_index", self.mask_index) 113 | else: 114 | self.compress_rate[index] = 1 - cfg_5x[cfg_index] / item[1].size()[0] 115 | self.mask_index.append(index) 116 | # print(item[0], "self.mask_index", self.mask_index, cfg_index, cfg_5x[cfg_index], item[1].size()[0],) 117 | cfg_index += 1 118 | elif "resnet" in self.cfg["arch"]: 119 | for index, item in enumerate(self.model.parameters()): 120 | self.compress_rate[index] = 1 121 | for key in range(self.cfg["layer_begin"], self.cfg["layer_end"] + 1, self.cfg["layer_inter"]): 122 | self.compress_rate[key] = layer_rate 123 | if self.cfg["arch"] == 'resnet18': 124 | # last index include last fc layer 125 | last_index = 59 126 | skip_list = [21, 36, 51] 127 | elif self.cfg["arch"] == 'resnet34': 128 | last_index = 108 129 | skip_list = [27, 54, 93] 130 | elif self.cfg["arch"] == 'resnet50': 131 | last_index = 159 132 | skip_list = [12, 42, 81, 138] 133 | elif self.cfg["arch"] == 'resnet101': 134 | last_index = 312 135 | skip_list = [12, 42, 81, 291] 136 | elif self.cfg["arch"] == 'resnet152': 137 | last_index = 465 138 | skip_list = [12, 42, 117, 444] 139 | self.mask_index = [x for x in range(0, last_index, 3)] 140 | # skip downsample layer 141 | if self.cfg["skip_downsample"] == 1: 142 | for x in skip_list: 143 | self.compress_rate[x] = 1 144 | self.mask_index.remove(x) 145 | # print(self.mask_index) 146 | else: 147 | pass 148 | 149 | def init_mask(self, layer_rate): 150 | self.init_rate(layer_rate) 151 | for index, item in enumerate(self.model.parameters()): 152 | if (index in self.mask_index): 153 | self.mat[index] = self.get_filter_codebook(item.data, self.compress_rate[index], 154 | self.model_length[index]) 155 | self.mat[index] = self.convert2tensor(self.mat[index]) 156 | if self.cfg["use_cuda"]: 157 | self.mat[index] = self.mat[index].cuda() 158 | print("mask Ready") 159 | 160 | def do_mask(self): 161 | for index, item in enumerate(self.model.parameters()): 162 | if (index in self.mask_index): 163 | a = item.data.view(self.model_length[index]) 164 | b = a * self.mat[index] 165 | item.data = b.view(self.model_size[index]) 166 | print("mask Done") 167 | 168 | def if_zero(self): 169 | for index, item in enumerate(self.model.parameters()): 170 | # if(index in self.mask_index): 171 | if index in [x for x in range(self.cfg["layer_begin"], self.cfg["layer_end"] + 1, self.cfg["layer_inter"])]: 172 | a = item.data.view(self.model_length[index]) 173 | b = a.cpu().numpy() 174 | 175 | # print("layer: %d, number of nonzero weight is %d, zero is %d" % (index, np.count_nonzero(b), len(b) - np.count_nonzero(b))) 176 | -------------------------------------------------------------------------------- /inclearn/tools/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import numbers 4 | import math 5 | 6 | 7 | class IncConfusionMeter: 8 | """Maintains a confusion matrix for a given calssification problem. 9 | The ConfusionMeter constructs a confusion matrix for a multi-class 10 | classification problems. It does not support multi-label, multi-class problems: 11 | for such problems, please use MultiLabelConfusionMeter. 12 | Args: 13 | k (int): number of classes in the classification problem 14 | normalized (boolean): Determines whether or not the confusion matrix 15 | is normalized or not 16 | """ 17 | def __init__(self, k, increments, normalized=False): 18 | self.conf = np.ndarray((k, k), dtype=np.int32) 19 | self.normalized = normalized 20 | self.increments = increments 21 | self.cum_increments = [0] + [sum(increments[:i + 1]) for i in range(len(increments))] 22 | self.k = k 23 | self.reset() 24 | 25 | def reset(self): 26 | self.conf.fill(0) 27 | 28 | def add(self, predicted, target): 29 | """Computes the confusion matrix of K x K size where K is no of classes 30 | Args: 31 | predicted (tensor): Can be an N x K tensor of predicted scores obtained from 32 | the model for N examples and K classes or an N-tensor of 33 | integer values between 0 and K-1. 34 | target (tensor): Can be a N-tensor of integer values assumed to be integer 35 | values between 0 and K-1 or N x K tensor, where targets are 36 | assumed to be provided as one-hot vectors 37 | """ 38 | if isinstance(predicted, torch.Tensor): 39 | predicted = predicted.cpu().numpy() 40 | if isinstance(target, torch.Tensor): 41 | target = target.cpu().numpy() 42 | 43 | assert predicted.shape[0] == target.shape[0], \ 44 | 'number of targets and predicted outputs do not match' 45 | 46 | if np.ndim(predicted) != 1: 47 | assert predicted.shape[1] == self.k, \ 48 | 'number of predictions does not match size of confusion matrix' 49 | predicted = np.argmax(predicted, 1) 50 | else: 51 | assert (predicted.max() < self.k) and (predicted.min() >= 0), \ 52 | 'predicted values are not between 1 and k' 53 | 54 | onehot_target = np.ndim(target) != 1 55 | if onehot_target: 56 | assert target.shape[1] == self.k, \ 57 | 'Onehot target does not match size of confusion matrix' 58 | assert (target >= 0).all() and (target <= 1).all(), \ 59 | 'in one-hot encoding, target values should be 0 or 1' 60 | assert (target.sum(1) == 1).all(), \ 61 | 'multi-label setting is not supported' 62 | target = np.argmax(target, 1) 63 | else: 64 | assert (predicted.max() < self.k) and (predicted.min() >= 0), \ 65 | 'predicted values are not between 0 and k-1' 66 | 67 | # hack for bincounting 2 arrays together 68 | x = predicted + self.k * target 69 | bincount_2d = np.bincount(x.astype(np.int32), minlength=self.k**2) 70 | assert bincount_2d.size == self.k**2 71 | conf = bincount_2d.reshape((self.k, self.k)) 72 | 73 | self.conf += conf 74 | 75 | def value(self): 76 | """ 77 | Returns: 78 | Confustion matrix of K rows and K columns, where rows corresponds 79 | to ground-truth targets and columns corresponds to predicted 80 | targets. 81 | """ 82 | conf = self.conf.astype(np.float32) 83 | new_conf = np.zeros([len(self.increments), len(self.increments) + 2]) 84 | for i in range(len(self.increments)): 85 | idxs = range(self.cum_increments[i], self.cum_increments[i + 1]) 86 | new_conf[i, 0] = conf[idxs, idxs].sum() 87 | new_conf[i, 1] = conf[self.cum_increments[i]:self.cum_increments[i + 1], 88 | self.cum_increments[i]:self.cum_increments[i + 1]].sum() - new_conf[i, 0] 89 | for j in range(len(self.increments)): 90 | new_conf[i, j + 2] = conf[self.cum_increments[i]:self.cum_increments[i + 1], 91 | self.cum_increments[j]:self.cum_increments[j + 1]].sum() 92 | conf = new_conf 93 | if self.normalized: 94 | return conf / conf[:, 2:].sum(1).clip(min=1e-12)[:, None] 95 | else: 96 | return conf 97 | 98 | 99 | class ClassErrorMeter: 100 | def __init__(self, topk=[1], accuracy=False): 101 | super(ClassErrorMeter, self).__init__() 102 | self.topk = np.sort(topk) 103 | self.accuracy = accuracy 104 | self.reset() 105 | 106 | def reset(self): 107 | self.sum = {v: 0 for v in self.topk} 108 | self.n = 0 109 | 110 | def add(self, output, target): 111 | if isinstance(output, np.ndarray): 112 | output = torch.Tensor(output) 113 | if isinstance(target, np.ndarray): 114 | target = torch.Tensor(target) 115 | # if torch.is_tensor(output): 116 | # output = output.cpu().squeeze().numpy() 117 | # if torch.is_tensor(target): 118 | # target = target.cpu().squeeze().numpy() 119 | # elif isinstance(target, numbers.Number): 120 | # target = np.asarray([target]) 121 | # if np.ndim(output) == 1: 122 | # output = output[np.newaxis] 123 | # else: 124 | # assert np.ndim(output) == 2, \ 125 | # 'wrong output size (1D or 2D expected)' 126 | # assert np.ndim(target) == 1, \ 127 | # 'target and output do not match' 128 | # assert target.shape[0] == output.shape[0], \ 129 | # 'target and output do not match' 130 | topk = self.topk 131 | maxk = int(topk[-1]) # seems like Python3 wants int and not np.int64 132 | no = output.shape[0] 133 | 134 | pred = output.topk(maxk, 1, True, True)[1] 135 | correct = pred == target.unsqueeze(1).repeat(1, pred.shape[1]) 136 | # pred = torch.from_numpy(output).topk(maxk, 1, True, True)[1].numpy() 137 | # correct = pred == target[:, np.newaxis].repeat(pred.shape[1], 1) 138 | 139 | for k in topk: 140 | self.sum[k] += no - correct[:, 0:k].sum() 141 | self.n += no 142 | 143 | def value(self, k=-1): 144 | if k != -1: 145 | assert k in self.sum.keys(), \ 146 | 'invalid k (this k was not provided at construction time)' 147 | if self.n == 0: 148 | return float('nan') 149 | if self.accuracy: 150 | return (1. - float(self.sum[k]) / self.n) * 100.0 151 | else: 152 | return float(self.sum[k]) / self.n * 100.0 153 | else: 154 | return [self.value(k_) for k_ in self.topk] 155 | 156 | 157 | class AverageValueMeter: 158 | def __init__(self): 159 | super(AverageValueMeter, self).__init__() 160 | self.reset() 161 | self.val = 0 162 | 163 | def add(self, value, n=1): 164 | self.val = value 165 | self.sum += value 166 | self.var += value * value 167 | self.n += n 168 | 169 | if self.n == 0: 170 | self.mean, self.std = np.nan, np.nan 171 | elif self.n == 1: 172 | self.mean, self.std = self.sum, np.inf 173 | self.mean_old = self.mean 174 | self.m_s = 0.0 175 | else: 176 | self.mean = self.mean_old + (value - n * self.mean_old) / float(self.n) 177 | self.m_s += (value - self.mean_old) * (value - self.mean) 178 | self.mean_old = self.mean 179 | self.std = math.sqrt(self.m_s / (self.n - 1.0)) 180 | 181 | def value(self): 182 | return self.mean, self.std 183 | 184 | def reset(self): 185 | self.n = 0 186 | self.sum = 0.0 187 | self.var = 0.0 188 | self.val = 0.0 189 | self.mean = np.nan 190 | self.mean_old = 0.0 191 | self.m_s = 0.0 192 | self.std = np.nan -------------------------------------------------------------------------------- /inclearn/convnet/network.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import pdb 3 | 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | 8 | from inclearn.tools import factory 9 | from inclearn.convnet.imbalance import CR, All_av 10 | from inclearn.convnet.classifier import CosineClassifier 11 | 12 | 13 | class BasicNet(nn.Module): 14 | def __init__( 15 | self, 16 | convnet_type, 17 | cfg, 18 | nf=64, 19 | use_bias=False, 20 | init="kaiming", 21 | device=None, 22 | dataset="cifar100", 23 | ): 24 | super(BasicNet, self).__init__() 25 | self.nf = nf 26 | self.init = init 27 | self.convnet_type = convnet_type 28 | self.dataset = dataset 29 | self.start_class = cfg['start_class'] 30 | self.weight_normalization = cfg['weight_normalization'] 31 | self.remove_last_relu = True if self.weight_normalization else False 32 | self.use_bias = use_bias if not self.weight_normalization else False 33 | self.dea = cfg['dea'] 34 | self.ft_type = cfg.get('feature_type', 'normal') 35 | self.at_res = cfg.get('attention_use_residual', False) 36 | self.div_type = cfg['div_type'] 37 | self.reuse_oldfc = cfg['reuse_oldfc'] 38 | self.prune = cfg.get('prune', False) 39 | self.reset = cfg.get('reset_se', True) 40 | 41 | 42 | if self.dea: 43 | print("Enable dynamical reprensetation expansion!") 44 | self.convnets = nn.ModuleList() 45 | self.convnets.append( 46 | factory.get_convnet(convnet_type, 47 | nf=nf, 48 | dataset=dataset, 49 | start_class=self.start_class, 50 | remove_last_relu=self.remove_last_relu)) 51 | self.out_dim = self.convnets[0].out_dim 52 | else: 53 | self.convnet = factory.get_convnet(convnet_type, 54 | nf=nf, 55 | dataset=dataset, 56 | remove_last_relu=self.remove_last_relu) 57 | self.out_dim = self.convnet.out_dim 58 | self.classifier = None 59 | self.se = None 60 | self.aux_classifier = None 61 | 62 | self.n_classes = 0 63 | self.ntask = 0 64 | self.device = device 65 | 66 | if cfg['postprocessor']['enable']: 67 | if cfg['postprocessor']['type'].lower() == "cr": 68 | self.postprocessor = CR() 69 | elif cfg['postprocessor']['type'].lower() == "aver": 70 | self.postprocessor = All_av() 71 | else: 72 | self.postprocessor = None 73 | 74 | self.to(self.device) 75 | 76 | def forward(self, x): 77 | if self.classifier is None: 78 | raise Exception("Add some classes before training.") 79 | 80 | if self.dea: 81 | feature = [convnet(x) for convnet in self.convnets] 82 | features = torch.cat(feature, 1) 83 | last_dim = feature[-1].size(1) 84 | width = features.size(1) 85 | 86 | if self.reset: 87 | se = factory.get_attention(width, self.ft_type, self.at_res).to(self.device) 88 | features = se(features) 89 | else: 90 | features = self.se(features) 91 | 92 | else: 93 | features = self.convnet(x) 94 | 95 | logits = self.classifier(features) 96 | 97 | div_logits = self.aux_classifier(features[:, -last_dim:]) if self.ntask > 1 else None 98 | 99 | return {'feature': features, 'logit': logits, 'div_logit': div_logits, 'features': feature} 100 | 101 | def caculate_dim(self, x): 102 | feature = [convnet(x) for convnet in self.convnets] 103 | features = torch.cat(feature, 1) 104 | 105 | width = features.size(1) 106 | 107 | # se = factory.get_attention(width, self.ft_type, self.at_res).to(self.device) 108 | se = factory.get_attention(width, "ce", self.at_res).cuda() 109 | features = se(features) 110 | 111 | # import pdb 112 | # pdb.set_trace() 113 | return features.size(1), feature[-1].size(1) 114 | 115 | @property 116 | def features_dim(self): 117 | if self.dea: 118 | return self.out_dim * len(self.convnets) 119 | else: 120 | return self.out_dim 121 | 122 | def freeze(self): 123 | for param in self.parameters(): 124 | param.requires_grad = False 125 | self.eval() 126 | return self 127 | 128 | def copy(self): 129 | return copy.deepcopy(self) 130 | 131 | def add_classes(self, n_classes): 132 | self.ntask += 1 133 | 134 | if self.dea: 135 | self._add_classes_multi_fc(n_classes) 136 | else: 137 | self._add_classes_single_fc(n_classes) 138 | 139 | self.n_classes += n_classes 140 | 141 | def _add_classes_multi_fc(self, n_classes): 142 | if self.ntask > 1: 143 | new_clf = factory.get_convnet(self.convnet_type, 144 | nf=self.nf, 145 | dataset=self.dataset, 146 | start_class=self.start_class, 147 | remove_last_relu=self.remove_last_relu).to(self.device) 148 | if self.prune: 149 | pass 150 | else: 151 | new_clf.load_state_dict(self.convnets[-1].state_dict()) 152 | self.convnets.append(new_clf) 153 | 154 | if not self.reset: 155 | self.se = factory.get_attention(512*len(self.convnets), self.ft_type, self.at_res) 156 | self.se.to(self.device) 157 | 158 | if self.classifier is not None: 159 | weight = copy.deepcopy(self.classifier.weight.data) 160 | 161 | fc = self._gen_classifier(self.out_dim * len(self.convnets), self.n_classes + n_classes) 162 | 163 | if self.classifier is not None and self.reuse_oldfc: 164 | fc.weight.data[:self.n_classes, :self.out_dim * (len(self.convnets) - 1)] = weight 165 | del self.classifier 166 | self.classifier = fc 167 | 168 | if self.div_type == "n+1": 169 | div_fc = self._gen_classifier(self.out_dim, n_classes + 1) 170 | elif self.div_type == "1+1": 171 | div_fc = self._gen_classifier(self.out_dim, 2) 172 | elif self.div_type == "n+t": 173 | div_fc = self._gen_classifier(self.out_dim, self.ntask + n_classes) 174 | else: 175 | div_fc = self._gen_classifier(self.out_dim, self.n_classes + n_classes) 176 | del self.aux_classifier 177 | self.aux_classifier = div_fc 178 | 179 | def _add_classes_single_fc(self, n_classes): 180 | if self.classifier is not None: 181 | weight = copy.deepcopy(self.classifier.weight.data) 182 | if self.use_bias: 183 | bias = copy.deepcopy(self.classifier.bias.data) 184 | 185 | classifier = self._gen_classifier(self.features_dim, self.n_classes + n_classes) 186 | 187 | if self.classifier is not None and self.reuse_oldfc: 188 | classifier.weight.data[:self.n_classes] = weight 189 | if self.use_bias: 190 | classifier.bias.data[:self.n_classes] = bias 191 | 192 | del self.classifier 193 | self.classifier = classifier 194 | 195 | def _gen_classifier(self, in_features, n_classes): 196 | if self.weight_normalization: 197 | classifier = CosineClassifier(in_features, n_classes).to(self.device) 198 | # classifier = CosineClassifier(in_features, n_classes).cuda() 199 | else: 200 | classifier = nn.Linear(in_features, n_classes, bias=self.use_bias).to(self.device) 201 | # classifier = nn.Linear(in_features, n_classes, bias=self.use_bias).cuda() 202 | if self.init == "kaiming": 203 | nn.init.kaiming_normal_(classifier.weight, nonlinearity="linear") 204 | if self.use_bias: 205 | nn.init.constant_(classifier.bias, 0.0) 206 | 207 | return classifier 208 | -------------------------------------------------------------------------------- /inclearn/prune/resnet_small.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 6 | """3x3 convolution with padding""" 7 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 8 | padding=dilation, groups=groups, bias=False, dilation=dilation) 9 | 10 | 11 | def conv1x1(in_planes, out_planes, stride=1): 12 | """1x1 convolution""" 13 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 14 | 15 | 16 | class BasicBlock(nn.Module): 17 | expansion = 1 18 | 19 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 20 | base_width=64, dilation=1, norm_layer=None): 21 | super(BasicBlock, self).__init__() 22 | if norm_layer is None: 23 | norm_layer = nn.BatchNorm2d 24 | if groups != 1 or base_width != 64: 25 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 26 | if dilation > 1: 27 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 28 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 29 | self.conv1 = conv3x3(inplanes, planes, stride) 30 | self.bn1 = norm_layer(planes) 31 | self.relu = nn.ReLU(inplace=True) 32 | self.conv2 = conv3x3(planes, planes) 33 | self.bn2 = norm_layer(planes) 34 | self.downsample = downsample 35 | self.stride = stride 36 | 37 | def forward(self, x): 38 | identity = x 39 | 40 | out = self.conv1(x) 41 | out = self.bn1(out) 42 | out = self.relu(out) 43 | 44 | out = self.conv2(out) 45 | out = self.bn2(out) 46 | 47 | if self.downsample is not None: 48 | identity = self.downsample(x) 49 | 50 | out += identity 51 | out = self.relu(out) 52 | 53 | return out 54 | 55 | 56 | class Bottleneck(nn.Module): 57 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 58 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 59 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 60 | # This variant is also known as ResNet V1.5 and improves accuracy according to 61 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 62 | 63 | expansion = 4 64 | 65 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 66 | base_width=64, dilation=1, norm_layer=None): 67 | super(Bottleneck, self).__init__() 68 | if norm_layer is None: 69 | norm_layer = nn.BatchNorm2d 70 | width = int(planes * (base_width / 64.)) * groups 71 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 72 | self.conv1 = conv1x1(inplanes, width) 73 | self.bn1 = norm_layer(width) 74 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 75 | self.bn2 = norm_layer(width) 76 | self.conv3 = conv1x1(width, planes * self.expansion) 77 | self.bn3 = norm_layer(planes * self.expansion) 78 | self.relu = nn.ReLU(inplace=True) 79 | self.downsample = downsample 80 | self.stride = stride 81 | 82 | def forward(self, x): 83 | identity = x 84 | 85 | out = self.conv1(x) 86 | out = self.bn1(out) 87 | out = self.relu(out) 88 | 89 | out = self.conv2(out) 90 | out = self.bn2(out) 91 | out = self.relu(out) 92 | 93 | out = self.conv3(out) 94 | out = self.bn3(out) 95 | 96 | if self.downsample is not None: 97 | identity = self.downsample(x) 98 | 99 | out += identity 100 | out = self.relu(out) 101 | 102 | return out 103 | 104 | 105 | class ResNet(nn.Module): 106 | 107 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 108 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 109 | norm_layer=None): 110 | super(ResNet, self).__init__() 111 | if norm_layer is None: 112 | norm_layer = nn.BatchNorm2d 113 | self._norm_layer = norm_layer 114 | 115 | self.inplanes = 64 116 | self.dilation = 1 117 | if replace_stride_with_dilation is None: 118 | # each element in the tuple indicates if we should replace 119 | # the 2x2 stride with a dilated convolution instead 120 | replace_stride_with_dilation = [False, False, False] 121 | if len(replace_stride_with_dilation) != 3: 122 | raise ValueError("replace_stride_with_dilation should be None " 123 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 124 | self.groups = groups 125 | self.base_width = width_per_group 126 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 127 | bias=False) 128 | self.bn1 = norm_layer(self.inplanes) 129 | self.relu = nn.ReLU(inplace=True) 130 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 131 | self.layer1 = self._make_layer(block, 64, layers[0]) 132 | # self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 133 | # self.fc = nn.Linear(64 * block.expansion, num_classes) 134 | self.final_conv=nn.Conv2d(64, 1024, kernel_size=3, stride=1, padding=1, 135 | bias=False) 136 | 137 | for m in self.modules(): 138 | if isinstance(m, nn.Conv2d): 139 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 140 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 141 | nn.init.constant_(m.weight, 1) 142 | nn.init.constant_(m.bias, 0) 143 | 144 | # Zero-initialize the last BN in each residual branch, 145 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 146 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 147 | if zero_init_residual: 148 | for m in self.modules(): 149 | if isinstance(m, Bottleneck): 150 | nn.init.constant_(m.bn3.weight, 0) 151 | elif isinstance(m, BasicBlock): 152 | nn.init.constant_(m.bn2.weight, 0) 153 | 154 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 155 | norm_layer = self._norm_layer 156 | downsample = None 157 | previous_dilation = self.dilation 158 | if dilate: 159 | self.dilation *= stride 160 | stride = 1 161 | if stride != 1 or self.inplanes != planes * block.expansion: 162 | downsample = nn.Sequential( 163 | conv1x1(self.inplanes, planes * block.expansion, stride), 164 | norm_layer(planes * block.expansion), 165 | ) 166 | 167 | layers = [] 168 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 169 | self.base_width, previous_dilation, norm_layer)) 170 | self.inplanes = planes * block.expansion 171 | for _ in range(1, blocks): 172 | layers.append(block(self.inplanes, planes, groups=self.groups, 173 | base_width=self.base_width, dilation=self.dilation, 174 | norm_layer=norm_layer)) 175 | 176 | return nn.Sequential(*layers) 177 | 178 | def _forward_impl(self, x): 179 | # See note [TorchScript super()] 180 | x = self.conv1(x) 181 | x = self.bn1(x) 182 | x = self.relu(x) 183 | x = self.maxpool(x) 184 | 185 | x = self.layer1(x) 186 | # x = self.avgpool(x) 187 | # x = torch.flatten(x, 1) 188 | # x = self.fc(x) 189 | x=self.final_conv(x) 190 | return x 191 | 192 | def forward(self, x): 193 | return self._forward_impl(x) 194 | 195 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 196 | model = ResNet(block, layers, **kwargs) 197 | return model 198 | 199 | def resnet_small(pretrained=False, progress=True, **kwargs): 200 | r"""ResNet-18 model from 201 | `"Deep Residual Learning for Image Recognition" `_ 202 | 203 | Args: 204 | pretrained (bool): If True, returns a model pre-trained on ImageNet 205 | progress (bool): If True, displays a progress bar of the download to stderr 206 | """ 207 | return _resnet('resnet_small', BasicBlock, [1], pretrained, progress, 208 | **kwargs) 209 | -------------------------------------------------------------------------------- /inclearn/datasets/dataset.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import numpy as np 3 | import glob 4 | 5 | from albumentations.pytorch import ToTensorV2 6 | 7 | from torchvision import datasets, transforms 8 | import torch 9 | from inclearn.tools.cutout import Cutout 10 | from inclearn.tools.autoaugment_extra import ImageNetPolicy 11 | 12 | 13 | def get_datasets(dataset_names): 14 | return [get_dataset(dataset_name) for dataset_name in dataset_names.split("-")] 15 | 16 | 17 | def get_dataset(dataset_name): 18 | if dataset_name == "cifar10": 19 | return iCIFAR10 20 | elif dataset_name == "cifar100": 21 | return iCIFAR100 22 | elif "imagenet100" in dataset_name: 23 | return iImageNet100 24 | else: 25 | raise NotImplementedError("Unknown dataset {}.".format(dataset_name)) 26 | 27 | 28 | class DataHandler: 29 | base_dataset = None 30 | train_transforms = [] 31 | common_transforms = [ToTensorV2()] 32 | class_order = None 33 | 34 | 35 | class iCIFAR10(DataHandler): 36 | base_dataset_cls = datasets.cifar.CIFAR10 37 | transform_type = 'torchvision' 38 | train_transforms = transforms.Compose([ 39 | transforms.RandomCrop(32, padding=4), 40 | transforms.RandomHorizontalFlip(), 41 | # transforms.ColorJitter(brightness=63 / 255), 42 | transforms.ToTensor(), 43 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 44 | ]) 45 | test_transforms = transforms.Compose([ 46 | transforms.ToTensor(), 47 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 48 | ]) 49 | 50 | def __init__(self, data_folder, train, is_fine_label=False): 51 | self.base_dataset = self.base_dataset_cls(data_folder, train=train, download=True) 52 | self.data = self.base_dataset.data 53 | self.targets = self.base_dataset.targets 54 | self.n_cls = 10 55 | 56 | @property 57 | def is_proc_inc_data(self): 58 | return False 59 | 60 | @classmethod 61 | def class_order(cls, trial_i): 62 | return [4, 0, 2, 5, 8, 3, 1, 6, 9, 7] 63 | 64 | 65 | class iCIFAR100(iCIFAR10): 66 | label_list = [ 67 | 'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle', 'bowl', 'boy', 68 | 'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', 69 | 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', 'dolphin', 'elephant', 'flatfish', 70 | 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion', 71 | 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', 'mushroom', 'oak_tree', 'orange', 72 | 'orchid', 'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine', 73 | 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 74 | 'skyscraper', 'snail', 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank', 75 | 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 76 | 'willow_tree', 'wolf', 'woman', 'worm' 77 | ] 78 | base_dataset_cls = datasets.cifar.CIFAR100 79 | transform_type = 'torchvision' 80 | train_transforms = transforms.Compose([ 81 | transforms.RandomCrop(32, padding=4), 82 | transforms.RandomHorizontalFlip(), 83 | transforms.ColorJitter(brightness=63 / 255), 84 | transforms.ToTensor(), 85 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), 86 | ]) 87 | test_transforms = transforms.Compose([ 88 | transforms.ToTensor(), 89 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), 90 | ]) 91 | 92 | def __init__(self, data_folder, train, is_fine_label=False): 93 | self.base_dataset = self.base_dataset_cls(data_folder, train=train, download=True) 94 | self.data = self.base_dataset.data 95 | self.targets = self.base_dataset.targets 96 | self.n_cls = 100 97 | self.transform_type = 'torchvision' 98 | 99 | @property 100 | def is_proc_inc_data(self): 101 | return False 102 | 103 | @classmethod 104 | def class_order(cls, trial_i): 105 | if trial_i == 0: 106 | return [ 107 | 62, 54, 84, 20, 94, 22, 40, 29, 78, 27, 26, 79, 17, 76, 68, 88, 3, 19, 31, 21, 33, 60, 24, 14, 6, 10, 108 | 16, 82, 70, 92, 25, 5, 28, 9, 61, 36, 50, 90, 8, 48, 47, 56, 11, 98, 35, 93, 44, 64, 75, 66, 15, 38, 97, 109 | 42, 43, 12, 37, 55, 72, 95, 18, 7, 23, 71, 49, 53, 57, 86, 39, 87, 34, 63, 81, 89, 69, 46, 2, 1, 73, 32, 110 | 67, 91, 0, 51, 83, 13, 58, 80, 74, 65, 4, 30, 45, 77, 99, 85, 41, 96, 59, 52 111 | ] 112 | elif trial_i == 1: 113 | return [ 114 | 68, 56, 78, 8, 23, 84, 90, 65, 74, 76, 40, 89, 3, 92, 55, 9, 26, 80, 43, 38, 58, 70, 77, 1, 85, 19, 17, 115 | 50, 28, 53, 13, 81, 45, 82, 6, 59, 83, 16, 15, 44, 91, 41, 72, 60, 79, 52, 20, 10, 31, 54, 37, 95, 14, 116 | 71, 96, 98, 97, 2, 64, 66, 42, 22, 35, 86, 24, 34, 87, 21, 99, 0, 88, 27, 18, 94, 11, 12, 47, 25, 30, 117 | 46, 62, 69, 36, 61, 7, 63, 75, 5, 32, 4, 51, 48, 73, 93, 39, 67, 29, 49, 57, 33 118 | ] 119 | elif trial_i == 2: #PODNet 120 | return [ 121 | 87, 0, 52, 58, 44, 91, 68, 97, 51, 15, 94, 92, 10, 72, 49, 78, 61, 14, 8, 86, 84, 96, 18, 24, 32, 45, 122 | 88, 11, 4, 67, 69, 66, 77, 47, 79, 93, 29, 50, 57, 83, 17, 81, 41, 12, 37, 59, 25, 20, 80, 73, 1, 28, 6, 123 | 46, 62, 82, 53, 9, 31, 75, 38, 63, 33, 74, 27, 22, 36, 3, 16, 21, 60, 19, 70, 90, 89, 43, 5, 42, 65, 76, 124 | 40, 30, 23, 85, 2, 95, 56, 48, 71, 64, 98, 13, 99, 7, 34, 55, 54, 26, 35, 39 125 | ] 126 | elif trial_i == 3: #PODNet 127 | return [ 128 | 58, 30, 93, 69, 21, 77, 3, 78, 12, 71, 65, 40, 16, 49, 89, 46, 24, 66, 19, 41, 5, 29, 15, 73, 11, 70, 129 | 90, 63, 67, 25, 59, 72, 80, 94, 54, 33, 18, 96, 2, 10, 43, 9, 57, 81, 76, 50, 32, 6, 37, 7, 68, 91, 88, 130 | 95, 85, 4, 60, 36, 22, 27, 39, 42, 34, 51, 55, 28, 53, 48, 38, 17, 83, 86, 56, 35, 45, 79, 99, 84, 97, 131 | 82, 98, 26, 47, 44, 62, 13, 31, 0, 75, 14, 52, 74, 8, 20, 1, 92, 87, 23, 64, 61 132 | ] 133 | elif trial_i == 4: #PODNet 134 | return [ 135 | 71, 54, 45, 32, 4, 8, 48, 66, 1, 91, 28, 82, 29, 22, 80, 27, 86, 23, 37, 47, 55, 9, 14, 68, 25, 96, 36, 136 | 90, 58, 21, 57, 81, 12, 26, 16, 89, 79, 49, 31, 38, 46, 20, 92, 88, 40, 39, 98, 94, 19, 95, 72, 24, 64, 137 | 18, 60, 50, 63, 61, 83, 76, 69, 35, 0, 52, 7, 65, 42, 73, 74, 30, 41, 3, 6, 53, 13, 56, 70, 77, 34, 97, 138 | 75, 2, 17, 93, 33, 84, 99, 51, 62, 87, 5, 15, 10, 78, 67, 44, 59, 85, 43, 11 139 | ] 140 | 141 | 142 | class DataHandler: 143 | base_dataset = None 144 | train_transforms = [] 145 | common_transforms = [ToTensorV2()] 146 | class_order = None 147 | 148 | 149 | class iImageNet100(DataHandler): 150 | 151 | base_dataset_cls = datasets.ImageFolder 152 | transform_type = 'torchvision' 153 | train_transforms = transforms.Compose([ 154 | transforms.ToPILImage(), 155 | transforms.ToTensor(), 156 | Cutout(n_holes=1, length=16), 157 | transforms.ToPILImage(), 158 | transforms.RandomResizedCrop(224), 159 | transforms.RandomHorizontalFlip(), 160 | ImageNetPolicy(), 161 | transforms.ToTensor(), 162 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 163 | 164 | ]) 165 | test_transforms = transforms.Compose([ 166 | transforms.ToPILImage(), 167 | transforms.ToTensor(), 168 | transforms.ToPILImage(), 169 | transforms.Resize(256), 170 | transforms.CenterCrop(224), 171 | transforms.ToTensor(), 172 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 173 | ]) 174 | 175 | def __init__(self, data_folder, train, is_fine_label=False): 176 | if train is True: 177 | self.base_dataset = self.base_dataset_cls(osp.join(data_folder, "train")) 178 | else: 179 | self.base_dataset = self.base_dataset_cls(osp.join(data_folder, "val")) 180 | 181 | self.data, self.targets = zip(*self.base_dataset.samples) 182 | self.data = np.array(self.data) 183 | self.targets = np.array(self.targets) 184 | self.n_cls = 100 185 | 186 | @property 187 | def is_proc_inc_data(self): 188 | return False 189 | 190 | @classmethod 191 | def class_order(cls, trial_i): 192 | return [ 193 | 68, 56, 78, 8, 23, 84, 90, 65, 74, 76, 40, 89, 3, 92, 55, 9, 26, 80, 43, 38, 58, 70, 77, 1, 85, 19, 17, 50, 194 | 28, 53, 13, 81, 45, 82, 6, 59, 83, 16, 15, 44, 91, 41, 72, 60, 79, 52, 20, 10, 31, 54, 37, 95, 14, 71, 96, 195 | 98, 97, 2, 64, 66, 42, 22, 35, 86, 24, 34, 87, 21, 99, 0, 88, 27, 18, 94, 11, 12, 47, 25, 30, 46, 62, 69, 196 | 36, 61, 7, 63, 75, 5, 32, 4, 51, 48, 73, 93, 39, 67, 29, 49, 57, 33 197 | ] 198 | -------------------------------------------------------------------------------- /inclearn/prune/autoslim_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from itertools import chain 5 | from .dependency import * 6 | from . import prune 7 | import math 8 | from scipy.spatial import distance 9 | 10 | __all__ = ['Autoslim'] 11 | 12 | 13 | class Autoslim(object): 14 | def __init__(self, model, inputs, compression_ratio): 15 | self.model = model # torchvision.models模型 16 | self.inputs = inputs # 输入大小,torch.randn(1,3,224,224) 17 | self.compression_ratio = compression_ratio # 期望压缩率 18 | self.DG = DependencyGraph() 19 | # 构建节点依赖关系 20 | self.DG.build_dependency(model, example_inputs=inputs) 21 | self.model_modules = list(model.modules()) 22 | self.pruning_func = { 23 | 'l1': self._base_l1_pruning, 24 | 'fpgm': self._base_fpgm_pruning 25 | } 26 | 27 | def index_of_layer(self): 28 | dicts = {} 29 | for i, m in enumerate(self.model_modules): 30 | if isinstance(m, nn.modules.conv._ConvNd): 31 | dicts[i] = m 32 | return dicts 33 | 34 | def base_prunging(self, config): 35 | if not config['pruning_func'] in self.pruning_func: 36 | raise KeyError( 37 | "-[ERROR] {} not supported.".format((config['pruning_func']))) 38 | 39 | ori_output = {} 40 | for i, m in enumerate(self.model_modules): 41 | if isinstance(m, nn.modules.conv._ConvNd): 42 | ori_output[i] = m.out_channels 43 | 44 | if config['layer_compression_ratio'] is None and config['prune_shortcut'] == 1: 45 | config['layer_compression_ratio'] = self._compute_auto_ratios() 46 | 47 | prune_indexes = self.pruning_func[config['pruning_func']](config) 48 | 49 | for i, m in enumerate(self.model_modules): 50 | if i in prune_indexes and m.out_channels == ori_output[i]: 51 | pruning_plan = self.DG.get_pruning_plan( 52 | m, prune.prune_conv, idxs=prune_indexes[i]) 53 | if pruning_plan and config['prune_shortcut'] == 1: 54 | pruning_plan.exec() 55 | elif not pruning_plan.is_in_shortcut: 56 | pruning_plan.exec() 57 | 58 | def _base_fpgm_pruning(self, config): 59 | prune_indexes = {} 60 | for i, m in enumerate(self.model_modules): 61 | # _ConvNd包含卷积和反卷积 62 | if isinstance(m, nn.modules.conv._ConvNd): 63 | weight_torch = m.weight.detach().cuda() 64 | if isinstance(m, nn.modules.conv._ConvTransposeMixin): 65 | weight_vec = weight_torch.view(weight_torch.size()[1], -1) 66 | out_channels = weight_torch.size()[1] 67 | else: 68 | weight_vec = weight_torch.view( 69 | weight_torch.size()[0], -1) # 权重[512,64,3,3] -> [512, 64*3*3] 70 | out_channels = weight_torch.size()[0] 71 | 72 | if config['layer_compression_ratio'] and i in config['layer_compression_ratio']: 73 | similar_pruned_num = int( 74 | out_channels * config['layer_compression_ratio'][i]) 75 | # 全自动化压缩时,不剪跳连层 76 | else: 77 | similar_pruned_num = int( 78 | out_channels * self.compression_ratio) 79 | 80 | filter_pruned_num = int( 81 | out_channels * (1 - config['norm_rate'])) 82 | 83 | if config['dist_type'] == "l2" or "cos": 84 | norm = torch.norm(weight_vec, 2, 1) 85 | norm_np = norm.cpu().numpy() 86 | elif config['dist_type'] == "l1": 87 | norm = torch.norm(weight_vec, 1, 1) 88 | norm_np = norm.cpu().numpy() 89 | 90 | filter_large_index = [] 91 | filter_large_index = norm_np.argsort()[filter_pruned_num:] 92 | 93 | indices = torch.LongTensor(filter_large_index).cuda() 94 | # weight_vec_after_norm.size=15 95 | weight_vec_after_norm = torch.index_select( 96 | weight_vec, 0, indices).cpu().numpy() 97 | 98 | # for euclidean distance 99 | if config['dist_type'] == "l2" or "l1": 100 | similar_matrix = distance.cdist( 101 | weight_vec_after_norm, weight_vec_after_norm, 'euclidean') 102 | elif config['dist_type'] == "cos": # for cos similarity 103 | similar_matrix = 1 - \ 104 | distance.cdist(weight_vec_after_norm, 105 | weight_vec_after_norm, 'cosine') 106 | 107 | # 将任意一个点与其他点的距离算出来,最后将距离相加,一共得到15组数据 108 | similar_sum = np.sum(np.abs(similar_matrix), axis=0) 109 | 110 | # for distance similar: get the filter index with largest similarity == small distance 111 | similar_large_index = similar_sum.argsort()[ 112 | similar_pruned_num:] 113 | similar_small_index = similar_sum.argsort()[ 114 | :similar_pruned_num] 115 | prune_index = [filter_large_index[i] 116 | for i in similar_small_index] 117 | prune_indexes[i] = prune_index 118 | return prune_indexes 119 | 120 | def _base_l1_pruning(self, config): 121 | prune_indexes = {} 122 | # 全局阈值剪枝法(最好别用,效果不佳) 123 | if config['global_pruning']: 124 | filter_record = [] 125 | for i, m in enumerate(self.model_modules): 126 | if isinstance(m, nn.modules.conv._ConvNd): 127 | weight = m.weight.detach().cpu().numpy() 128 | if isinstance(m, nn.modules.conv._ConvTransposeMixin): 129 | L1_norm = np.sum(np.abs(weight), axis=( 130 | 0, 2, 3)) # 注:反卷积维数1对应输出维度 131 | else: 132 | L1_norm = np.sum(np.abs(weight), axis=(1, 2, 3)) 133 | filter_record.append(L1_norm.tolist()) # 记录每层卷积的l1_norm参数 134 | 135 | filter_record = list(chain.from_iterable(filter_record)) 136 | total = len(filter_record) 137 | filter_record.sort() # 全局排序 138 | thre_index = int(total * self.compression_ratio) 139 | thre = filter_record[thre_index] # 根据裁剪率确定阈值 140 | for i, m in enumerate(self.model_modules): 141 | if isinstance(m, nn.modules.conv._ConvNd): 142 | weight = m.weight.detach().cpu().numpy() 143 | # _ConvTransposeMixin只包含反卷积 144 | if isinstance(m, nn.modules.conv._ConvTransposeMixin): 145 | L1_norm = np.sum(np.abs(weight), axis=( 146 | 0, 2, 3)) # 注:反卷积维数1对应输出维度 147 | else: 148 | L1_norm = np.sum(np.abs(weight), axis=(1, 2, 3)) 149 | num_pruned = min(int(max_ratio*len(L1_norm)), 150 | len(L1_norm[L1_norm < thre])) # 不能全部减去 151 | # 删除低于阈值的卷积核 152 | prune_index = np.argsort(L1_norm)[:num_pruned].tolist() 153 | prune_indexes.append(prune_index) 154 | 155 | # 局部阈值加指定层 156 | else: 157 | if config['layer_compression_ratio'] is None and config['prune_shortcut'] == 1: 158 | # 需要剪跳连层,并且未指定每一层的裁剪率 159 | config['layer_compression_ratio'] = self._compute_auto_ratios() 160 | 161 | for i, m in enumerate(self.model_modules): 162 | # 逐层裁剪 163 | # _ConvNd包含卷积和反卷积 164 | if isinstance(m, nn.modules.conv._ConvNd): 165 | weight = m.weight.detach().cpu().numpy() 166 | # _ConvTransposeMixin只包含反卷积 167 | if isinstance(m, nn.modules.conv._ConvTransposeMixin): 168 | out_channels = weight.shape[1] 169 | L1_norm = np.sum(np.abs(weight), axis=(0, 2, 3)) 170 | else: 171 | out_channels = weight.shape[0] 172 | L1_norm = np.sum( 173 | np.abs(weight), axis=(1, 2, 3)) # 计算卷积核的L1范式 174 | 175 | # 自定义压缩或全自动化压缩时剪跳连层 176 | if config['layer_compression_ratio'] and i in config['layer_compression_ratio']: 177 | num_pruned = int( 178 | out_channels * config['layer_compression_ratio'][i]) 179 | # 全自动化压缩时,不剪跳连层 180 | else: 181 | num_pruned = int(out_channels * self.compression_ratio) 182 | 183 | # remove filters with small L1-Norm 184 | prune_index = np.argsort(L1_norm)[:num_pruned].tolist() 185 | prune_indexes.append(prune_index) 186 | return prune_indexes 187 | 188 | def _compute_auto_ratios(self): 189 | layer_compression_ratio = {} 190 | mid_value = self.compression_ratio 191 | 192 | one_value = (1-mid_value)/4 if mid_value >= 0.43 else mid_value/4 193 | values = [mid_value-one_value*3, mid_value-one_value*2, mid_value-one_value, 194 | mid_value, mid_value+one_value, mid_value+one_value*2, mid_value+one_value*3] 195 | layer_cnt = 0 196 | for i, m in enumerate(self.model_modules): 197 | if isinstance(m, nn.modules.conv._ConvNd): 198 | layer_compression_ratio[i] = 0 199 | layer_cnt += 1 200 | layers_of_class = layer_cnt/7 201 | conv_cnt = 0 202 | for i, m in enumerate(self.model_modules): 203 | if isinstance(m, nn.modules.conv._ConvNd): 204 | layer_compression_ratio[i] = values[math.floor( 205 | conv_cnt/layers_of_class)] 206 | conv_cnt += 1 207 | return layer_compression_ratio 208 | 209 | 210 | if __name__ == "__main__": 211 | from resnet_small import resnet_small 212 | model = resnet_small() 213 | slim = Autoslim(model, inputs=torch.randn( 214 | 1, 3, 224, 224), compression_ratio=0.5) 215 | slim.l1_norm_pruning() 216 | print(model) 217 | -------------------------------------------------------------------------------- /inclearn/prune/autoslim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from itertools import chain 5 | from .dependency import * 6 | from . import prune 7 | import math 8 | from scipy.spatial import distance 9 | 10 | __all__ = ['Autoslim'] 11 | 12 | 13 | class Autoslim(object): 14 | def __init__(self, model, inputs, compression_ratio): 15 | self.model = model # torchvision.models模型 16 | self.inputs = inputs # 输入大小,torch.randn(1,3,224,224) 17 | self.compression_ratio = compression_ratio # 期望压缩率 18 | self.DG = DependencyGraph() 19 | # 构建节点依赖关系 20 | self.DG.build_dependency(model, example_inputs=inputs) 21 | self.model_modules = list(model.modules()) 22 | self.pruning_func = { 23 | 'l1': self._base_l1_pruning, 24 | 'fpgm': self._base_fpgm_pruning 25 | } 26 | 27 | def index_of_layer(self): 28 | dicts = {} 29 | for i, m in enumerate(self.model_modules): 30 | if isinstance(m, nn.modules.conv._ConvNd): 31 | dicts[i] = m 32 | return dicts 33 | 34 | def base_prunging(self, config): 35 | if not config['pruning_func'] in self.pruning_func: 36 | raise KeyError( 37 | "-[ERROR] {} pruning not supported.".format(config['pruning_func'])) 38 | 39 | ori_output = {} 40 | for i, m in enumerate(self.model_modules): 41 | if isinstance(m, nn.modules.conv._ConvNd): 42 | ori_output[i] = m.out_channels 43 | 44 | if config['layer_compression_ratio'] is None and config['prune_shortcut'] == 1: 45 | config['layer_compression_ratio'] = self._compute_auto_ratios() 46 | 47 | prune_indexes = self.pruning_func[config['pruning_func']](config) 48 | 49 | for i, m in enumerate(self.model_modules): 50 | if i in prune_indexes and m.out_channels == ori_output[i]: 51 | pruning_plan = self.DG.get_pruning_plan( 52 | m, prune.prune_conv, idxs=prune_indexes[i]) 53 | if pruning_plan and config['prune_shortcut'] == 1: 54 | pruning_plan.exec() 55 | elif not pruning_plan: 56 | continue 57 | elif not pruning_plan.is_in_shortcut: 58 | pruning_plan.exec() 59 | 60 | def _base_fpgm_pruning(self, config): 61 | prune_indexes = {} 62 | for i, m in enumerate(self.model_modules): 63 | # _ConvNd包含卷积和反卷积 64 | if isinstance(m, nn.modules.conv._ConvNd): 65 | weight_torch = m.weight.detach().cuda() 66 | if isinstance(m, nn.modules.conv._ConvTransposeMixin): 67 | weight_vec = weight_torch.view(weight_torch.size()[1], -1) 68 | out_channels = weight_torch.size()[1] 69 | else: 70 | weight_vec = weight_torch.view( 71 | weight_torch.size()[0], -1) # 权重[512,64,3,3] -> [512, 64*3*3] 72 | out_channels = weight_torch.size()[0] 73 | 74 | if config['layer_compression_ratio'] and i in config['layer_compression_ratio']: 75 | similar_pruned_num = int( 76 | out_channels * config['layer_compression_ratio'][i]) 77 | # 全自动化压缩时,不剪跳连层 78 | else: 79 | similar_pruned_num = int( 80 | out_channels * self.compression_ratio) 81 | 82 | filter_pruned_num = int( 83 | out_channels * (1 - config['norm_rate'])) 84 | 85 | if config['dist_type'] == "l2" or "cos": 86 | norm = torch.norm(weight_vec, 2, 1) 87 | norm_np = norm.cpu().numpy() 88 | elif config['dist_type'] == "l1": 89 | norm = torch.norm(weight_vec, 1, 1) 90 | norm_np = norm.cpu().numpy() 91 | 92 | filter_large_index = [] 93 | filter_large_index = norm_np.argsort()[filter_pruned_num:] 94 | 95 | indices = torch.LongTensor(filter_large_index).cuda() 96 | # weight_vec_after_norm.size=15 97 | weight_vec_after_norm = torch.index_select( 98 | weight_vec, 0, indices).cpu().numpy() 99 | 100 | # for euclidean distance 101 | if config['dist_type'] == "l2" or "l1": 102 | similar_matrix = distance.cdist( 103 | weight_vec_after_norm, weight_vec_after_norm, 'euclidean') 104 | elif config['dist_type'] == "cos": # for cos similarity 105 | similar_matrix = 1 - \ 106 | distance.cdist(weight_vec_after_norm, 107 | weight_vec_after_norm, 'cosine') 108 | 109 | # 将任意一个点与其他点的距离算出来,最后将距离相加,一共得到15组数据 110 | similar_sum = np.sum(np.abs(similar_matrix), axis=0) 111 | 112 | # for distance similar: get the filter index with largest similarity == small distance 113 | similar_large_index = similar_sum.argsort()[ 114 | similar_pruned_num:] 115 | similar_small_index = similar_sum.argsort()[ 116 | :similar_pruned_num] 117 | prune_index = [filter_large_index[i] 118 | for i in similar_small_index] 119 | prune_indexes[i] = prune_index 120 | return prune_indexes 121 | 122 | def _base_l1_pruning(self, config): 123 | return self.__base_lx_norm_pruning(config, norm='l1') 124 | 125 | def _base_l1_pruning(self, config): 126 | return self.__base_lx_norm_pruning(config, norm='l2') 127 | 128 | def __base_lx_norm_pruning(self, config, norm='l1'): 129 | prune_indexes = {} 130 | 131 | def _compute_lx_norm(m, norm): 132 | weight = m.weight.detach().cpu().numpy() 133 | if isinstance(m, nn.modules.conv._ConvTransposeMixin): 134 | if norm == 'l1': 135 | Lx_norm = np.sum(np.abs(weight), axis=(0, 2, 3)) 136 | else: 137 | Lx_norm = np.sum( 138 | np.sqrt(weight ** 2), axis=(0, 2, 3)) 139 | else: 140 | if norm == 'l1': 141 | Lx_norm = np.sum(np.abs(weight), axis=(1, 2, 3)) 142 | else: 143 | # 注:反卷积维数1对应输出维度 144 | Lx_norm = np.sum( 145 | np.sqrt(weight ** 2), axis=(1, 2, 3)) 146 | return Lx_norm 147 | 148 | # 全局阈值剪枝法(最好别用,效果不佳) 149 | if config['global_pruning']: 150 | filter_record = [] 151 | for i, m in enumerate(self.model_modules): 152 | if isinstance(m, nn.modules.conv._ConvNd): 153 | Lx_norm = _compute_lx_norm(m, norm) 154 | filter_record.append(Lx_norm.tolist()) # 记录每层卷积的lx_norm参数 155 | 156 | filter_record = list(chain.from_iterable(filter_record)) 157 | total = len(filter_record) 158 | filter_record.sort() # 全局排序 159 | thre_index = int(total * self.compression_ratio) 160 | thre = filter_record[thre_index] # 根据裁剪率确定阈值 161 | for i, m in enumerate(self.model_modules): 162 | if isinstance(m, nn.modules.conv._ConvNd): 163 | weight = m.weight.detach().cpu().numpy() 164 | # _ConvTransposeMixin只包含反卷积 165 | if isinstance(m, nn.modules.conv._ConvTransposeMixin): 166 | Lx_norm = np.sum(np.abs(weight), axis=( 167 | 0, 2, 3)) # 注:反卷积维数1对应输出维度 168 | else: 169 | Lx_norm = np.sum(np.abs(weight), axis=(1, 2, 3)) 170 | num_pruned = min(int(max_ratio*len(Lx_norm)), 171 | len(Lx_norm[Lx_norm < thre])) # 不能全部减去 172 | # 删除低于阈值的卷积核 173 | prune_index = np.argsort(Lx_norm)[:num_pruned].tolist() 174 | prune_indexes[i] = prune_index 175 | 176 | # 局部阈值加指定层 177 | else: 178 | if config['layer_compression_ratio'] is None and config['prune_shortcut'] == 1: 179 | # 需要剪跳连层,并且未指定每一层的裁剪率 180 | config['layer_compression_ratio'] = self._compute_auto_ratios() 181 | 182 | for i, m in enumerate(self.model_modules): 183 | # 逐层裁剪 184 | # _ConvNd包含卷积和反卷积 185 | if isinstance(m, nn.modules.conv._ConvNd): 186 | Lx_norm = _compute_lx_norm(m, norm) 187 | 188 | # 自定义压缩或全自动化压缩时剪跳连层 189 | if config['layer_compression_ratio'] and i in config['layer_compression_ratio']: 190 | num_pruned = int( 191 | out_channels * config['layer_compression_ratio'][i]) 192 | # 全自动化压缩时,不剪跳连层 193 | else: 194 | num_pruned = int(out_channels * self.compression_ratio) 195 | 196 | # remove filters with small L1-Norm 197 | prune_index = np.argsort(Lx_norm)[:num_pruned].tolist() 198 | prune_indexes[i] = prune_index 199 | return prune_indexes 200 | 201 | def _compute_auto_ratios(self): 202 | # 如果未指定每层裁剪率,则自动生成 203 | layer_compression_ratio = {} 204 | mid_value = self.compression_ratio 205 | 206 | one_value = (1-mid_value)/4 if mid_value >= 0.43 else mid_value/4 207 | values = [mid_value-one_value*3, mid_value-one_value*2, mid_value-one_value, 208 | mid_value, mid_value+one_value, mid_value+one_value*2, mid_value+one_value*3] 209 | # 分为七级裁剪率,从浅到深,从小到大 210 | # 均值为期望裁剪率 211 | layer_cnt = 0 212 | for i, m in enumerate(self.model_modules): 213 | if isinstance(m, nn.modules.conv._ConvNd): 214 | layer_compression_ratio[i] = 0 215 | layer_cnt += 1 216 | layers_of_class = layer_cnt/7 217 | conv_cnt = 0 218 | for i, m in enumerate(self.model_modules): 219 | if isinstance(m, nn.modules.conv._ConvNd): 220 | layer_compression_ratio[i] = values[math.floor( 221 | conv_cnt/layers_of_class)] 222 | conv_cnt += 1 223 | return layer_compression_ratio 224 | 225 | 226 | if __name__ == "__main__": 227 | from resnet_small import resnet_small 228 | model = resnet_small() 229 | slim = Autoslim(model, inputs=torch.randn( 230 | 1, 3, 224, 224), compression_ratio=0.5) 231 | slim.l1_norm_pruning() 232 | print(model) 233 | -------------------------------------------------------------------------------- /inclearn/convnet/resnet.py: -------------------------------------------------------------------------------- 1 | """Taken & slightly modified from: 2 | * https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | import torch.utils.model_zoo as model_zoo 7 | from torch.nn import functional as F 8 | 9 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'] 10 | 11 | model_urls = { 12 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 13 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 14 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 15 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 16 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 17 | } 18 | 19 | 20 | def conv3x3(in_planes, out_planes, stride=1): 21 | """3x3 convolution with padding""" 22 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 23 | 24 | 25 | def conv1x1(in_planes, out_planes, stride=1): 26 | """1x1 convolution""" 27 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 28 | 29 | 30 | class BasicBlock(nn.Module): 31 | expansion = 1 32 | 33 | def __init__(self, inplanes, planes, stride=1, downsample=None, remove_last_relu=False): 34 | super(BasicBlock, self).__init__() 35 | self.conv1 = conv3x3(inplanes, planes, stride) 36 | self.bn1 = nn.BatchNorm2d(planes) 37 | self.relu = nn.ReLU(inplace=True) 38 | self.conv2 = conv3x3(planes, planes) 39 | self.bn2 = nn.BatchNorm2d(planes) 40 | self.downsample = downsample 41 | self.stride = stride 42 | self.remove_last_relu = remove_last_relu 43 | 44 | def forward(self, x): 45 | identity = x 46 | 47 | out = self.conv1(x) 48 | out = self.bn1(out) 49 | out = self.relu(out) 50 | 51 | out = self.conv2(out) 52 | out = self.bn2(out) 53 | 54 | if self.downsample is not None: 55 | identity = self.downsample(x) 56 | 57 | out += identity 58 | if not self.remove_last_relu: 59 | out = self.relu(out) 60 | return out 61 | 62 | 63 | class Bottleneck(nn.Module): 64 | expansion = 4 65 | 66 | def __init__(self, inplanes, planes, stride=1, downsample=None): 67 | super(Bottleneck, self).__init__() 68 | self.conv1 = conv1x1(inplanes, planes) 69 | self.bn1 = nn.BatchNorm2d(planes) 70 | self.conv2 = conv3x3(planes, planes, stride) 71 | self.bn2 = nn.BatchNorm2d(planes) 72 | self.conv3 = conv1x1(planes, planes * self.expansion) 73 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 74 | self.relu = nn.ReLU(inplace=True) 75 | self.downsample = downsample 76 | self.stride = stride 77 | 78 | def forward(self, x): 79 | identity = x 80 | 81 | out = self.conv1(x) 82 | out = self.bn1(out) 83 | out = self.relu(out) 84 | 85 | out = self.conv2(out) 86 | out = self.bn2(out) 87 | out = self.relu(out) 88 | 89 | out = self.conv3(out) 90 | out = self.bn3(out) 91 | 92 | if self.downsample is not None: 93 | identity = self.downsample(x) 94 | 95 | out += identity 96 | out = self.relu(out) 97 | 98 | return out 99 | 100 | 101 | class ChannelAttention(nn.Module): 102 | def __init__(self, in_planes, ratio=16): 103 | super(ChannelAttention, self).__init__() 104 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 105 | self.max_pool = nn.AdaptiveMaxPool2d(1) 106 | # 共享权重的MLP 107 | self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False) 108 | self.relu1 = nn.ReLU() 109 | self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False) 110 | self.sigmoid = nn.Sigmoid() 111 | 112 | def forward(self, x): 113 | avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) 114 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) 115 | out = avg_out + max_out 116 | return self.sigmoid(out) 117 | 118 | 119 | class SpatialAttention(nn.Module): 120 | def __init__(self, kernel_size=7): 121 | super(SpatialAttention, self).__init__() 122 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7' 123 | padding = 3 if kernel_size == 7 else 1 124 | self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) 125 | self.sigmoid = nn.Sigmoid() 126 | 127 | def forward(self, x): 128 | avg_out = torch.mean(x, dim=1, keepdim=True) 129 | max_out, _ = torch.max(x, dim=1, keepdim=True) 130 | x = torch.cat([avg_out, max_out], dim=1) 131 | x = self.conv1(x) 132 | return self.sigmoid(x) 133 | 134 | class SEFeatureAt(nn.Module): 135 | def __init__(self, inplanes, type, at_res): 136 | super(SEFeatureAt, self).__init__() 137 | self.se = nn.Sequential( 138 | nn.AdaptiveAvgPool2d((1,1)), 139 | nn.Conv2d(inplanes,inplanes//16,kernel_size=1), 140 | nn.ReLU(), 141 | nn.Conv2d(inplanes//16,inplanes,kernel_size=1), 142 | nn.Sigmoid() 143 | ) 144 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 145 | self.type = type 146 | self.at_res = at_res 147 | self.ca = ChannelAttention(inplanes) 148 | self.sa = SpatialAttention() 149 | 150 | def forward(self, x): 151 | residual = x 152 | if self.type == "se": 153 | attention = self.se(x) 154 | x = x * attention 155 | elif self.type == "ffm": 156 | x = self.ca(x) * x 157 | x = self.sa(x) * x 158 | if self.at_res: 159 | x += residual 160 | x = self.avgpool(x) 161 | x = x.view(x.size(0), -1) 162 | 163 | return x 164 | 165 | 166 | class ResNet(nn.Module): 167 | def __init__(self, 168 | block, 169 | layers, 170 | nf=64, 171 | zero_init_residual=True, 172 | dataset='cifar', 173 | start_class=0, 174 | remove_last_relu=False): 175 | super(ResNet, self).__init__() 176 | self.remove_last_relu = remove_last_relu 177 | self.inplanes = nf 178 | if 'cifar' in dataset: 179 | self.conv1 = nn.Sequential(nn.Conv2d(3, nf, kernel_size=3, stride=1, padding=1, bias=False), 180 | nn.BatchNorm2d(nf), nn.ReLU(inplace=True)) 181 | elif 'imagenet' in dataset: 182 | if start_class == 0: 183 | self.conv1 = nn.Sequential( 184 | nn.Conv2d(3, nf, kernel_size=7, stride=2, padding=3, bias=False), 185 | nn.BatchNorm2d(nf), 186 | nn.ReLU(inplace=True), 187 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1), 188 | ) 189 | else: 190 | # Following PODNET implmentation 191 | self.conv1 = nn.Sequential( 192 | nn.Conv2d(3, nf, kernel_size=3, stride=1, padding=1, bias=False), 193 | nn.BatchNorm2d(nf), 194 | nn.ReLU(inplace=True), 195 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1), 196 | ) 197 | 198 | self.layer1 = self._make_layer(block, 1 * nf, layers[0]) 199 | self.layer2 = self._make_layer(block, 2 * nf, layers[1], stride=2) 200 | self.layer3 = self._make_layer(block, 4 * nf, layers[2], stride=2) 201 | self.layer4 = self._make_layer(block, 8 * nf, layers[3], stride=2, remove_last_relu=remove_last_relu) 202 | 203 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 204 | 205 | self.out_dim = 8 * nf * block.expansion 206 | 207 | for m in self.modules(): 208 | if isinstance(m, nn.Conv2d): 209 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 210 | elif isinstance(m, nn.BatchNorm2d): 211 | nn.init.constant_(m.weight, 1) 212 | nn.init.constant_(m.bias, 0) 213 | 214 | # Zero-initialize the last BN in each residual branch, 215 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 216 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 217 | if zero_init_residual: 218 | for m in self.modules(): 219 | if isinstance(m, Bottleneck): 220 | nn.init.constant_(m.bn3.weight, 0) 221 | elif isinstance(m, BasicBlock): 222 | nn.init.constant_(m.bn2.weight, 0) 223 | 224 | def _make_layer(self, block, planes, blocks, remove_last_relu=False, stride=1): 225 | downsample = None 226 | if stride != 1 or self.inplanes != planes * block.expansion: 227 | downsample = nn.Sequential( 228 | conv1x1(self.inplanes, planes * block.expansion, stride), 229 | nn.BatchNorm2d(planes * block.expansion), 230 | ) 231 | 232 | layers = [] 233 | layers.append(block(self.inplanes, planes, stride, downsample)) 234 | self.inplanes = planes * block.expansion 235 | if remove_last_relu: 236 | for i in range(1, blocks - 1): 237 | layers.append(block(self.inplanes, planes)) 238 | layers.append(block(self.inplanes, planes, remove_last_relu=True)) 239 | else: 240 | for _ in range(1, blocks): 241 | layers.append(block(self.inplanes, planes)) 242 | 243 | return nn.Sequential(*layers) 244 | 245 | def reset_bn(self): 246 | for m in self.modules(): 247 | if isinstance(m, nn.BatchNorm2d): 248 | m.reset_running_stats() 249 | 250 | def forward(self, x): 251 | x = self.conv1(x) 252 | x = self.layer1(x) 253 | x = self.layer2(x) 254 | x = self.layer3(x) 255 | x = self.layer4(x) 256 | # x = self.avgpool(x) 257 | # x = x.view(x.size(0), -1) 258 | return x 259 | 260 | 261 | def resnet18(pretrained=False, **kwargs): 262 | """Constructs a ResNet-18 model. 263 | 264 | """ 265 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 266 | if pretrained: 267 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 268 | return model 269 | 270 | 271 | def resnet34(pretrained=False, **kwargs): 272 | """Constructs a ResNet-34 model. 273 | 274 | """ 275 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 276 | if pretrained: 277 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 278 | return model 279 | 280 | 281 | def resnet50(pretrained=False, **kwargs): 282 | """Constructs a ResNet-50 model. 283 | 284 | """ 285 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 286 | if pretrained: 287 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 288 | return model 289 | 290 | 291 | def resnet101(pretrained=False, **kwargs): 292 | """Constructs a ResNet-101 model. 293 | 294 | """ 295 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 296 | if pretrained: 297 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 298 | return model 299 | 300 | 301 | def resnet152(pretrained=False, **kwargs): 302 | """Constructs a ResNet-152 model. 303 | 304 | """ 305 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 306 | if pretrained: 307 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 308 | return model 309 | -------------------------------------------------------------------------------- /inclearn/prune/sensitivity_analysis.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import copy 5 | import csv 6 | import logging 7 | from collections import OrderedDict 8 | 9 | import numpy as np 10 | import torch.nn as nn 11 | 12 | # FIXME: I don't know where "utils" should be 13 | SUPPORTED_OP_NAME = ['Conv2d', 'Conv1d'] 14 | SUPPORTED_OP_TYPE = [getattr(nn, name) for name in SUPPORTED_OP_NAME] 15 | 16 | logger = logging.getLogger('Sensitivity_Analysis') 17 | logger.setLevel(logging.INFO) 18 | 19 | 20 | class SensitivityAnalysis: 21 | def __init__(self, model, val_func, sparsities=None, prune_type='l1', early_stop_mode=None, early_stop_value=None): 22 | """ 23 | Perform sensitivity analysis for this model. 24 | Parameters 25 | ---------- 26 | model : torch.nn.Module 27 | the model to perform sensitivity analysis 28 | val_func : function 29 | validation function for the model. Due to 30 | different models may need different dataset/criterion 31 | , therefore the user need to cover this part by themselves. 32 | In the val_func, the model should be tested on the validation dateset, 33 | and the validation accuracy/loss should be returned as the output of val_func. 34 | There are no restrictions on the input parameters of the val_function. 35 | User can use the val_args, val_kwargs parameters in analysis 36 | to pass all the parameters that val_func needed. 37 | sparsities : list 38 | The sparsity list provided by users. This parameter is set when the user 39 | only wants to test some specific sparsities. In the sparsity list, each element 40 | is a sparsity value which means how much weight the pruner should prune. Take 41 | [0.25, 0.5, 0.75] for an example, the SensitivityAnalysis will prune 25% 50% 75% 42 | weights gradually for each layer. 43 | prune_type : str 44 | The pruner type used to prune the conv layers, default is 'l1', 45 | and 'l2', 'fine-grained' is also supported. 46 | early_stop_mode : str 47 | If this flag is set, the sensitivity analysis 48 | for a conv layer will early stop when the validation metric( 49 | for example, accurracy/loss) has alreay meet the threshold. We 50 | support four different early stop modes: minimize, maximize, dropped, 51 | raised. The default value is None, which means the analysis won't stop 52 | until all given sparsities are tested. This option should be used with 53 | early_stop_value together. 54 | 55 | minimize: The analysis stops when the validation metric return by the val_func 56 | lower than early_stop_value. 57 | maximize: The analysis stops when the validation metric return by the val_func 58 | larger than early_stop_value. 59 | dropped: The analysis stops when the validation metric has dropped by early_stop_value. 60 | raised: The analysis stops when the validation metric has raised by early_stop_value. 61 | early_stop_value : float 62 | This value is used as the threshold for different earlystop modes. 63 | This value is effective only when the early_stop_mode is set. 64 | 65 | """ 66 | from nni.algorithms.compression.pytorch.pruning.constants_pruner import PRUNER_DICT 67 | 68 | self.model = model 69 | self.val_func = val_func 70 | self.target_layer = OrderedDict() 71 | self.ori_state_dict = copy.deepcopy(self.model.state_dict()) 72 | self.target_layer = {} 73 | self.sensitivities = {} 74 | if sparsities is not None: 75 | self.sparsities = sorted(sparsities) 76 | else: 77 | self.sparsities = np.arange(0.1, 1.0, 0.1) 78 | self.sparsities = [np.round(x, 2) for x in self.sparsities] 79 | self.Pruner = PRUNER_DICT[prune_type] 80 | self.early_stop_mode = early_stop_mode 81 | self.early_stop_value = early_stop_value 82 | self.ori_metric = None # original validation metric for the model 83 | # already_pruned is for the iterative sensitivity analysis 84 | # For example, sensitivity_pruner iteratively prune the target 85 | # model according to the sensitivity. After each round of 86 | # pruning, the sensitivity_pruner will test the new sensitivity 87 | # for each layer 88 | self.already_pruned = {} 89 | self.model_parse() 90 | 91 | @property 92 | def layers_count(self): 93 | return len(self.target_layer) 94 | 95 | def model_parse(self): 96 | for name, submodel in self.model.named_modules(): 97 | for op_type in SUPPORTED_OP_TYPE: 98 | if isinstance(submodel, op_type): 99 | self.target_layer[name] = submodel 100 | self.already_pruned[name] = 0 101 | 102 | def _need_to_stop(self, ori_metric, cur_metric): 103 | """ 104 | Judge if meet the stop conditon(early_stop, min_threshold, 105 | max_threshold). 106 | Parameters 107 | ---------- 108 | ori_metric : float 109 | original validation metric 110 | cur_metric : float 111 | current validation metric 112 | 113 | Returns 114 | ------- 115 | stop : bool 116 | if stop the sensitivity analysis 117 | """ 118 | if self.early_stop_mode is None: 119 | # early stop mode is not enable 120 | return False 121 | assert self.early_stop_value is not None 122 | if self.early_stop_mode == 'minimize': 123 | if cur_metric < self.early_stop_value: 124 | return True 125 | elif self.early_stop_mode == 'maximize': 126 | if cur_metric > self.early_stop_value: 127 | return True 128 | elif self.early_stop_mode == 'dropped': 129 | if cur_metric < ori_metric - self.early_stop_value: 130 | return True 131 | elif self.early_stop_mode == 'raised': 132 | if cur_metric > ori_metric + self.early_stop_value: 133 | return True 134 | return False 135 | 136 | def analysis(self, val_args=None, val_kwargs=None, specified_layers=None): 137 | """ 138 | This function analyze the sensitivity to pruning for 139 | each conv layer in the target model. 140 | If start and end are not set, we analyze all the conv 141 | layers by default. Users can specify several layers to 142 | analyze or parallelize the analysis process easily through 143 | the start and end parameter. 144 | 145 | Parameters 146 | ---------- 147 | val_args : list 148 | args for the val_function 149 | val_kwargs : dict 150 | kwargs for the val_funtion 151 | specified_layers : list 152 | list of layer names to analyze sensitivity. 153 | If this variable is set, then only analyze 154 | the conv layers that specified in the list. 155 | User can also use this option to parallelize 156 | the sensitivity analysis easily. 157 | Returns 158 | ------- 159 | sensitivities : dict 160 | dict object that stores the trajectory of the 161 | accuracy/loss when the prune ratio changes 162 | """ 163 | if val_args is None: 164 | val_args = [] 165 | if val_kwargs is None: 166 | val_kwargs = {} 167 | # Get the original validation metric(accuracy/loss) before pruning 168 | # Get the accuracy baseline before starting the analysis. 169 | self.ori_metric = self.val_func(*val_args, **val_kwargs) 170 | namelist = list(self.target_layer.keys()) 171 | if specified_layers is not None: 172 | # only analyze several specified conv layers 173 | namelist = list(filter(lambda x: x in specified_layers, namelist)) 174 | for name in namelist: 175 | self.sensitivities[name] = {} 176 | for sparsity in self.sparsities: 177 | # here the sparsity is the relative sparsity of the 178 | # the remained weights 179 | # Calculate the actual prune ratio based on the already pruned ratio 180 | real_sparsity = ( 181 | 1.0 - self.already_pruned[name]) * sparsity + self.already_pruned[name] 182 | # TODO In current L1/L2 Filter Pruner, the 'op_types' is still necessary 183 | # I think the L1/L2 Pruner should specify the op_types automaticlly 184 | # according to the op_names 185 | cfg = [{'sparsity': real_sparsity, 'op_names': [ 186 | name], 'op_types': ['Conv2d']}] 187 | pruner = self.Pruner(self.model, cfg) 188 | pruner.compress() 189 | val_metric = self.val_func(*val_args, **val_kwargs) 190 | logger.info('Layer: %s Sparsity: %.2f Validation Metric: %.4f', 191 | name, real_sparsity, val_metric) 192 | 193 | self.sensitivities[name][sparsity] = val_metric 194 | pruner._unwrap_model() 195 | del pruner 196 | # check if the current metric meet the stop condition 197 | if self._need_to_stop(self.ori_metric, val_metric): 198 | break 199 | 200 | # reset the weights pruned by the pruner, because the 201 | # input sparsities is sorted, so we donnot need to reset 202 | # weight of the layer when the sparsity changes, instead, 203 | # we only need reset the weight when the pruning layer changes. 204 | self.model.load_state_dict(self.ori_state_dict) 205 | 206 | return self.sensitivities 207 | 208 | def export(self, filepath): 209 | """ 210 | Export the results of the sensitivity analysis 211 | to a csv file. The firstline of the csv file describe the content 212 | structure. The first line is constructed by 'layername' and sparsity 213 | list. Each line below records the validation metric returned by val_func 214 | when this layer is under different sparsities. Note that, due to the early_stop 215 | option, some layers may not have the metrics under all sparsities. 216 | 217 | layername, 0.25, 0.5, 0.75 218 | conv1, 0.6, 0.55 219 | conv2, 0.61, 0.57, 0.56 220 | 221 | Parameters 222 | ---------- 223 | filepath : str 224 | Path of the output file 225 | """ 226 | str_sparsities = [str(x) for x in self.sparsities] 227 | header = ['layername'] + str_sparsities 228 | with open(filepath, 'w') as csvf: 229 | csv_w = csv.writer(csvf) 230 | csv_w.writerow(header) 231 | for layername in self.sensitivities: 232 | row = [] 233 | row.append(layername) 234 | for sparsity in sorted(self.sensitivities[layername].keys()): 235 | row.append(self.sensitivities[layername][sparsity]) 236 | csv_w.writerow(row) 237 | 238 | def update_already_pruned(self, layername, ratio): 239 | """ 240 | Set the already pruned ratio for the target layer. 241 | """ 242 | self.already_pruned[layername] = ratio 243 | 244 | def load_state_dict(self, state_dict): 245 | """ 246 | Update the weight of the model 247 | """ 248 | self.ori_state_dict = copy.deepcopy(state_dict) 249 | self.model.load_state_dict(self.ori_state_dict) 250 | -------------------------------------------------------------------------------- /inclearn/tools/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | from copy import deepcopy 3 | import numpy as np 4 | import datetime 5 | 6 | import torch 7 | 8 | from inclearn.tools.metrics import ClassErrorMeter 9 | from sklearn.metrics import classification_report 10 | 11 | 12 | def get_date(): 13 | return datetime.datetime.now().strftime("%Y%m%d") 14 | 15 | 16 | def to_onehot(targets, n_classes): 17 | if not hasattr(targets, "device"): 18 | targets = torch.from_numpy(targets) 19 | onehot = torch.zeros(targets.shape[0], n_classes).to(targets.device) 20 | onehot.scatter_(dim=1, index=targets.long().view(-1, 1), value=1.0) 21 | return onehot 22 | 23 | 24 | def get_class_loss(network, cur_n_cls, loader): 25 | class_loss = torch.zeros(cur_n_cls) 26 | n_cls_data = torch.zeros(cur_n_cls) # the num of imgs for cls i. 27 | EPS = 1e-10 28 | task_size = 10 29 | network.eval() 30 | for x, y in loader: 31 | x, y = x.cuda(), y.cuda() 32 | preds = network(x)['logit'].softmax(dim=1) 33 | # preds[:,-task_size:] = preds[:,-task_size:].softmax(dim=1) 34 | for i, lbl in enumerate(y): 35 | class_loss[lbl] = class_loss[lbl] - (preds[i, lbl] + EPS).detach().log().cpu() 36 | n_cls_data[lbl] += 1 37 | class_loss = class_loss / n_cls_data 38 | return class_loss 39 | 40 | 41 | def get_featnorm_grouped_by_class(network, cur_n_cls, loader): 42 | """ 43 | Ret: feat_norms: list of list 44 | feat_norms[idx] is the list of feature norm of the images for class idx. 45 | """ 46 | feats = [[] for i in range(cur_n_cls)] 47 | feat_norms = np.zeros(cur_n_cls) 48 | network.eval() 49 | with torch.no_grad(): 50 | for x, y in loader: 51 | x = x.cuda() 52 | feat = network(x)['feature'].cpu() 53 | for i, lbl in enumerate(y): 54 | if lbl >= cur_n_cls: 55 | continue 56 | feats[lbl].append(feat[y == lbl]) 57 | for i in range(len(feats)): 58 | if len(feats[i]) != 0: 59 | feat_cls = torch.cat((feats[i])) 60 | feat_norms[i] = torch.norm(feat_cls, p=2, dim=1).mean().data.numpy() 61 | return feat_norms 62 | 63 | 64 | def set_seed(seed): 65 | print("Set seed", seed) 66 | random.seed(seed) 67 | np.random.seed(seed) 68 | torch.manual_seed(seed) 69 | torch.cuda.manual_seed_all(seed) 70 | torch.backends.cudnn.deterministic = True # This will slow down training. 71 | torch.backends.cudnn.benchmark = False 72 | 73 | 74 | def display_weight_norm(logger, network, increments, tag): 75 | weight_norms = [[] for _ in range(len(increments))] 76 | increments = np.cumsum(np.array(increments)) 77 | for idx in range(network.module.classifier.weight.shape[0]): 78 | norm = torch.norm(network.module.classifier.weight[idx].data, p=2).item() 79 | for i in range(len(weight_norms)): 80 | if idx < increments[i]: 81 | break 82 | weight_norms[i].append(round(norm, 3)) 83 | avg_weight_norm = [] 84 | # all_weight_norms = [] 85 | for idx in range(len(weight_norms)): 86 | # all_weight_norms += weight_norms[idx] 87 | # logger.info("task %s: Weight norm per class %s" % (str(idx), str(weight_norms[idx]))) 88 | avg_weight_norm.append(round(np.array(weight_norms[idx]).mean(), 3)) 89 | 90 | logger.info("%s: Weight norm per task %s" % (tag, str(avg_weight_norm))) 91 | 92 | 93 | def display_feature_norm(logger, network, loader, n_classes, increments, tag, return_norm=False): 94 | avg_feat_norm_per_cls = get_featnorm_grouped_by_class(network, n_classes, loader) 95 | feature_norms = [[] for _ in range(len(increments))] 96 | increments = np.cumsum(np.array(increments)) 97 | for idx in range(len(avg_feat_norm_per_cls)): 98 | for i in range(len(feature_norms)): 99 | if idx < increments[i]: #Find the mapping from class idx to step i. 100 | break 101 | feature_norms[i].append(round(avg_feat_norm_per_cls[idx], 3)) 102 | avg_feature_norm = [] 103 | for idx in range(len(feature_norms)): 104 | avg_feature_norm.append(round(np.array(feature_norms[idx]).mean(), 3)) 105 | logger.info("%s: Feature norm per class %s" % (tag, str(avg_feature_norm))) 106 | if return_norm: 107 | return avg_feature_norm 108 | else: 109 | return 110 | 111 | 112 | def check_loss(loss): 113 | return not bool(torch.isnan(loss).item()) and bool((loss >= 0.0).item()) 114 | 115 | def class2task(class_form, classnum): 116 | target_form = deepcopy(class_form) 117 | for i in range(classnum): 118 | mask = (target_form==i) 119 | target_form[mask] = -(i//10)-1 120 | target_form = (target_form+1)*(-1) 121 | return target_form 122 | 123 | def maskclass(pred, target, classnum, type='new'): 124 | # type 为new,遮盖new的class,old遮盖旧的,all遮盖新旧 125 | target_form = deepcopy(target) 126 | pred_form = deepcopy(pred) 127 | if type == 'old': 128 | mask = np.logical_or(pred_form<(classnum-10), target_form<(classnum-10)) 129 | pred_form[mask] = 0 130 | target_form[mask] = 0 131 | 132 | if type == 'new': 133 | mask = np.logical_or(pred_form>=(classnum-10), target_form>=(classnum-10)) 134 | pred_form[mask] = 1000 135 | target_form[mask] = 1000 136 | 137 | if type == 'all': 138 | mask = (target_form>=(classnum-10)) 139 | target_form[mask] = 1000 140 | mask = (pred_form>=(classnum-10)) 141 | pred_form[mask] = 1000 142 | 143 | mask = (target_form<(classnum-10)) 144 | target_form[mask] = 0 145 | mask = (pred_form<(classnum-10)) 146 | pred_form[mask] = 0 147 | 148 | all_err = np.sum(pred_form!=target_form) 149 | pred_form1 = deepcopy(pred_form) 150 | mask = (target_form<(classnum-10)) 151 | pred_form[mask] = 0 152 | 153 | new_old_err = np.sum(pred_form!=target_form) 154 | 155 | mask = (target_form>=(classnum-10)) 156 | pred_form1[mask] = 1000 157 | old_new_err = np.sum(pred_form1!=target_form) 158 | return all_err, new_old_err, old_new_err 159 | 160 | 161 | return pred_form, target_form 162 | 163 | def compute_old_new_mix(ypred, ytrue, increments, n_classes, task_order): 164 | 165 | task_means = [] 166 | for i in range (n_classes//10): 167 | taski_mask = np.logical_and(ytrue>=i*10, ytrue<(i+1)*10) 168 | task_i_mean = ypred[np.arange(ytrue.shape[0]), ytrue][taski_mask].mean().item() 169 | task_means.append(task_i_mean) 170 | task_mean = ypred[np.arange(ytrue.shape[0]), ytrue].mean().item() 171 | 172 | classnum = ypred.shape[1] 173 | ypred = ypred.argmax(1) 174 | all_err = np.sum(ypred!=ytrue) 175 | ypred_task = class2task(ypred, classnum) 176 | ytrue_task = class2task(ytrue, classnum) 177 | err_among_task = np.sum(ypred_task!=ytrue_task) 178 | err_inner_task = all_err - err_among_task 179 | # print("all err : {}\n among task err: {}\n inner task err: {}\n".format(all_err, err_among_task, err_inner_task)) 180 | 181 | 182 | ypred_new, ytrue_new = maskclass(ypred, ytrue, n_classes, 'old') 183 | new_err = np.sum(ypred_new!=ytrue_new) 184 | 185 | ypred_old, ytrue_old = maskclass(ypred, ytrue, n_classes, 'new') 186 | old_err = np.sum(ypred_old!=ytrue_old) 187 | 188 | all_err, new_old_err, old_new_err = maskclass(ypred, ytrue, n_classes, 'all') 189 | print("******all_err:****", all_err) 190 | 191 | all_acc = {"task_mean": task_mean, "task_means":task_means, "new_err": new_err, "old_err":old_err, "new_old_err": new_old_err, "old_new_err": old_new_err, "err_among_task": err_among_task, "err_inner_task": err_inner_task} 192 | 193 | return all_acc 194 | 195 | 196 | def compute_task_accuracy(ypred, ytrue, increments, n_classes, task_order): 197 | task_mean = ypred[np.arange(ytrue.shape[0]), ytrue].mean().item() 198 | classnum = ypred.shape[1] 199 | ypred = ypred.argmax(1) 200 | ypred_task = class2task(ypred, classnum) 201 | ytrue_task = class2task(ytrue, classnum) 202 | 203 | 204 | all_acc = {"task_mean": task_mean, "class_info": classification_report(ytrue, ypred), "task_info": classification_report(ytrue_task, ypred_task)} 205 | 206 | return all_acc 207 | 208 | def compute_accuracy(ypred, ytrue, increments, n_classes): 209 | all_acc = {"top1": {}, "top5": {}} 210 | topk = 5 if n_classes >= 5 else n_classes 211 | ncls = np.unique(ytrue).shape[0] 212 | if topk > ncls: 213 | topk = ncls 214 | all_acc_meter = ClassErrorMeter(topk=[1, topk], accuracy=True) 215 | all_acc_meter.add(ypred, ytrue) 216 | all_acc["top1"]["total"] = round(all_acc_meter.value()[0], 3) 217 | all_acc["top5"]["total"] = round(all_acc_meter.value()[1], 3) 218 | # all_acc["total"] = round((ypred == ytrue).sum() / len(ytrue), 3) 219 | 220 | # for class_id in range(0, np.max(ytrue), task_size): 221 | start, end = 0, 0 222 | for i in range(len(increments)): 223 | if increments[i] <= 0: 224 | pass 225 | else: 226 | start = end 227 | end += increments[i] 228 | 229 | idxes = np.where(np.logical_and(ytrue >= start, ytrue < end))[0] 230 | topk_ = 5 if increments[i] >= 5 else increments[i] 231 | ncls = np.unique(ytrue[idxes]).shape[0] 232 | if topk_ > ncls: 233 | topk_ = ncls 234 | cur_acc_meter = ClassErrorMeter(topk=[1, topk_], accuracy=True) 235 | cur_acc_meter.add(ypred[idxes], ytrue[idxes]) 236 | top1_acc = (ypred[idxes].argmax(1) == ytrue[idxes]).sum() / idxes.shape[0] * 100 237 | if start < end: 238 | label = "{}-{}".format(str(start).rjust(2, "0"), str(end - 1).rjust(2, "0")) 239 | else: 240 | label = "{}-{}".format(str(start).rjust(2, "0"), str(end).rjust(2, "0")) 241 | all_acc["top1"][label] = round(top1_acc, 3) 242 | all_acc["top5"][label] = round(cur_acc_meter.value()[1], 3) 243 | # all_acc[label] = round((ypred[idxes] == ytrue[idxes]).sum() / len(idxes), 3) 244 | 245 | return all_acc 246 | 247 | 248 | def make_logger(log_name, savedir='.logs/'): 249 | """Set up the logger for saving log file on the disk 250 | Args: 251 | cfg: configuration dict 252 | 253 | Return: 254 | logger: a logger for record essential information 255 | """ 256 | import logging 257 | import os 258 | from logging.config import dictConfig 259 | import time 260 | 261 | logging_config = dict( 262 | version=1, 263 | formatters={'f_t': { 264 | 'format': '\n %(asctime)s | %(levelname)s | %(name)s \t %(message)s' 265 | }}, 266 | handlers={ 267 | 'stream_handler': { 268 | 'class': 'logging.StreamHandler', 269 | 'formatter': 'f_t', 270 | 'level': logging.INFO 271 | }, 272 | 'file_handler': { 273 | 'class': 'logging.FileHandler', 274 | 'formatter': 'f_t', 275 | 'level': logging.INFO, 276 | 'filename': None, 277 | } 278 | }, 279 | root={ 280 | 'handlers': ['stream_handler', 'file_handler'], 281 | 'level': logging.DEBUG, 282 | }, 283 | ) 284 | # set up logger 285 | log_file = '{}.log'.format(log_name) 286 | # if folder not exist,create it 287 | if not os.path.exists(savedir): 288 | os.makedirs(savedir) 289 | log_file_path = os.path.join(savedir, log_file) 290 | 291 | logging_config['handlers']['file_handler']['filename'] = log_file_path 292 | 293 | open(log_file_path, 'w').close() # Clear the content of logfile 294 | # get logger from dictConfig 295 | dictConfig(logging_config) 296 | 297 | logger = logging.getLogger() 298 | 299 | return logger -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import os.path as osp 4 | import copy 5 | import time 6 | import shutil 7 | import cProfile 8 | import logging 9 | from pathlib import Path 10 | import numpy as np 11 | import random 12 | from easydict import EasyDict as edict 13 | from tensorboardX import SummaryWriter 14 | import os 15 | import inclearn.prune as pruning 16 | 17 | os.environ['CUDA_VISIBLE_DEVICES']='0' 18 | 19 | repo_name = 'TCIL' 20 | base_dir = osp.realpath(".")[:osp.realpath(".").index(repo_name) + len(repo_name)] 21 | sys.path.insert(0, base_dir) 22 | 23 | from sacred import Experiment 24 | ex = Experiment(base_dir=base_dir, save_git_info=False) 25 | 26 | 27 | import torch 28 | 29 | from inclearn.tools import factory, results_utils, utils 30 | from inclearn.learn.pretrain import pretrain 31 | from inclearn.tools.metrics import IncConfusionMeter 32 | 33 | def initialization(config, seed, mode, exp_id): 34 | 35 | torch.backends.cudnn.benchmark = True # This will result in non-deterministic results. 36 | # ex.captured_out_filter = lambda text: 'Output capturing turned off.' 37 | cfg = edict(config) 38 | utils.set_seed(cfg['seed']) 39 | if exp_id is None: 40 | exp_id = -1 41 | cfg.exp.savedir = "./logs" 42 | logger = utils.make_logger(f"exp{exp_id}_{cfg.exp.name}_{mode}", savedir=cfg.exp.savedir) 43 | 44 | # Tensorboard 45 | exp_name = f'{exp_id}_{cfg["exp"]["name"]}' if exp_id is not None else f'../inbox/{cfg["exp"]["name"]}' 46 | tensorboard_dir = cfg["exp"]["tensorboard_dir"] + f"/{exp_name}" 47 | 48 | # If not only save latest tensorboard log. 49 | # if Path(tensorboard_dir).exists(): 50 | # shutil.move(tensorboard_dir, cfg["exp"]["tensorboard_dir"] + f"/../inbox/{time.time()}_{exp_name}") 51 | 52 | tensorboard = SummaryWriter(tensorboard_dir) 53 | 54 | return cfg, logger, tensorboard 55 | 56 | 57 | @ex.command 58 | def train(_run, _rnd, _seed): 59 | cfg, ex.logger, tensorboard = initialization(_run.config, _seed, "train", _run._id) 60 | ex.logger.info(cfg) 61 | cfg.data_folder = osp.join(base_dir, "data") 62 | 63 | start_time = time.time() 64 | _train(cfg, _run, ex, tensorboard) 65 | ex.logger.info("Training finished in {}s.".format(int(time.time() - start_time))) 66 | 67 | 68 | def _train(cfg, _run, ex, tensorboard): 69 | device = factory.set_device(cfg) 70 | trial_i = cfg['trial'] 71 | 72 | inc_dataset = factory.get_data(cfg, trial_i) 73 | ex.logger.info("classes_order") 74 | ex.logger.info(inc_dataset.class_order) 75 | 76 | model = factory.get_model(cfg, trial_i, _run, ex, tensorboard, inc_dataset) 77 | 78 | if _run.meta_info["options"]["--file_storage"] is not None: 79 | _save_dir = osp.join(_run.meta_info["options"]["--file_storage"], str(_run._id)) 80 | else: 81 | _save_dir = cfg["exp"]["ckptdir"] 82 | 83 | results = results_utils.get_template_results(cfg) 84 | 85 | for task_i in range(inc_dataset.n_tasks): 86 | task_info, train_loader, val_loader, test_loader = inc_dataset.new_task() 87 | 88 | model.set_task_info( 89 | task=task_info["task"], 90 | total_n_classes=task_info["max_class"], 91 | increment=task_info["increment"], 92 | n_train_data=task_info["n_train_data"], 93 | n_test_data=task_info["n_test_data"], 94 | n_tasks=inc_dataset.n_tasks, 95 | ) 96 | 97 | model.before_task(task_i, inc_dataset) 98 | # TODO: Move to incmodel.py 99 | if 'min_class' in task_info: 100 | ex.logger.info("Train on {}->{}.".format(task_info["min_class"], task_info["max_class"])) 101 | 102 | # Pretraining at step0 if needed 103 | if task_i == 0 and cfg["start_class"] > 0: 104 | do_pretrain(cfg, ex, model, device, train_loader, test_loader) 105 | inc_dataset.shared_data_inc = train_loader.dataset.share_memory 106 | elif task_i < cfg['start_task']: 107 | state_dict = torch.load(f'./{cfg.exp.saveckpt}/step{task_i}.ckpt') 108 | model._parallel_network.load_state_dict(state_dict) 109 | inc_dataset.shared_data_inc = train_loader.dataset.share_memory 110 | else: 111 | model.train_task(train_loader, val_loader) 112 | model.after_task(task_i, inc_dataset) 113 | 114 | ex.logger.info("Eval on {}->{}.".format(0, task_info["max_class"])) 115 | 116 | ypred, ytrue = model.eval_task(test_loader) 117 | 118 | 119 | acc_stats = utils.compute_accuracy(ypred, ytrue, increments=model._increments, n_classes=model._n_classes) 120 | 121 | #Logging 122 | model._tensorboard.add_scalar(f"taskaccu/trial{trial_i}", acc_stats["top1"]["total"], task_i) 123 | 124 | _run.log_scalar(f"trial{trial_i}_taskaccu", acc_stats["top1"]["total"], task_i) 125 | _run.log_scalar(f"trial{trial_i}_task_top5_accu", acc_stats["top5"]["total"], task_i) 126 | 127 | ex.logger.info(f"top1:{acc_stats['top1']}") 128 | ex.logger.info(f"top5:{acc_stats['top5']}") 129 | 130 | results["results"].append(acc_stats) 131 | 132 | top1_avg_acc, top5_avg_acc = results_utils.compute_avg_inc_acc(results["results"]) 133 | 134 | _run.info[f"trial{trial_i}"][f"avg_incremental_accu_top1"] = top1_avg_acc 135 | _run.info[f"trial{trial_i}"][f"avg_incremental_accu_top5"] = top5_avg_acc 136 | ex.logger.info("Average Incremental Accuracy Top 1: {} Top 5: {}.".format( 137 | _run.info[f"trial{trial_i}"][f"avg_incremental_accu_top1"], 138 | _run.info[f"trial{trial_i}"][f"avg_incremental_accu_top5"], 139 | )) 140 | if cfg["exp"]["name"]: 141 | results_utils.save_results(results, cfg["exp"]["name"]) 142 | 143 | 144 | def do_pretrain(cfg, ex, model, device, train_loader, test_loader): 145 | if not os.path.exists(osp.join(ex.base_dir, 'pretrain/')): 146 | os.makedirs(osp.join(ex.base_dir, 'pretrain/')) 147 | model_path = osp.join( 148 | ex.base_dir, 149 | "pretrain/{}_{}_cosine_{}_dynamic_{}_nplus1_{}_{}_trial_{}_{}_seed_{}_start_{}_epoch_{}.pth".format( 150 | cfg["model"], 151 | cfg["convnet"], 152 | cfg["weight_normalization"], 153 | cfg["dea"], 154 | cfg["div_type"], 155 | cfg["dataset"], 156 | cfg["trial"], 157 | cfg["train_head"], 158 | cfg['seed'], 159 | cfg["start_class"], 160 | cfg["pretrain"]["epochs"], 161 | ), 162 | ) 163 | if osp.exists(model_path): 164 | print("Load pretrain model") 165 | if hasattr(model._network, "module"): 166 | model._network.module.load_state_dict(torch.load(model_path)) 167 | else: 168 | model._network.load_state_dict(torch.load(model_path)) 169 | else: 170 | pretrain(cfg, ex, model, device, train_loader, test_loader, model_path) 171 | 172 | @ex.command 173 | def test(_run, _rnd, _seed): 174 | cfg, ex.logger, tensorboard = initialization(_run.config, _seed, "test", _run._id) 175 | ex.logger.info(cfg) 176 | 177 | trial_i = cfg['trial'] 178 | cfg.data_folder = osp.join(base_dir, "data") 179 | inc_dataset = factory.get_data(cfg, trial_i) 180 | 181 | ex.logger.info("classes_order") 182 | ex.logger.info(inc_dataset.class_order) 183 | 184 | # inc_dataset._current_task = taski 185 | # train_loader = inc_dataset._get_loader(inc_dataset.data_cur, inc_dataset.targets_cur) 186 | model = factory.get_model(cfg, trial_i, _run, ex, tensorboard, inc_dataset) 187 | model._network.task_size = cfg.increment 188 | 189 | test_results = results_utils.get_template_results(cfg) 190 | 191 | for taski in range(inc_dataset.n_tasks): 192 | task_info, train_loader, _, test_loader = inc_dataset.new_task() 193 | model.set_task_info( 194 | task=task_info["task"], 195 | total_n_classes=task_info["max_class"], 196 | increment=task_info["increment"], 197 | n_train_data=task_info["n_train_data"], 198 | n_test_data=task_info["n_test_data"], 199 | n_tasks=task_info["max_task"] 200 | ) 201 | model.before_task(taski, inc_dataset) 202 | state_dict = torch.load(f'./{cfg.exp.saveckpt}/step{taski}.ckpt') 203 | if cfg.get("caculate_params", False): 204 | model._parallel_network.load_state_dict(state_dict,False) 205 | else: 206 | model._parallel_network.load_state_dict(state_dict) 207 | 208 | model.eval() 209 | 210 | #Build exemplars 211 | model.after_task(taski, inc_dataset) 212 | 213 | 214 | ypred, ytrue = model.eval_task(test_loader) 215 | 216 | test_acc_stats = utils.compute_accuracy(ypred, ytrue, increments=model._increments, n_classes=model._n_classes) 217 | 218 | test_acc_task_stats = utils.compute_old_new_mix(ypred, ytrue, increments=model._increments, n_classes=model._n_classes, task_order=inc_dataset.class_order) 219 | 220 | test_results['results'].append(test_acc_stats) 221 | ex.logger.info(f"task{taski} test acc:{test_acc_stats['top1']}") 222 | 223 | # ex.logger.info(f"task{taski} task mean:{test_acc_task_stats['task_mean']} \ntest class\n:{test_acc_task_stats['class_info']} \ntest task\n:{test_acc_task_stats['task_info']}") 224 | ex.logger.info(f"task{taski} all task mean:{test_acc_task_stats['task_mean']} \n task means: {test_acc_task_stats['task_means']} \n new err:{test_acc_task_stats['new_err']} \nold err:{test_acc_task_stats['old_err']}\nnew_old_err:{test_acc_task_stats['new_old_err']} \nold_new_err:{test_acc_task_stats['old_new_err']} \nerr_among_task:{test_acc_task_stats['err_among_task']} \nerr_inner_task:{test_acc_task_stats['err_inner_task']}") 225 | 226 | avg_test_acc = results_utils.compute_avg_inc_acc(test_results['results']) 227 | ex.logger.info(f"Test Average Incremental Accuracy: {avg_test_acc}") 228 | 229 | 230 | @ex.command 231 | def prune(_run, _rnd, _seed): 232 | from copy import deepcopy 233 | 234 | cfg, ex.logger, tensorboard = initialization(_run.config, _seed, "prune", _run._id) 235 | #ex.logger.info(cfg) 236 | 237 | trial_i = cfg['trial'] 238 | cfg.data_folder = osp.join(base_dir, "data") 239 | inc_dataset = factory.get_data(cfg, trial_i) 240 | 241 | #ex.logger.info("classes_order") 242 | #ex.logger.info(inc_dataset.class_order) 243 | 244 | model = factory.get_model(cfg, trial_i, _run, ex, tensorboard, inc_dataset) 245 | tmodel = factory.get_model(cfg, trial_i, _run, ex, tensorboard, inc_dataset) 246 | 247 | model._network.task_size = cfg.increment 248 | tmodel._network.task_size = cfg.increment 249 | 250 | test_results = results_utils.get_template_results(cfg) 251 | 252 | 253 | for taski in range(inc_dataset.n_tasks): 254 | 255 | print(f"--------------对step{taski}进行剪枝--------------") 256 | 257 | task_info, train_loader, val_loader, test_loader = inc_dataset.new_task() 258 | 259 | 260 | model.set_task_info( 261 | task=task_info["task"], 262 | total_n_classes=task_info["max_class"], 263 | increment=task_info["increment"], 264 | n_train_data=task_info["n_train_data"], 265 | n_test_data=task_info["n_test_data"], 266 | n_tasks=inc_dataset.n_tasks, 267 | ) 268 | 269 | model.before_task(taski, inc_dataset) 270 | 271 | tmodel.set_task_info( 272 | task=task_info["task"], 273 | total_n_classes=task_info["max_class"], 274 | increment=task_info["increment"], 275 | n_train_data=task_info["n_train_data"], 276 | n_test_data=task_info["n_test_data"], 277 | n_tasks=inc_dataset.n_tasks, 278 | ) 279 | 280 | tmodel.before_task(taski, inc_dataset) 281 | 282 | state_dict = torch.load(f'./{cfg.exp.saveckpt}/step{taski}.ckpt') 283 | tmodel._parallel_network.load_state_dict(state_dict) 284 | model._parallel_network.module.convnets[-1].load_state_dict(tmodel._parallel_network.module.convnets[-1].state_dict()) 285 | 286 | net = model._parallel_network.module.convnets[-1] 287 | 288 | flops_raw, params_raw = pruning.get_model_complexity_info( 289 | net, (3, 32, 32), as_strings=True, print_per_layer_stat=False) 290 | print(f'-pruning step{taski} with net{taski}') 291 | print('-[INFO] before pruning flops: ' + flops_raw) 292 | print('-[INFO] before pruning params: ' + params_raw) 293 | # 选择裁剪方式 294 | mod = 'fpgm' 295 | 296 | # 剪枝引擎建立 297 | slim = pruning.Autoslim(net, inputs=torch.randn( 298 | 1, 3, 32, 32), compression_ratio=0.4) 299 | 300 | if mod == 'fpgm': 301 | config = { 302 | 'layer_compression_ratio': None, 303 | 'norm_rate': 1.0, 'prune_shortcut': 1, 304 | 'dist_type': 'l1', 'pruning_func': 'fpgm' 305 | } 306 | elif mod == 'l1': 307 | config = { 308 | 'layer_compression_ratio': None, 309 | 'norm_rate': 1.0, 'prune_shortcut': 1, 310 | 'global_pruning': False, 'pruning_func': 'l1' 311 | } 312 | slim.base_prunging(config) 313 | flops_new, params_new = pruning.get_model_complexity_info( 314 | net, (3, 32, 32), as_strings=True, print_per_layer_stat=False) 315 | print('-[INFO] after pruning flops: ' + flops_new) 316 | print('-[INFO] after pruning params: ' + params_new) 317 | 318 | model.after_prune(taski, inc_dataset) 319 | 320 | model.set_optimizer() 321 | 322 | model.train_task(train_loader, val_loader) 323 | 324 | model.eval() 325 | 326 | model.after_task(taski, inc_dataset) 327 | 328 | model._parallel_network = model._parallel_network.cuda() 329 | ypred, ytrue = model.eval_task(test_loader) 330 | 331 | test_acc_stats = utils.compute_accuracy(ypred, ytrue, increments=model._increments, n_classes=model._n_classes) 332 | 333 | test_results['results'].append(test_acc_stats) 334 | ex.logger.info(f"task{taski} test acc:{test_acc_stats['top1']}") 335 | ex.logger.info(f"top1:{test_acc_stats['top1']}") 336 | ex.logger.info(f"top5:{test_acc_stats['top5']}") 337 | 338 | save_path = os.path.join(os.getcwd(), f"{cfg.exp.saveckpt}") 339 | torch.save(model._parallel_network.cpu(), "{}/prune_step{}.ckpt".format(save_path, taski)) # 保存整个神经网络的模型结构以及参数 340 | 341 | 342 | top1_avg_acc, top5_avg_acc = results_utils.compute_avg_inc_acc(test_results["results"]) 343 | 344 | _run.info[f"trial{trial_i}"][f"avg_incremental_accu_top1"] = top1_avg_acc 345 | _run.info[f"trial{trial_i}"][f"avg_incremental_accu_top5"] = top5_avg_acc 346 | ex.logger.info("Average Incremental Accuracy Top 1: {} Top 5: {}.".format( 347 | top1_avg_acc, 348 | top5_avg_acc, 349 | )) 350 | 351 | 352 | if __name__ == "__main__": 353 | ex.add_config("./configs/cifar_b0_10s.yaml") 354 | ex.run_commandline() 355 | --------------------------------------------------------------------------------