├── requirements.txt ├── pruner ├── __init__.py ├── l1_pruner.py ├── feat_analyze.py ├── l1_pruner_iterative.py ├── reg_pruner.py ├── meta_pruner.py └── reg_pruner_iterative.py ├── .gitignore ├── data ├── data_loader_mnist.py ├── data_loader_fmnist.py ├── data_loader_tiny_imagenet.py ├── data_loader_cifar100.py ├── data_loader_cifar10.py ├── __init__.py ├── data_loader.py └── data_loader_celeba.py ├── model ├── __init__.py ├── mobilenetv2.py ├── mlp.py ├── vgg.py ├── resnet_cifar10.py └── wrn.py ├── README.md ├── option.py ├── logger.py ├── main.py └── utils.py /requirements.txt: -------------------------------------------------------------------------------- 1 | pandas==1.1.2 2 | numpy==1.16.0 3 | matplotlib==3.1.1 4 | torchvision==0.4.2 5 | torchsummary==1.5.1 6 | pathos==0.2.7 7 | scipy==1.1.0 8 | torch==1.3.1 9 | Pillow==8.2.0 10 | PyYAML==5.4.1 11 | -------------------------------------------------------------------------------- /pruner/__init__.py: -------------------------------------------------------------------------------- 1 | from . import reg_pruner, l1_pruner 2 | from . import l1_pruner_iterative, reg_pruner_iterative 3 | 4 | pruner_dict = { 5 | 'RST': reg_pruner, 6 | 'L1': l1_pruner, 7 | 'L1_Iter': l1_pruner_iterative, 8 | 'RST_Iter': reg_pruner_iterative 9 | } -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .script_history 3 | Debug_Dir 4 | Experiments 5 | data/cifar10* 6 | data/mnist 7 | data/imagenet* 8 | data/tiny_imagenet 9 | pruned_models 10 | base_models 11 | model/*.th 12 | train_params/ 13 | scripts 14 | ._* 15 | .auto_run* 16 | .DS_Store 17 | test/ 18 | *.sh 19 | -------------------------------------------------------------------------------- /data/data_loader_mnist.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets.mnist import MNIST 2 | from torch.utils.data import DataLoader 3 | import torchvision.transforms as transforms 4 | 5 | 6 | def get_data_loader(data_path, batch_size): 7 | transform = transforms.Compose([ 8 | transforms.Resize((32, 32)), 9 | transforms.ToTensor(), 10 | transforms.Normalize( 11 | (0.1307,), (0.3081,)) 12 | ]) 13 | train_set = MNIST(data_path, 14 | train=True, 15 | download=True, 16 | transform=transform) 17 | test_set = MNIST(data_path, 18 | train=False, 19 | download=True, 20 | transform=transform) 21 | 22 | return train_set, test_set 23 | -------------------------------------------------------------------------------- /data/data_loader_fmnist.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets import FashionMNIST 2 | from torch.utils.data import DataLoader 3 | import torchvision.transforms as transforms 4 | 5 | 6 | def get_data_loader(data_path, batch_size): 7 | transform = transforms.Compose([ 8 | transforms.Resize((32, 32)), 9 | transforms.ToTensor(), 10 | transforms.Normalize( 11 | (0.1307,), (0.3081,)) 12 | ]) 13 | train_set = FashionMNIST(data_path, 14 | train=True, 15 | download=True, 16 | transform=transform) 17 | test_set = FashionMNIST(data_path, 18 | train=False, 19 | download=True, 20 | transform=transform) 21 | 22 | return train_set, test_set 23 | -------------------------------------------------------------------------------- /pruner/l1_pruner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import copy 4 | import time 5 | import numpy as np 6 | import torch.optim as optim 7 | from .meta_pruner import MetaPruner 8 | from utils import Timer 9 | 10 | class Pruner(MetaPruner): 11 | def __init__(self, model, args, logger, passer): 12 | super(Pruner, self).__init__(model, args, logger, passer) 13 | 14 | def prune(self): 15 | self._get_kept_wg_L1() 16 | self._prune_and_build_new_model() 17 | return self.model 18 | 19 | def _save_model(self, model, optimizer, acc1=0, acc5=0, mark=''): 20 | state = {'iter': self.total_iter, 21 | 'arch': self.args.arch, 22 | 'model': model, 23 | 'state_dict': model.state_dict(), 24 | 'acc1': acc1, 25 | 'acc5': acc5, 26 | 'optimizer': optimizer.state_dict(), 27 | 'ExpID': self.logger.ExpID, 28 | } 29 | self.save(state, is_best=False, mark=mark) -------------------------------------------------------------------------------- /data/data_loader_tiny_imagenet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.utils.data as data 4 | import torch.utils.data.distributed 5 | import torchvision.transforms as transforms 6 | import torchvision.datasets as datasets 7 | from option import args 8 | from PIL import Image 9 | import os 10 | from utils import Dataset_npy_batch 11 | 12 | 13 | # refer to: https://github.com/pytorch/examples/blob/master/imagenet/main.py 14 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 15 | std=[0.229, 0.224, 0.225]) 16 | transform_train = transforms.Compose([ 17 | transforms.RandomCrop(64, padding=8), # refer to the cifar case 18 | transforms.RandomHorizontalFlip(), 19 | transforms.ToTensor(), 20 | normalize, 21 | ]) 22 | transform_test = transforms.Compose([ 23 | transforms.ToTensor(), 24 | normalize, 25 | ]) 26 | 27 | def get_data_loader(data_path, batch_size): 28 | train_set = Dataset_npy_batch( 29 | data_path + "/train", 30 | transform=transform_train, 31 | ) 32 | test_set = Dataset_npy_batch( 33 | data_path + "/val", 34 | transform=transform_test, 35 | ) 36 | return train_set, test_set 37 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from importlib import import_module 2 | from .vgg import vgg11, vgg13, vgg16, vgg19 3 | from .resnet_cifar10 import resnet56 4 | from .mlp import mlp_7_linear, mlp_7_relu 5 | 6 | def set_up_model(args, logger): 7 | logger.log_printer("==> making model ...") 8 | module = import_module("model.model_%s" % args.method) 9 | model = module.make_model(args, logger) 10 | return model 11 | 12 | def is_single_branch(model_name): 13 | for k in single_branch_model: 14 | if model_name.startswith(k): 15 | return True 16 | return False 17 | 18 | 19 | model_dict = { 20 | 'mlp_7_linear': mlp_7_linear, 21 | 'mlp_7_relu': mlp_7_relu, 22 | 'resnet56': resnet56, 23 | 'vgg11': vgg11, 24 | 'vgg13': vgg13, 25 | 'vgg16': vgg16, 26 | 'vgg19': vgg19, 27 | } 28 | 29 | num_layers = { 30 | 'mlp_7_linear': 7, 31 | 'mlp_7_relu': 7, 32 | 'alexnet': 8, 33 | 'vgg11': 11, 34 | 'vgg13': 13, 35 | 'vgg16': 16, 36 | 'vgg19': 19, 37 | 'vgg11_bn': 11, 38 | 'vgg13_bn': 13, 39 | 'vgg16_bn': 16, 40 | 'vgg19_bn': 19, 41 | } 42 | 43 | single_branch_model = [ 44 | 'mlp_7', 45 | 'vgg', 46 | ] -------------------------------------------------------------------------------- /model/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision.models import alexnet 4 | try: 5 | from torchvision.models import mobilenet_v2 6 | except: 7 | pass 8 | from model import generator as g 9 | 10 | # modify mobilenet to my interface 11 | class MobilenetV2(nn.Module): 12 | def __init__(self, n_class=1000, width_mult=1.0): 13 | super(MobilenetV2, self).__init__() 14 | self.net = mobilenet_v2(width_mult=width_mult) 15 | self.net.classifier = nn.Sequential( 16 | nn.Dropout(p=0.2), 17 | nn.Linear(in_features=1280, out_features=n_class, bias=True) 18 | ) 19 | 20 | def forward(self, x, out_feat=False): 21 | embed = self.net.features(x).mean([2, 3]) 22 | x = self.net.classifier(embed) 23 | return (x, embed) if out_feat else x 24 | 25 | # modify alexnet to my interface 26 | class AlexNet(nn.Module): 27 | def __init__(self, pretrained=False): 28 | super(AlexNet, self).__init__() 29 | if pretrained: 30 | self.net = alexnet(True) 31 | else: 32 | self.net = alexnet() 33 | def forward(self, x, out_feat=False): 34 | embed = self.net.features(x).view(x.size(0), -1) 35 | x = self.net.classifier(embed) 36 | return (x, embed) if out_feat else x -------------------------------------------------------------------------------- /data/data_loader_cifar100.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets import CIFAR100 2 | import torchvision.transforms as transforms 3 | 4 | 5 | def get_data_loader(data_path, batch_size): 6 | transform_train = transforms.Compose([ 7 | transforms.RandomCrop(32, padding=4), 8 | transforms.RandomHorizontalFlip(), 9 | transforms.ToTensor(), 10 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), # ref to EigenDamage code 11 | # transforms.Normalize((0.4914, 0.4822, 0.4465), 12 | # (0.2023, 0.1994, 0.2010)), 13 | ]) 14 | 15 | transform_test = transforms.Compose([ 16 | transforms.ToTensor(), 17 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), 18 | # transforms.Normalize((0.4914, 0.4822, 0.4465), 19 | # (0.2023, 0.1994, 0.2010)), 20 | ]) 21 | 22 | train_set = CIFAR100(data_path, 23 | train=True, 24 | download=True, 25 | transform=transform_train) 26 | test_set = CIFAR100(data_path, 27 | train=False, 28 | download=True, 29 | transform=transform_test) 30 | 31 | return train_set, test_set 32 | -------------------------------------------------------------------------------- /data/data_loader_cifar10.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets import CIFAR10 2 | from torch.utils.data import DataLoader 3 | import torchvision.transforms as transforms 4 | 5 | 6 | def get_data_loader(data_path, batch_size): 7 | transform_train = transforms.Compose([ 8 | transforms.RandomCrop(32, padding=4), 9 | transforms.RandomHorizontalFlip(), 10 | transforms.ToTensor(), 11 | # transforms.Normalize((0.4914, 0.4822, 0.4465), 12 | # (0.2023, 0.1994, 0.2010)), # ref to: https://github.com/kuangliu/pytorch-cifar/blob/master/main.py 13 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 14 | std=[0.229, 0.224, 0.225]) # these mean and var are from official PyTorch ImageNet example 15 | ]) 16 | 17 | transform_test = transforms.Compose([ 18 | transforms.ToTensor(), 19 | # transforms.Normalize((0.4914, 0.4822, 0.4465), 20 | # (0.2023, 0.1994, 0.2010)), 21 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 22 | std=[0.229, 0.224, 0.225]) 23 | ]) 24 | 25 | train_set = CIFAR10(data_path, 26 | train=True, 27 | download=True, 28 | transform=transform_train) 29 | test_set = CIFAR10(data_path, 30 | train=False, 31 | download=True, 32 | transform=transform_test) 33 | 34 | return train_set, test_set 35 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from importlib import import_module 4 | import os 5 | import numpy as np 6 | import torch 7 | from torch.utils import data 8 | from torch.utils.data import DataLoader 9 | 10 | 11 | class Data(object): 12 | def __init__(self, args): 13 | self.args = args 14 | loader = import_module("data.data_loader_%s" % args.dataset) 15 | path = os.path.join(args.data_path, args.dataset) 16 | train_set, test_set = loader.get_data_loader(path, args.batch_size) 17 | 18 | self.train_loader = DataLoader(train_set, 19 | batch_size=args.batch_size, 20 | num_workers=args.workers, 21 | shuffle=True, 22 | pin_memory=True) 23 | self.train_loader_prune = DataLoader(train_set, 24 | batch_size=args.batch_size_prune, 25 | num_workers=args.workers, 26 | shuffle=True, 27 | pin_memory=True) 28 | self.test_loader = DataLoader(test_set, 29 | batch_size=256, 30 | num_workers=args.workers, 31 | shuffle=False, 32 | pin_memory=True) 33 | 34 | num_classes_dict = { 35 | 'mnist': 10, 36 | 'cifar10': 10, 37 | 'cifar100': 100, 38 | 'imagenet': 1000, 39 | 'imagenet_subset_200': 200, 40 | 'tiny_imagenet': 200, 41 | } 42 | 43 | img_size_dict = { 44 | 'mnist': 32, 45 | 'cifar10': 32, 46 | 'cifar100': 32, 47 | 'imagenet': 224, 48 | 'tiny_imagenet': 64, 49 | 'imagenet_subset_200': 224, 50 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dual Lottery Ticket Hypothesis 2 | 3 | This repository is for our ICLR'22 paper: 4 | > Dual Lottery Ticket Hypothesis [arXiv](https://arxiv.org/abs/2203.04248) \ 5 | > [Yue Bai](https://yueb17.github.io/), [Huan Wang](http://huanwang.tech/), [Zhiqiang Tao](http://ztao.cc/), [Kunpeng Li](https://kunpengli1994.github.io/), and [Yun Fu](http://www1.ece.neu.edu/~yunfu/) 6 | 7 | This paper articulates a Dual Lottery Ticket Hypothesis (DLTH) as a dual format of original Lottery Ticket Hypothesis (LTH). Correspondingly, a simple regularization based sparse network training strategy, Random Sparse Network Transformation (RST), is proposed to validate DLTH and enhance sparse network training. 8 | 9 | ## Step 1: Set up environment 10 | - python=3.6 11 | - Install libraries by `pip install -r requirements.txt`. 12 | 13 | ## Setp 2: Running 14 | ``` 15 | # Pretraining, ResNet56, CIFAR100 16 | CUDA_VISIBLE_DEVICES=0 python main.py --arch resnet56 --dataset cifar100 --method L1 --stage_pr [0,0,0,0,0] --batch_size 128 --wd 0.0005 --lr_ft 0:0.1,100:0.01,150:0.001 --epochs 200 --project pretrain_resnet56_cifar100 --save_init_model 17 | ``` 18 | 19 | ``` 20 | # RST One-shot, ResNet56, CIFAR100, sparsity ratio = 0.7 21 | CUDA_VISIBLE_DEVICES=0 python main.py --arch resnet56 --dataset cifar100 --batch_size 128 --wd 0.0005 --lr_ft 0:0.1,100:0.01,150:0.001 --epochs 200 --wg weight --base_model_path Experiments/$YOUR SAVED PRETRAINED MODEL FOLDER NAME$/weights/checkpoint_just_finished_prune.pth --stage_pr [0,0.7,0.7,0.7,0] --method RST --project RST_rs56_cifar100_pr0.7 22 | ``` 23 | 24 | ``` 25 | # RST Iter-5, ResNet56, CIFAR100, sparsity ratio = 0.7 26 | CUDA_VISIBLE_DEVICES=0 python main.py --method RST_Iter --dataset cifar100 --arch resnet56 --wd 0.0005 --batch_size 128 --base_model_path Experiments/$YOUR SAVED PRETRAINED MODEL FOLDER NAME$/weights/checkpoint_just_finished_prune.pth --stage_pr [0,0.7,0.7,0.7,0] --lr_ft 0:0.1,100:0.01,150:0.001 --epochs 200 --num_cycles 5 --project RST_Iter5_rs56_cifar100_pr0.7 --wg weight --stabilize_reg_interval 10000 --update_reg_interval 1 --pick_pruned iter_rand --RST_Iter_ft 0 27 | ``` 28 | 29 | 30 | 31 | ## Acknowledgments 32 | We refer to the following repositories for our implementations: [Regularization-Pruning](https://github.com/MingSun-Tse/Regularization-Pruning), [pytorch_resnet_cifar10](https://github.com/akamaster/pytorch_resnet_cifar10). We appreciate their great works! 33 | 34 | ## Reference 35 | Please cite this in your publication if our work helps your research. Should you have any questions, welcome to reach out to Yue Bai (bai.yue@northeastern.edu). 36 | 37 | ``` 38 | @inproceedings{ 39 | bai2022dual, 40 | title={Dual Lottery Ticket Hypothesis}, 41 | author={Yue Bai and Huan Wang and ZHIQIANG TAO and Kunpeng Li and Yun Fu}, 42 | booktitle={International Conference on Learning Representations}, 43 | year={2022}, 44 | url={https://openreview.net/forum?id=fOsN52jn25l} 45 | } 46 | ``` 47 | 48 | 49 | -------------------------------------------------------------------------------- /pruner/feat_analyze.py: -------------------------------------------------------------------------------- 1 | 2 | from collections import OrderedDict 3 | import torch.nn as nn 4 | import numpy as np 5 | 6 | class AverageMeter(object): 7 | """Computes and stores the average and current value""" 8 | def __init__(self, name, fmt=':f'): 9 | self.name = name 10 | self.fmt = fmt 11 | self.reset() 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count 24 | 25 | def __str__(self): 26 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 27 | return fmtstr.format(**self.__dict__) 28 | 29 | class FeatureAnalyzer(): 30 | def __init__(self, model, data_loader, criterion, print=print): 31 | self.feat_mean = OrderedDict() 32 | self.grad_mean = OrderedDict() 33 | self.data_loader = data_loader 34 | self.criterion = criterion 35 | self.print = print 36 | self.layer_names = {} 37 | for name, module in model.named_modules(): 38 | self.layer_names[module] = name 39 | self.register_hooks(model) 40 | self.analyze_feat(model) 41 | self.rm_hooks(model) 42 | 43 | def register_hooks(self, model): 44 | def forward_hook(m, i, o): 45 | name = self.layer_names[m] 46 | if name not in self.feat_mean: 47 | self.feat_mean[name] = AverageMeter(name) 48 | self.feat_mean[name].update(o.abs().mean().item(), o.size(0)) 49 | 50 | def backward_hook(m, grad_i, grad_o): 51 | name = self.layer_names[m] 52 | if name not in self.grad_mean: 53 | self.grad_mean[name] = AverageMeter(name) 54 | assert len(grad_o) == 1 55 | self.grad_mean[name].update(grad_o[0].abs().mean().item(), grad_o[0].size(0)) 56 | 57 | for _, module in model.named_modules(): 58 | if isinstance(module, (nn.Conv2d, nn.Linear)): 59 | module.register_forward_hook(forward_hook) 60 | module.register_backward_hook(backward_hook) 61 | 62 | def rm_hooks(self, model): 63 | for _, module in model.named_modules(): 64 | if isinstance(module, (nn.Conv2d, nn.Linear)): 65 | module._forward_hooks = OrderedDict() 66 | module._backward_hooks = OrderedDict() 67 | 68 | def analyze_feat(self, model): 69 | # forward to activate hooks 70 | for i, (images, target) in enumerate(self.data_loader): 71 | images, target = images.cuda(), target.cuda() 72 | output = model(images) 73 | loss = self.criterion(output, target) 74 | loss.backward() 75 | 76 | max_key_len = np.max([len(k) for k in self.feat_mean.keys()]) 77 | for k, v in self.feat_mean.items(): 78 | grad = self.grad_mean[k] 79 | self.print(f'{k.rjust(max_key_len)} -- feat_mean {v.avg:.4f} grad_mean {grad.avg:.10f}') -------------------------------------------------------------------------------- /model/mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | class FCNet(nn.Module): 6 | def __init__(self, d_image, n_class, n_fc, width=0, n_param=0, branch_layer_out_dim=[], act='relu', dropout=0): 7 | super(FCNet, self).__init__() 8 | # activation func 9 | if act == 'relu': 10 | activation = nn.ReLU() 11 | elif act == 'lrelu': 12 | activation = nn.LeakyReLU() 13 | elif act == 'linear': 14 | activation = nn.Identity() 15 | else: 16 | raise NotImplementedError 17 | 18 | n_middle = n_fc - 2 19 | if width == 0: 20 | # Given total num of parameters budget, calculate the width: n_middle * width^2 + width * (d_image + n_class) = n_param 21 | assert n_param > 0 22 | Delta = (d_image + n_class) * (d_image + n_class) + 4 * n_middle * n_param 23 | width = (math.sqrt(Delta) - d_image - n_class) / 2 / n_middle 24 | width = int(width) 25 | print("FC net width = %s" % width) 26 | 27 | # build the stem net 28 | net = [nn.Linear(d_image, width), activation] 29 | for i in range(n_middle): 30 | net.append(nn.Linear(width, width)) 31 | if dropout and n_middle - i <= 2: # the last two middle fc layers will be applied with dropout 32 | net.append(nn.Dropout(dropout)) 33 | net.append(activation) 34 | net.append(nn.Linear(width, n_class)) 35 | self.net = nn.Sequential(*net) 36 | 37 | # build branch layers 38 | branch = [] 39 | for x in branch_layer_out_dim: 40 | branch.append(nn.Linear(width, x)) 41 | self.branch = nn.Sequential(*branch) # so that the whole model can be put on cuda 42 | self.branch_layer_ix = [] 43 | 44 | def forward(self, img, branch_out=False, mapping=False): 45 | ''' 46 | : if output the internal features 47 | : if the internal features go through a mapping layer 48 | ''' 49 | if not branch_out: 50 | img = img.view(img.size(0), -1) 51 | return self.net(img) 52 | else: 53 | out = [] 54 | start = 0 55 | y = img.view(img.size(0), -1) 56 | keys = [int(x) for x in self.branch_layer_ix] 57 | for i in range(len(keys)): 58 | end = keys[i] + 1 59 | y = self.net[start:end](y) 60 | y_branch = self.branch[i](y) if mapping else y 61 | out.append(y_branch) 62 | start = end 63 | y = self.net[start:](y) 64 | out.append(y) 65 | return out 66 | 67 | 68 | # Refer to: A Signal Propagation Perspective for Pruning Neural Networks at Initialization (ICLR 2020). 69 | # https://github.com/namhoonlee/spp-public/blob/32bde490f19b4c28843303f1dc2935efcd09ebc9/spp/network.py#L108 70 | def mlp_7_linear(**kwargs): 71 | return FCNet(d_image=1024, n_class=10, n_fc=7, width=100, act='linear') 72 | 73 | def mlp_7_relu(**kwargs): 74 | return FCNet(d_image=1024, n_class=10, n_fc=7, width=100, act='relu') -------------------------------------------------------------------------------- /data/data_loader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | Image.MAX_IMAGE_PIXELS = None 4 | import torchvision.transforms as transforms 5 | import os 6 | import numpy as np 7 | import torch 8 | 9 | def is_img(x): 10 | _, ext = os.path.splitext(x) 11 | return ext.lower() in ['.jpg', '.png', '.bmp', '.jpeg'] 12 | 13 | # Not used in this project. 14 | class CelebA_multi_attr(data.Dataset): 15 | def __init__(self, img_dir, label_file, transform): 16 | self.img_list = [os.path.join(img_dir, i) for i in os.listdir(img_dir) if is_img(i)] 17 | self.transform = transform 18 | self.label = {} 19 | num_attributes = 40 20 | for line in open(label_file): 21 | if ".jpg" not in line: continue 22 | img_name, *attr = line.strip().split() 23 | label = torch.zeros(num_attributes).long() 24 | for i in range(num_attributes): 25 | if attr[i] == "1": 26 | label[i] = 1 27 | self.label[img_name] = label 28 | def __getitem__(self, index): 29 | img_path = self.img_list[index] 30 | img_name = img_path.split("/")[-1] 31 | img = Image.open(img_path).convert("RGB") 32 | img = img.resize((224, 224)) # for alexnet 33 | img = self.transform(img) 34 | return img.squeeze(0), self.label[img_name] 35 | def __len__(self): 36 | return len(self.img_list) 37 | 38 | # only for the most balanced attribute "Attractive" 39 | # Deprecated. This class is not fully worked through. Be careful. 40 | class CelebA(data.Dataset): 41 | def __init__(self, img_dir, label_file, transform): 42 | self.img_list = [os.path.join(img_dir, i) for i in os.listdir(img_dir) if i.endswith(".npy")] 43 | self.transform = transform 44 | if label_file.endswith(".npy"): 45 | self.label = np.load(label_file) # label file is npy 46 | else: 47 | self.label = {} 48 | for line in open(label_file): # label file is txt 49 | if ".jpg" not in line: continue 50 | img_name, *attr = line.strip().split() 51 | self.label[img_name] = int(attr[2] == "1") # "Attractive" is at the third position of all attrs 52 | 53 | def __getitem__(self, index): 54 | img_path = self.img_list[index] 55 | img_name = img_path.split("/")[-1] 56 | img = Image.open(img_path).convert("RGB") 57 | img = img.resize((224, 224)) # for alexnet 58 | img = self.transform(img) 59 | return img.squeeze(0), self.label[img_name] 60 | def __len__(self): 61 | return len(self.img_list) 62 | 63 | class CelebA_npy(data.Dataset): 64 | def __init__(self, npy_dir, label_file, transform): 65 | self.npy_list = [os.path.join(npy_dir, i) for i in os.listdir(npy_dir) if i.endswith(".npy") and i != "batch.npy"] 66 | self.transform = transform 67 | self.label = torch.from_numpy(np.load(label_file)).long() # label_file should be an npy 68 | def __getitem__(self, index): 69 | npy = self.npy_list[index] 70 | img = np.load(npy) 71 | img = Image.fromarray(img) 72 | img = self.transform(img) 73 | return img.squeeze(0), self.label[int(npy.split("/")[-1].split(".")[0])] 74 | def __len__(self): 75 | return len(self.npy_list) 76 | 77 | class Dataset_npy_batch(data.Dataset): 78 | def __init__(self, npy_dir, transform): 79 | self.data = np.load(os.path.join(npy_dir, "batch.npy")) 80 | self.transform = transform 81 | def __getitem__(self, index): 82 | img = Image.fromarray(self.data[index][0]) 83 | img = self.transform(img) 84 | label = self.data[index][1] 85 | label = torch.LongTensor([label])[0] 86 | return img.squeeze(0), label 87 | def __len__(self): 88 | return len(self.data) 89 | -------------------------------------------------------------------------------- /model/vgg.py: -------------------------------------------------------------------------------- 1 | # This file is referring to: [EigenDamage, ICML'19] at https://github.com/alecwangcq/EigenDamage-Pytorch. 2 | # We modified a little to make it more neat and standard. 3 | 4 | import math 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.init as init 8 | 9 | def _weights_init(m): 10 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 11 | init.kaiming_normal(m.weight) 12 | if m.bias is not None: 13 | m.bias.data.fill_(0) 14 | elif isinstance(m, nn.BatchNorm2d): 15 | if m.weight is not None: 16 | m.weight.data.fill_(1.0) 17 | m.bias.data.zero_() 18 | 19 | _AFFINE = True 20 | 21 | defaultcfg = { 22 | 11: [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512], 23 | 13: [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512], 24 | 16: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512], 25 | 19: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512], 26 | } 27 | 28 | 29 | class VGG(nn.Module): 30 | def __init__(self, depth=19, num_classes=10, num_channels=3, use_bn=True, init_weights=True, cfg=None): 31 | super(VGG, self).__init__() 32 | if cfg is None: 33 | cfg = defaultcfg[depth] 34 | 35 | self.num_channels = num_channels 36 | self.features = self.make_layers(cfg, use_bn) 37 | self.classifier = nn.Linear(cfg[-1], num_classes) 38 | if init_weights: 39 | self.apply(_weights_init) 40 | 41 | def make_layers(self, cfg, batch_norm=False): 42 | layers = [] 43 | in_channels = self.num_channels 44 | for v in cfg: 45 | if v == 'M': 46 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 47 | else: 48 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1, bias=False) 49 | if batch_norm: 50 | layers += [conv2d, nn.BatchNorm2d(v, affine=_AFFINE), nn.ReLU(inplace=True)] 51 | else: 52 | layers += [conv2d, nn.ReLU(inplace=True)] 53 | in_channels = v 54 | return nn.Sequential(*layers) 55 | 56 | def forward(self, x): 57 | x = self.features(x) 58 | x = nn.AvgPool2d(x.size(3))(x) 59 | x = x.view(x.size(0), -1) 60 | y = self.classifier(x) 61 | return y 62 | 63 | def _initialize_weights(self): 64 | for m in self.modules(): 65 | if isinstance(m, nn.Conv2d): 66 | n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels 67 | m.weight.data.normal_(0, math.sqrt(2. / n)) 68 | if m.bias is not None: 69 | m.bias.data.zero_() 70 | elif isinstance(m, nn.BatchNorm2d): 71 | if m.weight is not None: 72 | m.weight.data.fill_(1.0) 73 | m.bias.data.zero_() 74 | elif isinstance(m, nn.Linear): 75 | m.weight.data.normal_(0, 0.01) 76 | m.bias.data.zero_() 77 | def vgg11(num_classes=10, num_channels=3, use_bn=True): 78 | return VGG(11, num_classes=num_classes, num_channels=num_channels, use_bn=use_bn) 79 | 80 | def vgg13(num_classes=10, num_channels=3, use_bn=True): 81 | return VGG(13, num_classes=num_classes, num_channels=num_channels, use_bn=use_bn) 82 | 83 | def vgg16(num_classes=10, num_channels=3, use_bn=True): 84 | return VGG(16, num_classes=num_classes, num_channels=num_channels, use_bn=use_bn) 85 | 86 | def vgg19(num_classes=10, num_channels=3, use_bn=True): 87 | return VGG(19, num_classes=num_classes, num_channels=num_channels, use_bn=use_bn) 88 | 89 | -------------------------------------------------------------------------------- /data/data_loader_celeba.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import torch 4 | import torch.utils.data as data 5 | from torch.utils.data import DataLoader 6 | import torchvision.transforms as transforms 7 | from PIL import Image 8 | Image.MAX_IMAGE_PIXELS = None 9 | pjoin = os.path.join 10 | 11 | 12 | def is_img(x): 13 | _, ext = os.path.splitext(x) 14 | return ext.lower() in ['.jpg', '.png', '.bmp', '.jpeg'] 15 | 16 | 17 | class CelebA(data.Dataset): 18 | ''' 19 | Only for the most balanced attribute "Attractive". 20 | Deprecated. This class is not fully worked through. Be careful. 21 | ''' 22 | 23 | def __init__(self, img_dir, label_file, transform): 24 | self.img_list = [os.path.join(img_dir, i) for i in os.listdir( 25 | img_dir) if i.endswith(".npy")] 26 | self.transform = transform 27 | if label_file.endswith(".npy"): 28 | self.label = np.load(label_file) # label file is npy 29 | else: 30 | self.label = {} 31 | for line in open(label_file): # label file is txt 32 | if ".jpg" not in line: 33 | continue 34 | img_name, *attr = line.strip().split() 35 | # "Attractive" is at the 3rd position of all attrs 36 | self.label[img_name] = int(attr[2] == "1") 37 | 38 | def __getitem__(self, index): 39 | img_path = self.img_list[index] 40 | img_name = img_path.split("/")[-1] 41 | img = Image.open(img_path).convert("RGB") 42 | img = img.resize((224, 224)) # for alexnet 43 | img = self.transform(img) 44 | return img.squeeze(0), self.label[img_name] 45 | 46 | def __len__(self): 47 | return len(self.img_list) 48 | 49 | 50 | class CelebA_npy(data.Dataset): 51 | def __init__(self, npy_dir, label_file, transform): 52 | self.npy_list = [os.path.join(npy_dir, i) for i in os.listdir( 53 | npy_dir) if i.endswith(".npy") and i != "batch.npy"] 54 | self.transform = transform 55 | # label_file should be an npy 56 | self.label = torch.from_numpy(np.load(label_file)).long() 57 | 58 | def __getitem__(self, index): 59 | npy = self.npy_list[index] 60 | img = np.load(npy) 61 | img = Image.fromarray(img) 62 | img = self.transform(img) 63 | return img.squeeze(0), self.label[int(npy.split("/")[-1].split(".")[0])] 64 | 65 | def __len__(self): 66 | return len(self.npy_list) 67 | 68 | 69 | class Dataset_npy_batch(data.Dataset): 70 | def __init__(self, npy_dir, transform): 71 | self.data = np.load(os.path.join(npy_dir, "batch.npy")) 72 | self.transform = transform 73 | 74 | def __getitem__(self, index): 75 | img = Image.fromarray(self.data[index][0]) 76 | img = self.transform(img) 77 | label = self.data[index][1] 78 | label = torch.LongTensor([label])[0] 79 | return img.squeeze(0), label 80 | 81 | def __len__(self): 82 | return len(self.data) 83 | 84 | 85 | def get_data_loader(data_path, batch_size): 86 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 87 | std=[0.229, 0.224, 0.225]) 88 | transform_train = transforms.Compose([ 89 | transforms.RandomHorizontalFlip(), 90 | transforms.ToTensor(), 91 | normalize, 92 | ]) 93 | transform_test = transforms.Compose([ 94 | transforms.ToTensor(), 95 | normalize, 96 | ]) 97 | 98 | train_data_path = pjoin(data_path, "train_npy") 99 | train_label_path = pjoin(data_path, "CelebA_Attractive_label.npy") 100 | test_path = pjoin(data_path, "test_npy") 101 | assert(os.path.exists(train_data_path)) 102 | assert(os.path.exists(train_label_path)) 103 | assert(os.path.exists(test_path)) 104 | 105 | train_set = CelebA_npy( 106 | train_data_path, train_label_path, transform=transform_train) 107 | test_set = Dataset_npy_batch(test_path, transform=transform_test) 108 | 109 | 110 | return train_set, test_set -------------------------------------------------------------------------------- /pruner/l1_pruner_iterative.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import copy 4 | import time 5 | import numpy as np 6 | import torch.optim as optim 7 | from .meta_pruner import MetaPruner 8 | from utils import PresetLRScheduler, Timer 9 | from pdb import set_trace as st 10 | 11 | class Pruner(MetaPruner): 12 | def __init__(self, model, args, logger, passer): 13 | super(Pruner, self).__init__(model, args, logger, passer) 14 | self.pr_backup = {} 15 | for k, v in self.pr.items(): 16 | self.pr_backup[k] = v 17 | 18 | 19 | def _update_pr(self, cycle): 20 | '''update layer pruning ratio in iterative pruning 21 | ''' 22 | for layer, pr in self.pr_backup.items(): 23 | pr_each_time_to_current = 1 - (1 - pr) ** (1. / self.args.num_cycles) 24 | pr_each_time = pr_each_time_to_current * ( (1-pr_each_time_to_current) ** (cycle-1) ) 25 | self.pr[layer] = pr_each_time if self.args.wg in ['filter', 'channel'] else pr_each_time + self.pr[layer] 26 | 27 | 28 | 29 | def _apply_mask_forward(self): 30 | assert hasattr(self, 'mask') and len(self.mask.keys()) > 0 31 | for name, m in self.model.named_modules(): 32 | if name in self.mask: 33 | m.weight.data.mul_(self.mask[name]) 34 | 35 | def _finetune(self, cycle): 36 | lr_scheduler = PresetLRScheduler(self.args.lr_ft_mini) 37 | optimizer = optim.SGD(self.model.parameters(), 38 | lr=0, # placeholder, this will be updated later 39 | momentum=self.args.momentum, 40 | weight_decay=self.args.weight_decay) 41 | 42 | best_acc1, best_acc1_epoch = 0, 0 43 | timer = Timer(self.args.epochs_mini) 44 | for epoch in range(self.args.epochs_mini): 45 | lr = lr_scheduler(optimizer, epoch) 46 | self.logprint(f'[Subprune #{cycle} Finetune] Epoch {epoch} Set LR = {lr}') 47 | for ix, (inputs, targets) in enumerate(self.train_loader): 48 | inputs, targets = inputs.cuda(), targets.cuda() 49 | self.model.train() 50 | y_ = self.model(inputs) 51 | loss = self.criterion(y_, targets) 52 | optimizer.zero_grad() 53 | loss.backward() 54 | optimizer.step() 55 | 56 | mask_forward = not (self.args.LTH_Iter and cycle == 1) 57 | if self.args.method and self.args.wg == 'weight' and mask_forward: 58 | self._apply_mask_forward() 59 | 60 | if ix % self.args.print_interval == 0: 61 | self.logprint(f'[Subprune #{cycle} Finetune] Epoch {epoch} Step {ix} loss {loss:.4f}') 62 | # test 63 | acc1, *_ = self.test(self.model) 64 | if acc1 > best_acc1: 65 | best_acc1 = acc1 66 | best_acc1_epoch = epoch 67 | self.accprint(f'[Subprune #{cycle} Finetune] Epoch {epoch} Acc1 {acc1:.4f} (Best_Acc1 {best_acc1:.4f} @ Best_Acc1_Epoch {best_acc1_epoch}) LR {lr}') 68 | self.logprint(f'predicted finish time: {timer()}') 69 | 70 | def prune(self): 71 | # clear existing pr 72 | for layer in self.pr: 73 | self.pr[layer] = 0 74 | 75 | if self.args.LTH_Iter: 76 | self.random_initialized_model_backup = copy.deepcopy(self.model) # load random model for LTH 77 | 78 | for cycle in range(1, self.args.num_cycles + 1): 79 | self.logprint(f'==> Start subtraining #{cycle}') 80 | self._finetune(cycle) # get subtrained model 81 | self._update_pr(cycle) # get pr 82 | self._get_kept_wg_L1() # from pr, get self.pruned_wg 83 | self._prune_and_build_new_model() # from self.pruned_wg, get mask for wg = weight 84 | self.model = copy.deepcopy(self.random_initialized_model_backup) # reset model 85 | self._apply_mask_forward() # mask once 86 | return self.model 87 | 88 | else: 89 | for cycle in range(1, self.args.num_cycles + 1): 90 | self.logprint(f'==> Start subprune #{cycle}') 91 | self._update_pr(cycle) 92 | self._get_kept_wg_L1() 93 | self._prune_and_build_new_model() 94 | if cycle < self.args.num_cycles: 95 | self._finetune(cycle) # there is a big finetuning after the last pruning, so do not finetune here 96 | 97 | return self.model -------------------------------------------------------------------------------- /model/resnet_cifar10.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Refer to: https://github.com/akamaster/pytorch_resnet_cifar10/blob/master/resnet.py 3 | 4 | Properly implemented ResNet-s for CIFAR10 as described in paper [1]. 5 | 6 | The implementation and structure of this file is hugely influenced by [2] 7 | which is implemented for ImageNet and doesn't have option A for identity. 8 | Moreover, most of the implementations on the web is copy-paste from 9 | torchvision's resnet and has wrong number of params. 10 | 11 | Proper ResNet-s for CIFAR10 (for fair comparision and etc.) has following 12 | number of layers and parameters: 13 | 14 | name | layers | params 15 | ResNet20 | 20 | 0.27M 16 | ResNet32 | 32 | 0.46M 17 | ResNet44 | 44 | 0.66M 18 | ResNet56 | 56 | 0.85M 19 | ResNet110 | 110 | 1.7M 20 | ResNet1202| 1202 | 19.4m 21 | 22 | which this implementation indeed has. 23 | 24 | Reference: 25 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 26 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 27 | [2] https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 28 | 29 | If you use this implementation in you work, please don't forget to mention the 30 | author, Yerlan Idelbayev. 31 | ''' 32 | import torch 33 | import torch.nn as nn 34 | import torch.nn.functional as F 35 | import torch.nn.init as init 36 | 37 | from torch.autograd import Variable 38 | 39 | __all__ = ['ResNet', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202'] 40 | 41 | def _weights_init(m): 42 | classname = m.__class__.__name__ 43 | #print(classname) 44 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 45 | init.kaiming_normal_(m.weight) 46 | 47 | class LambdaLayer(nn.Module): 48 | def __init__(self, lambd): 49 | super(LambdaLayer, self).__init__() 50 | self.lambd = lambd 51 | 52 | def forward(self, x): 53 | return self.lambd(x) 54 | 55 | class LambdaLayer2(nn.Module): 56 | def __init__(self, planes): 57 | super(LambdaLayer2, self).__init__() 58 | self.planes = planes 59 | 60 | def forward(self, x): 61 | y = F.pad(x[:, :, ::2, ::2], 62 | (0, 0, 0, 0, self.planes//4, self.planes//4), "constant", 0) 63 | return y 64 | 65 | class BasicBlock(nn.Module): 66 | expansion = 1 67 | 68 | def __init__(self, in_planes, planes, stride=1, option='A'): 69 | super(BasicBlock, self).__init__() 70 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 71 | self.bn1 = nn.BatchNorm2d(planes) 72 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 73 | self.bn2 = nn.BatchNorm2d(planes) 74 | 75 | self.downsample = nn.Sequential() 76 | if stride != 1 or in_planes != planes: 77 | if option == 'A': 78 | """ 79 | For CIFAR10 ResNet paper uses option A. 80 | """ 81 | # self.downsample = LambdaLayer(lambda x: 82 | # F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0)) 83 | # @mingsun-tse: when pickling, the above lambda func will cause an error, so I replace it. 84 | self.downsample = LambdaLayer2(planes) 85 | 86 | elif option == 'B': 87 | self.downsample = nn.Sequential( 88 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 89 | nn.BatchNorm2d(self.expansion * planes) 90 | ) 91 | 92 | 93 | def forward(self, x): 94 | out = F.relu(self.bn1(self.conv1(x))) 95 | out = self.bn2(self.conv2(out)) 96 | out += self.downsample(x) 97 | out = F.relu(out) 98 | return out 99 | 100 | 101 | class ResNet(nn.Module): 102 | def __init__(self, block, num_blocks, num_classes=10): 103 | super(ResNet, self).__init__() 104 | self.in_planes = 16 105 | 106 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 107 | self.bn1 = nn.BatchNorm2d(16) 108 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 109 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 110 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) 111 | self.linear = nn.Linear(64, num_classes) 112 | 113 | self.apply(_weights_init) 114 | 115 | def _make_layer(self, block, planes, num_blocks, stride): 116 | strides = [stride] + [1]*(num_blocks-1) 117 | layers = [] 118 | for stride in strides: 119 | layers.append(block(self.in_planes, planes, stride)) 120 | self.in_planes = planes * block.expansion 121 | 122 | return nn.Sequential(*layers) 123 | 124 | def forward(self, x): 125 | out = F.relu(self.bn1(self.conv1(x))) 126 | out = self.layer1(out) 127 | out = self.layer2(out) 128 | out = self.layer3(out) 129 | out = F.avg_pool2d(out, out.size()[3]) 130 | out = out.view(out.size(0), -1) 131 | out = self.linear(out) 132 | return out 133 | 134 | 135 | def resnet20(num_classes=10, **kwargs): 136 | return ResNet(BasicBlock, [3, 3, 3], num_classes=num_classes) 137 | 138 | 139 | def resnet32(num_classes=10, **kwargs): 140 | return ResNet(BasicBlock, [5, 5, 5], num_classes=num_classes) 141 | 142 | 143 | def resnet44(num_classes=10, **kwargs): 144 | return ResNet(BasicBlock, [7, 7, 7], num_classes=num_classes) 145 | 146 | 147 | def resnet56(num_classes=10, **kwargs): 148 | return ResNet(BasicBlock, [9, 9, 9], num_classes=num_classes) 149 | 150 | 151 | def resnet110(num_classes=10, **kwargs): 152 | return ResNet(BasicBlock, [18, 18, 18], num_classes=num_classes) 153 | 154 | 155 | def resnet1202(num_classes=10, **kwargs): 156 | return ResNet(BasicBlock, [200, 200, 200], num_classes=num_classes) 157 | -------------------------------------------------------------------------------- /model/wrn.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is from 2019-NIPS-ZSKT (Spotlight): https://github.com/polo5/ZeroShotKnowledgeTransfer/blob/master/models/wresnet.py 3 | """ 4 | import math 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import model.generator as g 9 | import copy 10 | 11 | 12 | class BasicBlock(nn.Module): 13 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 14 | super(BasicBlock, self).__init__() 15 | self.bn1 = nn.BatchNorm2d(in_planes) 16 | self.relu1 = nn.ReLU(inplace=True) 17 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 18 | padding=1, bias=False) 19 | self.bn2 = nn.BatchNorm2d(out_planes) 20 | self.relu2 = nn.ReLU(inplace=True) 21 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 22 | padding=1, bias=False) 23 | self.droprate = dropRate 24 | self.equalInOut = (in_planes == out_planes) 25 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 26 | padding=0, bias=False) or None 27 | 28 | def forward(self, x): 29 | if not self.equalInOut: 30 | x = self.relu1(self.bn1(x)) 31 | else: 32 | out = self.relu1(self.bn1(x)) 33 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 34 | if self.droprate > 0: 35 | out = F.dropout(out, p=self.droprate, training=self.training) 36 | out = self.conv2(out) 37 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 38 | 39 | class NetworkBlock(nn.Module): 40 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 41 | super(NetworkBlock, self).__init__() 42 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 43 | 44 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 45 | layers = [] 46 | for i in range(int(nb_layers)): 47 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 48 | return nn.Sequential(*layers) 49 | 50 | def forward(self, x): 51 | return self.layer(x) 52 | 53 | class WideResNet(nn.Module): 54 | def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0): 55 | super(WideResNet, self).__init__() 56 | nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] 57 | assert((depth - 4) % 6 == 0) 58 | n = (depth - 4) / 6 59 | block = BasicBlock 60 | # 1st conv before any network block 61 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 62 | padding=1, bias=False) 63 | # 1st block 64 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 65 | # 2nd block 66 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 67 | # 3rd block 68 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 69 | # global average pooling and classifier 70 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 71 | self.relu = nn.ReLU(inplace=True) 72 | self.fc = nn.Linear(nChannels[3], num_classes) 73 | self.nChannels = nChannels[3] 74 | 75 | for m in self.modules(): 76 | if isinstance(m, nn.Conv2d): 77 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 78 | m.weight.data.normal_(0, math.sqrt(2. / n)) 79 | elif isinstance(m, nn.BatchNorm2d): 80 | m.weight.data.fill_(1) 81 | m.bias.data.zero_() 82 | elif isinstance(m, nn.Linear): 83 | m.bias.data.zero_() 84 | 85 | def forward(self, x, out_feat=False): 86 | out = self.conv1(x) 87 | out = self.block1(out) 88 | out = self.block2(out) 89 | out = self.block3(out) 90 | out = self.relu(self.bn1(out)) 91 | out = F.avg_pool2d(out, 8); embed = out 92 | out = out.view(-1, self.nChannels) 93 | if out_feat: 94 | return self.fc(out), embed 95 | else: 96 | return self.fc(out) 97 | 98 | def ccl(lr_G, lr_S, G_ix, equal_distill=False, embed=False): 99 | T = WideResNet(depth=40, num_classes=10, widen_factor=2, dropRate=0.0) 100 | if equal_distill: 101 | S = WideResNet(depth=40, num_classes=10, widen_factor=2, dropRate=0.0) 102 | else: 103 | S = WideResNet(depth=16, num_classes=10, widen_factor=2, dropRate=0.0) 104 | G = eval("g.Generator" + G_ix)() 105 | optim_G = torch.optim.Adam(G.parameters(), lr=lr_G) 106 | optim_S = torch.optim.SGD(S.parameters(), lr=lr_S, momentum=0.9, weight_decay=5e-4) 107 | return T, S, G, optim_G, optim_S 108 | 109 | def train_teacher(lr_T, embed=False, student=False): 110 | T = WideResNet(depth=40, num_classes=10, widen_factor=2, dropRate=0.0) 111 | if student: 112 | T = WideResNet(depth=16, num_classes=10, widen_factor=2, dropRate=0.0) 113 | optim_T = torch.optim.SGD(T.parameters(), lr=lr_T, momentum=0.9, weight_decay=5e-4) 114 | return T, optim_T 115 | 116 | def kd(lr_S, equal=False, embed=False): 117 | T = WideResNet(depth=40, num_classes=10, widen_factor=2, dropRate=0.0) 118 | if equal: 119 | S = copy.deepcopy(T) 120 | else: 121 | S = WideResNet(depth=16, num_classes=10, widen_factor=2, dropRate=0.0) 122 | optim_S = torch.optim.SGD(S.parameters(), lr=lr_S, momentum=0.9, weight_decay=5e-4) 123 | return T, S, optim_S 124 | 125 | if __name__ == '__main__': 126 | import random 127 | import time 128 | from torchsummary import summary 129 | 130 | x = torch.FloatTensor(64, 3, 32, 32).uniform_(0, 1) 131 | 132 | ### WideResNets 133 | # Notation: W-depth-widening_factor 134 | #model = WideResNet(depth=16, num_classes=10, widen_factor=1, dropRate=0.0) 135 | #model = WideResNet(depth=16, num_classes=10, widen_factor=2, dropRate=0.0) 136 | #model = WideResNet(depth=16, num_classes=10, widen_factor=8, dropRate=0.0) 137 | #model = WideResNet(depth=16, num_classes=10, widen_factor=10, dropRate=0.0) 138 | #model = WideResNet(depth=22, num_classes=10, widen_factor=8, dropRate=0.0) 139 | #model = WideResNet(depth=34, num_classes=10, widen_factor=2, dropRate=0.0) 140 | #model = WideResNet(depth=40, num_classes=10, widen_factor=10, dropRate=0.0) 141 | #model = WideResNet(depth=40, num_classes=10, widen_factor=1, dropRate=0.0) 142 | model = WideResNet(depth=40, num_classes=10, widen_factor=2, dropRate=0.0) 143 | ###model = WideResNet(depth=50, num_classes=10, widen_factor=2, dropRate=0.0) 144 | 145 | t0 = time.time() 146 | output, *act = model(x) 147 | print("Time taken for forward pass: {} s".format(time.time() - t0)) 148 | print("\nOUTPUT SHPAE: ", output.shape) 149 | 150 | summary(model, input_size=(3, 32, 32)) -------------------------------------------------------------------------------- /option.py: -------------------------------------------------------------------------------- 1 | import torchvision.models as models 2 | import argparse 3 | import sys 4 | 5 | 6 | model_names = sorted(name for name in models.__dict__ 7 | if name.islower() and not name.startswith("__") 8 | and callable(models.__dict__[name])) 9 | 10 | parser = argparse.ArgumentParser(description='Dual Lottery Ticket Hypothesis PyTorch') 11 | parser.add_argument('--data', metavar='DIR', 12 | help='path to dataset') 13 | parser.add_argument('--dataset', 14 | help='dataset name', choices=['mnist', 'cifar10', 'cifar100', 'imagenet', 'imagenet_subset_200', 'tiny_imagenet']) 15 | parser.add_argument('--use_lmdb', action='store_true', 16 | help='use lmdb format data instead of images of .JPEG/.PNG etc.') 17 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', 18 | # choices=model_names, # @mst: We will use more than the imagenet models, so remove this 19 | help='model architecture: ' + 20 | ' | '.join(model_names) + 21 | ' (default: resnet18)') 22 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 23 | help='number of data loading workers (default: 4)') 24 | parser.add_argument('--epochs', default=90, type=int, metavar='N', 25 | help='number of total epochs to run') 26 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 27 | help='manual epoch number (useful on restarts)') 28 | parser.add_argument('-b', '--batch-size', '--batch_size', default=256, type=int, 29 | metavar='N', 30 | help='mini-batch size (default: 256), this is the total ' 31 | 'batch size of all GPUs on the current node when ' 32 | 'using Data Parallel or Distributed Data Parallel') 33 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 34 | metavar='LR', help='initial learning rate', dest='lr') 35 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 36 | help='momentum') 37 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 38 | metavar='W', help='weight decay (default: 1e-4)', 39 | dest='weight_decay') 40 | parser.add_argument('-p', '--print-freq', default=10, type=int, 41 | metavar='N', help='print frequency (default: 10)') 42 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 43 | help='path to latest checkpoint (default: none)') 44 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 45 | help='evaluate model on validation set') 46 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 47 | help='use pre-trained model') 48 | parser.add_argument('--world-size', default=-1, type=int, 49 | help='number of nodes for distributed training') 50 | parser.add_argument('--rank', default=-1, type=int, 51 | help='node rank for distributed training') 52 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 53 | help='url used to set up distributed training') 54 | parser.add_argument('--dist-backend', default='nccl', type=str, 55 | help='distributed backend') 56 | parser.add_argument('--seed', default=None, type=int, 57 | help='seed for initializing training. ') 58 | parser.add_argument('--gpu', default=None, type=int, 59 | help='GPU id to use.') 60 | parser.add_argument('--multiprocessing-distributed', action='store_true', 61 | help='Use multi-processing distributed training to launch ' 62 | 'N processes per node, which has N GPUs. This is the ' 63 | 'fastest way to use PyTorch for either single node or ' 64 | 'multi node data parallel training') 65 | 66 | import os, copy 67 | from utils import strlist_to_list, strdict_to_dict, check_path, parse_prune_ratio_vgg, merge_args 68 | from model import num_layers, is_single_branch 69 | pjoin = os.path.join 70 | 71 | 72 | # routine params 73 | parser.add_argument('--project_name', type=str, default="") 74 | parser.add_argument('--debug', action="store_true") 75 | parser.add_argument('--screen_print', action="store_true") 76 | parser.add_argument('--note', type=str, default='', help='experiment note') 77 | parser.add_argument('--print_interval', type=int, default=100) 78 | parser.add_argument('--test_interval', type=int, default=2000) 79 | parser.add_argument('--plot_interval', type=int, default=100000000) 80 | parser.add_argument('--save_interval', type=int, default=2000, help="the interval to save model") 81 | parser.add_argument('--ExpID', type=str, default='', 82 | help='Experiment id. In default it will be assigned automatically') 83 | 84 | # base model related 85 | parser.add_argument('--resume_path', type=str, default=None, help="supposed to replace the original 'resume' feature") 86 | parser.add_argument('--directly_ft_weights', type=str, default=None, help="the path to a pretrained model") 87 | parser.add_argument('--base_model_path', type=str, default=None, help="the path to the unpruned base model") 88 | parser.add_argument('--test_pretrained', action="store_true", help='test the pretrained model') 89 | parser.add_argument('--start_epoch', type=int, default=0) 90 | parser.add_argument('--save_init_model', action="store_true", help='save the model after initialization') 91 | 92 | # general pruning method related 93 | parser.add_argument('--method', type=str, default="", choices=['', 'L1', 'L1_Iter', 'RST', 'RST_Iter'], 94 | help='pruning method name; default is "", implying the original training without any pruning') 95 | parser.add_argument('--stage_pr', type=str, default="", help='to appoint layer-wise pruning ratio') 96 | parser.add_argument('--index_layer', type=str, default="numbers", choices=['numbers', 'name_matching'], 97 | help='the rule to index layers in a network by its name; used in designating pruning ratio') 98 | parser.add_argument('--previous_layers', type=str, default='') 99 | parser.add_argument('--skip_layers', type=str, default="", help='layer id to skip when pruning') 100 | parser.add_argument('--lr_ft', type=str, default="{0:0.01,30:0.001,60:0.0001,75:0.00001}") 101 | parser.add_argument('--data_path', type=str, default="./data") 102 | parser.add_argument('--wg', type=str, default="filter", choices=['filter', 'channel', 'weight']) 103 | parser.add_argument('--pick_pruned', type=str, default='min', choices=['min', 'max', 'rand', 'iter_rand'], help='the criterion to select weights to prune') 104 | parser.add_argument('--reinit', type=str, default='', help='before finetuning, the pruned model will be reinited') 105 | parser.add_argument('--not_use_bn', dest='use_bn', default=True, action="store_false", help='if use BN in the network') 106 | parser.add_argument('--block_loss_grad', action="store_true", help="block the grad from loss, only apply weight decay") 107 | parser.add_argument('--save_mag_reg_log', action="store_true", help="save log of L1-norm of filters wrt reg") 108 | parser.add_argument('--save_order_log', action="store_true") 109 | parser.add_argument('--mag_ratio_limit', type=float, default=1000) 110 | parser.add_argument('--base_pr_model', type=str, default=None, help='the model that provides layer-wise pr') 111 | parser.add_argument('--inherit_pruned', type=str, default='index', choices=['index', 'pr'], 112 | help='when --base_pr_model is provided, we can choose to inherit the pruned index or only the pruning ratio (pr)') 113 | parser.add_argument('--model_noise_std', type=float, default=0, help='add Gaussian noise to model weights') 114 | parser.add_argument('--model_noise_num', type=int, default=10) 115 | parser.add_argument('--last_n_epoch', type=int, default=5, help='in correlation analysis, collect the last_n_epoch loss and average them') 116 | parser.add_argument('--init', type=str, default='default', help="weight initialization scheme") 117 | parser.add_argument('--activation', type=str, default='relu', help="activation function", choices=['relu', 'leaky_relu', 'linear', 'tanh', 'sigmoid']) 118 | parser.add_argument('--lr_AI', type=float, default=0.001, help="lr in approximate_isometry_optimize") 119 | parser.add_argument('--solver', type=str, default='SGD') 120 | parser.add_argument('--verbose', action="store_true", help='if true, print debug logs') 121 | 122 | # GReg method related (default setting is for ImageNet): 123 | parser.add_argument('--batch_size_prune', type=int, default=64) 124 | parser.add_argument('--update_reg_interval', type=int, default=5) 125 | parser.add_argument('--stabilize_reg_interval', type=int, default=40000) 126 | parser.add_argument('--lr_prune', type=float, default=0.001) 127 | parser.add_argument('--reg_upper_limit', type=float, default=1.0) 128 | parser.add_argument('--reg_upper_limit_pick', type=float, default=1e-2) 129 | parser.add_argument('--reg_granularity_pick', type=float, default=1e-5) 130 | parser.add_argument('--reg_granularity_prune', type=float, default=1e-4) 131 | parser.add_argument('--reg_granularity_recover', type=float, default=-1e-4) 132 | parser.add_argument('--RST_schedule', type=str, default='x', choices=['x', 'x^2', 'x^3']) 133 | 134 | # Iterative RST method related 135 | parser.add_argument('--batch_size_prune_mini', type=int, default=64) 136 | parser.add_argument('--update_reg_interval_mini', type=int, default=1) 137 | parser.add_argument('--stabilize_reg_interval_mini', type=int, default=1000) 138 | parser.add_argument('--lr_prune_mini', type=float, default=0.001) 139 | parser.add_argument('--reg_upper_limit_mini', type=float, default=0.0001) 140 | parser.add_argument('--reg_upper_limit_pick_mini', type=float, default=1e-2) 141 | parser.add_argument('--reg_granularity_pick_mini', type=float, default=1e-5) 142 | parser.add_argument('--reg_granularity_prune_mini', type=float, default=1e-4) 143 | parser.add_argument('--reg_granularity_recover_mini', type=float, default=-1e-4) 144 | parser.add_argument('--RST_Iter_ft', type = int, default = 0) 145 | parser.add_argument('--RST_Iter_weight_delete', action='store_true', 146 | help='if delete the Greged weight in each cycle') 147 | 148 | # LTH related 149 | parser.add_argument('--num_cycles', type=int, default=0, 150 | help='num of cycles in iterative pruning') 151 | parser.add_argument('--lr_ft_mini', type=str, default='', 152 | help='finetuning lr in each iterative pruning cycle') 153 | parser.add_argument('--epochs_mini', type=int, default=0, 154 | help='num of epochs in each iterative pruning cycle') 155 | parser.add_argument('--LTH_Iter', action='store_true', 156 | help='if use iterative way to make LTH, 0 for no, 1 for yes') 157 | # parser.add_argument('--random_initialized_model', type=str, default=None, 158 | # help='load the random initialized model weights') 159 | 160 | 161 | args = parser.parse_args() 162 | args_tmp = {} 163 | for k, v in args._get_kwargs(): 164 | args_tmp[k] = v 165 | 166 | # Above is the default setting. But if we explicitly assign new value for some arg in the shell script, 167 | # the following will adjust the arg to the assigned value. 168 | script = " ".join(sys.argv) 169 | for k, v in args_tmp.items(): 170 | if k in script: 171 | args.__dict__[k] = v 172 | 173 | # parse for layer-wise prune ratio 174 | # stage_pr is a list of float, skip_layers is a list of strings 175 | if args.stage_pr: 176 | if args.index_layer == 'numbers': # deprecated, kept for now for back-compatability, will be removed 177 | if is_single_branch(args.arch): # e.g., alexnet, vgg 178 | args.stage_pr = parse_prune_ratio_vgg(args.stage_pr, num_layers=num_layers[args.arch]) # example: [0-4:0.5, 5:0.6, 8-10:0.2] 179 | args.skip_layers = strlist_to_list(args.skip_layers, str) # example: [0, 2, 6] 180 | else: # e.g., resnet 181 | args.stage_pr = strlist_to_list(args.stage_pr, float) # example: [0, 0.4, 0.5, 0] 182 | args.skip_layers = strlist_to_list(args.skip_layers, str) # example: [2.3.1, 3.1] 183 | elif args.index_layer == 'name_matching': 184 | args.stage_pr = strdict_to_dict(args.stage_pr, float) 185 | else: 186 | assert args.base_pr_model, 'If stage_pr is not provided, base_pr_model must be provided' 187 | 188 | # set up finetuning lr 189 | assert args.lr_ft, 'lr_ft must be provided' 190 | args.lr_ft = strdict_to_dict(args.lr_ft, float) 191 | 192 | args.resume_path = check_path(args.resume_path) 193 | args.directly_ft_weights = check_path(args.directly_ft_weights) 194 | args.base_model_path = check_path(args.base_model_path) 195 | args.base_pr_model = check_path(args.base_pr_model) 196 | 197 | args.previous_layers = strdict_to_dict(args.previous_layers, str) 198 | 199 | if args.method in ['L1_Iter', 'RST_Iter']: 200 | assert args.num_cycles > 0 201 | args.lr_ft_mini = strdict_to_dict(args.lr_ft_mini, float) 202 | 203 | # some deprecated params to maintain back-compatibility 204 | args.copy_bn_w = True 205 | args.copy_bn_b = True 206 | args.reg_multiplier = 1 207 | 208 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | import time, math, os, sys, copy, numpy as np, shutil as sh 2 | import matplotlib; matplotlib.use("Agg") 3 | import matplotlib.pyplot as plt 4 | try: 5 | from utils import get_project_path, mkdirs 6 | except: 7 | from uutils import get_project_path, mkdirs # sometimes, there is a name conflict for 'utils' then we will use 'uutils' 8 | from mpl_toolkits.axes_grid1 import make_axes_locatable 9 | from collections import OrderedDict 10 | import json, yaml 11 | import logging 12 | import traceback 13 | pjoin = os.path.join 14 | 15 | # globals 16 | CONFIDENTIAL_SERVERS = ['202', '008'] 17 | 18 | class LogPrinter(object): 19 | def __init__(self, file, ExpID, print_to_screen=False): 20 | self.file = file 21 | self.ExpID = ExpID 22 | self.print_to_screen = print_to_screen 23 | 24 | def __call__(self, *in_str): 25 | in_str = [str(x) for x in in_str] 26 | in_str = " ".join(in_str) 27 | short_exp_id = self.ExpID[-6:] 28 | pid = os.getpid() 29 | current_time = time.strftime("%Y/%m/%d-%H:%M:%S") 30 | out_str = "[%s %s %s] %s" % (short_exp_id, pid, current_time, in_str) 31 | print(out_str, file=self.file, flush=True) # print to txt 32 | if self.print_to_screen: 33 | print(out_str) # print to screen 34 | 35 | def logprint(self, *in_str): # to keep the interface uniform 36 | self.__call__(*in_str) 37 | 38 | def accprint(self, *in_str): 39 | blank = ' ' * int(self.ExpID[-1]) 40 | self.__call__(blank, *in_str) 41 | 42 | def netprint(self, *in_str): # i.e., print without any prefix 43 | '''Deprecated. Use netprint in Logger. 44 | ''' 45 | for x in in_str: 46 | print(x, file=self.file, flush=True) 47 | if self.print_to_screen: 48 | print(x) 49 | 50 | def print(self, *in_str): 51 | '''print without any prefix 52 | ''' 53 | for x in in_str: 54 | print(x, file=self.file, flush=True) 55 | if self.print_to_screen: 56 | print(x) 57 | 58 | def print_args(self, args): 59 | ''' 60 | Example: ('batch_size', 16) ('CodeID', 12defsd2) ('decoder', models/small16x_ae_base/d5_base.pth) 61 | It will sort the arg keys in alphabeta order, ignoring the upper/lower difference. 62 | ''' 63 | # build a key map for later sorting 64 | key_map = {} 65 | for k in args.__dict__: 66 | k_lower = k.lower() 67 | if k_lower in key_map: 68 | key_map[k_lower + '_' + k_lower] = k 69 | else: 70 | key_map[k_lower] = k 71 | 72 | # print in the order of sorted lower keys 73 | logtmp = '' 74 | for k_ in sorted(key_map.keys()): 75 | real_key = key_map[k_] 76 | logtmp += "('%s': %s) " % (real_key, args.__dict__[real_key]) 77 | self.print(logtmp[:-1] + '\n') # the last one is blank 78 | 79 | class LogTracker(object): 80 | def __init__(self, momentum=0.9): 81 | self.loss = OrderedDict() 82 | self.momentum = momentum 83 | self.show = OrderedDict() 84 | 85 | def __call__(self, name, value, step=-1, show=True): 86 | ''' 87 | Update the loss value of 88 | ''' 89 | assert(type(step) == int) 90 | # value = np.array(value) 91 | 92 | if step == -1: 93 | if name not in self.loss: 94 | self.loss[name] = value 95 | else: 96 | self.loss[name] = self.loss[name] * \ 97 | self.momentum + value * (1 - self.momentum) 98 | else: 99 | if name not in self.loss: 100 | self.loss[name] = [[step, value]] 101 | else: 102 | self.loss[name].append([step, value]) 103 | 104 | # if the loss item will show in the log printing 105 | self.show[name] = show 106 | 107 | def avg(self, name): 108 | nparray = np.array(self.loss[name]) 109 | return np.mean(nparray[:, 1], aixs=0) 110 | 111 | def max(self, name): 112 | nparray = np.array(self.loss[name]) 113 | # TODO: max index 114 | return np.max(nparray[:, 1], axis=0) 115 | 116 | def format(self): 117 | ''' 118 | loss example: 119 | [[1, xx], [2, yy], ...] ==> [[step, [xx, yy]], ...] 120 | xx ==> [xx, yy, ...] 121 | ''' 122 | keys = self.loss.keys() 123 | k_str, v_str = [], [] 124 | for k in keys: 125 | if self.show[k] == False: 126 | continue 127 | v = self.loss[k] 128 | if not hasattr(v, "__len__"): # xx 129 | v = "%.4f" % v 130 | else: 131 | if not hasattr(v[0], "__len__"): # [xx, yy, ...] 132 | v = " ".join(["%.3f" % x for x in v]) 133 | elif hasattr(v[0][1], "__len__"): # [[step, [xx, yy]], ...] 134 | v = " ".join(["%.3f" % x for x in v[-1][1]]) 135 | else: # [[1, xx], [2, yy], ...] 136 | v = "%.4f" % v[-1][1] 137 | 138 | length = min(max(len(k), len(v)), 15) 139 | format_str = "{:<%d}" % (length) 140 | k_str.append(format_str.format(k)) 141 | v_str.append(format_str.format(v)) 142 | k_str = " | ".join(k_str) 143 | v_str = " | ".join(v_str) 144 | return k_str + " |", v_str + " |" 145 | 146 | def plot(self, name, out_path): 147 | ''' 148 | Plot the loss of , save it to . 149 | ''' 150 | v = self.loss[name] 151 | if (not hasattr(v, "__len__")) or type(v[0][0]) != int: # do not log the 'step' 152 | return 153 | if hasattr(v[0][1], "__len__"): 154 | # self.plot_heatmap(name, out_path) 155 | return 156 | v = np.array(v) 157 | step, value = v[:, 0], v[:, 1] 158 | fig = plt.figure() 159 | ax = fig.add_subplot(111) 160 | ax.plot(step, value) 161 | ax.set_xlabel("step") 162 | ax.set_ylabel(name) 163 | ax.grid() 164 | fig.savefig(out_path, dpi=200) 165 | plt.close(fig) 166 | 167 | def plot_heatmap(self, name, out_path, show_ticks=False): 168 | ''' 169 | A typical case: plot the training process of 10 weights 170 | x-axis: step 171 | y-axis: index (10 weights, 0-9) 172 | value: the weight values 173 | ''' 174 | v = self.loss[name] 175 | step, value = [], [] 176 | [(step.append(x[0]), value.append(x[1])) for x in v] 177 | n_class = len(value[0]) 178 | fig, ax = plt.subplots(figsize=[0.1*len(step), n_class / 5]) # /5 is set manually 179 | im = ax.imshow(np.transpose(value), cmap='jet') 180 | 181 | # make a beautiful colorbar 182 | divider = make_axes_locatable(ax) 183 | cax = divider.append_axes('right', size=0.05, pad=0.05) 184 | fig.colorbar(im, cax=cax, orientation='vertical') 185 | 186 | # set the x and y ticks 187 | # For now, this can not adjust its range adaptively, so deprecated. 188 | # ax.set_xticks(range(len(step))); ax.set_xticklabels(step) 189 | # ax.set_yticks(range(len(value[0]))); ax.set_yticklabels(range(len(value[0]))) 190 | 191 | interval = step[0] if len(step) == 1 else step[1] - step[0] 192 | ax.set_xlabel("step (* interval = %d)" % interval) 193 | ax.set_ylabel("index") 194 | ax.set_title(name) 195 | fig.savefig(out_path, dpi=200) 196 | plt.close(fig) 197 | 198 | class Logger(object): 199 | ''' 200 | The top logger, which 201 | (1) set up all log directories 202 | (2) maintain the losses and accuracies 203 | ''' 204 | 205 | def __init__(self, args): 206 | self.args = args 207 | 208 | # set up work folder 209 | self.ExpID = args.ExpID if hasattr(args, 'ExpID') and args.ExpID else self.get_ExpID() 210 | self.Exps_Dir = 'Experiments' 211 | if hasattr(self.args, 'Exps_Dir'): 212 | self.Exps_Dir = self.args.Exps_Dir 213 | self.set_up_dir() 214 | 215 | self.log_printer = LogPrinter( 216 | self.logtxt, self.ExpID, self.args.debug or self.args.screen_print) # for all txt logging 217 | self.log_tracker = LogTracker() # for all numerical logging 218 | 219 | # initial print: save args 220 | self.print_script() 221 | self.print_nvidia_smi() 222 | self.print_git_status() 223 | self.print_note() 224 | if (not args.debug) and self.SERVER != '': 225 | # If self.SERVER != '', it shows this is Huan's computer, then call this func, which is just a small feature to my need. 226 | # When others use this code, they probably need NOT call this func. 227 | # self.__send_to_exp_hub() # this function is not very useful. deprecated. 228 | pass 229 | args.CodeID = self.get_CodeID() 230 | self.log_printer.print_args(args) 231 | self.save_args(args) 232 | self.cache_model() 233 | self.n_log_item = 0 234 | 235 | def get_CodeID(self): 236 | if hasattr(self.args, 'CodeID') and self.args.CodeID: 237 | return self.args.CodeID 238 | else: 239 | f = 'wh_git_status_%s.tmp' % time.time() 240 | script = 'git status >> %s' % f 241 | os.system(script) 242 | x = open(f).readlines() 243 | x = "".join(x) 244 | os.remove(f) 245 | if "Changes not staged for commit" in x: 246 | self.log_printer("Warning! Your code is not commited. Cannot be too careful.") 247 | time.sleep(3) 248 | 249 | f = 'wh_CodeID_file_%s.tmp' % time.time() 250 | script = "git log --pretty=oneline >> %s" % f 251 | os.system(script) 252 | x = open(f).readline() 253 | os.remove(f) 254 | return x[:8] 255 | 256 | def get_ExpID(self): 257 | self.SERVER = os.environ["SERVER"] if 'SERVER' in os.environ.keys() else '' 258 | TimeID = time.strftime("%Y%m%d-%H%M%S") 259 | ExpID = 'SERVER' + self.SERVER + '-' + TimeID 260 | return ExpID 261 | 262 | def set_up_dir(self): 263 | project_path = pjoin("%s/%s_%s" % (self.Exps_Dir, self.args.project_name, self.ExpID)) 264 | if hasattr(self.args, 'resume_ExpID') and self.args.resume_ExpID: 265 | project_path = get_project_path(self.args.resume_ExpID) 266 | if self.args.debug: # debug has the highest priority. If debug, all the things will be saved in Debug_dir 267 | project_path = "Debug_Dir" 268 | 269 | self.exp_path = project_path 270 | self.weights_path = pjoin(project_path, "weights") 271 | self.gen_img_path = pjoin(project_path, "gen_img") 272 | self.cache_path = pjoin(project_path, ".caches") 273 | self.log_path = pjoin(project_path, "log") 274 | self.logplt_path = pjoin(project_path, "log", "plot") 275 | self.logtxt_path = pjoin(project_path, "log", "log.txt") 276 | mkdirs(self.weights_path, self.gen_img_path, self.logplt_path, self.cache_path) 277 | self.logtxt = open(self.logtxt_path, "a+") 278 | self.script_hist = open('.script_history', 'a+') # save local script history, for convenience of check 279 | 280 | def print_script(self): 281 | script = 'cd %s\n' % os.path.abspath(os.getcwd()) 282 | if 'CUDA_VISIBLE_DEVICES' in os.environ: 283 | gpu_id = os.environ['CUDA_VISIBLE_DEVICES'] 284 | script += ' '.join(['CUDA_VISIBLE_DEVICES=%s python' % gpu_id, *sys.argv]) 285 | else: 286 | script += ' '.join(['python', *sys.argv]) 287 | script += '\n' 288 | print(script, file=self.logtxt, flush=True) 289 | print(script, file=sys.stdout, flush=True) 290 | print(script, file=self.script_hist, flush=True) 291 | 292 | def print_exc(self): 293 | traceback.print_exc(file=self.logtxt) 294 | 295 | def print_nvidia_smi(self): 296 | out = pjoin(self.log_path, 'gpu_info.txt') 297 | script = 'nvidia-smi >> %s' % out 298 | os.system(script) 299 | 300 | def print_git_status(self): 301 | out = pjoin(self.log_path, 'git_status.txt') 302 | script = 'git status >> %s' % out 303 | try: 304 | os.system(script) 305 | except: 306 | pass 307 | 308 | def print_note(self): 309 | project = self.get_project_name() # the current project folder name 310 | exp_id = self.ExpID.split('-')[-1] # SERVER138-20200623-095526 311 | if hasattr(self.args, 'note') and self.args.note: 312 | self.ExpNote = 'ExpNote [%s-%s-%s]: "%s" -- %s' % (self.SERVER, project, exp_id, self.args.note, self.args.project_name) 313 | print(self.ExpNote, file=self.logtxt, flush=True) 314 | print(self.ExpNote, file=sys.stdout, flush=True) 315 | 316 | def plot(self, name, out_path): 317 | self.log_tracker.plot(name, out_path) 318 | 319 | def print(self, step): 320 | keys, values = self.log_tracker.format() 321 | k = keys.split("|")[0].strip() 322 | if k: # only when there is sth to print, print 323 | values += " (step = %d)" % step 324 | if step % (self.args.print_interval * 10) == 0 \ 325 | or len(self.log_tracker.loss.keys()) > self.n_log_item: # when a new loss is added into the loss pool, print 326 | self.log_printer(keys) 327 | self.n_log_item = len(self.log_tracker.loss.keys()) 328 | self.log_printer(values) 329 | 330 | def cache_model(self): 331 | ''' 332 | Save the modle architecture, loss, configs, in case of future check. 333 | ''' 334 | if self.args.debug: return 335 | 336 | t0 = time.time() 337 | if not os.path.exists(self.cache_path): 338 | os.makedirs(self.cache_path) 339 | self.log_printer(f"==> Caching various config files to '{self.cache_path}'") 340 | 341 | extensions = ['.py', '.json', '.yaml', '.sh', '.txt', '.md'] # files of these types will be cached 342 | def copy_folder(folder_path): 343 | for root, dirs, files in os.walk(folder_path): 344 | if '__pycache__' in root: continue 345 | for f in files: 346 | _, ext = os.path.splitext(f) 347 | if ext in extensions: 348 | dir_path = pjoin(self.cache_path, root) 349 | f_path = pjoin(root, f) 350 | if not os.path.exists(dir_path): 351 | os.makedirs(dir_path) 352 | if os.path.exists(f_path): 353 | sh.copy(f_path, dir_path) 354 | 355 | # copy files in current dir 356 | [sh.copy(f, self.cache_path) for f in os.listdir('.') if os.path.isfile(f) and os.path.splitext(f)[1] in extensions] 357 | 358 | # copy dirs in current dir 359 | ignore = ['__pycache__', 'Experiments', 'Debug_Dir', '.git'] 360 | if hasattr(self.args, 'cache_ignore'): 361 | ignore += self.args.cache_ignore.split(',') 362 | [copy_folder(d) for d in os.listdir('.') if os.path.isdir(d) and d not in ignore] 363 | self.log_printer(f'==> Caching done (time: {time.time() - t0:.2f}s)') 364 | 365 | def get_project_name(self): 366 | '''For example, 'Projects/CRD/logger.py', then return CRD 367 | ''' 368 | file_path = os.path.abspath(__file__) 369 | return file_path.split('/')[-2] 370 | 371 | def __send_to_exp_hub(self): 372 | '''For every experiment, it will send to a hub for the convenience of checking. 373 | ''' 374 | today_exp = time.strftime("%Y%m%d") + "_exps.txt" 375 | if self.SERVER in CONFIDENTIAL_SERVERS: 376 | today_remote = 'huwang@137.203.141.202:/homes/huwang/Projects/ExpLogs/%s' % today_exp 377 | else: 378 | today_remote = 'wanghuan@155.33.198.138:/home/wanghuan/Projects/ExpLogs/%s' % today_exp 379 | local_f = 'wh_exps_%s.tmp' % time.time() 380 | try: 381 | script_pull = 'scp %s %s' % (today_remote, local_f) 382 | os.system(script_pull) 383 | except: 384 | pass 385 | with open(local_f, 'a+') as f: 386 | f.write(self.ExpNote + '\n') 387 | script_push = 'scp %s %s' % (local_f, today_remote) 388 | os.system(script_push) 389 | os.remove(local_f) 390 | 391 | def save_args(self, args): 392 | # with open(pjoin(self.log_path, 'params.json'), 'w') as f: 393 | # json.dump(args.__dict__, f, indent=4) 394 | with open(pjoin(self.log_path, 'params.yaml'), 'w') as f: 395 | yaml.dump(args.__dict__, f, indent=4) 396 | 397 | def netprint(self, net, comment=''): 398 | with open(pjoin(self.log_path, 'model_arch.txt'), 'w') as f: 399 | if comment: 400 | print('%s:' % comment, file=f) 401 | print('%s\n' % str(net), file=f, flush=True) 402 | -------------------------------------------------------------------------------- /pruner/reg_pruner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import os, copy, time, pickle, numpy as np, math 5 | from .meta_pruner import MetaPruner 6 | from utils import plot_weights_heatmap, Timer 7 | import matplotlib.pyplot as plt 8 | pjoin = os.path.join 9 | 10 | class Pruner(MetaPruner): 11 | def __init__(self, model, args, logger, passer): 12 | super(Pruner, self).__init__(model, args, logger, passer) 13 | 14 | # Reg related variables 15 | self.reg = {} 16 | self.delta_reg = {} 17 | self.hist_mag_ratio = {} 18 | self.n_update_reg = {} 19 | self.iter_update_reg_finished = {} 20 | self.iter_finish_pick = {} 21 | self.iter_stabilize_reg = math.inf 22 | self.original_w_mag = {} 23 | self.original_kept_w_mag = {} 24 | self.ranking = {} 25 | self.pruned_wg_L1 = {} 26 | self.all_layer_finish_pick = False 27 | self.w_abs = {} 28 | self.mag_reg_log = {} 29 | 30 | # prune_init, to determine the pruned weights 31 | # this will update the 'self.kept_wg' and 'self.pruned_wg' 32 | if self.args.method in ['RST']: 33 | self._get_kept_wg_L1() 34 | for k, v in self.pruned_wg.items(): 35 | self.pruned_wg_L1[k] = v 36 | 37 | self.prune_state = "update_reg" 38 | for name, m in self.model.named_modules(): 39 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 40 | shape = m.weight.data.shape 41 | 42 | # initialize reg 43 | if self.args.wg == 'weight': 44 | self.reg[name] = torch.zeros_like(m.weight.data).flatten().cuda() 45 | else: 46 | self.reg[name] = torch.zeros(shape[0], shape[1]).cuda() 47 | 48 | # get original weight magnitude 49 | w_abs = self._get_score(m) 50 | n_wg = len(w_abs) 51 | self.ranking[name] = [] 52 | for _ in range(n_wg): 53 | self.ranking[name].append([]) 54 | self.original_w_mag[name] = m.weight.abs().mean().item() 55 | # kept_wg_L1 = [i for i in range(n_wg) if i not in self.pruned_wg_L1[name]] # low speed 56 | kept_wg_L1 = list(set(range(n_wg)) - set(self.pruned_wg_L1[name])) 57 | self.original_kept_w_mag[name] = w_abs[kept_wg_L1].mean().item() 58 | 59 | self.reg_ = copy.deepcopy(self.reg) 60 | 61 | def _pick_pruned_wg(self, w, pr): 62 | if pr == 0: 63 | return [] 64 | elif pr > 0: 65 | w = w.flatten() 66 | n_pruned = min(math.ceil(pr * w.size(0)), w.size(0) - 1) # do not prune all 67 | return w.sort()[1][:n_pruned] 68 | elif pr == -1: # automatically decide lr by each layer itself 69 | tmp = w.flatten().sort()[0] 70 | n_not_consider = int(len(tmp) * 0.02) 71 | w = tmp[n_not_consider:-n_not_consider] 72 | 73 | sorted_w, sorted_index = w.flatten().sort() 74 | max_gap = 0 75 | max_index = 0 76 | for i in range(len(sorted_w) - 1): 77 | # gap = sorted_w[i+1:].mean() - sorted_w[:i+1].mean() 78 | gap = sorted_w[i+1] - sorted_w[i] 79 | if gap > max_gap: 80 | max_gap = gap 81 | max_index = i 82 | max_index += n_not_consider 83 | return sorted_index[:max_index + 1] 84 | else: 85 | self.logprint("Wrong pr. Please check.") 86 | exit(1) 87 | 88 | def _update_mag_ratio(self, m, name, w_abs, pruned=None): 89 | if type(pruned) == type(None): 90 | pruned = self.pruned_wg[name] 91 | kept = [i for i in range(len(w_abs)) if i not in pruned] 92 | ave_mag_pruned = w_abs[pruned].mean() 93 | ave_mag_kept = w_abs[kept].mean() 94 | if len(pruned): 95 | mag_ratio = ave_mag_kept / ave_mag_pruned 96 | if name in self.hist_mag_ratio: 97 | self.hist_mag_ratio[name] = self.hist_mag_ratio[name]* 0.9 + mag_ratio * 0.1 98 | else: 99 | self.hist_mag_ratio[name] = mag_ratio 100 | else: 101 | mag_ratio = math.inf 102 | self.hist_mag_ratio[name] = math.inf 103 | 104 | # print 105 | mag_ratio_now_before = ave_mag_kept / self.original_kept_w_mag[name] 106 | if self.total_iter % self.args.print_interval == 0: 107 | self.logprint(" mag_ratio %.4f mag_ratio_momentum %.4f" % (mag_ratio, self.hist_mag_ratio[name])) 108 | self.logprint(" for kept weights, original_kept_w_mag %.6f, now_kept_w_mag %.6f ratio_now_over_original %.4f" % 109 | (self.original_kept_w_mag[name], ave_mag_kept, mag_ratio_now_before)) 110 | return mag_ratio_now_before 111 | 112 | def _get_score(self, m): 113 | shape = m.weight.data.shape 114 | if self.args.wg == "channel": 115 | w_abs = m.weight.abs().mean(dim=[0, 2, 3]) if len(shape) == 4 else m.weight.abs().mean(dim=0) 116 | elif self.args.wg == "filter": 117 | w_abs = m.weight.abs().mean(dim=[1, 2, 3]) if len(shape) == 4 else m.weight.abs().mean(dim=1) 118 | elif self.args.wg == "weight": 119 | w_abs = m.weight.abs().flatten() 120 | return w_abs 121 | 122 | def _greg_1(self, m, name): 123 | if self.pr[name] == 0: 124 | return True 125 | 126 | if self.args.wg != 'weight': # weight is too slow 127 | self._update_mag_ratio(m, name, self.w_abs[name]) 128 | 129 | pruned = self.pruned_wg[name] 130 | if self.args.RST_schedule == 'x': 131 | if self.args.wg == "channel": 132 | self.reg[name][:, pruned] += self.args.reg_granularity_prune 133 | elif self.args.wg == "filter": 134 | self.reg[name][pruned, :] += self.args.reg_granularity_prune 135 | elif self.args.wg == 'weight': 136 | self.reg[name][pruned] += self.args.reg_granularity_prune 137 | else: 138 | raise NotImplementedError 139 | 140 | if self.args.RST_schedule == 'x^2': 141 | if self.args.wg == 'weight': 142 | self.reg_[name][pruned] += self.args.reg_granularity_prune 143 | self.reg[name][pruned] = self.reg_[name][pruned]**2 144 | else: 145 | raise NotImplementedError 146 | 147 | if self.args.RST_schedule == 'x^3': 148 | if self.args.wg == 'weight': 149 | self.reg_[name][pruned] += self.args.reg_granularity_prune 150 | self.reg[name][pruned] = self.reg_[name][pruned]**3 151 | else: 152 | raise NotImplementedError 153 | 154 | # when all layers are pushed hard enough, stop 155 | if self.args.wg == 'weight': # for weight, do not use the magnitude ratio condition, because 'hist_mag_ratio' is not updated, too costly 156 | finish_update_reg = False 157 | else: 158 | finish_update_reg = True 159 | for k in self.hist_mag_ratio: 160 | if self.hist_mag_ratio[k] < self.args.mag_ratio_limit: 161 | finish_update_reg = False 162 | return finish_update_reg or self.reg[name].max() > self.args.reg_upper_limit 163 | 164 | def _update_reg(self): 165 | for name, m in self.model.named_modules(): 166 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 167 | cnt_m = self.layers[name].layer_index 168 | pr = self.pr[name] 169 | 170 | if name in self.iter_update_reg_finished.keys(): 171 | continue 172 | 173 | if self.total_iter % self.args.print_interval == 0: 174 | self.logprint("[%d] Update reg for layer '%s'. Pr = %s. Iter = %d" 175 | % (cnt_m, name, pr, self.total_iter)) 176 | 177 | # get the importance score (L1-norm in this case) 178 | self.w_abs[name] = self._get_score(m) 179 | 180 | # update reg functions, two things: 181 | # (1) update reg of this layer (2) determine if it is time to stop update reg 182 | if self.args.method == "RST": 183 | finish_update_reg = self._greg_1(m, name) 184 | else: 185 | self.logprint("Wrong '--method' argument, please check.") 186 | exit(1) 187 | 188 | # check prune state 189 | if finish_update_reg: 190 | # after 'update_reg' stage, keep the reg to stabilize weight magnitude 191 | self.iter_update_reg_finished[name] = self.total_iter 192 | self.logprint("==> [%d] Just finished 'update_reg'. Iter = %d" % (cnt_m, self.total_iter)) 193 | 194 | # check if all layers finish 'update_reg' 195 | self.prune_state = "stabilize_reg" 196 | for n, mm in self.model.named_modules(): 197 | if isinstance(mm, nn.Conv2d) or isinstance(mm, nn.Linear): 198 | if n not in self.iter_update_reg_finished: 199 | self.prune_state = "update_reg" 200 | break 201 | if self.prune_state == "stabilize_reg": 202 | self.iter_stabilize_reg = self.total_iter 203 | self.logprint("==> All layers just finished 'update_reg', go to 'stabilize_reg'. Iter = %d" % self.total_iter) 204 | self._save_model(mark='just_finished_update_reg') 205 | 206 | # after reg is updated, print to check 207 | if self.total_iter % self.args.print_interval == 0: 208 | self.logprint(" reg_status: min = %.5f ave = %.5f max = %.5f" % 209 | (self.reg[name].min(), self.reg[name].mean(), self.reg[name].max())) 210 | 211 | def _apply_reg(self): 212 | for name, m in self.model.named_modules(): 213 | if name in self.reg: 214 | reg = self.reg[name] # [N, C] 215 | if self.args.wg in ['filter', 'channel']: 216 | if reg.shape != m.weight.data.shape: 217 | reg = reg.unsqueeze(2).unsqueeze(3) # [N, C, 1, 1] 218 | elif self.args.wg == 'weight': 219 | reg = reg.view_as(m.weight.data) # [N, C, H, W] 220 | l2_grad = reg * m.weight 221 | if self.args.block_loss_grad: 222 | m.weight.grad = l2_grad 223 | else: 224 | m.weight.grad += l2_grad 225 | 226 | def _resume_prune_status(self, ckpt_path): 227 | state = torch.load(ckpt_path) 228 | self.model = state['model'].cuda() 229 | self.model.load_state_dict(state['state_dict']) 230 | self.optimizer = optim.SGD(self.model.parameters(), 231 | lr=self.args.lr_pick if self.args.__dict__.get('AdaReg_only_picking') else self.args.lr_prune, 232 | momentum=self.args.momentum, 233 | weight_decay=self.args.weight_decay) 234 | self.optimizer.load_state_dict(state['optimizer']) 235 | self.prune_state = state['prune_state'] 236 | self.total_iter = state['iter'] 237 | self.iter_stabilize_reg = state.get('iter_stabilize_reg', math.inf) 238 | self.reg = state['reg'] 239 | self.hist_mag_ratio = state['hist_mag_ratio'] 240 | 241 | def _save_model(self, acc1=0, acc5=0, mark=''): 242 | state = {'iter': self.total_iter, 243 | 'prune_state': self.prune_state, # we will resume prune_state 244 | 'arch': self.args.arch, 245 | 'model': self.model, 246 | 'state_dict': self.model.state_dict(), 247 | 'iter_stabilize_reg': self.iter_stabilize_reg, 248 | 'acc1': acc1, 249 | 'acc5': acc5, 250 | 'optimizer': self.optimizer.state_dict(), 251 | 'reg': self.reg, 252 | 'hist_mag_ratio': self.hist_mag_ratio, 253 | 'ExpID': self.logger.ExpID, 254 | } 255 | self.save(state, is_best=False, mark=mark) 256 | 257 | def prune(self): 258 | self.model = self.model.train() 259 | self.optimizer = optim.SGD(self.model.parameters(), 260 | lr=self.args.lr_pick if self.args.__dict__.get('AdaReg_only_picking') else self.args.lr_prune, 261 | momentum=self.args.momentum, 262 | weight_decay=self.args.weight_decay) 263 | 264 | # resume model, optimzer, prune_status 265 | self.total_iter = -1 266 | if self.args.resume_path: 267 | self._resume_prune_status(self.args.resume_path) 268 | self._get_kept_wg_L1() # get pruned and kept wg from the resumed model 269 | self.model = self.model.train() 270 | self.logprint("Resume model successfully: '{}'. Iter = {}. prune_state = {}".format( 271 | self.args.resume_path, self.total_iter, self.prune_state)) 272 | 273 | acc1 = acc5 = 0 274 | total_iter_reg = self.args.reg_upper_limit / self.args.reg_granularity_prune * self.args.update_reg_interval + self.args.stabilize_reg_interval 275 | timer = Timer(total_iter_reg / self.args.print_interval) 276 | while True: 277 | for _, (inputs, targets) in enumerate(self.train_loader): 278 | inputs, targets = inputs.cuda(), targets.cuda() 279 | self.total_iter += 1 280 | total_iter = self.total_iter 281 | 282 | # test 283 | if total_iter % self.args.test_interval == 0: 284 | acc1, acc5, *_ = self.test(self.model) 285 | self.accprint("Acc1 = %.4f Acc5 = %.4f Iter = %d (before update) [prune_state = %s, method = %s]" % 286 | (acc1, acc5, total_iter, self.prune_state, self.args.method)) 287 | 288 | # save model (save model before a batch starts) 289 | if total_iter % self.args.save_interval == 0: 290 | self._save_model(acc1, acc5) 291 | self.logprint('Periodically save model done. Iter = {}'.format(total_iter)) 292 | 293 | if total_iter % self.args.print_interval == 0: 294 | self.logprint("") 295 | self.logprint("Iter = %d [prune_state = %s, method = %s] " 296 | % (total_iter, self.prune_state, self.args.method) + "-"*40) 297 | 298 | # forward 299 | self.model.train() 300 | y_ = self.model(inputs) 301 | 302 | if self.prune_state == "update_reg" and total_iter % self.args.update_reg_interval == 0: 303 | self._update_reg() 304 | 305 | # normal training forward 306 | loss = self.criterion(y_, targets) 307 | self.optimizer.zero_grad() 308 | loss.backward() 309 | 310 | # after backward but before update, apply reg to the grad 311 | self._apply_reg() 312 | self.optimizer.step() 313 | 314 | # log print 315 | if total_iter % self.args.print_interval == 0: 316 | # check BN stats 317 | if self.args.verbose: 318 | for name, m in self.model.named_modules(): 319 | if isinstance(m, nn.BatchNorm2d): 320 | # get the associating conv layer of this BN layer 321 | ix = self.all_layers.index(name) 322 | for k in range(ix-1, -1, -1): 323 | if self.all_layers[k] in self.layers: 324 | last_conv = self.all_layers[k] 325 | break 326 | mask_ = [0] * m.weight.data.size(0) 327 | for i in self.kept_wg[last_conv]: 328 | mask_[i] = 1 329 | wstr = ' '.join(['%.3f (%s)' % (x, y) for x, y in zip(m.weight.data, mask_)]) 330 | bstr = ' '.join(['%.3f (%s)' % (x, y) for x, y in zip(m.bias.data, mask_)]) 331 | logstr = f'{last_conv} BN weight: {wstr}\nBN bias: {bstr}' 332 | self.logprint(logstr) 333 | 334 | # check train acc 335 | _, predicted = y_.max(1) 336 | correct = predicted.eq(targets).sum().item() 337 | train_acc = correct / targets.size(0) 338 | self.logprint("After optim update current_train_loss: %.4f current_train_acc: %.4f" % (loss.item(), train_acc)) 339 | 340 | 341 | # change prune state 342 | if self.prune_state == "stabilize_reg" and total_iter - self.iter_stabilize_reg == self.args.stabilize_reg_interval: 343 | # # --- check accuracy to make sure '_prune_and_build_new_model' works normally 344 | # # checked. works normally! 345 | # for name, m in self.model.named_modules(): 346 | # if isinstance(m, self.learnable_layers): 347 | # pruned_filter = self.pruned_wg[name] 348 | # m.weight.data[pruned_filter] *= 0 349 | # next_bn = self._next_bn(self.model, m) 350 | # elif isinstance(m, nn.BatchNorm2d) and m == next_bn: 351 | # m.weight.data[pruned_filter] *= 0 352 | # m.bias.data[pruned_filter] *= 0 353 | 354 | # acc1_before, *_ = self.test(self.model) 355 | # self._prune_and_build_new_model() 356 | # acc1_after, *_ = self.test(self.model) 357 | # print(acc1_before, acc1_after) 358 | # exit() 359 | # # --- 360 | model_before_removing_weights = copy.deepcopy(self.model) 361 | self._prune_and_build_new_model() 362 | self.logprint("'stabilize_reg' is done. Pruned, go to 'finetune'. Iter = %d" % total_iter) 363 | return model_before_removing_weights, copy.deepcopy(self.model) 364 | 365 | if total_iter % self.args.print_interval == 0: 366 | self.logprint(f"predicted_finish_time of reg: {timer()}") 367 | 368 | def _plot_mag_ratio(self, w_abs, name): 369 | fig, ax = plt.subplots() 370 | max_ = w_abs.max().item() 371 | w_abs_normalized = (w_abs / max_).data.cpu().numpy() 372 | ax.plot(w_abs_normalized) 373 | ax.set_ylim([0, 1]) 374 | ax.set_xlabel('filter index') 375 | ax.set_ylabel('relative L1-norm ratio') 376 | layer_index = self.layers[name].layer_index 377 | shape = self.layers[name].size 378 | ax.set_title("layer %d iter %d shape %s\n(max = %s)" 379 | % (layer_index, self.total_iter, shape, max_)) 380 | out = pjoin(self.logger.logplt_path, "%d_iter%d_w_abs_dist.jpg" % 381 | (layer_index, self.total_iter)) 382 | fig.savefig(out) 383 | plt.close(fig) 384 | np.save(out.replace('.jpg', '.npy'), w_abs_normalized) 385 | -------------------------------------------------------------------------------- /pruner/meta_pruner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import copy 4 | import time 5 | import numpy as np 6 | from math import ceil, sqrt 7 | from collections import OrderedDict 8 | from utils import strdict_to_dict 9 | from fnmatch import fnmatch, fnmatchcase 10 | from pdb import set_trace as st 11 | 12 | class Layer: 13 | def __init__(self, name, size, layer_index, res=False, layer_type=None): 14 | self.name = name 15 | self.size = [] 16 | for x in size: 17 | self.size.append(x) 18 | self.layer_index = layer_index 19 | self.layer_type = layer_type 20 | self.is_shortcut = True if "downsample" in name else False 21 | if res: 22 | self.stage, self.seq_index, self.block_index = self._get_various_index_by_name(name) 23 | 24 | def _get_various_index_by_name(self, name): 25 | '''Get the indeces including stage, seq_ix, blk_ix. 26 | Same stage means the same feature map size. 27 | ''' 28 | global lastest_stage # an awkward impel, just for now 29 | if name.startswith('module.'): 30 | name = name[7:] # remove the prefix caused by pytorch data parallel 31 | 32 | if "conv1" == name: # TODO: this might not be so safe 33 | lastest_stage = 0 34 | return 0, None, None 35 | if "linear" in name or 'fc' in name: # Note: this can be risky. Check it fully. TODO: @mingsun-tse 36 | return lastest_stage + 1, None, None # fc layer should always be the last layer 37 | else: 38 | try: 39 | stage = int(name.split(".")[0][-1]) # ONLY work for standard resnets. name example: layer2.2.conv1, layer4.0.downsample.0 40 | seq_ix = int(name.split(".")[1]) 41 | if 'conv' in name.split(".")[-1]: 42 | blk_ix = int(name[-1]) - 1 43 | else: 44 | blk_ix = -1 # shortcut layer 45 | lastest_stage = stage 46 | return stage, seq_ix, blk_ix 47 | except: 48 | print('!Parsing the layer name failed: %s. Please check.' % name) 49 | 50 | class MetaPruner: 51 | def __init__(self, model, args, logger, passer): 52 | self.model = model 53 | self.args = args 54 | self.logger = logger 55 | self.logprint = logger.log_printer.logprint 56 | self.accprint = logger.log_printer.accprint 57 | self.netprint = logger.log_printer.netprint 58 | self.test = lambda net: passer.test(passer.test_loader, net, passer.criterion, passer.args) 59 | self.train_loader = passer.train_loader 60 | self.criterion = passer.criterion 61 | self.save = passer.save 62 | self.is_single_branch = passer.is_single_branch 63 | 64 | self.learnable_layers = (nn.Conv2d, nn.Linear) # Note: for now, we only focus on weights in Conv and FC modules, no BN. 65 | self.layers = OrderedDict() # learnable layers 66 | self.all_layers = [] # all layers 67 | self._register_layers() # register learnable layers 68 | 69 | arch = self.args.arch 70 | if arch.startswith('resnet'): 71 | # TODO: add block 72 | self.n_conv_within_block = 0 73 | if args.dataset == "imagenet": 74 | if arch in ['resnet18', 'resnet34']: 75 | self.n_conv_within_block = 2 76 | elif arch in ['resnet50', 'resnet101', 'resnet152']: 77 | self.n_conv_within_block = 3 78 | else: 79 | self.n_conv_within_block = 2 80 | 81 | self.kept_wg = {} 82 | self.pruned_wg = {} 83 | self.get_pr() # set up pr for each layer 84 | 85 | def _pick_pruned(self, w_abs, pr, mode="min", name=None): 86 | if pr == 0: 87 | return [] 88 | w_abs_list = w_abs # .flatten() 89 | n_wg = len(w_abs_list) 90 | n_pruned = min(ceil(pr * n_wg), n_wg - 1) # do not prune all 91 | if mode == "rand": 92 | out = np.random.permutation(n_wg)[:n_pruned] 93 | elif mode == "min": 94 | out = w_abs_list.sort()[1][:n_pruned] 95 | out = out.data.cpu().numpy() 96 | elif mode == "max": 97 | out = w_abs_list.sort()[1][-n_pruned:] 98 | out = out.data.cpu().numpy() 99 | elif mode == "iter_rand": # for iterative setting: 1. min for previous part, rand for current part 100 | if self.pruned_wg.__contains__(name): # means this is NOT first cycle 101 | previous_out_size = len(self.pruned_wg[name]) 102 | previous_out = w_abs_list.sort()[1][:previous_out_size] # min mode picks all zero 103 | previous_out = previous_out.data.cpu().numpy() 104 | 105 | current_existing_index = w_abs_list.sort()[1][previous_out_size:] # sorted 106 | current_existing_index = current_existing_index.cpu().numpy() 107 | np.random.shuffle(current_existing_index) # random (un-sorted) 108 | 109 | current_out_size = n_pruned - previous_out_size 110 | current_out = current_existing_index[:current_out_size] 111 | 112 | # current_existing_size = n_wg - previous_out_size 113 | # current_out = np.random.permutation(current_existing_size)[:current_out_size] + previous_out_size 114 | 115 | out = np.concatenate((previous_out, current_out)) 116 | else: # means this is first cycle: return rand mode 117 | out = np.random.permutation(n_wg)[:n_pruned] 118 | 119 | return out 120 | 121 | def _register_layers(self): 122 | ''' 123 | This will maintain a data structure that can return some useful 124 | information by the name of a layer. 125 | ''' 126 | ix = -1 # layer index, starts from 0 127 | self._max_len_name = 0 128 | layer_shape = {} 129 | for name, m in self.model.named_modules(): 130 | self.all_layers += [name] 131 | if isinstance(m, self.learnable_layers): 132 | if "downsample" not in name: # hardcoding, an ugly design, will be improved 133 | ix += 1 134 | layer_shape[name] = [ix, m.weight.size()] 135 | self._max_len_name = max(self._max_len_name, len(name)) 136 | 137 | size = m.weight.size() 138 | res = True if self.args.arch.startswith('resnet') else False 139 | self.layers[name] = Layer(name, size, ix, res, layer_type=m.__class__.__name__) 140 | 141 | self._max_len_ix = len("%s" % ix) 142 | print("Register layer index and kernel shape:") 143 | format_str = "[%{}d] %{}s -- kernel_shape: %s".format(self._max_len_ix, self._max_len_name) 144 | for name, (ix, ks) in layer_shape.items(): 145 | print(format_str % (ix, name, ks)) 146 | 147 | def _next_learnable_layer(self, model, name, mm): 148 | '''get the next conv or fc layer name 149 | ''' 150 | if hasattr(self.layers[name], 'block_index'): 151 | block_index = self.layers[name].block_index 152 | if block_index == self.n_conv_within_block - 1: 153 | return None 154 | 155 | ix = self.layers[name].layer_index # layer index of current layer 156 | type_ = mm.__class__.__name__ # layer type of current layer 157 | for name, layer in self.layers.items(): 158 | if layer.layer_type == type_ and layer.layer_index == ix + 1: # for now, requires the same layer_type for wg == 'channel'. TODO: generalize this 159 | return name 160 | return None 161 | 162 | def _prev_learnable_layer(self, model, name, mm): 163 | '''get the previous conv or fc layer name 164 | ''' 165 | # explicitly provide the previous layer name, then use it as the highest priority! 166 | # useful for complex residual networks 167 | for p in self.args.previous_layers: 168 | if fnmatch(name, p): 169 | prev_layer = self.args.previous_layers[p] 170 | if prev_layer.lower() == 'none': 171 | return None 172 | else: 173 | return prev_layer 174 | 175 | # standard resnets. hardcoding, deprecated, will be improved 176 | if hasattr(self.layers[name], 'block_index'): 177 | block_index = self.layers[name].block_index 178 | if block_index in [None, 0, -1]: # 1st conv, 1st conv in a block, 1x1 shortcut layer 179 | return None 180 | 181 | # get the previous layer by order 182 | ix = self.layers[name].layer_index # layer index of current layer 183 | for name, layer in self.layers.items(): 184 | if layer.layer_index == ix - 1: 185 | return name 186 | return None 187 | 188 | def _next_bn(self, model, mm): 189 | just_passed_mm = False 190 | for m in model.modules(): 191 | if m == mm: 192 | just_passed_mm = True 193 | if just_passed_mm and isinstance(m, nn.BatchNorm2d): 194 | return m 195 | return None 196 | 197 | def _replace_module(self, model, name, new_m): 198 | ''' 199 | Replace the module in with 200 | E.g., 'module.layer1.0.conv1' 201 | ==> model.__getattr__('module').__getattr__("layer1").__getitem__(0).__setattr__('conv1', new_m) 202 | ''' 203 | obj = model 204 | segs = name.split(".") 205 | for ix in range(len(segs)): 206 | s = segs[ix] 207 | if ix == len(segs) - 1: # the last one 208 | if s.isdigit(): 209 | obj.__setitem__(int(s), new_m) 210 | else: 211 | obj.__setattr__(s, new_m) 212 | return 213 | if s.isdigit(): 214 | obj = obj.__getitem__(int(s)) 215 | else: 216 | obj = obj.__getattr__(s) 217 | 218 | def _get_n_filter(self, model): 219 | ''' 220 | Do not consider the downsample 1x1 shortcuts. 221 | ''' 222 | n_filter = OrderedDict() 223 | for name, m in model.named_modules(): 224 | if name in self.layers: 225 | if not self.layers[name].is_shortcut: 226 | ix = self.layers[name].layer_index 227 | n_filter[ix] = m.weight.size(0) 228 | return n_filter 229 | 230 | def _get_layer_pr_vgg(self, name): 231 | '''Example: '[0-4:0.5, 5:0.6, 8-10:0.2]' 232 | 6, 7 not mentioned, default value is 0 233 | ''' 234 | layer_index = self.layers[name].layer_index 235 | pr = self.args.stage_pr[layer_index] 236 | if str(layer_index) in self.args.skip_layers: 237 | pr = 0 238 | return pr 239 | 240 | def _get_layer_pr_resnet(self, name): 241 | ''' 242 | This function will determine the prune_ratio (pr) for each specific layer 243 | by a set of rules. 244 | ''' 245 | wg = self.args.wg 246 | layer_index = self.layers[name].layer_index 247 | stage = self.layers[name].stage 248 | seq_index = self.layers[name].seq_index 249 | block_index = self.layers[name].block_index 250 | is_shortcut = self.layers[name].is_shortcut 251 | pr = self.args.stage_pr[stage] 252 | 253 | # for unstructured pruning, no restrictions, every layer can be pruned 254 | if self.args.wg != 'weight': 255 | # do not prune the shortcut layers for now 256 | if is_shortcut: 257 | pr = 0 258 | 259 | # do not prune layers we set to be skipped 260 | layer_id = '%s.%s.%s' % (str(stage), str(seq_index), str(block_index)) 261 | for s in self.args.skip_layers: 262 | if s and layer_id.startswith(s): 263 | pr = 0 264 | 265 | # for channel/filter prune, do not prune the 1st/last conv in a block 266 | if (wg == "channel" and block_index == 0) or \ 267 | (wg == "filter" and block_index == self.n_conv_within_block - 1): 268 | pr = 0 269 | 270 | return pr 271 | 272 | def _get_pr_by_name_matching(self, name): 273 | pr = 0 # default pr = 0 274 | for p in self.args.stage_pr: 275 | if fnmatch(name, p): 276 | pr = self.args.stage_pr[p] 277 | return pr 278 | 279 | def get_pr(self): 280 | '''Get layer-wise pruning ratio for each layer. 281 | ''' 282 | self.pr = {} 283 | if self.args.stage_pr: # stage_pr may be None (in the case that base_pr_model is provided) 284 | assert self.args.base_pr_model is None 285 | if self.args.index_layer == 'numbers': # old way to assign pruning ratios, deprecated, will be removed 286 | get_layer_pr = self._get_layer_pr_vgg if self.is_single_branch(self.args.arch) else self._get_layer_pr_resnet 287 | for name, m in self.model.named_modules(): 288 | if isinstance(m, self.learnable_layers): 289 | self.pr[name] = get_layer_pr(name) 290 | elif self.args.index_layer == 'name_matching': 291 | for name, m in self.model.named_modules(): 292 | if isinstance(m, self.learnable_layers): 293 | self.pr[name] = self._get_pr_by_name_matching(name) 294 | else: 295 | assert self.args.base_pr_model 296 | state = torch.load(self.args.base_pr_model) 297 | self.pruned_wg_pr_model = state['pruned_wg'] 298 | self.kept_wg_pr_model = state['kept_wg'] 299 | for k in self.pruned_wg_pr_model: 300 | n_pruned = len(self.pruned_wg_pr_model[k]) 301 | n_kept = len(self.kept_wg_pr_model[k]) 302 | self.pr[k] = float(n_pruned) / (n_pruned + n_kept) 303 | self.logprint("==> Load base_pr_model successfully and inherit its pruning ratio: '{}'".format(self.args.base_pr_model)) 304 | 305 | def _get_kept_wg_L1(self): 306 | '''Decide kept (or pruned) weight group by L1-norm sorting. 307 | ''' 308 | if self.args.base_pr_model and self.args.inherit_pruned == 'index': 309 | self.pruned_wg = self.pruned_wg_pr_model 310 | self.kept_wg = self.kept_wg_pr_model 311 | self.logprint("==> Inherit the pruned index from base_pr_model: '{}'".format(self.args.base_pr_model)) 312 | else: 313 | wg = self.args.wg 314 | for name, m in self.model.named_modules(): 315 | if isinstance(m, self.learnable_layers): 316 | shape = m.weight.data.shape 317 | if wg == "filter": 318 | score = m.weight.abs().mean(dim=[1, 2, 3]) if len(shape) == 4 else m.weight.abs().mean(dim=1) 319 | elif wg == "channel": 320 | score = m.weight.abs().mean(dim=[0, 2, 3]) if len(shape) == 4 else m.weight.abs().mean(dim=0) 321 | elif wg == "weight": 322 | score = m.weight.abs().flatten() 323 | else: 324 | raise NotImplementedError 325 | 326 | self.pruned_wg[name] = self._pick_pruned(score, self.pr[name], self.args.pick_pruned, name) 327 | self.kept_wg[name] = list(set(range(len(score))) - set(self.pruned_wg[name])) 328 | format_str = f"[%{self._max_len_ix}d] %{self._max_len_name}s -- shape {shape} -- got pruned wg by L1 sorting ({self.args.pick_pruned}), pr {self.pr[name]}" 329 | logtmp = format_str % (self.layers[name].layer_index, name) 330 | 331 | # compare the pruned weights picked by L1-sorting vs. other criterion which provides the base_pr_model (e.g., OBD) 332 | if self.args.base_pr_model: 333 | intersection = [x for x in self.pruned_wg_pr_model[name] if x in self.pruned_wg[name]] 334 | intersection_ratio = len(intersection) / len(self.pruned_wg[name]) if len(self.pruned_wg[name]) else 0 335 | logtmp += ', intersection ratio of the weights picked by L1 vs. base_pr_model: %.4f (%d)' % (intersection_ratio, len(intersection)) 336 | self.netprint(logtmp) 337 | 338 | def _get_kept_filter_channel(self, m, name): 339 | '''For filter/channel pruning, prune one layer will affect the following/previous layer. This func is to figure out which filters 340 | and channels will be kept in a layer speficially. 341 | ''' 342 | if self.args.wg == "channel": 343 | kept_chl = self.kept_wg[name] 344 | next_learnable_layer = self._next_learnable_layer(self.model, name, m) 345 | if not next_learnable_layer: 346 | kept_filter = list(range(m.weight.size(0))) 347 | else: 348 | kept_filter = self.kept_wg[next_learnable_layer] 349 | 350 | elif self.args.wg == "filter": 351 | kept_filter = self.kept_wg[name] 352 | prev_learnable_layer = self._prev_learnable_layer(self.model, name, m) 353 | if isinstance(m, nn.Conv2d) and m.groups == m.weight.shape[0] and m.weight.shape[1] == 1: # depth-wise conv 354 | kept_chl = [0] # depth-wise conv, channel number is always 1 355 | if prev_learnable_layer: 356 | kept_filter = [x for x in kept_filter if x in self.kept_wg[prev_learnable_layer]] 357 | self.kept_wg[name] = kept_filter 358 | else: 359 | if not prev_learnable_layer: 360 | kept_chl = list(range(m.weight.size(1))) 361 | else: 362 | if self.layers[name].layer_type == self.layers[prev_learnable_layer].layer_type: 363 | kept_chl = self.kept_wg[prev_learnable_layer] 364 | 365 | else: # current layer is the 1st fc, the previous layer is the last conv 366 | last_conv_n_filter = self.layers[prev_learnable_layer].size[0] 367 | last_conv_fm_size = int(m.weight.size(1) / last_conv_n_filter) # feature map spatial size. 36 for alexnet 368 | self.logprint('last_conv_feature_map_size: %dx%d (before fed into the first fc)' % (sqrt(last_conv_fm_size), sqrt(last_conv_fm_size))) 369 | last_conv_kept_filter = self.kept_wg[prev_learnable_layer] 370 | kept_chl = [] 371 | for i in last_conv_kept_filter: 372 | tmp = list(range(i * last_conv_fm_size, i * last_conv_fm_size + last_conv_fm_size)) 373 | kept_chl += tmp 374 | 375 | return kept_filter, kept_chl 376 | 377 | def _prune_and_build_new_model(self): 378 | if self.args.wg == 'weight': 379 | self._get_masks() 380 | return 381 | 382 | new_model = copy.deepcopy(self.model) 383 | for name, m in self.model.named_modules(): 384 | if isinstance(m, self.learnable_layers): 385 | kept_filter, kept_chl = self._get_kept_filter_channel(m, name) 386 | # print(f'{name} kept_filter: {kept_filter} kept_chl: {kept_chl}') 387 | 388 | # copy weight and bias 389 | bias = False if isinstance(m.bias, type(None)) else True 390 | if isinstance(m, nn.Conv2d): 391 | kept_weights = m.weight.data[kept_filter][:, kept_chl, :, :] 392 | if m.weight.shape[0] == m.groups and m.weight.shape[1] == 1: # depth-wise conv 393 | groups = len(kept_filter) 394 | else: 395 | groups = m.groups 396 | new_layer = nn.Conv2d(len(kept_chl) * groups, len(kept_filter), m.kernel_size, 397 | m.stride, m.padding, m.dilation, groups, bias).cuda() 398 | elif isinstance(m, nn.Linear): 399 | kept_weights = m.weight.data[kept_filter][:, kept_chl] 400 | new_layer = nn.Linear(in_features=len(kept_chl), out_features=len(kept_filter), bias=bias).cuda() 401 | new_layer.weight.data.copy_(kept_weights) # load weights into the new module 402 | if bias: 403 | kept_bias = m.bias.data[kept_filter] 404 | new_layer.bias.data.copy_(kept_bias) 405 | 406 | # load the new conv 407 | self._replace_module(new_model, name, new_layer) 408 | 409 | # get the corresponding bn (if any) for later use 410 | next_bn = self._next_bn(self.model, m) 411 | 412 | elif isinstance(m, nn.BatchNorm2d) and m == next_bn: 413 | new_bn = nn.BatchNorm2d(len(kept_filter), eps=m.eps, momentum=m.momentum, 414 | affine=m.affine, track_running_stats=m.track_running_stats).cuda() 415 | 416 | # copy bn weight and bias 417 | if self.args.copy_bn_w: 418 | weight = m.weight.data[kept_filter] 419 | new_bn.weight.data.copy_(weight) 420 | if self.args.copy_bn_b: 421 | bias = m.bias.data[kept_filter] 422 | new_bn.bias.data.copy_(bias) 423 | 424 | # copy bn running stats 425 | new_bn.running_mean.data.copy_(m.running_mean[kept_filter]) 426 | new_bn.running_var.data.copy_(m.running_var[kept_filter]) 427 | new_bn.num_batches_tracked.data.copy_(m.num_batches_tracked) 428 | 429 | # load the new bn 430 | self._replace_module(new_model, name, new_bn) 431 | 432 | self.model = new_model 433 | n_filter = self._get_n_filter(self.model) 434 | logtmp = '{' 435 | for ix, num in n_filter.items(): 436 | logtmp += '%s:%d, ' % (ix, num) 437 | logtmp = logtmp[:-2] + '}' 438 | self.logprint('n_filter of pruned model: %s' % logtmp) 439 | 440 | def _get_masks(self): 441 | '''Get masks for unstructured pruning 442 | ''' 443 | self.mask = {} 444 | for name, m in self.model.named_modules(): 445 | if isinstance(m, self.learnable_layers): 446 | mask = torch.ones_like(m.weight.data).cuda().flatten() 447 | pruned = self.pruned_wg[name] 448 | mask[pruned] = 0 449 | self.mask[name] = mask.view_as(m.weight.data) 450 | self.logprint('Get masks done for weight pruning') -------------------------------------------------------------------------------- /pruner/reg_pruner_iterative.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import os, copy, time, pickle, numpy as np, math 5 | from .meta_pruner import MetaPruner 6 | from utils import plot_weights_heatmap, Timer 7 | import matplotlib.pyplot as plt 8 | pjoin = os.path.join 9 | from utils import PresetLRScheduler, Timer 10 | from pdb import set_trace as st 11 | 12 | class Pruner(MetaPruner): 13 | def __init__(self, model, args, logger, passer): 14 | super(Pruner, self).__init__(model, args, logger, passer) 15 | 16 | # Reg related variables 17 | self.reg = {} # 18 | self.delta_reg = {} 19 | self.hist_mag_ratio = {} 20 | self.n_update_reg = {} 21 | self.iter_update_reg_finished = {} # 22 | self.iter_finish_pick = {} 23 | self.iter_stabilize_reg = math.inf # 24 | self.original_w_mag = {} 25 | self.original_kept_w_mag = {} 26 | self.ranking = {} 27 | self.pruned_wg_L1 = {} 28 | self.all_layer_finish_pick = False 29 | self.w_abs = {} # 30 | self.mag_reg_log = {} 31 | 32 | self.prune_state = "update_reg" # 33 | for name, m in self.model.named_modules(): # 34 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 35 | shape = m.weight.data.shape 36 | 37 | # initialize reg 38 | if self.args.wg == 'weight': 39 | self.reg[name] = torch.zeros_like(m.weight.data).flatten().cuda() 40 | else: 41 | self.reg[name] = torch.zeros(shape[0], shape[1]).cuda() 42 | 43 | # get original weight magnitude 44 | w_abs = self._get_score(m) 45 | n_wg = len(w_abs) 46 | self.ranking[name] = [] 47 | for _ in range(n_wg): 48 | self.ranking[name].append([]) 49 | self.original_w_mag[name] = m.weight.abs().mean().item() 50 | # kept_wg_L1 = [i for i in range(n_wg) if i not in self.pruned_wg_L1[name]] 51 | # self.original_kept_w_mag[name] = w_abs[kept_wg_L1].mean().item() 52 | 53 | self.pr_backup = {} 54 | for k, v in self.pr.items(): 55 | self.pr_backup[k] = v 56 | 57 | def _pick_pruned_wg(self, w, pr): 58 | if pr == 0: 59 | return [] 60 | elif pr > 0: 61 | w = w.flatten() 62 | n_pruned = min(math.ceil(pr * w.size(0)), w.size(0) - 1) # do not prune all 63 | return w.sort()[1][:n_pruned] 64 | elif pr == -1: # automatically decide lr by each layer itself 65 | tmp = w.flatten().sort()[0] 66 | n_not_consider = int(len(tmp) * 0.02) 67 | w = tmp[n_not_consider:-n_not_consider] 68 | 69 | sorted_w, sorted_index = w.flatten().sort() 70 | max_gap = 0 71 | max_index = 0 72 | for i in range(len(sorted_w) - 1): 73 | # gap = sorted_w[i+1:].mean() - sorted_w[:i+1].mean() 74 | gap = sorted_w[i+1] - sorted_w[i] 75 | if gap > max_gap: 76 | max_gap = gap 77 | max_index = i 78 | max_index += n_not_consider 79 | return sorted_index[:max_index + 1] 80 | else: 81 | self.logprint("Wrong pr. Please check.") 82 | exit(1) 83 | 84 | def _update_mag_ratio(self, m, name, w_abs, pruned=None): 85 | if type(pruned) == type(None): 86 | pruned = self.pruned_wg[name] 87 | kept = [i for i in range(len(w_abs)) if i not in pruned] 88 | ave_mag_pruned = w_abs[pruned].mean() 89 | ave_mag_kept = w_abs[kept].mean() 90 | if len(pruned): 91 | mag_ratio = ave_mag_kept / ave_mag_pruned 92 | if name in self.hist_mag_ratio: 93 | self.hist_mag_ratio[name] = self.hist_mag_ratio[name]* 0.9 + mag_ratio * 0.1 94 | else: 95 | self.hist_mag_ratio[name] = mag_ratio 96 | else: 97 | mag_ratio = math.inf 98 | self.hist_mag_ratio[name] = math.inf 99 | 100 | # print 101 | mag_ratio_now_before = ave_mag_kept / self.original_kept_w_mag[name] 102 | if self.total_iter % self.args.print_interval == 0: 103 | self.logprint(" mag_ratio %.4f mag_ratio_momentum %.4f" % (mag_ratio, self.hist_mag_ratio[name])) 104 | self.logprint(" for kept weights, original_kept_w_mag %.6f, now_kept_w_mag %.6f ratio_now_over_original %.4f" % 105 | (self.original_kept_w_mag[name], ave_mag_kept, mag_ratio_now_before)) 106 | return mag_ratio_now_before 107 | 108 | def _get_score(self, m): 109 | shape = m.weight.data.shape 110 | if self.args.wg == "channel": 111 | w_abs = m.weight.abs().mean(dim=[0, 2, 3]) if len(shape) == 4 else m.weight.abs().mean(dim=0) 112 | elif self.args.wg == "filter": 113 | w_abs = m.weight.abs().mean(dim=[1, 2, 3]) if len(shape) == 4 else m.weight.abs().mean(dim=1) 114 | elif self.args.wg == "weight": 115 | w_abs = m.weight.abs().flatten() 116 | return w_abs 117 | 118 | 119 | def _greg_1(self, m, name): 120 | if self.pr[name] == 0: 121 | return True 122 | 123 | if self.args.wg != 'weight': # weight is too slow 124 | self._update_mag_ratio(m, name, self.w_abs[name]) 125 | 126 | pruned = self.pruned_wg[name] 127 | if self.args.wg == "channel": 128 | self.reg[name][:, pruned] += self.args.reg_granularity_prune 129 | elif self.args.wg == "filter": 130 | self.reg[name][pruned, :] += self.args.reg_granularity_prune 131 | elif self.args.wg == 'weight': 132 | self.reg[name][pruned] += self.args.reg_granularity_prune 133 | else: 134 | raise NotImplementedError 135 | 136 | # when all layers are pushed hard enough, stop 137 | if self.args.wg == 'weight': # for weight, do not use the magnitude ratio condition, because 'hist_mag_ratio' is not updated, too costly 138 | finish_update_reg = False 139 | else: 140 | finish_update_reg = True 141 | for k in self.hist_mag_ratio: 142 | if self.hist_mag_ratio[k] < self.args.mag_ratio_limit: 143 | finish_update_reg = False 144 | return finish_update_reg or self.reg[name].max() > self.args.reg_upper_limit 145 | 146 | def _update_reg(self): 147 | for name, m in self.model.named_modules(): 148 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 149 | cnt_m = self.layers[name].layer_index 150 | pr = self.pr[name] 151 | # self.logprint("HERE 3 to CHECK total Iter: %d" % self.total_iter) 152 | # self.logprint(self.iter_update_reg_finished.keys()) 153 | if name in self.iter_update_reg_finished.keys(): 154 | continue 155 | 156 | if self.total_iter % self.args.print_interval == 0: 157 | self.logprint("[%d] Update reg for layer '%s'. Pr = %s. Iter = %d" 158 | % (cnt_m, name, pr, self.total_iter)) 159 | 160 | # get the importance score (L1-norm in this case) 161 | self.w_abs[name] = self._get_score(m) 162 | 163 | # update reg functions, two things: 164 | # (1) update reg of this layer (2) determine if it is time to stop update reg 165 | if self.args.method == "RST" or self.args.method == "RST_Iter": 166 | finish_update_reg = self._greg_1(m, name) 167 | else: 168 | self.logprint("Wrong '--method' argument, please check.") 169 | exit(1) 170 | 171 | # check prune state 172 | if finish_update_reg: 173 | # after 'update_reg' stage, keep the reg to stabilize weight magnitude 174 | self.iter_update_reg_finished[name] = self.total_iter 175 | self.logprint("==> [%d] Just finished 'update_reg'. Iter = %d" % (cnt_m, self.total_iter)) 176 | 177 | # check if all layers finish 'update_reg' 178 | self.prune_state = "stabilize_reg" 179 | for n, mm in self.model.named_modules(): 180 | if isinstance(mm, nn.Conv2d) or isinstance(mm, nn.Linear): 181 | if n not in self.iter_update_reg_finished: 182 | self.prune_state = "update_reg" 183 | break 184 | if self.prune_state == "stabilize_reg": 185 | self.iter_stabilize_reg = self.total_iter 186 | self.logprint("==> All layers just finished 'update_reg', go to 'stabilize_reg'. Iter = %d" % self.total_iter) 187 | self._save_model(mark='just_finished_update_reg') 188 | 189 | # after reg is updated, print to check 190 | # self.logprint("HERE 4 to CHECK total Iter: %d" % self.total_iter) 191 | if self.total_iter % self.args.print_interval == 0: 192 | self.logprint(" reg_status: min = %.5f ave = %.5f max = %.5f" % 193 | (self.reg[name].min(), self.reg[name].mean(), self.reg[name].max())) 194 | 195 | def _apply_reg(self): 196 | for name, m in self.model.named_modules(): 197 | if name in self.reg: 198 | reg = self.reg[name] # [N, C] 199 | if self.args.wg in ['filter', 'channel']: 200 | if reg.shape != m.weight.data.shape: 201 | reg = reg.unsqueeze(2).unsqueeze(3) # [N, C, 1, 1] 202 | elif self.args.wg == 'weight': 203 | reg = reg.view_as(m.weight.data) # [N, C, H, W] 204 | l2_grad = reg * m.weight 205 | if self.args.block_loss_grad: 206 | m.weight.grad = l2_grad 207 | else: 208 | m.weight.grad += l2_grad 209 | 210 | def _resume_prune_status(self, ckpt_path): 211 | state = torch.load(ckpt_path) 212 | self.model = state['model'].cuda() 213 | self.model.load_state_dict(state['state_dict']) 214 | self.optimizer = optim.SGD(self.model.parameters(), 215 | lr=self.args.lr_pick if self.args.__dict__.get('AdaReg_only_picking') else self.args.lr_prune, 216 | momentum=self.args.momentum, 217 | weight_decay=self.args.weight_decay) 218 | self.optimizer.load_state_dict(state['optimizer']) 219 | self.prune_state = state['prune_state'] 220 | self.total_iter = state['iter'] 221 | self.iter_stabilize_reg = state.get('iter_stabilize_reg', math.inf) 222 | self.reg = state['reg'] 223 | self.hist_mag_ratio = state['hist_mag_ratio'] 224 | 225 | def _save_model(self, acc1=0, acc5=0, mark=''): 226 | state = {'iter': self.total_iter, 227 | 'prune_state': self.prune_state, # we will resume prune_state 228 | 'arch': self.args.arch, 229 | 'model': self.model, 230 | 'state_dict': self.model.state_dict(), 231 | 'iter_stabilize_reg': self.iter_stabilize_reg, 232 | 'acc1': acc1, 233 | 'acc5': acc5, 234 | 'optimizer': self.optimizer.state_dict(), 235 | 'reg': self.reg, 236 | 'hist_mag_ratio': self.hist_mag_ratio, 237 | 'ExpID': self.logger.ExpID, 238 | } 239 | self.save(state, is_best=False, mark=mark) 240 | 241 | ### new content from here ### 242 | def _apply_mask_forward(self): 243 | assert hasattr(self, 'mask') and len(self.mask.keys()) > 0 244 | for name, m in self.model.named_modules(): 245 | if name in self.mask: 246 | m.weight.data.mul_(self.mask[name]) 247 | 248 | def _update_pr(self, cycle): 249 | '''update layer pruning ratio in iterative pruning 250 | ''' 251 | for layer, pr in self.pr_backup.items(): 252 | pr_each_time_to_current = 1 - (1 - pr) ** (1. / self.args.num_cycles) 253 | pr_each_time = pr_each_time_to_current * ( (1-pr_each_time_to_current) ** (cycle-1) ) 254 | self.pr[layer] = pr_each_time if self.args.wg in ['filter', 'channel'] else pr_each_time + self.pr[layer] 255 | 256 | def _finetune(self, cycle): 257 | lr_scheduler = PresetLRScheduler(self.args.lr_ft_mini) 258 | optimizer = optim.SGD(self.model.parameters(), 259 | lr=0, # placeholder, this will be updated later 260 | momentum=self.args.momentum, 261 | weight_decay=self.args.weight_decay) 262 | 263 | best_acc1, best_acc1_epoch = 0, 0 264 | timer = Timer(self.args.epochs_mini) 265 | for epoch in range(self.args.epochs_mini): 266 | lr = lr_scheduler(optimizer, epoch) 267 | self.logprint(f'[Subprune #{cycle} Finetune] Epoch {epoch} Set LR = {lr}') 268 | for ix, (inputs, targets) in enumerate(self.train_loader): 269 | inputs, targets = inputs.cuda(), targets.cuda() 270 | self.model.train() 271 | y_ = self.model(inputs) 272 | loss = self.criterion(y_, targets) 273 | optimizer.zero_grad() 274 | loss.backward() 275 | optimizer.step() 276 | 277 | if self.args.method and self.args.wg == 'weight': 278 | self._apply_mask_forward() 279 | 280 | if ix % self.args.print_interval == 0: 281 | self.logprint(f'[Subprune #{cycle} Finetune] Epoch {epoch} Step {ix} loss {loss:.4f}') 282 | # test 283 | acc1, *_ = self.test(self.model) 284 | if acc1 > best_acc1: 285 | best_acc1 = acc1 286 | best_acc1_epoch = epoch 287 | self.accprint(f'[Subprune #{cycle} Finetune] Epoch {epoch} Acc1 {acc1:.4f} (Best_Acc1 {best_acc1:.4f} @ Best_Acc1_Epoch {best_acc1_epoch}) LR {lr}') 288 | self.logprint(f'predicted finish time: {timer()}') 289 | 290 | def prune(self): 291 | # clear existing pr 292 | for layer in self.pr: 293 | self.pr[layer] = 0 294 | 295 | for cycle in range(1, self.args.num_cycles + 1): 296 | self.logprint(f'==> Start sub-Reg #{cycle}') 297 | self._update_pr(cycle) # get pr 298 | self._get_kept_wg_L1() # from pr, update self.pruned_wg 299 | 300 | if cycle == 1: 301 | self.mask = {} # pre-define self.mask here, will be updated after mini_prune ( in self._prune_and_build_new_model() ) 302 | 303 | model_before_removing_weights, self.model = self.mini_prune(cycle) 304 | self._prune_and_build_new_model() # from self.pruned_wg, get mask for wg:weight 305 | 306 | self.logprint('==> Check: if the mask does the correct sparsity') 307 | keys_list = [i for i in self.mask.keys()] 308 | pr_list = [ 1-(self.mask[i].sum()/self.mask[i].numel()) for i in keys_list ] 309 | self.logprint("==> Layer-wise sparsity:") 310 | self.logprint(pr_list) 311 | self.logprint("==> Check done") 312 | 313 | if self.args.RST_Iter_weight_delete: 314 | self._apply_mask_forward() # set pruned weights to 0 315 | 316 | if cycle < self.args.num_cycles and self.args.RST_Iter_ft == 1: 317 | self._finetune(cycle) 318 | 319 | return model_before_removing_weights, self.model 320 | 321 | # self._prune_and_build_new_model() 322 | # if cycle < self.args.num_cycles: 323 | # self._finetune(cycle) # there is a big finetuning after the last pruning, so do not finetune here 324 | 325 | ### new content until here ### 326 | 327 | def mini_prune(self, cycle): # prune --> mini_prune 328 | self.model = self.model.train() 329 | self.optimizer = optim.SGD(self.model.parameters(), 330 | lr=self.args.lr_pick if self.args.__dict__.get('AdaReg_only_picking') else self.args.lr_prune, 331 | momentum=self.args.momentum, 332 | weight_decay=self.args.weight_decay) 333 | 334 | # resume model, optimzer, prune_status 335 | self.total_iter = -1 336 | if self.args.resume_path: 337 | self._resume_prune_status(self.args.resume_path) 338 | self._get_kept_wg_L1() # get pruned and kept wg from the resumed model 339 | self.model = self.model.train() 340 | self.logprint("Resume model successfully: '{}'. Iter = {}. prune_state = {}".format( 341 | self.args.resume_path, self.total_iter, self.prune_state)) 342 | 343 | acc1 = acc5 = 0 344 | total_iter_reg = self.args.reg_upper_limit / self.args.reg_granularity_prune * self.args.update_reg_interval + self.args.stabilize_reg_interval 345 | timer = Timer(total_iter_reg / self.args.print_interval) 346 | while True: 347 | for _, (inputs, targets) in enumerate(self.train_loader): 348 | inputs, targets = inputs.cuda(), targets.cuda() 349 | self.total_iter += 1 350 | total_iter = self.total_iter 351 | 352 | # test 353 | if total_iter % self.args.test_interval == 0: 354 | acc1, acc5, *_ = self.test(self.model) 355 | self.accprint("Acc1 = %.4f Acc5 = %.4f Iter = %d (before update) [prune_state = %s, method = %s]" % 356 | (acc1, acc5, total_iter, self.prune_state, self.args.method)) 357 | 358 | # save model (save model before a batch starts) 359 | if total_iter % self.args.save_interval == 0: 360 | self._save_model(acc1, acc5) 361 | self.logprint('Periodically save model done. Iter = {}'.format(total_iter)) 362 | 363 | if total_iter % self.args.print_interval == 0: 364 | self.logprint("") 365 | self.logprint("Iter = %d [prune_state = %s, method = %s] " 366 | % (total_iter, self.prune_state, self.args.method) + "-"*40) 367 | 368 | # forward 369 | self.model.train() 370 | y_ = self.model(inputs) 371 | # self.logprint("HERE 1 to CHECK total Iter: %d" % self.total_iter) 372 | if self.prune_state == "update_reg" and total_iter % self.args.update_reg_interval == 0: 373 | # self.logprint("HERE 2 to CHECK total Iter: %d" % self.total_iter) 374 | self._update_reg() 375 | 376 | # normal training forward 377 | loss = self.criterion(y_, targets) 378 | self.optimizer.zero_grad() 379 | loss.backward() 380 | 381 | # after backward but before update, apply reg to the grad 382 | self._apply_reg() 383 | self.optimizer.step() 384 | 385 | if self.args.method and self.args.wg == 'weight' and cycle != 1: 386 | if self.args.RST_Iter_weight_delete: 387 | self._apply_mask_forward() # the mask from last cycle should be used here 388 | 389 | # log print 390 | if total_iter % self.args.print_interval == 0: 391 | # check BN stats 392 | if self.args.verbose: 393 | for name, m in self.model.named_modules(): 394 | if isinstance(m, nn.BatchNorm2d): 395 | # get the associating conv layer of this BN layer 396 | ix = self.all_layers.index(name) 397 | for k in range(ix-1, -1, -1): 398 | if self.all_layers[k] in self.layers: 399 | last_conv = self.all_layers[k] 400 | break 401 | mask_ = [0] * m.weight.data.size(0) 402 | for i in self.kept_wg[last_conv]: 403 | mask_[i] = 1 404 | wstr = ' '.join(['%.3f (%s)' % (x, y) for x, y in zip(m.weight.data, mask_)]) 405 | bstr = ' '.join(['%.3f (%s)' % (x, y) for x, y in zip(m.bias.data, mask_)]) 406 | logstr = f'{last_conv} BN weight: {wstr}\nBN bias: {bstr}' 407 | self.logprint(logstr) 408 | 409 | # check train acc 410 | _, predicted = y_.max(1) 411 | correct = predicted.eq(targets).sum().item() 412 | train_acc = correct / targets.size(0) 413 | self.logprint("After optim update current_train_loss: %.4f current_train_acc: %.4f" % (loss.item(), train_acc)) 414 | 415 | 416 | # change prune state 417 | if self.prune_state == "stabilize_reg" and total_iter - self.iter_stabilize_reg == self.args.stabilize_reg_interval: 418 | # # --- check accuracy to make sure '_prune_and_build_new_model' works normally 419 | # # checked. works normally! 420 | # for name, m in self.model.named_modules(): 421 | # if isinstance(m, self.learnable_layers): 422 | # pruned_filter = self.pruned_wg[name] 423 | # m.weight.data[pruned_filter] *= 0 424 | # next_bn = self._next_bn(self.model, m) 425 | # elif isinstance(m, nn.BatchNorm2d) and m == next_bn: 426 | # m.weight.data[pruned_filter] *= 0 427 | # m.bias.data[pruned_filter] *= 0 428 | 429 | # acc1_before, *_ = self.test(self.model) 430 | # self._prune_and_build_new_model() 431 | # acc1_after, *_ = self.test(self.model) 432 | # print(acc1_before, acc1_after) 433 | # exit() 434 | # # --- 435 | model_before_removing_weights = copy.deepcopy(self.model) 436 | self._prune_and_build_new_model() 437 | self.logprint("'stabilize_reg' is done. Pruned, go to 'finetune'. Iter = %d" % total_iter) 438 | 439 | if cycle < self.args.num_cycles: # reset all necessary config 440 | self.prune_state = "update_reg" # recover for next mini-prune 441 | self.iter_update_reg_finished = {} 442 | self.reg = {} # 443 | self.w_abs = {} # 444 | self.iter_stabilize_reg = math.inf 445 | 446 | 447 | for name, m in self.model.named_modules(): # 448 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 449 | shape = m.weight.data.shape 450 | 451 | # initialize reg 452 | if self.args.wg == 'weight': 453 | self.reg[name] = torch.zeros_like(m.weight.data).flatten().cuda() 454 | else: 455 | self.reg[name] = torch.zeros(shape[0], shape[1]).cuda() 456 | 457 | # get original weight magnitude 458 | w_abs = self._get_score(m) 459 | n_wg = len(w_abs) 460 | self.ranking[name] = [] 461 | for _ in range(n_wg): 462 | self.ranking[name].append([]) 463 | self.original_w_mag[name] = m.weight.abs().mean().item() 464 | 465 | 466 | return model_before_removing_weights, copy.deepcopy(self.model) 467 | 468 | if total_iter % self.args.print_interval == 0: 469 | self.logprint(f"predicted_finish_time of reg: {timer()}") 470 | 471 | def _plot_mag_ratio(self, w_abs, name): 472 | fig, ax = plt.subplots() 473 | max_ = w_abs.max().item() 474 | w_abs_normalized = (w_abs / max_).data.cpu().numpy() 475 | ax.plot(w_abs_normalized) 476 | ax.set_ylim([0, 1]) 477 | ax.set_xlabel('filter index') 478 | ax.set_ylabel('relative L1-norm ratio') 479 | layer_index = self.layers[name].layer_index 480 | shape = self.layers[name].size 481 | ax.set_title("layer %d iter %d shape %s\n(max = %s)" 482 | % (layer_index, self.total_iter, shape, max_)) 483 | out = pjoin(self.logger.logplt_path, "%d_iter%d_w_abs_dist.jpg" % 484 | (layer_index, self.total_iter)) 485 | fig.savefig(out) 486 | plt.close(fig) 487 | np.save(out.replace('.jpg', '.npy'), w_abs_normalized) 488 | 489 | def _log_down_mag_reg(self, w_abs, name): 490 | step = self.total_iter 491 | reg = self.reg[name].max().item() 492 | mag = w_abs.data.cpu().numpy() 493 | if name not in self.mag_reg_log: 494 | values = [[step, reg, mag]] 495 | log = { 496 | 'name': name, 497 | 'layer_index': self.layers[name].layer_index, 498 | 'shape': self.layers[name].size, 499 | 'values': values, 500 | } 501 | self.mag_reg_log[name] = log 502 | else: 503 | values = self.mag_reg_log[name]['values'] 504 | values.append([step, reg, mag]) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import shutil 5 | import time 6 | import warnings 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.parallel 11 | import torch.backends.cudnn as cudnn 12 | import torch.distributed as dist 13 | import torch.optim 14 | import torch.multiprocessing as mp 15 | import torch.utils.data 16 | import torch.utils.data.distributed 17 | import torchvision.transforms as transforms 18 | import torchvision.datasets as datasets 19 | import torchvision.models as models 20 | 21 | import copy 22 | import numpy as np 23 | from importlib import import_module 24 | from data import Data 25 | from logger import Logger 26 | from utils import get_n_params, get_n_flops, get_n_params_, get_n_flops_, PresetLRScheduler, Timer 27 | from utils import add_noise_to_model, compute_jacobian 28 | from utils import Dataset_lmdb_batch 29 | from utils import AverageMeter, ProgressMeter, adjust_learning_rate, accuracy 30 | from model import model_dict, is_single_branch 31 | from data import num_classes_dict, img_size_dict 32 | from pruner import pruner_dict 33 | from option import args 34 | pjoin = os.path.join 35 | 36 | logger = Logger(args) 37 | logprint = logger.log_printer.logprint 38 | accprint = logger.log_printer.accprint 39 | netprint = logger.netprint 40 | timer = Timer(args.epochs) 41 | # --- 42 | from pdb import set_trace as st 43 | 44 | def main(): 45 | if args.seed is not None: 46 | random.seed(args.seed) 47 | torch.manual_seed(args.seed) 48 | cudnn.deterministic = True 49 | warnings.warn('You have chosen to seed training. ' 50 | 'This will turn on the CUDNN deterministic setting, ' 51 | 'which can slow down your training considerably! ' 52 | 'You may see unexpected behavior when restarting ' 53 | 'from checkpoints.') 54 | 55 | if args.gpu is not None: 56 | warnings.warn('You have chosen a specific GPU. This will completely ' 57 | 'disable data parallelism.') 58 | 59 | if args.dist_url == "env://" and args.world_size == -1: 60 | args.world_size = int(os.environ["WORLD_SIZE"]) 61 | 62 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 63 | 64 | ngpus_per_node = torch.cuda.device_count() 65 | if args.multiprocessing_distributed: 66 | # Since we have ngpus_per_node processes per node, the total world_size 67 | # needs to be adjusted accordingly 68 | args.world_size = ngpus_per_node * args.world_size 69 | # Use torch.multiprocessing.spawn to launch distributed processes: the 70 | # main_worker process function 71 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 72 | else: 73 | # Simply call main_worker function 74 | main_worker(args.gpu, ngpus_per_node, args) 75 | 76 | 77 | def main_worker(gpu, ngpus_per_node, args): 78 | global best_acc1, best_acc1_epoch 79 | args.gpu = gpu 80 | 81 | # Data loading code 82 | train_sampler = None 83 | if args.dataset not in ['imagenet', 'imagenet_subset_200']: 84 | loader = Data(args) 85 | train_loader = loader.train_loader 86 | val_loader = loader.test_loader 87 | else: 88 | traindir = os.path.join(args.data_path, args.dataset, 'train') 89 | val_folder = 'val' 90 | if args.debug: 91 | val_folder = 'val_tmp' # val_tmp is a tiny version of val to accelerate test in debugging 92 | val_folder_path = f'{args.data_path}/{args.dataset}/{val_folder}' 93 | if not os.path.exists(val_folder_path): 94 | os.makedirs(val_folder_path) 95 | dirs = os.listdir(f'{args.data_path}/{args.dataset}/val')[:3] 96 | [shutil.copytree(f'{args.data_path}/{args.dataset}/val/{d}', f'{val_folder_path}/{d}') for d in dirs] 97 | valdir = os.path.join(args.data_path, args.dataset, val_folder) 98 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 99 | std=[0.229, 0.224, 0.225]) 100 | transforms_train = transforms.Compose([ 101 | transforms.RandomResizedCrop(224), 102 | transforms.RandomHorizontalFlip(), 103 | transforms.ToTensor(), 104 | normalize]) 105 | transforms_val = transforms.Compose([ 106 | transforms.Resize(256), 107 | transforms.CenterCrop(224), 108 | transforms.ToTensor(), 109 | normalize]) 110 | 111 | if args.use_lmdb: 112 | lmdb_path_train = traindir + '/lmdb' 113 | lmdb_path_val = valdir + '/lmdb' 114 | assert os.path.exists(lmdb_path_train) and os.path.exists(lmdb_path_val) 115 | logprint(f'Loading data in LMDB format: "{lmdb_path_train}" and "{lmdb_path_val}"') 116 | train_dataset = Dataset_lmdb_batch(lmdb_path_train, transforms_train) 117 | val_dataset = Dataset_lmdb_batch(lmdb_path_val, transforms_val) 118 | else: 119 | train_dataset = datasets.ImageFolder(traindir, transforms_train) 120 | val_dataset = datasets.ImageFolder(valdir, transforms_val) 121 | 122 | if args.distributed: 123 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 124 | 125 | train_loader = torch.utils.data.DataLoader( 126 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 127 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 128 | 129 | val_loader = torch.utils.data.DataLoader( 130 | val_dataset, batch_size=args.batch_size, shuffle=False, 131 | num_workers=args.workers, pin_memory=True) 132 | 133 | # define loss function (criterion) and optimizer 134 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 135 | 136 | if args.gpu is not None: 137 | logprint("Use GPU: {} for training".format(args.gpu)) 138 | 139 | if args.distributed: 140 | if args.dist_url == "env://" and args.rank == -1: 141 | args.rank = int(os.environ["RANK"]) 142 | if args.multiprocessing_distributed: 143 | # For multiprocessing distributed training, rank needs to be the 144 | # global rank among all the processes 145 | args.rank = args.rank * ngpus_per_node + gpu 146 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 147 | world_size=args.world_size, rank=args.rank) 148 | # create model 149 | num_classes = num_classes_dict[args.dataset] 150 | img_size = img_size_dict[args.dataset] 151 | num_channels = 1 if args.dataset == 'mnist' else 3 152 | if args.dataset in ["imagenet", "imagenet_subset_200", "tiny_imagenet"]: 153 | if args.pretrained: 154 | logprint("=> using pre-trained model '{}'".format(args.arch)) 155 | model = models.__dict__[args.arch](num_classes=num_classes, pretrained=True) 156 | else: 157 | logprint("=> creating model '{}'".format(args.arch)) 158 | model = models.__dict__[args.arch](num_classes=num_classes) 159 | else: # @mst: added non-imagenet models 160 | model = model_dict[args.arch](num_classes=num_classes, num_channels=num_channels, use_bn=args.use_bn) 161 | if args.init in ['orth', 'exact_isometry_from_scratch']: 162 | model.apply(lambda m: _weights_init_orthogonal(m, act=args.activation)) 163 | logprint("==> Use weight initialization: 'orthogonal_'. Activation: %s" % args.activation) 164 | 165 | # @mst: save the model after initialization if necessary 166 | if args.save_init_model: 167 | state = { 168 | 'arch': args.arch, 169 | 'model': model, 170 | 'state_dict': model.state_dict(), 171 | 'ExpID': logger.ExpID, 172 | } 173 | save_model(state, mark='init') 174 | 175 | if args.distributed: 176 | # For multiprocessing distributed, DistributedDataParallel constructor 177 | # should always set the single device scope, otherwise, 178 | # DistributedDataParallel will use all available devices. 179 | if args.gpu is not None: 180 | torch.cuda.set_device(args.gpu) 181 | model.cuda(args.gpu) 182 | # When using a single GPU per process and per 183 | # DistributedDataParallel, we need to divide the batch size 184 | # ourselves based on the total number of GPUs we have 185 | args.batch_size = int(args.batch_size / ngpus_per_node) 186 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) 187 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 188 | else: 189 | model.cuda() 190 | # DistributedDataParallel will divide and allocate batch_size to all 191 | # available GPUs if device_ids are not set 192 | model = torch.nn.parallel.DistributedDataParallel(model) 193 | elif args.gpu is not None: 194 | torch.cuda.set_device(args.gpu) 195 | model = model.cuda(args.gpu) 196 | else: 197 | # DataParallel will divide and allocate batch_size to all available GPUs 198 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 199 | model.features = torch.nn.DataParallel(model.features) 200 | model.cuda() 201 | else: 202 | model = torch.nn.DataParallel(model).cuda() 203 | 204 | # @mst: load the unpruned model for pruning 205 | # This may be useful for the non-imagenet cases where we use our pretrained models 206 | if args.base_model_path: 207 | ckpt = torch.load(args.base_model_path) 208 | if 'model' in ckpt: 209 | model = ckpt['model'] 210 | model.load_state_dict(ckpt['state_dict']) 211 | logstr = f"==> Load pretrained model successfully: '{args.base_model_path}'" 212 | if args.test_pretrained: 213 | acc1, acc5, loss_test = validate(val_loader, model, criterion, args) 214 | logstr += f". Its accuracy: {acc1:.4f}" 215 | logprint(logstr) 216 | 217 | # @mst: print base model arch 218 | netprint(model, comment='base model arch') 219 | 220 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 221 | momentum=args.momentum, 222 | weight_decay=args.weight_decay) # @mst: This solver is not be really used. We will use our own. 223 | 224 | 225 | # optionally resume from a checkpoint 226 | # @mst: we will use our option '--resume_path', keep this simply for back-compatibility 227 | best_acc1, best_acc1_epoch = 0, 0 228 | if args.resume: 229 | if os.path.isfile(args.resume): 230 | logprint("=> loading checkpoint '{}'".format(args.resume)) 231 | if args.gpu is None: 232 | checkpoint = torch.load(args.resume) 233 | else: 234 | # Map model to be loaded to specified single gpu. 235 | loc = 'cuda:{}'.format(args.gpu) 236 | checkpoint = torch.load(args.resume, map_location=loc) 237 | args.start_epoch = checkpoint['epoch'] 238 | best_acc1 = checkpoint['best_acc1'] 239 | if args.gpu is not None: 240 | # best_acc1 may be from a checkpoint from a different GPU 241 | best_acc1 = best_acc1.to(args.gpu) 242 | model.load_state_dict(checkpoint['state_dict']) 243 | optimizer.load_state_dict(checkpoint['optimizer']) 244 | logprint("=> loaded checkpoint '{}' (epoch {})" 245 | .format(args.resume, checkpoint['epoch'])) 246 | else: 247 | logprint("=> no checkpoint found at '{}'".format(args.resume)) 248 | 249 | cudnn.benchmark = True 250 | 251 | if args.evaluate: 252 | acc1, acc5, loss_test = validate(val_loader, model, criterion, args) 253 | logprint('Acc1 %.4f Acc5 %.4f Loss_test %.4f' % (acc1, acc5, loss_test)) 254 | return 255 | 256 | # --- @mst: Structured pruning is basically equivalent to providing a new weight initialization before finetune, 257 | # so just before training, conduct pruning to obtain a new model. 258 | if args.method: 259 | if args.dataset in ['imagenet', 'imagenet_subset_200']: 260 | # imagenet training costs too much time, so we use a smaller batch size for pruning training 261 | train_loader_prune = torch.utils.data.DataLoader( 262 | train_dataset, batch_size=args.batch_size_prune, shuffle=(train_sampler is None), 263 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 264 | else: 265 | train_loader_prune = loader.train_loader_prune 266 | 267 | # get the original unpruned model statistics 268 | n_params_original_v2 = get_n_params_(model) 269 | n_flops_original_v2 = get_n_flops_(model, img_size=img_size, n_channel=num_channels) 270 | 271 | # init some variables for pruning 272 | prune_state, pruner = 'prune', None 273 | if args.wg == 'weight': 274 | global mask 275 | 276 | # resume a model 277 | if args.resume_path: 278 | state = torch.load(args.resume_path) 279 | prune_state = state['prune_state'] # finetune or update_reg or stabilize_reg 280 | if prune_state == 'finetune': 281 | model = state['model'].cuda() 282 | model.load_state_dict(state['state_dict']) 283 | if args.solver == 'Adam': 284 | logprint('==> Using Adam optimizer') 285 | optimizer = torch.optim.Adam(model.parameters(), args.lr) 286 | else: 287 | logprint('==> Using SGD optimizer') 288 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 289 | momentum=args.momentum, 290 | weight_decay=args.weight_decay) 291 | optimizer.load_state_dict(state['optimizer']) 292 | args.start_epoch = state['epoch'] 293 | logprint("==> Resume model successfully: '{}'. Epoch = {}. prune_state = '{}'".format( 294 | args.resume_path, args.start_epoch, prune_state)) 295 | else: 296 | raise NotImplementedError 297 | 298 | # finetune a model 299 | if args.directly_ft_weights: 300 | state = torch.load(args.directly_ft_weights) 301 | model = state['model'].cuda() 302 | model.load_state_dict(state['state_dict']) 303 | prune_state = 'finetune' 304 | logprint("==> Load model successfully: '{}'. Epoch = {}. prune_state = '{}'".format( 305 | args.directly_ft_weights, args.start_epoch, prune_state)) 306 | 307 | if 'mask' in state: 308 | mask = state['mask'] 309 | apply_mask_forward(model) 310 | logprint('==> Mask restored') 311 | st() 312 | 313 | if prune_state in ['prune']: 314 | 315 | class passer: pass # to pass arguments 316 | passer.test = validate 317 | passer.finetune = finetune 318 | passer.train_loader = train_loader_prune 319 | passer.test_loader = val_loader 320 | passer.save = save_model 321 | passer.criterion = criterion 322 | passer.train_sampler = train_sampler 323 | passer.pruner = pruner 324 | passer.args = args 325 | passer.is_single_branch = is_single_branch 326 | # *********************************************************** 327 | # key pruning function 328 | pruner = pruner_dict[args.method].Pruner(model, args, logger, passer) 329 | model = pruner.prune() # get the pruned model 330 | if args.method in ['RST', 'RST_Iter']: 331 | model_before_removing_weights, model = model 332 | if args.wg == 'weight': 333 | mask = pruner.mask 334 | apply_mask_forward(model) 335 | logprint('==> Check: if the mask does the correct sparsity') 336 | keys_list = [i for i in mask.keys()] 337 | pr_list = [ 1-(mask[i].sum()/mask[i].numel()) for i in keys_list ] 338 | logprint("==> Layer-wise sparsity:") 339 | logprint(pr_list) 340 | logprint('==> Apply masks before finetuning to ensure the pruned weights are zero') 341 | netprint(model, comment='model that was just pruned') 342 | # *********************************************************** 343 | 344 | # get model statistics of the pruned model 345 | n_params_now_v2 = get_n_params_(model) 346 | n_flops_now_v2 = get_n_flops_(model, img_size=img_size, n_channel=num_channels) 347 | logprint("==> n_params_original_v2: {:>9.6f}M, n_flops_original_v2: {:>9.6f}G".format(n_params_original_v2/1e6, n_flops_original_v2/1e9)) 348 | logprint("==> n_params_now_v2: {:>9.6f}M, n_flops_now_v2: {:>9.6f}G".format(n_params_now_v2/1e6, n_flops_now_v2/1e9)) 349 | ratio_param = (n_params_original_v2 - n_params_now_v2) / n_params_original_v2 350 | ratio_flops = (n_flops_original_v2 - n_flops_now_v2) / n_flops_original_v2 351 | compression_ratio = 1.0 / (1 - ratio_param) 352 | speedup_ratio = 1.0 / (1 - ratio_flops) 353 | logprint("==> reduction ratio -- params: {:>5.2f}% (compression ratio {:>.2f}x), flops: {:>5.2f}% (speedup ratio {:>.2f}x)".format(ratio_param*100, compression_ratio, ratio_flops*100, speedup_ratio)) 354 | 355 | # test the just pruned model 356 | t1 = time.time() 357 | acc1, acc5, loss_test = validate(val_loader, model, criterion, args) # test set 358 | logstr = [] 359 | logstr += ["Acc1 %.4f Acc5 %.4f Loss_test %.4f" % (acc1, acc5, loss_test)] 360 | if args.dataset not in ['imagenet']: # too costly, not test 361 | acc1_train, acc5_train, loss_train = validate(train_loader, model, criterion, args, noisy_model_ensemble=args.model_noise_std) # train set 362 | logstr += ["Acc1_train %.4f Acc5_train %.4f Loss_train %.4f" % (acc1_train, acc5_train, loss_train)] 363 | logstr += ["(test_time %.2fs) Just got pruned model, about to finetune" % (time.time() - t1)] 364 | accprint(' | '.join(logstr)) 365 | 366 | # save the just pruned model 367 | state = {'arch': args.arch, 368 | 'model': model, 369 | 'state_dict': model.state_dict(), 370 | 'acc1': acc1, 371 | 'acc5': acc5, 372 | 'ExpID': logger.ExpID, 373 | 'pruned_wg': pruner.pruned_wg, 374 | 'kept_wg': pruner.kept_wg, 375 | } 376 | if args.wg == 'weight': 377 | state['mask'] = mask 378 | save_model(state, mark="just_finished_prune") 379 | 380 | # finetune 381 | finetune(model, train_loader, val_loader, train_sampler, criterion, pruner, best_acc1, best_acc1_epoch, args, num_classes=num_classes) 382 | 383 | # @mst 384 | def finetune(model, train_loader, val_loader, train_sampler, criterion, pruner, best_acc1, best_acc1_epoch, args, num_classes=10, print_log=True): 385 | # since model is new, we need a new optimizer 386 | if args.solver == 'Adam': 387 | logprint('==> Start to finetune: using Adam optimizer') 388 | optimizer = torch.optim.Adam(model.parameters(), args.lr) 389 | else: 390 | logprint('==> Start to finetune: using SGD optimizer') 391 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 392 | momentum=args.momentum, 393 | weight_decay=args.weight_decay) 394 | 395 | # set lr finetune schduler for finetune 396 | if args.method: 397 | assert args.lr_ft is not None 398 | lr_scheduler = PresetLRScheduler(args.lr_ft) 399 | 400 | acc1_list, loss_train_list, loss_test_list = [], [], [] 401 | for epoch in range(args.start_epoch, args.epochs): 402 | if args.distributed: 403 | train_sampler.set_epoch(epoch) 404 | 405 | # @mst: use our own lr scheduler 406 | lr = lr_scheduler(optimizer, epoch) if args.method else adjust_learning_rate(optimizer, epoch, args) 407 | if print_log: 408 | logprint("==> Set lr = %s @ Epoch %d " % (lr, epoch)) 409 | 410 | # train for one epoch 411 | train(train_loader, model, criterion, optimizer, epoch, args, print_log=print_log) 412 | 413 | # @mst: check weights magnitude during finetune 414 | if args.method in ['GReg-1', 'GReg-2'] and not isinstance(pruner, type(None)): 415 | for name, m in model.named_modules(): 416 | if name in pruner.reg: 417 | ix = pruner.layers[name].layer_index 418 | mag_now = m.weight.data.abs().mean() 419 | mag_old = pruner.original_w_mag[name] 420 | ratio = mag_now / mag_old 421 | tmp = '[%2d] %25s -- mag_old = %.4f, mag_now = %.4f (%.2f)' % (ix, name, mag_old, mag_now, ratio) 422 | print(tmp, file=logger.logtxt, flush=True) 423 | if args.screen_print: 424 | print(tmp) 425 | 426 | # evaluate on validation set 427 | acc1, acc5, loss_test = validate(val_loader, model, criterion, args) # @mst: added acc5 428 | if args.dataset != 'imagenet': # too costly, not test for now 429 | acc1_train, acc5_train, loss_train = validate(train_loader, model, criterion, args) 430 | else: 431 | acc1_train, acc5_train, loss_train = -1, -1, -1 432 | acc1_list.append(acc1) 433 | loss_train_list.append(loss_train) 434 | loss_test_list.append(loss_test) 435 | 436 | # remember best acc@1 and save checkpoint 437 | is_best = acc1 > best_acc1 438 | best_acc1 = max(acc1, best_acc1) 439 | if is_best: 440 | best_acc1_epoch = epoch 441 | best_loss_train = loss_train 442 | best_loss_test = loss_test 443 | if print_log: 444 | accprint("Acc1 %.4f Acc5 %.4f Loss_test %.4f | Acc1_train %.4f Acc5_train %.4f Loss_train %.4f | Epoch %d (Best_Acc1 %.4f @ Best_Acc1_Epoch %d) lr %s" % 445 | (acc1, acc5, loss_test, acc1_train, acc5_train, loss_train, epoch, best_acc1, best_acc1_epoch, lr)) 446 | logprint('predicted finish time: %s' % timer()) 447 | 448 | ngpus_per_node = torch.cuda.device_count() 449 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 450 | and args.rank % ngpus_per_node == 0): 451 | if args.method: 452 | # @mst: use our own save func 453 | state = {'epoch': epoch + 1, 454 | 'arch': args.arch, 455 | 'model': model, 456 | 'state_dict': model.state_dict(), 457 | 'acc1': acc1, 458 | 'acc5': acc5, 459 | 'optimizer': optimizer.state_dict(), 460 | 'ExpID': logger.ExpID, 461 | 'prune_state': 'finetune', 462 | } 463 | if args.wg == 'weight': 464 | state['mask'] = mask 465 | save_model(state, is_best) 466 | else: 467 | save_checkpoint({ 468 | 'epoch': epoch + 1, 469 | 'arch': args.arch, 470 | 'state_dict': model.state_dict(), 471 | 'best_acc1': best_acc1, 472 | 'optimizer' : optimizer.state_dict(), 473 | }, is_best) 474 | 475 | last5_acc_mean, last5_acc_std = np.mean(acc1_list[-args.last_n_epoch:]), np.std(acc1_list[-args.last_n_epoch:]) 476 | last5_loss_train_mean, last5_loss_train_std = np.mean(loss_train_list[-args.last_n_epoch:]), np.std(loss_train_list[-args.last_n_epoch:]) 477 | last5_loss_test_mean, last5_loss_test_std = np.mean(loss_test_list[-args.last_n_epoch:]), np.std(loss_test_list[-args.last_n_epoch:]) 478 | 479 | best = [best_acc1, best_loss_train, best_loss_test] 480 | last5 = [last5_acc_mean, last5_acc_std, last5_loss_train_mean, last5_loss_train_std, last5_loss_test_mean, last5_loss_test_std] 481 | return best, last5 482 | 483 | def train(train_loader, model, criterion, optimizer, epoch, args, print_log=True): 484 | batch_time = AverageMeter('Time', ':6.3f') 485 | data_time = AverageMeter('Data', ':6.3f') 486 | losses = AverageMeter('Loss', ':.4e') 487 | top1 = AverageMeter('Acc@1', ':6.2f') 488 | top5 = AverageMeter('Acc@5', ':6.2f') 489 | progress = ProgressMeter( 490 | len(train_loader), 491 | [batch_time, data_time, losses, top1, top5], 492 | prefix="Epoch: [{}]".format(epoch)) 493 | 494 | # switch to train mode 495 | model.train() 496 | 497 | end = time.time() 498 | for i, (images, target) in enumerate(train_loader): 499 | # measure data loading time 500 | data_time.update(time.time() - end) 501 | 502 | if args.gpu is not None: 503 | images = images.cuda(args.gpu, non_blocking=True) 504 | target = target.cuda(args.gpu, non_blocking=True) 505 | 506 | # compute output 507 | output = model(images) 508 | loss = criterion(output, target) 509 | 510 | # measure accuracy and record loss 511 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 512 | losses.update(loss.item(), images.size(0)) 513 | top1.update(acc1[0], images.size(0)) 514 | top5.update(acc5[0], images.size(0)) 515 | 516 | # compute gradient and do SGD step 517 | optimizer.zero_grad() 518 | loss.backward() 519 | optimizer.step() 520 | 521 | # @mst: after update, zero out pruned weights 522 | if args.method and args.wg == 'weight': 523 | apply_mask_forward(model) 524 | 525 | # measure elapsed time 526 | batch_time.update(time.time() - end) 527 | end = time.time() 528 | 529 | if print_log and i % args.print_freq == 0: 530 | progress.display(i) 531 | 532 | 533 | def validate(val_loader, model, criterion, args, noisy_model_ensemble=False): 534 | batch_time = AverageMeter('Time', ':6.3f') 535 | losses = AverageMeter('Loss', ':.4e') 536 | top1 = AverageMeter('Acc@1', ':6.2f') 537 | top5 = AverageMeter('Acc@5', ':6.2f') 538 | progress = ProgressMeter( 539 | len(val_loader), 540 | [batch_time, losses, top1, top5], 541 | prefix='Test: ') 542 | 543 | train_state = model.training 544 | 545 | # switch to evaluate mode 546 | model.eval() 547 | 548 | # @mst: add noise to model 549 | model_ensemble = [] 550 | if noisy_model_ensemble: 551 | for i in range(args.model_noise_num): 552 | noisy_model = add_noise_to_model(model, std=args.model_noise_std) 553 | model_ensemble.append(noisy_model) 554 | logprint('==> added Gaussian noise to model weights (std=%s, num=%d)' % (args.model_noise_std, args.model_noise_num)) 555 | else: 556 | model_ensemble.append(model) 557 | 558 | time_compute = [] 559 | with torch.no_grad(): 560 | end = time.time() 561 | for i, (images, target) in enumerate(val_loader): 562 | if args.gpu is not None: 563 | images = images.cuda(args.gpu, non_blocking=True) 564 | target = target.cuda(args.gpu, non_blocking=True) 565 | 566 | # compute output 567 | t1 = time.time() 568 | output = 0 569 | for model in model_ensemble: # @mst: test model ensemble 570 | output += model(images) 571 | output /= len(model_ensemble) 572 | time_compute.append((time.time() - t1) / images.size(0)) 573 | loss = criterion(output, target) 574 | 575 | # measure accuracy and record loss 576 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 577 | losses.update(loss.item(), images.size(0)) 578 | top1.update(acc1[0], images.size(0)) 579 | top5.update(acc5[0], images.size(0)) 580 | 581 | # measure elapsed time 582 | batch_time.update(time.time() - end) 583 | end = time.time() 584 | 585 | if i % args.print_freq == 0: 586 | progress.display(i) 587 | 588 | # TODO: this should also be done with the ProgressMeter 589 | # logprint(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 590 | # .format(top1=top1, top5=top5)) 591 | # @mst: commented because we will use another print outside 'validate' 592 | # logprint("time compute: %.4f ms" % (np.mean(time_compute)*1000)) 593 | 594 | # change back to original model state if necessary 595 | if train_state: 596 | model.train() 597 | return top1.avg.item(), top5.avg.item(), losses.avg # @mst: added returning top5 acc and loss 598 | 599 | 600 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 601 | torch.save(state, filename) 602 | if is_best: 603 | shutil.copyfile(filename, 'model_best.pth.tar') 604 | 605 | # @mst: use our own save model function 606 | def save_model(state, is_best=False, mark=''): 607 | out = pjoin(logger.weights_path, "checkpoint.pth") 608 | torch.save(state, out) 609 | if is_best: 610 | out_best = pjoin(logger.weights_path, "checkpoint_best.pth") 611 | torch.save(state, out_best) 612 | if mark: 613 | out_mark = pjoin(logger.weights_path, "checkpoint_{}.pth".format(mark)) 614 | torch.save(state, out_mark) 615 | 616 | # @mst: zero out pruned weights for unstructured pruning 617 | def apply_mask_forward(model): 618 | global mask 619 | for name, m in model.named_modules(): 620 | if name in mask: 621 | m.weight.data.mul_(mask[name]) 622 | 623 | if __name__ == '__main__': 624 | main() -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | from torch.utils.data import Dataset 5 | import torch.nn.functional as F 6 | import torchvision 7 | from torch.autograd import Variable 8 | from pprint import pprint 9 | import time, math, os, sys, copy, numpy as np, shutil as sh 10 | import matplotlib.pyplot as plt 11 | from mpl_toolkits.axes_grid1 import make_axes_locatable 12 | from collections import OrderedDict 13 | import glob 14 | from PIL import Image 15 | import pickle 16 | import scipy.io as sio 17 | 18 | def _weights_init(m): 19 | if isinstance(m, (nn.Conv2d, nn.Linear)): 20 | init.kaiming_normal(m.weight) 21 | if m.bias is not None: 22 | m.bias.data.fill_(0) 23 | elif isinstance(m, nn.BatchNorm2d): 24 | if m.weight is not None: 25 | m.weight.data.fill_(1.0) 26 | m.bias.data.zero_() 27 | 28 | 29 | # refer to: https://github.com/Eric-mingjie/rethinking-network-pruning/blob/master/imagenet/l1-norm-pruning/compute_flops.py 30 | def get_n_params(model): 31 | total = sum([param.nelement() if param.requires_grad else 0 for param in model.parameters()]) 32 | total /= 1e6 33 | return total 34 | 35 | # The above 'get_n_params' requires 'param.requires_grad' to be true. In KD, for the teacher, this is not the case. 36 | def get_n_params_(model): 37 | n_params = 0 38 | for _, module in model.named_modules(): 39 | if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear): # only consider Conv2d and Linear, no BN 40 | n_params += module.weight.numel() 41 | if hasattr(module, 'bias') and type(module.bias) != type(None): 42 | n_params += module.bias.numel() 43 | return n_params 44 | 45 | def get_n_flops(model=None, input_res=224, multiply_adds=True, n_channel=3): 46 | model = copy.deepcopy(model) 47 | 48 | prods = {} 49 | def save_hook(name): 50 | def hook_per(self, input, output): 51 | prods[name] = np.prod(input[0].shape) 52 | return hook_per 53 | 54 | list_1=[] 55 | def simple_hook(self, input, output): 56 | list_1.append(np.prod(input[0].shape)) 57 | list_2={} 58 | def simple_hook2(self, input, output): 59 | list_2['names'] = np.prod(input[0].shape) 60 | 61 | 62 | list_conv=[] 63 | def conv_hook(self, input, output): 64 | batch_size, input_channels, input_height, input_width = input[0].size() 65 | output_channels, output_height, output_width = output[0].size() 66 | 67 | kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) 68 | bias_ops = 0 if self.bias is not None else 0 69 | 70 | # params = output_channels * (kernel_ops + bias_ops) # @mst: commented since not used 71 | # flops = (kernel_ops * (2 if multiply_adds else 1) + bias_ops) * output_channels * output_height * output_width * batch_size 72 | 73 | num_weight_params = (self.weight.data != 0).float().sum() # @mst: this should be considering the pruned model 74 | # could be problematic if some weights happen to be 0. 75 | flops = (num_weight_params * (2 if multiply_adds else 1) + bias_ops * output_channels) * output_height * output_width * batch_size 76 | 77 | list_conv.append(flops) 78 | 79 | list_linear=[] 80 | def linear_hook(self, input, output): 81 | batch_size = input[0].size(0) if input[0].dim() == 2 else 1 82 | 83 | weight_ops = self.weight.nelement() * (2 if multiply_adds else 1) 84 | bias_ops = self.bias.nelement() 85 | 86 | flops = batch_size * (weight_ops + bias_ops) 87 | list_linear.append(flops) 88 | 89 | list_bn=[] 90 | def bn_hook(self, input, output): 91 | list_bn.append(input[0].nelement() * 2) 92 | 93 | list_relu=[] 94 | def relu_hook(self, input, output): 95 | list_relu.append(input[0].nelement()) 96 | 97 | list_pooling=[] 98 | def pooling_hook(self, input, output): 99 | batch_size, input_channels, input_height, input_width = input[0].size() 100 | output_channels, output_height, output_width = output[0].size() 101 | 102 | kernel_ops = self.kernel_size * self.kernel_size 103 | bias_ops = 0 104 | params = 0 105 | flops = (kernel_ops + bias_ops) * output_channels * output_height * output_width * batch_size 106 | 107 | list_pooling.append(flops) 108 | 109 | list_upsample=[] 110 | 111 | # For bilinear upsample 112 | def upsample_hook(self, input, output): 113 | batch_size, input_channels, input_height, input_width = input[0].size() 114 | output_channels, output_height, output_width = output[0].size() 115 | 116 | flops = output_height * output_width * output_channels * batch_size * 12 117 | list_upsample.append(flops) 118 | 119 | def foo(net): 120 | childrens = list(net.children()) 121 | if not childrens: 122 | if isinstance(net, torch.nn.Conv2d): 123 | net.register_forward_hook(conv_hook) 124 | if isinstance(net, torch.nn.Linear): 125 | net.register_forward_hook(linear_hook) 126 | if isinstance(net, torch.nn.BatchNorm2d): 127 | net.register_forward_hook(bn_hook) 128 | if isinstance(net, torch.nn.ReLU): 129 | net.register_forward_hook(relu_hook) 130 | if isinstance(net, torch.nn.MaxPool2d) or isinstance(net, torch.nn.AvgPool2d): 131 | net.register_forward_hook(pooling_hook) 132 | if isinstance(net, torch.nn.Upsample): 133 | net.register_forward_hook(upsample_hook) 134 | return 135 | for c in childrens: 136 | foo(c) 137 | 138 | if model == None: 139 | model = torchvision.models.alexnet() 140 | foo(model) 141 | input = Variable(torch.rand(n_channel,input_res,input_res).unsqueeze(0), requires_grad = True) 142 | out = model(input) 143 | 144 | 145 | total_flops = (sum(list_conv) + sum(list_linear)) # + sum(list_bn) + sum(list_relu) + sum(list_pooling) + sum(list_upsample)) 146 | total_flops /= 1e9 147 | # print(' Number of FLOPs: %.2fG' % total_flops) 148 | 149 | return total_flops 150 | 151 | # The above version is redundant. Get a neat version as follow. 152 | def get_n_flops_(model=None, img_size=(224,224), n_channel=3, count_adds=True, input=None, **kwargs): 153 | '''Only count the FLOPs of conv and linear layers (no BN layers etc.). 154 | Only count the weight computation (bias not included since it is negligible) 155 | ''' 156 | if hasattr(img_size, '__len__'): 157 | height, width = img_size 158 | else: 159 | assert isinstance(img_size, int) 160 | height, width = img_size, img_size 161 | 162 | # model = copy.deepcopy(model) 163 | list_conv = [] 164 | def conv_hook(self, input, output): 165 | flops = np.prod(self.weight.data.shape) * output.size(2) * output.size(3) / self.groups 166 | list_conv.append(flops) 167 | 168 | list_linear = [] 169 | def linear_hook(self, input, output): 170 | flops = np.prod(self.weight.data.shape) 171 | list_linear.append(flops) 172 | 173 | def register_hooks(net, hooks): 174 | childrens = list(net.children()) 175 | if not childrens: 176 | if isinstance(net, torch.nn.Conv2d): 177 | h = net.register_forward_hook(conv_hook) 178 | hooks += [h] 179 | if isinstance(net, torch.nn.Linear): 180 | h = net.register_forward_hook(linear_hook) 181 | hooks += [h] 182 | return 183 | 184 | for c in childrens: 185 | register_hooks(c, hooks) 186 | 187 | hooks = [] 188 | register_hooks(model, hooks) 189 | if input is None: 190 | input = torch.rand(1, n_channel, height, width) 191 | use_cuda = next(model.parameters()).is_cuda 192 | if use_cuda: 193 | input = input.cuda() 194 | 195 | # forward 196 | is_train = model.training 197 | model.eval() 198 | with torch.no_grad(): 199 | model(input, **kwargs) 200 | total_flops = (sum(list_conv) + sum(list_linear)) 201 | if count_adds: 202 | total_flops *= 2 203 | 204 | # reset to original model 205 | for h in hooks: h.remove() # clear hooks 206 | if is_train: model.train() 207 | return total_flops 208 | 209 | # refer to: https://github.com/alecwangcq/EigenDamage-Pytorch/blob/master/utils/common_utils.py 210 | class PresetLRScheduler(object): 211 | """Using a manually designed learning rate schedule rules. 212 | """ 213 | def __init__(self, decay_schedule): 214 | if not isinstance(decay_schedule, dict): 215 | assert isinstance(decay_schedule, str) 216 | decay_schedule = strdict_to_dict(decay_schedule) 217 | 218 | # decay_schedule is a dictionary 219 | # which is for specifying iteration -> lr 220 | self.decay_schedule = {} 221 | for k, v in decay_schedule.items(): # a dict, example: {"0":0.001, "30":0.00001, "45":0.000001} 222 | self.decay_schedule[int(float(k))] = v # to float first in case of '1e3' 223 | # print('Using a preset learning rate schedule:') 224 | # print(self.decay_schedule) 225 | 226 | def __call__(self, optimizer, e): 227 | epochs = list(self.decay_schedule.keys()) 228 | epochs = sorted(epochs) # example: [0, 30, 45] 229 | lr = self.decay_schedule[epochs[-1]] 230 | for i in range(len(epochs) - 1): 231 | if epochs[i] <= e < epochs[i+1]: 232 | lr = self.decay_schedule[epochs[i]] 233 | break 234 | for param_group in optimizer.param_groups: 235 | param_group['lr'] = lr 236 | return lr 237 | 238 | def get_lr(optimizer): 239 | for param_group in optimizer.param_groups: 240 | lr = param_group['lr'] 241 | return lr 242 | 243 | def plot_weights_heatmap(weights, out_path): 244 | ''' 245 | weights: [N, C, H, W]. Torch tensor 246 | averaged in dim H, W so that we get a 2-dim color map of size [N, C] 247 | ''' 248 | w_abs = weights.abs() 249 | w_abs = w_abs.data.cpu().numpy() 250 | 251 | fig, ax = plt.subplots() 252 | im = ax.imshow(w_abs, cmap='jet') 253 | 254 | # make a beautiful colorbar 255 | divider = make_axes_locatable(ax) 256 | cax = divider.append_axes('right', size=0.05, pad=0.05) 257 | fig.colorbar(im, cax=cax, orientation='vertical') 258 | 259 | ax.set_xlabel("Channel") 260 | ax.set_ylabel("Filter") 261 | fig.savefig(out_path, dpi=200) 262 | plt.close(fig) 263 | 264 | def strlist_to_list(sstr, ttype=str): 265 | ''' 266 | example: 267 | # self.args.stage_pr = [0, 0.3, 0.3, 0.3, 0, ] 268 | # self.args.skip_layers = ['1.0', '2.0', '2.3', '3.0', '3.5', ] 269 | turn these into a list of (float or str or int etc.) 270 | ''' 271 | if not sstr: 272 | return sstr 273 | out = [] 274 | sstr = sstr.strip() 275 | if sstr.startswith('[') and sstr.endswith(']'): 276 | sstr = sstr[1:-1] 277 | for x in sstr.split(','): 278 | x = x.strip() 279 | if x: 280 | x = ttype(x) 281 | out.append(x) 282 | return out 283 | 284 | def strdict_to_dict(sstr, ttype): 285 | ''' 286 | '{"1": 0.04, "2": 0.04, "4": 0.03, "5": 0.02, "7": 0.03, }' 287 | ''' 288 | if not sstr: 289 | return sstr 290 | out = {} 291 | sstr = sstr.strip() 292 | if sstr.startswith('{') and sstr.endswith('}'): 293 | sstr = sstr[1:-1] 294 | sep = ';' if ';' in sstr else ',' 295 | for x in sstr.split(sep): 296 | x = x.strip() 297 | if x: 298 | k = x.split(':')[0] # note: key is always str 299 | if k.startswith("'"): k = k.strip("'") # remove ' ' 300 | if k.startswith('"'): k = k.strip('"') # remove " " 301 | v = ttype(x.split(':')[1].strip()) 302 | out[k] = v 303 | return out 304 | 305 | def check_path(x): 306 | if x: 307 | complete_path = glob.glob(x) 308 | assert(len(complete_path) == 1) 309 | x = complete_path[0] 310 | return x 311 | 312 | def parse_prune_ratio_vgg(sstr, num_layers=20): 313 | # example: [0-4:0.5, 5:0.6, 8-10:0.2] 314 | out = np.zeros(num_layers) 315 | if '[' in sstr: 316 | sstr = sstr.split("[")[1].split("]")[0] 317 | else: 318 | sstr = sstr.strip() 319 | for x in sstr.split(','): 320 | k = x.split(":")[0].strip() 321 | v = x.split(":")[1].strip() 322 | if k.isdigit(): 323 | out[int(k)] = float(v) 324 | else: 325 | begin = int(k.split('-')[0].strip()) 326 | end = int(k.split('-')[1].strip()) 327 | out[begin : end+1] = float(v) 328 | return list(out) 329 | 330 | 331 | def kronecker(A, B): 332 | return torch.einsum("ab,cd->acbd", A, B).view(A.size(0) * B.size(0), A.size(1) * B.size(1)) 333 | 334 | 335 | def np_to_torch(x): 336 | ''' 337 | np array to pytorch float tensor 338 | ''' 339 | x = np.array(x) 340 | x= torch.from_numpy(x).float() 341 | return x 342 | 343 | def kd_loss(student_scores, teacher_scores, temp=1, weights=None): 344 | '''Knowledge distillation loss: soft target 345 | ''' 346 | p = F.log_softmax(student_scores / temp, dim=1) 347 | q = F.softmax(teacher_scores / temp, dim=1) 348 | # l_kl = F.kl_div(p, q, size_average=False) / student_scores.shape[0] # previous working loss 349 | if isinstance(weights, type(None)): 350 | l_kl = F.kl_div(p, q, reduction='batchmean') # 2020-06-21 @mst: Since 'size_average' is deprecated, use 'reduction' instead. 351 | else: 352 | l_kl = (F.kl_div(p, q, reduction='none').sum(dim=1) * weights).sum() 353 | return l_kl 354 | 355 | def test(net, test_loader): 356 | n_example_test = 0 357 | total_correct = 0 358 | avg_loss = 0 359 | is_train = net.training 360 | net.eval() 361 | with torch.no_grad(): 362 | pred_total = [] 363 | label_total = [] 364 | for _, (images, labels) in enumerate(test_loader): 365 | n_example_test += images.size(0) 366 | images = images.cuda() 367 | labels = labels.cuda() 368 | output = net(images) 369 | avg_loss += nn.CrossEntropyLoss()(output, labels).sum() 370 | pred = output.data.max(1)[1] 371 | total_correct += pred.eq(labels.data.view_as(pred)).sum() 372 | pred_total.extend(list(pred.data.cpu().numpy())) 373 | label_total.extend(list(labels.data.cpu().numpy())) 374 | 375 | acc = float(total_correct) / n_example_test 376 | avg_loss /= n_example_test 377 | 378 | # get accuracy per class 379 | n_class = output.size(1) 380 | acc_test = [0] * n_class 381 | cnt_test = [0] * n_class 382 | for p, l in zip(pred_total, label_total): 383 | acc_test[l] += int(p == l) 384 | cnt_test[l] += 1 385 | acc_per_class = [] 386 | for c in range(n_class): 387 | acc_test[c] = 0 if cnt_test[c] == 0 else acc_test[c] / float(cnt_test[c]) 388 | acc_per_class.append(acc_test[c]) 389 | 390 | # return to the train state if necessary 391 | if is_train: 392 | net.train() 393 | return acc, avg_loss.item(), acc_per_class 394 | 395 | def get_project_path(ExpID): 396 | full_path = glob.glob("Experiments/*%s*" % ExpID) 397 | assert(len(full_path) == 1) # There should be only ONE folder with in its name. 398 | return full_path[0] 399 | 400 | def parse_ExpID(path): 401 | '''parse out the ExpID from 'path', which can be a file or directory. 402 | Example: Experiments/AE__ckpt_epoch_240.pth__LR1.5__originallabel__vgg13_SERVER138-20200829-202307/gen_img 403 | Example: Experiments/AE__ckpt_epoch_240.pth__LR1.5__originallabel__vgg13_SERVER-20200829-202307/gen_img 404 | ''' 405 | return 'SERVER' + path.split('_SERVER')[1].split('/')[0] 406 | 407 | def mkdirs(*paths): 408 | for p in paths: 409 | if not os.path.exists(p): 410 | os.makedirs(p) 411 | 412 | class EMA(): 413 | ''' 414 | Exponential Moving Average for pytorch tensor 415 | ''' 416 | def __init__(self, mu): 417 | self.mu = mu 418 | self.history = {} 419 | 420 | def __call__(self, name, x): 421 | ''' 422 | Note: this func will modify x directly, no return value. 423 | x is supposed to be a pytorch tensor. 424 | ''' 425 | if self.mu > 0: 426 | assert(0 < self.mu < 1) 427 | if name in self.history.keys(): 428 | new_average = self.mu * self.history[name] + (1.0 - self.mu) * x.clone() 429 | else: 430 | new_average = x.clone() 431 | self.history[name] = new_average.clone() 432 | return new_average.clone() 433 | else: 434 | return x.clone() 435 | 436 | # Exponential Moving Average 437 | class EMA2(): 438 | def __init__(self, mu): 439 | self.mu = mu 440 | self.shadow = {} 441 | def register(self, name, value): 442 | self.shadow[name] = value.clone() 443 | def __call__(self, name, x): 444 | assert name in self.shadow 445 | new_average = (1.0 - self.mu) * x + self.mu * self.shadow[name] 446 | self.shadow[name] = new_average.clone() 447 | return new_average 448 | 449 | def register_ema(emas): 450 | for net, ema in emas: 451 | for name, param in net.named_parameters(): 452 | if param.requires_grad: 453 | ema.register(name, param.data) 454 | 455 | def apply_ema(emas): 456 | for net, ema in emas: 457 | for name, param in net.named_parameters(): 458 | if param.requires_grad: 459 | param.data = ema(name, param.data) 460 | 461 | colors = ["gray", "blue", "black", "yellow", "green", "yellowgreen", "gold", "royalblue", "peru", "purple"] 462 | def feat_visualize(ax, feat, label): 463 | ''' 464 | feat: N x 2 # 2-d feature, N: number of examples 465 | label: N x 1 466 | ''' 467 | for ix in range(len(label)): 468 | x = feat[ix] 469 | y = label[ix] 470 | ax.scatter(x[0], x[1], color=colors[y], marker=".") 471 | return ax 472 | 473 | def _remove_module_in_name(name): 474 | ''' remove 'module.' in the module name, caused by DataParallel, if any 475 | ''' 476 | module_name_parts = name.split(".") 477 | module_name_parts_new = [] 478 | for x in module_name_parts: 479 | if x != 'module': 480 | module_name_parts_new.append(x) 481 | new_name = '.'.join(module_name_parts_new) 482 | return new_name 483 | 484 | def smart_weights_load(net, w_path, key=None, load_mode='exact'): 485 | ''' 486 | This func is to load the weights of into . 487 | ''' 488 | common_weights_keys = ['T', 'S', 'G', 'model', 'state_dict', 'state_dict_t'] 489 | 490 | ckpt = torch.load(w_path, map_location=lambda storage, location: storage) 491 | 492 | # get state_dict 493 | if isinstance(ckpt, OrderedDict): 494 | state_dict = ckpt 495 | else: 496 | if key: 497 | state_dict = ckpt[key] 498 | else: 499 | intersection = [k for k in ckpt.keys() if k in common_weights_keys and isinstance(ckpt[k], OrderedDict)] 500 | if len(intersection) == 1: 501 | k = intersection[0] 502 | state_dict = ckpt[k] 503 | else: 504 | print('Error: multiple or no model keys found in ckpt: %s. Please explicitly appoint one' % intersection) 505 | exit(1) 506 | 507 | if load_mode == 'exact': # net and state_dict have exactly the same architecture (layer names etc. are exactly same) 508 | try: 509 | net.load_state_dict(state_dict) 510 | except: 511 | ckpt_data_parallel = False 512 | for k, v in state_dict.items(): 513 | if k.startswith('module.'): 514 | ckpt_data_parallel = True # DataParallel was used in the ckpt 515 | break 516 | 517 | if ckpt_data_parallel: 518 | # If ckpt used DataParallel, then the reason of the load failure above should be that the does not use 519 | # DataParallel. Therefore, remove the surfix 'module.' in ckpt. 520 | new_state_dict = OrderedDict() 521 | for k, v in state_dict.items(): 522 | param_name = k.split("module.")[-1] 523 | new_state_dict[param_name] = v 524 | else: 525 | # Similarly, if ckpt didn't use DataParallel, here we add the surfix 'module.'. 526 | new_state_dict = OrderedDict() 527 | for k, v in state_dict.items(): 528 | param_name = 'module.' + k 529 | new_state_dict[param_name] = v 530 | net.load_state_dict(new_state_dict) 531 | 532 | else: 533 | # Here is the case that and ckpt only have part of weights in common. Then load them by module name: 534 | # for every named module in , if ckpt has a module of the same (or contextually similar) name, then they are matched and weights are loaded from ckpt to . 535 | for name, m in net.named_modules(): 536 | print(name) 537 | 538 | for name, m in net.named_modules(): 539 | if name: 540 | print('loading weights for module "%s" in the network' % name) 541 | new_name = _remove_module_in_name(name) 542 | 543 | # find the matched module name 544 | matched_param_name = '' 545 | for k in ckpt.keys(): 546 | new_k = _remove_module_in_name(k) 547 | if new_name == new_k: 548 | matched_param_name = k 549 | break 550 | 551 | # load weights 552 | if matched_param_name: 553 | m.weight.copy_(ckpt[matched_param_name]) 554 | print("net module name: '%s' <- '%s' (ckpt module name)" % (name, matched_param_name)) 555 | else: 556 | print("Error: cannot find matched module in ckpt for module '%s' in net. Please check manually." % name) 557 | exit(1) 558 | 559 | # parse wanted value from accuracy print log 560 | def parse_acc_log(line, key, type_func=float): 561 | line_seg = line.strip().lower().split() 562 | for i in range(len(line_seg)): 563 | if key in line_seg[i]: 564 | break 565 | if i == len(line_seg) - 1: 566 | return None # did not find the in this line 567 | try: 568 | value = type_func(line_seg[i+1]) 569 | except: 570 | value = type_func(line_seg[i+2]) 571 | return value 572 | 573 | def get_layer_by_index(net, index): 574 | cnt = -1 575 | for _, m in net.named_modules(): 576 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 577 | cnt += 1 578 | if cnt == index: 579 | return m 580 | return None 581 | 582 | def get_total_index_by_learnable_index(net, learnable_index): 583 | ''' 584 | learnable_index: index when only counting learnable layers (conv or fc, no bn); 585 | total_index: count relu, pooling etc in. 586 | ''' 587 | layer_type_considered = [nn.Conv2d, nn.ReLU, nn.LeakyReLU, nn.PReLU, 588 | nn.BatchNorm2d, nn.MaxPool2d, nn.AvgPool2d, nn.Linear] 589 | cnt_total = -1 590 | cnt_learnable = -1 591 | for _, m in net.named_modules(): 592 | cond = [isinstance(m, x) for x in layer_type_considered] 593 | if any(cond): 594 | cnt_total += 1 595 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 596 | cnt_learnable += 1 597 | if cnt_learnable == learnable_index: 598 | return cnt_total 599 | return None 600 | 601 | def cal_correlation(x, coef=False): 602 | '''Calculate the correlation matrix for a pytorch tensor. 603 | Input shape: [n_sample, n_attr] 604 | Output shape: [n_attr, n_attr] 605 | Refer to: https://github.com/pytorch/pytorch/issues/1254 606 | ''' 607 | # calculate covariance matrix 608 | y = x - x.mean(dim=0) 609 | c = y.t().mm(y) / (y.size(0) - 1) 610 | 611 | if coef: 612 | # normalize covariance matrix 613 | d = torch.diag(c) 614 | stddev = torch.pow(d, 0.5) 615 | c = c.div(stddev.expand_as(c)) 616 | c = c.div(stddev.expand_as(c).t()) 617 | 618 | # clamp between -1 and 1 619 | # probably not necessary but numpy does it 620 | c = torch.clamp(c, -1.0, 1.0) 621 | return c 622 | 623 | def get_class_corr(loader, model): 624 | model.eval().cuda() 625 | logits = 0 626 | n_batch = len(loader) 627 | with torch.no_grad(): 628 | for ix, data in enumerate(loader): 629 | input = data[0] 630 | print('[%d/%d] -- forwarding' % (ix, n_batch)) 631 | input = input.float().cuda() 632 | if type(logits) == int: 633 | logits = model(input) # [batch_size, n_class] 634 | else: 635 | logits = torch.cat([logits, model(input)], dim=0) 636 | # Use numpy: 637 | # logits -= logits.mean(dim=0) 638 | # logits = logits.data.cpu().numpy() 639 | # corr = np.corrcoef(logits, rowvar=False) 640 | 641 | # Use pytorch 642 | corr = cal_correlation(logits, coef=True) 643 | return corr 644 | 645 | def cal_acc(logits, y): 646 | pred = logits.argmax(dim=1) 647 | acc = pred.eq(y.data.view_as(pred)).sum().float() / y.size(0) 648 | return acc 649 | 650 | class Timer(): 651 | '''Log down iteration time and predict the left time for the left iterations 652 | ''' 653 | def __init__(self, total_epoch): 654 | self.total_epoch = total_epoch 655 | self.time_stamp = [] 656 | 657 | def predict_finish_time(self, ave_window=3): 658 | self.time_stamp.append(time.time()) # update time stamp 659 | if len(self.time_stamp) == 1: 660 | return 'only one time stamp, not enough to predict' 661 | interval = [] 662 | for i in range(len(self.time_stamp) - 1): 663 | t = self.time_stamp[i + 1] - self.time_stamp[i] 664 | interval.append(t) 665 | sec_per_epoch = np.mean(interval[-ave_window:]) 666 | left_t = sec_per_epoch * (self.total_epoch - len(interval)) 667 | finish_t = left_t + time.time() 668 | finish_t = time.strftime('%Y/%m/%d-%H:%M', time.localtime(finish_t)) 669 | total_t = '%.2fh' % ((np.sum(interval) + left_t) / 3600.) 670 | return finish_t + ' (speed: %.2fs per timing, total_time: %s)' % (sec_per_epoch, total_t) 671 | 672 | def __call__(self): 673 | return(self.predict_finish_time()) 674 | 675 | class Dataset_npy_batch(Dataset): 676 | def __init__(self, npy_dir, transform, f='batch.npy'): 677 | self.data = np.load(os.path.join(npy_dir, f), allow_pickle=True) 678 | self.transform = transform 679 | def __getitem__(self, index): 680 | img = Image.fromarray(self.data[index][0]) 681 | img = self.transform(img) 682 | label = self.data[index][1] 683 | label = torch.LongTensor([label])[0] 684 | return img.squeeze(0), label 685 | def __len__(self): 686 | return len(self.data) 687 | 688 | class Dataset_lmdb_batch(Dataset): 689 | '''Dataset to load a lmdb data file. 690 | ''' 691 | def __init__(self, lmdb_path, transform): 692 | import lmdb 693 | env = lmdb.open(lmdb_path, readonly=True) 694 | with env.begin() as txn: 695 | self.data = [value for key, value in txn.cursor()] 696 | self.transform = transform 697 | def __getitem__(self, index): 698 | img, label = pickle.loads(self.data[index]) # PIL image 699 | if self.transform: 700 | img = self.transform(img) 701 | return img, label 702 | def __len__(self): 703 | return len(self.data) 704 | 705 | def merge_args(args, params_json): 706 | import json, yaml 707 | ''' is from argparser. is a json/yaml file. 708 | merge them, if there is collision, the param in has a higher priority. 709 | ''' 710 | with open(params_json) as f: 711 | if params_json.endswith('.json'): 712 | params = json.load(f) 713 | elif params_json.endswith('.yaml'): 714 | params = yaml.load(f, Loader=yaml.FullLoader) 715 | else: 716 | raise NotImplementedError 717 | for k, v in params.items(): 718 | args.__dict__[k] = v 719 | return args 720 | 721 | class AccuracyManager(): 722 | def __init__(self): 723 | import pandas as pd 724 | self.accuracy = pd.DataFrame() 725 | 726 | def update(self, time, acc1, acc5=None): 727 | acc = pd.DataFrame([[time, acc1, acc5]], columns=['time', 'acc1', 'acc5']) # time can be epoch or step 728 | self.accuracy = self.accuracy.append(acc, ignore_index=True) 729 | 730 | def get_best_acc(self, criterion='acc1'): 731 | assert criterion in ['acc1', 'acc5'] 732 | acc = self.accuracy.sort_values(by=criterion) # ascending sort 733 | best = acc.iloc[-1] # the last row 734 | time, acc1, acc5 = best.time, best.acc1, best.acc5 735 | return time, acc1, acc5 736 | 737 | def get_last_acc(self): 738 | last = self.accuracy.iloc[-1] 739 | time, acc1, acc5 = last.time, last.acc1, last.acc5 740 | return time, acc1, acc5 741 | 742 | def format_acc_log(acc1_set, lr, acc5=None, time_unit='Epoch'): 743 | '''return uniform format for the accuracy print 744 | ''' 745 | acc1, acc1_time, acc1_best, acc1_best_time = acc1_set 746 | if acc5: 747 | line = 'Acc1 %.4f Acc5 %.4f @ %s %d (Best_Acc1 %.4f @ %s %d) LR %s' % (acc1, acc5, time_unit, acc1_time, acc1_best, time_unit, acc1_best_time, lr) 748 | else: 749 | line = 'Acc1 %.4f @ %s %d (Best_Acc1 %.4f @ %s %d) LR %s' % (acc1, time_unit, acc1_time, acc1_best, time_unit, acc1_best_time, lr) 750 | return line 751 | 752 | def get_lambda(alpha=1.0): 753 | '''Return lambda''' 754 | if alpha > 0.: 755 | lam = np.random.beta(alpha, alpha) 756 | else: 757 | lam = 1. 758 | return lam 759 | 760 | # refer to: 2018-ICLR-mixup 761 | # https://github.com/facebookresearch/mixup-cifar10/blob/eaff31ab397a90fbc0a4aac71fb5311144b3608b/train.py#L119 762 | def mixup_data(x, y, alpha=1.0, use_cuda=True): 763 | '''Returns mixed inputs, pairs of targets, and lambda''' 764 | if alpha > 0: 765 | lam = np.random.beta(alpha, alpha) 766 | else: 767 | lam = 1 768 | 769 | batch_size = x.size()[0] 770 | if use_cuda: 771 | index = torch.randperm(batch_size).cuda() 772 | else: 773 | index = torch.randperm(batch_size) 774 | 775 | mixed_x = lam * x + (1 - lam) * x[index, :] 776 | y_a, y_b = y, y[index] 777 | return mixed_x, y_a, y_b, lam 778 | 779 | 780 | def mixup_criterion(criterion, pred, y_a, y_b, lam): 781 | return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b) 782 | 783 | 784 | def visualize_filter(layer, layer_id, save_dir, n_filter_plot=16, n_channel_plot=16, pick_mode='rand', plot_abs=True, prefix='', ext='.pdf'): 785 | '''layer is a pytorch model layer 786 | ''' 787 | w = layer.weight.data.cpu().numpy() # shape: [N, C, H, W] 788 | if plot_abs: 789 | w = np.abs(w) 790 | n, c = w.shape[0], w.shape[1] 791 | n_filter_plot = min(n_filter_plot, n) 792 | n_channel_plot = min(n_channel_plot, c) 793 | if pick_mode == 'rand': 794 | filter_ix = np.random.permutation(n)[:n_filter_plot] # filter indexes to plot 795 | channel_ix = np.random.permutation(c)[:n_channel_plot] # channel indexes to plot 796 | else: 797 | filter_ix = list(range(n_filter_plot)) 798 | channel_ix = list(range(n_channel_plot)) 799 | 800 | # iteration for plotting 801 | for i in filter_ix: 802 | f_avg = np.mean(w[i], axis=0) 803 | fig, ax = plt.subplots() 804 | im = ax.imshow(f_avg, cmap='jet') 805 | # make a beautiful colorbar 806 | divider = make_axes_locatable(ax) 807 | cax = divider.append_axes('right', size=0.05, pad=0.05) 808 | fig.colorbar(im, cax=cax, orientation='vertical') 809 | save_path = '%s/filter_visualize__%s__layer%s__filter%s__average_cross_channel' % (save_dir, prefix, layer_id, i) # prefix is usually a net name 810 | fig.savefig(save_path + ext, bbox_inches='tight') 811 | plt.close(fig) 812 | 813 | for j in channel_ix: 814 | f = w[i][j] 815 | fig, ax = plt.subplots() 816 | im = ax.imshow(f, cmap='jet') 817 | # make a beautiful colorbar 818 | divider = make_axes_locatable(ax) 819 | cax = divider.append_axes('right', size=0.05, pad=0.05) 820 | fig.colorbar(im, cax=cax, orientation='vertical') 821 | save_path = '%s/filter_visualize__%s__layer%s__filter%s__channel%s' % (save_dir, prefix, layer_id, i, j) 822 | fig.savefig(save_path + ext, bbox_inches='tight') 823 | plt.close(fig) 824 | 825 | 826 | def visualize_feature_map(fm, layer_id, save_dir, n_channel_plot=16, pick_mode='rand', plot_abs=True, prefix='', ext='.pdf'): 827 | fm = fm.clone().detach() 828 | fm = fm.data.cpu().numpy()[0] # shape: [N, C, H, W], N is batch size. Default: batch size should be 1 829 | if plot_abs: 830 | fm = np.abs(fm) 831 | c = fm.shape[0] 832 | n_channel_plot = min(n_channel_plot, c) 833 | if pick_mode == 'rand': 834 | channel_ix = np.random.permutation(c)[:n_channel_plot] # channel indexes to plot 835 | else: 836 | channel_ix = list(range(n_channel_plot)) 837 | 838 | # iteration for plotting 839 | fm_avg = np.mean(fm, axis=0) 840 | fig, ax = plt.subplots() 841 | im = ax.imshow(fm_avg, cmap='jet') 842 | # make a beautiful colorbar 843 | divider = make_axes_locatable(ax) 844 | cax = divider.append_axes('right', size=0.05, pad=0.05) 845 | fig.colorbar(im, cax=cax, orientation='vertical') 846 | save_path = '%s/featmap_visualization__%s__layer%s__average_cross_channel' % (save_dir, prefix, layer_id) # prefix is usually a net name 847 | fig.savefig(save_path + ext, bbox_inches='tight') 848 | plt.close(fig) 849 | 850 | for j in channel_ix: 851 | f = fm[j] 852 | fig, ax = plt.subplots() 853 | im = ax.imshow(f, cmap='jet') 854 | # make a beautiful colorbar 855 | divider = make_axes_locatable(ax) 856 | cax = divider.append_axes('right', size=0.05, pad=0.05) 857 | fig.colorbar(im, cax=cax, orientation='vertical') 858 | save_path = '%s/featmap_visualization__%s__layer%s__channel%s' % (save_dir, prefix, layer_id, j) 859 | fig.savefig(save_path + ext, bbox_inches='tight') 860 | plt.close(fig) 861 | 862 | 863 | def add_noise_to_model(model, std=0.01): 864 | model = copy.deepcopy(model) # do not modify the original model 865 | for name, module in model.named_modules(): 866 | if isinstance(module, (nn.Conv2d, nn.Linear, nn.BatchNorm2d)): # all learnable params for a typical DNN 867 | w = module.weight 868 | w.data += torch.randn_like(w) * std 869 | return model 870 | 871 | 872 | # Refer to: https://github.com/ast0414/adversarial-example/blob/26ee4144a1771d3a565285e0a631056a6f42d49c/craft.py#L6 873 | def compute_jacobian(inputs, output): 874 | """ 875 | :param inputs: Batch X Size (e.g. Depth X Width X Height) 876 | :param output: Batch X Classes 877 | :return: jacobian: Batch X Classes X Size 878 | """ 879 | from torch.autograd.gradcheck import zero_gradients 880 | assert inputs.requires_grad 881 | num_classes = output.size()[1] 882 | 883 | jacobian = torch.zeros(num_classes, *inputs.size()) 884 | grad_output = torch.zeros(*output.size()) 885 | if inputs.is_cuda: 886 | grad_output = grad_output.cuda() 887 | jacobian = jacobian.cuda() 888 | 889 | for i in range(num_classes): 890 | zero_gradients(inputs) 891 | grad_output.zero_() 892 | grad_output[:, i] = 1 893 | output.backward(grad_output, retain_graph=True) 894 | jacobian[i] = inputs.grad.data 895 | 896 | return torch.transpose(jacobian, dim0=0, dim1=1) 897 | 898 | def get_jacobian_singular_values(model, data_loader, num_classes, n_loop=20, print_func=print, rand_data=False): 899 | jsv, condition_number = [], [] 900 | if rand_data: 901 | picked_batch = np.random.permutation(len(data_loader))[:n_loop] 902 | else: 903 | picked_batch = list(range(n_loop)) 904 | for i, (images, target) in enumerate(data_loader): 905 | if i in picked_batch: 906 | images, target = images.cuda(), target.cuda() 907 | batch_size = images.size(0) 908 | images.requires_grad = True # for Jacobian computation 909 | output = model(images) 910 | jacobian = compute_jacobian(images, output) # shape [batch_size, num_classes, num_channels, input_width, input_height] 911 | jacobian = jacobian.view(batch_size, num_classes, -1) # shape [batch_size, num_classes, num_channels*input_width*input_height] 912 | u, s, v = torch.svd(jacobian) # u: [batch_size, num_channels*input_width*input_height, num_classes], s: [batch_size, num_classes], v: [batch_size, num_channels*input_width*input_height, num_classes] 913 | s = s.data.cpu().numpy() 914 | jsv.append(s) 915 | condition_number.append(s.max(axis=1) / s.min(axis=1)) 916 | print_func('[%3d/%3d] calculating Jacobian...' % (i, len(data_loader))) 917 | jsv = np.concatenate(jsv) 918 | condition_number = np.concatenate(condition_number) 919 | return jsv, condition_number 920 | 921 | def approximate_entropy(X, num_bins=10, esp=1e-30): 922 | '''X shape: [num_sample, n_var], numpy array. 923 | ''' 924 | entropy = [] 925 | for di in range(X.shape[1]): 926 | samples = X[:, di] 927 | bins = np.linspace(samples.min(), samples.max(), num=num_bins+1) 928 | prob = np.histogram(samples, bins=bins, density=False)[0] / len(samples) 929 | entropy.append((-np.log2(prob + esp) * prob).sum()) # esp for numerical stability when prob = 0 930 | return np.mean(entropy) 931 | 932 | # matplotlib utility functions 933 | def set_ax(ax): 934 | '''This will modify ax in place. 935 | ''' 936 | # set background 937 | ax.grid(color='white') 938 | ax.set_facecolor('whitesmoke') 939 | 940 | # remove axis line 941 | ax.spines['right'].set_visible(False) 942 | ax.spines['left'].set_visible(False) 943 | ax.spines['top'].set_visible(False) 944 | ax.spines['bottom'].set_visible(False) 945 | 946 | # remove tick but keep the values 947 | ax.xaxis.set_ticks_position('none') 948 | ax.yaxis.set_ticks_position('none') 949 | 950 | def parse_value(line, key, type_func=float, exact_key=True): 951 | '''Parse a line with the key 952 | ''' 953 | try: 954 | if exact_key: # back compatibility 955 | value = line.split(key)[1].strip().split()[0] 956 | if value.endswith(')'): # hand-fix case: "Epoch 23)" 957 | value = value[:-1] 958 | value = type_func(value) 959 | else: 960 | line_seg = line.split() 961 | for i in range(len(line_seg)): 962 | if key in line_seg[i]: # example: 'Acc1: 0.7' 963 | break 964 | if i == len(line_seg) - 1: 965 | return None # did not find the in this line 966 | value = type_func(line_seg[i + 1]) 967 | return value 968 | except: 969 | print('Got error for line: "%s". Please check.' % line) 970 | 971 | def to_tensor(x): 972 | x = np.array(x) 973 | x = torch.from_numpy(x).float() 974 | return x 975 | 976 | def denormalize_image(x, mean, std): 977 | '''x shape: [N, C, H, W], batch image 978 | ''' 979 | x = x.cuda() 980 | mean = to_tensor(mean).cuda() 981 | std = to_tensor(std).cuda() 982 | mean = mean.unsqueeze(0).unsqueeze(2).unsqueeze(3) # shape: [1, C, 1, 1] 983 | std = std.unsqueeze(0).unsqueeze(2).unsqueeze(3) 984 | x = std * x + mean 985 | return x 986 | 987 | def make_one_hot(labels, C): # labels: [N] 988 | '''turn a batch of labels to the one-hot form 989 | ''' 990 | labels = labels.unsqueeze(1) # [N, 1] 991 | one_hot = torch.zeros(labels.size(0), C).cuda() 992 | target = one_hot.scatter_(1, labels, 1) 993 | return target 994 | 995 | class AverageMeter(object): 996 | """Computes and stores the average and current value""" 997 | def __init__(self, name, fmt=':f'): 998 | self.name = name 999 | self.fmt = fmt 1000 | self.reset() 1001 | 1002 | def reset(self): 1003 | self.val = 0 1004 | self.avg = 0 1005 | self.sum = 0 1006 | self.count = 0 1007 | 1008 | def update(self, val, n=1): 1009 | self.val = val 1010 | self.sum += val * n 1011 | self.count += n 1012 | self.avg = self.sum / self.count 1013 | 1014 | def __str__(self): 1015 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 1016 | return fmtstr.format(**self.__dict__) 1017 | 1018 | class ProgressMeter(object): 1019 | def __init__(self, num_batches, meters, prefix=""): 1020 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 1021 | self.meters = meters 1022 | self.prefix = prefix 1023 | 1024 | def display(self, batch): 1025 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 1026 | entries += [str(meter) for meter in self.meters] 1027 | print('\t'.join(entries)) 1028 | 1029 | def _get_batch_fmtstr(self, num_batches): 1030 | num_digits = len(str(num_batches // 1)) 1031 | fmt = '{:' + str(num_digits) + 'd}' 1032 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 1033 | 1034 | 1035 | def adjust_learning_rate(optimizer, epoch, args): 1036 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 1037 | lr = args.lr * (0.1 ** (epoch // 30)) 1038 | for param_group in optimizer.param_groups: 1039 | param_group['lr'] = lr 1040 | return lr 1041 | 1042 | def accuracy(output, target, topk=(1,)): 1043 | """Computes the accuracy over the k top predictions for the specified values of k""" 1044 | with torch.no_grad(): 1045 | maxk = max(topk) 1046 | batch_size = target.size(0) 1047 | _, pred = output.topk(maxk, 1, True, True) 1048 | pred = pred.t() # shape [maxk, batch_size] 1049 | correct = pred.eq(target.view(1, -1).expand_as(pred)) # target shape: [batch_size] -> [1, batch_size] -> [maxk, batch_size] 1050 | res = [] 1051 | for k in topk: 1052 | # correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) # Because of pytorch new versions, this does not work anymore (pt1.3 is okay, pt1.9 not okay). 1053 | correct_k = correct[:k].flatten().float().sum(0, keepdim=True) 1054 | res.append(correct_k.mul_(100.0 / batch_size)) 1055 | return res 1056 | 1057 | class LossLine(): 1058 | '''Format loss items for easy print. 1059 | ''' 1060 | def __init__(self): 1061 | self.log_dict = OrderedDict() 1062 | self.formats = OrderedDict() 1063 | def update(self, key, value, format): 1064 | self.log_dict[key] = value 1065 | self.formats[key] = format 1066 | def format(self, sep=' '): 1067 | out = [] 1068 | for k, v in self.log_dict.items(): 1069 | item = f"{k} {v:{self.formats[k]}}" 1070 | out.append(item) 1071 | return sep.join(out) 1072 | 1073 | --------------------------------------------------------------------------------