├── utils ├── __init__.py ├── PyTransformer │ ├── transformers │ │ ├── __init__.py │ │ ├── quantize.py │ │ ├── utils.py │ │ └── torchTransformer.py │ ├── examples │ │ ├── vgg16.png │ │ ├── alexnet.png │ │ ├── resnet18.png │ │ ├── googlenet.png │ │ ├── inception_v3.png │ │ ├── mobilenet_v2.png │ │ ├── squeezenet1_0.png │ │ └── shufflenet_v2_x1_0.png │ ├── .gitignore │ ├── test.py │ ├── README.md │ ├── transform_example.ipynb │ └── visualize_example.ipynb ├── summaries.py ├── lazy_property.py ├── metrics.py ├── utils.py ├── lr_scheduler.py ├── loss.py ├── sparsity.py ├── saver.py ├── cluster.py └── misc.py ├── dataloader ├── datasets │ ├── __init__.py │ ├── cars.py │ ├── imagenet.py │ ├── aircraft.py │ ├── cub200.py │ └── cifar10.py ├── __init__.py └── custom_transforms.py ├── modeling ├── DGMS │ ├── __init__.py │ ├── DGMSConv.py │ └── GMM.py ├── networks │ ├── resnet.py │ ├── proxylessnas.py │ ├── __init__.py │ ├── mnasnet.py │ ├── vgg_small_cifar.py │ └── resnet_cifar.py └── __init__.py ├── requirements.txt ├── .gitignore ├── tools ├── train_cifar.sh ├── train_imgnet.sh └── validation.sh ├── mypath.py ├── config.py ├── README.md ├── LICENSE └── main.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataloader/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/PyTransformer/transformers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /modeling/DGMS/__init__.py: -------------------------------------------------------------------------------- 1 | from .DGMSConv import * 2 | from .GMM import * -------------------------------------------------------------------------------- /utils/PyTransformer/examples/vgg16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RunpeiDong/DGMS/HEAD/utils/PyTransformer/examples/vgg16.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.6.0 2 | torchvision==0.7.0 3 | tensorboardX 4 | pillow 5 | graphviz 6 | inspect 7 | pydot 8 | kmeans-pytorch -------------------------------------------------------------------------------- /utils/PyTransformer/examples/alexnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RunpeiDong/DGMS/HEAD/utils/PyTransformer/examples/alexnet.png -------------------------------------------------------------------------------- /utils/PyTransformer/examples/resnet18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RunpeiDong/DGMS/HEAD/utils/PyTransformer/examples/resnet18.png -------------------------------------------------------------------------------- /utils/PyTransformer/examples/googlenet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RunpeiDong/DGMS/HEAD/utils/PyTransformer/examples/googlenet.png -------------------------------------------------------------------------------- /utils/PyTransformer/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | .ipynb_checkpoints 4 | *.py[cod] 5 | *$py.class 6 | -------------------------------------------------------------------------------- /utils/PyTransformer/examples/inception_v3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RunpeiDong/DGMS/HEAD/utils/PyTransformer/examples/inception_v3.png -------------------------------------------------------------------------------- /utils/PyTransformer/examples/mobilenet_v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RunpeiDong/DGMS/HEAD/utils/PyTransformer/examples/mobilenet_v2.png -------------------------------------------------------------------------------- /utils/PyTransformer/examples/squeezenet1_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RunpeiDong/DGMS/HEAD/utils/PyTransformer/examples/squeezenet1_0.png -------------------------------------------------------------------------------- /utils/PyTransformer/examples/shufflenet_v2_x1_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RunpeiDong/DGMS/HEAD/utils/PyTransformer/examples/shufflenet_v2_x1_0.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | .ipynb_checkpoints 4 | *.py[cod] 5 | *$py.class 6 | 7 | # Checkpoint files 8 | *.pth.tar -------------------------------------------------------------------------------- /utils/summaries.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tensorboardX import SummaryWriter 3 | 4 | class TensorboardSummary(object): 5 | def __init__(self, directory): 6 | self.directory = directory 7 | 8 | def create_summary(self): 9 | writer = SummaryWriter(log_dir=os.path.join(self.directory)) 10 | return writer 11 | -------------------------------------------------------------------------------- /utils/lazy_property.py: -------------------------------------------------------------------------------- 1 | # https://stevenloria.com/lazy-properties/ 2 | 3 | def lazy_property(fn): 4 | '''Decorator that makes a property lazy-evaluated. 5 | ''' 6 | attr_name = '_lazy_' + fn.__name__ 7 | 8 | @property 9 | def _lazy_property(self): 10 | if not hasattr(self, attr_name): 11 | setattr(self, attr_name, fn(self)) 12 | return getattr(self, attr_name) 13 | return _lazy_property 14 | -------------------------------------------------------------------------------- /modeling/networks/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | import torch.utils.model_zoo as model_zoo 6 | import torchvision.models as models 7 | 8 | def resnet(args, **kwargs): 9 | """Constructs a ResNet model.""" 10 | if args.pretrained: 11 | model = models.__dict__[args.network](pretrained=True) 12 | print("ImageNet pretrained model loaded!") 13 | else: 14 | model = models.__dict__[args.network]() 15 | return model 16 | -------------------------------------------------------------------------------- /modeling/networks/proxylessnas.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def proxyless_nas_mobile(args): 4 | target_platform = "proxyless_mobile" # proxyless_gpu, proxyless_mobile, proxyless_mobile14 are also avaliable. 5 | if args.pretrained: 6 | model = torch.hub.load('mit-han-lab/ProxylessNAS', target_platform, pretrained=True) 7 | print("ImageNet pretrained ProxylessNAS-Mobile loaded! (Pretrained Top-1 Acc: 74.59%)") 8 | else: 9 | model = torch.hub.load('mit-han-lab/ProxylessNAS', target_platform, pretrained=False) 10 | return model 11 | -------------------------------------------------------------------------------- /tools/train_cifar.sh: -------------------------------------------------------------------------------- 1 | DATASET="--train-dir Path2DatasetCIFAR10/train/ --val-dir Path2DatasetCIFAR10/val/ -d cifar10 --num-classes 10" 2 | GENERAL="--lr 2e-5 --batch-size 128 --epochs 350 --workers 4 --base-size 32 --crop-size 32" 3 | INFO="--checkname vggsmall2bit --lr-scheduler one-cycle" 4 | MODEL="--network vggsmall --mask --K 4 --weight-decay 5e-4 --empirical True" 5 | PARAMS="--tau 0.01" 6 | NORMAL="--normal" 7 | RESUME="--resume Path2FP32Model --rt --show-info" 8 | DEVICES="0" 9 | GPU="--gpu-ids 0" 10 | CUDA_VISIBLE_DEVICES=$DEVICES python3 main.py $GPU $DATASET $GENERAL $MODEL $INFO $PARAMS $RESUME -------------------------------------------------------------------------------- /tools/train_imgnet.sh: -------------------------------------------------------------------------------- 1 | DATASET="--train-dir Path2DatasetImageNet/train/ --val-dir Path2DatasetImageNet/val/ -d imagenet --num-classes 1000" 2 | GENERAL="--lr 2e-5 --batch-size 256 --test-batch-size 256 --epochs 60 --workers 4 --base-size 256 --crop-size 224" 3 | INFO="--checkname resnet18_4bit --lr-scheduler one-cycle" 4 | MODEL="--network resnet18 --mask --K 16 --weight-decay 5e-4" 5 | PARAMS="--tau 0.01" 6 | NORMAL="--normal" 7 | PRETRAINED="--pretrained --rt --show-info" 8 | DEVICES="0,1" 9 | GPU="--gpu-ids 0,1" 10 | CUDA_VISIBLE_DEVICES=$DEVICES python3 main.py $GPU $DATASET $GENERAL $MODEL $INFO $PARAMS $PRETRAINED -------------------------------------------------------------------------------- /modeling/networks/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models as models 4 | 5 | from .resnet import resnet 6 | from .mnasnet import * 7 | from .proxylessnas import * 8 | from .resnet_cifar import * 9 | from .vgg_small_cifar import * 10 | 11 | def get_network(args): 12 | return { 13 | 'resnet18': resnet, 14 | 'resnet50': resnet, 15 | 'mnasnet': mnasnet1_0, 16 | 'proxylessnas': proxyless_nas_mobile, 17 | 'vggsmall': vggsmall, 18 | 'resnet20': resnet20, 19 | 'resnet32': resnet32, 20 | 'resnet56': resnet56, 21 | }[args.network](args) 22 | -------------------------------------------------------------------------------- /modeling/networks/mnasnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models as models 4 | 5 | 6 | def mnasnet0_5(args, **kwargs): 7 | if args.pretrained: 8 | model = models.__dict__['mnasnet0_5'](pretrained=True) 9 | print("ImageNet pretrained model loaded!") 10 | else: 11 | model = models.__dict__['mnasnet0_5']() 12 | return model 13 | 14 | def mnasnet1_0(args, **kwargs): 15 | if args.pretrained: 16 | model = models.mnasnet1_0(pretrained=True) 17 | print("ImageNet pretrained model loaded!") 18 | else: 19 | model = models.mnasnet1_0() 20 | return model 21 | -------------------------------------------------------------------------------- /tools/validation.sh: -------------------------------------------------------------------------------- 1 | DATASET="--train-dir Path2DatasetImageNet/train/ --val-dir Path2DatasetImageNet/val/ -d imagenet --num-classes 1000" 2 | GENERAL="--lr 2e-5 --batch-size 256 --test-batch-size 256 --epochs 60 --workers 4 --base-size 256 --crop-size 224" 3 | INFO="--checkname inference --lr-scheduler one-cycle" 4 | RESUME="--resume checkpoints/resnet18/resnet18_4bit_K16_7029.pth.tar --only-inference True" 5 | MODEL="--network resnet18 --mask --K 16 --weight-decay 5e-4" 6 | PARAMS="--tau 0.01" 7 | NORMAL="--normal" 8 | PRETRAINED="--pretrained --rt --show-info" 9 | DEVICES="0" 10 | GPU="--gpu-ids 0" 11 | CUDA_VISIBLE_DEVICES=$DEVICES python3 main.py $GPU $DATASET $GENERAL $MODEL $INFO $PARAMS $PRETRAINED $RESUME -------------------------------------------------------------------------------- /mypath.py: -------------------------------------------------------------------------------- 1 | from torch.utils import data 2 | import config as cfg 3 | from config import DATA_FOLDERS 4 | class Path(object): 5 | @staticmethod 6 | def db_root_dir(dataset): 7 | if dataset == 'cifar10': 8 | return cfg.DATA_FOLDERS['cifar'] 9 | elif dataset == 'imagenet': 10 | return cfg.DATA_FOLDERS['imagenet'] 11 | elif dataset == 'cub200': 12 | return cfg.DATA_FOLDERS['cub200'] 13 | elif dataset == 'cars': 14 | return cfg.DATA_FOLDERS['cars'] 15 | elif dataset == 'aircraft': 16 | return cfg.DATA_FOLDERS['aircraft'] 17 | else: 18 | raise NotImplementedError("no support for dataset " + dataset) -------------------------------------------------------------------------------- /utils/PyTransformer/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | import torchvision.models as models 5 | import copy 6 | from transformers.torchTransformer import TorchTransformer 7 | from transformers.quantize import QConv2d 8 | model = models.__dict__["resnet18"]() 9 | model.cuda() 10 | model = model.eval() 11 | 12 | transofrmer = TorchTransformer() 13 | transofrmer.register(nn.Conv2d, QConv2d) 14 | model = transofrmer.trans_layers(model) 15 | print(model) 16 | sys.exit() 17 | 18 | 19 | input_tensor = torch.randn([1, 3, 224, 224]) 20 | input_tensor = input_tensor.cuda() 21 | net = transofrmer.summary(model, input_tensor=input_tensor) 22 | # transofrmer.visualize(model, input_tensor = input_tensor, save_name= "example", graph_size = 80) -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import numpy as np 3 | 4 | class Evaluator(object): 5 | def __init__(self, num_class, args): 6 | self.num_class = num_class 7 | self.args = args 8 | 9 | def Accuracy(self, output, target, topk=(1,)): 10 | """ Computes the precision@k for the specified values of k. """ 11 | maxk = max(topk) 12 | batch_size = target.size(0) 13 | 14 | _, pred = output.topk(maxk, 1, True, True) 15 | pred = pred.t() 16 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 17 | 18 | res = [] 19 | for k in topk: 20 | correct_k = correct[:k].reshape(-1).float().sum(0) 21 | res.append(correct_k.mul_(100.0 / batch_size)) 22 | return res 23 | 24 | def reset(self): 25 | self.confusion_matrix = np.zeros((self.num_class,) * 2) 26 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import struct as st 2 | 3 | 4 | def save_data(tensor, path, is_act=False, to_int=False, to_hex=False, output_dir=None, q=0.0): 5 | def identity(x): 6 | return x 7 | 8 | def convert_int(x): 9 | return int(x) 10 | 11 | def convert_hex(x): 12 | return '%X' % st.unpack('H', st.pack('e', x)) 13 | 14 | def convert_act(x): 15 | return round((x * (2 ** q)).item()) 16 | 17 | print(f'Saving {path}') 18 | dir_name = output_dir 19 | 20 | type_cast = identity 21 | if to_int: 22 | type_cast = convert_int 23 | elif to_hex: 24 | type_cast = convert_hex 25 | elif is_act: 26 | type_cast = convert_act 27 | 28 | path = f'{dir_name}/{path}' 29 | with open(f'{path}.txt', 'w') as f: 30 | print('\n'.join( 31 | f'{type_cast(num.item())}' 32 | for num in tensor.half().view(-1) 33 | ), file=f) 34 | -------------------------------------------------------------------------------- /utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch.optim.lr_scheduler import CosineAnnealingLR, MultiStepLR, OneCycleLR, ReduceLROnPlateau 3 | 4 | def get_scheduler(args, optimizer, base_lr, steps_per_epoch=0): 5 | mode = args.lr_scheduler 6 | print('Using {} LR Scheduler!'.format(mode)) 7 | if mode == 'one-cycle': 8 | scheduler = OneCycleLR(optimizer, base_lr, 9 | steps_per_epoch=steps_per_epoch, 10 | epochs=args.epochs) 11 | elif mode == 'cosine': 12 | scheduler = CosineAnnealingLR( 13 | optimizer, T_max=args.epochs * steps_per_epoch) 14 | elif mode == 'multi-step': 15 | scheduler = MultiStepLR(optimizer, milestones=[ 16 | e * steps_per_epoch for e in args.schedule], gamma=0.1) 17 | else: 18 | assert mode == 'reduce' 19 | scheduler = ReduceLROnPlateau(optimizer) 20 | 21 | return scheduler 22 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | from numpy.core.fromnumeric import argmax, size 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | import torch.functional as F 6 | 7 | class RecognitionLosses(object): 8 | def __init__(self, size_average=True, batch_average=True, cuda=True, 9 | num_classes=10): 10 | self.size_average = size_average 11 | self.batch_average = batch_average 12 | self.cuda = cuda 13 | self.num_classes = num_classes 14 | 15 | def build_losses(self, mode='ce'): 16 | """Choices: ['ce' or 'focal' or 'mse']""" 17 | if mode == 'ce': 18 | return self.CrossEntropyLoss 19 | elif mode == 'focal': 20 | return self.FocalLoss 21 | else: 22 | raise NotImplementedError 23 | 24 | def CrossEntropyLoss(self, logit, target): 25 | return nn.CrossEntropyLoss(logit, target) 26 | 27 | def FocalLoss(self, logit, target): 28 | return 29 | -------------------------------------------------------------------------------- /modeling/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch.nn as nn 4 | 5 | from .networks import get_network 6 | from modeling.DGMS import DGMSConv 7 | 8 | class DGMSNet(nn.Module): 9 | def __init__(self, args, freeze_bn=False): 10 | super(DGMSNet, self).__init__() 11 | self.args = args 12 | self.network = get_network(args) 13 | self.freeze_bn = freeze_bn 14 | 15 | def init_mask_params(self): 16 | print("--> Start to initialize sub-distribution parameters, this may take some time...") 17 | for name, m in self.network.named_modules(): 18 | if isinstance(m, DGMSConv): 19 | m.init_mask_params() 20 | print("--> Sub-distribution parameters initialization finished!") 21 | 22 | def forward(self, x): 23 | x = self.network(x) 24 | return x 25 | 26 | def get_1x_lr_params(self): 27 | self.init_mask_params() 28 | modules = [self.network] 29 | for i in range(len(modules)): 30 | for m in modules[i].named_modules(): 31 | if self.freeze_bn: 32 | if isinstance(m[1], nn.Conv2d): 33 | if self.args.freeze_weights: 34 | for p in m[1].parameters(): 35 | pass 36 | else: 37 | for p in m[1].parameters(): 38 | if p.requires_grad: 39 | yield p 40 | else: 41 | for p in m[1].parameters(): 42 | if p.requires_grad: 43 | yield p 44 | -------------------------------------------------------------------------------- /dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import cifar10, imagenet, cub200, cars, aircraft 2 | 3 | def make_data_loader(args, **kwargs): 4 | 5 | if args.dataset == 'cifar10': 6 | _cifar10 = cifar10.CIFAR10_Module(args, **kwargs) 7 | train_loader = _cifar10.train_dataloader() 8 | val_loader = _cifar10.val_dataloader() 9 | test_loader = None 10 | num_class = _cifar10.num_class 11 | 12 | return train_loader, val_loader, test_loader, num_class 13 | 14 | elif args.dataset == 'cub200': 15 | _cub200 = cub200.CUB200(args, **kwargs) 16 | train_loader = _cub200.train_dataloader() 17 | val_loader = _cub200.val_dataloader() 18 | test_loader = None 19 | num_class = _cub200.num_class 20 | 21 | return train_loader, val_loader, test_loader, num_class 22 | 23 | elif args.dataset == 'cars': 24 | _cars = cars.Cars(args, **kwargs) 25 | train_loader = _cars.train_dataloader() 26 | val_loader = _cars.val_dataloader() 27 | test_loader = None 28 | num_class = _cars.num_class 29 | 30 | return train_loader, val_loader, test_loader, num_class 31 | 32 | elif args.dataset == 'aircraft': 33 | _aircfraft = aircraft.Aircraft(args, **kwargs) 34 | train_loader = _aircfraft.train_dataloader() 35 | val_loader = _aircfraft.val_dataloader() 36 | test_loader = None 37 | num_class = _aircfraft.num_class 38 | 39 | return train_loader, val_loader, test_loader, num_class 40 | 41 | 42 | elif args.dataset == 'imagenet': 43 | _imagenet = imagenet.ImageNet(args, **kwargs) 44 | train_loader = _imagenet.train_dataloader() 45 | val_loader = _imagenet.val_dataloader() 46 | test_loader = None 47 | num_class = _imagenet.num_class 48 | 49 | return train_loader, val_loader, test_loader, num_class 50 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | """ Global configurations file. 2 | """ 3 | 4 | # Dataset settings 5 | NUM_CLASSES = { 6 | 'cifar10': 10, 7 | 'imagenet': 1000, 8 | 'cub200': 200, 9 | 'cars': 196, 10 | 'aircraft': 100, 11 | } 12 | 13 | DATA_FOLDERS = { 14 | 'cifar': 'Path2DatasetCIFAR10/', 15 | 'imagenet': 'Path2DatasetImageNet/', 16 | 'cub200': 'Path2DatasetCUB_200_2011/', 17 | 'cars': 'Path2DatasetStanfordCars/', 18 | 'aircraft': 'Path2DatasetFGVCAircraft/', 19 | } 20 | 21 | MEANS = { 22 | 'cifar': (0.4914, 0.4822, 0.4465), 23 | 'imagenet': (0.485, 0.456, 0.406), 24 | 'cub200': (0.485, 0.456, 0.406), 25 | 'cars': (0.485, 0.456, 0.406), 26 | 'aircraft': (0.485, 0.456, 0.406), 27 | } 28 | 29 | STDS = { 30 | 'cifar': (0.2023, 0.1994, 0.2010), 31 | 'imagenet': (0.229, 0.224, 0.225), 32 | 'cub200': (0.229, 0.224, 0.225), 33 | 'cars': (0.229, 0.224, 0.225), 34 | 'aircraft': (0.229, 0.224, 0.225), 35 | } 36 | 37 | # Model definition 38 | TAU = 0.01 39 | IS_TRAIN = True 40 | K_LEVEL = 16 41 | IS_NORMAL = True 42 | IS_EMP = False 43 | 44 | # Training settings 45 | BATCH_SIZE = { 46 | 'cifar10': 128, 47 | 'imagenet': 256, 48 | 'cub200': 256, 49 | 'cars': 256, 50 | 'aircraft': 256, 51 | } 52 | 53 | EPOCH = { 54 | 'cifar10': 350, 55 | 'imagenet': 60, 56 | 'cub200': 60, 57 | 'cars': 60, 58 | 'aircraft': 60, 59 | } 60 | 61 | LAYER = { 62 | 'resnet20': 20, 63 | 'resnet32': 32, 64 | 'resnet56': 56, 65 | 'vggsmall': 7, 66 | 'resnet18': 21, 67 | 'resnet50': 54, 68 | 'mnasnet': 53, 69 | 'proxylessnas': 62, 70 | } 71 | 72 | L_CNT = 0 73 | LAYER_NUM = 20 74 | EPS = 1e-11 75 | KEEP = True 76 | DEBUG = False 77 | SKIPPED_LAYERS = [] 78 | 79 | def set_status(flag): 80 | global IS_TRAIN 81 | IS_TRAIN = flag 82 | 83 | def count_layer(): 84 | global L_CNT 85 | L_CNT = L_CNT + 1 86 | 87 | 88 | def set_config(args): 89 | global IS_EMP, IS_NORMAL, K_LEVEL, TAU, LAYER, LAYER_NUM, SKIPPED_LAYERS 90 | IS_EMP = args.empirical 91 | IS_NORMAL = args.normal 92 | TAU = args.tau 93 | K_LEVEL = args.K 94 | LAYER_NUM = LAYER[args.network] 95 | SKIPPED_LAYERS = [1, LAYER_NUM] -------------------------------------------------------------------------------- /dataloader/datasets/cars.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import config as cfg 4 | 5 | from torch.utils.data import Dataset, DataLoader 6 | from torchvision import datasets, transforms 7 | from torchvision.datasets import ImageFolder 8 | 9 | class Cars(Dataset): 10 | """`Stanford Cars `_ Dataset. 11 | """ 12 | def __init__(self, args, **kwargs): 13 | super(Cars, self).__init__() 14 | self.args = args 15 | self.num_class = cfg.NUM_CLASSES[args.dataset.lower()] 16 | 17 | @property 18 | def mean(self): 19 | return cfg.MEANS[self.args.dataset.lower()] 20 | 21 | @property 22 | def std(self): 23 | return cfg.STDS[self.args.dataset.lower()] 24 | 25 | def train_transform(self): 26 | return transforms.Compose([ 27 | transforms.Resize(self.args.base_size), 28 | transforms.RandomCrop(self.args.crop_size), 29 | transforms.RandomHorizontalFlip(), 30 | transforms.ToTensor(), 31 | transforms.Normalize(self.mean, self.std), 32 | ]) 33 | 34 | def val_transform(self): 35 | return transforms.Compose([ 36 | transforms.Resize(self.args.base_size), 37 | transforms.CenterCrop(self.args.crop_size), 38 | transforms.ToTensor(), 39 | transforms.Normalize(self.mean, self.std) 40 | ]) 41 | 42 | def train_dataloader(self): 43 | dataset = ImageFolder( 44 | root=self.args.train_dir, 45 | transform=self.train_transform(), 46 | ) 47 | dataloader = DataLoader(dataset, batch_size=self.args.batch_size, 48 | num_workers=4, shuffle=True, drop_last=True, pin_memory=True) 49 | return dataloader 50 | 51 | def val_dataloader(self): 52 | dataset = ImageFolder( 53 | root=self.args.val_dir, 54 | transform=self.val_transform(), 55 | ) 56 | dataloader = DataLoader(dataset, batch_size=self.args.test_batch_size, 57 | num_workers=4, shuffle=True, drop_last=True, pin_memory=True) 58 | return dataloader 59 | -------------------------------------------------------------------------------- /dataloader/datasets/imagenet.py: -------------------------------------------------------------------------------- 1 | import dataloader 2 | import os 3 | from torch.utils.data import dataset 4 | 5 | import torchvision.transforms as transforms 6 | import config as cfg 7 | 8 | from torch.utils.data import Dataset, DataLoader 9 | from torchvision.datasets import ImageFolder 10 | 11 | class ImageNet(Dataset): 12 | """`ImageNet `_ Dataset. 13 | """ 14 | def __init__(self, args, **kwargs): 15 | super(ImageNet, self).__init__() 16 | self.args = args 17 | self.num_class = cfg.NUM_CLASSES[args.dataset.lower()] 18 | 19 | @property 20 | def mean(self): 21 | return cfg.MEANS['imagenet'] 22 | 23 | @property 24 | def std(self): 25 | return cfg.STDS['imagenet'] 26 | 27 | def train_transform(self): 28 | return transforms.Compose([ 29 | transforms.Resize(self.args.base_size), 30 | transforms.RandomCrop(self.args.crop_size), 31 | transforms.RandomHorizontalFlip(), 32 | transforms.ToTensor(), 33 | transforms.Normalize(self.mean, self.std), 34 | ]) 35 | 36 | def val_transform(self): 37 | return transforms.Compose([ 38 | transforms.Resize(self.args.base_size), 39 | transforms.CenterCrop(self.args.crop_size), 40 | transforms.ToTensor(), 41 | transforms.Normalize(self.mean, self.std) 42 | ]) 43 | 44 | def train_dataloader(self): 45 | dataset = ImageFolder( 46 | root=self.args.train_dir, 47 | transform=self.train_transform(), 48 | ) 49 | dataloader = DataLoader(dataset, batch_size=self.args.batch_size, 50 | num_workers=4, shuffle=True, drop_last=True, pin_memory=True) 51 | return dataloader 52 | 53 | def val_dataloader(self): 54 | dataset = ImageFolder( 55 | root=self.args.val_dir, 56 | transform=self.val_transform(), 57 | ) 58 | dataloader = DataLoader(dataset, batch_size=self.args.test_batch_size, 59 | num_workers=4, shuffle=True, drop_last=True, pin_memory=True) 60 | return dataloader 61 | -------------------------------------------------------------------------------- /dataloader/datasets/aircraft.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import config as cfg 4 | 5 | from torch.utils.data import Dataset, DataLoader 6 | from torchvision import datasets, transforms 7 | from torchvision.datasets import ImageFolder 8 | 9 | class Aircraft(Dataset): 10 | """`FGVC Aircraft `_ Dataset. 11 | """ 12 | def __init__(self, args, **kwargs): 13 | super(Aircraft, self).__init__() 14 | self.args = args 15 | self.num_class = cfg.NUM_CLASSES[args.dataset.lower()] 16 | 17 | @property 18 | def mean(self): 19 | return cfg.MEANS[self.args.dataset.lower()] 20 | 21 | @property 22 | def std(self): 23 | return cfg.STDS[self.args.dataset.lower()] 24 | 25 | def train_transform(self): 26 | return transforms.Compose([ 27 | transforms.Resize(self.args.base_size), 28 | transforms.RandomCrop(self.args.crop_size), 29 | transforms.RandomHorizontalFlip(), 30 | transforms.ToTensor(), 31 | transforms.Normalize(self.mean, self.std), 32 | ]) 33 | 34 | def val_transform(self): 35 | return transforms.Compose([ 36 | transforms.Resize(self.args.base_size), 37 | transforms.CenterCrop(self.args.crop_size), 38 | transforms.ToTensor(), 39 | transforms.Normalize(self.mean, self.std) 40 | ]) 41 | 42 | def train_dataloader(self): 43 | dataset = ImageFolder( 44 | root=self.args.train_dir, 45 | transform=self.train_transform(), 46 | ) 47 | dataloader = DataLoader(dataset, batch_size=self.args.batch_size, 48 | num_workers=4, shuffle=True, drop_last=True, pin_memory=True) 49 | return dataloader 50 | 51 | def val_dataloader(self): 52 | dataset = ImageFolder( 53 | root=self.args.val_dir, 54 | transform=self.val_transform(), 55 | ) 56 | dataloader = DataLoader(dataset, batch_size=self.args.test_batch_size, 57 | num_workers=4, shuffle=True, drop_last=True, pin_memory=True) 58 | return dataloader 59 | -------------------------------------------------------------------------------- /dataloader/datasets/cub200.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import config as cfg 4 | 5 | from torch.utils.data import Dataset, DataLoader 6 | from torchvision import datasets, transforms 7 | from torchvision.datasets import ImageFolder 8 | 9 | class CUB200(Dataset): 10 | """`Caltech-UCSD Birds 200 `_ Dataset. 11 | """ 12 | def __init__(self, args, **kwargs): 13 | super(CUB200, self).__init__() 14 | self.args = args 15 | self.num_class = cfg.NUM_CLASSES[args.dataset.lower()] 16 | 17 | @property 18 | def mean(self): 19 | return cfg.MEANS[self.args.dataset.lower()] 20 | 21 | @property 22 | def std(self): 23 | return cfg.STDS[self.args.dataset.lower()] 24 | 25 | def train_transform(self): 26 | return transforms.Compose([ 27 | transforms.Resize(self.args.base_size), 28 | transforms.RandomCrop(self.args.crop_size), 29 | transforms.RandomHorizontalFlip(), 30 | transforms.ToTensor(), 31 | transforms.Normalize(self.mean, self.std), 32 | ]) 33 | 34 | def val_transform(self): 35 | return transforms.Compose([ 36 | transforms.Resize(self.args.base_size), 37 | transforms.CenterCrop(self.args.crop_size), 38 | transforms.ToTensor(), 39 | transforms.Normalize(self.mean, self.std) 40 | ]) 41 | 42 | def train_dataloader(self): 43 | dataset = ImageFolder( 44 | root=self.args.train_dir, 45 | transform=self.train_transform(), 46 | ) 47 | dataloader = DataLoader(dataset, batch_size=self.args.batch_size, 48 | num_workers=4, shuffle=True, drop_last=True, pin_memory=True) 49 | return dataloader 50 | 51 | def val_dataloader(self): 52 | dataset = ImageFolder( 53 | root=self.args.val_dir, 54 | transform=self.val_transform(), 55 | ) 56 | dataloader = DataLoader(dataset, batch_size=self.args.test_batch_size, 57 | num_workers=4, shuffle=True, drop_last=True, pin_memory=True) 58 | return dataloader 59 | -------------------------------------------------------------------------------- /modeling/networks/vgg_small_cifar.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | import math 4 | 5 | class VGG_SMALL(nn.Module): 6 | """ Pytorch implementation of VGGSmall artecture, modified from 7 | https://github.com/microsoft/LQ-Nets/blob/master/cifar10-vgg-small.py (Tensorflow Version). 8 | """ 9 | def __init__(self, features, num_classes=10, init_weights=True): 10 | super(VGG_SMALL, self).__init__() 11 | self.features = features 12 | self.classifier = nn.Linear(in_features=512*4*4, out_features=num_classes, bias=True) 13 | if init_weights: 14 | self._initialize_weights() 15 | 16 | def _initialize_weights(self): 17 | for m in self.modules(): 18 | if isinstance(m, nn.Conv2d): 19 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 20 | if m.bias is not None: 21 | nn.init.zeros_(m.bias) 22 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 23 | nn.init.ones_(m.weight) 24 | nn.init.zeros_(m.bias) 25 | elif isinstance(m, nn.Linear): 26 | nn.init.normal_(m.weight, 0, 0.01) 27 | nn.init.zeros_(m.bias) 28 | 29 | def forward(self, input): 30 | x = self.features(input) 31 | x = x.view(x.size(0), -1) 32 | x = self.classifier(x) 33 | return x 34 | 35 | def make_layers(cfg, batch_norm=False): 36 | layers = [] 37 | in_channels = 3 38 | batch_norm = True 39 | for v in cfg: 40 | if v == 'M': 41 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 42 | elif v == 'A': 43 | layers += [nn.BatchNorm2d(in_channels), nn.ReLU(inplace=True)] 44 | else: 45 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 46 | if batch_norm: 47 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 48 | else: 49 | layers += [conv2d] 50 | batch_norm = not batch_norm 51 | in_channels = v 52 | return nn.Sequential(*layers) 53 | 54 | def vggsmall(args): 55 | config = [128, 128, 'M', 'A', 256, 256, 'M', 'A', 512, 512, 'M', 'A'] 56 | return VGG_SMALL(make_layers(config)) 57 | -------------------------------------------------------------------------------- /utils/sparsity.py: -------------------------------------------------------------------------------- 1 | from modeling.DGMS import DGMSConv 2 | import torch 3 | import torch.nn as nn 4 | 5 | def _count_zero(x): 6 | return x.eq(0.0).float().mean().item() 7 | 8 | 9 | def _check_filter(x): 10 | return _count_zero(x.abs().sum(dim=(1, 2, 3))) 11 | 12 | 13 | def _check_channel(x): 14 | return _count_zero(x.abs().sum(dim=(0, 2, 3))) 15 | 16 | 17 | def _check_kernel(x): 18 | return _count_zero(x.abs().sum(dim=(2, 3))) 19 | 20 | 21 | def _check_irregular(x): 22 | return _count_zero(x) 23 | 24 | def check_total_zero(x): 25 | with torch.no_grad(): 26 | return x.eq(0.0).float().sum().item() 27 | 28 | def check_total_weights(x): 29 | with torch.no_grad(): 30 | return x.numel() 31 | 32 | 33 | _CHECKS = { 34 | 'filter': _check_filter, 35 | 'kernel': _check_kernel, 36 | 'channel': _check_channel, 37 | 'irregular': _check_irregular, 38 | } 39 | 40 | 41 | def check(x, method): 42 | with torch.no_grad(): 43 | return _CHECKS[method](x) 44 | 45 | 46 | class SparsityMeasure(object): 47 | def __init__(self, args): 48 | super(SparsityMeasure, self).__init__() 49 | self.args = args 50 | 51 | def check_sparsity_per_layer(self, model): 52 | total_sparsity_num = 0 53 | total_weight_num = 0 54 | skipped_weight_num = 0 55 | for name, m in model.named_modules(): 56 | if isinstance(m, DGMSConv): 57 | Pweight = m.get_Pweight() 58 | sparse_ratio = check(Pweight, 'irregular') 59 | total_sparsity_num += check_total_zero(Pweight) 60 | total_weight_num += check_total_weights(Pweight) 61 | print(f'{name}\t{m.weight.size()}:\t{sparse_ratio:.3f}') 62 | elif isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 63 | skipped_weight_num += check_total_weights(m.weight) 64 | print(f'{name}\t{m.weight.size()}:\tfp32') 65 | total_sparse_ratio = total_sparsity_num / total_weight_num 66 | nz_parameters_num = total_weight_num-total_sparsity_num 67 | print(f"Total sparsity is {total_sparsity_num} / {total_weight_num}:\t {total_sparse_ratio:.4f}") 68 | nz_ratio = 1 - total_sparse_ratio 69 | print(f"NZ ratio is :\t {nz_ratio:.4f}") 70 | model_params = (skipped_weight_num+nz_parameters_num) / 1e6 71 | print(f"Skipped weights number: {skipped_weight_num}") 72 | print(f"NZ parameters size: {model_params:.2f}M") 73 | return total_sparse_ratio, model_params 74 | -------------------------------------------------------------------------------- /modeling/DGMS/DGMSConv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Runpei Dong, ArChip Lab. 2 | 3 | """ DGMS convolution implementation. 4 | 5 | Author: Runpei Dong 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import config as cfg 12 | from .GMM import * 13 | 14 | class DGMSConv(nn.Conv2d): 15 | """ DGMS Convolution: 16 | Convolution operator based on Differentiable Gaussian Mixture Weight Sharing (DGMS) for model compression. 17 | """ 18 | def __init__( 19 | self, 20 | in_channels: int, 21 | out_channels: int, 22 | kernel_size: int, 23 | stride=1, 24 | padding=0, 25 | dilation=1, 26 | groups=1, 27 | bias=False, 28 | padding_mode: str = 'zeros', 29 | ): 30 | super(DGMSConv, self).__init__( 31 | in_channels, out_channels, kernel_size, stride, padding, dilation, 32 | groups, bias, padding_mode) 33 | self.is_normal = cfg.IS_NORMAL 34 | 35 | self.k_level = cfg.K_LEVEL 36 | self.temperature = cfg.TAU 37 | 38 | def init_mask_params(self): 39 | init_method = 'empirical' if cfg.IS_EMP else 'k-means' 40 | self.sub_distribution = gmm_approximation(self.k_level, self.weight, self.temperature, init_method) 41 | 42 | def get_Sweight(self): 43 | # soft quantized weights during training 44 | with torch.no_grad(): 45 | return self.sub_distribution(weights=self.weight, train=True) 46 | 47 | def get_Pweight(self): 48 | # hard quantized weights during inference 49 | with torch.no_grad(): 50 | return self.sub_distribution(weights=self.weight, train=False) 51 | 52 | def forward(self, input): 53 | if cfg.IS_NORMAL: 54 | # pretraning using normal convolution operator 55 | output = F.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 56 | else: 57 | # DGMS convolution operator 58 | if cfg.IS_TRAIN: 59 | # training using DGMS differentiable indicator 60 | Sweight = self.sub_distribution(weights=self.weight, train=True) 61 | output = F.conv2d(input, Sweight, self.bias, self.stride, self.padding, self.dilation, self.groups) 62 | else: 63 | # inference using hard mask 64 | Pweight = self.sub_distribution(weights=self.weight, train=False) 65 | output = F.conv2d(input, Pweight, self.bias, self.stride, self.padding, self.dilation, self.groups) 66 | return output 67 | 68 | 69 | # unit test script 70 | if __name__ == '__main__': 71 | m = DGMSConv(16, 33, 3, stride=2) 72 | input = torch.randn(20, 16, 50, 100) 73 | output = m(input) 74 | print(output.size()) 75 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Differentiable Gaussian Mixture Weight Sharing Network Quantization, ICML 2022 Spotlight 2 | > [**Finding the Task-Optimal Low-Bit Sub-Distribution in Deep Neural Networks**](https://proceedings.mlr.press/v162/dong22a.html), ICML 2022
3 | > [Runpei Dong](https://runpeidong.com/)\*, [Zhanhong Tan](https://www.zhanhongtan.com/)\*, [Mengdi Wu](), [Linfeng Zhang](https://scholar.google.com.hk/citations?user=AK9VF30AAAAJ&hl=en), and [Kaisheng Ma](http://group.iiis.tsinghua.edu.cn/~maks/leader.html)
4 | 5 | Created by [Runpei Dong](https://runpeidong.com/)\*, [Zhanhong Tan](https://www.zhanhongtan.com/)\*, [Mengdi Wu](https://scholar.google.com.hk/citations?user=F9EN5zgAAAAJ&hl=en&oi=sra), [Linfeng Zhang](http://group.iiis.tsinghua.edu.cn/~maks/linfeng/index.html), and [Kaisheng Ma](http://group.iiis.tsinghua.edu.cn/~maks/leader.html). 6 | 7 | [PMLR](https://proceedings.mlr.press/v162/dong22a.html) | [arXiv](https://arxiv.org/abs/2112.15139) | [Models](https://drive.google.com/drive/folders/1rQJLAbP8gb5ZIUyIjEVHof0euyhsVGu4?usp=sharing) 8 | 9 | This repository contains the code release of the paper **Finding the Task-Optimal Low-Bit Sub-Distribution in Deep Neural Networks** (ICML 2022). 10 | 11 | 12 | ## Installation 13 | 14 | Our code works with Python 3.8.3. we recommend to use [Anaconda](https://www.anaconda.com/), and you can install the dependencies by running: 15 | 16 | ```shell 17 | $ python3 -m venv env 18 | $ source env/bin/activate 19 | (env) $ python3 -m pip install -r requirements.txt 20 | ``` 21 | ## How to Run 22 | 23 | The main procedures are written in script `main.py`, please run the following command for instructions: 24 | 25 | ```shell 26 | $ python main.py -h 27 | ``` 28 | 29 | ### Datasets 30 | 31 | Before running the code, you can specify the path for datasets in `config.py`, or you can specify it by `--train-dir` and `--val-dir`. 32 | 33 | ### Training on ImageNet 34 | 35 | We have provided a simple SHELL script to train a 4-bit `ResNet-18` with `DGMS`. Run: 36 | 37 | ```shell 38 | $ sh tools/train_imgnet.sh 39 | ``` 40 | 41 | ### Inference on ImageNet 42 | 43 | To inference compressed models on ImageNet, you only need to follow 2 steps: 44 | 45 | * **Step-1**: Download the checkpoints released on [Google Drive](https://drive.google.com/drive/folders/1rQJLAbP8gb5ZIUyIjEVHof0euyhsVGu4?usp=sharing). 46 | 47 | * **Step-2**: Run the inference SHELL script we provide: 48 | 49 | ```shell 50 | $ sh tools/validation.sh 51 | ``` 52 | 53 | ## Q-SIMD 54 | 55 | The [TVM](https://github.com/apache/tvm) based Q-SIMD codes can be download from [Google Drive](https://drive.google.com/file/d/1hGeXXdHetGKZKSd4dp7xTSRxjWXgPkjc/view?usp=sharing). 56 | 57 | ## Citation 58 | 59 | If you find our work useful in your research, please consider citing: 60 | 61 | ```tex 62 | @inproceedings{dong2021finding, 63 | title={Finding the Task-Optimal Low-Bit Sub-Distribution in Deep Neural Networks}, 64 | author={Dong, Runpei and Tan, Zhanhong and Wu, Mengdi and Zhang, Linfeng and Ma, Kaisheng}, 65 | booktitle={Proceedings of the International Conference on Machine Learning (ICML)}, 66 | year={2022} 67 | } 68 | ``` 69 | 70 | ## License 71 | 72 | DGMS is released under the Apache 2.0 license. See the [LICENSE](./LICENSE) file for more details. 73 | -------------------------------------------------------------------------------- /utils/saver.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import torch 4 | from collections import OrderedDict 5 | import glob 6 | 7 | class Saver(object): 8 | 9 | def __init__(self, args): 10 | self.args = args 11 | if args.normal: 12 | checkname = args.checkname + '_uncompressed' 13 | else: 14 | checkname = args.checkname 15 | self.directory = os.path.join('run', args.dataset, checkname) 16 | if not os.path.exists(self.directory): 17 | os.makedirs(self.directory) 18 | self.runs = sorted(glob.glob(os.path.join(self.directory, 'experiment_*'))) 19 | run_id = int(self.runs[-1].split('_')[-1]) + 1 if self.runs else 0 20 | 21 | self.experiment_dir = os.path.join(self.directory, 'experiment_{}'.format(str(run_id))) 22 | if not os.path.exists(self.experiment_dir): 23 | os.makedirs(self.experiment_dir) 24 | 25 | def save_checkpoint(self, state, is_best, filename='checkpoint.pth.tar'): 26 | """Saves checkpoint to disk""" 27 | filename = os.path.join(self.experiment_dir, filename) 28 | torch.save(state, filename) 29 | if is_best: 30 | epoch = state['epoch'] 31 | best_top1 = state['best_top1'] 32 | best_top5 = state['best_top5'] 33 | params = state['params'] 34 | bitwidth = state['bits'] 35 | compression_rate = state['CR'] 36 | with open(os.path.join(self.experiment_dir, 'best_pred.txt'), 'w') as f: 37 | f.write('Epoch: ' + str(epoch) + "\n") 38 | f.write('Top1: ' + str(best_top1) + "\n") 39 | f.write('Top5: ' + str(best_top5) + "\n") 40 | f.write('#Params: ' + str(params) + "M" + "\n") 41 | f.write('Bits: ' + str(bitwidth) + "\n") 42 | f.write('CR: ' + str(compression_rate) + "\n") 43 | if self.runs: 44 | previous_acc = [0.0] 45 | for run in self.runs: 46 | run_id = run.split('_')[-1] 47 | path = os.path.join(self.directory, 'experiment_{}'.format(str(run_id)), 'best_pred.txt') 48 | if os.path.exists(path): 49 | cnt = 1 50 | with open(path, 'r') as f: 51 | if cnt == 2: 52 | acc = float(f.readline().split(' ')[-1]) 53 | previous_acc.append(acc) 54 | cnt += 1 55 | else: 56 | continue 57 | max_acc = max(previous_acc) 58 | if best_top1 > max_acc: 59 | shutil.copyfile(filename, os.path.join(self.directory, 'model_best.pth.tar')) 60 | else: 61 | shutil.copyfile(filename, os.path.join(self.directory, 'model_best.pth.tar')) 62 | 63 | def save_experiment_config(self): 64 | logfile = os.path.join(self.experiment_dir, 'training_configs.txt') 65 | log_file = open(logfile, 'w') 66 | p = OrderedDict() 67 | p['dataset'] = self.args.dataset 68 | p['network'] = self.args.network 69 | p['lr'] = self.args.lr 70 | p['lr_scheduler'] = self.args.lr_scheduler 71 | p['epoch'] = self.args.epochs 72 | p['tau'] = self.args.tau 73 | p['K'] = self.args.K 74 | 75 | for key, val in p.items(): 76 | log_file.write(key + ':' + str(val) + '\n') 77 | log_file.close() 78 | -------------------------------------------------------------------------------- /utils/PyTransformer/README.md: -------------------------------------------------------------------------------- 1 | # PyTranformer 2 | Modified from https://github.com/ricky40403/PyTransformer. 3 | 4 | ## summary 5 | This repository implement the summary function similar to keras summary() 6 | 7 | ``` 8 | model = nn.Sequential( 9 | nn.Conv2d(3,20,5), 10 | nn.ReLU(), 11 | nn.Conv2d(20,64,5), 12 | nn.ReLU() 13 | ) 14 | 15 | model.eval() 16 | 17 | transofrmer = TorchTransformer() 18 | input_tensor = torch.randn([1, 3, 224, 224]) 19 | net = transofrmer.summary(model, input_tensor) 20 | 21 | ########################################################################################## 22 | Index| Layer (type) | Bottoms Output Shape Param # 23 | --------------------------------------------------------------------------- 24 | 1| Data | [(1, 3, 224, 224)] 0 25 | --------------------------------------------------------------------------- 26 | 2| Conv2d_1 | Data [(1, 20, 220, 220)] 1500 27 | --------------------------------------------------------------------------- 28 | 3| ReLU_2 | Conv2d_1 [(1, 20, 220, 220)] 0 29 | --------------------------------------------------------------------------- 30 | 4| Conv2d_3 | ReLU_2 [(1, 64, 216, 216)] 32000 31 | --------------------------------------------------------------------------- 32 | 5| ReLU_4 | Conv2d_3 [(1, 64, 216, 216)] 0 33 | --------------------------------------------------------------------------- 34 | ================================================================================== 35 | Total Trainable params: 33500 36 | Total Non-Trainable params: 0 37 | Total params: 33500 38 | ``` 39 | 40 | other example is in [example.ipynb](summary_example.ipynb) 41 | 42 | ## visualize 43 | visualize using [graphviz](https://graphviz.readthedocs.io/en/stable/) and [pydot](https://pypi.org/project/pydot/) 44 | it will show the architecture. 45 | Such as alexnet in torchvision: 46 | ``` 47 | model = models.__dict__["alexnet"]() 48 | model.eval() 49 | transofrmer = TorchTransformer() 50 | transofrmer.visualize(model, save_name= "example", graph_size = 80) 51 | # graph_size can modify to change the size of the output graph 52 | # graphviz does not auto fit the model's layers, which mean if the model is too deep. 53 | # And it will become too small to see. 54 | # So change the graph size to enlarge the image for higher resolution. 55 | ``` 56 | 57 | 58 | example is in [example](visualize_example.ipynb) 59 | other example image is in [examples](/examples) 60 | 61 | ## transform layers 62 | you can register layer type to transform 63 | First you need to register to transformer and the transformer will transform layers you registered. 64 | 65 | example in in [transform_example](transform_example.ipynb) 66 | 67 | 68 | 69 | 70 | ## Note 71 | Suggest that the layers input should not be too many because the graphviz may generate image slow.(eg: densenet161 in torchvision 0.4.0 may stuck when generating png) 72 | 73 | ## TODO 74 | - [x] support registration(replace) for custom layertype 75 | - [ ] support replacement of specified layer in model for specified layer 76 | - [x] activation size calculation for supported layers 77 | - [x] network summary output as in keras 78 | - [x] model graph visualization 79 | - [ ] replace multiple modules to 1 module 80 | - [ ] conditional module replacement 81 | - [ ] add additional module to forward graph 82 | -------------------------------------------------------------------------------- /modeling/networks/resnet_cifar.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.nn.init as init 5 | 6 | __all__ = ['ResNet', 'resnet20', 'resnet32', 'resnet56'] 7 | 8 | def _weights_init(m): 9 | classname = m.__class__.__name__ 10 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 11 | init.kaiming_normal_(m.weight) 12 | 13 | class LambdaLayer(nn.Module): 14 | def __init__(self, lambd): 15 | super(LambdaLayer, self).__init__() 16 | self.lambd = lambd 17 | 18 | def forward(self, x): 19 | return self.lambd(x) 20 | 21 | 22 | class BasicBlock(nn.Module): 23 | expansion = 1 24 | 25 | def __init__(self, in_planes, planes, stride=1, option='A'): 26 | super(BasicBlock, self).__init__() 27 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 28 | self.bn1 = nn.BatchNorm2d(planes) 29 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 30 | self.bn2 = nn.BatchNorm2d(planes) 31 | 32 | self.shortcut = nn.Sequential() 33 | if stride != 1 or in_planes != planes: 34 | if option == 'A': 35 | """ 36 | For CIFAR10 ResNet paper uses option A. 37 | """ 38 | self.shortcut = LambdaLayer(lambda x: 39 | F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0)) 40 | elif option == 'B': 41 | self.shortcut = nn.Sequential( 42 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 43 | nn.BatchNorm2d(self.expansion * planes) 44 | ) 45 | 46 | def forward(self, x): 47 | out = F.relu(self.bn1(self.conv1(x))) 48 | out = self.bn2(self.conv2(out)) 49 | out += self.shortcut(x) 50 | out = F.relu(out) 51 | return out 52 | 53 | 54 | class ResNet(nn.Module): 55 | def __init__(self, block, num_blocks, num_classes=10): 56 | super(ResNet, self).__init__() 57 | self.in_planes = 16 58 | 59 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 60 | self.bn1 = nn.BatchNorm2d(16) 61 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 62 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 63 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) 64 | self.linear = nn.Linear(64, num_classes) 65 | 66 | self.apply(_weights_init) 67 | 68 | def _make_layer(self, block, planes, num_blocks, stride): 69 | strides = [stride] + [1]*(num_blocks-1) 70 | layers = [] 71 | for stride in strides: 72 | layers.append(block(self.in_planes, planes, stride)) 73 | self.in_planes = planes * block.expansion 74 | 75 | return nn.Sequential(*layers) 76 | 77 | def forward(self, x): 78 | out = F.relu(self.bn1(self.conv1(x))) 79 | out = self.layer1(out) 80 | out = self.layer2(out) 81 | out = self.layer3(out) 82 | out = F.avg_pool2d(out, out.size()[3]) 83 | out = out.view(out.size(0), -1) 84 | out = self.linear(out) 85 | return out 86 | 87 | 88 | def resnet20(args, **kwargs): 89 | model = ResNet(BasicBlock, [3, 3, 3]) 90 | return model 91 | 92 | 93 | def resnet32(args, **kwargs): 94 | model = ResNet(BasicBlock, [5, 5, 5]) 95 | return model 96 | 97 | 98 | def resnet56(args, **kwargs): 99 | model = ResNet(BasicBlock, [9, 9, 9]) 100 | return model 101 | 102 | 103 | def test(net): 104 | import numpy as np 105 | total_params = 0 106 | 107 | for x in filter(lambda p: p.requires_grad, net.parameters()): 108 | total_params += np.prod(x.data.numpy().shape) 109 | print("Total number of params", total_params) 110 | print("Total layers", len(list(filter(lambda p: p.requires_grad and len(p.data.size())>1, net.parameters())))) 111 | -------------------------------------------------------------------------------- /modeling/DGMS/GMM.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Runpei Dong, ArChip Lab. 2 | 3 | """ DGMS GM Sub-distribution implementation. 4 | 5 | Author: Runpei Dong 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import math 12 | import config as cfg 13 | 14 | from utils.misc import cluster_weights 15 | 16 | class GaussianMixtureModel(nn.Module): 17 | """Concrete GMM for sub-distribution approximation. 18 | """ 19 | def __init__(self, num_components, init_weights, temperature=0.01, init_method="k-means"): 20 | super(GaussianMixtureModel, self).__init__() 21 | self.num_components = num_components 22 | self.temperature = temperature 23 | if torch.cuda.is_available(): 24 | self.device = torch.device('cuda') 25 | else: 26 | self.device = torch.device('cpu') 27 | self.params_initialization(init_weights, init_method) 28 | 29 | def params_initialization(self, init_weights, method='k-means'): 30 | """ Initialization of GMM parameters using k-means algorithm. """ 31 | self.mu_zero = torch.tensor([0.0], device=self.device).float() 32 | self.pi_k, self.mu, self.sigma = \ 33 | torch.ones(self.num_components-1, device=self.device), \ 34 | torch.ones(self.num_components-1, device=self.device), \ 35 | torch.ones(self.num_components-1, device=self.device) 36 | if method == 'k-means': 37 | initial_region_saliency, pi_init, pi_zero_init, sigma_init, _sigma_zero = cluster_weights(init_weights, self.num_components) 38 | elif method == 'empirical': 39 | initial_region_saliency, pi_init, pi_zero_init, sigma_init, _sigma_zero = cluster_weights(init_weights, self.num_components) 40 | sigma_init, _sigma_zero = torch.ones_like(sigma_init).mul(0.01).cuda(), torch.ones_like(torch.tensor([_sigma_zero])).mul(0.01).cuda() 41 | self.mu = nn.Parameter(data=torch.mul(self.mu.cuda(), initial_region_saliency.flatten().cuda())) 42 | self.pi_k = nn.Parameter(data=torch.mul(self.pi_k.cuda(), pi_init)).cuda().float() 43 | self.pi_zero = nn.Parameter(data=torch.tensor([pi_zero_init], device=self.device)).cuda().float() 44 | self.sigma_zero = nn.Parameter(data=torch.tensor([_sigma_zero], device=self.device)).float() 45 | self.sigma = nn.Parameter(data=torch.mul(self.sigma, sigma_init)).cuda().float() 46 | self.temperature = nn.Parameter(data=torch.tensor([self.temperature], device=self.device)) 47 | 48 | def gaussian_mixing_regularization(self): 49 | pi_tmp = torch.cat([self.pi_zero, self.pi_k], dim=-1).abs() 50 | return torch.div(pi_tmp, pi_tmp.sum(dim=-1).unsqueeze(-1)).cuda() 51 | 52 | def Normal_pdf(self, x, _pi, mu, sigma): 53 | """ Standard Normal Distribution PDF. """ 54 | return torch.mul(torch.reciprocal(torch.sqrt(torch.mul( \ 55 | torch.tensor([2 * math.pi], device=self.device), sigma**2))), \ 56 | torch.exp(-torch.div((x - mu)**2, 2 * sigma**2))).mul(_pi) 57 | 58 | def GMM_region_responsibility(self, weights): 59 | """" Region responsibility of GMM. """ 60 | pi_normalized = self.gaussian_mixing_regularization().cuda() 61 | responsibility = torch.zeros([self.num_components, weights.size(0)], device=self.device) 62 | responsibility[0] = self.Normal_pdf(weights.cuda(), pi_normalized[0], 0.0, self.sigma_zero.cuda()) 63 | for k in range(self.num_components-1): 64 | responsibility[k+1] = self.Normal_pdf(weights, pi_normalized[k+1], self.mu[k].cuda(), self.sigma[k].cuda()) 65 | responsibility = torch.div(responsibility, responsibility.sum(dim=0) + cfg.EPS) 66 | return F.softmax(responsibility / self.temperature, dim=0) 67 | 68 | def forward(self, weights, train=True): 69 | if train: 70 | # soft mask generalized pruning during training 71 | self.region_belonging = self.GMM_region_responsibility(weights.flatten()) 72 | Sweight = torch.mul(self.region_belonging[0], 0.) \ 73 | + torch.mul(self.region_belonging[1:], self.mu.unsqueeze(1)).sum(dim=0) 74 | return Sweight.view(weights.size()) 75 | else: 76 | self.region_belonging = self.GMM_region_responsibility(weights.flatten()) 77 | max_index = torch.argmax(self.region_belonging, dim=0).unsqueeze(0) 78 | mask_w = torch.zeros_like(self.region_belonging).scatter_(dim=0, index=max_index, value=1.) 79 | Pweight = torch.mul(mask_w[1:], self.mu.unsqueeze(1)).sum(dim=0) 80 | return Pweight.view(weights.size()) 81 | 82 | def gmm_approximation(num_components, init_weights, temperature=0.5, init_method='k-means'): 83 | return GaussianMixtureModel(num_components, init_weights.flatten(), temperature, init_method) 84 | -------------------------------------------------------------------------------- /dataloader/datasets/cifar10.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as transforms 2 | import config as cfg 3 | 4 | from torch.utils.data import Dataset, DataLoader 5 | from torchvision import transforms 6 | from torchvision.datasets import CIFAR10, CIFAR100 7 | 8 | class CIFAR10_Module(Dataset): 9 | """`CIFAR10 `_ Dataset. 10 | """ 11 | base_folder = 'cifar-10-batches-py' 12 | url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" 13 | filename = "cifar-10-python.tar.gz" 14 | tgz_md5 = 'c58f30108f718f92721af3b95e74349a' 15 | train_list = [ 16 | ['data_batch_1', 'c99cafc152244af753f735de768cd75f'], 17 | ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'], 18 | ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'], 19 | ['data_batch_4', '634d18415352ddfa80567beed471001a'], 20 | ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'], 21 | ] 22 | 23 | test_list = [ 24 | ['test_batch', '40351d587109b95175f43aff81a1287e'], 25 | ] 26 | 27 | NUM_CLASSES = cfg.NUM_CLASSES['cifar10'] 28 | 29 | def __init__(self, args, **kwargs): 30 | super(CIFAR10_Module, self).__init__() 31 | self.args = args 32 | 33 | @property 34 | def mean(self): 35 | return cfg.MEANS['cifar'] 36 | 37 | @property 38 | def std(self): 39 | return cfg.STDS['cifar'] 40 | 41 | @property 42 | def num_class(self): 43 | return self.NUM_CLASSES 44 | 45 | def train_dataloader(self): 46 | transform_train = transforms.Compose([transforms.RandomHorizontalFlip(), 47 | transforms.RandomCrop(32, padding=4), 48 | transforms.ToTensor(), 49 | transforms.Normalize(self.mean, self.std)]) 50 | dataset = CIFAR10(root=self.args.train_dir, 51 | train=True, download=True, transform=transform_train) 52 | dataloader = DataLoader(dataset, batch_size=self.args.batch_size, 53 | num_workers=4, shuffle=True, drop_last=True, pin_memory=True) 54 | return dataloader 55 | 56 | def val_dataloader(self): 57 | transform_val = transforms.Compose([transforms.ToTensor(), 58 | transforms.Normalize(self.mean, self.std)]) 59 | dataset = CIFAR10(root=self.args.val_dir, 60 | train=False, transform=transform_val) 61 | dataloader = DataLoader(dataset, batch_size=self.args.batch_size, 62 | num_workers=4, pin_memory=True) 63 | return dataloader 64 | 65 | class CIFAR100_Module(CIFAR10_Module): 66 | """`CIFAR100 `_ Dataset. 67 | 68 | This is a subclass of the `CIFAR10` Dataset. 69 | """ 70 | base_folder = 'cifar-100-python' 71 | url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" 72 | filename = "cifar-100-python.tar.gz" 73 | tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' 74 | train_list = [ 75 | ['train', '16019d7e3df5f24257cddd939b257f8d'], 76 | ] 77 | 78 | test_list = [ 79 | ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], 80 | ] 81 | meta = { 82 | 'filename': 'meta', 83 | 'key': 'fine_label_names', 84 | 'md5': '7973b15100ade9c7d40fb424638fde48', 85 | } 86 | 87 | def __init__(self, args, **kwargs): 88 | super(CIFAR100_Module, self).__init__(args, **kwargs) 89 | 90 | def train_dataloader(self): 91 | transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4), 92 | transforms.RandomHorizontalFlip(), 93 | transforms.ToTensor(), 94 | transforms.Normalize(self.mean, self.std)]) 95 | dataset = CIFAR100(root=self.args.train_dir, 96 | train=True, download=True, transform=transform_train) 97 | dataloader = DataLoader(dataset, batch_size=self.args.test_batch_size, 98 | num_workers=4, shuffle=True, drop_last=True, pin_memory=True) 99 | return dataloader 100 | 101 | def val_dataloader(self): 102 | transform_val = transforms.Compose([transforms.ToTensor(), 103 | transforms.Normalize(self.mean, self.std)]) 104 | dataset = CIFAR100(root=self.args.val_dir, 105 | train=False, transform=transform_val) 106 | dataloader = DataLoader( 107 | dataset, batch_size=self.args.batch_size, num_workers=4, pin_memory=True) 108 | return dataloader 109 | -------------------------------------------------------------------------------- /utils/PyTransformer/transformers/quantize.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module Source: 3 | https://github.com/eladhoffer/quantized.pytorch 4 | """ 5 | 6 | import torch 7 | from torch.autograd.function import InplaceFunction, Function 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import math 11 | 12 | 13 | class UniformQuantize(InplaceFunction): 14 | @staticmethod 15 | def forward(ctx, input, num_bits=8, min_value=None, max_value=None, inplace=False, symmetric=False, num_chunks=None): 16 | num_chunks = num_chunks = input.shape[0] if num_chunks is None else num_chunks 17 | if min_value is None or max_value is None: 18 | B = input.shape[0] 19 | y = input.view(B // num_chunks, -1) 20 | 21 | if min_value is None: 22 | min_value = y.min(-1)[0].mean(-1) # C 23 | #min_value = float(input.view(input.size(0), -1).min(-1)[0].mean()) 24 | 25 | if max_value is None: 26 | #max_value = float(input.view(input.size(0), -1).max(-1)[0].mean()) 27 | max_value = y.max(-1)[0].mean(-1) # C 28 | 29 | ctx.inplace = inplace 30 | ctx.num_bits = num_bits 31 | ctx.min_value = min_value 32 | ctx.max_value = max_value 33 | 34 | if ctx.inplace: 35 | ctx.mark_dirty(input) 36 | output = input 37 | 38 | else: 39 | output = input.clone() 40 | 41 | if symmetric: 42 | qmin = -2. ** (num_bits - 1) 43 | qmax = 2 ** (num_bits - 1) - 1 44 | max_value = torch.max(torch.abs(max_value), torch.abs(min_value)) 45 | min_value = 0. 46 | 47 | else: 48 | qmin = 0. 49 | qmax = 2. ** num_bits - 1. 50 | 51 | scale = (max_value - min_value) / (qmax - qmin) 52 | scale = max(scale, 1e-8) 53 | 54 | output.add_(-min_value).div_(scale) 55 | 56 | output.clamp_(qmin, qmax).round_() # quantize 57 | 58 | output.mul_(scale).add_(min_value) # dequantize 59 | 60 | return output 61 | 62 | 63 | @staticmethod 64 | def backward(ctx, grad_output): 65 | # straight-through estimator 66 | grad_input = grad_output 67 | return grad_input, None, None, None, None, None, None 68 | 69 | 70 | def quantize(x, num_bits=8, min_value=None, max_value=None, inplace=False, symmetric=False, num_chunks=None): 71 | return UniformQuantize().apply(x, num_bits, min_value, max_value, inplace, symmetric, num_chunks) 72 | 73 | 74 | class QuantMeasure(nn.Module): 75 | """docstring for QuantMeasure.""" 76 | 77 | def __init__(self, num_bits=8, momentum=0.1): 78 | super(QuantMeasure, self).__init__() 79 | self.register_buffer('running_min', torch.zeros(1)) 80 | self.register_buffer('running_max', torch.zeros(1)) 81 | self.momentum = momentum 82 | self.num_bits = num_bits 83 | 84 | 85 | def forward(self, input): 86 | if self.training: 87 | min_value = input.detach().view(input.size(0), -1).min(-1)[0].mean() 88 | max_value = input.detach().view(input.size(0), -1).max(-1)[0].mean() 89 | self.running_min.mul_(1 - self.momentum).add_(min_value * (self.momentum)) 90 | self.running_max.mul_(1 - self.momentum).add_(max_value * (self.momentum)) 91 | 92 | else: 93 | min_value = self.running_min 94 | max_value = self.running_max 95 | 96 | return quantize(input, self.num_bits, min_value=float(min_value), max_value=float(max_value), num_chunks=16) 97 | 98 | 99 | class QConv2d(nn.Conv2d): 100 | """docstring for QConv2d.""" 101 | 102 | def __init__(self, in_channels, out_channels, kernel_size, 103 | stride=1, padding=0, dilation=1, groups=1, bias=True, num_bits=8, num_bits_weight=None): 104 | super(QConv2d, self).__init__(in_channels, out_channels, kernel_size, 105 | stride, padding, dilation, groups, bias) 106 | self.num_bits = num_bits 107 | self.num_bits_weight = num_bits_weight or num_bits 108 | 109 | 110 | def forward(self, input): 111 | qweight = quantize(self.weight, num_bits=self.num_bits_weight, 112 | min_value=float(self.weight.min()), 113 | max_value=float(self.weight.max())) 114 | if self.bias is not None: 115 | qbias = quantize(self.bias, num_bits=self.num_bits_weight) 116 | else: 117 | qbias = None 118 | 119 | output = F.conv2d(input, qweight, qbias, self.stride, 120 | self.padding, self.dilation, self.groups) 121 | 122 | return output 123 | 124 | 125 | class QLinear(nn.Linear): 126 | """docstring for QConv2d.""" 127 | 128 | def __init__(self, in_features, out_features, bias=True, num_bits=8, num_bits_weight=None, num_bits_grad=None, biprecision=False): 129 | super(QLinear, self).__init__(in_features, out_features, bias) 130 | self.num_bits = num_bits 131 | self.num_bits_weight = num_bits_weight or num_bits 132 | 133 | 134 | def forward(self, input): 135 | qweight = quantize(self.weight, num_bits=self.num_bits_weight, 136 | min_value=float(self.weight.min()), 137 | max_value=float(self.weight.max())) 138 | if self.bias is not None: 139 | qbias = quantize(self.bias, num_bits=self.num_bits_weight) 140 | else: 141 | qbias = None 142 | 143 | output = F.linear(input, qweight, qbias) 144 | 145 | return output 146 | -------------------------------------------------------------------------------- /utils/cluster.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of k-means algorithm on GPU devices, modified from https://github.com/subhadarship/kmeans_pytorch. 3 | """ 4 | import numpy as np 5 | import torch 6 | from tqdm import tqdm 7 | 8 | 9 | def initialize(X, num_clusters): 10 | """ 11 | initialize cluster centers 12 | :param X: (torch.tensor) matrix 13 | :param num_clusters: (int) number of clusters 14 | :return: (np.array) initial state 15 | """ 16 | num_samples = len(X) 17 | indices = np.random.choice(num_samples, num_clusters, replace=False) 18 | initial_state = X[indices] 19 | return initial_state 20 | 21 | 22 | def kmeans( 23 | X, 24 | num_clusters, 25 | distance='euclidean', 26 | cluster_centers = [], 27 | tol=1e-5, 28 | tqdm_flag=True, 29 | iter_limit=0, 30 | device=torch.device('cuda:0') 31 | ): 32 | """ 33 | perform kmeans 34 | :param X: (torch.tensor) matrix 35 | :param num_clusters: (int) number of clusters 36 | :param distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean'] 37 | :param tol: (float) threshold [default: 0.00001] 38 | :param device: (torch.device) device [default: cpu] 39 | :param tqdm_flag: Allows to turn logs on and off 40 | :param iter_limit: hard limit for max number of iterations 41 | :return: (torch.tensor, torch.tensor) cluster ids, cluster centers 42 | """ 43 | 44 | if distance == 'euclidean': 45 | pairwise_distance_function = pairwise_distance 46 | elif distance == 'cosine': 47 | pairwise_distance_function = pairwise_cosine 48 | else: 49 | raise NotImplementedError 50 | 51 | # convert to float 52 | X = X.float() 53 | 54 | # transfer to device 55 | X = X.to(device) 56 | 57 | # initialize 58 | if type(cluster_centers) == list: #ToDo: make this less annoyingly weird 59 | initial_state = initialize(X, num_clusters) 60 | else: 61 | print('resuming') 62 | # find data point closest to the initial cluster center 63 | initial_state = cluster_centers 64 | dis = pairwise_distance_function(X, initial_state) 65 | choice_points = torch.argmin(dis, dim=0) 66 | initial_state = X[choice_points] 67 | initial_state = initial_state.to(device) 68 | 69 | iteration = 0 70 | if tqdm_flag: 71 | tqdm_meter = tqdm(desc='[running kmeans]') 72 | while True: 73 | 74 | dis = pairwise_distance_function(X, initial_state) 75 | 76 | choice_cluster = torch.argmin(dis, dim=1) 77 | 78 | initial_state_pre = initial_state.clone() 79 | 80 | for index in range(num_clusters): 81 | selected = torch.nonzero(choice_cluster == index).squeeze().to(device) 82 | 83 | selected = torch.index_select(X, 0, selected) 84 | 85 | initial_state[index] = selected.mean(dim=0) 86 | 87 | center_shift = torch.sum( 88 | torch.sqrt( 89 | torch.sum((initial_state - initial_state_pre) ** 2, dim=1) 90 | )) 91 | 92 | # increment iteration 93 | iteration = iteration + 1 94 | 95 | # update tqdm meter 96 | if tqdm_flag: 97 | tqdm_meter.set_postfix( 98 | iteration=f'{iteration}', 99 | center_shift=f'{center_shift ** 2:0.6f}', 100 | tol=f'{tol:0.6f}' 101 | ) 102 | tqdm_meter.update() 103 | if center_shift ** 2 < tol: 104 | break 105 | if iter_limit != 0 and iteration >= iter_limit: 106 | break 107 | 108 | return choice_cluster.cuda(), initial_state.cuda() 109 | 110 | 111 | def kmeans_predict( 112 | X, 113 | cluster_centers, 114 | distance='euclidean', 115 | device=torch.device('cpu') 116 | ): 117 | """ 118 | predict using cluster centers 119 | :param X: (torch.tensor) matrix 120 | :param cluster_centers: (torch.tensor) cluster centers 121 | :param distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean'] 122 | :param device: (torch.device) device [default: 'cpu'] 123 | :return: (torch.tensor) cluster ids 124 | """ 125 | print(f'predicting on {device}..') 126 | 127 | if distance == 'euclidean': 128 | pairwise_distance_function = pairwise_distance 129 | elif distance == 'cosine': 130 | pairwise_distance_function = pairwise_cosine 131 | else: 132 | raise NotImplementedError 133 | 134 | # convert to float 135 | X = X.float() 136 | 137 | # transfer to device 138 | X = X.to(device) 139 | 140 | dis = pairwise_distance_function(X, cluster_centers) 141 | choice_cluster = torch.argmin(dis, dim=1) 142 | 143 | return choice_cluster.cpu() 144 | 145 | 146 | def pairwise_distance(data1, data2, device=torch.device('cpu')): 147 | # transfer to device 148 | data1, data2 = data1.to(device), data2.to(device) 149 | 150 | # N*1*M 151 | A = data1.unsqueeze(dim=1) 152 | 153 | # 1*N*M 154 | B = data2.unsqueeze(dim=0) 155 | 156 | dis = (A - B) ** 2.0 157 | # return N*N matrix for pairwise distance 158 | dis = dis.sum(dim=-1).squeeze() 159 | return dis 160 | 161 | 162 | def pairwise_cosine(data1, data2, device=torch.device('cpu')): 163 | # transfer to device 164 | data1, data2 = data1.to(device), data2.to(device) 165 | 166 | # N*1*M 167 | A = data1.unsqueeze(dim=1) 168 | 169 | # 1*N*M 170 | B = data2.unsqueeze(dim=0) 171 | 172 | # normalize the points | [0.3, 0.4] -> [0.3/sqrt(0.09 + 0.16), 0.4/sqrt(0.09 + 0.16)] = [0.3/0.5, 0.4/0.5] 173 | A_normalized = A / A.norm(dim=-1, keepdim=True) 174 | B_normalized = B / B.norm(dim=-1, keepdim=True) 175 | 176 | cosine = A_normalized * B_normalized 177 | 178 | # return N*N matrix for pairwise distance 179 | cosine_dis = 1 - cosine.sum(dim=-1).squeeze() 180 | return cosine_dis 181 | -------------------------------------------------------------------------------- /dataloader/custom_transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | 5 | from PIL import Image, ImageOps, ImageFilter 6 | 7 | class Normalize(object): 8 | """Normalize a tensor image with mean and standard deviation. 9 | Args: 10 | mean (tuple): means for each channel. 11 | std (tuple): standard deviations for each channel. 12 | """ 13 | def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)): 14 | self.mean = mean 15 | self.std = std 16 | 17 | def __call__(self, sample): 18 | img = sample['image'] 19 | mask = sample['label'] 20 | img = np.array(img).astype(np.float32) 21 | if mask is not None: 22 | mask = np.array(mask).astype(np.float32) 23 | else: 24 | mask = 0 25 | img /= 255.0 26 | img -= self.mean 27 | img /= self.std 28 | 29 | return {'image': img, 30 | 'label': mask} 31 | 32 | class ToTensor(object): 33 | """Convert ndarrays in sample to Tensors.""" 34 | 35 | def __init__(self, ignore_label=255, num_class=21, flag=False): 36 | self.ignore_label = ignore_label 37 | self.num_class = num_class 38 | self.flag = flag 39 | 40 | def __call__(self, sample): 41 | # swap color axis because 42 | # numpy image: H x W x C 43 | # torch image: C X H X W 44 | img = sample['image'] 45 | mask = sample['label'] 46 | img = np.array(img).astype(np.float32).transpose((2, 0, 1)) 47 | img = torch.from_numpy(img).float() 48 | if mask is not None: 49 | if self.flag: 50 | mask[mask >= self.num_class] = self.ignore_label 51 | mask[mask < 0] = self.ignore_label 52 | mask = np.array(mask).astype(np.float32) 53 | mask = torch.from_numpy(mask).float() 54 | else: 55 | mask = 0 56 | 57 | 58 | return {'image': img, 59 | 'label': mask} 60 | 61 | 62 | class RandomHorizontalFlip(object): 63 | def __call__(self, sample): 64 | img = sample['image'] 65 | mask = sample['label'] 66 | if random.random() < 0.5: 67 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 68 | mask = mask.transpose(Image.FLIP_LEFT_RIGHT) 69 | 70 | return {'image': img, 71 | 'label': mask} 72 | 73 | 74 | class RandomRotate(object): 75 | def __init__(self, degree): 76 | self.degree = degree 77 | 78 | def __call__(self, sample): 79 | img = sample['image'] 80 | mask = sample['label'] 81 | rotate_degree = random.uniform(-1*self.degree, self.degree) 82 | img = img.rotate(rotate_degree, Image.BILINEAR) 83 | mask = mask.rotate(rotate_degree, Image.NEAREST) 84 | 85 | return {'image': img, 86 | 'label': mask} 87 | 88 | 89 | class RandomGaussianBlur(object): 90 | def __call__(self, sample): 91 | img = sample['image'] 92 | mask = sample['label'] 93 | if random.random() < 0.5: 94 | img = img.filter(ImageFilter.GaussianBlur( 95 | radius=random.random())) 96 | 97 | return {'image': img, 98 | 'label': mask} 99 | 100 | 101 | class RandomScaleCrop(object): 102 | def __init__(self, base_size, crop_size, fill=0): 103 | self.base_size = base_size 104 | self.crop_size = crop_size 105 | self.fill = fill 106 | 107 | def __call__(self, sample): 108 | img = sample['image'] 109 | mask = sample['label'] 110 | # random scale (short edge) 111 | short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0)) 112 | w, h = img.size 113 | if h > w: 114 | ow = short_size 115 | oh = int(1.0 * h * ow / w) 116 | else: 117 | oh = short_size 118 | ow = int(1.0 * w * oh / h) 119 | img = img.resize((ow, oh), Image.BILINEAR) 120 | mask = mask.resize((ow, oh), Image.NEAREST) 121 | # pad crop 122 | if short_size < self.crop_size: 123 | padh = self.crop_size - oh if oh < self.crop_size else 0 124 | padw = self.crop_size - ow if ow < self.crop_size else 0 125 | img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0) 126 | mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=self.fill) 127 | # random crop crop_size 128 | w, h = img.size 129 | x1 = random.randint(0, w - self.crop_size) 130 | y1 = random.randint(0, h - self.crop_size) 131 | img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 132 | mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 133 | 134 | return {'image': img, 135 | 'label': mask} 136 | 137 | 138 | class FixScaleCrop(object): 139 | def __init__(self, crop_size): 140 | self.crop_size = crop_size 141 | 142 | def __call__(self, sample): 143 | img = sample['image'] 144 | mask = sample['label'] 145 | w, h = img.size 146 | if w > h: 147 | oh = self.crop_size 148 | ow = int(1.0 * w * oh / h) 149 | else: 150 | ow = self.crop_size 151 | oh = int(1.0 * h * ow / w) 152 | img = img.resize((ow, oh), Image.BILINEAR) 153 | mask = mask.resize((ow, oh), Image.NEAREST) 154 | # center crop 155 | w, h = img.size 156 | x1 = int(round((w - self.crop_size) / 2.)) 157 | y1 = int(round((h - self.crop_size) / 2.)) 158 | img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 159 | mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 160 | 161 | return {'image': img, 162 | 'label': mask} 163 | 164 | class MultiScaleResize(object): 165 | def __init__(self, size, scales): 166 | ratio = np.random.choice(scales) 167 | size = int(size * ratio) 168 | self.size = (size, size) # size: (h, w) 169 | 170 | def __call__(self, sample): 171 | img = sample['image'] 172 | mask = sample['label'] 173 | 174 | assert img.size == mask.size 175 | 176 | img = img.resize(self.size, Image.BILINEAR) 177 | mask = mask.resize(self.size, Image.NEAREST) 178 | 179 | return {'image': img, 180 | 'label': mask} 181 | 182 | class MultiScaleResizeTest(object): 183 | def __init__(self, scales): 184 | self.ratio = np.random.choice(scales) 185 | # h, w = size 186 | # h_ = int(h * ratio) 187 | # w_ = int(w * ratio) 188 | # size_ = (h_, w_) 189 | 190 | def __call__(self, sample): 191 | img = sample['image'] 192 | mask = sample['label'] 193 | size = sample['size'] 194 | 195 | # assert img.size == mask.size 196 | 197 | h, w = size 198 | size_new = (int(h * self.ratio), int(w * self.ratio)) 199 | img = img.resize(size_new, Image.BILINEAR) 200 | if mask is not None: 201 | mask = mask.resize(size_new, Image.BILINEAR) 202 | 203 | return {'image': img, 204 | 'label': mask, 205 | 'size': size} -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import errno 2 | import os 3 | import sys 4 | import time 5 | import math 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.nn.init as init 11 | import config as cfg 12 | from utils.cluster import kmeans 13 | from utils.lr_scheduler import get_scheduler 14 | from sklearn.mixture import GaussianMixture 15 | 16 | __all__ = ['get_mean_and_std', 'init_params', 'mkdir_p', 'AverageMeter', 'cluster_weights', 'get_optimizer', 'resume_ckpt'] 17 | 18 | 19 | def get_mean_and_std(dataset): 20 | '''Compute the mean and std value of dataset.''' 21 | dataloader = trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 22 | 23 | mean = torch.zeros(3) 24 | std = torch.zeros(3) 25 | print('==> Computing mean and std..') 26 | for inputs, targets in dataloader: 27 | for i in range(3): 28 | mean[i] += inputs[:,i,:,:].mean() 29 | std[i] += inputs[:,i,:,:].std() 30 | mean.div_(len(dataset)) 31 | std.div_(len(dataset)) 32 | return mean, std 33 | 34 | def init_params(net): 35 | '''Init layer parameters.''' 36 | for m in net.modules(): 37 | if isinstance(m, nn.Conv2d): 38 | init.kaiming_normal(m.weight, mode='fan_out') 39 | if m.bias: 40 | init.constant(m.bias, 0) 41 | elif isinstance(m, nn.BatchNorm2d): 42 | init.constant(m.weight, 1) 43 | init.constant(m.bias, 0) 44 | elif isinstance(m, nn.Linear): 45 | init.normal(m.weight, std=1e-3) 46 | if m.bias: 47 | init.constant(m.bias, 0) 48 | 49 | def mkdir_p(path): 50 | '''make dir if not exist''' 51 | try: 52 | os.makedirs(path) 53 | except OSError as exc: # Python >2.5 54 | if exc.errno == errno.EEXIST and os.path.isdir(path): 55 | pass 56 | else: 57 | raise 58 | 59 | @torch.no_grad() 60 | def cluster_weights(weights, n_clusters): 61 | """ Initialization of GMM with k-means algorithm, note this procedure may bring 62 | different initialization results, and the results may be slightly different. 63 | Args: 64 | weights:[weight_size] 65 | n_clusters: 1 + 2^i = K Gaussian Mixture Component number 66 | Returns: 67 | [n_clusters-1] initial region saliency obtained by k-means algorthm 68 | """ 69 | flat_weight = weights.view(-1, 1).cuda() 70 | _tol = 1e-11 71 | if cfg.IS_NORMAL is True: 72 | print("skip k-means") 73 | tmp = torch.rand(n_clusters-1).cuda() 74 | return tmp, tmp , 0.5, tmp, 0.01 75 | _cluster_idx, region_saliency = kmeans(X=flat_weight, num_clusters=n_clusters, tol=_tol, \ 76 | distance='euclidean', device=torch.device('cuda'), tqdm_flag=False) 77 | pi_initialization = torch.tensor([torch.true_divide(_cluster_idx.eq(i).sum(), _cluster_idx.numel()) \ 78 | for i in range(n_clusters)], device='cuda') 79 | zero_center_idx = torch.argmin(torch.abs(region_saliency)) 80 | region_saliency_tmp = region_saliency.clone() 81 | region_saliency_zero = region_saliency[zero_center_idx] 82 | region_saliency_tmp[zero_center_idx] = 0.0 83 | pi_zero = pi_initialization[zero_center_idx] 84 | 85 | sigma_tmp = torch.zeros(n_clusters,1).cuda() 86 | for i in range(flat_weight.size(0)): 87 | _idx = _cluster_idx[i] 88 | sigma_tmp[_idx] += (flat_weight[i,0]-region_saliency_tmp[_idx])**2 89 | sigma_initialization = torch.tensor([torch.true_divide(sigma_tmp[i], _cluster_idx.eq(i).sum()-1) \ 90 | for i in range(n_clusters)], device='cuda').sqrt() 91 | sigma_zero = sigma_initialization[zero_center_idx] 92 | sigma_initialization = sigma_initialization[torch.arange(region_saliency.size(0)).cuda() != zero_center_idx] 93 | 94 | pi_initialization = pi_initialization[torch.arange(region_saliency.size(0)).cuda() != zero_center_idx] 95 | region_saliency = region_saliency[torch.arange(region_saliency.size(0)).cuda() != zero_center_idx] # remove zero component center 96 | return region_saliency, pi_initialization, pi_zero, sigma_initialization, sigma_zero 97 | 98 | @torch.no_grad() 99 | def cluster_weights_em(weights, n_clusters): 100 | """ Initialization of GMM with EM algorithm, note this procedure may bring 101 | different initialization results, and the results may be slightly different. 102 | Args: 103 | weights:[weight_size] 104 | n_clusters: 1 + 2^i = K Gaussian Mixture Component number 105 | Returns: 106 | [n_clusters-1] initial region saliency obtained by k-means algorthm 107 | """ 108 | flat_weight = weights.view(-1, 1).contiguous().detach().numpy() 109 | _tol = 1e-5 110 | if cfg.IS_NORMAL is True: 111 | print("skip k-means") 112 | tmp = torch.rand(n_clusters-1).cuda() 113 | return tmp, tmp , 0.5, tmp, 0.01 114 | # construct GMM using EM algorithm 115 | gm = GaussianMixture(n_components=n_clusters, random_state=0, tol=_tol).fit(flat_weight) 116 | region_saliency = torch.from_numpy(gm.means_).view(-1).cuda() 117 | pi_initialization = torch.from_numpy(gm.weights_).cuda() 118 | sigma_initialization = torch.from_numpy(gm.covariances_).view(-1).sqrt().cuda() 119 | 120 | zero_center_idx = torch.argmin(torch.abs(region_saliency)) 121 | pi_zero = pi_initialization[zero_center_idx] 122 | sigma_zero = sigma_initialization[zero_center_idx] 123 | sigma_initialization = sigma_initialization[torch.arange(region_saliency.size(0)).cuda() != zero_center_idx] 124 | pi_initialization = pi_initialization[torch.arange(region_saliency.size(0)).cuda() != zero_center_idx] 125 | region_saliency = region_saliency[torch.arange(region_saliency.size(0)).cuda() != zero_center_idx] # remove zero component center 126 | return region_saliency, pi_initialization, pi_zero, sigma_initialization, sigma_zero 127 | 128 | def get_optimizer(model, args): 129 | train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr}] 130 | optimizer = torch.optim.SGD(train_params, momentum=args.momentum, 131 | weight_decay=args.weight_decay, nesterov=args.nesterov) 132 | return optimizer 133 | 134 | def resume_ckpt(args, model, train_loader, optimizer, lr_scheduler): 135 | if not os.path.isfile(args.resume): 136 | raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume)) 137 | if args.only_inference: 138 | checkpoint = torch.load(args.resume) 139 | if args.cuda: 140 | model.module.load_state_dict(checkpoint) 141 | model = model.cuda() 142 | else: 143 | model.load_state_dict(checkpoint['state_dict']) 144 | model.init_mask_params() 145 | optimizer = get_optimizer(model, args) 146 | lr_scheduler = get_scheduler(args, optimizer, \ 147 | args.lr, len(train_loader)) 148 | best_pred = 0.0 149 | print("=> loaded checkpoint '{}'".format(args.resume)) 150 | else: 151 | checkpoint = torch.load(args.resume) 152 | args.start_epoch = checkpoint['epoch'] 153 | if args.cuda: 154 | model.module.load_state_dict(checkpoint['state_dict']) 155 | model.module.init_mask_params() 156 | optimizer = get_optimizer(model, args) 157 | lr_scheduler = get_scheduler(args, optimizer, \ 158 | args.lr, len(train_loader)) 159 | model = model.cuda() 160 | else: 161 | model.load_state_dict(checkpoint['state_dict']) 162 | model.init_mask_params() 163 | optimizer.load_state_dict(checkpoint['optimizer']) 164 | if args.rt: 165 | best_pred = 0.0 166 | else: 167 | best_pred = checkpoint['best_pred'] 168 | print("=> loaded checkpoint '{}' (epoch {})" 169 | .format(args.resume, checkpoint['epoch'])) 170 | print("Checkpoint Top-1 Acc: ", checkpoint['best_pred']) 171 | return model, optimizer, lr_scheduler, best_pred 172 | 173 | class AverageMeter(object): 174 | """Computes and stores the average and current value 175 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L363-L384 176 | """ 177 | def __init__(self): 178 | self.reset() 179 | 180 | def reset(self): 181 | self.val = 0 182 | self.avg = 0 183 | self.sum = 0 184 | self.count = 0 185 | 186 | def update(self, val, n=1): 187 | self.val = val 188 | self.sum += val * n 189 | self.count += n 190 | self.avg = self.sum / self.count 191 | -------------------------------------------------------------------------------- /utils/PyTransformer/transformers/utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch 4 | import torch.nn as nn 5 | from collections import OrderedDict 6 | 7 | class _ReplaceFunc(object): 8 | """! 9 | This Function replace torch functions with self-define Function. 10 | Inorder to get the imformation of torch model layer infomration. 11 | """ 12 | def __init__(self, ori_func, replace_func, **kwargs): 13 | self.torch_func = ori_func 14 | self.replace_func = replace_func 15 | 16 | def __call__(self, *args, **kwargs): 17 | out = self.replace_func(self.torch_func, *args, **kwargs) 18 | return out 19 | 20 | 21 | class Log(object): 22 | """! 23 | This class use as an log to replace input tensor and store all the information 24 | """ 25 | def __init__(self): 26 | self.graph = OrderedDict() 27 | self.bottoms = OrderedDict() 28 | self.output_shape = OrderedDict() 29 | self.cur_tensor = None 30 | self.cur_id = None 31 | self.tmp_list = None 32 | self.log_init() 33 | 34 | def __len__(self): 35 | """! 36 | Log should be one 37 | """ 38 | return 1 39 | 40 | def __copy__(self): 41 | """! 42 | copy, create new one and assign clone tensor in log 43 | """ 44 | copy_paster = Log() 45 | copy_paster.__dict__.update(self.__dict__) 46 | copy_paster.cur_tensor = self.cur_tensor.clone() 47 | return copy_paster 48 | 49 | def __deepcopy__(self, memo): 50 | """! 51 | deepcopy, create new one and assign clone tensor in log 52 | """ 53 | copy_paster = Log() 54 | copy_paster.__dict__.update(self.__dict__) 55 | copy_paster.cur_tensor = self.cur_tensor.clone() 56 | return copy_paster 57 | 58 | def reset(self): 59 | """ 60 | This function reset all attribute in log. 61 | """ 62 | self.graph = OrderedDict() 63 | self.bottoms = OrderedDict() 64 | self.output_shape = OrderedDict() 65 | self.cur_tensor = None 66 | self.cur_id = None 67 | self.tmp_list = [] 68 | self.log_init() 69 | 70 | 71 | # add data input layer to log 72 | def log_init(self): 73 | """ 74 | Init log attribute, set Data Layer as the first layer 75 | """ 76 | layer_id = "Data" 77 | self.graph[layer_id] = layer_id 78 | self.bottoms[layer_id] = None 79 | self.output_shape[layer_id] = "" 80 | self.cur_id = layer_id 81 | self.tmp_list = [] 82 | 83 | 84 | # for general layer (should has only one input?) 85 | def putLayer(self, layer): 86 | """! 87 | Put genreal layer's information into log 88 | """ 89 | # force use different address id ( prevent use same defined layer more than once, eg: bottleneck in torchvision) 90 | # tmp_layer = copy.deepcopy(layer) 91 | layer_id = id(layer) 92 | self.tmp_list.append(layer) 93 | layer_id = id(self.tmp_list[-1]) 94 | if layer_id in self.graph: 95 | tmp_layer = copy.deepcopy(layer) 96 | self.tmp_list.append(tmp_layer) 97 | # layer_id = id(self.tmp_list[-1]) 98 | layer_id = id(tmp_layer) 99 | 100 | self.graph[layer_id] = layer 101 | self.bottoms[layer_id] = [self.cur_id] 102 | self.cur_id = layer_id 103 | # del layer, tmp_layer, layer_id 104 | 105 | def getGraph(self): 106 | """! 107 | This function get the layers graph from log 108 | """ 109 | return self.graph 110 | 111 | def getBottoms(self): 112 | """! 113 | This function get the layers bottoms from log 114 | """ 115 | return self.bottoms 116 | 117 | def getOutShapes(self): 118 | """! 119 | This function get the layers output shape from log 120 | """ 121 | return self.output_shape 122 | 123 | def getTensor(self): 124 | """! 125 | This function get the layers current tensor (output tensor) 126 | """ 127 | return self.cur_tensor 128 | 129 | def setTensor(self, tensor): 130 | """! 131 | This function set the layer's current tensor 132 | and also change output shape by the input tensor 133 | """ 134 | self.cur_tensor = tensor 135 | if tensor is not None: 136 | self.output_shape[self.cur_id] = self.cur_tensor.size() 137 | else: 138 | self.output_shape[self.cur_id] = None 139 | 140 | 141 | # handle tensor operation(eg: tensor.view) 142 | def __getattr__(self, name): 143 | """! 144 | This function handle all the tensor operation 145 | """ 146 | if name == "__deepcopy__" or name == "__setstate__": 147 | return object.__getattribute__(self, name) 148 | # if get data => get cur_tensor.data 149 | elif name == "data": 150 | return self.cur_tensor.data 151 | 152 | elif hasattr(self.cur_tensor, name): 153 | def wrapper(*args, **kwargs): 154 | func = self.cur_tensor.__getattribute__(name) 155 | out_tensor = func(*args, **kwargs) 156 | 157 | if not isinstance(out_tensor, torch.Tensor): 158 | out_logs = [] 159 | for t in out_tensor: 160 | out_log = copy.deepcopy(self) 161 | out_log.setTensor(t) 162 | out_logs.append(out_log) 163 | 164 | return out_logs 165 | else: 166 | self.cur_tensor = out_tensor 167 | self.output_shape[self.cur_id] = out_tensor.size() 168 | 169 | return self 170 | # print(wrapper) 171 | return wrapper 172 | 173 | # return self 174 | 175 | 176 | else: 177 | return object.__getattribute__(self, name) 178 | 179 | 180 | def __add__(self, other): 181 | """! 182 | Log addition 183 | """ 184 | #print("add") 185 | # merge other branch 186 | self.graph.update(other.graph) 187 | self.bottoms.update(other.bottoms) 188 | self.output_shape.update(other.output_shape) 189 | layer_name = "add_{}".format(len(self.graph)) 190 | self.graph[layer_name] = layer_name 191 | self.bottoms[layer_name] = [self.cur_id, other.cur_id] 192 | self.output_shape[layer_name] = self.cur_tensor.size() 193 | self.cur_id = layer_name 194 | # save memory 195 | del other 196 | 197 | return self 198 | 199 | 200 | def __iadd__(self, other): 201 | """! 202 | Log identity addition 203 | """ 204 | #print("iadd") 205 | # merge other branch 206 | self.graph.update(other.graph) 207 | self.bottoms.update(other.bottoms) 208 | self.output_shape.update(other.output_shape) 209 | layer_name = "iadd_{}".format(len(self.graph)) 210 | self.graph[layer_name] = layer_name 211 | self.bottoms[layer_name] = [self.cur_id, other.cur_id] 212 | self.output_shape[layer_name] = self.cur_tensor.size() 213 | self.cur_id = layer_name 214 | # save memory 215 | del other 216 | return self 217 | 218 | 219 | def __sub__(self, other): 220 | """! 221 | Log substraction 222 | """ 223 | #print("sub") 224 | # merge other branch 225 | self.graph.update(other.graph) 226 | self.bottoms.update(other.bottoms) 227 | self.output_shape.update(other.output_shape) 228 | layer_name = "sub_{}".format(len(self.graph)) 229 | self.graph[layer_name] = layer_name 230 | self.bottoms[layer_name] = [self.cur_id, other.cur_id] 231 | self.output_shape[layer_name] = self.cur_tensor.size() 232 | self.cur_id = layer_name 233 | # save memory 234 | del other 235 | return self 236 | 237 | 238 | def __isub__(self, other): 239 | """! 240 | Log identity substraction 241 | """ 242 | #print("isub") 243 | # merge other branch 244 | self.graph.update(other.graph) 245 | self.bottoms.update(other.bottoms) 246 | self.output_shape.update(other.output_shape) 247 | layer_name = "sub_{}".format(len(self.graph)) 248 | self.graph[layer_name] = layer_name 249 | self.bottoms[layer_name] = [self.cur_id, other.cur_id] 250 | self.output_shape[layer_name] = self.cur_tensor.size() 251 | self.cur_id = layer_name 252 | # save memory 253 | del other 254 | return self 255 | 256 | 257 | def __mul__(self, other): 258 | """! 259 | Log multiplication 260 | """ 261 | #print("mul") 262 | # merge other branch 263 | self.graph.update(other.graph) 264 | self.bottoms.update(other.bottoms) 265 | self.output_shape.update(other.output_shape) 266 | layer_name = "mul_{}".format(len(self.graph)) 267 | self.graph[layer_name] = layer_name 268 | self.bottoms[layer_name] = [self.cur_id, other.cur_id] 269 | self.output_shape[layer_name] = self.cur_tensor.size() 270 | self.cur_id = layer_name 271 | # save memory 272 | del other 273 | return self 274 | 275 | 276 | def __imul__(self, other): 277 | """! 278 | Log identity multiplication 279 | """ 280 | #print("imul") 281 | # merge other branch 282 | self.graph.update(other.graph) 283 | self.bottoms.update(other.bottoms) 284 | self.output_shape.update(other.output_shape) 285 | layer_name = "mul_{}".format(len(self.graph)) 286 | self.graph[layer_name] = layer_name 287 | self.bottoms[layer_name] = [self.cur_id, other.cur_id] 288 | self.output_shape[layer_name] = self.cur_tensor.size() 289 | self.cur_id = layer_name 290 | # save memory 291 | del other 292 | return self 293 | 294 | 295 | def size(self, dim=None): 296 | """! 297 | This function return the size of the tensor by given dim 298 | 299 | @param dim: defult None, return as tensor.size(dim) 300 | 301 | @return tensor size by dim 302 | """ 303 | return self.cur_tensor.size(dim) if dim is not None else self.cur_tensor.size() 304 | 305 | 306 | 307 | class UnitLayer(nn.Module): 308 | """! 309 | This class is an Unit-layer act like an identity layer 310 | """ 311 | def __init__(self, ori_layer): 312 | super(UnitLayer, self).__init__() 313 | self.origin_layer = ori_layer 314 | 315 | 316 | def setOrigin(self, ori_layer): 317 | self.origin_layer = ori_layer 318 | 319 | 320 | # general layer should has only one input? 321 | def forward(self, log, *args): 322 | # prevent overwrite log for other forward flow 323 | cur_log = copy.deepcopy(log) 324 | # print(cur_log) 325 | cur_log.putLayer(self.origin_layer) 326 | 327 | # print(log.cur_tensor) 328 | log_tensor = log.getTensor() 329 | # out_tensor = self.origin_layer(log_tensor).clone().detach() 330 | out_tensor = self.origin_layer(log_tensor).clone() 331 | cur_log.setTensor(out_tensor) 332 | 333 | return cur_log -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020-2021 Runpei Dong, ArChip Lab. All rights reserved. 2 | 3 | Apache License 4 | Version 2.0, January 2004 5 | http://www.apache.org/licenses/ 6 | 7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 8 | 9 | 1. Definitions. 10 | 11 | "License" shall mean the terms and conditions for use, reproduction, 12 | and distribution as defined by Sections 1 through 9 of this document. 13 | 14 | "Licensor" shall mean the copyright owner or entity authorized by 15 | the copyright owner that is granting the License. 16 | 17 | "Legal Entity" shall mean the union of the acting entity and all 18 | other entities that control, are controlled by, or are under common 19 | control with that entity. For the purposes of this definition, 20 | "control" means (i) the power, direct or indirect, to cause the 21 | direction or management of such entity, whether by contract or 22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 23 | outstanding shares, or (iii) beneficial ownership of such entity. 24 | 25 | "You" (or "Your") shall mean an individual or Legal Entity 26 | exercising permissions granted by this License. 27 | 28 | "Source" form shall mean the preferred form for making modifications, 29 | including but not limited to software source code, documentation 30 | source, and configuration files. 31 | 32 | "Object" form shall mean any form resulting from mechanical 33 | transformation or translation of a Source form, including but 34 | not limited to compiled object code, generated documentation, 35 | and conversions to other media types. 36 | 37 | "Work" shall mean the work of authorship, whether in Source or 38 | Object form, made available under the License, as indicated by a 39 | copyright notice that is included in or attached to the work 40 | (an example is provided in the Appendix below). 41 | 42 | "Derivative Works" shall mean any work, whether in Source or Object 43 | form, that is based on (or derived from) the Work and for which the 44 | editorial revisions, annotations, elaborations, or other modifications 45 | represent, as a whole, an original work of authorship. For the purposes 46 | of this License, Derivative Works shall not include works that remain 47 | separable from, or merely link (or bind by name) to the interfaces of, 48 | the Work and Derivative Works thereof. 49 | 50 | "Contribution" shall mean any work of authorship, including 51 | the original version of the Work and any modifications or additions 52 | to that Work or Derivative Works thereof, that is intentionally 53 | submitted to Licensor for inclusion in the Work by the copyright owner 54 | or by an individual or Legal Entity authorized to submit on behalf of 55 | the copyright owner. For the purposes of this definition, "submitted" 56 | means any form of electronic, verbal, or written communication sent 57 | to the Licensor or its representatives, including but not limited to 58 | communication on electronic mailing lists, source code control systems, 59 | and issue tracking systems that are managed by, or on behalf of, the 60 | Licensor for the purpose of discussing and improving the Work, but 61 | excluding communication that is conspicuously marked or otherwise 62 | designated in writing by the copyright owner as "Not a Contribution." 63 | 64 | "Contributor" shall mean Licensor and any individual or Legal Entity 65 | on behalf of whom a Contribution has been received by Licensor and 66 | subsequently incorporated within the Work. 67 | 68 | 2. Grant of Copyright License. Subject to the terms and conditions of 69 | this License, each Contributor hereby grants to You a perpetual, 70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 71 | copyright license to reproduce, prepare Derivative Works of, 72 | publicly display, publicly perform, sublicense, and distribute the 73 | Work and such Derivative Works in Source or Object form. 74 | 75 | 3. Grant of Patent License. Subject to the terms and conditions of 76 | this License, each Contributor hereby grants to You a perpetual, 77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 78 | (except as stated in this section) patent license to make, have made, 79 | use, offer to sell, sell, import, and otherwise transfer the Work, 80 | where such license applies only to those patent claims licensable 81 | by such Contributor that are necessarily infringed by their 82 | Contribution(s) alone or by combination of their Contribution(s) 83 | with the Work to which such Contribution(s) was submitted. If You 84 | institute patent litigation against any entity (including a 85 | cross-claim or counterclaim in a lawsuit) alleging that the Work 86 | or a Contribution incorporated within the Work constitutes direct 87 | or contributory patent infringement, then any patent licenses 88 | granted to You under this License for that Work shall terminate 89 | as of the date such litigation is filed. 90 | 91 | 4. Redistribution. You may reproduce and distribute copies of the 92 | Work or Derivative Works thereof in any medium, with or without 93 | modifications, and in Source or Object form, provided that You 94 | meet the following conditions: 95 | 96 | (a) You must give any other recipients of the Work or 97 | Derivative Works a copy of this License; and 98 | 99 | (b) You must cause any modified files to carry prominent notices 100 | stating that You changed the files; and 101 | 102 | (c) You must retain, in the Source form of any Derivative Works 103 | that You distribute, all copyright, patent, trademark, and 104 | attribution notices from the Source form of the Work, 105 | excluding those notices that do not pertain to any part of 106 | the Derivative Works; and 107 | 108 | (d) If the Work includes a "NOTICE" text file as part of its 109 | distribution, then any Derivative Works that You distribute must 110 | include a readable copy of the attribution notices contained 111 | within such NOTICE file, excluding those notices that do not 112 | pertain to any part of the Derivative Works, in at least one 113 | of the following places: within a NOTICE text file distributed 114 | as part of the Derivative Works; within the Source form or 115 | documentation, if provided along with the Derivative Works; or, 116 | within a display generated by the Derivative Works, if and 117 | wherever such third-party notices normally appear. The contents 118 | of the NOTICE file are for informational purposes only and 119 | do not modify the License. You may add Your own attribution 120 | notices within Derivative Works that You distribute, alongside 121 | or as an addendum to the NOTICE text from the Work, provided 122 | that such additional attribution notices cannot be construed 123 | as modifying the License. 124 | 125 | You may add Your own copyright statement to Your modifications and 126 | may provide additional or different license terms and conditions 127 | for use, reproduction, or distribution of Your modifications, or 128 | for any such Derivative Works as a whole, provided Your use, 129 | reproduction, and distribution of the Work otherwise complies with 130 | the conditions stated in this License. 131 | 132 | 5. Submission of Contributions. Unless You explicitly state otherwise, 133 | any Contribution intentionally submitted for inclusion in the Work 134 | by You to the Licensor shall be under the terms and conditions of 135 | this License, without any additional terms or conditions. 136 | Notwithstanding the above, nothing herein shall supersede or modify 137 | the terms of any separate license agreement you may have executed 138 | with Licensor regarding such Contributions. 139 | 140 | 6. Trademarks. This License does not grant permission to use the trade 141 | names, trademarks, service marks, or product names of the Licensor, 142 | except as required for reasonable and customary use in describing the 143 | origin of the Work and reproducing the content of the NOTICE file. 144 | 145 | 7. Disclaimer of Warranty. Unless required by applicable law or 146 | agreed to in writing, Licensor provides the Work (and each 147 | Contributor provides its Contributions) on an "AS IS" BASIS, 148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 149 | implied, including, without limitation, any warranties or conditions 150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 151 | PARTICULAR PURPOSE. You are solely responsible for determining the 152 | appropriateness of using or redistributing the Work and assume any 153 | risks associated with Your exercise of permissions under this License. 154 | 155 | 8. Limitation of Liability. In no event and under no legal theory, 156 | whether in tort (including negligence), contract, or otherwise, 157 | unless required by applicable law (such as deliberate and grossly 158 | negligent acts) or agreed to in writing, shall any Contributor be 159 | liable to You for damages, including any direct, indirect, special, 160 | incidental, or consequential damages of any character arising as a 161 | result of this License or out of the use or inability to use the 162 | Work (including but not limited to damages for loss of goodwill, 163 | work stoppage, computer failure or malfunction, or any and all 164 | other commercial damages or losses), even if such Contributor 165 | has been advised of the possibility of such damages. 166 | 167 | 9. Accepting Warranty or Additional Liability. While redistributing 168 | the Work or Derivative Works thereof, You may choose to offer, 169 | and charge a fee for, acceptance of support, warranty, indemnity, 170 | or other liability obligations and/or rights consistent with this 171 | License. However, in accepting such obligations, You may act only 172 | on Your own behalf and on Your sole responsibility, not on behalf 173 | of any other Contributor, and only if You agree to indemnify, 174 | defend, and hold each Contributor harmless for any liability 175 | incurred by, or claims asserted against, such Contributor by reason 176 | of your accepting any such warranty or additional liability. 177 | 178 | END OF TERMS AND CONDITIONS 179 | 180 | APPENDIX: How to apply the Apache License to your work. 181 | 182 | To apply the Apache License to your work, attach the following 183 | boilerplate notice, with the fields enclosed by brackets "[]" 184 | replaced with your own identifying information. (Don't include 185 | the brackets!) The text should be enclosed in the appropriate 186 | comment syntax for the file format. We also recommend that a 187 | file or class name and description of purpose be included on the 188 | same "printed page" as the copyright notice for easier 189 | identification within third-party archives. 190 | 191 | Copyright (c) 2020-2021 Runpei Dong, ArChip Lab. 192 | 193 | Licensed under the Apache License, Version 2.0 (the "License"); 194 | you may not use this file except in compliance with the License. 195 | You may obtain a copy of the License at 196 | 197 | http://www.apache.org/licenses/LICENSE-2.0 198 | 199 | Unless required by applicable law or agreed to in writing, software 200 | distributed under the License is distributed on an "AS IS" BASIS, 201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 202 | See the License for the specific language governing permissions and 203 | limitations under the License. -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Runpei Dong, ArChip Lab. 2 | # This source code is licensed under the Apache 2.0 license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | import torch 6 | 7 | import argparse 8 | import time 9 | import os 10 | import sys 11 | import math 12 | 13 | import torch.nn as nn 14 | import config as cfg 15 | 16 | from tqdm import tqdm 17 | from mypath import Path 18 | from dataloader import make_data_loader 19 | from modeling import DGMSNet 20 | from modeling.DGMS import DGMSConv 21 | from utils.sparsity import SparsityMeasure 22 | from utils.lr_scheduler import get_scheduler 23 | from utils.PyTransformer.transformers.torchTransformer import TorchTransformer 24 | from utils.summaries import TensorboardSummary 25 | from utils.metrics import Evaluator 26 | from utils.saver import Saver 27 | from utils.misc import AverageMeter, get_optimizer, resume_ckpt 28 | from utils.loss import * 29 | 30 | class Trainer(object): 31 | def __init__(self, args): 32 | self.args = args 33 | cfg.set_config(args) 34 | 35 | self.saver = Saver(args) 36 | self.saver.save_experiment_config() 37 | 38 | self.summary = TensorboardSummary(self.saver.experiment_dir) 39 | self.writer = self.summary.create_summary() 40 | 41 | kwargs = {'num_workers': args.workers, 'pin_memory': True} 42 | self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs) 43 | 44 | model = DGMSNet(args, args.freeze_bn) 45 | 46 | if args.mask: 47 | print("DGMS Conv!") 48 | _transformer = TorchTransformer() 49 | _transformer.register(nn.Conv2d, DGMSConv) 50 | model = _transformer.trans_layers(model) 51 | else: 52 | print("Normal Conv!") 53 | 54 | print(' Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0)) 55 | 56 | cfg.IS_NORMAL = True if (args.resume is not None) else False 57 | optimizer = get_optimizer(model, args) 58 | cfg.IS_NORMAL = self.args.normal 59 | self.model, self.optimizer = model, optimizer 60 | 61 | self.criterion = nn.CrossEntropyLoss() 62 | self.sparsity = SparsityMeasure(args) 63 | 64 | self.lr_scheduler = get_scheduler(args, self.optimizer, args.lr, len(self.train_loader)) 65 | 66 | self.evaluator = Evaluator(self.nclass, self.args) 67 | 68 | if args.cuda: 69 | torch.backends.cudnn.benchmark=True 70 | self.model = torch.nn.parallel.DataParallel(self.model, device_ids=self.args.gpu_ids) 71 | self.model = self.model.cuda() 72 | 73 | self.best_top1 = 0.0 74 | self.best_top5 = 0.0 75 | self.best_sparse_ratio = 0.0 76 | self.this_sparsity = 0.0 77 | self.best_params = 0.0 78 | if args.resume is not None: 79 | self.model, self.optimizer, self.lr_scheduler, self.best_top1 = \ 80 | resume_ckpt(args, self.model, self.train_loader, self.optimizer, self.lr_scheduler) 81 | 82 | print(' Total params (+GMM) : %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0)) 83 | 84 | if args.rt: 85 | args.start_epoch = 0 86 | 87 | def training(self, epoch): 88 | cfg.set_status(True) 89 | self.model.train() 90 | 91 | batch_time = AverageMeter() 92 | data_time = AverageMeter() 93 | losses = AverageMeter() 94 | top1 = AverageMeter() 95 | top5 = AverageMeter() 96 | end = time.time() 97 | 98 | train_loss = 0.0 99 | num_img_tr = len(self.train_loader) 100 | 101 | tbar = tqdm(self.train_loader) 102 | for i, (image, target) in enumerate(tbar): 103 | data_time.update(time.time() - end) 104 | if self.args.cuda: 105 | image, target = image.cuda(), target.cuda() 106 | outputs = self.model(image) 107 | loss = self.criterion(outputs, target) 108 | 109 | prec1, prec5 = self.evaluator.Accuracy(outputs.data, target.data, topk=(1, 5)) 110 | losses.update(loss.item(), image.size(0)) 111 | top1.update(prec1.item(), image.size(0)) 112 | top5.update(prec5.item(), image.size(0)) 113 | 114 | self.optimizer.zero_grad() 115 | 116 | loss.backward(retain_graph=True) 117 | self.optimizer.step() 118 | self.lr_scheduler.step() 119 | batch_time.update(time.time() - end) 120 | end = time.time() 121 | 122 | train_loss = (loss.item() + train_loss) 123 | tbar.set_description('Train Loss: {loss:.4f} | T1: {top1: .3f} | T5: {top5: .2f} | best T1: {pre_best:.2f} T5: {best_top5:.2f} NZ: {nz_val:.4f} #Params: {params:.2f}M | lr: {_lr:.8f}' 124 | .format(loss=losses.avg, 125 | top1=top1.avg, 126 | top5=top5.avg, 127 | pre_best=self.best_top1, 128 | best_top5=self.best_top5, 129 | nz_val=1-self.best_sparse_ratio, 130 | params=self.best_params, 131 | _lr=self.optimizer.param_groups[0]['lr'] 132 | )) 133 | self.writer.add_scalar('train/train_loss_iter', loss.item(), i + num_img_tr * epoch) 134 | self.writer.add_scalar('train/top1', top1.avg, i + num_img_tr * epoch) 135 | self.writer.add_scalar('train/top5', top5.avg, i + num_img_tr * epoch) 136 | self.writer.add_scalar('train/lr', self.optimizer.param_groups[0]['lr'], i + num_img_tr * epoch) 137 | 138 | self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch) 139 | print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0])) 140 | print('Loss: %.3f' % train_loss) 141 | 142 | return (losses.avg, top1.avg) 143 | 144 | def validation(self, epoch): 145 | cfg.set_status(False) 146 | num_img_tr = len(self.val_loader) 147 | 148 | batch_time = AverageMeter() 149 | data_time = AverageMeter() 150 | losses = AverageMeter() 151 | top1 = AverageMeter() 152 | top5 = AverageMeter() 153 | 154 | self.model.eval() 155 | self.evaluator.reset() 156 | tbar = tqdm(self.val_loader, desc='\r') 157 | test_loss = 0.0 158 | end = time.time() 159 | for i, (image, target) in enumerate(tbar): 160 | data_time.update(time.time() - end) 161 | 162 | if self.args.cuda: 163 | image, target = image.cuda(), target.cuda() 164 | with torch.no_grad(): 165 | outputs = self.model(image) 166 | loss = self.criterion(outputs, target) 167 | test_loss += loss.item() 168 | 169 | prec1, prec5 = self.evaluator.Accuracy(outputs.data, target.data, topk=(1, 5)) 170 | losses.update(loss.item(), image.size(0)) 171 | top1.update(prec1.item(), image.size(0)) 172 | top5.update(prec5.item(), image.size(0)) 173 | 174 | batch_time.update(time.time() - end) 175 | end = time.time() 176 | 177 | tbar.set_description('({batch}/{size}) Test Loss: {loss:.4f} | Top1: {top1: .4f} | Top5: {top5: .4f}' 178 | .format(batch=i + 1, 179 | size=len(self.val_loader), 180 | loss=losses.avg, 181 | top1=top1.avg, 182 | top5=top5.avg, 183 | )) 184 | self.writer.add_scalar('val/val_loss_iter', loss.item(), i + num_img_tr * epoch) 185 | self.writer.add_scalar('val/top1', top1.avg, i + num_img_tr * epoch) 186 | self.writer.add_scalar('val/top5', top5.avg, i + num_img_tr * epoch) 187 | if self.args.show_info: 188 | self.this_sparsity, this_params = self.sparsity.check_sparsity_per_layer(self.model) 189 | self.writer.add_scalar('val/total_sparsity', self.this_sparsity, epoch) 190 | new_pred = top1.avg 191 | if new_pred > self.best_top1 and not self.args.only_inference: 192 | is_best = True 193 | self.best_top1 = new_pred 194 | self.best_params = this_params 195 | self.best_top5 = top5.avg 196 | self.best_sparse_ratio = self.this_sparsity 197 | bitwidth = math.floor(math.log(cfg.K_LEVEL, 2)) 198 | self.saver.save_checkpoint({ 199 | 'epoch': epoch + 1, 200 | 'state_dict': self.model.module.state_dict(), 201 | 'optimizer': self.optimizer.state_dict(), 202 | 'best_top1': self.best_top1, 203 | 'best_top5': self.best_top5, 204 | 'params': self.best_params, 205 | 'bits': bitwidth, 206 | 'CR': 1/((1-self.best_sparse_ratio) * bitwidth / 32), 207 | 'according_sparsity': self.this_sparsity, 208 | }, is_best) 209 | return (losses.avg, top1.avg) 210 | 211 | def main(): 212 | parser = argparse.ArgumentParser(description="Differentiable Gaussian Mixture Weight Sharing (DGMS)", 213 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 214 | parser.add_argument('-b', '--network', type=str, default='resnet18', 215 | choices=['resnet18', 'resnet50', 'mnasnet', 'proxylessnas', 216 | 'resnet20', 'resnet32', 'resnet56', 'vggsmall'], 217 | help='network name (default: resnet18)') 218 | parser.add_argument('-d', '--dataset', type=str, default='imagenet', 219 | choices=['cifar10', 'imagenet', 'cars', 'cub200', 'aircraft'], 220 | help='dataset name (default: imgenet)') 221 | parser.add_argument('-j', '--workers', type=int, default=4, 222 | metavar='N', help='dataloader threads') 223 | parser.add_argument('--base-size', type=int, default=32, 224 | help='base image size') 225 | parser.add_argument('--crop-size', type=int, default=32, 226 | help='crop image size') 227 | parser.add_argument('--sync-bn', type=bool, default=None, 228 | help='whether to use sync bn (default: auto)') 229 | parser.add_argument('--freeze-bn', type=bool, default=False, 230 | help='whether to freeze bn parameters (default: False)') 231 | parser.add_argument('--train-dir', type=str, default=None, 232 | help='training set directory (default: None)') 233 | parser.add_argument('--val-dir', type=str, default='None', 234 | help='validation set directory (default: None)') 235 | parser.add_argument('--num-classes', type=int, default=1000, 236 | help='Number of classes (default: 1000)') 237 | parser.add_argument('--show-info', action='store_true', default=False, 238 | help='set if show model compression info (default: False)') 239 | 240 | # training hyper params 241 | parser.add_argument('--epochs', type=int, default=None, metavar='N', 242 | help='number of epochs to train (default: auto)') 243 | parser.add_argument('--start_epoch', type=int, default=0, 244 | metavar='N', help='start epochs (default:0)') 245 | parser.add_argument('--batch-size', type=int, default=256, metavar='N', 246 | help='input batch size for training (default: 256)') 247 | parser.add_argument('--test-batch-size', type=int, default=256, metavar='N', 248 | help='input batch size for testing (default: 256)') 249 | # model params 250 | parser.add_argument('--K', type=int, default=16, metavar='K', 251 | help='number of GMM components (default: 2^4=16)') 252 | parser.add_argument('--tau', type=float, default=0.01, metavar='TAU', 253 | help='gumbel softmax temperature (default: 0.01)') 254 | parser.add_argument('--normal', action='store_true', default=False, 255 | help='whether train noramlly (default: False)') 256 | parser.add_argument('--empirical', type=bool, default=False, 257 | help='whether use empirical initialization for parameter sigma (default: False)') 258 | parser.add_argument('--mask', action='store_true', default=False, 259 | help='whether transform normal convolution into DGMS convolution (default: False)') 260 | # optimizer params 261 | parser.add_argument('--lr', type=float, default=2e-5, metavar='LR', 262 | help='learning rate (default: 2e-5)') 263 | parser.add_argument('--lr-scheduler', type=str, default='one-cycle', 264 | choices=['one-cycle', 'cosine', 'multi-step', 'reduce'], 265 | help='lr scheduler mode: (default: one-cycle)') 266 | parser.add_argument('--schedule', type=str, default='70,140,190') 267 | parser.add_argument('--momentum', type=float, default=0.9, 268 | metavar='M', help='momentum (default: 0.9)') 269 | parser.add_argument('--weight-decay', type=float, default=5e-4, 270 | metavar='M', help='w-decay (default: 5e-4)') 271 | parser.add_argument('--nesterov', action='store_true', default=False, 272 | help='whether use nesterov (default: False)') 273 | # cuda, seed and logging 274 | parser.add_argument('--no-cuda', action='store_true', default=False, 275 | help='disables CUDA training (default: False)') 276 | parser.add_argument('--gpu-ids', type=str, default='0', 277 | help='use which gpu to train, must be a \ 278 | comma-separated list of integers only (default=0)') 279 | parser.add_argument('--seed', type=int, default=1, metavar='S', 280 | help='random seed (default: 1)') 281 | # checking point 282 | parser.add_argument('--resume', type=str, default=None, 283 | help='put the path to resuming file if needed') 284 | parser.add_argument('--checkname', type=str, default="Experiments", 285 | help='set the checkpoint name') 286 | parser.add_argument('--pretrained', action='store_true', default=True, 287 | help='set if use a pretrained network') 288 | # re-train a pre-trained model 289 | parser.add_argument('--rt', action='store_true', default=False, 290 | help='retraining model for quantization') 291 | # evaluation option 292 | parser.add_argument('--eval-interval', type=int, default=1, 293 | help='evaluuation interval (default: 1)') 294 | parser.add_argument('--only-inference', type=bool, default=False, 295 | help='skip training and only inference') 296 | 297 | args = parser.parse_args() 298 | args.schedule = [int(s) for s in args.schedule.split(',')] 299 | args.cuda = not args.no_cuda and torch.cuda.is_available() 300 | if args.cuda: 301 | try: 302 | args.gpu_ids = [int(s) for s in args.gpu_ids.split(',')] 303 | except ValueError: 304 | raise ValueError("Argument --gpu_ids must be a comma-separeted list of integers only") 305 | if args.sync_bn is None: 306 | if args.cuda and len(args.gpu_ids) > 1: 307 | args.sync_bn = True 308 | else: 309 | args.sync_bn = False 310 | 311 | if args.epochs is None: 312 | args.epochs = cfg.EPOCH[args.dataset.lower()] 313 | 314 | if args.num_classes is None: 315 | args.num_classes = cfg.NUM_CLASSES[args.dataset.lower()] 316 | 317 | if args.train_dir is None or args.val_dir is None: 318 | args.train_dir, args.val_dir = Path.db_root_dir(args.dataset.lower()), Path.db_root_dir(args.dataset.lower()) 319 | 320 | print(args) 321 | torch.manual_seed(args.seed) 322 | trainer = Trainer(args) 323 | print('Starting Epoch:', trainer.args.start_epoch) 324 | print('Total Epoches:', trainer.args.epochs) 325 | if args.only_inference: 326 | print("Only inference with given resumed model...") 327 | val_loss, val_acc = trainer.validation(0) 328 | return 329 | for epoch in range(trainer.args.start_epoch, trainer.args.epochs): 330 | train_loss, train_acc = trainer.training(epoch) 331 | if epoch % args.eval_interval == (args.eval_interval - 1): 332 | val_loss, val_acc = trainer.validation(epoch) 333 | nz_val = 1 - trainer.best_sparse_ratio 334 | params_val = trainer.best_params 335 | compression_rate = 1/(nz_val * (math.floor(math.log(cfg.K_LEVEL, 2)) / 32)) 336 | print(f"Best Top-1: {trainer.best_top1} | Top-5: {trainer.best_top5} | NZ: {nz_val} | #Params: {params_val:.2f}M | CR: {compression_rate:.2f}") 337 | 338 | trainer.writer.close() 339 | 340 | if __name__ == '__main__': 341 | main() 342 | -------------------------------------------------------------------------------- /utils/PyTransformer/transformers/torchTransformer.py: -------------------------------------------------------------------------------- 1 | import time 2 | import copy 3 | import types 4 | import inspect 5 | from collections import OrderedDict 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | import pydot 12 | import config as cfg 13 | from graphviz import Digraph 14 | 15 | from .utils import _ReplaceFunc, Log, UnitLayer 16 | 17 | 18 | 19 | class TorchTransformer(nn.Module): 20 | """! 21 | This class handle layer swap, summary, visualization of the input model 22 | """ 23 | def __init__(self): 24 | super(TorchTransformer, self).__init__() 25 | 26 | self._register_dict = OrderedDict() 27 | self.log = Log() 28 | self._raw_TrochFuncs = OrderedDict() 29 | self._raw_TrochFunctionals = OrderedDict() 30 | 31 | # register class to trans 32 | def register(self, origin_class, target_class): 33 | """! 34 | This function register which class should transform to target class. 35 | """ 36 | print("register", origin_class, target_class) 37 | 38 | self._register_dict[origin_class] = target_class 39 | 40 | pass 41 | 42 | def trans_layers(self, model, update = True): 43 | """! 44 | This function transform layer by layers in register dictionarys 45 | 46 | @param model: input model to transfer 47 | 48 | @param update: default is True, wether to update the paramter from the orign layer or not. 49 | Note that it will update matched parameters only. 50 | 51 | @return transfered model 52 | """ 53 | # print("trans layer") 54 | if len(self._register_dict) == 0: 55 | print("No layer to swap") 56 | print("Please use register( {origin_layer}, {target_layer} ) to register layer") 57 | return model 58 | else: 59 | for module_name in model._modules: 60 | # has children 61 | if len(model._modules[module_name]._modules) > 0: 62 | self.trans_layers(model._modules[module_name]) 63 | else: 64 | if type(getattr(model, module_name)) in self._register_dict: 65 | cfg.count_layer() 66 | if cfg.KEEP: 67 | if cfg.L_CNT in cfg.SKIPPED_LAYERS: 68 | continue 69 | # use inspect.signature to know args and kwargs of __init__ 70 | _sig = inspect.signature(type(getattr(model, module_name))) 71 | _kwargs = {} 72 | for key in _sig.parameters: 73 | if _sig.parameters[key].default == inspect.Parameter.empty: #args 74 | # assign args 75 | # default values should be handled more properly, unknown data type might be an issue 76 | if 'kernel' in key: 77 | # _sig.parameters[key].replace(default=inspect.Parameter.empty, annotation=3) 78 | value = 3 79 | elif 'channel' in key: 80 | # _sig.parameters[key].replace(default=inspect.Parameter.empty, annotation=32) 81 | value = 32 82 | else: 83 | # _sig.parameters[key].replace(default=inspect.Parameter.empty, annotation=None) 84 | # value = None 85 | value = 1000 86 | 87 | _kwargs[key] = value 88 | 89 | _attr_dict = getattr(model, module_name).__dict__ 90 | _layer_new = self._register_dict[type(getattr(model, module_name))](**_kwargs) # only give positional args 91 | _layer_new.__dict__.update(_attr_dict) 92 | 93 | setattr(model, module_name, _layer_new) 94 | 95 | return model 96 | 97 | 98 | 99 | 100 | def summary(self, model = None, input_tensor = None): 101 | """! 102 | This function act like keras summary function 103 | 104 | @param model: input model to summary 105 | 106 | @param input_tensor: input data of the model to forward 107 | 108 | """ 109 | # input_tensor = torch.randn([1, 3, 224, 224]) 110 | # input_tensor = input_tensor.cuda() 111 | 112 | 113 | self._build_graph(model, input_tensor) 114 | 115 | # get dicts and variables 116 | model_graph = self.log.getGraph() 117 | bottoms_graph = self.log.getBottoms() 118 | output_shape_graph = self.log.getOutShapes() 119 | # store top names for bottoms 120 | topNames = OrderedDict() 121 | totoal_trainable_params = 0 122 | total_params = 0 123 | # loop graph 124 | print("##########################################################################################") 125 | line_title = "{:>5}| {:<15} | {:<15} {:<25} {:<15}".format("Index","Layer (type)", "Bottoms","Output Shape", "Param #") 126 | print(line_title) 127 | print("---------------------------------------------------------------------------") 128 | 129 | 130 | for layer_index, key in enumerate(model_graph): 131 | 132 | # data layer 133 | if bottoms_graph[key] is None: 134 | # Layer information 135 | layer = model_graph[key] 136 | layer_type = layer.__class__.__name__ 137 | if layer_type == "str": 138 | layer_type = key 139 | else: 140 | layer_type = layer.__class__.__name__ + "_{}".format(layer_index) 141 | 142 | topNames[key] = layer_type 143 | 144 | # Layer Output shape 145 | output_shape = "[{}]".format(tuple(output_shape_graph[key])) 146 | 147 | # Layer Params 148 | param_weight_num = 0 149 | if hasattr(layer, "weight") and hasattr(layer.weight, "size"): 150 | param_weight_num += torch.prod(torch.LongTensor(list(layer.weight.size()))) 151 | if layer.weight.requires_grad: 152 | totoal_trainable_params += param_weight_num 153 | if hasattr(layer, "bias") and hasattr(layer.weight, "bias"): 154 | param_weight_num += torch.prod(torch.LongTensor(list(layer.bias.size()))) 155 | if layer.bias.requires_grad: 156 | totoal_trainable_params += param_weight_num 157 | 158 | total_params += param_weight_num 159 | 160 | new_layer = "{:5}| {:<15} | {:<15} {:<25} {:<15}".format(layer_index, layer_type, "", output_shape, param_weight_num) 161 | print(new_layer) 162 | 163 | else: 164 | # Layer Information 165 | layer = model_graph[key] 166 | layer_type = layer.__class__.__name__ 167 | 168 | # add, sub, mul...,etc. (custom string) 169 | if layer_type == "str": 170 | # the key should be XXX_{idx_prevent_duplicate} 171 | tmp_key = key.split("_") 172 | tmp_key[-1] = "_{}".format(layer_index) 173 | tmp_key = "".join(tmp_key) 174 | layer_type = tmp_key 175 | else: 176 | layer_type = layer.__class__.__name__ + "_{}".format(layer_index) 177 | 178 | topNames[key] = layer_type 179 | 180 | # Layer Bottoms 181 | bottoms = [] 182 | for b_key in bottoms_graph[key]: 183 | bottom = topNames[b_key] 184 | bottoms.append(bottom) 185 | 186 | # Layer Output Shape 187 | if key in output_shape_graph: 188 | output_shape = "[{}]".format(tuple(output_shape_graph[key])) 189 | else: 190 | output_shape = "None" 191 | 192 | # Layer Params 193 | param_weight_num = 0 194 | if hasattr(layer, "weight") and hasattr(layer.weight, "size"): 195 | param_weight_num += torch.prod(torch.LongTensor(list(layer.weight.size()))) 196 | if layer.weight.requires_grad: 197 | totoal_trainable_params += param_weight_num 198 | if hasattr(layer, "bias") and hasattr(layer.weight, "bias"): 199 | param_weight_num += torch.prod(torch.LongTensor(list(layer.bias.size()))) 200 | if layer.bias.requires_grad: 201 | totoal_trainable_params += param_weight_num 202 | total_params += param_weight_num 203 | 204 | # Print (one bottom a line) 205 | for idx, b in enumerate(bottoms): 206 | # if more than one bottom, only print bottom 207 | if idx == 0: 208 | new_layer = "{:>5}| {:<15} | {:<15} {:<25} {:<15}".format(layer_index, layer_type, b, output_shape, param_weight_num) 209 | else: 210 | new_layer = "{:>5}| {:<15} | {:<15} {:<25} {:<15}".format("", "", b, "", "") 211 | print(new_layer) 212 | print("---------------------------------------------------------------------------") 213 | 214 | 215 | # total information 216 | print("==================================================================================") 217 | print("Total Trainable params: {} ".format(totoal_trainable_params)) 218 | print("Total Non-Trainable params: {} ".format(total_params - totoal_trainable_params)) 219 | print("Total params: {} ".format(total_params)) 220 | 221 | # del model_graph, bottoms_graph, output_shape_graph, topNames 222 | # return model 223 | 224 | def visualize(self, model = None, input_tensor = None, save_name = None, graph_size = 30): 225 | """! 226 | This functin visualize the model architecture 227 | 228 | @param model: input model to summary 229 | 230 | @param input_tensor: input data of the model to forward 231 | 232 | @param save_name: if save_name is not None, it will save as '{save_name}.png' 233 | 234 | @param graph_size: graph_size for graphviz, to help increase the resolution of the output graph 235 | 236 | @return dot, graphviz's Digraph element 237 | """ 238 | # input_tensor = torch.randn([1, 3, 224, 224]) 239 | # model_graph = self.log.getGraph() 240 | 241 | # if graph empty 242 | if model is None: 243 | # check if use self modules 244 | if len(self._modules) > 0: 245 | self._build_graph(self, input_tensor) 246 | else: 247 | raise ValueError("Please input model to visualize") 248 | else: 249 | self._build_graph(model, input_tensor) 250 | 251 | # graph 252 | node_attr = dict(style='filled', 253 | shape='box', 254 | align='left', 255 | fontsize='30', 256 | ranksep='0.1', 257 | height='0.2') 258 | 259 | dot = Digraph(node_attr=node_attr, graph_attr=dict(size="{},{}".format(graph_size, graph_size))) 260 | 261 | # get dicts and variables 262 | model_graph = self.log.getGraph() 263 | bottoms_graph = self.log.getBottoms() 264 | output_shape_graph = self.log.getOutShapes() 265 | topNames = OrderedDict() 266 | 267 | for layer_index, key in enumerate(model_graph): 268 | # Input Data layer 269 | if bottoms_graph[key] is None: 270 | layer = model_graph[key] 271 | layer_type = layer.__class__.__name__ 272 | # add, sub, mul...,etc. (custom string) 273 | if layer_type == "str": 274 | layer_type = key 275 | else: 276 | layer_type = layer.__class__.__name__ + "_{}".format(layer_index) 277 | 278 | output_shape = "{}".format(tuple(output_shape_graph[key])) 279 | topNames[key] = layer_type 280 | output_shape = "[{}]".format(tuple(output_shape_graph[key])) 281 | layer_type = layer_type + "\nShape: " + output_shape 282 | 283 | dot.node(str(key), layer_type, fillcolor='orange') 284 | else: 285 | # Layer Information 286 | layer = model_graph[key] 287 | layer_type = layer.__class__.__name__ 288 | # add, sub, mul...,etc. (custom string) 289 | if layer_type == "str": 290 | # the key should be XXX_{idx_prevent_duplicate} 291 | tmp_key = key.split("_") 292 | tmp_key[-1] = "_{}".format(layer_index) 293 | tmp_key = "".join(tmp_key) 294 | layer_type = tmp_key 295 | else: 296 | layer_type = layer.__class__.__name__ + "_{}".format(layer_index) 297 | 298 | topNames[key] = layer_type 299 | # layer_type = layer_type 300 | # print("Layer: {}".format(layer_type)) 301 | # print("Key: {}".format(key)) 302 | # add bottoms 303 | 304 | layer_type = layer_type + "\nBottoms: " 305 | for b_key in bottoms_graph[key]: 306 | layer_type = layer_type + topNames[b_key] + "\n" 307 | 308 | output_shape = "[{}]".format(tuple(output_shape_graph[key])) 309 | layer_type = layer_type + "Shape: " + output_shape 310 | 311 | dot.node(str(key), layer_type, fillcolor='orange') 312 | # link bottoms 313 | # print("Bottoms: ") 314 | for bot_key in bottoms_graph[key]: 315 | # print(bot_key) 316 | dot.edge(str(bot_key), str(key)) 317 | 318 | # return graph 319 | if save_name is not None: 320 | (graph,) = pydot.graph_from_dot_data(dot.source) 321 | graph.write_png(save_name + ".png" ) 322 | return dot 323 | 324 | def _build_graph(self, model, input_tensor = None): 325 | 326 | if input_tensor is None: 327 | raise ValueError("Please set input tensor") 328 | 329 | # reset log 330 | self.log = Log() 331 | # add Data input 332 | self.log.setTensor(input_tensor) 333 | 334 | 335 | tmp_model = self._trans_unit(copy.deepcopy(model)) 336 | 337 | for f in dir(torch): 338 | 339 | # if private function, pass 340 | if f.startswith("_") or "tensor" == f: 341 | continue 342 | if isinstance(getattr(torch, f) ,types.BuiltinMethodType) or isinstance(getattr(torch, f) ,types.BuiltinFunctionType): 343 | self._raw_TrochFuncs[f] = getattr(torch, f) 344 | setattr(torch, f, _ReplaceFunc(getattr(torch,f), self._torchFunctions)) 345 | 346 | for f in dir(F): 347 | # if private function, pass 348 | if f.startswith("_"): 349 | continue 350 | 351 | if isinstance(getattr(F, f) ,types.BuiltinMethodType) or isinstance(getattr(F, f) ,types.BuiltinFunctionType) or isinstance(getattr(F, f) ,types.FunctionType): 352 | self._raw_TrochFunctionals[f] = getattr(F, f) 353 | setattr(F, f, _ReplaceFunc(getattr(F,f), self._torchFunctionals)) 354 | 355 | 356 | self.log = tmp_model.forward(self.log) 357 | 358 | # reset back 359 | for f in self._raw_TrochFuncs: 360 | setattr(torch, f, self._raw_TrochFuncs[f]) 361 | 362 | for f in self._raw_TrochFunctionals: 363 | setattr(F, f, self._raw_TrochFunctionals[f]) 364 | 365 | del tmp_model 366 | 367 | def _trans_unit(self, model): 368 | # print("TRNS_UNIT") 369 | for module_name in model._modules: 370 | # has children 371 | if len(model._modules[module_name]._modules) > 0: 372 | self._trans_unit(model._modules[module_name]) 373 | else: 374 | unitlayer = UnitLayer(getattr(model, module_name)) 375 | setattr(model, module_name, unitlayer) 376 | 377 | return model 378 | 379 | def _torchFunctions(self, raw_func, *args, **kwargs): 380 | """! 381 | The replaced torch function (eg: torch.{function}) will go here 382 | """ 383 | # print("Torch function") 384 | function_name = raw_func.__name__ 385 | 386 | # torch function may has no input 387 | # so check first 388 | 389 | if len(args) > 0: 390 | logs = args[0] 391 | cur_args = args[1:] 392 | elif len(kwargs) > 0: 393 | 394 | return raw_func(**kwargs) 395 | else: 396 | return raw_func() 397 | 398 | # check is user used or in torch function call 399 | is_tensor_in = False 400 | # tensor input 401 | # multi tensor input 402 | if isinstance(logs, tuple) and (type(logs[0]) == torch.Tensor): 403 | cur_inputs = logs 404 | is_tensor_in = True 405 | return raw_func(*args, **kwargs) 406 | # single tensor input 407 | elif (type(logs) == torch.Tensor): 408 | 409 | cur_inputs = logs 410 | is_tensor_in = True 411 | # print(*args) 412 | # print(**kwargs) 413 | return raw_func(*args, **kwargs) 414 | elif (type(logs) == nn.Parameter): 415 | cur_inputs = logs 416 | is_tensor_in = True 417 | return raw_func(*args, **kwargs) 418 | # log input 419 | else: 420 | # multi inputs 421 | bottoms = [] 422 | cur_inputs = [] 423 | 424 | if isinstance(logs, tuple) or isinstance(logs, list): 425 | # may use origin input log as others' input 426 | # eg: densenet in torchvision 0.4.0 427 | cur_log = copy.deepcopy(logs[0]) 428 | for log in logs: 429 | cur_inputs.append(log.cur_tensor) 430 | # print(log.cur_tensor.size()) 431 | bottoms.append(log.cur_id) 432 | # update informations 433 | cur_log.graph.update(log.graph) 434 | cur_log.bottoms.update(log.bottoms) 435 | cur_log.output_shape.update(log.output_shape) 436 | cur_inputs = tuple(cur_inputs) 437 | # one input 438 | else: 439 | # print(args) 440 | # print(kwargs) 441 | cur_log = logs 442 | cur_inputs = cur_log.cur_tensor 443 | bottoms.append(cur_log.cur_id) 444 | 445 | # replace logs to tensor as function inputs to get output tensor 446 | args = list(args) 447 | args[0] = cur_inputs 448 | args = tuple(args) 449 | # send into origin functions 450 | #out_tensor = raw_func(*args, **kwargs).clone().detach() 451 | out_tensor = raw_func(*args, **kwargs).clone() 452 | 453 | # if function call, just return out tensor 454 | if is_tensor_in: 455 | return out_tensor 456 | 457 | # most multi input change to one output 458 | # most multi output has one input 459 | # if shape change 460 | # store theese types of opreation as a layer 461 | if isinstance(logs, tuple) or isinstance(logs, list) or isinstance(out_tensor, tuple) or (logs.cur_tensor.size() != out_tensor.size()): 462 | layer_name = "torch.{}_{}".format(function_name, len(cur_log.graph)) 463 | cur_log.graph[layer_name] = layer_name 464 | cur_log.bottoms[layer_name] = bottoms 465 | cur_log.cur_id = layer_name 466 | 467 | # multi output 468 | if not isinstance(out_tensor , torch.Tensor): 469 | # print("multi output") 470 | out_logs = [] 471 | for t in out_tensor: 472 | out_log = copy.deepcopy(cur_log) 473 | out_log.setTensor(t) 474 | out_logs.append(out_log) 475 | 476 | # sometimes will has (out, ) and this lens is >1 477 | if len(out_logs) == 1: 478 | out_logs = out_logs[0] 479 | return out_logs 480 | 481 | else: 482 | cur_log.setTensor(out_tensor) 483 | return cur_log 484 | 485 | # torch.functionals 486 | def _torchFunctionals(self, raw_func, *args, **kwargs): 487 | """! 488 | The replaced torch.functional function (eg: F.{function}) will go here 489 | """ 490 | # print("Functional") 491 | function_name = raw_func.__name__ 492 | # print(raw_func.__name__) 493 | 494 | # functional has input expect affine_grid 495 | if function_name == "affine_grid": 496 | pass 497 | else: 498 | logs = args[0] 499 | cur_args = args[1:] 500 | 501 | # check is user used or in torch function call 502 | is_tensor_in = False 503 | # tensor input 504 | if (len(logs) > 1) and (type(logs[0]) == torch.Tensor): 505 | # print(logs[0].size(), logs[1].size()) 506 | cur_inputs = logs 507 | is_tensor_in = True 508 | out = raw_func(*args, **kwargs) 509 | # print("Functional return : {}".format(out.size())) 510 | return raw_func(*args, **kwargs) 511 | 512 | elif (len(logs) ==1) and (type(logs) == torch.Tensor): 513 | cur_inputs = logs 514 | is_tensor_in = True 515 | out = raw_func(*args, **kwargs) 516 | # print("Functional return : {}".format(out.size())) 517 | return raw_func(*args, **kwargs) 518 | 519 | # log input 520 | else: 521 | # multi inputs 522 | bottoms = [] 523 | cur_inputs = [] 524 | if len(logs) > 1: 525 | cur_log = logs[0] 526 | for log in logs: 527 | cur_inputs.append(log.cur_tensor) 528 | bottoms.append(log.cur_id) 529 | # update informations 530 | cur_log.graph.update(log.graph) 531 | cur_log.bottoms.update(log.bottoms) 532 | cur_log.output_shape.update(log.output_shape) 533 | cur_inputs = tuple(cur_inputs) 534 | # one input 535 | else: 536 | cur_log = logs 537 | cur_inputs = cur_log.cur_tensor 538 | bottoms.append(cur_log.cur_id) 539 | 540 | 541 | 542 | # replace logs to tensor as function inputs to get output tensor 543 | args = list(args) 544 | args[0] = cur_inputs 545 | args = tuple(args) 546 | # send into origin functions 547 | #out_tensor = raw_func(*args, **kwargs).clone().detach() 548 | out_tensor = raw_func(*args, **kwargs).clone() 549 | 550 | # if function call, just return out tensor 551 | if is_tensor_in: 552 | return out_tensor 553 | 554 | # if log input and is function type, store as an layer 555 | if isinstance(raw_func, types.FunctionType): 556 | # use multiple address as name to prevent duplicate address 557 | layer_name = "F.{}_{}{}{}".format(function_name, id(out_tensor), id(args), id(kwargs)) 558 | # replace with new address if still duplicate 559 | while layer_name in cur_log.graph: 560 | #if layer_name in cur_log.graph: 561 | # tmp_list = [] 562 | # tmp_list.append(out_tensor) 563 | # tmp_tensor = copy.deepcopy(tmp_list[-1]) 564 | # tmp_tensor = tmp_list[-1].clone() 565 | tmp_tensor = torch.tensor([0]) 566 | 567 | # should not duplicate again? 568 | # layer_name = layer_name.split('.')[0] + "F" + ".{}_{}{}{}".format(function_name, id(tmp_tensor), id(args), id(kwargs)) 569 | layer_name = "F.{}_{}{}{}{}".format(function_name, id(tmp_tensor), id(args), id(kwargs), int((time.time()*100000)%1000000)) 570 | 571 | cur_log.graph[layer_name] = layer_name 572 | cur_log.bottoms[layer_name] = bottoms 573 | cur_log.cur_id = layer_name 574 | 575 | # if multi-output 576 | # if len(out_tensor) > 1: 577 | if not isinstance(out_tensor, torch.Tensor): 578 | out_logs = [] 579 | for t in out_tensor: 580 | out_log = copy.deepcopy(cur_log) 581 | out_log.setTensor(t) 582 | out_logs.append(out_log) 583 | 584 | return out_logs 585 | else: 586 | cur_log.setTensor(out_tensor) 587 | return cur_log 588 | 589 | 590 | -------------------------------------------------------------------------------- /utils/PyTransformer/transform_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torch.nn as nn\n", 11 | "import torchvision\n", 12 | "import torchvision.models as models\n", 13 | "\n", 14 | "from transformers.torchTransformer import TorchTransformer\n", 15 | "from transformers.quantize import QConv2d" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 2, 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "model = models.__dict__[\"resnet18\"]()\n", 25 | "model.cuda()\n", 26 | "model = model.eval()\n" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": {}, 32 | "source": [ 33 | "## Register layer to be transform" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 3, 39 | "metadata": {}, 40 | "outputs": [ 41 | { 42 | "name": "stdout", 43 | "output_type": "stream", 44 | "text": [ 45 | "register \n" 46 | ] 47 | } 48 | ], 49 | "source": [ 50 | "transformer = TorchTransformer()\n", 51 | "transformer.register(nn.Conv2d, QConv2d)\n", 52 | "model = transformer.trans_layers(model)" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 4, 58 | "metadata": { 59 | "scrolled": true 60 | }, 61 | "outputs": [ 62 | { 63 | "data": { 64 | "text/plain": [ 65 | "ResNet(\n", 66 | " (conv1): QConv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n", 67 | " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 68 | " (relu): ReLU(inplace=True)\n", 69 | " (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n", 70 | " (layer1): Sequential(\n", 71 | " (0): BasicBlock(\n", 72 | " (conv1): QConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 73 | " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 74 | " (relu): ReLU(inplace=True)\n", 75 | " (conv2): QConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 76 | " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 77 | " )\n", 78 | " (1): BasicBlock(\n", 79 | " (conv1): QConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 80 | " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 81 | " (relu): ReLU(inplace=True)\n", 82 | " (conv2): QConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 83 | " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 84 | " )\n", 85 | " )\n", 86 | " (layer2): Sequential(\n", 87 | " (0): BasicBlock(\n", 88 | " (conv1): QConv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", 89 | " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 90 | " (relu): ReLU(inplace=True)\n", 91 | " (conv2): QConv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 92 | " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 93 | " (downsample): Sequential(\n", 94 | " (0): QConv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", 95 | " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 96 | " )\n", 97 | " )\n", 98 | " (1): BasicBlock(\n", 99 | " (conv1): QConv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 100 | " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 101 | " (relu): ReLU(inplace=True)\n", 102 | " (conv2): QConv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 103 | " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 104 | " )\n", 105 | " )\n", 106 | " (layer3): Sequential(\n", 107 | " (0): BasicBlock(\n", 108 | " (conv1): QConv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", 109 | " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 110 | " (relu): ReLU(inplace=True)\n", 111 | " (conv2): QConv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 112 | " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 113 | " (downsample): Sequential(\n", 114 | " (0): QConv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", 115 | " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 116 | " )\n", 117 | " )\n", 118 | " (1): BasicBlock(\n", 119 | " (conv1): QConv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 120 | " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 121 | " (relu): ReLU(inplace=True)\n", 122 | " (conv2): QConv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 123 | " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 124 | " )\n", 125 | " )\n", 126 | " (layer4): Sequential(\n", 127 | " (0): BasicBlock(\n", 128 | " (conv1): QConv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", 129 | " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 130 | " (relu): ReLU(inplace=True)\n", 131 | " (conv2): QConv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 132 | " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 133 | " (downsample): Sequential(\n", 134 | " (0): QConv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", 135 | " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 136 | " )\n", 137 | " )\n", 138 | " (1): BasicBlock(\n", 139 | " (conv1): QConv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 140 | " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 141 | " (relu): ReLU(inplace=True)\n", 142 | " (conv2): QConv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 143 | " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 144 | " )\n", 145 | " )\n", 146 | " (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n", 147 | " (fc): Linear(in_features=512, out_features=1000, bias=True)\n", 148 | ")" 149 | ] 150 | }, 151 | "execution_count": 4, 152 | "metadata": {}, 153 | "output_type": "execute_result" 154 | } 155 | ], 156 | "source": [ 157 | "model" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": 5, 163 | "metadata": {}, 164 | "outputs": [ 165 | { 166 | "name": "stdout", 167 | "output_type": "stream", 168 | "text": [ 169 | "##########################################################################################\n", 170 | "Index| Layer (type) | Bottoms Output Shape Param # \n", 171 | "---------------------------------------------------------------------------\n", 172 | " 0| Data | [(1, 3, 224, 224)] 0 \n", 173 | "---------------------------------------------------------------------------\n", 174 | " 1| QConv2d_1 | Data [(1, 64, 112, 112)] 9408 \n", 175 | "---------------------------------------------------------------------------\n", 176 | " 2| BatchNorm2d_2 | QConv2d_1 [(1, 64, 112, 112)] 64 \n", 177 | "---------------------------------------------------------------------------\n", 178 | " 3| ReLU_3 | BatchNorm2d_2 [(1, 64, 112, 112)] 0 \n", 179 | "---------------------------------------------------------------------------\n", 180 | " 4| MaxPool2d_4 | ReLU_3 [(1, 64, 56, 56)] 0 \n", 181 | "---------------------------------------------------------------------------\n", 182 | " 5| QConv2d_5 | MaxPool2d_4 [(1, 64, 56, 56)] 36864 \n", 183 | "---------------------------------------------------------------------------\n", 184 | " 6| BatchNorm2d_6 | QConv2d_5 [(1, 64, 56, 56)] 64 \n", 185 | "---------------------------------------------------------------------------\n", 186 | " 7| ReLU_7 | BatchNorm2d_6 [(1, 64, 56, 56)] 0 \n", 187 | "---------------------------------------------------------------------------\n", 188 | " 8| QConv2d_8 | ReLU_7 [(1, 64, 56, 56)] 36864 \n", 189 | "---------------------------------------------------------------------------\n", 190 | " 9| BatchNorm2d_9 | QConv2d_8 [(1, 64, 56, 56)] 64 \n", 191 | "---------------------------------------------------------------------------\n", 192 | " 10| iadd_10 | BatchNorm2d_9 [(1, 64, 56, 56)] 0 \n", 193 | " | | MaxPool2d_4 \n", 194 | "---------------------------------------------------------------------------\n", 195 | " 11| ReLU_11 | iadd_10 [(1, 64, 56, 56)] 0 \n", 196 | "---------------------------------------------------------------------------\n", 197 | " 12| QConv2d_12 | ReLU_11 [(1, 64, 56, 56)] 36864 \n", 198 | "---------------------------------------------------------------------------\n", 199 | " 13| BatchNorm2d_13 | QConv2d_12 [(1, 64, 56, 56)] 64 \n", 200 | "---------------------------------------------------------------------------\n", 201 | " 14| ReLU_14 | BatchNorm2d_13 [(1, 64, 56, 56)] 0 \n", 202 | "---------------------------------------------------------------------------\n", 203 | " 15| QConv2d_15 | ReLU_14 [(1, 64, 56, 56)] 36864 \n", 204 | "---------------------------------------------------------------------------\n", 205 | " 16| BatchNorm2d_16 | QConv2d_15 [(1, 64, 56, 56)] 64 \n", 206 | "---------------------------------------------------------------------------\n", 207 | " 17| iadd_17 | BatchNorm2d_16 [(1, 64, 56, 56)] 0 \n", 208 | " | | ReLU_11 \n", 209 | "---------------------------------------------------------------------------\n", 210 | " 18| ReLU_18 | iadd_17 [(1, 64, 56, 56)] 0 \n", 211 | "---------------------------------------------------------------------------\n", 212 | " 19| QConv2d_19 | ReLU_18 [(1, 128, 28, 28)] 73728 \n", 213 | "---------------------------------------------------------------------------\n", 214 | " 20| BatchNorm2d_20 | QConv2d_19 [(1, 128, 28, 28)] 128 \n", 215 | "---------------------------------------------------------------------------\n", 216 | " 21| ReLU_21 | BatchNorm2d_20 [(1, 128, 28, 28)] 0 \n", 217 | "---------------------------------------------------------------------------\n", 218 | " 22| QConv2d_22 | ReLU_21 [(1, 128, 28, 28)] 147456 \n", 219 | "---------------------------------------------------------------------------\n", 220 | " 23| BatchNorm2d_23 | QConv2d_22 [(1, 128, 28, 28)] 128 \n", 221 | "---------------------------------------------------------------------------\n", 222 | " 24| QConv2d_24 | ReLU_18 [(1, 128, 28, 28)] 8192 \n", 223 | "---------------------------------------------------------------------------\n", 224 | " 25| BatchNorm2d_25 | QConv2d_24 [(1, 128, 28, 28)] 128 \n", 225 | "---------------------------------------------------------------------------\n", 226 | " 26| iadd_26 | BatchNorm2d_23 [(1, 128, 28, 28)] 0 \n", 227 | " | | BatchNorm2d_25 \n", 228 | "---------------------------------------------------------------------------\n", 229 | " 27| ReLU_27 | iadd_26 [(1, 128, 28, 28)] 0 \n", 230 | "---------------------------------------------------------------------------\n", 231 | " 28| QConv2d_28 | ReLU_27 [(1, 128, 28, 28)] 147456 \n", 232 | "---------------------------------------------------------------------------\n", 233 | " 29| BatchNorm2d_29 | QConv2d_28 [(1, 128, 28, 28)] 128 \n", 234 | "---------------------------------------------------------------------------\n", 235 | " 30| ReLU_30 | BatchNorm2d_29 [(1, 128, 28, 28)] 0 \n", 236 | "---------------------------------------------------------------------------\n", 237 | " 31| QConv2d_31 | ReLU_30 [(1, 128, 28, 28)] 147456 \n", 238 | "---------------------------------------------------------------------------\n", 239 | " 32| BatchNorm2d_32 | QConv2d_31 [(1, 128, 28, 28)] 128 \n", 240 | "---------------------------------------------------------------------------\n", 241 | " 33| iadd_33 | BatchNorm2d_32 [(1, 128, 28, 28)] 0 \n", 242 | " | | ReLU_27 \n", 243 | "---------------------------------------------------------------------------\n", 244 | " 34| ReLU_34 | iadd_33 [(1, 128, 28, 28)] 0 \n", 245 | "---------------------------------------------------------------------------\n", 246 | " 35| QConv2d_35 | ReLU_34 [(1, 256, 14, 14)] 294912 \n", 247 | "---------------------------------------------------------------------------\n", 248 | " 36| BatchNorm2d_36 | QConv2d_35 [(1, 256, 14, 14)] 256 \n", 249 | "---------------------------------------------------------------------------\n", 250 | " 37| ReLU_37 | BatchNorm2d_36 [(1, 256, 14, 14)] 0 \n", 251 | "---------------------------------------------------------------------------\n", 252 | " 38| QConv2d_38 | ReLU_37 [(1, 256, 14, 14)] 589824 \n", 253 | "---------------------------------------------------------------------------\n", 254 | " 39| BatchNorm2d_39 | QConv2d_38 [(1, 256, 14, 14)] 256 \n", 255 | "---------------------------------------------------------------------------\n", 256 | " 40| QConv2d_40 | ReLU_34 [(1, 256, 14, 14)] 32768 \n", 257 | "---------------------------------------------------------------------------\n", 258 | " 41| BatchNorm2d_41 | QConv2d_40 [(1, 256, 14, 14)] 256 \n", 259 | "---------------------------------------------------------------------------\n", 260 | " 42| iadd_42 | BatchNorm2d_39 [(1, 256, 14, 14)] 0 \n", 261 | " | | BatchNorm2d_41 \n", 262 | "---------------------------------------------------------------------------\n", 263 | " 43| ReLU_43 | iadd_42 [(1, 256, 14, 14)] 0 \n", 264 | "---------------------------------------------------------------------------\n", 265 | " 44| QConv2d_44 | ReLU_43 [(1, 256, 14, 14)] 589824 \n", 266 | "---------------------------------------------------------------------------\n", 267 | " 45| BatchNorm2d_45 | QConv2d_44 [(1, 256, 14, 14)] 256 \n", 268 | "---------------------------------------------------------------------------\n", 269 | " 46| ReLU_46 | BatchNorm2d_45 [(1, 256, 14, 14)] 0 \n", 270 | "---------------------------------------------------------------------------\n", 271 | " 47| QConv2d_47 | ReLU_46 [(1, 256, 14, 14)] 589824 \n", 272 | "---------------------------------------------------------------------------\n", 273 | " 48| BatchNorm2d_48 | QConv2d_47 [(1, 256, 14, 14)] 256 \n", 274 | "---------------------------------------------------------------------------\n", 275 | " 49| iadd_49 | BatchNorm2d_48 [(1, 256, 14, 14)] 0 \n", 276 | " | | ReLU_43 \n", 277 | "---------------------------------------------------------------------------\n", 278 | " 50| ReLU_50 | iadd_49 [(1, 256, 14, 14)] 0 \n", 279 | "---------------------------------------------------------------------------\n", 280 | " 51| QConv2d_51 | ReLU_50 [(1, 512, 7, 7)] 1179648 \n", 281 | "---------------------------------------------------------------------------\n", 282 | " 52| BatchNorm2d_52 | QConv2d_51 [(1, 512, 7, 7)] 512 \n", 283 | "---------------------------------------------------------------------------\n", 284 | " 53| ReLU_53 | BatchNorm2d_52 [(1, 512, 7, 7)] 0 \n", 285 | "---------------------------------------------------------------------------\n", 286 | " 54| QConv2d_54 | ReLU_53 [(1, 512, 7, 7)] 2359296 \n", 287 | "---------------------------------------------------------------------------\n", 288 | " 55| BatchNorm2d_55 | QConv2d_54 [(1, 512, 7, 7)] 512 \n", 289 | "---------------------------------------------------------------------------\n", 290 | " 56| QConv2d_56 | ReLU_50 [(1, 512, 7, 7)] 131072 \n", 291 | "---------------------------------------------------------------------------\n", 292 | " 57| BatchNorm2d_57 | QConv2d_56 [(1, 512, 7, 7)] 512 \n", 293 | "---------------------------------------------------------------------------\n", 294 | " 58| iadd_58 | BatchNorm2d_55 [(1, 512, 7, 7)] 0 \n", 295 | " | | BatchNorm2d_57 \n", 296 | "---------------------------------------------------------------------------\n", 297 | " 59| ReLU_59 | iadd_58 [(1, 512, 7, 7)] 0 \n", 298 | "---------------------------------------------------------------------------\n", 299 | " 60| QConv2d_60 | ReLU_59 [(1, 512, 7, 7)] 2359296 \n", 300 | "---------------------------------------------------------------------------\n", 301 | " 61| BatchNorm2d_61 | QConv2d_60 [(1, 512, 7, 7)] 512 \n", 302 | "---------------------------------------------------------------------------\n", 303 | " 62| ReLU_62 | BatchNorm2d_61 [(1, 512, 7, 7)] 0 \n", 304 | "---------------------------------------------------------------------------\n", 305 | " 63| QConv2d_63 | ReLU_62 [(1, 512, 7, 7)] 2359296 \n", 306 | "---------------------------------------------------------------------------\n", 307 | " 64| BatchNorm2d_64 | QConv2d_63 [(1, 512, 7, 7)] 512 \n", 308 | "---------------------------------------------------------------------------\n", 309 | " 65| iadd_65 | BatchNorm2d_64 [(1, 512, 7, 7)] 0 \n", 310 | " | | ReLU_59 \n", 311 | "---------------------------------------------------------------------------\n", 312 | " 66| ReLU_66 | iadd_65 [(1, 512, 7, 7)] 0 \n", 313 | "---------------------------------------------------------------------------\n", 314 | " 67| AdaptiveAvgPool2d_67 | ReLU_66 [(1, 512, 1, 1)] 0 \n", 315 | "---------------------------------------------------------------------------\n", 316 | " 68| torch.flatten_68 | AdaptiveAvgPool2d_67 [(1, 512)] 0 \n", 317 | "---------------------------------------------------------------------------\n", 318 | " 69| Linear_69 | torch.flatten_68 [(1, 1000)] 512000 \n", 319 | "---------------------------------------------------------------------------\n", 320 | "==================================================================================\n", 321 | "Total Trainable params: 11683712 \n", 322 | "Total Non-Trainable params: 0 \n", 323 | "Total params: 11683712 \n" 324 | ] 325 | } 326 | ], 327 | "source": [ 328 | "input_tensor = torch.randn([1, 3, 224, 224]).cuda()\n", 329 | "model = model.cuda()\n", 330 | "transformer.summary(model, input_tensor = input_tensor)" 331 | ] 332 | }, 333 | { 334 | "cell_type": "code", 335 | "execution_count": null, 336 | "metadata": {}, 337 | "outputs": [], 338 | "source": [] 339 | } 340 | ], 341 | "metadata": { 342 | "kernelspec": { 343 | "display_name": "Python 3", 344 | "language": "python", 345 | "name": "python3" 346 | }, 347 | "language_info": { 348 | "codemirror_mode": { 349 | "name": "ipython", 350 | "version": 3 351 | }, 352 | "file_extension": ".py", 353 | "mimetype": "text/x-python", 354 | "name": "python", 355 | "nbconvert_exporter": "python", 356 | "pygments_lexer": "ipython3", 357 | "version": "3.6.9" 358 | } 359 | }, 360 | "nbformat": 4, 361 | "nbformat_minor": 2 362 | } 363 | -------------------------------------------------------------------------------- /utils/PyTransformer/visualize_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torch.nn as nn\n", 11 | "import torchvision\n", 12 | "import torchvision.models as models\n", 13 | "\n", 14 | "from transformers.torchTransformer import TorchTransformer" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 2, 20 | "metadata": { 21 | "scrolled": true 22 | }, 23 | "outputs": [ 24 | { 25 | "data": { 26 | "text/plain": [ 27 | "AlexNet(\n", 28 | " (features): Sequential(\n", 29 | " (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))\n", 30 | " (1): ReLU(inplace=True)\n", 31 | " (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 32 | " (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n", 33 | " (4): ReLU(inplace=True)\n", 34 | " (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 35 | " (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 36 | " (7): ReLU(inplace=True)\n", 37 | " (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 38 | " (9): ReLU(inplace=True)\n", 39 | " (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 40 | " (11): ReLU(inplace=True)\n", 41 | " (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 42 | " )\n", 43 | " (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))\n", 44 | " (classifier): Sequential(\n", 45 | " (0): Dropout(p=0.5, inplace=False)\n", 46 | " (1): Linear(in_features=9216, out_features=4096, bias=True)\n", 47 | " (2): ReLU(inplace=True)\n", 48 | " (3): Dropout(p=0.5, inplace=False)\n", 49 | " (4): Linear(in_features=4096, out_features=4096, bias=True)\n", 50 | " (5): ReLU(inplace=True)\n", 51 | " (6): Linear(in_features=4096, out_features=1000, bias=True)\n", 52 | " )\n", 53 | ")" 54 | ] 55 | }, 56 | "execution_count": 2, 57 | "metadata": {}, 58 | "output_type": "execute_result" 59 | } 60 | ], 61 | "source": [ 62 | "model_name = \"alexnet\"\n", 63 | "model = models.__dict__[model_name]()\n", 64 | "model.eval()" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 3, 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "input_tensor = torch.randn([1, 3, 224, 224])" 74 | ] 75 | }, 76 | { 77 | "cell_type": "markdown", 78 | "metadata": {}, 79 | "source": [ 80 | "## visualization" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "metadata": {}, 86 | "source": [ 87 | "### without saving image" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 4, 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "transformer = TorchTransformer()" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 5, 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "dot = transformer.visualize(model, input_tensor = input_tensor)" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 6, 111 | "metadata": { 112 | "scrolled": true 113 | }, 114 | "outputs": [ 115 | { 116 | "data": { 117 | "image/svg+xml": [ 118 | "\n", 119 | "\n", 121 | "\n", 123 | "\n", 124 | "\n", 126 | "\n", 127 | "%3\n", 128 | "\n", 129 | "\n", 130 | "\n", 131 | "Data\n", 132 | "\n", 133 | "Data\n", 134 | "Shape: [(1, 3, 224, 224)]\n", 135 | "\n", 136 | "\n", 137 | "\n", 138 | "140437396772176\n", 139 | "\n", 140 | "Conv2d_1\n", 141 | "Bottoms: Data\n", 142 | "Shape: [(1, 64, 55, 55)]\n", 143 | "\n", 144 | "\n", 145 | "\n", 146 | "Data->140437396772176\n", 147 | "\n", 148 | "\n", 149 | "\n", 150 | "\n", 151 | "\n", 152 | "140437396773744\n", 153 | "\n", 154 | "ReLU_2\n", 155 | "Bottoms: Conv2d_1\n", 156 | "Shape: [(1, 64, 55, 55)]\n", 157 | "\n", 158 | "\n", 159 | "\n", 160 | "140437396772176->140437396773744\n", 161 | "\n", 162 | "\n", 163 | "\n", 164 | "\n", 165 | "\n", 166 | "140437394718280\n", 167 | "\n", 168 | "MaxPool2d_3\n", 169 | "Bottoms: ReLU_2\n", 170 | "Shape: [(1, 64, 27, 27)]\n", 171 | "\n", 172 | "\n", 173 | "\n", 174 | "140437396773744->140437394718280\n", 175 | "\n", 176 | "\n", 177 | "\n", 178 | "\n", 179 | "\n", 180 | "140437394761432\n", 181 | "\n", 182 | "Conv2d_4\n", 183 | "Bottoms: MaxPool2d_3\n", 184 | "Shape: [(1, 192, 27, 27)]\n", 185 | "\n", 186 | "\n", 187 | "\n", 188 | "140437394718280->140437394761432\n", 189 | "\n", 190 | "\n", 191 | "\n", 192 | "\n", 193 | "\n", 194 | "140437394760760\n", 195 | "\n", 196 | "ReLU_5\n", 197 | "Bottoms: Conv2d_4\n", 198 | "Shape: [(1, 192, 27, 27)]\n", 199 | "\n", 200 | "\n", 201 | "\n", 202 | "140437394761432->140437394760760\n", 203 | "\n", 204 | "\n", 205 | "\n", 206 | "\n", 207 | "\n", 208 | "140437394813616\n", 209 | "\n", 210 | "MaxPool2d_6\n", 211 | "Bottoms: ReLU_5\n", 212 | "Shape: [(1, 192, 13, 13)]\n", 213 | "\n", 214 | "\n", 215 | "\n", 216 | "140437394760760->140437394813616\n", 217 | "\n", 218 | "\n", 219 | "\n", 220 | "\n", 221 | "\n", 222 | "140437394834040\n", 223 | "\n", 224 | "Conv2d_7\n", 225 | "Bottoms: MaxPool2d_6\n", 226 | "Shape: [(1, 384, 13, 13)]\n", 227 | "\n", 228 | "\n", 229 | "\n", 230 | "140437394813616->140437394834040\n", 231 | "\n", 232 | "\n", 233 | "\n", 234 | "\n", 235 | "\n", 236 | "140437394857768\n", 237 | "\n", 238 | "ReLU_8\n", 239 | "Bottoms: Conv2d_7\n", 240 | "Shape: [(1, 384, 13, 13)]\n", 241 | "\n", 242 | "\n", 243 | "\n", 244 | "140437394834040->140437394857768\n", 245 | "\n", 246 | "\n", 247 | "\n", 248 | "\n", 249 | "\n", 250 | "140437394897496\n", 251 | "\n", 252 | "Conv2d_9\n", 253 | "Bottoms: ReLU_8\n", 254 | "Shape: [(1, 256, 13, 13)]\n", 255 | "\n", 256 | "\n", 257 | "\n", 258 | "140437394857768->140437394897496\n", 259 | "\n", 260 | "\n", 261 | "\n", 262 | "\n", 263 | "\n", 264 | "140437394945640\n", 265 | "\n", 266 | "ReLU_10\n", 267 | "Bottoms: Conv2d_9\n", 268 | "Shape: [(1, 256, 13, 13)]\n", 269 | "\n", 270 | "\n", 271 | "\n", 272 | "140437394897496->140437394945640\n", 273 | "\n", 274 | "\n", 275 | "\n", 276 | "\n", 277 | "\n", 278 | "140437394476448\n", 279 | "\n", 280 | "Conv2d_11\n", 281 | "Bottoms: ReLU_10\n", 282 | "Shape: [(1, 256, 13, 13)]\n", 283 | "\n", 284 | "\n", 285 | "\n", 286 | "140437394945640->140437394476448\n", 287 | "\n", 288 | "\n", 289 | "\n", 290 | "\n", 291 | "\n", 292 | "140437394497376\n", 293 | "\n", 294 | "ReLU_12\n", 295 | "Bottoms: Conv2d_11\n", 296 | "Shape: [(1, 256, 13, 13)]\n", 297 | "\n", 298 | "\n", 299 | "\n", 300 | "140437394476448->140437394497376\n", 301 | "\n", 302 | "\n", 303 | "\n", 304 | "\n", 305 | "\n", 306 | "140437394495920\n", 307 | "\n", 308 | "MaxPool2d_13\n", 309 | "Bottoms: ReLU_12\n", 310 | "Shape: [(1, 256, 6, 6)]\n", 311 | "\n", 312 | "\n", 313 | "\n", 314 | "140437394497376->140437394495920\n", 315 | "\n", 316 | "\n", 317 | "\n", 318 | "\n", 319 | "\n", 320 | "140437394496424\n", 321 | "\n", 322 | "AdaptiveAvgPool2d_14\n", 323 | "Bottoms: MaxPool2d_13\n", 324 | "Shape: [(1, 256, 6, 6)]\n", 325 | "\n", 326 | "\n", 327 | "\n", 328 | "140437394495920->140437394496424\n", 329 | "\n", 330 | "\n", 331 | "\n", 332 | "\n", 333 | "\n", 334 | "torch.flatten_15\n", 335 | "\n", 336 | "torch.flatten_15\n", 337 | "Bottoms: AdaptiveAvgPool2d_14\n", 338 | "Shape: [(1, 9216)]\n", 339 | "\n", 340 | "\n", 341 | "\n", 342 | "140437394496424->torch.flatten_15\n", 343 | "\n", 344 | "\n", 345 | "\n", 346 | "\n", 347 | "\n", 348 | "140437394496928\n", 349 | "\n", 350 | "Dropout_16\n", 351 | "Bottoms: torch.flatten_15\n", 352 | "Shape: [(1, 9216)]\n", 353 | "\n", 354 | "\n", 355 | "\n", 356 | "torch.flatten_15->140437394496928\n", 357 | "\n", 358 | "\n", 359 | "\n", 360 | "\n", 361 | "\n", 362 | "140437396656704\n", 363 | "\n", 364 | "Linear_17\n", 365 | "Bottoms: Dropout_16\n", 366 | "Shape: [(1, 4096)]\n", 367 | "\n", 368 | "\n", 369 | "\n", 370 | "140437394496928->140437396656704\n", 371 | "\n", 372 | "\n", 373 | "\n", 374 | "\n", 375 | "\n", 376 | "140437396693792\n", 377 | "\n", 378 | "ReLU_18\n", 379 | "Bottoms: Linear_17\n", 380 | "Shape: [(1, 4096)]\n", 381 | "\n", 382 | "\n", 383 | "\n", 384 | "140437396656704->140437396693792\n", 385 | "\n", 386 | "\n", 387 | "\n", 388 | "\n", 389 | "\n", 390 | "140437396625336\n", 391 | "\n", 392 | "Dropout_19\n", 393 | "Bottoms: ReLU_18\n", 394 | "Shape: [(1, 4096)]\n", 395 | "\n", 396 | "\n", 397 | "\n", 398 | "140437396693792->140437396625336\n", 399 | "\n", 400 | "\n", 401 | "\n", 402 | "\n", 403 | "\n", 404 | "140439190873704\n", 405 | "\n", 406 | "Linear_20\n", 407 | "Bottoms: Dropout_19\n", 408 | "Shape: [(1, 4096)]\n", 409 | "\n", 410 | "\n", 411 | "\n", 412 | "140437396625336->140439190873704\n", 413 | "\n", 414 | "\n", 415 | "\n", 416 | "\n", 417 | "\n", 418 | "140437394655440\n", 419 | "\n", 420 | "ReLU_21\n", 421 | "Bottoms: Linear_20\n", 422 | "Shape: [(1, 4096)]\n", 423 | "\n", 424 | "\n", 425 | "\n", 426 | "140439190873704->140437394655440\n", 427 | "\n", 428 | "\n", 429 | "\n", 430 | "\n", 431 | "\n", 432 | "140439262563408\n", 433 | "\n", 434 | "Linear_22\n", 435 | "Bottoms: ReLU_21\n", 436 | "Shape: [(1, 1000)]\n", 437 | "\n", 438 | "\n", 439 | "\n", 440 | "140437394655440->140439262563408\n", 441 | "\n", 442 | "\n", 443 | "\n", 444 | "\n", 445 | "\n" 446 | ], 447 | "text/plain": [ 448 | "" 449 | ] 450 | }, 451 | "execution_count": 6, 452 | "metadata": {}, 453 | "output_type": "execute_result" 454 | } 455 | ], 456 | "source": [ 457 | "dot" 458 | ] 459 | }, 460 | { 461 | "cell_type": "markdown", 462 | "metadata": {}, 463 | "source": [ 464 | "### save image" 465 | ] 466 | }, 467 | { 468 | "cell_type": "code", 469 | "execution_count": 7, 470 | "metadata": {}, 471 | "outputs": [], 472 | "source": [ 473 | "model.cuda()\n", 474 | "input_tensor = input_tensor.cuda()" 475 | ] 476 | }, 477 | { 478 | "cell_type": "code", 479 | "execution_count": 8, 480 | "metadata": {}, 481 | "outputs": [], 482 | "source": [ 483 | "dot = transformer.visualize(model, input_tensor = input_tensor, save_name = \"examples/{}\".format(model_name), graph_size = 80)" 484 | ] 485 | }, 486 | { 487 | "cell_type": "code", 488 | "execution_count": null, 489 | "metadata": {}, 490 | "outputs": [], 491 | "source": [] 492 | } 493 | ], 494 | "metadata": { 495 | "kernelspec": { 496 | "display_name": "Python 3", 497 | "language": "python", 498 | "name": "python3" 499 | }, 500 | "language_info": { 501 | "codemirror_mode": { 502 | "name": "ipython", 503 | "version": 3 504 | }, 505 | "file_extension": ".py", 506 | "mimetype": "text/x-python", 507 | "name": "python", 508 | "nbconvert_exporter": "python", 509 | "pygments_lexer": "ipython3", 510 | "version": "3.6.9" 511 | } 512 | }, 513 | "nbformat": 4, 514 | "nbformat_minor": 2 515 | } 516 | --------------------------------------------------------------------------------