├── .gitignore ├── ImageNet ├── data_loader.py ├── experiments │ ├── resnet101_ItN.sh │ ├── resnet101_ItN_DF.sh │ ├── resnet18_ItN.sh │ ├── resnet18_ItN_DF.sh │ ├── resnet50_ItN.sh │ └── resnet50_ItN_DF.sh ├── imagenet.py └── models │ ├── __init__.py │ ├── alexnet.py │ ├── densenet.py │ ├── inception.py │ ├── resnet.py │ ├── squeezenet.py │ └── vgg.py ├── LICENSE ├── README.md ├── cifar10 ├── cifar10.py ├── experiments │ ├── vgg │ │ ├── vgg_LargeLR_BN.sh │ │ ├── vgg_LargeLR_ItN.sh │ │ ├── vgg_b1024_BN.sh │ │ ├── vgg_b1024_ItN.sh │ │ ├── vgg_b16_BN.sh │ │ ├── vgg_b16_ItN.sh │ │ ├── vgg_base_BN.sh │ │ └── vgg_base_ItN.sh │ └── wrn │ │ ├── wrn_28_10_BN.sh │ │ ├── wrn_28_10_ItN.sh │ │ ├── wrn_40_10_BN.sh │ │ └── wrn_40_10_ItN.sh ├── mnist.py └── models │ ├── WRN.py │ ├── __init__.py │ ├── resnet.py │ └── vgg.py └── extension ├── __init__.py ├── checkpoint.py ├── dataset.py ├── layers ├── __init__.py ├── scale.py ├── sequential.py └── view.py ├── logger.py ├── normailzation ├── __init__.py ├── center_normalization.py ├── dbn.py ├── group_batch_normalization.py ├── iterative_normalization.py ├── iterative_normalization_FlexGroup.py └── normailzation.py ├── optimizer.py ├── progress_bar.py ├── scheduler.py ├── test ├── IterNorm_test.py └── test_util.py ├── trainer.py ├── utils.py └── visualization.py /.gitignore: -------------------------------------------------------------------------------- 1 | /extension/src/** 2 | !/extension/src/**/*.cpp 3 | !/extension/src/**/*.cu 4 | !/extension/src/**/*.cuh 5 | !/extension/src/**/*.hpp 6 | !/extension/src/**/*.py 7 | !/extension/src/test 8 | 9 | temp.py 10 | 11 | # compilation and distribution 12 | __pycache__ 13 | _ext 14 | *.pyc 15 | *.so 16 | *.a 17 | *.exe 18 | maskrcnn_benchmark.egg-info/ 19 | build/ 20 | dist/ 21 | results/ 22 | 23 | # pytorch/python/numpy formats 24 | *.pth 25 | *.pkl 26 | *.npy 27 | 28 | # ipython/jupyter notebooks 29 | *.ipynb 30 | **/.ipynb_checkpoints/ 31 | 32 | # Editor temporaries 33 | *.swn 34 | *.swo 35 | *.swp 36 | *~ 37 | 38 | # Pycharm editor settings 39 | .idea 40 | .vscode 41 | .idea 42 | .DS_Store 43 | .pytest* 44 | 45 | # CMake 46 | CMakeLists.txt.user 47 | CMakeCache.txt 48 | CMakeFiles 49 | CMakeScripts 50 | Testing 51 | Makefile 52 | cmake_install.cmake 53 | install_manifest.txt 54 | compile_commands.json 55 | CTestTestfile.cmake -------------------------------------------------------------------------------- /ImageNet/data_loader.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import time 5 | import warnings 6 | 7 | import extension as ext 8 | import torchvision.transforms as transforms 9 | 10 | has_DALI = True 11 | try: 12 | from nvidia.dali.plugin.pytorch import DALIClassificationIterator 13 | from nvidia.dali.pipeline import Pipeline 14 | import nvidia.dali.ops as ops 15 | import nvidia.dali.types as types 16 | except ImportError: 17 | warnings.warn("Please install DALI from https://www.github.com/NVIDIA/DALI to enable DALI data loader") 18 | has_DALI = False 19 | Pipeline = object 20 | DALIClassificationIterator = object 21 | 22 | 23 | def add_arguments(parser: argparse.ArgumentParser): 24 | group = ext.dataset.add_arguments(parser) 25 | group.add_argument('--dali', default=has_DALI, type=ext.utils.str2bool, metavar='BOOL', 26 | help="Use NVIDIA DALI to accelerate data load.") 27 | group.set_defaults(dataset='ImageNet', batch_size=[256, 200]) 28 | return group 29 | 30 | 31 | class HybridTrainPipe(Pipeline): 32 | def __init__(self, batch_size, num_threads, device_id, data_dir, crop, dali_cpu=False, local_rank=0, world_size=1): 33 | super(HybridTrainPipe, self).__init__(batch_size, num_threads, device_id, seed=12 + device_id) 34 | self.input = ops.FileReader(file_root=data_dir, shard_id=local_rank, num_shards=world_size, random_shuffle=True) 35 | # let user decide which pipeline works him bets for RN version he runs 36 | scale = [0.08, 1.0] 37 | ratio = [3. / 4., 4. / 3.] 38 | if dali_cpu: 39 | dali_device = "cpu" 40 | self.decode = ops.HostDecoderRandomCrop(device=dali_device, output_type=types.RGB, 41 | random_aspect_ratio=ratio, random_area=scale, num_attempts=100) 42 | else: 43 | dali_device = "gpu" 44 | # This padding sets the size of the internal nvJPEG buffers to be able to handle all images from 45 | # full-sized ImageNet without additional reallocations 46 | self.decode = ops.nvJPEGDecoderRandomCrop(device="mixed", output_type=types.RGB, 47 | device_memory_padding=211025920, host_memory_padding=140544512, 48 | random_aspect_ratio=ratio, random_area=scale, num_attempts=100) 49 | self.res = ops.Resize(device=dali_device, resize_x=crop, resize_y=crop, interp_type=types.INTERP_TRIANGULAR) 50 | self.cmnp = ops.CropMirrorNormalize(device="gpu", output_dtype=types.FLOAT, output_layout=types.NCHW, 51 | crop=(crop, crop), image_type=types.RGB, 52 | mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], 53 | std=[0.229 * 255, 0.224 * 255, 0.225 * 255]) 54 | self.coin = ops.CoinFlip(probability=0.5) 55 | print('DALI "{0}" variant'.format(dali_device)) 56 | 57 | def define_graph(self): 58 | rng = self.coin() 59 | self.jpegs, self.labels = self.input(name="Reader") 60 | images = self.decode(self.jpegs) 61 | images = self.res(images) 62 | output = self.cmnp(images.gpu(), mirror=rng) 63 | return [output, self.labels] 64 | 65 | 66 | class HybridValPipe(Pipeline): 67 | def __init__(self, batch_size, num_threads, device_id, data_dir, crop, size, local_rank=0, world_size=1): 68 | super(HybridValPipe, self).__init__(batch_size, num_threads, device_id, seed=12 + device_id) 69 | self.input = ops.FileReader(file_root=data_dir, shard_id=local_rank, num_shards=world_size, 70 | random_shuffle=False) 71 | self.decode = ops.nvJPEGDecoder(device="mixed", output_type=types.RGB) 72 | self.res = ops.Resize(device="gpu", resize_shorter=size, interp_type=types.INTERP_TRIANGULAR) 73 | self.cmnp = ops.CropMirrorNormalize(device="gpu", output_dtype=types.FLOAT, output_layout=types.NCHW, 74 | crop=(crop, crop), image_type=types.RGB, 75 | mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], 76 | std=[0.229 * 255, 0.224 * 255, 0.225 * 255]) 77 | 78 | def define_graph(self): 79 | self.jpegs, self.labels = self.input(name="Reader") 80 | images = self.decode(self.jpegs) 81 | images = self.res(images) 82 | output = self.cmnp(images) 83 | return [output, self.labels] 84 | 85 | 86 | class ImageNetDataLoader(DALIClassificationIterator): 87 | def __next__(self): 88 | data = super(ImageNetDataLoader, self).__next__() 89 | if isinstance(data, list): 90 | inputs = data[0]["data"] 91 | targets = data[0]["label"].squeeze().long() 92 | return inputs, targets 93 | else: 94 | return data 95 | 96 | def __len__(self): 97 | return int(self._size / self.batch_size) 98 | 99 | 100 | def dail_loader(args, test=False, local_rank=0, world_size=1): 101 | logger = ext.get_logger() 102 | args.dataset_root = os.path.expanduser(args.dataset_root) 103 | root = os.path.join(args.dataset_root, args.dataset) 104 | assert os.path.exists(root), 'Please assign the correct dataset root path with --dataset-root ' 105 | train_dir = os.path.join(root, 'train') 106 | val_dir = os.path.join(root, 'val') 107 | crop_size = 224 108 | if len(args.batch_size) == 0: 109 | args.batch_size = [256, 200] 110 | elif len(args.batch_size) == 1: 111 | args.batch_size.append(args.batch_size[0]) 112 | if test: 113 | train_loader = None 114 | else: 115 | logger('==> Load ImageNet train dataset:') 116 | pipe = HybridTrainPipe(batch_size=args.batch_size[0], num_threads=args.workers, device_id=local_rank, 117 | data_dir=train_dir, crop=crop_size) 118 | pipe.build() 119 | train_loader = ImageNetDataLoader([pipe], size=int(pipe.epoch_size("Reader") / world_size), auto_reset=True, 120 | stop_at_epoch=True) 121 | 122 | logger('==> Load ImageNet val dataset:') 123 | pipe = HybridValPipe(batch_size=args.batch_size[1], num_threads=args.workers, device_id=local_rank, 124 | data_dir=val_dir, crop=crop_size, size=256) 125 | pipe.build() 126 | val_loader = ImageNetDataLoader([pipe], size=int(pipe.epoch_size("Reader") / world_size), auto_reset=True) 127 | return train_loader, val_loader 128 | 129 | 130 | def set_dataset(cfg, test=False): 131 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 132 | train_transform = [transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), 133 | normalize, ] 134 | val_transform = [transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ] 135 | if not test: 136 | train_loader = ext.dataset.get_dataset_loader(cfg, train_transform, None, True) 137 | else: 138 | train_loader = None 139 | val_loader = ext.dataset.get_dataset_loader(cfg, val_transform, None, False) 140 | return train_loader, val_loader 141 | 142 | 143 | def setting(cfg, test=False): 144 | if has_DALI and cfg.dali: 145 | return dail_loader(cfg, test=test) 146 | return set_dataset(cfg, test=test) 147 | 148 | 149 | if __name__ == '__main__': 150 | parser = argparse.ArgumentParser('Test Data Loader') 151 | add_arguments(parser) 152 | args = parser.parse_args() 153 | print('==> args: ', args) 154 | train_loader_, val_loader_ = dail_loader(args) 155 | print('len of train_loader', len(train_loader_)) 156 | total = 0 157 | start_time = time.time() 158 | for i, (inputs, targets) in enumerate(train_loader_, 1): 159 | # inputs = data[0]["data"] 160 | # targets = data[0]["label"].squeeze().cuda().long() 161 | total += targets.size(0) 162 | print('Load train data [{}/{}]: {}, {}'.format(i, len(train_loader_), inputs.size(), targets.size()), end='\r') 163 | print('\nTrain Read {} images, use {:.2f}s'.format(total, time.time() - start_time)) 164 | 165 | print('len of val_loader', len(val_loader_)) 166 | total = 0 167 | start_time = time.time() 168 | for i, (inputs, targets) in enumerate(val_loader_, 1): 169 | # inputs = data[0]["data"] 170 | # targets = data[0]["label"].squeeze().cuda().long() 171 | total += targets.size(0) 172 | print('Load val data [{}/{}]: {}, {}'.format(i, len(val_loader_), inputs.size(), targets.size()), end='\r') 173 | print('\nTrain Read {} images, use {:.2f}s'.format(total, time.time() - start_time)) 174 | -------------------------------------------------------------------------------- /ImageNet/experiments/resnet101_ItN.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd "$(dirname $0)/.." 3 | CUDA_VISIBLE_DEVICES=0,1,2,3 python3 imagenet.py \ 4 | -a=resnet101 \ 5 | --arch-cfg=last=False \ 6 | --batch-size=256 \ 7 | --epochs=100 \ 8 | -oo=sgd \ 9 | -oc=momentum=0.9 \ 10 | -wd=1e-4 \ 11 | --lr=0.1 \ 12 | --lr-method=step \ 13 | --lr-steps=30 \ 14 | --lr-gamma=0.1 \ 15 | --dataset-root=/data/lei/imageNet/input_torch/ \ 16 | --dataset=folder \ 17 | --norm=ItN \ 18 | --norm-cfg=T=5,num_channels=64 \ 19 | --log-suffix=ItN \ 20 | $@ 21 | -------------------------------------------------------------------------------- /ImageNet/experiments/resnet101_ItN_DF.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd "$(dirname $0)/.." 3 | CUDA_VISIBLE_DEVICES=0,1,2,3 python3 imagenet.py \ 4 | -a=resnet101 \ 5 | --arch-cfg=last=True \ 6 | --batch-size=256 \ 7 | --epochs=100 \ 8 | -oo=sgd \ 9 | -oc=momentum=0.9 \ 10 | -wd=1e-4 \ 11 | --lr=0.1 \ 12 | --lr-method=step \ 13 | --lr-steps=30 \ 14 | --lr-gamma=0.1 \ 15 | --dataset-root=/data/lei/imageNet/input_torch/ \ 16 | --dataset=folder \ 17 | --norm=ItN \ 18 | --norm-cfg=T=5,num_channels=64 \ 19 | --log-suffix=DF \ 20 | $@ 21 | -------------------------------------------------------------------------------- /ImageNet/experiments/resnet18_ItN.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd "$(dirname $0)/.." 3 | CUDA_VISIBLE_DEVICES=0 python3 imagenet.py \ 4 | -a=resnet18 \ 5 | --arch-cfg=last=False \ 6 | --batch-size=256 \ 7 | --epochs=100 \ 8 | -oo=sgd \ 9 | -oc=momentum=0.9 \ 10 | -wd=1e-4 \ 11 | --lr=0.1 \ 12 | --lr-method=step \ 13 | --lr-steps=30 \ 14 | --lr-gamma=0.1 \ 15 | --dataset-root=/data/lei/imageNet/input_torch/ \ 16 | --dataset=folder \ 17 | --norm=ItN \ 18 | --norm-cfg=T=5,num_channels=64 \ 19 | $@ 20 | -------------------------------------------------------------------------------- /ImageNet/experiments/resnet18_ItN_DF.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd "$(dirname $0)/.." 3 | CUDA_VISIBLE_DEVICES=0 python3 imagenet.py \ 4 | -a=resnet18 \ 5 | --arch-cfg=last=True \ 6 | --batch-size=256 \ 7 | --epochs=100 \ 8 | -oo=sgd \ 9 | -oc=momentum=0.9 \ 10 | -wd=1e-4 \ 11 | --lr=0.1 \ 12 | --lr-method=step \ 13 | --lr-steps=30 \ 14 | --lr-gamma=0.1 \ 15 | --dataset-root=/data/lei/imageNet/input_torch/ \ 16 | --dataset=folder \ 17 | --norm=ItN \ 18 | --norm-cfg=T=5,num_channels=64 \ 19 | $@ 20 | -------------------------------------------------------------------------------- /ImageNet/experiments/resnet50_ItN.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd "$(dirname $0)/.." 3 | CUDA_VISIBLE_DEVICES=4,5,6,7 python3 imagenet.py \ 4 | -a=resnet50 \ 5 | -ac=last_bn=False \ 6 | --arch-cfg=dropout=0.3 \ 7 | --batch-size=256 \ 8 | --epochs=100 \ 9 | -oo=sgd \ 10 | -oc=momentum=0.9 \ 11 | -wd=1e-4 \ 12 | --lr=0.1 \ 13 | --lr-method=step \ 14 | --lr-steps=30 \ 15 | --lr-gamma=0.1 \ 16 | --dataset-root=/data/lei/imageNet/input_torch/ \ 17 | --dataset=folder \ 18 | --norm=ItN \ 19 | --norm-cfg=T=5,num_channels=64 \ 20 | $@ 21 | -------------------------------------------------------------------------------- /ImageNet/experiments/resnet50_ItN_DF.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd "$(dirname $0)/.." 3 | CUDA_VISIBLE_DEVICES=4,5,6,7 python3 imagenet.py \ 4 | -a=resnet50 \ 5 | --arch-cfg=last=True \ 6 | --batch-size=256 \ 7 | --epochs=100 \ 8 | -oo=sgd \ 9 | -oc=momentum=0.9 \ 10 | -wd=1e-4 \ 11 | --lr=0.1 \ 12 | --lr-method=step \ 13 | --lr-steps=30 \ 14 | --lr-gamma=0.1 \ 15 | --dataset-root=/data/lei/imageNet/input_torch/ \ 16 | --dataset=folder \ 17 | --norm=ItN \ 18 | --norm-cfg=T=5,num_channels=64 \ 19 | --log-suffix=DF \ 20 | $@ 21 | -------------------------------------------------------------------------------- /ImageNet/imagenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | import time 4 | import argparse 5 | import shutil 6 | import sys 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.parallel 11 | import torch.optim 12 | import torch.utils.data 13 | import torch.utils.data.distributed 14 | import numpy as np 15 | 16 | sys.path.append('..') 17 | import models 18 | import data_loader 19 | import extension as ext 20 | from extension.progress_bar import format_time 21 | 22 | 23 | class ClassificationLarge: 24 | def __init__(self): 25 | self.args = self.add_arguments() 26 | self.best_prec1 = 0 27 | self.model_name = self.args.arch + ext.normailzation.setting(self.args) + '_{}'.format(self.args.optimizer) 28 | if not self.args.resume: 29 | self.args.output = os.path.join(self.args.output, self.model_name, self.args.log_suffix) 30 | self.logger = ext.logger.setting('log.txt', self.args.output, self.args.test, bool(self.args.resume)) 31 | ext.trainer.setting(self.args) 32 | self.model = models.__dict__[self.args.arch](**self.args.arch_cfg) 33 | self.logger('==> Model [{}]: {}'.format(self.model_name, self.model)) 34 | self.optimizer = ext.optimizer.setting(self.model, self.args) 35 | self.scheduler = ext.scheduler.setting(self.optimizer, self.args) 36 | self.device = torch.device('cuda') 37 | self.num_gpus = torch.cuda.device_count() 38 | self.logger('==> The number of gpus: {}'.format(self.num_gpus)) 39 | self.saver = ext.checkpoint.Checkpoint(self.model, self.args, self.optimizer, self.scheduler, self.args.output, 40 | not self.args.test) 41 | self.saver.load(self.args.load) 42 | if self.num_gpus > 1: 43 | self.model = torch.nn.DataParallel(self.model).cuda() 44 | self.model = self.model.cuda() 45 | if self.args.resume: 46 | saved = self.saver.resume(self.args.resume) 47 | self.args.start_epoch = saved['epoch'] 48 | self.args.best_prec1 = saved['best_prec1'] 49 | 50 | self.train_loader, self.val_loader = data_loader.setting(self.args, self.args.test) 51 | self.criterion = nn.CrossEntropyLoss().to(self.device) 52 | self.vis = ext.visualization.setting(self.args, self.model_name, 53 | {'train loss': 'loss', 'train top-1': 'accuracy', 54 | 'train top-5': 'accuracy', 'test loss': 'loss', 'test top-1': 'accuracy', 55 | 'test top-5': 'accuracy', 'epoch loss': 'epoch_loss', 56 | 'loss average': 'epoch_loss'}) 57 | return 58 | 59 | def add_arguments(self): 60 | parser = argparse.ArgumentParser('ImageNet Classification') 61 | ext.normailzation.add_arguments(parser) 62 | data_loader.add_arguments(parser) 63 | ext.trainer.add_arguments(parser) 64 | ext.optimizer.add_arguments(parser) 65 | ext.visualization.add_arguments(parser) 66 | ext.logger.add_arguments(parser) 67 | ext.checkpoint.add_arguments(parser) 68 | ext.scheduler.add_arguments(parser) 69 | model_names = sorted(name for name in models.__dict__ if 70 | name.islower() and not name.startswith("__") and callable(models.__dict__[name])) 71 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', choices=model_names, 72 | help='model architecture: {' + ' | '.join(model_names) + '} (default: resnet18)') 73 | 74 | parser.add_argument('-ac', '--arch-cfg', metavar='DICT', default={}, type=ext.utils.str2dict, 75 | help='The extra configure for model architecture') 76 | parser.set_defaults(dataset='ImageNet') 77 | parser.set_defaults(lr_method='step', lr=0.1, lr_step=30, weight_decay=1e-4) 78 | parser.set_defaults(epochs=90) 79 | parser.set_defaults(workers=10) 80 | args = parser.parse_args() 81 | if args.resume is not None: 82 | args = parser.parse_args(namespace=ext.checkpoint.Checkpoint.load_config(args.resume)) 83 | return args 84 | 85 | def train(self): 86 | if self.args.test: 87 | self.validate() 88 | return 89 | self.logger('\n++++++++++++++++++ Begin Train ++++++++++++++++++') 90 | used_times = [] 91 | for epoch in range(self.args.start_epoch + 1, self.args.epochs): 92 | epoch_start_time = time.time() 93 | # adjust_learning_rate(optimizer, epoch) 94 | if self.args.lr_method != 'auto': 95 | self.scheduler.step(epoch) 96 | self.logger('Model {} [{}/{}]: lr={:.3g}, weight decay={:.2g}, time: {}'.format(self.model_name, epoch, 97 | self.args.epochs, 98 | self.optimizer.param_groups[ 99 | 0]['lr'], 100 | self.optimizer.param_groups[ 101 | 0]['weight_decay'], 102 | time.asctime())) 103 | # train for one epoch 104 | self.train_epoch(epoch) 105 | 106 | # evaluate on validation set 107 | prec1, val_loss = self.validate(epoch) 108 | 109 | if self.args.lr_method == 'auto': 110 | self.scheduler.step(val_loss, epoch) 111 | # remember best prec@1 and save checkpoint 112 | is_best = prec1 > self.best_prec1 113 | self.best_prec1 = max(prec1, self.best_prec1) 114 | self.saver.save_checkpoint('checkpoint.pth', epoch=epoch, best_prec1=self.best_prec1, arch=self.args.arch) 115 | if is_best: 116 | self.saver.save_model('best.pth') 117 | used_times.append(time.time() - epoch_start_time) 118 | self.logger('Epoch [{}/{}] use: {}, average: {}, expect: {}\n'.format(epoch, self.args.epochs, 119 | format_time(used_times[-1]), 120 | format_time(np.mean(used_times)), 121 | format_time(( 122 | self.args.epochs - 1 - epoch) * np.mean( 123 | used_times)))) 124 | 125 | now_date = time.strftime("%y-%m-%d_%H:%M:%S", time.localtime(time.time())) 126 | new_log_filename = '{}_{}_{:.2f}%.txt'.format(self.model_name, now_date, self.best_prec1) 127 | self.logger('\n==> Network training completed. Copy log file to {}'.format(new_log_filename)) 128 | shutil.copy(self.logger.filename, os.path.join(self.args.output, new_log_filename)) 129 | return 130 | 131 | def train_epoch(self, epoch): 132 | batch_time = AverageMeter() 133 | data_time = AverageMeter() 134 | losses = AverageMeter() 135 | top1 = AverageMeter() 136 | top5 = AverageMeter() 137 | 138 | # switch to train mode 139 | self.model.train() 140 | self.vis.clear('epoch_loss') 141 | end = time.time() 142 | for i, (inputs, targets) in enumerate(self.train_loader, 1): 143 | # measure data loading time 144 | # if self.args.gpu is not None: 145 | inputs = inputs.cuda(non_blocking=True) 146 | targets = targets.cuda(non_blocking=True) 147 | data_time.update(time.time() - end) 148 | 149 | # compute output 150 | output = self.model(inputs) 151 | loss = self.criterion(output, targets) 152 | 153 | # measure accuracy and record loss 154 | prec1, prec5 = accuracy(output, targets, topk=(1, 5)) 155 | losses.update(loss.item(), inputs.size(0)) 156 | top1.update(prec1[0], inputs.size(0)) 157 | top5.update(prec5[0], inputs.size(0)) 158 | 159 | # compute gradient and do SGD step 160 | self.optimizer.zero_grad() 161 | loss.backward() 162 | self.optimizer.step() 163 | 164 | # measure elapsed time 165 | torch.cuda.synchronize() 166 | batch_time.update(time.time() - end) 167 | end = time.time() 168 | 169 | # logger 170 | is_log = i % self.args.print_f == 0 or i == len(self.train_loader) 171 | self.logger('Epoch: [{0}][{1:5d}/{2:5d}] ' 172 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 173 | 'Data {data_time.val:.3f} ({data_time.avg:.3f}) ' 174 | 'Loss {loss.val:.2f} ({loss.avg:.2f}) ' 175 | 'Prec@1 {top1.val:5.2f} ({top1.avg:5.2f}) ' 176 | 'Prec@5 {top5.val:5.2f} ({top5.avg:5.2f}) ' 177 | ''.format(epoch, i, len(self.train_loader), batch_time=batch_time, data_time=data_time, 178 | loss=losses, top1=top1, top5=top5), end='\n' if is_log else '\r', is_log=is_log) 179 | if is_log: 180 | self.vis.add_value('epoch loss', losses.val) 181 | self.vis.add_value('loss average', losses.ravg) 182 | 183 | self.vis.add_value('train loss', losses.avg) 184 | self.vis.add_value('train top-1', top1.avg) 185 | self.vis.add_value('train top-5', top5.avg) 186 | return 187 | 188 | def validate(self, epoch=-1): 189 | batch_time = AverageMeter() 190 | losses = AverageMeter() 191 | top1 = AverageMeter() 192 | top5 = AverageMeter() 193 | 194 | # switch to evaluate mode 195 | self.model.eval() 196 | 197 | with torch.no_grad(): 198 | end = time.time() 199 | for i, (inputs, targets) in enumerate(self.val_loader, 1): 200 | # if self.args.gpu is not None: 201 | # inputs = inputs.cuda(self.args.gpu, non_blocking=True) 202 | inputs = inputs.cuda(non_blocking=True) 203 | targets = targets.cuda(None, non_blocking=True) 204 | 205 | # compute output 206 | output = self.model(inputs) 207 | loss = self.criterion(output, targets) 208 | 209 | # measure accuracy and record loss 210 | prec1, prec5 = accuracy(output, targets, topk=(1, 5)) 211 | losses.update(loss.item(), inputs.size(0)) 212 | top1.update(prec1[0], inputs.size(0)) 213 | top5.update(prec5[0], inputs.size(0)) 214 | 215 | # measure elapsed time 216 | torch.cuda.synchronize() 217 | batch_time.update(time.time() - end) 218 | end = time.time() 219 | 220 | is_log = i % self.args.print_f == 0 or i == len(self.val_loader) 221 | self.logger('Test: [{0:3d}/{1:3d}] ' 222 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 223 | 'Loss {loss.val:.4f} ({loss.avg:.4f}) ' 224 | 'Prec@1 {top1.val:5.2f} ({top1.avg:5.2f}) ' 225 | 'Prec@5 {top5.val:5.2f} ({top5.avg:5.2f}) ' 226 | ''.format(i, len(self.val_loader), batch_time=batch_time, loss=losses, top1=top1, 227 | top5=top5), end='\n' if is_log else '\r', is_log=is_log) 228 | 229 | self.logger(' * Prec@1 {top1.avg:5.2f} Prec@5 {top5.avg:5.2f} on epoch {epoch}'.format(top1=top1, top5=top5, 230 | epoch=epoch)) 231 | 232 | self.vis.add_value('test loss', losses.avg) 233 | self.vis.add_value('test top-1', top1.avg) 234 | self.vis.add_value('test top-5', top5.avg) 235 | return top1.avg, losses.avg 236 | 237 | 238 | class AverageMeter(object): 239 | """Computes and stores the average and current value""" 240 | 241 | def __init__(self, momentum=0.9): 242 | self.val = 0 243 | self.avg = 0 244 | self.sum = 0 245 | self.count = 0 246 | self.ravg = 0 247 | self.momentum = momentum 248 | 249 | def reset(self): 250 | self.ravg = 0 251 | self.val = 0 252 | self.avg = 0 253 | self.sum = 0 254 | self.count = 0 255 | 256 | def update(self, val, n=1): 257 | self.val = val 258 | self.sum += val * n 259 | if self.count == 0: 260 | self.ravg = self.val 261 | else: 262 | self.ravg = self.momentum * self.ravg + (1. - self.momentum) * val 263 | self.count += n 264 | self.avg = self.sum / self.count 265 | 266 | 267 | def accuracy(output, target, topk=(1,)): 268 | """Computes the precision@k for the specified values of k""" 269 | with torch.no_grad(): 270 | maxk = max(topk) 271 | batch_size = target.size(0) 272 | 273 | _, prediction = output.topk(maxk, 1, True, True) 274 | prediction = prediction.t() 275 | correct = prediction.eq(target.view(1, -1).expand_as(prediction)) 276 | 277 | res = [] 278 | for k in topk: 279 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 280 | res.append(correct_k.mul_(100.0 / batch_size)) 281 | return res 282 | 283 | 284 | if __name__ == '__main__': 285 | ImageNet = ClassificationLarge() 286 | ImageNet.train() 287 | -------------------------------------------------------------------------------- /ImageNet/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .alexnet import * 2 | from .resnet import * 3 | from .vgg import * 4 | from .squeezenet import * 5 | from .inception import * 6 | from .densenet import * 7 | -------------------------------------------------------------------------------- /ImageNet/models/alexnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | 4 | __all__ = ['AlexNet', 'alexnet'] 5 | 6 | model_urls = {'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth', } 7 | 8 | 9 | class AlexNet(nn.Module): 10 | 11 | def __init__(self, num_classes=1000): 12 | super(AlexNet, self).__init__() 13 | self.features = nn.Sequential(nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), nn.ReLU(inplace=True), 14 | nn.MaxPool2d(kernel_size=3, stride=2), nn.Conv2d(64, 192, kernel_size=5, padding=2), nn.ReLU(inplace=True), 15 | nn.MaxPool2d(kernel_size=3, stride=2), nn.Conv2d(192, 384, kernel_size=3, padding=1), nn.ReLU(inplace=True), 16 | nn.Conv2d(384, 256, kernel_size=3, padding=1), nn.ReLU(inplace=True), 17 | nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.ReLU(inplace=True), 18 | nn.MaxPool2d(kernel_size=3, stride=2), ) 19 | self.classifier = nn.Sequential(nn.Dropout(), nn.Linear(256 * 6 * 6, 4096), nn.ReLU(inplace=True), nn.Dropout(), 20 | nn.Linear(4096, 4096), nn.ReLU(inplace=True), nn.Linear(4096, num_classes), ) 21 | 22 | def forward(self, x): 23 | x = self.features(x) 24 | x = x.view(x.size(0), 256 * 6 * 6) 25 | x = self.classifier(x) 26 | return x 27 | 28 | 29 | def alexnet(pretrained=False, **kwargs): 30 | r"""AlexNet model architecture from the 31 | `"One weird trick..." `_ paper. 32 | 33 | Args: 34 | pretrained (bool): If True, returns a model pre-trained on ImageNet 35 | """ 36 | model = AlexNet(**kwargs) 37 | if pretrained: 38 | model.load_state_dict(model_zoo.load_url(model_urls['alexnet'])) 39 | return model 40 | -------------------------------------------------------------------------------- /ImageNet/models/densenet.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils.model_zoo as model_zoo 6 | from collections import OrderedDict 7 | 8 | __all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161'] 9 | 10 | model_urls = {'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth', 11 | 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth', 12 | 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth', 13 | 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth', } 14 | 15 | 16 | def densenet121(pretrained=False, **kwargs): 17 | r"""Densenet-121 model from 18 | `"Densely Connected Convolutional Networks" `_ 19 | 20 | Args: 21 | pretrained (bool): If True, returns a model pre-trained on ImageNet 22 | """ 23 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), **kwargs) 24 | if pretrained: 25 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 26 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 27 | # They are also in the checkpoints in model_urls. This pattern is used 28 | # to find such keys. 29 | pattern = re.compile( 30 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 31 | state_dict = model_zoo.load_url(model_urls['densenet121']) 32 | for key in list(state_dict.keys()): 33 | res = pattern.match(key) 34 | if res: 35 | new_key = res.group(1) + res.group(2) 36 | state_dict[new_key] = state_dict[key] 37 | del state_dict[key] 38 | model.load_state_dict(state_dict) 39 | return model 40 | 41 | 42 | def densenet169(pretrained=False, **kwargs): 43 | r"""Densenet-169 model from 44 | `"Densely Connected Convolutional Networks" `_ 45 | 46 | Args: 47 | pretrained (bool): If True, returns a model pre-trained on ImageNet 48 | """ 49 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), **kwargs) 50 | if pretrained: 51 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 52 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 53 | # They are also in the checkpoints in model_urls. This pattern is used 54 | # to find such keys. 55 | pattern = re.compile( 56 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 57 | state_dict = model_zoo.load_url(model_urls['densenet169']) 58 | for key in list(state_dict.keys()): 59 | res = pattern.match(key) 60 | if res: 61 | new_key = res.group(1) + res.group(2) 62 | state_dict[new_key] = state_dict[key] 63 | del state_dict[key] 64 | model.load_state_dict(state_dict) 65 | return model 66 | 67 | 68 | def densenet201(pretrained=False, **kwargs): 69 | r"""Densenet-201 model from 70 | `"Densely Connected Convolutional Networks" `_ 71 | 72 | Args: 73 | pretrained (bool): If True, returns a model pre-trained on ImageNet 74 | """ 75 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), **kwargs) 76 | if pretrained: 77 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 78 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 79 | # They are also in the checkpoints in model_urls. This pattern is used 80 | # to find such keys. 81 | pattern = re.compile( 82 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 83 | state_dict = model_zoo.load_url(model_urls['densenet201']) 84 | for key in list(state_dict.keys()): 85 | res = pattern.match(key) 86 | if res: 87 | new_key = res.group(1) + res.group(2) 88 | state_dict[new_key] = state_dict[key] 89 | del state_dict[key] 90 | model.load_state_dict(state_dict) 91 | return model 92 | 93 | 94 | def densenet161(pretrained=False, **kwargs): 95 | r"""Densenet-161 model from 96 | `"Densely Connected Convolutional Networks" `_ 97 | 98 | Args: 99 | pretrained (bool): If True, returns a model pre-trained on ImageNet 100 | """ 101 | model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), **kwargs) 102 | if pretrained: 103 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 104 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 105 | # They are also in the checkpoints in model_urls. This pattern is used 106 | # to find such keys. 107 | pattern = re.compile( 108 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 109 | state_dict = model_zoo.load_url(model_urls['densenet161']) 110 | for key in list(state_dict.keys()): 111 | res = pattern.match(key) 112 | if res: 113 | new_key = res.group(1) + res.group(2) 114 | state_dict[new_key] = state_dict[key] 115 | del state_dict[key] 116 | model.load_state_dict(state_dict) 117 | return model 118 | 119 | 120 | class _DenseLayer(nn.Sequential): 121 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate): 122 | super(_DenseLayer, self).__init__() 123 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)), 124 | self.add_module('relu1', nn.ReLU(inplace=True)), 125 | self.add_module('conv1', 126 | nn.Conv2d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)), 127 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), 128 | self.add_module('relu2', nn.ReLU(inplace=True)), 129 | self.add_module('conv2', 130 | nn.Conv2d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)), 131 | self.drop_rate = drop_rate 132 | 133 | def forward(self, x): 134 | new_features = super(_DenseLayer, self).forward(x) 135 | if self.drop_rate > 0: 136 | new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) 137 | return torch.cat([x, new_features], 1) 138 | 139 | 140 | class _DenseBlock(nn.Sequential): 141 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate): 142 | super(_DenseBlock, self).__init__() 143 | for i in range(num_layers): 144 | layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate) 145 | self.add_module('denselayer%d' % (i + 1), layer) 146 | 147 | 148 | class _Transition(nn.Sequential): 149 | def __init__(self, num_input_features, num_output_features): 150 | super(_Transition, self).__init__() 151 | self.add_module('norm', nn.BatchNorm2d(num_input_features)) 152 | self.add_module('relu', nn.ReLU(inplace=True)) 153 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False)) 154 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) 155 | 156 | 157 | class DenseNet(nn.Module): 158 | r"""Densenet-BC model class, based on 159 | `"Densely Connected Convolutional Networks" `_ 160 | 161 | Args: 162 | growth_rate (int) - how many filters to add each layer (`k` in paper) 163 | block_config (list of 4 ints) - how many layers in each pooling block 164 | num_init_features (int) - the number of filters to learn in the first convolution layer 165 | bn_size (int) - multiplicative factor for number of bottle neck layers 166 | (i.e. bn_size * k features in the bottleneck layer) 167 | drop_rate (float) - dropout rate after each dense layer 168 | num_classes (int) - number of classification classes 169 | """ 170 | 171 | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64, bn_size=4, drop_rate=0, 172 | num_classes=1000): 173 | 174 | super(DenseNet, self).__init__() 175 | 176 | # First convolution 177 | self.features = nn.Sequential(OrderedDict( 178 | [('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), 179 | ('norm0', nn.BatchNorm2d(num_init_features)), ('relu0', nn.ReLU(inplace=True)), 180 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), ])) 181 | 182 | # Each denseblock 183 | num_features = num_init_features 184 | for i, num_layers in enumerate(block_config): 185 | block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, bn_size=bn_size, 186 | growth_rate=growth_rate, drop_rate=drop_rate) 187 | self.features.add_module('denseblock%d' % (i + 1), block) 188 | num_features = num_features + num_layers * growth_rate 189 | if i != len(block_config) - 1: 190 | trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2) 191 | self.features.add_module('transition%d' % (i + 1), trans) 192 | num_features = num_features // 2 193 | 194 | # Final batch norm 195 | self.features.add_module('norm5', nn.BatchNorm2d(num_features)) 196 | 197 | # Linear layer 198 | self.classifier = nn.Linear(num_features, num_classes) 199 | 200 | # Official init from torch repo. 201 | for m in self.modules(): 202 | if isinstance(m, nn.Conv2d): 203 | nn.init.kaiming_normal(m.weight.data) 204 | elif isinstance(m, nn.BatchNorm2d): 205 | m.weight.data.fill_(1) 206 | m.bias.data.zero_() 207 | elif isinstance(m, nn.Linear): 208 | m.bias.data.zero_() 209 | 210 | def forward(self, x): 211 | features = self.features(x) 212 | out = F.relu(features, inplace=True) 213 | out = F.avg_pool2d(out, kernel_size=7, stride=1).view(features.size(0), -1) 214 | out = self.classifier(out) 215 | return out 216 | -------------------------------------------------------------------------------- /ImageNet/models/inception.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.utils.model_zoo as model_zoo 5 | 6 | __all__ = ['Inception3', 'inception_v3'] 7 | 8 | model_urls = { # Inception v3 ported from TensorFlow 9 | 'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth', } 10 | 11 | 12 | def inception_v3(pretrained=False, **kwargs): 13 | r"""Inception v3 model architecture from 14 | `"Rethinking the Inception Architecture for Computer Vision" `_. 15 | 16 | Args: 17 | pretrained (bool): If True, returns a model pre-trained on ImageNet 18 | """ 19 | if pretrained: 20 | if 'transform_input' not in kwargs: 21 | kwargs['transform_input'] = True 22 | model = Inception3(**kwargs) 23 | model.load_state_dict(model_zoo.load_url(model_urls['inception_v3_google'])) 24 | return model 25 | 26 | return Inception3(**kwargs) 27 | 28 | 29 | class Inception3(nn.Module): 30 | 31 | def __init__(self, num_classes=1000, aux_logits=True, transform_input=False): 32 | super(Inception3, self).__init__() 33 | self.aux_logits = aux_logits 34 | self.transform_input = transform_input 35 | self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2) 36 | self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3) 37 | self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1) 38 | self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1) 39 | self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3) 40 | self.Mixed_5b = InceptionA(192, pool_features=32) 41 | self.Mixed_5c = InceptionA(256, pool_features=64) 42 | self.Mixed_5d = InceptionA(288, pool_features=64) 43 | self.Mixed_6a = InceptionB(288) 44 | self.Mixed_6b = InceptionC(768, channels_7x7=128) 45 | self.Mixed_6c = InceptionC(768, channels_7x7=160) 46 | self.Mixed_6d = InceptionC(768, channels_7x7=160) 47 | self.Mixed_6e = InceptionC(768, channels_7x7=192) 48 | if aux_logits: 49 | self.AuxLogits = InceptionAux(768, num_classes) 50 | self.Mixed_7a = InceptionD(768) 51 | self.Mixed_7b = InceptionE(1280) 52 | self.Mixed_7c = InceptionE(2048) 53 | self.fc = nn.Linear(2048, num_classes) 54 | 55 | for m in self.modules(): 56 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 57 | import scipy.stats as stats 58 | stddev = m.stddev if hasattr(m, 'stddev') else 0.1 59 | X = stats.truncnorm(-2, 2, scale=stddev) 60 | values = torch.Tensor(X.rvs(m.weight.data.numel())) 61 | values = values.view(m.weight.data.size()) 62 | m.weight.data.copy_(values) 63 | elif isinstance(m, nn.BatchNorm2d): 64 | m.weight.data.fill_(1) 65 | m.bias.data.zero_() 66 | 67 | def forward(self, x): 68 | if self.transform_input: 69 | x = x.clone() 70 | x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 71 | x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 72 | x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 73 | # 299 x 299 x 3 74 | x = self.Conv2d_1a_3x3(x) 75 | # 149 x 149 x 32 76 | x = self.Conv2d_2a_3x3(x) 77 | # 147 x 147 x 32 78 | x = self.Conv2d_2b_3x3(x) 79 | # 147 x 147 x 64 80 | x = F.max_pool2d(x, kernel_size=3, stride=2) 81 | # 73 x 73 x 64 82 | x = self.Conv2d_3b_1x1(x) 83 | # 73 x 73 x 80 84 | x = self.Conv2d_4a_3x3(x) 85 | # 71 x 71 x 192 86 | x = F.max_pool2d(x, kernel_size=3, stride=2) 87 | # 35 x 35 x 192 88 | x = self.Mixed_5b(x) 89 | # 35 x 35 x 256 90 | x = self.Mixed_5c(x) 91 | # 35 x 35 x 288 92 | x = self.Mixed_5d(x) 93 | # 35 x 35 x 288 94 | x = self.Mixed_6a(x) 95 | # 17 x 17 x 768 96 | x = self.Mixed_6b(x) 97 | # 17 x 17 x 768 98 | x = self.Mixed_6c(x) 99 | # 17 x 17 x 768 100 | x = self.Mixed_6d(x) 101 | # 17 x 17 x 768 102 | x = self.Mixed_6e(x) 103 | # 17 x 17 x 768 104 | if self.training and self.aux_logits: 105 | aux = self.AuxLogits(x) 106 | # 17 x 17 x 768 107 | x = self.Mixed_7a(x) 108 | # 8 x 8 x 1280 109 | x = self.Mixed_7b(x) 110 | # 8 x 8 x 2048 111 | x = self.Mixed_7c(x) 112 | # 8 x 8 x 2048 113 | x = F.avg_pool2d(x, kernel_size=8) 114 | # 1 x 1 x 2048 115 | x = F.dropout(x, training=self.training) 116 | # 1 x 1 x 2048 117 | x = x.view(x.size(0), -1) 118 | # 2048 119 | x = self.fc(x) 120 | # 1000 (num_classes) 121 | if self.training and self.aux_logits: 122 | return x, aux 123 | return x 124 | 125 | 126 | class InceptionA(nn.Module): 127 | 128 | def __init__(self, in_channels, pool_features): 129 | super(InceptionA, self).__init__() 130 | self.branch1x1 = BasicConv2d(in_channels, 64, kernel_size=1) 131 | 132 | self.branch5x5_1 = BasicConv2d(in_channels, 48, kernel_size=1) 133 | self.branch5x5_2 = BasicConv2d(48, 64, kernel_size=5, padding=2) 134 | 135 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1) 136 | self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1) 137 | self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, padding=1) 138 | 139 | self.branch_pool = BasicConv2d(in_channels, pool_features, kernel_size=1) 140 | 141 | def forward(self, x): 142 | branch1x1 = self.branch1x1(x) 143 | 144 | branch5x5 = self.branch5x5_1(x) 145 | branch5x5 = self.branch5x5_2(branch5x5) 146 | 147 | branch3x3dbl = self.branch3x3dbl_1(x) 148 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 149 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 150 | 151 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) 152 | branch_pool = self.branch_pool(branch_pool) 153 | 154 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 155 | return torch.cat(outputs, 1) 156 | 157 | 158 | class InceptionB(nn.Module): 159 | 160 | def __init__(self, in_channels): 161 | super(InceptionB, self).__init__() 162 | self.branch3x3 = BasicConv2d(in_channels, 384, kernel_size=3, stride=2) 163 | 164 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1) 165 | self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1) 166 | self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, stride=2) 167 | 168 | def forward(self, x): 169 | branch3x3 = self.branch3x3(x) 170 | 171 | branch3x3dbl = self.branch3x3dbl_1(x) 172 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 173 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 174 | 175 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) 176 | 177 | outputs = [branch3x3, branch3x3dbl, branch_pool] 178 | return torch.cat(outputs, 1) 179 | 180 | 181 | class InceptionC(nn.Module): 182 | 183 | def __init__(self, in_channels, channels_7x7): 184 | super(InceptionC, self).__init__() 185 | self.branch1x1 = BasicConv2d(in_channels, 192, kernel_size=1) 186 | 187 | c7 = channels_7x7 188 | self.branch7x7_1 = BasicConv2d(in_channels, c7, kernel_size=1) 189 | self.branch7x7_2 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3)) 190 | self.branch7x7_3 = BasicConv2d(c7, 192, kernel_size=(7, 1), padding=(3, 0)) 191 | 192 | self.branch7x7dbl_1 = BasicConv2d(in_channels, c7, kernel_size=1) 193 | self.branch7x7dbl_2 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)) 194 | self.branch7x7dbl_3 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3)) 195 | self.branch7x7dbl_4 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)) 196 | self.branch7x7dbl_5 = BasicConv2d(c7, 192, kernel_size=(1, 7), padding=(0, 3)) 197 | 198 | self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1) 199 | 200 | def forward(self, x): 201 | branch1x1 = self.branch1x1(x) 202 | 203 | branch7x7 = self.branch7x7_1(x) 204 | branch7x7 = self.branch7x7_2(branch7x7) 205 | branch7x7 = self.branch7x7_3(branch7x7) 206 | 207 | branch7x7dbl = self.branch7x7dbl_1(x) 208 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 209 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 210 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 211 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 212 | 213 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) 214 | branch_pool = self.branch_pool(branch_pool) 215 | 216 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 217 | return torch.cat(outputs, 1) 218 | 219 | 220 | class InceptionD(nn.Module): 221 | 222 | def __init__(self, in_channels): 223 | super(InceptionD, self).__init__() 224 | self.branch3x3_1 = BasicConv2d(in_channels, 192, kernel_size=1) 225 | self.branch3x3_2 = BasicConv2d(192, 320, kernel_size=3, stride=2) 226 | 227 | self.branch7x7x3_1 = BasicConv2d(in_channels, 192, kernel_size=1) 228 | self.branch7x7x3_2 = BasicConv2d(192, 192, kernel_size=(1, 7), padding=(0, 3)) 229 | self.branch7x7x3_3 = BasicConv2d(192, 192, kernel_size=(7, 1), padding=(3, 0)) 230 | self.branch7x7x3_4 = BasicConv2d(192, 192, kernel_size=3, stride=2) 231 | 232 | def forward(self, x): 233 | branch3x3 = self.branch3x3_1(x) 234 | branch3x3 = self.branch3x3_2(branch3x3) 235 | 236 | branch7x7x3 = self.branch7x7x3_1(x) 237 | branch7x7x3 = self.branch7x7x3_2(branch7x7x3) 238 | branch7x7x3 = self.branch7x7x3_3(branch7x7x3) 239 | branch7x7x3 = self.branch7x7x3_4(branch7x7x3) 240 | 241 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) 242 | outputs = [branch3x3, branch7x7x3, branch_pool] 243 | return torch.cat(outputs, 1) 244 | 245 | 246 | class InceptionE(nn.Module): 247 | 248 | def __init__(self, in_channels): 249 | super(InceptionE, self).__init__() 250 | self.branch1x1 = BasicConv2d(in_channels, 320, kernel_size=1) 251 | 252 | self.branch3x3_1 = BasicConv2d(in_channels, 384, kernel_size=1) 253 | self.branch3x3_2a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1)) 254 | self.branch3x3_2b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0)) 255 | 256 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 448, kernel_size=1) 257 | self.branch3x3dbl_2 = BasicConv2d(448, 384, kernel_size=3, padding=1) 258 | self.branch3x3dbl_3a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1)) 259 | self.branch3x3dbl_3b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0)) 260 | 261 | self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1) 262 | 263 | def forward(self, x): 264 | branch1x1 = self.branch1x1(x) 265 | 266 | branch3x3 = self.branch3x3_1(x) 267 | branch3x3 = [self.branch3x3_2a(branch3x3), self.branch3x3_2b(branch3x3), ] 268 | branch3x3 = torch.cat(branch3x3, 1) 269 | 270 | branch3x3dbl = self.branch3x3dbl_1(x) 271 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 272 | branch3x3dbl = [self.branch3x3dbl_3a(branch3x3dbl), self.branch3x3dbl_3b(branch3x3dbl), ] 273 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 274 | 275 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) 276 | branch_pool = self.branch_pool(branch_pool) 277 | 278 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 279 | return torch.cat(outputs, 1) 280 | 281 | 282 | class InceptionAux(nn.Module): 283 | 284 | def __init__(self, in_channels, num_classes): 285 | super(InceptionAux, self).__init__() 286 | self.conv0 = BasicConv2d(in_channels, 128, kernel_size=1) 287 | self.conv1 = BasicConv2d(128, 768, kernel_size=5) 288 | self.conv1.stddev = 0.01 289 | self.fc = nn.Linear(768, num_classes) 290 | self.fc.stddev = 0.001 291 | 292 | def forward(self, x): 293 | # 17 x 17 x 768 294 | x = F.avg_pool2d(x, kernel_size=5, stride=3) 295 | # 5 x 5 x 768 296 | x = self.conv0(x) 297 | # 5 x 5 x 128 298 | x = self.conv1(x) 299 | # 1 x 1 x 768 300 | x = x.view(x.size(0), -1) 301 | # 768 302 | x = self.fc(x) 303 | # 1000 304 | return x 305 | 306 | 307 | class BasicConv2d(nn.Module): 308 | 309 | def __init__(self, in_channels, out_channels, **kwargs): 310 | super(BasicConv2d, self).__init__() 311 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) 312 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001) 313 | 314 | def forward(self, x): 315 | x = self.conv(x) 316 | x = self.bn(x) 317 | return F.relu(x, inplace=True) 318 | -------------------------------------------------------------------------------- /ImageNet/models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | import extension as my 5 | 6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'] 7 | 8 | model_urls = {'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 9 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 10 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 11 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 12 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', } 13 | 14 | 15 | def conv3x3(in_planes, out_planes, stride=1): 16 | """3x3 convolution with padding""" 17 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 18 | 19 | 20 | class BasicBlock(nn.Module): 21 | expansion = 1 22 | 23 | def __init__(self, inplanes, planes, stride=1, downsample=None): 24 | super(BasicBlock, self).__init__() 25 | self.conv1 = conv3x3(inplanes, planes, stride) 26 | self.bn1 = nn.BatchNorm2d(planes) 27 | self.relu = nn.ReLU(inplace=True) 28 | self.conv2 = conv3x3(planes, planes) 29 | self.bn2 = nn.BatchNorm2d(planes) 30 | self.downsample = downsample 31 | self.stride = stride 32 | 33 | def forward(self, x): 34 | residual = x 35 | 36 | out = self.conv1(x) 37 | out = self.bn1(out) 38 | out = self.relu(out) 39 | 40 | out = self.conv2(out) 41 | out = self.bn2(out) 42 | 43 | if self.downsample is not None: 44 | residual = self.downsample(x) 45 | 46 | out += residual 47 | out = self.relu(out) 48 | 49 | return out 50 | 51 | 52 | class Bottleneck(nn.Module): 53 | expansion = 4 54 | 55 | def __init__(self, inplanes, planes, stride=1, downsample=None): 56 | super(Bottleneck, self).__init__() 57 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 58 | self.bn1 = nn.BatchNorm2d(planes) 59 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 60 | self.bn2 = nn.BatchNorm2d(planes) 61 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 62 | self.bn3 = nn.BatchNorm2d(planes * 4) 63 | self.relu = nn.ReLU(inplace=True) 64 | self.downsample = downsample 65 | self.stride = stride 66 | 67 | def forward(self, x): 68 | residual = x 69 | 70 | out = self.conv1(x) 71 | out = self.bn1(out) 72 | out = self.relu(out) 73 | 74 | out = self.conv2(out) 75 | out = self.bn2(out) 76 | out = self.relu(out) 77 | 78 | out = self.conv3(out) 79 | out = self.bn3(out) 80 | 81 | if self.downsample is not None: 82 | residual = self.downsample(x) 83 | 84 | out += residual 85 | out = self.relu(out) 86 | 87 | return out 88 | 89 | 90 | class ResNet(nn.Module): 91 | 92 | def __init__(self, block, layers, num_classes=1000, **kwargs): 93 | self.inplanes = 64 94 | super(ResNet, self).__init__() 95 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 96 | self.bn1 = my.Norm(64) 97 | self.relu = nn.ReLU(inplace=True) 98 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 99 | self.layer1 = self._make_layer(block, 64, layers[0]) 100 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 101 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 102 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 103 | self.avgpool = nn.AvgPool2d(7, stride=1) 104 | if kwargs.setdefault('last', False): 105 | self.last_bn = my.Norm(512 * block.expansion, dim=2) 106 | else: 107 | self.last_bn = None 108 | self.fc = nn.Linear(512 * block.expansion, num_classes) 109 | 110 | for m in self.modules(): 111 | if isinstance(m, nn.Conv2d): 112 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 113 | m.weight.data.normal_(0, math.sqrt(2. / n)) 114 | elif isinstance(m, nn.BatchNorm2d): 115 | m.weight.data.fill_(1) 116 | m.bias.data.zero_() 117 | 118 | def _make_layer(self, block, planes, blocks, stride=1): 119 | downsample = None 120 | if stride != 1 or self.inplanes != planes * block.expansion: 121 | downsample = nn.Sequential( 122 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), 123 | nn.BatchNorm2d(planes * block.expansion), ) 124 | 125 | layers = [] 126 | layers.append(block(self.inplanes, planes, stride, downsample)) 127 | self.inplanes = planes * block.expansion 128 | for i in range(1, blocks): 129 | layers.append(block(self.inplanes, planes)) 130 | 131 | return nn.Sequential(*layers) 132 | 133 | def forward(self, x): 134 | x = self.conv1(x) 135 | x = self.bn1(x) 136 | x = self.relu(x) 137 | x = self.maxpool(x) 138 | 139 | x = self.layer1(x) 140 | x = self.layer2(x) 141 | x = self.layer3(x) 142 | x = self.layer4(x) 143 | 144 | x = self.avgpool(x) 145 | x = x.view(x.size(0), -1) 146 | if self.last_bn is not None: 147 | x = self.last_bn(x) 148 | x = self.fc(x) 149 | 150 | return x 151 | 152 | 153 | def resnet18(pretrained=False, **kwargs): 154 | """Constructs a ResNet-18 model. 155 | 156 | Args: 157 | pretrained (bool): If True, returns a model pre-trained on ImageNet 158 | """ 159 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 160 | if pretrained: 161 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 162 | return model 163 | 164 | 165 | def resnet34(pretrained=False, **kwargs): 166 | """Constructs a ResNet-34 model. 167 | 168 | Args: 169 | pretrained (bool): If True, returns a model pre-trained on ImageNet 170 | """ 171 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 172 | if pretrained: 173 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 174 | return model 175 | 176 | 177 | def resnet50(pretrained=False, **kwargs): 178 | """Constructs a ResNet-50 model. 179 | 180 | Args: 181 | pretrained (bool): If True, returns a model pre-trained on ImageNet 182 | """ 183 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 184 | if pretrained: 185 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 186 | return model 187 | 188 | 189 | def resnet101(pretrained=False, **kwargs): 190 | """Constructs a ResNet-101 model. 191 | 192 | Args: 193 | pretrained (bool): If True, returns a model pre-trained on ImageNet 194 | """ 195 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 196 | if pretrained: 197 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 198 | return model 199 | 200 | 201 | def resnet152(pretrained=False, **kwargs): 202 | """Constructs a ResNet-152 model. 203 | 204 | Args: 205 | pretrained (bool): If True, returns a model pre-trained on ImageNet 206 | """ 207 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 208 | if pretrained: 209 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 210 | return model 211 | -------------------------------------------------------------------------------- /ImageNet/models/squeezenet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.init as init 5 | import torch.utils.model_zoo as model_zoo 6 | 7 | __all__ = ['SqueezeNet', 'squeezenet1_0', 'squeezenet1_1'] 8 | 9 | model_urls = {'squeezenet1_0': 'https://download.pytorch.org/models/squeezenet1_0-a815701f.pth', 10 | 'squeezenet1_1': 'https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth', } 11 | 12 | 13 | class Fire(nn.Module): 14 | 15 | def __init__(self, inplanes, squeeze_planes, expand1x1_planes, expand3x3_planes): 16 | super(Fire, self).__init__() 17 | self.inplanes = inplanes 18 | self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1) 19 | self.squeeze_activation = nn.ReLU(inplace=True) 20 | self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes, kernel_size=1) 21 | self.expand1x1_activation = nn.ReLU(inplace=True) 22 | self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes, kernel_size=3, padding=1) 23 | self.expand3x3_activation = nn.ReLU(inplace=True) 24 | 25 | def forward(self, x): 26 | x = self.squeeze_activation(self.squeeze(x)) 27 | return torch.cat([self.expand1x1_activation(self.expand1x1(x)), self.expand3x3_activation(self.expand3x3(x))], 28 | 1) 29 | 30 | 31 | class SqueezeNet(nn.Module): 32 | 33 | def __init__(self, version=1.0, num_classes=1000): 34 | super(SqueezeNet, self).__init__() 35 | if version not in [1.0, 1.1]: 36 | raise ValueError("Unsupported SqueezeNet version {version}:" 37 | "1.0 or 1.1 expected".format(version=version)) 38 | self.num_classes = num_classes 39 | if version == 1.0: 40 | self.features = nn.Sequential(nn.Conv2d(3, 96, kernel_size=7, stride=2), nn.ReLU(inplace=True), 41 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), Fire(96, 16, 64, 64), Fire(128, 16, 64, 64), 42 | Fire(128, 32, 128, 128), nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), Fire(256, 32, 128, 128), 43 | Fire(256, 48, 192, 192), Fire(384, 48, 192, 192), Fire(384, 64, 256, 256), 44 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), Fire(512, 64, 256, 256), ) 45 | else: 46 | self.features = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, stride=2), nn.ReLU(inplace=True), 47 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), Fire(64, 16, 64, 64), Fire(128, 16, 64, 64), 48 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), Fire(128, 32, 128, 128), Fire(256, 32, 128, 128), 49 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), Fire(256, 48, 192, 192), Fire(384, 48, 192, 192), 50 | Fire(384, 64, 256, 256), Fire(512, 64, 256, 256), ) 51 | # Final convolution is initialized differently form the rest 52 | final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1) 53 | self.classifier = nn.Sequential(nn.Dropout(p=0.5), final_conv, nn.ReLU(inplace=True), 54 | nn.AvgPool2d(13, stride=1)) 55 | 56 | for m in self.modules(): 57 | if isinstance(m, nn.Conv2d): 58 | if m is final_conv: 59 | init.normal(m.weight.data, mean=0.0, std=0.01) 60 | else: 61 | init.kaiming_uniform(m.weight.data) 62 | if m.bias is not None: 63 | m.bias.data.zero_() 64 | 65 | def forward(self, x): 66 | x = self.features(x) 67 | x = self.classifier(x) 68 | return x.view(x.size(0), self.num_classes) 69 | 70 | 71 | def squeezenet1_0(pretrained=False, **kwargs): 72 | r"""SqueezeNet model architecture from the `"SqueezeNet: AlexNet-level 73 | accuracy with 50x fewer parameters and <0.5MB model size" 74 | `_ paper. 75 | 76 | Args: 77 | pretrained (bool): If True, returns a model pre-trained on ImageNet 78 | """ 79 | model = SqueezeNet(version=1.0, **kwargs) 80 | if pretrained: 81 | model.load_state_dict(model_zoo.load_url(model_urls['squeezenet1_0'])) 82 | return model 83 | 84 | 85 | def squeezenet1_1(pretrained=False, **kwargs): 86 | r"""SqueezeNet 1.1 model from the `official SqueezeNet repo 87 | `_. 88 | SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters 89 | than SqueezeNet 1.0, without sacrificing accuracy. 90 | 91 | Args: 92 | pretrained (bool): If True, returns a model pre-trained on ImageNet 93 | """ 94 | model = SqueezeNet(version=1.1, **kwargs) 95 | if pretrained: 96 | model.load_state_dict(model_zoo.load_url(model_urls['squeezenet1_1'])) 97 | return model 98 | -------------------------------------------------------------------------------- /ImageNet/models/vgg.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | import math 4 | 5 | __all__ = ['VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19_bn', 'vgg19', ] 6 | 7 | model_urls = {'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 8 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 9 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 10 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 11 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', 12 | 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', 13 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', 14 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', } 15 | 16 | 17 | class VGG(nn.Module): 18 | 19 | def __init__(self, features, num_classes=1000, init_weights=True): 20 | super(VGG, self).__init__() 21 | self.features = features 22 | self.classifier = nn.Sequential(nn.Linear(512 * 7 * 7, 4096), nn.ReLU(True), nn.Dropout(), 23 | nn.Linear(4096, 4096), nn.ReLU(True), nn.Dropout(), 24 | nn.Linear(4096, num_classes), ) 25 | if init_weights: 26 | self._initialize_weights() 27 | 28 | def forward(self, x): 29 | x = self.features(x) 30 | x = x.view(x.size(0), -1) 31 | x = self.classifier(x) 32 | return x 33 | 34 | def _initialize_weights(self): 35 | for m in self.modules(): 36 | if isinstance(m, nn.Conv2d): 37 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 38 | m.weight.data.normal_(0, math.sqrt(2. / n)) 39 | if m.bias is not None: 40 | m.bias.data.zero_() 41 | elif isinstance(m, nn.BatchNorm2d): 42 | m.weight.data.fill_(1) 43 | m.bias.data.zero_() 44 | elif isinstance(m, nn.Linear): 45 | m.weight.data.normal_(0, 0.01) 46 | m.bias.data.zero_() 47 | 48 | 49 | def make_layers(cfg, batch_norm=False): 50 | layers = [] 51 | in_channels = 3 52 | for v in cfg: 53 | if v == 'M': 54 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 55 | else: 56 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 57 | if batch_norm: 58 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 59 | else: 60 | layers += [conv2d, nn.ReLU(inplace=True)] 61 | in_channels = v 62 | return nn.Sequential(*layers) 63 | 64 | 65 | cfg = {'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 66 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 67 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 68 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], } 69 | 70 | 71 | def vgg11(pretrained=False, **kwargs): 72 | """VGG 11-layer model (configuration "A") 73 | 74 | Args: 75 | pretrained (bool): If True, returns a model pre-trained on ImageNet 76 | """ 77 | if pretrained: 78 | kwargs['init_weights'] = False 79 | model = VGG(make_layers(cfg['A']), **kwargs) 80 | if pretrained: 81 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11'])) 82 | return model 83 | 84 | 85 | def vgg11_bn(pretrained=False, **kwargs): 86 | """VGG 11-layer model (configuration "A") with batch normalization 87 | 88 | Args: 89 | pretrained (bool): If True, returns a model pre-trained on ImageNet 90 | """ 91 | if pretrained: 92 | kwargs['init_weights'] = False 93 | model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs) 94 | if pretrained: 95 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11_bn'])) 96 | return model 97 | 98 | 99 | def vgg13(pretrained=False, **kwargs): 100 | """VGG 13-layer model (configuration "B") 101 | 102 | Args: 103 | pretrained (bool): If True, returns a model pre-trained on ImageNet 104 | """ 105 | if pretrained: 106 | kwargs['init_weights'] = False 107 | model = VGG(make_layers(cfg['B']), **kwargs) 108 | if pretrained: 109 | model.load_state_dict(model_zoo.load_url(model_urls['vgg13'])) 110 | return model 111 | 112 | 113 | def vgg13_bn(pretrained=False, **kwargs): 114 | """VGG 13-layer model (configuration "B") with batch normalization 115 | 116 | Args: 117 | pretrained (bool): If True, returns a model pre-trained on ImageNet 118 | """ 119 | if pretrained: 120 | kwargs['init_weights'] = False 121 | model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs) 122 | if pretrained: 123 | model.load_state_dict(model_zoo.load_url(model_urls['vgg13_bn'])) 124 | return model 125 | 126 | 127 | def vgg16(pretrained=False, **kwargs): 128 | """VGG 16-layer model (configuration "D") 129 | 130 | Args: 131 | pretrained (bool): If True, returns a model pre-trained on ImageNet 132 | """ 133 | if pretrained: 134 | kwargs['init_weights'] = False 135 | model = VGG(make_layers(cfg['D']), **kwargs) 136 | if pretrained: 137 | model.load_state_dict(model_zoo.load_url(model_urls['vgg16'])) 138 | return model 139 | 140 | 141 | def vgg16_bn(pretrained=False, **kwargs): 142 | """VGG 16-layer model (configuration "D") with batch normalization 143 | 144 | Args: 145 | pretrained (bool): If True, returns a model pre-trained on ImageNet 146 | """ 147 | if pretrained: 148 | kwargs['init_weights'] = False 149 | model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs) 150 | if pretrained: 151 | model.load_state_dict(model_zoo.load_url(model_urls['vgg16_bn'])) 152 | return model 153 | 154 | 155 | def vgg19(pretrained=False, **kwargs): 156 | """VGG 19-layer model (configuration "E") 157 | 158 | Args: 159 | pretrained (bool): If True, returns a model pre-trained on ImageNet 160 | """ 161 | if pretrained: 162 | kwargs['init_weights'] = False 163 | model = VGG(make_layers(cfg['E']), **kwargs) 164 | if pretrained: 165 | model.load_state_dict(model_zoo.load_url(model_urls['vgg19'])) 166 | return model 167 | 168 | 169 | def vgg19_bn(pretrained=False, **kwargs): 170 | """VGG 19-layer model (configuration 'E') with batch normalization 171 | 172 | Args: 173 | pretrained (bool): If True, returns a model pre-trained on ImageNet 174 | """ 175 | if pretrained: 176 | kwargs['init_weights'] = False 177 | model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs) 178 | if pretrained: 179 | model.load_state_dict(model_zoo.load_url(model_urls['vgg19_bn'])) 180 | return model 181 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2019, Lei Huang 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # IterNorm-pytorch 2 | Pytorch reimplementation of the IterNorm methods, which is described in the following paper: 3 | 4 | **Iterative Normalization: Beyond Standardization towards Efficient Whitening** 5 | 6 | Lei Huang, Yi Zhou, Fan Zhu, Li Liu, Ling Shao 7 | 8 | *IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2019 (accepted).* 9 | [arXiv:1904.03441](https://arxiv.org/abs/1904.03441) 10 | 11 | 12 | This project also provide the pytorch implementation of Decorrelated Batch Normalization (CVPR 2018, [arXiv:1804.08450](https://arxiv.org/abs/1804.08450)), more details please refer to the [Torch project](https://github.com/princeton-vl/DecorrelatedBN). 13 | 14 | ## Requirements and Dependency 15 | * Install [PyTorch](http://torch.ch) with CUDA (for GPU). (Experiments are validated on python 3.6.8 and pytorch-nightly 1.0.0) 16 | * (For visualization if needed), install the dependency [visdom](https://github.com/facebookresearch/visdom) by: 17 | ```Bash 18 | pip install visdom 19 | ``` 20 | 21 | 22 | ## Experiments 23 | 24 | #### 1. VGG-network on Cifar-10 datasets: 25 | 26 | run the scripts in the `./cifar10/experiments/vgg`. Note that the dataset root dir should be altered by setting the para '--dataset-root', and the dataset style is described as: 27 | ``` 28 | - 29 | |-cifar10-batches-py 30 | ||-data_batch_1 31 | ||-data_batch_2 32 | ||-data_batch_3 33 | ||-data_batch_4 34 | ||-data_batch_5 35 | ||-test_batch 36 | ``` 37 | If the dataset is not exist, the script will download it, under the conditioning that the `dataset-root` dir is existed 38 | 39 | #### 2. Wide-Residual-Network on Cifar-10 datasets: 40 | 41 | run the scripts in the `./cifar10/experiments/wrn`. 42 | 43 | #### 3. ImageNet experiments. 44 | 45 | run the scripts in the `./ImageNet/experiment`. Note that resnet18 experimetns are run on one GPU, and resnet-50/101 are run on 4 GPU in the scripts. 46 | 47 | Note that the dataset root dir should be altered by setting the para '--dataset-root'. 48 | and the dataset style is described as: 49 | 50 | ``` 51 | - 52 | |-train 53 | ||-class1 54 | ||-... 55 | ||-class1000 56 | |-var 57 | ||-class1 58 | ||-... 59 | ||-class1000 60 | ``` 61 | 62 | ## Using IterNorm in other projects/tasks 63 | (1) copy `./extension/normalization/iterative_normalization.py` to the respective dir. 64 | 65 | (2) import the `IterNorm` class in `iterative_normalization.py` 66 | 67 | (3) generally speaking, replace the `BatchNorm` layer by `IterNorm`, or add it in any place if you want to the feature/channel decorrelated. Considering the efficiency (Note that `BatchNorm` is intergrated in `cudnn` while `IterNorm` is based on the pytorch script without optimization), we recommend 1) replace the first `BatchNorm`; 2) insert extra `IterNorm` before the first skip connection in resnet; 3) inserted before the final linear classfier as described in the paper. 68 | 69 | (4) Some tips related to the hyperparamters (Group size `G` and Iterative Number `T`). We recommend `G=64` (i.e., the channel number in per group is 64) and `T=5` by default. If you run on large batch size (e.g.>1024), you can either increase `G` or `T`. For fine tunning, fix `G=64 or G=32`, and search `T={3,4,5,6,7,8}` may help. 70 | -------------------------------------------------------------------------------- /cifar10/cifar10.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import time 3 | import os 4 | import sys 5 | import shutil 6 | import argparse 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.utils.data 11 | import torchvision.transforms as transforms 12 | 13 | sys.path.append('..') 14 | import models 15 | import extension as ext 16 | 17 | 18 | class ClassificationSmall: 19 | def __init__(self): 20 | self.cfg = self.add_arguments() 21 | self.model_name = self.cfg.arch + ext.normailzation.setting(self.cfg) + '_' + self.cfg.dataset 22 | 23 | self.result_path = os.path.join(self.cfg.output, self.model_name, self.cfg.log_suffix) 24 | os.makedirs(self.result_path, exist_ok=True) 25 | self.logger = ext.logger.setting('log.txt', self.result_path, self.cfg.test, bool(self.cfg.resume)) 26 | ext.trainer.setting(self.cfg) 27 | self.model = models.__dict__[self.cfg.arch](**self.cfg.arch_cfg) 28 | self.logger('==> model [{}]: {}'.format(self.model_name, self.model)) 29 | self.optimizer = ext.optimizer.setting(self.model, self.cfg) 30 | self.scheduler = ext.scheduler.setting(self.optimizer, self.cfg) 31 | 32 | self.saver = ext.checkpoint.Checkpoint(self.model, self.cfg, self.optimizer, self.scheduler, self.result_path, 33 | not self.cfg.test) 34 | self.saver.load(self.cfg.load) 35 | 36 | # dataset loader 37 | normalize = transforms.Normalize(mean=[0.491, 0.482, 0.447], std=[0.247, 0.243, 0.262]) 38 | val_transform = [transforms.ToTensor(), normalize, ] 39 | if self.cfg.augmentation: 40 | train_transform = [transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip()] 41 | else: 42 | train_transform = [] 43 | train_transform.extend([transforms.ToTensor(), normalize, ]) 44 | self.train_loader = ext.dataset.get_dataset_loader(self.cfg, train_transform, train=True) 45 | self.val_loader = ext.dataset.get_dataset_loader(self.cfg, val_transform, train=False) 46 | 47 | self.device = torch.device('cuda') 48 | self.num_gpu = torch.cuda.device_count() 49 | self.logger('==> use {:d} GPUs'.format(self.num_gpu)) 50 | if self.num_gpu > 1: 51 | self.model = torch.nn.DataParallel(self.model) 52 | self.model.cuda() 53 | 54 | self.best_acc = 0 55 | if self.cfg.resume: 56 | saved = self.saver.resume(self.cfg.resume) 57 | self.cfg.start_epoch = saved['epoch'] 58 | self.best_acc = saved['best_acc'] 59 | self.criterion = nn.CrossEntropyLoss() 60 | 61 | self.vis = ext.visualization.setting(self.cfg, self.model_name, 62 | {'train loss': 'loss', 'test loss': 'loss', 'train accuracy': 'accuracy', 63 | 'test accuracy': 'accuracy'}) 64 | return 65 | 66 | def add_arguments(self): 67 | model_names = sorted( 68 | name for name in models.__dict__ if not name.startswith("__") and callable(models.__dict__[name])) 69 | parser = argparse.ArgumentParser('Small Scale Image Classification') 70 | parser.add_argument('-a', '--arch', metavar='ARCH', default='simple', choices=model_names, 71 | help='model architecture: ' + ' | '.join(model_names) + '\t(Default: simple)') 72 | parser.add_argument('--arch-cfg', metavar='DICT', default={}, type=ext.utils.str2dict, 73 | help='The extra model architecture configuration.') 74 | parser.add_argument('-A', '--augmentation', type=ext.utils.str2bool, default=True, metavar='BOOL', 75 | help='Use data augmentation? (default: True)') 76 | ext.trainer.add_arguments(parser) 77 | parser.set_defaults(epochs=200) 78 | ext.dataset.add_arguments(parser) 79 | parser.set_defaults(dataset='cifar10', workers=4) 80 | ext.scheduler.add_arguments(parser) 81 | parser.set_defaults(lr_method='steps', lr_steps=[100, 150], lr=0.1) 82 | ext.optimizer.add_arguments(parser) 83 | parser.set_defaults(optimizer='sgd', weight_decay=1e-4) 84 | ext.logger.add_arguments(parser) 85 | ext.checkpoint.add_arguments(parser) 86 | ext.visualization.add_arguments(parser) 87 | ext.normailzation.add_arguments(parser) 88 | args = parser.parse_args() 89 | if args.resume: 90 | args = parser.parse_args(namespace=ext.checkpoint.Checkpoint.load_config(args.resume)) 91 | return args 92 | 93 | def train(self): 94 | if self.cfg.test: 95 | self.validate() 96 | return 97 | # train model 98 | for epoch in range(self.cfg.start_epoch + 1, self.cfg.epochs): 99 | if self.cfg.lr_method != 'auto': 100 | self.scheduler.step() 101 | self.train_epoch(epoch) 102 | accuracy, val_loss = self.validate(epoch) 103 | self.saver.save_checkpoint(epoch=epoch, best_acc=self.best_acc) 104 | if self.cfg.lr_method == 'auto': 105 | self.scheduler.step(val_loss) 106 | # finish train 107 | now_date = time.strftime("%y-%m-%d_%H:%M:%S", time.localtime(time.time())) 108 | self.logger('==> end time: {}'.format(now_date)) 109 | new_log_filename = '{}_{}_{:5.2f}%.txt'.format(self.model_name, now_date, self.best_acc) 110 | self.logger('\n==> Network training completed. Copy log file to {}'.format(new_log_filename)) 111 | shutil.copy(self.logger.filename, os.path.join(self.result_path, new_log_filename)) 112 | return 113 | 114 | def train_epoch(self, epoch): 115 | self.logger('\nEpoch: {}, lr: {:.2g}, weight decay: {:.2g} on model {}'.format(epoch, 116 | self.optimizer.param_groups[0]['lr'], self.optimizer.param_groups[0]['weight_decay'], self.model_name)) 117 | self.model.train() 118 | train_loss = 0 119 | correct = 0 120 | total = 0 121 | progress_bar = ext.ProgressBar(len(self.train_loader)) 122 | for i, (inputs, targets) in enumerate(self.train_loader, 1): 123 | inputs, targets = inputs.to(self.device), targets.to(self.device) 124 | 125 | # compute output 126 | outputs = self.model(inputs) 127 | losses = self.criterion(outputs, targets) 128 | 129 | # compute gradient and do SGD step 130 | self.optimizer.zero_grad() 131 | losses.backward() 132 | self.optimizer.step() 133 | 134 | # measure accuracy and record loss 135 | train_loss += losses.item() * targets.size(0) 136 | pred = outputs.max(1, keepdim=True)[1] 137 | correct += pred.eq(targets.view_as(pred)).sum().item() 138 | total += targets.size(0) 139 | if i % 10 == 0 or i == len(self.train_loader): 140 | progress_bar.step('Loss: {:.5g} | Accuracy: {:.2f}%'.format(train_loss / total, 100. * correct / total), 141 | 10) 142 | train_loss /= total 143 | accuracy = 100. * correct / total 144 | self.vis.add_value('train loss', train_loss) 145 | self.vis.add_value('train accuracy', accuracy) 146 | self.logger( 147 | 'Train on epoch {}: average loss={:.5g}, accuracy={:.2f}% ({}/{}), time: {}'.format(epoch, train_loss, 148 | accuracy, correct, total, progress_bar.time_used())) 149 | return 150 | 151 | def validate(self, epoch=-1): 152 | test_loss = 0 153 | correct = 0 154 | total = 0 155 | progress_bar = ext.ProgressBar(len(self.val_loader)) 156 | self.model.eval() 157 | with torch.no_grad(): 158 | for inputs, targets in self.val_loader: 159 | inputs, targets = inputs.to(self.device), targets.to(self.device) 160 | outputs = self.model(inputs) 161 | test_loss += self.criterion(outputs, targets).item() * targets.size(0) 162 | prediction = outputs.max(1, keepdim=True)[1] 163 | correct += prediction.eq(targets.view_as(prediction)).sum().item() 164 | total += targets.size(0) 165 | progress_bar.step('Loss: {:.5g} | Accuracy: {:.2f}%'.format(test_loss / total, 100. * correct / total)) 166 | test_loss /= total 167 | accuracy = correct * 100. / total 168 | self.vis.add_value('test loss', test_loss) 169 | self.vis.add_value('test accuracy', accuracy) 170 | self.logger('Test on epoch {}: average loss={:.5g}, accuracy={:.2f}% ({}/{}), time: {}'.format(epoch, test_loss, 171 | accuracy, correct, total, progress_bar.time_used())) 172 | if not self.cfg.test and accuracy > self.best_acc: 173 | self.best_acc = accuracy 174 | self.saver.save_model('best.pth') 175 | self.logger('==> best accuracy: {:.2f}%'.format(self.best_acc)) 176 | return accuracy, test_loss 177 | 178 | 179 | if __name__ == '__main__': 180 | Cs = ClassificationSmall() 181 | Cs.train() 182 | -------------------------------------------------------------------------------- /cifar10/experiments/vgg/vgg_LargeLR_BN.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd "$(dirname $0)/../.." 3 | CUDA_VISIBLE_DEVICES=0 python3 cifar10.py \ 4 | -a=vgg \ 5 | --batch-size=256 \ 6 | --epochs=160 \ 7 | -oo=sgd \ 8 | -oc=momentum=0.9 \ 9 | -wd=0 \ 10 | --lr=1 \ 11 | --lr-method=steps \ 12 | --lr-steps=60,120 \ 13 | --lr-gamma=0.2 \ 14 | --dataset-root=/home/lei/PycharmProjects/data/cifar10/ \ 15 | --norm=BN \ 16 | --norm-cfg=T=5,num_channels=512 \ 17 | --seed=1 \ 18 | --log-suffix=LargeLR \ 19 | $@ 20 | -------------------------------------------------------------------------------- /cifar10/experiments/vgg/vgg_LargeLR_ItN.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd "$(dirname $0)/../.." 3 | CUDA_VISIBLE_DEVICES=0 python3 cifar10.py \ 4 | -a=vgg \ 5 | --batch-size=256 \ 6 | --epochs=160 \ 7 | -oo=sgd \ 8 | -oc=momentum=0.9 \ 9 | -wd=0 \ 10 | --lr=1 \ 11 | --lr-method=steps \ 12 | --lr-steps=60,120 \ 13 | --lr-gamma=0.2 \ 14 | --dataset-root=/home/lei/PycharmProjects/data/cifar10/ \ 15 | --norm=ItN \ 16 | --norm-cfg=T=5,num_channels=512 \ 17 | --seed=1 \ 18 | --log-suffix=LargeLR \ 19 | $@ 20 | -------------------------------------------------------------------------------- /cifar10/experiments/vgg/vgg_b1024_BN.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd "$(dirname $0)/../.." 3 | CUDA_VISIBLE_DEVICES=0 python3 cifar10.py \ 4 | -a=vgg \ 5 | --batch-size=1024 \ 6 | --epochs=160 \ 7 | -oo=sgd \ 8 | -oc=momentum=0.9 \ 9 | -wd=0 \ 10 | --lr=0.4 \ 11 | --lr-method=steps \ 12 | --lr-steps=60,120 \ 13 | --lr-gamma=0.2 \ 14 | --dataset-root=/home/lei/PycharmProjects/data/cifar10/ \ 15 | --norm=BN \ 16 | --norm-cfg=T=5,num_channels=512 \ 17 | --seed=1 \ 18 | --log-suffix=b1024 \ 19 | --vis \ 20 | $@ 21 | -------------------------------------------------------------------------------- /cifar10/experiments/vgg/vgg_b1024_ItN.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd "$(dirname $0)/../.." 3 | CUDA_VISIBLE_DEVICES=0 python3 cifar10.py \ 4 | -a=vgg \ 5 | --batch-size=1024 \ 6 | --epochs=160 \ 7 | -oo=sgd \ 8 | -oc=momentum=0.9 \ 9 | -wd=0 \ 10 | --lr=0.4 \ 11 | --lr-method=steps \ 12 | --lr-steps=60,120 \ 13 | --lr-gamma=0.2 \ 14 | --dataset-root=/home/lei/PycharmProjects/data/cifar10/ \ 15 | --norm=ItN \ 16 | --norm-cfg=T=5,num_channels=512 \ 17 | --seed=1 \ 18 | --log-suffix=b1024 \ 19 | --vis \ 20 | $@ 21 | -------------------------------------------------------------------------------- /cifar10/experiments/vgg/vgg_b16_BN.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd "$(dirname $0)/../.." 3 | CUDA_VISIBLE_DEVICES=0 python3 cifar10.py \ 4 | -a=vgg \ 5 | --batch-size=16 \ 6 | --epochs=160 \ 7 | -oo=sgd \ 8 | -oc=momentum=0.9 \ 9 | -wd=0 \ 10 | --lr=0.1 \ 11 | --lr-method=steps \ 12 | --lr-steps=60,120 \ 13 | --lr-gamma=0.2 \ 14 | --dataset-root=/home/lei/PycharmProjects/data/cifar10/ \ 15 | --norm=BN \ 16 | --norm-cfg=T=5,num_channels=512 \ 17 | --seed=1 \ 18 | --log-suffix=b16 \ 19 | $@ 20 | -------------------------------------------------------------------------------- /cifar10/experiments/vgg/vgg_b16_ItN.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd "$(dirname $0)/../.." 3 | CUDA_VISIBLE_DEVICES=0 python3 cifar10.py \ 4 | -a=vgg \ 5 | --batch-size=16 \ 6 | --epochs=160 \ 7 | -oo=sgd \ 8 | -oc=momentum=0.9 \ 9 | -wd=0 \ 10 | --lr=0.1 \ 11 | --lr-method=steps \ 12 | --lr-steps=60,120 \ 13 | --lr-gamma=0.2 \ 14 | --dataset-root=/home/lei/PycharmProjects/data/cifar10/ \ 15 | --norm=ItN \ 16 | --norm-cfg=T=5,num_channels=512 \ 17 | --seed=1 \ 18 | --log-suffix=b16 \ 19 | $@ 20 | -------------------------------------------------------------------------------- /cifar10/experiments/vgg/vgg_base_BN.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd "$(dirname $0)/../.." 3 | CUDA_VISIBLE_DEVICES=0 python3 cifar10.py \ 4 | -a=vgg \ 5 | --batch-size=256 \ 6 | --epochs=160 \ 7 | -oo=sgd \ 8 | -oc=momentum=0.9 \ 9 | -wd=0 \ 10 | --lr=0.1 \ 11 | --lr-method=steps \ 12 | --lr-steps=60,120 \ 13 | --lr-gamma=0.2 \ 14 | --dataset-root=/home/lei/PycharmProjects/data/cifar10/ \ 15 | --norm=BN \ 16 | --norm-cfg=T=5,num_channels=512 \ 17 | --seed=1 \ 18 | --log-suffix=base \ 19 | --vis \ 20 | $@ 21 | -------------------------------------------------------------------------------- /cifar10/experiments/vgg/vgg_base_ItN.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd "$(dirname $0)/../.." 3 | CUDA_VISIBLE_DEVICES=0 python3 cifar10.py \ 4 | -a=vgg \ 5 | --batch-size=256 \ 6 | --epochs=160 \ 7 | -oo=sgd \ 8 | -oc=momentum=0.9 \ 9 | -wd=0 \ 10 | --lr=0.1 \ 11 | --lr-method=steps \ 12 | --lr-steps=60,120 \ 13 | --lr-gamma=0.2 \ 14 | --dataset-root=/home/lei/PycharmProjects/data/cifar10/ \ 15 | --norm=ItN \ 16 | --norm-cfg=T=5,num_channels=512 \ 17 | --seed=1 \ 18 | --log-suffix=base \ 19 | --vis \ 20 | $@ 21 | -------------------------------------------------------------------------------- /cifar10/experiments/wrn/wrn_28_10_BN.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd "$(dirname $0)/../.." 3 | CUDA_VISIBLE_DEVICES=0 python3 cifar10.py \ 4 | -a=WRN_28_10 \ 5 | --arch-cfg=dropout=0.3 \ 6 | --batch-size=128 \ 7 | --epochs=200 \ 8 | -oo=sgd \ 9 | -oc=momentum=0.9 \ 10 | -wd=5e-4 \ 11 | --lr=0.1 \ 12 | --lr-method=steps \ 13 | --lr-steps=60,120,160 \ 14 | --lr-gamma=0.2 \ 15 | --dataset-root=/home/lei/PycharmProjects/data/cifar10/ \ 16 | --norm=BN \ 17 | --norm-cfg=T=5,num_channels=64 \ 18 | --seed=1 \ 19 | --log-suffix=seed1 \ 20 | $@ 21 | -------------------------------------------------------------------------------- /cifar10/experiments/wrn/wrn_28_10_ItN.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd "$(dirname $0)/../.." 3 | CUDA_VISIBLE_DEVICES=0 python3 cifar10.py \ 4 | -a=WRN_28_10 \ 5 | --arch-cfg=dropout=0.3 \ 6 | --batch-size=128 \ 7 | --epochs=200 \ 8 | -oo=sgd \ 9 | -oc=momentum=0.9 \ 10 | -wd=5e-4 \ 11 | --lr=0.1 \ 12 | --lr-method=steps \ 13 | --lr-steps=60,120,160 \ 14 | --lr-gamma=0.2 \ 15 | --dataset-root=/home/lei/PycharmProjects/data/cifar10/ \ 16 | --norm=ItN \ 17 | --norm-cfg=T=5,num_channels=64 \ 18 | --seed=1 \ 19 | --log-suffix=seed1 \ 20 | $@ 21 | -------------------------------------------------------------------------------- /cifar10/experiments/wrn/wrn_40_10_BN.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd "$(dirname $0)/../.." 3 | CUDA_VISIBLE_DEVICES=0 python3 cifar10.py \ 4 | -a=WRN_40_10 \ 5 | --arch-cfg=dropout=0.3 \ 6 | --batch-size=128 \ 7 | --epochs=200 \ 8 | -oo=sgd \ 9 | -oc=momentum=0.9 \ 10 | -wd=5e-4 \ 11 | --lr=0.1 \ 12 | --lr-method=steps \ 13 | --lr-steps=60,120,160 \ 14 | --lr-gamma=0.2 \ 15 | --dataset-root=/home/lei/PycharmProjects/data/cifar10/ \ 16 | --norm=BN \ 17 | --norm-cfg=T=5,num_channels=64 \ 18 | --seed=1 \ 19 | --log-suffix=seed1 \ 20 | $@ 21 | -------------------------------------------------------------------------------- /cifar10/experiments/wrn/wrn_40_10_ItN.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd "$(dirname $0)/../.." 3 | CUDA_VISIBLE_DEVICES=0 python3 cifar10.py \ 4 | -a=WRN_40_10 \ 5 | --arch-cfg=dropout=0.3 \ 6 | --batch-size=128 \ 7 | --epochs=200 \ 8 | -oo=sgd \ 9 | -oc=momentum=0.9 \ 10 | -wd=5e-4 \ 11 | --lr=0.1 \ 12 | --lr-method=steps \ 13 | --lr-steps=60,120,160 \ 14 | --lr-gamma=0.2 \ 15 | --dataset-root=/home/lei/PycharmProjects/data/cifar10/ \ 16 | --norm=ItN \ 17 | --norm-cfg=T=5,num_channels=64 \ 18 | --seed=1 \ 19 | --log-suffix=seed1 \ 20 | $@ 21 | -------------------------------------------------------------------------------- /cifar10/mnist.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import time 3 | import os 4 | import shutil 5 | import argparse 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.utils.data 10 | import torchvision.transforms as transforms 11 | from torchvision.utils import save_image 12 | 13 | import extension as ext 14 | 15 | 16 | class LeNet(nn.Module): 17 | def __init__(self): 18 | super(LeNet, self).__init__() 19 | self.net = ext.Sequential(ext.Conv2d(1, 6, 5, 1, 2, special_id=0), ext.NonLinear(nn.ReLU(True)), 20 | nn.MaxPool2d(2, 2), ext.Conv2d(6, 16, 5), ext.NonLinear(nn.ReLU(True)), 21 | nn.MaxPool2d(2, 2), ext.View(400), ext.Linear(400, 120), ext.NonLinear(nn.ReLU(True)), 22 | ext.Linear(120, 84), ext.NonLinear(nn.ReLU(True)), ext.Linear(84, 10, special_id=1)) 23 | 24 | def forward(self, input): 25 | return self.net(input) 26 | 27 | 28 | def to_img(x): 29 | x = 0.5 * (x + 1) 30 | x = x.clamp(0, 1) 31 | x = x.view(x.size(0), 1, 28, 28) 32 | return x 33 | 34 | 35 | class AutoEncoder(nn.Module): 36 | def __init__(self): 37 | super(AutoEncoder, self).__init__() 38 | self.encoder = ext.Sequential(ext.View(28 * 28), ext.Linear(28 * 28, 128, special_id=0), 39 | ext.NonLinear(nn.ReLU(True)), ext.Linear(128, 64), ext.NonLinear(nn.ReLU(True)), 40 | ext.Linear(64, 12), ext.NonLinear(nn.ReLU(True), special_id=1), 41 | ext.Linear(12, 3, special_id=1)) 42 | self.decoder = ext.Sequential(ext.Linear(3, 12, special_id=2), ext.NonLinear(nn.ReLU(True)), ext.Linear(12, 64), 43 | ext.NonLinear(nn.ReLU(True)), ext.Linear(64, 128), 44 | ext.NonLinear(nn.ReLU(True), special_id=3), 45 | ext.Linear(128, 28 * 28, special_id=3), nn.Tanh(), ext.View(1, 28, 28)) 46 | 47 | def forward(self, x): 48 | x = self.encoder(x) 49 | x = self.decoder(x) 50 | return x 51 | 52 | 53 | class MNIST: 54 | def __init__(self): 55 | self.cfg = self.add_arguments() 56 | self.model_name = self.cfg.arch + ext.quantization.setting(self.cfg) 57 | self.result_path = os.path.join(self.cfg.output, self.cfg.dataset, self.model_name) 58 | os.makedirs(self.result_path, exist_ok=True) 59 | self.logger = ext.logger.setting('log.txt', self.result_path, self.cfg.test, self.cfg.resume is not None) 60 | ext.trainer.setting(self.cfg) 61 | self.model = LeNet() if self.cfg.arch == 'LeNet' else AutoEncoder() 62 | self.logger('==> model [{}]: {}'.format(self.model_name, self.model)) 63 | self.optimizer = ext.optimizer.setting(self.model, self.cfg) 64 | self.scheduler = ext.scheduler.setting(self.optimizer, self.cfg) 65 | 66 | self.saver = ext.checkpoint.Checkpoint(self.model, self.cfg, self.optimizer, self.scheduler, self.result_path, 67 | not self.cfg.test) 68 | self.saver.load(self.cfg.load) 69 | 70 | # dataset loader 71 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) 72 | self.train_loader = ext.dataset.get_dataset_loader(self.cfg, transform, train=True) 73 | self.val_loader = ext.dataset.get_dataset_loader(self.cfg, transform, train=False) 74 | 75 | self.device = torch.device('cuda') 76 | # self.num_gpu = torch.cuda.device_count() 77 | # self.logger('==> use {:d} GPUs'.format(self.num_gpu)) 78 | # if self.num_gpu > 1: 79 | # self.model = torch.nn.DataParallel(self.model) 80 | self.model.cuda() 81 | 82 | self.best_acc = 0 83 | if self.cfg.resume: 84 | saved = self.saver.resume(self.cfg.resume) 85 | self.cfg.start_epoch = saved['epoch'] 86 | self.best_acc = saved['best_acc'] 87 | self.criterion = nn.CrossEntropyLoss() if self.cfg.arch == 'LeNet' else nn.MSELoss() 88 | 89 | self.vis = ext.visualization.setting(self.cfg, self.model_name, 90 | {'train loss': 'loss', 'test loss': 'loss', 'train accuracy': 'accuracy', 91 | 'test accuracy': 'accuracy'}) 92 | return 93 | 94 | def add_arguments(self): 95 | parser = argparse.ArgumentParser('MNIST Classification') 96 | model_names = ['LeNet', 'AE'] 97 | parser.add_argument('-a', '--arch', metavar='ARCH', default=model_names[0], choices=model_names, 98 | help='model architecture: ' + ' | '.join(model_names)) 99 | ext.trainer.add_arguments(parser) 100 | parser.set_defaults(epochs=10) 101 | ext.quantization.add_arguments(parser) 102 | ext.dataset.add_arguments(parser) 103 | parser.set_defaults(dataset='mnist', workers=1, batch_size=[64, 1000]) 104 | ext.scheduler.add_arguments(parser) 105 | parser.set_defaults(lr_method='fix', lr=1e-3) 106 | ext.optimizer.add_arguments(parser) 107 | parser.set_defaults(optimizer='adam', weight_decay=1e-5) 108 | ext.logger.add_arguments(parser) 109 | ext.checkpoint.add_arguments(parser) 110 | ext.visualization.add_arguments(parser) 111 | args = parser.parse_args() 112 | if args.resume: 113 | args = parser.parse_args(namespace=ext.checkpoint.Checkpoint.load_config(args.resume)) 114 | return args 115 | 116 | def train(self): 117 | if self.cfg.test: 118 | self.validate() 119 | return 120 | # train model 121 | for epoch in range(self.cfg.start_epoch + 1, self.cfg.epochs): 122 | if self.cfg.lr_method != 'auto': 123 | self.scheduler.step() 124 | self.train_epoch(epoch) 125 | accuracy, val_loss = self.validate(epoch) 126 | self.saver.save_checkpoint(epoch=epoch, best_acc=self.best_acc) 127 | if self.cfg.lr_method == 'auto': 128 | self.scheduler.step(val_loss) 129 | # finish train 130 | now_date = time.strftime("%y-%m-%d_%H:%M:%S", time.localtime(time.time())) 131 | self.logger('==> end time: {}'.format(now_date)) 132 | new_log_filename = '{}_{}_{:5.2f}%.txt'.format(self.model_name, now_date, self.best_acc) 133 | self.logger('\n==> Network training completed. Copy log file to {}'.format(new_log_filename)) 134 | shutil.copy(self.logger.filename, os.path.join(self.result_path, new_log_filename)) 135 | return 136 | 137 | def train_epoch(self, epoch): 138 | self.logger('\nEpoch: {}, lr: {:.2g}, weight decay: {:.2g} on model {}'.format(epoch, 139 | self.optimizer.param_groups[0]['lr'], self.optimizer.param_groups[0]['weight_decay'], self.model_name)) 140 | self.model.train() 141 | train_loss = 0 142 | correct = 0 143 | total = 0 144 | progress_bar = ext.ProgressBar(len(self.train_loader)) 145 | for i, (inputs, targets) in enumerate(self.train_loader, 1): 146 | inputs = inputs.to(self.device) 147 | targets = targets.to(self.device) if self.cfg.arch == 'LeNet' else inputs 148 | 149 | # compute output 150 | outputs = self.model(inputs) 151 | losses = self.criterion(outputs, targets) 152 | 153 | # compute gradient and do SGD step 154 | self.optimizer.zero_grad() 155 | losses.backward() 156 | self.optimizer.step() 157 | 158 | # measure accuracy and record loss 159 | train_loss += losses.item() * targets.size(0) 160 | if self.cfg.arch == 'LeNet': 161 | pred = outputs.max(1, keepdim=True)[1] 162 | correct += pred.eq(targets.view_as(pred)).sum().item() 163 | else: 164 | correct = -train_loss 165 | total += targets.size(0) 166 | if i % 10 == 0 or i == len(self.train_loader): 167 | progress_bar.step('Loss: {:.5g} | Accuracy: {:.2f}%'.format(train_loss / total, 100. * correct / total), 168 | 10) 169 | train_loss /= total 170 | accuracy = 100. * correct / total 171 | self.vis.add_value('train loss', train_loss) 172 | self.vis.add_value('train accuracy', accuracy) 173 | self.logger( 174 | 'Train on epoch {}: average loss={:.5g}, accuracy={:.2f}% ({}/{}), time: {}'.format(epoch, train_loss, 175 | accuracy, correct, total, progress_bar.time_used())) 176 | return 177 | 178 | def validate(self, epoch=-1): 179 | test_loss = 0 180 | correct = 0 181 | total = 0 182 | progress_bar = ext.ProgressBar(len(self.val_loader)) 183 | self.model.eval() 184 | with torch.no_grad(): 185 | for inputs, targets in self.val_loader: 186 | inputs = inputs.to(self.device) 187 | targets = targets.to(self.device) if self.cfg.arch == 'LeNet' else inputs 188 | outputs = self.model(inputs) 189 | test_loss += self.criterion(outputs, targets).item() * targets.size(0) 190 | if self.cfg.arch == 'LeNet': 191 | prediction = outputs.max(1, keepdim=True)[1] 192 | correct += prediction.eq(targets.view_as(prediction)).sum().item() 193 | else: 194 | correct = -test_loss 195 | total += targets.size(0) 196 | progress_bar.step('Loss: {:.5g} | Accuracy: {:.2f}%'.format(test_loss / total, 100. * correct / total)) 197 | test_loss /= total 198 | accuracy = correct * 100. / total 199 | self.vis.add_value('test loss', test_loss) 200 | self.vis.add_value('test accuracy', accuracy) 201 | self.logger('Test on epoch {}: average loss={:.5g}, accuracy={:.2f}% ({}/{}), time: {}'.format(epoch, test_loss, 202 | accuracy, correct, total, progress_bar.time_used())) 203 | if not self.cfg.test and accuracy > self.best_acc: 204 | self.best_acc = accuracy 205 | self.saver.save_model('best.pth') 206 | self.logger('==> best accuracy: {:.2f}%'.format(self.best_acc)) 207 | if self.cfg.arch == 'AE': 208 | pic = to_img(outputs[:64].cpu().data) 209 | save_image(pic, os.path.join(self.result_path, 'result_{}.png').format(epoch)) 210 | return accuracy, test_loss 211 | 212 | 213 | if __name__ == '__main__': 214 | Cs = MNIST() 215 | Cs.train() 216 | -------------------------------------------------------------------------------- /cifar10/models/WRN.py: -------------------------------------------------------------------------------- 1 | """ 2 | wide residual network 3 | """ 4 | import extension as my 5 | 6 | import math 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | __all__ = ['WideResNet', 'WRN_28_10', 'WRN_40_10'] 12 | 13 | 14 | class BasicBlock(nn.Module): 15 | def __init__(self, in_channels, out_channels, stride, drop_ratio=0.0): 16 | super(BasicBlock, self).__init__() 17 | self.bn1 = my.Norm(in_channels) 18 | self.relu1 = nn.ReLU(inplace=True) 19 | self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride, 1, bias=False) 20 | self.bn2 = my.Norm(out_channels) 21 | self.relu2 = nn.ReLU(inplace=True) 22 | self.dropout = nn.Dropout2d(p=drop_ratio) if drop_ratio > 0 else None 23 | self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False) 24 | if in_channels != out_channels: 25 | self.shortcut = nn.Conv2d(in_channels, out_channels, 1, stride, 0, bias=False) 26 | else: 27 | self.shortcut = None 28 | 29 | def forward(self, x): 30 | y = self.relu1(self.bn1(x)) 31 | z = self.relu2(self.bn2(self.conv1(y))) 32 | if self.dropout is not None: 33 | z = self.dropout(z) 34 | z = self.conv2(z) 35 | if self.shortcut is None: 36 | return x + z 37 | else: 38 | return z + self.shortcut(y) 39 | 40 | 41 | class WideResNet(nn.Module): 42 | def __init__(self, depth=28, widen_factor=1, num_classes=10, dropout=0.0): 43 | super(WideResNet, self).__init__() 44 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 45 | assert ((depth - 4) % 6 == 0) 46 | n = (depth - 4) / 6 47 | block = BasicBlock 48 | # 1st conv before any network block 49 | self.conv1 = nn.Conv2d(3, nChannels[0], 3, 1, 1, bias=False) 50 | # 1st block 51 | self.block1 = self._make_layer(n, block, nChannels[0], nChannels[1], 1, dropout) 52 | # 2nd block 53 | self.block2 = self._make_layer(n, block, nChannels[1], nChannels[2], 2, dropout) 54 | # 3rd block 55 | self.block3 = self._make_layer(n, block, nChannels[2], nChannels[3], 2, dropout) 56 | # global average pooling and classifier 57 | self.bn = my.Norm(nChannels[3]) 58 | self.relu = nn.ReLU(inplace=True) 59 | self.pool = nn.AvgPool2d(8, 1) 60 | self.fc = nn.Linear(nChannels[3], num_classes) 61 | self.nChannels = nChannels[3] 62 | 63 | for m in self.modules(): 64 | if isinstance(m, nn.Conv2d): 65 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 66 | m.weight.data.normal_(0, math.sqrt(2. / n)) 67 | elif isinstance(m, nn.BatchNorm2d): 68 | m.weight.data.fill_(1) 69 | m.bias.data.zero_() 70 | elif isinstance(m, nn.Linear): 71 | m.bias.data.zero_() 72 | 73 | def _make_layer(self, n, block, in_planes, out_planes, stride, dropout): 74 | layers = [] 75 | for i in range(int(n)): 76 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropout)) 77 | return nn.Sequential(*layers) 78 | 79 | def forward(self, x): 80 | out = self.conv1(x) 81 | out = self.block1(out) 82 | out = self.block2(out) 83 | out = self.block3(out) 84 | out = self.relu(self.bn(out)) 85 | out = self.pool(out) 86 | out = out.view(-1, self.nChannels) 87 | return self.fc(out) 88 | 89 | 90 | def WRN_28_10(**kwargs): 91 | return WideResNet(28, 10, **kwargs) 92 | 93 | 94 | def WRN_40_10(**kwargs): 95 | return WideResNet(40, 10, **kwargs) 96 | 97 | 98 | if __name__ == '__main__': 99 | wrn = WideResNet(28, widen_factor=10) 100 | x = torch.randn(2, 3, 32, 32) 101 | y = wrn(x) 102 | print(wrn) 103 | -------------------------------------------------------------------------------- /cifar10/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import * 2 | from .vgg import * 3 | from .WRN import * 4 | -------------------------------------------------------------------------------- /cifar10/models/resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | resnet for cifar in pytorch 3 | 4 | Reference: 5 | [1] K. He, X. Zhang, S. Ren, and J. Sun. Deep residual learning for image recognition. In CVPR, 2016. 6 | [2] K. He, X. Zhang, S. Ren, and J. Sun. Identity mappings in deep residual networks. In ECCV, 2016. 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | import math 12 | import extension as my 13 | 14 | __all__ = ['ResNet', 'PreAct_ResNet', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet164', 15 | 'resnet1001', 'resnet1202', 'preact_resnet20', 'preact_resnet110', 'preact_resnet164', 'preact_resnet1001'] 16 | 17 | 18 | class BasicBlock(nn.Module): 19 | expansion = 1 20 | 21 | def __init__(self, in_planes, planes, stride=1, shortcut=None): 22 | super(BasicBlock, self).__init__() 23 | self.relu = nn.ReLU(True) 24 | self.conv1 = nn.Conv2d(in_planes, planes, 3, stride, 1, bias=False) 25 | self.bn1 = nn.BatchNorm2d(planes) 26 | self.relu1 = nn.ReLU(True) 27 | self.conv2 = nn.Conv2d(planes, planes, 3, 1, 1, bias=False) 28 | self.bn2 = nn.BatchNorm2d(planes) 29 | self.shortcut = shortcut 30 | 31 | def forward(self, x): 32 | x = self.relu(x) 33 | residual = x if self.shortcut is None else self.shortcut(x) 34 | x = self.relu1(self.bn1(self.conv1(x))) 35 | x = self.bn2(self.conv2(x)) 36 | x += residual 37 | return x 38 | 39 | 40 | class Bottleneck(nn.Module): 41 | expansion = 4 42 | 43 | def __init__(self, inplanes, planes, stride=1, shortcut=None): 44 | super(Bottleneck, self).__init__() 45 | self.relu = nn.ReLU(True) 46 | 47 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 48 | self.bn1 = nn.BatchNorm2d(planes) 49 | self.relu1 = nn.ReLU(True) 50 | 51 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 52 | self.bn2 = nn.BatchNorm2d(planes) 53 | self.relu2 = nn.ReLU(True) 54 | 55 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 56 | self.bn3 = nn.BatchNorm2d(planes * 4) 57 | 58 | self.shortcut = shortcut 59 | self.stride = stride 60 | 61 | def forward(self, x): 62 | x = self.relu(x) 63 | residual = x if self.shortcut is None else self.shortcut(x) 64 | x = self.relu1(self.bn1(self.conv1(x))) 65 | x = self.relu2(self.bn2(self.conv2(x))) 66 | x = self.bn3(self.conv3(x)) 67 | x += residual 68 | return x 69 | 70 | 71 | class PreActBasicBlock(nn.Module): 72 | expansion = 1 73 | 74 | def __init__(self, inplanes, planes, stride=1, shortcut=None): 75 | super(PreActBasicBlock, self).__init__() 76 | self.bn1 = nn.BatchNorm2d(inplanes) 77 | self.relu1 = nn.ReLU(True) 78 | self.conv1 = nn.Conv2d(inplanes, planes, 3, stride, 1, bias=False) 79 | self.bn2 = nn.BatchNorm2d(planes) 80 | self.relu2 = nn.ReLU(True) 81 | self.conv2 = nn.Conv2d(planes, planes, 3, 1, 1, bias=False) 82 | self.shortcut = shortcut 83 | self.stride = stride 84 | 85 | def forward(self, x): 86 | out = self.bn1(x) 87 | out = self.relu1(out) 88 | residual = x if self.shortcut is None else self.shortcut(out) 89 | out = self.conv1(out) 90 | 91 | out = self.bn2(out) 92 | out = self.relu2(out) 93 | out = self.conv2(out) 94 | 95 | out += residual 96 | return out 97 | 98 | 99 | class PreActBottleneck(nn.Module): 100 | expansion = 4 101 | 102 | def __init__(self, inplanes, planes, stride=1, shortcut=None): 103 | super(PreActBottleneck, self).__init__() 104 | self.bn1 = nn.BatchNorm2d(inplanes) 105 | self.relu1 = nn.ReLU(True) 106 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 107 | self.bn2 = nn.BatchNorm2d(planes) 108 | self.relu2 = nn.ReLU(True) 109 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 110 | self.bn3 = nn.BatchNorm2d(planes) 111 | self.relu3 = nn.ReLU(True) 112 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 113 | self.shortcut = shortcut 114 | self.stride = stride 115 | 116 | def forward(self, x): 117 | residual = x 118 | out = self.relu1(self.bn1(x)) 119 | if self.shortcut is not None: 120 | residual = self.shortcut(out) 121 | out = self.conv1(out) 122 | out = self.conv2(self.relu2(self.bn2(out))) 123 | out = self.conv3(self.relu3(self.bn3(out))) 124 | out += residual 125 | return out 126 | 127 | 128 | class ResNet(nn.Module): 129 | 130 | def __init__(self, block, layers, num_classes=10): 131 | super(ResNet, self).__init__() 132 | self.inplanes = 16 133 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 134 | self.bn1 = nn.BatchNorm2d(16) 135 | self.layer1 = self._make_layer(block, 16, layers[0]) 136 | self.layer2 = self._make_layer(block, 32, layers[1], stride=2) 137 | self.layer3 = self._make_layer(block, 64, layers[2], stride=2) 138 | self.relu = nn.ReLU(inplace=True) 139 | self.avgpool = nn.AvgPool2d(8, stride=1) 140 | self.fc = nn.Linear(64 * block.expansion, num_classes) 141 | 142 | for m in self.modules(): 143 | if isinstance(m, nn.Conv2d): 144 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 145 | m.weight.data.normal_(0, math.sqrt(2. / n)) 146 | elif isinstance(m, nn.BatchNorm2d) and m.affine: 147 | m.weight.data.fill_(1) 148 | m.bias.data.zero_() 149 | 150 | def _make_layer(self, block, planes, blocks, stride=1): 151 | shortcut = None 152 | if stride != 1 or self.inplanes != planes * block.expansion: 153 | shortcut = nn.Sequential( 154 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), 155 | nn.BatchNorm2d(planes * block.expansion)) 156 | 157 | layers = [block(self.inplanes, planes, stride, shortcut)] 158 | self.inplanes = planes * block.expansion 159 | for _ in range(1, blocks): 160 | layers.append(block(self.inplanes, planes)) 161 | 162 | return nn.Sequential(*layers) 163 | 164 | def forward(self, x): 165 | x = self.conv1(x) 166 | x = self.bn1(x) 167 | 168 | x = self.layer1(x) 169 | x = self.layer2(x) 170 | x = self.layer3(x) 171 | 172 | x = self.avgpool(self.relu(x)) 173 | x = x.view(x.size(0), -1) 174 | x = self.fc(x) 175 | 176 | return x 177 | 178 | 179 | class PreAct_ResNet(nn.Module): 180 | 181 | def __init__(self, block, layers, num_classes=10): 182 | super(PreAct_ResNet, self).__init__() 183 | self.inplanes = 16 184 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 185 | self.layer1 = self._make_layer(block, 16, layers[0]) 186 | self.layer2 = self._make_layer(block, 32, layers[1], stride=2) 187 | self.layer3 = self._make_layer(block, 64, layers[2], stride=2) 188 | self.bn = nn.BatchNorm2d(64 * block.expansion) 189 | self.relu = nn.ReLU(inplace=True) 190 | self.avgpool = nn.AvgPool2d(8, stride=1) 191 | self.fc = nn.Linear(64 * block.expansion, num_classes) 192 | 193 | for m in self.modules(): 194 | if isinstance(m, nn.Conv2d): 195 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 196 | m.weight.data.normal_(0, math.sqrt(2. / n)) 197 | elif isinstance(m, nn.BatchNorm2d) and m.affine: 198 | m.weight.data.fill_(1) 199 | m.bias.data.zero_() 200 | 201 | def _make_layer(self, block, planes, blocks, stride=1): 202 | shortcut = None 203 | if stride != 1 or self.inplanes != planes * block.expansion: 204 | shortcut = nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False) 205 | 206 | layers = [block(self.inplanes, planes, stride, shortcut)] 207 | self.inplanes = planes * block.expansion 208 | for _ in range(1, blocks): 209 | layers.append(block(self.inplanes, planes)) 210 | return nn.Sequential(*layers) 211 | 212 | def forward(self, x): 213 | x = self.conv1(x) 214 | 215 | x = self.layer1(x) 216 | x = self.layer2(x) 217 | x = self.layer3(x) 218 | 219 | x = self.bn(x) 220 | x = self.relu(x) 221 | x = self.avgpool(x) 222 | x = x.view(x.size(0), -1) 223 | x = self.fc(x) 224 | 225 | return x 226 | 227 | 228 | def resnet20(**kwargs): 229 | model = ResNet(BasicBlock, [3, 3, 3], **kwargs) 230 | return model 231 | 232 | 233 | def resnet32(**kwargs): 234 | model = ResNet(BasicBlock, [5, 5, 5], **kwargs) 235 | return model 236 | 237 | 238 | def resnet44(**kwargs): 239 | model = ResNet(BasicBlock, [7, 7, 7], **kwargs) 240 | return model 241 | 242 | 243 | def resnet56(**kwargs): 244 | model = ResNet(BasicBlock, [9, 9, 9], **kwargs) 245 | return model 246 | 247 | 248 | def resnet110(**kwargs): 249 | model = ResNet(BasicBlock, [18, 18, 18], **kwargs) 250 | return model 251 | 252 | 253 | def resnet164(**kwargs): 254 | model = ResNet(Bottleneck, [18, 18, 18], **kwargs) 255 | return model 256 | 257 | 258 | def resnet1001(**kwargs): 259 | model = ResNet(Bottleneck, [111, 111, 111], **kwargs) 260 | return model 261 | 262 | 263 | def resnet1202(**kwargs): 264 | model = ResNet(BasicBlock, [200, 200, 200], **kwargs) 265 | return model 266 | 267 | 268 | def preact_resnet20(**kwargs): 269 | model = PreAct_ResNet(PreActBasicBlock, [3, 3, 3], **kwargs) 270 | return model 271 | 272 | 273 | def preact_resnet110(**kwargs): 274 | model = PreAct_ResNet(PreActBasicBlock, [18, 18, 18], **kwargs) 275 | return model 276 | 277 | 278 | def preact_resnet164(**kwargs): 279 | model = PreAct_ResNet(PreActBottleneck, [18, 18, 18], **kwargs) 280 | return model 281 | 282 | 283 | def preact_resnet1001(**kwargs): 284 | model = PreAct_ResNet(PreActBottleneck, [111, 111, 111], **kwargs) 285 | return model 286 | 287 | 288 | if __name__ == '__main__': 289 | net = preact_resnet110() 290 | y = net(torch.autograd.Variable(torch.randn(1, 3, 32, 32))) 291 | print(net) 292 | print(y.size()) 293 | -------------------------------------------------------------------------------- /cifar10/models/vgg.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import extension as my 3 | 4 | 5 | __all__ = [ 6 | 'vgg', 7 | ] 8 | 9 | 10 | 11 | class VGG(nn.Module): 12 | 13 | def __init__(self, features, num_classes=10, init_weights=True): 14 | super(VGG, self).__init__() 15 | self.features = features 16 | #self.avgpool = nn.AdaptiveAvgPool2d((2,2)) 17 | self.avgpool = nn.AvgPool2d(2,2) 18 | self.classifier = nn.Sequential( 19 | nn.Linear(512, num_classes), 20 | ) 21 | if init_weights: 22 | self._initialize_weights() 23 | 24 | def forward(self, x): 25 | x = self.features(x) 26 | #print(x.size()) 27 | x = self.avgpool(x) 28 | #print(x.size()) 29 | x = x.view(x.size(0), -1) 30 | #print(x.size()) 31 | x = self.classifier(x) 32 | return x 33 | 34 | def _initialize_weights(self): 35 | for m in self.modules(): 36 | if isinstance(m, nn.Conv2d): 37 | #nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 38 | if m.bias is not None: 39 | nn.init.constant_(m.bias, 0) 40 | elif isinstance(m, nn.BatchNorm2d): 41 | nn.init.constant_(m.weight, 1) 42 | nn.init.constant_(m.bias, 0) 43 | elif isinstance(m, nn.Linear): 44 | nn.init.normal_(m.weight, 0, 0.01) 45 | nn.init.constant_(m.bias, 0) 46 | 47 | 48 | def make_layers(cfg, batch_norm=True): 49 | layers = [] 50 | in_channels = 3 51 | for v in cfg: 52 | if v == 'M': 53 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 54 | else: 55 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1, bias=False) 56 | if batch_norm: 57 | layers += [conv2d, my.Norm(v), nn.ReLU(inplace=True)] 58 | else: 59 | layers += [conv2d, nn.ReLU(inplace=True)] 60 | in_channels = v 61 | return nn.Sequential(*layers) 62 | 63 | 64 | cfg = { 65 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512], 66 | } 67 | 68 | 69 | def vgg(**kwargs): 70 | model = VGG(make_layers(cfg['E']), **kwargs) 71 | return model 72 | -------------------------------------------------------------------------------- /extension/__init__.py: -------------------------------------------------------------------------------- 1 | from .progress_bar import ProgressBar 2 | from .trainer import * 3 | from . import scheduler, optimizer 4 | from . import logger, visualization, checkpoint 5 | from . import dataset, trainer 6 | from . import utils, normailzation 7 | 8 | # network modules 9 | from .layers import * 10 | from .normailzation import Norm 11 | -------------------------------------------------------------------------------- /extension/checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from collections import OrderedDict 4 | 5 | import torch 6 | 7 | from .logger import get_logger 8 | from . import utils 9 | 10 | 11 | def add_arguments(parser: argparse.ArgumentParser): 12 | group = parser.add_argument_group('Save Options') 13 | group.add_argument('--resume', default="", metavar='PATH', type=utils.path, 14 | help='path to the checkpoint needed resume') 15 | group.add_argument('--load', default="", metavar='PATH', type=utils.path, help='The path to (pre-)trained model.') 16 | group.add_argument('--load-no-strict', default=True, action='store_false', 17 | help='The keys of loaded model may not exactly match the model\'s. (May usefully for finetune)') 18 | return 19 | 20 | 21 | def _strip_prefix_if_present(state_dict, prefix): 22 | keys = sorted(state_dict.keys()) 23 | if not all(key.startswith(prefix) for key in keys): 24 | return state_dict 25 | stripped_state_dict = OrderedDict() 26 | for key, value in state_dict.items(): 27 | stripped_state_dict[key.replace(prefix, "")] = value 28 | return stripped_state_dict 29 | 30 | 31 | class Checkpoint(object): 32 | checkpoint = None 33 | 34 | def __init__(self, model, cfg=None, optimizer=None, scheduler=None, save_dir="", save_to_disk=True, logger=None): 35 | self.model = model 36 | self.cfg = cfg 37 | self.optimizer = optimizer 38 | self.scheduler = scheduler 39 | self.save_dir = save_dir 40 | self.save_to_disk = save_to_disk and bool(self.save_dir) 41 | if logger is None: 42 | logger = get_logger() 43 | self.logger = logger 44 | 45 | def _check_name(self, name: str): 46 | if not name.endswith('.pth'): 47 | name = name + '.pth' 48 | return os.path.join(self.save_dir, name) 49 | 50 | def save_checkpoint(self, name='checkpoint.pth', **kwargs): 51 | if not self.save_to_disk: 52 | return 53 | save_file = self._check_name(name) 54 | data = {"model": self.model.state_dict()} 55 | if self.cfg is not None: 56 | data["cfg"] = self.cfg 57 | if self.optimizer is not None: 58 | data["optimizer"] = self.optimizer.state_dict() 59 | if self.scheduler is not None: 60 | data["scheduler"] = self.scheduler.state_dict() 61 | data.update(kwargs) 62 | self.logger("Saving checkpoint to {}".format(save_file)) 63 | 64 | torch.save(data, save_file) # self.tag_last_checkpoint(save_file) 65 | 66 | def save_model(self, name='model.pth'): 67 | if not self.save_to_disk: 68 | return 69 | save_file = self._check_name(name) 70 | data = _strip_prefix_if_present(self.model.state_dict(), 'module.') 71 | self.logger("Saving model to {}".format(save_file)) 72 | torch.save(data, save_file) 73 | 74 | def load(self, f=None): 75 | # if self.has_checkpoint(): 76 | # override argument with existing checkpoint 77 | # f = self.get_checkpoint_file() 78 | if not f: 79 | # no checkpoint could be found 80 | # self.logger("No checkpoint found. Initializing model from scratch") 81 | return {} 82 | self.logger("==> Loading model from {}, strict: ".format(f, self.cfg.load_no_strict)) 83 | checkpoint = torch.load(f, map_location=torch.device("cpu")) 84 | # if the state_dict comes from a model that was wrapped in a 85 | # DataParallel or DistributedDataParallel during serialization, 86 | # remove the "module" prefix before performing the matching 87 | loaded_state_dict = _strip_prefix_if_present(checkpoint, prefix="module.") 88 | self.model.load_state_dict(loaded_state_dict, strict=self.cfg.load_no_strict) 89 | 90 | return checkpoint 91 | 92 | def resume(self, f=None): 93 | # if self.has_checkpoint(): 94 | # override argument with existing checkpoint 95 | # f = self.get_checkpoint_file() 96 | if not f: 97 | # no checkpoint could be found 98 | # self.logger("No checkpoint found. Initializing model from scratch") 99 | return {} 100 | self.logger("Loading checkpoint from {}".format(f)) 101 | if Checkpoint.checkpoint is not None: 102 | checkpoint = Checkpoint.checkpoint 103 | Checkpoint.checkpoint = None 104 | else: 105 | checkpoint = torch.load(f, map_location=torch.device("cpu")) 106 | 107 | # if the state_dict comes from a model that was wrapped in a 108 | # DataParallel or DistributedDataParallel during serialization, 109 | # remove the "module" prefix before performing the matching 110 | loaded_state_dict = _strip_prefix_if_present(checkpoint.pop("model"), prefix="module.") 111 | self.model.load_state_dict(loaded_state_dict) 112 | if "optimizer" in checkpoint and self.optimizer: 113 | self.logger("Loading optimizer from {}".format(f)) 114 | self.optimizer.load_state_dict(checkpoint.pop("optimizer")) 115 | if "scheduler" in checkpoint and self.scheduler: 116 | self.logger("Loading scheduler from {}".format(f)) 117 | self.scheduler.load_state_dict(checkpoint.pop("scheduler")) 118 | if "cfg" in checkpoint: 119 | checkpoint.pop("cfg") 120 | 121 | # return any further checkpoint data 122 | return checkpoint 123 | 124 | def has_checkpoint(self): 125 | save_file = os.path.join(self.save_dir, "last_checkpoint") 126 | return os.path.exists(save_file) 127 | 128 | def get_checkpoint_file(self): 129 | save_file = os.path.join(self.save_dir, "last_checkpoint") 130 | try: 131 | with open(save_file, "r") as f: 132 | last_saved = f.read() 133 | except IOError: 134 | # if file doesn't exist, maybe because it has just been 135 | # deleted by a separate process 136 | last_saved = "" 137 | return last_saved 138 | 139 | def tag_last_checkpoint(self, last_filename): 140 | save_file = os.path.join(self.save_dir, "last_checkpoint") 141 | with open(save_file, "w") as f: 142 | f.write(last_filename) 143 | 144 | @staticmethod 145 | def load_config(f=None): 146 | if f: 147 | Checkpoint.checkpoint = torch.load(f, map_location=torch.device("cpu")) 148 | if "cfg" in Checkpoint.checkpoint: 149 | print('Read config from checkpoint {}'.format(f)) 150 | return Checkpoint.checkpoint.pop("cfg") 151 | return None 152 | -------------------------------------------------------------------------------- /extension/dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import torchvision 5 | import torch.utils.data 6 | from . import utils 7 | from .logger import get_logger 8 | from torchvision.datasets.folder import has_file_allowed_extension, default_loader, IMG_EXTENSIONS 9 | 10 | dataset_list = ['mnist', 'fashion-mnist', 'cifar10', 'ImageNet', 'folder'] 11 | 12 | 13 | def add_arguments(parser: argparse.ArgumentParser): 14 | group = parser.add_argument_group('Dataset Option') 15 | group.add_argument('--dataset', metavar='NAME', default='mnist', choices=dataset_list, 16 | help='The name of dataset in {' + ', '.join(dataset_list) + '}') 17 | group.add_argument('--dataset-root', metavar='PATH', default=os.path.expanduser('~/data/'), type=utils.path, 18 | help='The directory which contains needed dataset.') 19 | group.add_argument('-b', '--batch-size', type=utils.str2list, default=[], metavar='NUMs', 20 | help='The size of mini-batch') 21 | group.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='The number of data loading workers.') 22 | group.add_argument('--im-size', type=utils.str2tuple, default=(), metavar='NUMs', 23 | help='Resize image to special size. (default: no resize)') 24 | group.add_argument('--dataset-classes', type=int, default=None, help='The number of classes in dataset.') 25 | return group 26 | 27 | 28 | def make_dataset(dir, extensions): 29 | images = [] 30 | dir = os.path.expanduser(dir) 31 | for root, _, fnames in sorted(os.walk(dir)): 32 | for fname in sorted(fnames): 33 | if has_file_allowed_extension(fname, extensions): 34 | path = os.path.join(root, fname) 35 | images.append(path) 36 | 37 | return images 38 | 39 | 40 | class DatasetFlatFolder(torch.utils.data.Dataset): 41 | """A generic data loader where the samples are arranged in this way: :: 42 | 43 | root/xxx.ext 44 | root/xxy.ext 45 | root/xxz.ext 46 | 47 | Args: 48 | root (string): Root directory path. 49 | transform (callable, optional): A function/transform that takes in 50 | a sample and returns a transformed version. 51 | E.g, ``transforms.RandomCrop`` for images. 52 | target_transform (callable, optional): A function/transform that takes 53 | in the target and transforms it. 54 | loader (callable): A function to load a sample given its path. 55 | 56 | Attributes: 57 | samples (list): List of (sample path, class_index) tuples 58 | """ 59 | 60 | def __init__(self, root, transform=None, loader=default_loader): 61 | samples = make_dataset(root, IMG_EXTENSIONS) 62 | assert len(samples) > 0, "Found 0 files in: " + root + "\nSupported extensions are: " + ",".join(IMG_EXTENSIONS) 63 | self.root = root 64 | self.loader = loader 65 | self.extensions = IMG_EXTENSIONS 66 | self.samples = samples 67 | self.transform = transform 68 | 69 | def __getitem__(self, index): 70 | """ 71 | Args: 72 | index (int): Index 73 | 74 | Returns: 75 | tuple: 'sample' where target is class_index of the target class. 76 | """ 77 | path = self.samples[index] 78 | sample = self.loader(path) 79 | if self.transform is not None: 80 | sample = self.transform(sample) 81 | return sample 82 | 83 | def __len__(self): 84 | return len(self.samples) 85 | 86 | def __repr__(self): 87 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 88 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 89 | fmt_str += ' Root Location: {}\n'.format(self.root) 90 | tmp = ' Transforms (if any): ' 91 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 92 | return fmt_str 93 | 94 | 95 | def get_dataset_loader(args: argparse.Namespace, transforms=None, target_transform=None, train=True, use_cuda=True): 96 | args.dataset_root = os.path.expanduser(args.dataset_root) 97 | root = args.dataset_root 98 | assert os.path.exists(root), 'Please assign the correct dataset root path with --dataset-root ' 99 | if args.dataset != 'folder': 100 | root = os.path.join(root, args.dataset) 101 | 102 | if isinstance(transforms, list): 103 | transforms = torchvision.transforms.Compose(transforms) 104 | 105 | if args.dataset == 'mnist': 106 | if len(args.im_size) == 0: 107 | args.im_size = (1, 28, 28) 108 | args.dataset_classes = 10 109 | dataset = torchvision.datasets.mnist.MNIST(root, train, transforms, target_transform, download=True) 110 | elif args.dataset == 'fashion-mnist': 111 | if len(args.im_size) == 0: 112 | args.im_size = (1, 28, 28) 113 | args.dataset_classes = 10 114 | dataset = torchvision.datasets.FashionMNIST(root, train, transforms, target_transform, download=True) 115 | elif args.dataset == 'cifar10': 116 | if len(args.im_size) == 0: 117 | args.im_size = (3, 32, 32) 118 | args.dataset_classes = 10 119 | dataset = torchvision.datasets.CIFAR10(root, train, transforms, target_transform, download=True) 120 | elif args.dataset in ['ImageNet', 'folder']: 121 | if len(args.im_size) == 0: 122 | args.im_size = (3, 256, 256) 123 | args.dataset_classes = 1000 124 | root = os.path.join(root, 'train' if train else 'val') 125 | dataset = torchvision.datasets.ImageFolder(root, transforms, target_transform) 126 | else: 127 | raise FileNotFoundError('No such dataset') 128 | 129 | loader_kwargs = {'num_workers': args.workers, 'pin_memory': True} if use_cuda else {} 130 | if len(args.batch_size) == 0: 131 | args.batch_size = [256, 256] 132 | elif len(args.batch_size) == 1: 133 | args.batch_size.append(args.batch_size[0]) 134 | dataset_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size[not train], shuffle=train, 135 | drop_last=train, **loader_kwargs) 136 | LOG = get_logger() 137 | LOG('==> Dataset: {}'.format(dataset)) 138 | return dataset_loader 139 | -------------------------------------------------------------------------------- /extension/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .sequential import Sequential, NamedSequential 2 | from .scale import Scale 3 | from .view import View 4 | -------------------------------------------------------------------------------- /extension/layers/scale.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.parameter import Parameter 4 | 5 | 6 | class Scale(nn.Module): 7 | def __init__(self, num_features, dim=4): 8 | super(Scale, self).__init__() 9 | self.num_features = num_features 10 | shape = [1 for _ in range(dim)] 11 | shape[1] = self.num_features 12 | 13 | self.weight = Parameter(torch.Tensor(*shape)) 14 | self.bias = Parameter(torch.Tensor(*shape)) 15 | 16 | self.reset_parameters() 17 | 18 | def reset_parameters(self): 19 | # nn.init.uniform_(self.weight) 20 | nn.init.ones_(self.weight) 21 | nn.init.zeros_(self.bias) 22 | 23 | def forward(self, input): 24 | return input * self.weight + self.bias 25 | 26 | def extra_repr(self): 27 | return '{}'.format(self.num_features) 28 | 29 | 30 | if __name__ == '__main__': 31 | s = Scale(4) 32 | x = torch.ones(3, 4, 5, 6) 33 | print(s.weight.size()) 34 | nn.init.constant_(s.weight, 2) 35 | nn.init.constant_(s.bias, 1) 36 | y = s(x) 37 | print(y, y.size()) 38 | -------------------------------------------------------------------------------- /extension/layers/sequential.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from collections import OrderedDict 3 | 4 | 5 | def Sequential(*args): 6 | """ 7 | Return a nn.Sequential object which ignore the parts not belong to nn.Module, such as None. 8 | """ 9 | modules = [] 10 | for m in args: 11 | if isinstance(m, nn.Module): 12 | modules.append(m) 13 | return nn.Sequential(*modules) 14 | 15 | 16 | def NamedSequential(**kwargs): 17 | """ 18 | Return a nn.Sequential object which ignore the parts not belong to nn.Module, such as None. 19 | """ 20 | modules = [] 21 | for k, v in kwargs.items(): 22 | if isinstance(v, nn.Module): 23 | modules.append((k, v)) 24 | return nn.Sequential(OrderedDict(modules)) 25 | 26 | 27 | if __name__ == '__main__': 28 | print(Sequential(nn.Conv2d(32, 3, 1, 1), nn.BatchNorm2d(32), None, nn.ReLU())) 29 | print(NamedSequential(conv1=nn.Conv2d(32, 3, 1, 1), bn=nn.BatchNorm2d(32), q=None, relu=nn.ReLU())) 30 | -------------------------------------------------------------------------------- /extension/layers/view.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class View(torch.nn.Module): 5 | """ 6 | reshape input tensor to a new tensor with by use torch.view() 7 | size is not include batch_size 8 | """ 9 | 10 | def __init__(self, *new_size: int): 11 | super(View, self).__init__() 12 | self.new_size = new_size 13 | 14 | def forward(self, x: torch.Tensor): 15 | y = x.view(x.size(0), *self.new_size) 16 | return y 17 | 18 | def __repr__(self): 19 | return 'view{}'.format(self.new_size) 20 | -------------------------------------------------------------------------------- /extension/logger.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import warnings 4 | 5 | logger = None 6 | 7 | 8 | class _Logger: 9 | def __init__(self, filename=None, path='.', only_print=False, append=False): 10 | os.makedirs(path, exist_ok=True) 11 | self.filename = os.path.join(path, filename) 12 | self.file = None 13 | if filename and not only_print: 14 | self.file = open(self.filename, 'a' if append else 'w') 15 | 16 | def __del__(self): 17 | if self.file: 18 | self.file.close() 19 | 20 | def __call__(self, msg='', end='\n', is_print=True, is_log=True): 21 | if is_print: 22 | print(msg, end=end) 23 | if is_log and self.file is not None: 24 | self.file.write(msg) 25 | self.file.write(end) 26 | self.file.flush() 27 | 28 | 29 | def add_arguments(parser: argparse.ArgumentParser): 30 | group = parser.add_argument_group('Logger Options') 31 | # group.add_argument('--log', metavar='PATH', default='./results', help='The root path of save log text and model') 32 | group.add_argument('--log-suffix', metavar='NAME', default='', help='the suffix of log path.') 33 | group.add_argument('--print-f', metavar='N', default=100, type=int, help='print frequency. (default: 100)') 34 | return 35 | 36 | 37 | def setting(filename=None, path='.', only_print=False, append=False): 38 | global logger 39 | logger = _Logger(filename, path, only_print, append) 40 | return logger 41 | 42 | 43 | def get_logger(): 44 | global logger 45 | if logger is None: 46 | warnings.warn('Logger is not set!') 47 | return print 48 | else: 49 | return logger 50 | -------------------------------------------------------------------------------- /extension/normailzation/__init__.py: -------------------------------------------------------------------------------- 1 | from .normailzation import add_arguments, setting, Norm 2 | -------------------------------------------------------------------------------- /extension/normailzation/center_normalization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.parameter import Parameter 4 | 5 | 6 | class CenterNorm(nn.Module): 7 | def __init__(self, num_features, momentum=0.1, dim=4, frozen=False, affine=True, *args, **kwargs): 8 | super(CenterNorm, self).__init__() 9 | self.frozen = frozen 10 | self.num_features = num_features 11 | self.momentum = momentum 12 | self.dim = dim 13 | self.shape = [1 for _ in range(dim)] 14 | self.shape[1] = self.num_features 15 | self.affine = affine 16 | if self.affine: 17 | self.bias = Parameter(torch.Tensor(*self.shape)) 18 | self.register_buffer('running_mean', torch.zeros(self.shape)) 19 | self.reset_parameters() 20 | 21 | def reset_parameters(self): 22 | if self.affine: 23 | nn.init.zeros_(self.bias) 24 | self.running_mean.zero_() 25 | 26 | def forward(self, input: torch.Tensor): 27 | assert input.size(1) == self.num_features and self.dim == input.dim() 28 | if self.training and not self.frozen: 29 | mean = input.mean(0, keepdim=True) 30 | for d in range(2, self.dim): 31 | mean = mean.mean(d, keepdim=True) 32 | output = input - mean 33 | self.running_mean = (1. - self.momentum) * self.running_mean + self.momentum * mean 34 | else: 35 | output = input - self.running_mean 36 | if self.affine: 37 | output = output + self.bias 38 | return output 39 | 40 | def extra_repr(self): 41 | return '{num_features}, momentum={momentum}, frozen={frozen}, affine={affine}'.format(**self.__dict__) 42 | 43 | 44 | if __name__ == '__main__': 45 | cn = CenterNorm(32) 46 | print(cn) 47 | print(cn.running_mean.size()) 48 | x = torch.randn(3, 32, 64, 64) + 1. 49 | print(x.mean()) 50 | y = cn(x) 51 | print(y.mean()) 52 | print(cn.running_mean.size()) 53 | -------------------------------------------------------------------------------- /extension/normailzation/dbn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import Parameter 4 | 5 | 6 | class DBN(nn.Module): 7 | def __init__(self, num_features, num_groups=32, num_channels=0, dim=4, eps=1e-5, momentum=0.1, affine=True, mode=0, 8 | *args, **kwargs): 9 | super(DBN, self).__init__() 10 | if num_channels > 0: 11 | num_groups = num_features // num_channels 12 | self.num_features = num_features 13 | self.num_groups = num_groups 14 | assert self.num_features % self.num_groups == 0 15 | self.dim = dim 16 | self.eps = eps 17 | self.momentum = momentum 18 | self.affine = affine 19 | self.mode = mode 20 | 21 | self.shape = [1] * dim 22 | self.shape[1] = num_features 23 | 24 | if self.affine: 25 | self.weight = Parameter(torch.Tensor(*self.shape)) 26 | self.bias = Parameter(torch.Tensor(*self.shape)) 27 | else: 28 | self.register_parameter('weight', None) 29 | self.register_parameter('bias', None) 30 | 31 | self.register_buffer('running_mean', torch.zeros(num_groups, 1)) 32 | self.register_buffer('running_projection', torch.eye(num_groups)) 33 | self.reset_parameters() 34 | 35 | # def reset_running_stats(self): 36 | # self.running_mean.zero_() 37 | # self.running_var.eye_(1) 38 | 39 | def reset_parameters(self): 40 | # self.reset_running_stats() 41 | if self.affine: 42 | nn.init.uniform_(self.weight) 43 | nn.init.zeros_(self.bias) 44 | 45 | def forward(self, input: torch.Tensor): 46 | size = input.size() 47 | assert input.dim() == self.dim and size[1] == self.num_features 48 | x = input.view(size[0] * size[1] // self.num_groups, self.num_groups, *size[2:]) 49 | training = self.mode > 0 or (self.mode == 0 and self.training) 50 | x = x.transpose(0, 1).contiguous().view(self.num_groups, -1) 51 | if training: 52 | mean = x.mean(1, keepdim=True) 53 | self.running_mean = (1. - self.momentum) * self.running_mean + self.momentum * mean 54 | x_mean = x - mean 55 | sigma = x_mean.matmul(x_mean.t()) / x.size(1) + self.eps * torch.eye(self.num_groups, device=input.device) 56 | # print('sigma size {}'.format(sigma.size())) 57 | u, eig, _ = sigma.svd() 58 | scale = eig.rsqrt() 59 | wm = u.matmul(scale.diag()).matmul(u.t()) 60 | self.running_projection = (1. - self.momentum) * self.running_projection + self.momentum * wm 61 | y = wm.matmul(x_mean) 62 | else: 63 | x_mean = x - self.running_mean 64 | y = self.running_projection.matmul(x_mean) 65 | output = y.view(self.num_groups, size[0] * size[1] // self.num_groups, *size[2:]).transpose(0, 1) 66 | output = output.contiguous().view_as(input) 67 | if self.affine: 68 | output = output * self.weight + self.bias 69 | return output 70 | 71 | def extra_repr(self): 72 | return '{num_features}, num_groups={num_groups}, eps={eps}, momentum={momentum}, affine={affine}, ' \ 73 | 'mode={mode}'.format(**self.__dict__) 74 | 75 | 76 | class DBN2(DBN): 77 | """ 78 | when evaluation phase, sigma using running average. 79 | """ 80 | 81 | def forward(self, input: torch.Tensor): 82 | size = input.size() 83 | assert input.dim() == self.dim and size[1] == self.num_features 84 | x = input.view(size[0] * size[1] // self.num_groups, self.num_groups, *size[2:]) 85 | training = self.mode > 0 or (self.mode == 0 and self.training) 86 | x = x.transpose(0, 1).contiguous().view(self.num_groups, -1) 87 | mean = x.mean(1, keepdim=True) if training else self.running_mean 88 | x_mean = x - mean 89 | if training: 90 | self.running_mean = (1. - self.momentum) * self.running_mean + self.momentum * mean 91 | sigma = x_mean.matmul(x_mean.t()) / x.size(1) + self.eps * torch.eye(self.num_groups, device=input.device) 92 | self.running_projection = (1. - self.momentum) * self.running_projection + self.momentum * sigma 93 | else: 94 | sigma = self.running_projection 95 | u, eig, _ = sigma.svd() 96 | scale = eig.rsqrt() 97 | wm = u.matmul(scale.diag()).matmul(u.t()) 98 | y = wm.matmul(x_mean) 99 | output = y.view(self.num_groups, size[0] * size[1] // self.num_groups, *size[2:]).transpose(0, 1) 100 | output = output.contiguous().view_as(input) 101 | if self.affine: 102 | output = output * self.weight + self.bias 103 | return output 104 | 105 | 106 | if __name__ == '__main__': 107 | dbn = DBN(64, 32, affine=False, momentum=1.) 108 | x = torch.randn(2, 64, 7, 7) 109 | print(dbn) 110 | y = dbn(x) 111 | print('y size:', y.size()) 112 | y = y.view(y.size(0) * y.size(1) // dbn.num_groups, dbn.num_groups, *y.size()[2:]) 113 | y = y.transpose(0, 1).contiguous().view(dbn.num_groups, -1) 114 | print('y reshaped:', y.size()) 115 | z = y.matmul(y.t()) 116 | print('train mode:', z.diag()) 117 | dbn.eval() 118 | y = dbn(x) 119 | y = y.view(y.size(0) * y.size(1) // dbn.num_groups, dbn.num_groups, *y.size()[2:]) 120 | y = y.transpose(0, 1).contiguous().view(dbn.num_groups, -1) 121 | z = y.matmul(y.t()) 122 | print('eval mode:', z.diag()) 123 | print(__file__) 124 | -------------------------------------------------------------------------------- /extension/normailzation/group_batch_normalization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import Parameter 5 | 6 | 7 | class GroupBatchNorm(nn.Module): 8 | def __init__(self, num_features, num_groups=1, num_channels=0, dim=4, eps=1e-5, momentum=0.1, affine=True, mode=0, 9 | *args, **kwargs): 10 | """""" 11 | super(GroupBatchNorm, self).__init__() 12 | if num_channels > 0: 13 | assert num_features % num_channels == 0 14 | num_groups = num_features // num_channels 15 | assert num_features % num_groups == 0 16 | self.num_features = num_features 17 | self.num_groups = num_groups 18 | self.dim = dim 19 | self.eps = eps 20 | self.momentum = momentum 21 | self.affine = affine 22 | self.mode = mode 23 | self.shape = [1] * dim 24 | self.shape[1] = num_features 25 | 26 | if self.affine: 27 | self.weight = Parameter(torch.Tensor(*self.shape)) 28 | self.bias = Parameter(torch.Tensor(*self.shape)) 29 | else: 30 | self.register_parameter('weight', None) 31 | self.register_parameter('bias', None) 32 | 33 | self.register_buffer('running_mean', torch.zeros(num_groups)) 34 | self.register_buffer('running_var', torch.ones(num_groups)) 35 | self.reset_parameters() 36 | 37 | def reset_running_stats(self): 38 | self.running_mean.zero_() 39 | self.running_var.fill_(1) 40 | 41 | def reset_parameters(self): 42 | self.reset_running_stats() 43 | if self.affine: 44 | nn.init.uniform_(self.weight) 45 | nn.init.zeros_(self.bias) 46 | 47 | def forward(self, input: torch.Tensor): 48 | training = self.mode > 0 or (self.mode == 0 and self.training) 49 | assert input.dim() == self.dim and input.size(1) == self.num_features 50 | sizes = input.size() 51 | reshaped = input.view(sizes[0] * sizes[1] // self.num_groups, self.num_groups, *sizes[2:self.dim]) 52 | output = F.batch_norm(reshaped, self.running_mean, self.running_var, training=training, momentum=self.momentum, 53 | eps=self.eps) 54 | output = output.view_as(input) 55 | if self.affine: 56 | output = output * self.weight + self.bias 57 | return output 58 | 59 | def extra_repr(self): 60 | return '{num_features}, num_groups={num_groups}, eps={eps}, momentum={momentum}, affine={affine}, ' \ 61 | 'mode={mode}'.format(**self.__dict__) 62 | 63 | 64 | if __name__ == '__main__': 65 | GBN = GroupBatchNorm(64, 16, momentum=1) 66 | print(GBN) 67 | # print(GBN.weight) 68 | # print(GBN.bias) 69 | x = torch.randn(4, 64, 32, 32) * 2 + 1 70 | print('x mean = {}, var = {}'.format(x.mean(), x.var())) 71 | y = GBN(x) 72 | print('y size = {}, mean = {}, var = {}'.format(y.size(), y.mean(), y.var())) 73 | print(GBN.running_mean, GBN.running_var) 74 | -------------------------------------------------------------------------------- /extension/normailzation/iterative_normalization.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reference: Iterative Normalization: Beyond Standardization towards Efficient Whitening, CVPR 2019 3 | 4 | - Paper: 5 | - Code: https://github.com/huangleiBuaa/IterNorm 6 | """ 7 | import torch.nn 8 | from torch.nn import Parameter 9 | 10 | # import extension._bcnn as bcnn 11 | 12 | __all__ = ['iterative_normalization', 'IterNorm'] 13 | 14 | 15 | # 16 | # class iterative_normalization(torch.autograd.Function): 17 | # @staticmethod 18 | # def forward(ctx, *inputs): 19 | # result = bcnn.iterative_normalization_forward(*inputs) 20 | # ctx.save_for_backward(*result[:-1]) 21 | # return result[-1] 22 | # 23 | # @staticmethod 24 | # def backward(ctx, *grad_outputs): 25 | # grad, = grad_outputs 26 | # grad_input = bcnn.iterative_normalization_backward(grad, ctx.saved_variables) 27 | # return grad_input, None, None, None, None, None, None, None 28 | 29 | 30 | class iterative_normalization_py(torch.autograd.Function): 31 | @staticmethod 32 | def forward(ctx, *args, **kwargs): 33 | X, running_mean, running_wmat, nc, ctx.T, eps, momentum, training = args 34 | # change NxCxHxW to (G x D) x(NxHxW), i.e., g*d*m 35 | ctx.g = X.size(1) // nc 36 | x = X.transpose(0, 1).contiguous().view(ctx.g, nc, -1) 37 | _, d, m = x.size() 38 | saved = [] 39 | if training: 40 | # calculate centered activation by subtracted mini-batch mean 41 | mean = x.mean(-1, keepdim=True) 42 | xc = x - mean 43 | saved.append(xc) 44 | # calculate covariance matrix 45 | P = [None] * (ctx.T + 1) 46 | P[0] = torch.eye(d).to(X).expand(ctx.g, d, d) 47 | Sigma = torch.baddbmm(eps, P[0], 1. / m, xc, xc.transpose(1, 2)) 48 | # reciprocal of trace of Sigma: shape [g, 1, 1] 49 | rTr = (Sigma * P[0]).sum((1, 2), keepdim=True).reciprocal_() 50 | saved.append(rTr) 51 | Sigma_N = Sigma * rTr 52 | saved.append(Sigma_N) 53 | for k in range(ctx.T): 54 | P[k + 1] = torch.baddbmm(1.5, P[k], -0.5, torch.matrix_power(P[k], 3), Sigma_N) 55 | saved.extend(P) 56 | wm = P[ctx.T].mul_(rTr.sqrt()) # whiten matrix: the matrix inverse of Sigma, i.e., Sigma^{-1/2} 57 | running_mean.copy_(momentum * mean + (1. - momentum) * running_mean) 58 | running_wmat.copy_(momentum * wm + (1. - momentum) * running_wmat) 59 | else: 60 | xc = x - running_mean 61 | wm = running_wmat 62 | xn = wm.matmul(xc) 63 | Xn = xn.view(X.size(1), X.size(0), *X.size()[2:]).transpose(0, 1).contiguous() 64 | ctx.save_for_backward(*saved) 65 | return Xn 66 | 67 | @staticmethod 68 | def backward(ctx, *grad_outputs): 69 | grad, = grad_outputs 70 | saved = ctx.saved_variables 71 | xc = saved[0] # centered input 72 | rTr = saved[1] # trace of Sigma 73 | sn = saved[2].transpose(-2, -1) # normalized Sigma 74 | P = saved[3:] # middle result matrix, 75 | g, d, m = xc.size() 76 | 77 | g_ = grad.transpose(0, 1).contiguous().view_as(xc) 78 | g_wm = g_.matmul(xc.transpose(-2, -1)) 79 | g_P = g_wm * rTr.sqrt() 80 | wm = P[ctx.T] 81 | g_sn = 0 82 | for k in range(ctx.T, 1, -1): 83 | P[k - 1].transpose_(-2, -1) 84 | P2 = P[k - 1].matmul(P[k - 1]) 85 | g_sn += P2.matmul(P[k - 1]).matmul(g_P) 86 | g_tmp = g_P.matmul(sn) 87 | g_P.baddbmm_(1.5, -0.5, g_tmp, P2) 88 | g_P.baddbmm_(1, -0.5, P2, g_tmp) 89 | g_P.baddbmm_(1, -0.5, P[k - 1].matmul(g_tmp), P[k - 1]) 90 | g_sn += g_P 91 | # g_sn = g_sn * rTr.sqrt() 92 | g_tr = ((-sn.matmul(g_sn) + g_wm.transpose(-2, -1).matmul(wm)) * P[0]).sum((1, 2), keepdim=True) * P[0] 93 | g_sigma = (g_sn + g_sn.transpose(-2, -1) + 2. * g_tr) * (-0.5 / m * rTr) 94 | # g_sigma = g_sigma + g_sigma.transpose(-2, -1) 95 | g_x = torch.baddbmm(wm.matmul(g_ - g_.mean(-1, keepdim=True)), g_sigma, xc) 96 | grad_input = g_x.view(grad.size(1), grad.size(0), *grad.size()[2:]).transpose(0, 1).contiguous() 97 | return grad_input, None, None, None, None, None, None, None 98 | 99 | 100 | class IterNorm(torch.nn.Module): 101 | def __init__(self, num_features, num_groups=1, num_channels=None, T=5, dim=4, eps=1e-5, momentum=0.1, affine=True, 102 | *args, **kwargs): 103 | super(IterNorm, self).__init__() 104 | # assert dim == 4, 'IterNorm is not support 2D' 105 | self.T = T 106 | self.eps = eps 107 | self.momentum = momentum 108 | self.num_features = num_features 109 | self.affine = affine 110 | self.dim = dim 111 | if num_channels is None: 112 | num_channels = (num_features - 1) // num_groups + 1 113 | num_groups = num_features // num_channels 114 | while num_features % num_channels != 0: 115 | num_channels //= 2 116 | num_groups = num_features // num_channels 117 | assert num_groups > 0 and num_features % num_groups == 0, "num features={}, num groups={}".format(num_features, 118 | num_groups) 119 | self.num_groups = num_groups 120 | self.num_channels = num_channels 121 | shape = [1] * dim 122 | shape[1] = self.num_features 123 | if self.affine: 124 | self.weight = Parameter(torch.Tensor(*shape)) 125 | self.bias = Parameter(torch.Tensor(*shape)) 126 | else: 127 | self.register_parameter('weight', None) 128 | self.register_parameter('bias', None) 129 | 130 | self.register_buffer('running_mean', torch.zeros(num_groups, num_channels, 1)) 131 | # running whiten matrix 132 | self.register_buffer('running_wm', torch.eye(num_channels).expand(num_groups, num_channels, num_channels)) 133 | self.reset_parameters() 134 | 135 | def reset_parameters(self): 136 | # self.reset_running_stats() 137 | if self.affine: 138 | torch.nn.init.ones_(self.weight) 139 | torch.nn.init.zeros_(self.bias) 140 | 141 | def forward(self, X: torch.Tensor): 142 | X_hat = iterative_normalization_py.apply(X, self.running_mean, self.running_wm, self.num_channels, self.T, 143 | self.eps, self.momentum, self.training) 144 | # affine 145 | if self.affine: 146 | return X_hat * self.weight + self.bias 147 | else: 148 | return X_hat 149 | 150 | def extra_repr(self): 151 | return '{num_features}, num_channels={num_channels}, T={T}, eps={eps}, ' \ 152 | 'momentum={momentum}, affine={affine}'.format(**self.__dict__) 153 | 154 | 155 | if __name__ == '__main__': 156 | ItN = IterNorm(64, num_groups=8, T=10, momentum=1, affine=False) 157 | print(ItN) 158 | ItN.train() 159 | #x = torch.randn(32, 64, 14, 14) 160 | x = torch.randn(128, 64) 161 | x.requires_grad_() 162 | y = ItN(x) 163 | z = y.transpose(0, 1).contiguous().view(x.size(1), -1) 164 | print(z.matmul(z.t()) / z.size(1)) 165 | 166 | y.sum().backward() 167 | print('x grad', x.grad.size()) 168 | 169 | ItN.eval() 170 | y = ItN(x) 171 | z = y.transpose(0, 1).contiguous().view(x.size(1), -1) 172 | print(z.matmul(z.t()) / z.size(1)) 173 | -------------------------------------------------------------------------------- /extension/normailzation/iterative_normalization_FlexGroup.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reference: Iterative Normalization: Beyond Standardization towards Efficient Whitening, CVPR 2019 3 | 4 | - Paper: 5 | - Code: https://github.com/huangleiBuaa/IterNorm 6 | 7 | ***** 8 | This implementation allows the number of featur maps is not divided by the channel number of per Group. E,g. one can use group size of 64 when the channel number is 80. (64 + 16) 9 | 10 | """ 11 | import torch.nn 12 | from torch.nn import Parameter 13 | 14 | # import extension._bcnn as bcnn 15 | 16 | __all__ = ['iterative_normalization_FlexGroup', 'IterNorm'] 17 | 18 | 19 | # 20 | # class iterative_normalization(torch.autograd.Function): 21 | # @staticmethod 22 | # def forward(ctx, *inputs): 23 | # result = bcnn.iterative_normalization_forward(*inputs) 24 | # ctx.save_for_backward(*result[:-1]) 25 | # return result[-1] 26 | # 27 | # @staticmethod 28 | # def backward(ctx, *grad_outputs): 29 | # grad, = grad_outputs 30 | # grad_input = bcnn.iterative_normalization_backward(grad, ctx.saved_variables) 31 | # return grad_input, None, None, None, None, None, None, None 32 | 33 | 34 | class iterative_normalization_py(torch.autograd.Function): 35 | @staticmethod 36 | def forward(ctx, *args, **kwargs): 37 | X, running_mean, running_wmat, nc, ctx.T, eps, momentum, training = args 38 | # change NxCxHxW to Dx(NxHxW), i.e., d*m 39 | ctx.g = X.size(1) // nc 40 | x = X.transpose(0, 1).contiguous().view(nc, -1) 41 | d, m = x.size() 42 | saved = [] 43 | if training: 44 | # calculate centered activation by subtracted mini-batch mean 45 | mean = x.mean(-1, keepdim=True) 46 | xc = x - mean 47 | saved.append(xc) 48 | # calculate covariance matrix 49 | P = [None] * (ctx.T + 1) 50 | P[0] = torch.eye(d).to(X) 51 | Sigma = torch.addmm(eps, P[0], 1. / m, xc, xc.transpose(0, 1)) 52 | # reciprocal of trace of Sigma: shape [g, 1, 1] 53 | rTr = (Sigma * P[0]).sum((0, 1), keepdim=True).reciprocal_() 54 | saved.append(rTr) 55 | Sigma_N = Sigma * rTr 56 | saved.append(Sigma_N) 57 | for k in range(ctx.T): 58 | P[k + 1] = torch.addmm(1.5, P[k], -0.5, torch.matrix_power(P[k], 3), Sigma_N) 59 | saved.extend(P) 60 | wm = P[ctx.T].mul_(rTr.sqrt()) # whiten matrix: the matrix inverse of Sigma, i.e., Sigma^{-1/2} 61 | running_mean.copy_(momentum * mean + (1. - momentum) * running_mean) 62 | running_wmat.copy_(momentum * wm + (1. - momentum) * running_wmat) 63 | else: 64 | xc = x - running_mean 65 | wm = running_wmat 66 | xn = wm.mm(xc) 67 | Xn = xn.view(X.size(1), X.size(0), *X.size()[2:]).transpose(0, 1).contiguous() 68 | ctx.save_for_backward(*saved) 69 | return Xn 70 | 71 | @staticmethod 72 | def backward(ctx, *grad_outputs): 73 | grad, = grad_outputs 74 | saved = ctx.saved_variables 75 | xc = saved[0] # centered input 76 | rTr = saved[1] # trace of Sigma 77 | sn = saved[2].transpose(-2, -1) # normalized Sigma 78 | P = saved[3:] # middle result matrix, 79 | d, m = xc.size() 80 | 81 | g_ = grad.transpose(0, 1).contiguous().view_as(xc) 82 | g_wm = g_.mm(xc.transpose(-2, -1)) 83 | g_P = g_wm * rTr.sqrt() 84 | wm = P[ctx.T] 85 | g_sn = 0 86 | for k in range(ctx.T, 1, -1): 87 | P[k - 1].transpose_(-2, -1) 88 | P2 = P[k - 1].mm(P[k - 1]) 89 | g_sn += P2.mm(P[k - 1]).mm(g_P) 90 | g_tmp = g_P.mm(sn) 91 | g_P.addmm_(1.5, -0.5, g_tmp, P2) 92 | g_P.addmm_(1, -0.5, P2, g_tmp) 93 | g_P.addmm_(1, -0.5, P[k - 1].mm(g_tmp), P[k - 1]) 94 | g_sn += g_P 95 | # g_sn = g_sn * rTr.sqrt() 96 | g_tr = ((-sn.mm(g_sn) + g_wm.transpose(-2, -1).mm(wm)) * P[0]).sum((0, 1), keepdim=True) * P[0] 97 | g_sigma = (g_sn + g_sn.transpose(-2, -1) + 2. * g_tr) * (-0.5 / m * rTr) 98 | # g_sigma = g_sigma + g_sigma.transpose(-2, -1) 99 | g_x = torch.addmm(wm.mm(g_ - g_.mean(-1, keepdim=True)), g_sigma, xc) 100 | grad_input = g_x.view(grad.size(1), grad.size(0), *grad.size()[2:]).transpose(0, 1).contiguous() 101 | return grad_input, None, None, None, None, None, None, None 102 | 103 | 104 | class IterNorm_Single(torch.nn.Module): 105 | def __init__(self, num_features, num_groups=1, num_channels=None, T=5, dim=4, eps=1e-5, momentum=0.1, affine=True, 106 | *args, **kwargs): 107 | super(IterNorm_Single, self).__init__() 108 | # assert dim == 4, 'IterNorm is not support 2D' 109 | self.T = T 110 | self.eps = eps 111 | self.momentum = momentum 112 | self.num_features = num_features 113 | self.affine = affine 114 | self.dim = dim 115 | shape = [1] * dim 116 | shape[1] = self.num_features 117 | 118 | self.register_buffer('running_mean', torch.zeros(num_features, 1)) 119 | # running whiten matrix 120 | self.register_buffer('running_wm', torch.eye(num_features)) 121 | 122 | 123 | def forward(self, X: torch.Tensor): 124 | X_hat = iterative_normalization_py.apply(X, self.running_mean, self.running_wm, self.num_features, self.T, self.eps, self.momentum, self.training) 125 | return X_hat 126 | 127 | class IterNorm(torch.nn.Module): 128 | def __init__(self, num_features, num_groups=1, num_channels=None, T=5, dim=4, eps=1e-5, momentum=0.1, affine=True, 129 | *args, **kwargs): 130 | super(IterNorm, self).__init__() 131 | # assert dim == 4, 'IterNorm is not support 2D' 132 | self.T = T 133 | self.eps = eps 134 | self.momentum = momentum 135 | self.num_features = num_features 136 | self.num_channels = num_channels 137 | num_groups = (self.num_features-1) // self.num_channels + 1 138 | self.num_groups = num_groups 139 | self.iterNorm_Groups = torch.nn.ModuleList( 140 | [IterNorm_Single(num_features = self.num_channels, eps=eps, momentum=momentum, T=T) for _ in range(self.num_groups-1)] 141 | ) 142 | num_channels_last=self.num_features - self.num_channels * (self.num_groups -1) 143 | self.iterNorm_Groups.append(IterNorm_Single(num_features = num_channels_last, eps=eps, momentum=momentum, T=T)) 144 | 145 | self.affine = affine 146 | self.dim = dim 147 | shape = [1] * dim 148 | shape[1] = self.num_features 149 | if self.affine: 150 | self.weight = Parameter(torch.Tensor(*shape)) 151 | self.bias = Parameter(torch.Tensor(*shape)) 152 | else: 153 | self.register_parameter('weight', None) 154 | self.register_parameter('bias', None) 155 | self.reset_parameters() 156 | 157 | def reset_parameters(self): 158 | # self.reset_running_stats() 159 | if self.affine: 160 | torch.nn.init.ones_(self.weight) 161 | torch.nn.init.zeros_(self.bias) 162 | 163 | def forward(self, X: torch.Tensor): 164 | X_splits = torch.split(X, self.num_channels, dim=1) 165 | X_hat_splits = [] 166 | for i in range(self.num_groups): 167 | X_hat_tmp = self.iterNorm_Groups[i](X_splits[i]) 168 | X_hat_splits.append(X_hat_tmp) 169 | X_hat = torch.cat(X_hat_splits, dim=1) 170 | # affine 171 | if self.affine: 172 | return X_hat * self.weight + self.bias 173 | else: 174 | return X_hat 175 | 176 | def extra_repr(self): 177 | return '{num_features}, num_channels={num_channels}, T={T}, eps={eps}, ' \ 178 | 'momentum={momentum}, affine={affine}'.format(**self.__dict__) 179 | 180 | 181 | if __name__ == '__main__': 182 | ItN = IterNorm(16, num_channels=4, T=10, momentum=1, affine=False) 183 | print(ItN) 184 | ItN.train() 185 | #x = torch.randn(32, 64, 14, 14) 186 | x = torch.randn(32, 16) 187 | x.requires_grad_() 188 | y = ItN(x) 189 | z = y.transpose(0, 1).contiguous().view(x.size(1), -1) 190 | print(z.matmul(z.t()) / z.size(1)) 191 | 192 | y.sum().backward() 193 | print('x grad', x.grad.size()) 194 | 195 | ItN.eval() 196 | y = ItN(x) 197 | z = y.transpose(0, 1).contiguous().view(x.size(1), -1) 198 | print(z.matmul(z.t()) / z.size(1)) 199 | -------------------------------------------------------------------------------- /extension/normailzation/normailzation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch.nn as nn 3 | from .center_normalization import CenterNorm 4 | from .group_batch_normalization import GroupBatchNorm 5 | from .iterative_normalization import IterNorm 6 | #from .iterative_normalization_FlexGroup import IterNorm 7 | from .dbn import DBN, DBN2 8 | from ..utils import str2dict 9 | 10 | 11 | def _GroupNorm(num_features, num_groups=32, eps=1e-5, affine=True, *args, **kwargs): 12 | return nn.GroupNorm(num_groups, num_features, eps=eps, affine=affine) 13 | 14 | 15 | def _LayerNorm(normalized_shape, eps=1e-5, affine=True, *args, **kwargs): 16 | return nn.LayerNorm(normalized_shape, eps=eps, elementwise_affine=affine) 17 | 18 | 19 | def _BatchNorm(num_features, dim=4, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, *args, **kwargs): 20 | return (nn.BatchNorm2d if dim == 4 else nn.BatchNorm1d)(num_features, eps=eps, momentum=momentum, affine=affine, 21 | track_running_stats=track_running_stats) 22 | 23 | 24 | def _InstanceNorm(num_features, dim=4, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False, *args, 25 | **kwargs): 26 | return (nn.InstanceNorm2d if dim == 4 else nn.InstanceNorm1d)(num_features, eps=eps, momentum=momentum, 27 | affine=affine, 28 | track_running_stats=track_running_stats) 29 | 30 | 31 | class _config: 32 | norm = 'BN' 33 | norm_cfg = {} 34 | norm_methods = {'BN': _BatchNorm, 'GN': _GroupNorm, 'LN': _LayerNorm, 'IN': _InstanceNorm, 'CN': CenterNorm, 35 | 'None': None, 'GBN': GroupBatchNorm, 'DBN': DBN, 'DBN2': DBN2, 'ItN': IterNorm} 36 | 37 | 38 | def add_arguments(parser: argparse.ArgumentParser): 39 | group = parser.add_argument_group('Normalization Options') 40 | group.add_argument('--norm', default='BN', help='Use which normalization layers? {' + ', '.join( 41 | _config.norm_methods.keys()) + '}' + ' (defalut: {})'.format(_config.norm)) 42 | group.add_argument('--norm-cfg', type=str2dict, default={}, metavar='DICT', help='layers config.') 43 | return group 44 | 45 | 46 | def setting(cfg: argparse.Namespace): 47 | for key, value in vars(cfg).items(): 48 | if key in _config.__dict__: 49 | setattr(_config, key, value) 50 | return ('_' + _config.norm) if _config.norm != 'BN' else '' 51 | 52 | 53 | def Norm(*args, **kwargs): 54 | kwargs.update(_config.norm_cfg) 55 | if _config.norm == 'None': 56 | return None 57 | return _config.norm_methods[_config.norm](*args, **kwargs) 58 | -------------------------------------------------------------------------------- /extension/optimizer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from .utils import str2dict 4 | from .logger import get_logger 5 | 6 | _methods = {'sgd': torch.optim.SGD, 'adam': torch.optim.Adam, 'adamax': torch.optim.Adamax, 7 | 'RMSprop': torch.optim.RMSprop} 8 | 9 | 10 | def add_arguments(parser: argparse.ArgumentParser): 11 | group = parser.add_argument_group('Optimizer Option:') 12 | group.add_argument('-oo', '--optimizer', default='sgd', choices=_methods.keys(), 13 | help='the optimizer method to train network {' + ', '.join(_methods.keys()) + '}') 14 | group.add_argument('-oc', '--optimizer-config', default={}, type=str2dict, metavar='DICT', 15 | help='The configure for optimizer') 16 | group.add_argument('-wd', '--weight-decay', default=0, type=float, metavar='FLOAT', 17 | help='weight decay (default: 0).') 18 | return 19 | 20 | 21 | def setting(model: torch.nn.Module, cfg: argparse.Namespace, **kwargs): 22 | if cfg.optimizer == 'sgd': 23 | kwargs.setdefault('momentum', 0.9) 24 | if hasattr(cfg, 'lr'): 25 | kwargs['lr'] = cfg.lr 26 | kwargs['weight_decay'] = cfg.weight_decay 27 | kwargs.update(cfg.optimizer_config) 28 | params = model.parameters() 29 | logger = get_logger() 30 | optimizer = _methods[cfg.optimizer](params, **kwargs) 31 | logger('==> Optimizer {}'.format(optimizer)) 32 | return optimizer 33 | -------------------------------------------------------------------------------- /extension/progress_bar.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | import os 3 | import sys 4 | import time 5 | 6 | __all__ = ['ProgressBar', 'format_time'] 7 | 8 | 9 | def _get_terminal_size(): 10 | try: 11 | columns, lines = os.get_terminal_size() 12 | return int(columns) 13 | except OSError: 14 | return -1 15 | 16 | 17 | def format_time(seconds): 18 | days = int(seconds / 3600 / 24) 19 | seconds = seconds - days * 3600 * 24 20 | hours = int(seconds / 3600) 21 | seconds = seconds - hours * 3600 22 | minutes = int(seconds / 60) 23 | seconds = seconds - minutes * 60 24 | secondsf = int(seconds) 25 | seconds = seconds - secondsf 26 | millis = int(seconds * 1000) 27 | 28 | f = '' 29 | i = 1 30 | if days > 0: 31 | f += str(days) + 'D' 32 | i += 1 33 | if hours > 0 and i <= 2: 34 | f += str(hours) + 'h' 35 | i += 1 36 | if minutes > 0 and i <= 2: 37 | f += str(minutes) + 'm' 38 | i += 1 39 | if secondsf > 0 and i <= 2: 40 | f += str(secondsf) + 's' 41 | i += 1 42 | if millis > 0 and i <= 2: 43 | f += str(millis) + 'ms' 44 | i += 1 45 | if f == '': 46 | f = '0ms' 47 | return f 48 | 49 | 50 | class ProgressBar(object): 51 | def __init__(self, total=100, max_length=160): 52 | self.start_time = time.time() 53 | self.iter = 0 54 | self.total = total 55 | self.max_length = max_length 56 | self.msg_on_bar = '' 57 | self.msg_end = '' 58 | self.bar_length = 80 59 | 60 | def reset(self): 61 | self.start_time = time.time() 62 | self.iter = 0 63 | 64 | def _deal_message(self): 65 | msg = self.msg_on_bar.strip().lstrip() 66 | if len(msg) > self.bar_length: 67 | msg = msg[0:self.bar_length - 3] 68 | msg += '...' 69 | # center message 70 | msg = ' ' * ((self.bar_length - len(msg)) // 2) + msg 71 | msg = msg + ' ' * (self.bar_length - len(msg)) 72 | self.msg_on_bar = msg 73 | 74 | def _raw_output(self): 75 | self.bar_length = 50 76 | show_len = int(self.iter / self.total * self.bar_length) 77 | msg = '\r|' + '>' * show_len + ' ' * (self.bar_length - show_len) 78 | msg += '| ' + self.msg_end + ' ' + self.msg_on_bar 79 | sys.stdout.write(msg) 80 | 81 | def step(self, msg='', add=1): 82 | """ 83 | :param add: How many iterations are executed? 84 | :param msg: the message need to be shown on the progress bar 85 | """ 86 | self.iter = min(self.iter + add, self.total) 87 | if not isinstance(msg, str): 88 | msg = '{}'.format(msg) 89 | self.msg_end = ' {}/{}'.format(self.iter, self.total) 90 | used_time = time.time() - self.start_time 91 | self.msg_end += ' {}'.format(format_time(used_time)) 92 | if self.iter != self.total: 93 | left_time = used_time / self.iter * (self.total - self.iter) 94 | self.msg_end += '<={}'.format(format_time(left_time)) 95 | self.msg_on_bar = msg 96 | 97 | columns = min(_get_terminal_size(), self.max_length) 98 | if columns < 0: 99 | self._raw_output() 100 | else: 101 | self.bar_length = columns - len(self.msg_end) 102 | self._linux_output() 103 | 104 | if self.iter == self.total: 105 | sys.stdout.write('\n') 106 | sys.stdout.flush() 107 | 108 | def time_used(self): 109 | used_time = time.time() - self.start_time 110 | return format_time(used_time) 111 | 112 | def _linux_output(self): 113 | show_len = int(self.iter / self.total * self.bar_length) 114 | self._deal_message() 115 | 116 | control = '\r' # 回到行首 117 | control += '\33[4m' # 下划线 118 | control += '\33[40;37m' # 黑底白字 119 | # control += '\33[7m' # 反显 120 | # control += '\33[?25l' # 隐藏光标 121 | control += self.msg_on_bar[0:show_len] 122 | # control += '\33[0m' # 反显 123 | control += '\33[47;30m' # 白底黑字 124 | control += self.msg_on_bar[show_len:self.bar_length] 125 | # control += '\33[K' # 清除从光标到行尾的内容 126 | control += '\33[0m' 127 | 128 | sys.stdout.write(control) 129 | sys.stdout.write(self.msg_end) 130 | 131 | 132 | if __name__ == '__main__': 133 | bar = ProgressBar() 134 | 135 | epoch = 0 136 | while True: 137 | bar.reset() 138 | for i in range(bar.total): 139 | bar.step(epoch) 140 | time.sleep(0.1) 141 | epoch += 1 142 | -------------------------------------------------------------------------------- /extension/scheduler.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from torch.optim.lr_scheduler import * 4 | 5 | from .utils import str2list 6 | from .logger import get_logger 7 | 8 | _methods = ['fix', 'step', 'steps', 'ploy', 'auto', 'exp', 'user', 'cos', '1cycle'] 9 | 10 | 11 | def add_arguments(parser: argparse.ArgumentParser): 12 | # train learning rate 13 | group = parser.add_argument_group('Learning rate scheduler Option:') 14 | group.add_argument('--lr-method', default='step', choices=_methods, metavar='METHOD', 15 | help='The learning rate scheduler: {' + ', '.join(_methods) + '}') 16 | group.add_argument('--lr', default=0.1, type=float, metavar='LR', help='The initial learning rate (default: 0.1)') 17 | group.add_argument('--lr-step', default=30, type=int, 18 | help='Every some epochs, the learning rate is multiplied by a factor (default: 30)') 19 | group.add_argument('--lr-gamma', default=0.1, type=float, help='The learning rate decay factor. (default: 0.1)') 20 | group.add_argument('--lr-steps', default=[], type=str2list, help='the step values for learning rate policy "steps"') 21 | return group 22 | 23 | 24 | def setting(optimizer, args, lr_func=None, **kwargs): 25 | lr_method = args.lr_method 26 | if lr_method == 'fix': 27 | scheduler = StepLR(optimizer, args.epochs, args.lr_gamma) 28 | elif lr_method == 'step': 29 | scheduler = StepLR(optimizer, args.lr_step, args.lr_gamma) 30 | elif lr_method == 'steps': 31 | scheduler = MultiStepLR(optimizer, args.lr_steps, args.lr_gamma) 32 | elif lr_method == 'ploy': 33 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, 34 | lambda _epoch: (1. - _epoch / args.epochs) ** args.lr_gamma) 35 | elif lr_method == 'auto': 36 | scheduler = ReduceLROnPlateau(optimizer, factor=args.lr_gamma, patience=args.lr_step, verbose=True) 37 | elif lr_method == 'exp': 38 | scheduler = ExponentialLR(optimizer, args.lr_gamma) 39 | elif lr_method == 'user': 40 | scheduler = LambdaLR(optimizer, lr_func) 41 | elif lr_method == 'cos': 42 | scheduler = CosineAnnealingLR(optimizer, args.lr_step, args.lr_gamma) 43 | elif lr_method == '1cycle': 44 | gamma = (args.lr_gamma - args.lr) / args.lr_step 45 | 46 | def adjust(epoch): 47 | if epoch < args.lr_step * 2: 48 | return (args.lr_gamma - gamma * abs(epoch - args.lr_step)) / args.lr 49 | else: 50 | return (args.epochs - epoch) / (args.epochs - args.lr_step * 2) 51 | 52 | scheduler = LambdaLR(optimizer, adjust) 53 | else: 54 | raise NotImplementedError('Learning rate scheduler {} is not supported!'.format(lr_method)) 55 | LOG = get_logger() 56 | LOG('==> Scheduler: {}'.format(scheduler)) 57 | return scheduler 58 | -------------------------------------------------------------------------------- /extension/test/IterNorm_test.py: -------------------------------------------------------------------------------- 1 | from extension.normailzation.iterative_normalization import IterNorm 2 | import torch 3 | from extension.test.test_util import * 4 | 5 | 6 | class IterNorm_py(IterNorm): 7 | def forward(self, X: torch.Tensor): 8 | # change NxCxHxW to Cx(NxHxW), i.e., d*m 9 | x = X.transpose(0, 1).contiguous().view(self.num_groups, self.num_features // self.num_groups, -1) 10 | g, d, m = x.size() 11 | if self.training: 12 | # calculate centered activation by subtracted mini-batch mean 13 | mean = x.mean(-1, keepdim=True) 14 | x_c = x - mean 15 | # calculate covariance matrix 16 | Sigma = x_c.matmul(x_c.transpose(1, 2)) / m + self.eps * torch.eye(d, dtype=X.dtype, device=X.device) 17 | # Sigma = torch.eye(d).to(X) 18 | # torch.baddbmm(self.eps, Sigma, 1./m, x_c, x_c.transpose(1, 2)) 19 | # reciprocal of trace of Sigma: shape [g, 1, 1] 20 | rTr = x_c.new_empty(g, 1, 1) 21 | for i in range(g): 22 | rTr[i] = 1. / Sigma[i].trace() 23 | sigma_norm = Sigma * rTr 24 | P = [None] * (self.T + 1) 25 | P[0] = torch.eye(d).to(X).expand(g, d, d) 26 | for k in range(self.T): 27 | P[k + 1] = 0.5 * (3 * P[k] - torch.matrix_power(P[k], 3).matmul( 28 | sigma_norm)) # P[k + 1] = P[k].clone() # torch.baddbmm(1.5, P[k + 1], -0.5, torch.matrix_power(P[k], 3), sigma_norm) 29 | sigma_inv = P[self.T] * rTr.sqrt() 30 | self.running_mean = self.momentum * mean + (1. - self.momentum) * self.running_mean 31 | self.running_wm = self.momentum * sigma_inv + (1. - self.momentum) * self.running_wm 32 | else: 33 | x_c = x - self.running_mean 34 | sigma_inv = self.running_wm 35 | x_hat = sigma_inv.matmul(x_c) 36 | X_hat = x_hat.view(X.size(1), X.size(0), *X.size()[2:]).transpose(0, 1).contiguous() 37 | 38 | # affine 39 | if self.affine: 40 | return X_hat * self.weight + self.bias 41 | else: 42 | return X_hat 43 | 44 | 45 | def test_IterNorm(test_number=100): 46 | device = torch.device('cuda') 47 | torch.set_default_dtype(torch.float64) 48 | batch_size = 16 49 | eps = 1e-6 if torch.get_default_dtype() == torch.float64 else 1e-4 50 | 51 | fm = Meter() 52 | bm = Meter() 53 | 54 | for i in range(test_number): 55 | torch.cuda.empty_cache() 56 | shape = batch_size, *rand_shapes('3d', 'input') 57 | T = random.randint(1, 10) 58 | num_channels = 2 ** random.randint(3, 6) 59 | print('run test [{}/{}], input shape: {}, T: {}, num_channels={} '.format(i + 1, test_number, shape, T, 60 | num_channels), end='\r') 61 | if shape[1] == 3: 62 | continue 63 | x1 = torch.randn(shape, device=device) 64 | x2 = x1.data.clone() 65 | x1.requires_grad_() 66 | x2.requires_grad_() 67 | g = torch.randn(shape, device=device) 68 | n1 = IterNorm_py(shape[1], num_channels=num_channels, T=T, dim=len(shape)).to(device) 69 | n2 = IterNorm(shape[1], num_channels=num_channels, T=T, dim=len(shape)).to(device) 70 | 71 | r1 = fm.run1(n1, x1) 72 | bm.run1(lambda: torch.autograd.backward(r1, g)) 73 | 74 | r2 = fm.run2(n2, x2) 75 | bm.run2(lambda: torch.autograd.backward(r2, g)) 76 | 77 | # z = r1.transpose(0, 1).contiguous().view(shape[1], -1) 78 | # z = z.matmul(z.t()) / z.size(1) - torch.eye(shape[1], device=device) 79 | # print('\n', z) 80 | # assert (z.abs() < 1e-4).sum().item() == 0 81 | # 82 | # z = r2.transpose(0, 1).contiguous().view(shape[1], -1) 83 | # z = z.matmul(z.t()) / z.size(1) - torch.eye(shape[1], device=device) 84 | # print(z) 85 | # assert (z.abs() < 1e-4).sum().item() == 0 86 | 87 | check(r1, r2, eps=eps) 88 | # check(x1.grad, x2.grad, eps=eps) 89 | check(n1.running_mean, n2.running_mean) 90 | check(n1.running_wm, n2.running_wm) 91 | 92 | del r1, r2, x1, x2, g, n1, n2 93 | 94 | print('\n\033[32mPass {} test!\033[0m'.format(test_number)) 95 | fm.print('IterNorm forward', 'py', 'c++') 96 | bm.print('IterNorm backward', 'py', 'c++') 97 | 98 | 99 | if __name__ == '__main__': 100 | seed = 0 101 | random.seed(seed) 102 | torch.manual_seed(seed) 103 | print("############# Test IterNorm #############") 104 | print('seed = {}'.format(seed)) 105 | test_IterNorm(1000) 106 | -------------------------------------------------------------------------------- /extension/test/test_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | import random 4 | 5 | weight_3d_shapes = [(64, 3, 7, 7), (64, 3, 11, 11), (64, 64, 1, 1), (64, 64, 3, 3), (64, 256, 1, 1), (128, 64, 1, 1), 6 | (128, 64, 3, 3), (128, 128, 3, 3), (128, 256, 1, 1), (128, 512, 1, 1), (192, 64, 5, 5), 7 | (256, 64, 1, 1), (256, 128, 1, 1), (256, 128, 3, 3), (256, 256, 3, 3), (256, 384, 3, 3), 8 | (256, 512, 1, 1), (256, 1024, 1, 1), (384, 192, 3, 3), (512, 128, 1, 1), (512, 256, 1, 1), 9 | (512, 256, 1, 1), (512, 256, 3, 3), (512, 512, 3, 3), (512, 1024, 1, 1), (512, 2048, 1, 1), 10 | (1024, 256, 1, 1), (1024, 512, 1, 1), (2048, 512, 1, 1), (2048, 1024, 1, 1)] 11 | 12 | weight_2d_shapes = [(1000, 512), (1000, 2048), (1000, 4096), (4096, 4096), (4096, 9216)] 13 | 14 | input_3d_shapes = [(3, 224, 244), (64, 112, 112), (64, 56, 56), (96, 55, 55), (128, 112, 112), (128, 56, 56), 15 | (128, 28, 28), (256, 56, 56), (256, 28, 28), (256, 27, 27), (256, 14, 14), (256, 7, 7), 16 | (384, 13, 13), (512, 28, 28), (512, 14, 14), (512, 7, 7), (1024, 14, 14), (2048, 7, 7)] 17 | input_2d_shapes = [(512 * 7 * 7,), (256 * 7 * 7,), (4096,), (1000,)] 18 | 19 | 20 | def rand_shapes(use="all", where='weight'): 21 | if where.startswith('w'): 22 | if use == 'all': 23 | return random.choice(weight_3d_shapes + weight_2d_shapes) 24 | elif use == '3d': 25 | return random.choice(weight_3d_shapes) 26 | else: 27 | return random.choice(weight_2d_shapes) 28 | else: 29 | if use == 'all': 30 | return random.choice(input_3d_shapes + input_2d_shapes) 31 | elif use == '3d': 32 | return random.choice(input_3d_shapes) 33 | else: 34 | return random.choice(input_2d_shapes) 35 | 36 | 37 | def check(x: torch.Tensor, y: torch.Tensor, eps=1e-6, msg='Check Failed!'): 38 | err = (x - y).abs() / x.abs().max() 39 | err = err.max() 40 | if err > eps: 41 | x = x.view(-1) 42 | y = y.view(-1) 43 | err_idx = (x - y).abs().topk(min(8, x.numel()))[1] 44 | print('') 45 | print('idx: {}'.format(err_idx.data.cpu())) 46 | print('x: {}'.format(x[err_idx].data.cpu())) 47 | print('y: {}'.format(y[err_idx].data.cpu())) 48 | print('Error {} > eps={}'.format(err, eps)) 49 | raise Exception(msg) 50 | 51 | 52 | class Meter: 53 | def __init__(self): 54 | self.t1 = 0 55 | self.t2 = 0 56 | self.cnt1 = 0 57 | self.cnt2 = 0 58 | 59 | def run1(self, func, *args): 60 | self.cnt1 += 1 61 | st = time.time() 62 | output = func(*args) 63 | torch.cuda.synchronize() 64 | self.t1 += time.time() - st 65 | return output 66 | 67 | def run2(self, func, *args): 68 | self.cnt2 += 1 69 | st = time.time() 70 | output = func(*args) 71 | torch.cuda.synchronize() 72 | self.t2 += time.time() - st 73 | return output 74 | 75 | def print(self, info='', name1='benchmark', name2='test'): 76 | self.t1 /= self.cnt1 77 | self.t2 /= self.cnt2 78 | unit = 's' 79 | if self.t1 < 1.0: 80 | self.t1 *= 1000 81 | self.t2 *= 1000 82 | unit = 'ms' 83 | if self.t1 < 1.0: 84 | self.t1 *= 1000 85 | self.t2 *= 1000 86 | unit = 'us' 87 | 88 | print('{}: {}: {:.2f} {}, {}: {:.2f} {} ({:.2f}x)'.format(info, name1, self.t1, unit, name2, self.t2, unit, 89 | self.t1 / self.t2)) 90 | -------------------------------------------------------------------------------- /extension/trainer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import time 5 | import random 6 | import torch.backends.cudnn as cudnn 7 | import torch.optim.lr_scheduler 8 | 9 | from .logger import get_logger 10 | 11 | 12 | def add_arguments(parser: argparse.ArgumentParser): 13 | group = parser.add_argument_group('Train Option') 14 | group.add_argument('-n', '--epochs', default=90, type=int, metavar='N', help='The total number of training epochs.') 15 | group.add_argument('--start-epoch', default=-1, type=int, metavar='N', 16 | help='manual epoch number (useful on restarts)') 17 | group.add_argument('-o', '--output', default='./results', metavar='PATH', 18 | help='The root path to store results (default ./results)') 19 | group.add_argument('-t', '--test', action='store_true', help='Only test model on validation set?') 20 | group.add_argument('--seed', default=-1, type=int, help='manual seed') 21 | # group.add_argument('--gpu', default=None, type=int, help='GPU id to use.') 22 | return 23 | 24 | 25 | def setting(cfg: argparse.Namespace): 26 | cudnn.benchmark = True 27 | logger = get_logger() 28 | logger('==> args: {}'.format(cfg)) 29 | logger('==> the results path: {}'.format(cfg.output)) 30 | if not hasattr(cfg, 'seed') or cfg.seed < 0: 31 | cfg.seed = int(time.time()) 32 | random.seed(cfg.seed) 33 | torch.manual_seed(cfg.seed) 34 | logger('==> seed: {}'.format(cfg.seed)) 35 | logger('==> PyTorch version: {}, cudnn version: {}'.format(torch.__version__, cudnn.version())) 36 | git_version = os.popen('git log --pretty=oneline | head -n 1').readline()[:-1] 37 | logger('==> git version: {}'.format(git_version)) 38 | return 39 | 40 | # 41 | # class Trainer(object): 42 | # 43 | # def __init__(self, num_model=1): 44 | # # config 45 | # self.parser = argparse.ArgumentParser(description='Trainer') 46 | # self.add_arguments() 47 | # self.args = self.parser.parse_args() 48 | # 49 | # self.num_model = num_model 50 | # 51 | # self.model_name = '' 52 | # self.model = torch.nn.Module() if self.num_model == 1 else [] 53 | # self.quantization_cfg = quantization.setting(self.args) 54 | # 55 | # self.result_path = '' 56 | # self.logger = None 57 | # self.vis = visualization.Visualization(False) 58 | # 59 | # self.device = None 60 | # self.num_gpu = 0 61 | # 62 | # self.train_transform = [] 63 | # self.val_transform = [] 64 | # self.train_loader = None 65 | # self.val_loader = None 66 | # 67 | # self.lr_schedulers = [] 68 | # self.optimizer = None 69 | # self.start_epoch = self.args.start_epoch if hasattr(self.args, 'start_epoch') else -1 70 | # self.criterion = None 71 | # 72 | # self.start_time = time.time() 73 | # self.global_steps = 0 74 | # # self.set_model() 75 | # # self.set_optimizer() 76 | # # self.set_dataset() 77 | # # self.set_device() 78 | # # self.resume() 79 | # # self.set_lr_scheduler() 80 | # return 81 | # 82 | # def train(self): 83 | # self.logger('\n++++++++++ train start (time: {}) ++++++++++'.format( 84 | # time.strftime("%y-%m-%d %H:%M:%S", time.localtime(time.time())))) 85 | # if self.args.evaluate: 86 | # self.validate() 87 | # return 88 | # self.start_time = time.time() 89 | # self.global_steps = 0 90 | # for epoch in range(self.start_epoch + 1, self.args.epochs): 91 | # if self.args.lr_method != 'auto': 92 | # for i in range(len(self.lr_schedulers)): 93 | # self.lr_schedulers[i].step(epoch=epoch) 94 | # self.train_epoch(epoch) 95 | # value = self.validate(epoch) 96 | # if self.args.lr_method == 'auto': 97 | # for i in range(len(self.lr_schedulers)): 98 | # self.lr_schedulers[i].step(value, epoch=epoch) 99 | # # self.save_checkpoint(epoch) 100 | # self.save() 101 | # now_date = time.strftime("%y-%m-%d_%H:%M:%S", time.localtime(time.time())) 102 | # new_log_filename = '{}_{}.txt'.format(self.model_name, now_date) 103 | # self.logger('\n==> Network training completed. Copy log file to {}'.format(new_log_filename)) 104 | # self.logger.save(new_log_filename) 105 | # self.logger('\n++++++++++ train finished (time: {}) ++++++++++'.format( 106 | # time.strftime("%y-%m-%d %H:%M:%S", time.localtime(time.time())))) 107 | # 108 | # def train_epoch(self, epoch): 109 | # return NotImplementedError 110 | # 111 | # def validate(self, epoch=-1): 112 | # return NotImplementedError 113 | # 114 | # def set_device(self): 115 | # self.device = torch.device('cuda') 116 | # self.num_gpu = torch.cuda.device_count() 117 | # if self.num_model == 1: 118 | # if self.num_gpu > 1: 119 | # self.model = torch.nn.DataParallel(self.model) 120 | # self.model.cuda() 121 | # else: 122 | # for i in range(self.num_model): 123 | # if self.num_gpu > 1: 124 | # self.model[i] = torch.nn.DataParallel(self.model[i]) 125 | # self.model[i].cuda() 126 | # self.logger('==> use {:d} GPUs with cudnn {}'.format(self.num_gpu, cudnn.version())) 127 | # 128 | # def set_dataset(self): 129 | # if not self.args.evaluate: 130 | # self.train_loader = dataset.get_dataset_loader(self.args, self.train_transform, train=True) 131 | # self.logger('==> Train Data Transforms: {}'.format(self.train_transform)) 132 | # self.val_loader = dataset.get_dataset_loader(self.args, self.val_transform, train=False) 133 | # self.logger('==> Val Data Transforms: {}'.format(self.val_transform)) 134 | # self.logger('==> Dataset: {}, image size: {}, classes: {}'.format( 135 | # self.args.dataset, self.args.im_size, self.args.dataset_classes)) 136 | # return 137 | -------------------------------------------------------------------------------- /extension/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch.nn as nn 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | 8 | class Shortcut(nn.Module): 9 | def __init__(self, block: nn.Module, shortcut=None): 10 | super(Shortcut, self).__init__() 11 | self.block = block 12 | self.shortcut = shortcut 13 | self.weight = nn.Parameter(torch.ones(1)) # self.weight.data.fill_(0.1) 14 | 15 | def forward(self, x): 16 | if self.shortcut is not None: 17 | return self.block(x) + self.shortcut(x) 18 | y = self.block(x) 19 | if x.size()[2:4] != y.size()[2:4]: 20 | x = F.adaptive_avg_pool2d(x, y.size()[2:4]) 21 | # x = x * self.weight 22 | if x.size(1) >= y.size(1): 23 | y += x[:, :y.size(1), :, :] 24 | else: 25 | y[:, :x.size(1), :, :] += x 26 | return y 27 | 28 | 29 | class sign(torch.autograd.Function): 30 | @staticmethod 31 | def forward(ctx, *inputs): 32 | weight_f, ctx.slope, ctx.back_way = inputs 33 | weight_b = weight_f.sign() 34 | ctx.save_for_backward(weight_f) 35 | return weight_b 36 | 37 | @staticmethod 38 | def backward(ctx, *grads): 39 | grad, = grads 40 | weight_f, = ctx.saved_variables 41 | if ctx.back_way == 0: 42 | # based on HardTanh 43 | grad[weight_f.abs() >= 1.] *= ctx.slope 44 | elif ctx.back_way == 1: 45 | # based on polynomial function 46 | grad[weight_f.abs() >= 1.] *= ctx.slope 47 | grad[0. <= weight_f < 1.] *= 2 - 2 * weight_f[0. <= weight_f < 1.] 48 | grad[-1 < weight_f < 0.] *= 2 + 2 * weight_f[-1 < weight_f < 0.] 49 | return grad 50 | 51 | 52 | class Identity(nn.Module): 53 | def __init__(self): 54 | super(Identity, self).__init__() 55 | 56 | def forward(self, x): 57 | return x 58 | 59 | 60 | class Scale(nn.Module): 61 | def __init__(self, init_value=0.1): 62 | super(Scale, self).__init__() 63 | self.weight = nn.Parameter(torch.Tensor(1)) 64 | self.init_value = init_value 65 | self.reset_parameters() 66 | 67 | def reset_parameters(self): 68 | self.weight.data.fill_(self.init_value) 69 | 70 | def forward(self, input: torch.Tensor): 71 | return input * self.weight 72 | 73 | def extra_repr(self): 74 | return 'init_value={:.5g}'.format(self.init_value) 75 | 76 | 77 | def str2num(s: str): 78 | s.strip() 79 | try: 80 | value = int(s) 81 | except ValueError: 82 | try: 83 | value = float(s) 84 | except ValueError: 85 | if s == 'True': 86 | value = True 87 | elif s == 'False': 88 | value = False 89 | elif s == 'None': 90 | value = None 91 | else: 92 | value = s 93 | return value 94 | 95 | 96 | def str2bool(v): 97 | if not isinstance(v, str): 98 | return bool(v) 99 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 100 | return True 101 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 102 | return False 103 | else: 104 | raise argparse.ArgumentTypeError('Unsupported value encountered.') 105 | 106 | 107 | def str2dict(s) -> dict: 108 | if s is None: 109 | return {} 110 | if not isinstance(s, str): 111 | return s 112 | s = s.split(',') 113 | d = {} 114 | for ss in s: 115 | if ss == '': 116 | continue 117 | ss = ss.split('=') 118 | assert len(ss) == 2 119 | key = ss[0].strip() 120 | value = str2num(ss[1]) 121 | d[key] = value 122 | return d 123 | 124 | 125 | def str2list(s: str) -> list: 126 | if not isinstance(s, str): 127 | return list(s) 128 | items = [] 129 | s = s.split(',') 130 | for ss in s: 131 | if ss == '': 132 | continue 133 | items.append(str2num(ss)) 134 | return items 135 | 136 | 137 | def str2tuple(s: str) -> tuple: 138 | return tuple(str2list(s)) 139 | 140 | 141 | def extend_list(l: list, size: int): 142 | while len(l) < size: 143 | l.append(l[-1]) 144 | return l[:size] 145 | 146 | 147 | def path(p: str): 148 | return os.path.expanduser(p) 149 | -------------------------------------------------------------------------------- /extension/visualization.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import numpy as np 5 | 6 | 7 | def add_arguments(parser: argparse.ArgumentParser): 8 | group = parser.add_argument_group('Visualization Options') 9 | group.add_argument('--vis', action='store_true', help='Is the visualization training process?') 10 | group.add_argument('--vis-port', default=6006, type=int, help='The visualization port (default 6006)') 11 | # group.add_argument('--vis-env', default=None, help='The env name of visdom use. Default: ') 12 | return 13 | 14 | 15 | class Visualization: 16 | def __init__(self, cfg: argparse.Namespace): 17 | self.cfg = cfg 18 | self.viz = None 19 | self.env = None 20 | self.names = {} 21 | self.values = {} 22 | self.windows = {} 23 | self.cnt = {} 24 | self.num = {} 25 | 26 | def set(self, env_name, names: dict): 27 | if not self.cfg.vis: 28 | return 29 | try: 30 | import visdom 31 | self.env = env_name 32 | self.viz = visdom.Visdom(env=env_name, port=self.cfg.vis_port) 33 | except ImportError: 34 | print('You do not install visdom!!!!') 35 | self.cfg.vis = False 36 | return 37 | self.names = names 38 | self.values = {} 39 | self.windows = {} 40 | self.cnt = {} 41 | self.num = {} 42 | for name, label in self.names.items(): 43 | self.values[name] = 0 44 | self.cnt[label] = 0 45 | self.num[label] = 0 46 | self.windows.setdefault(label, []) 47 | self.windows[label].append(name) 48 | 49 | for label, names in self.windows.items(): 50 | opts = dict(title=label, legend=names, showlegend=True, # webgl=False, 51 | # layoutopts={'plotly': {'legend': {'x': 0, 'y': 0}}}, 52 | # marginleft=0, marginright=0, margintop=10, marginbottom=0, 53 | ) 54 | 55 | zero = np.ones((1, len(names))) 56 | self.viz.line(zero, zero, win=label, opts=opts) 57 | 58 | def add_value(self, name, value): 59 | if not self.cfg.vis: 60 | return 61 | if isinstance(value, torch.Tensor): 62 | assert value.numel() == 1 63 | value = value.item() 64 | self.values[name] = value 65 | label = self.names[name] 66 | self.cnt[label] += 1 67 | if self.cnt[label] == len(self.windows[label]): 68 | y = np.array([[self.values[name] for name in self.windows[label]]]) 69 | x = np.ones_like(y) * self.num[label] 70 | opts = dict(title=label, legend=self.windows[label], showlegend=True, # webgl=False, 71 | layoutopts={'plotly': {'legend': {'x': 0.05, 'y': 1}}}, 72 | # marginleft=0, marginright=0, margintop=10, marginbottom=0, 73 | ) 74 | self.viz.line(y, x, update='append' if self.num[label] else 'new', win=label, opts=opts) 75 | self.cnt[label] = 0 76 | self.num[label] += 1 77 | 78 | def clear(self, label): 79 | if not self.cfg.vis: 80 | return 81 | self.num[label] = 0 82 | 83 | def add_images(self, images, title='images', win='images', nrow=8): 84 | if self.cfg.vis: 85 | self.viz.images(images, win=win, nrow=nrow, opts={'title': title}) 86 | 87 | def __del__(self): 88 | if self.viz: 89 | self.viz.save([self.env]) 90 | 91 | 92 | def setting(cfg: argparse.Namespace, env_name: str, names: dict): 93 | vis = Visualization(cfg) 94 | vis.set(env_name, names) 95 | return vis 96 | --------------------------------------------------------------------------------