├── .gitignore ├── LICENSE ├── README.md ├── config.yaml ├── examples └── lsq │ └── resnet18_a2w3_imagenet.yaml ├── logging.conf ├── main.py ├── model ├── __init__.py ├── model.py ├── resnet.py └── resnet_cifar.py ├── process.py ├── quan ├── __init__.py ├── func.py ├── quantizer │ ├── __init__.py │ ├── lsq.py │ └── quantizer.py └── utils.py └── util ├── __init__.py ├── checkpoint.py ├── config.py ├── data_loader.py ├── lr_scheduler.py └── monitor.py /.gitignore: -------------------------------------------------------------------------------- 1 | # ============================================================================== 2 | # Editors & IDEs 3 | 4 | # gedit: backup files 5 | *~ 6 | 7 | # VIM: swap files 8 | [._]*.s[a-w][a-z] 9 | [._]s[a-w][a-z] 10 | 11 | # VIM: session 12 | Session.vim 13 | 14 | # VIM: netrw.vim: Network oriented reading, writing, browsing (eg: ftp scp) 15 | .netrwhist 16 | 17 | # Intellij IDEA 18 | .idea/ 19 | 20 | # ============================================================================== 21 | # Language & EDA Tools 22 | 23 | # Python 24 | .py[cod] 25 | __pycache__/ 26 | 27 | # ============================================================================== 28 | # Operationg Systems 29 | 30 | # Folder view configuration files 31 | *.DS_store 32 | *.DS_store? 33 | Desktop.ini 34 | 35 | # Thumbnail cache files 36 | ._* 37 | Thumbs.db 38 | 39 | # Files that might appear on external disks 40 | .Spotlight-V100 41 | .Trashes 42 | 43 | # NFS Temp Files 44 | .nfs* 45 | 46 | # ============================================================================== 47 | # This Project 48 | 49 | # Working directory 50 | out/ 51 | 52 | # SVN 53 | .svn/ 54 | 55 | *.log 56 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Haozhe Zhu(zhutmost@outlook.com) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LSQ-Net: Learned Step Size Quantization 2 | 3 | ## Introduction 4 | 5 | This is an unofficial implementation of LSQ-Net, a deep neural network quantization framework. 6 | LSQ-Net is proposed by Steven K. Esser and et al. from IBM. It can be found on [arXiv:1902.08153](https://arxiv.org/abs/1902.08153). 7 | 8 | There are some little differences between my implementation and the original paper, which will be described in detail below. 9 | 10 | If this repository is helpful to you, please star it. 11 | 12 | ## Results and Models 13 | 14 | Here are some experiment results. 15 | We will release more quantized models with different configurations soon. 16 | 17 | All these models can be downloaded from [Dropbox](https://www.dropbox.com/sh/un1k74qael1k6mx/AADroPMhvCrd1szG6HUYO_N3a?dl=0). 18 | 19 | | Network | Config. File | Model | Bitwidth (W/A) | Top-1 Acc. (%) | Top-5 Acc. (%) | 20 | |:---------:|:--------:|:-------------:|:---------------:|:--------------:|:--------------:| 21 | | ResNet-18 | [link](examples/lsq/resnet18_a2w3_imagenet.yaml) | [link](https://www.dropbox.com/sh/a5spn8boovfhjrj/AAD-Ureq7DpMKOujPdH4l0jVa?dl=0) | 3/2 | 66.9 | 87.2 | 22 | 23 | ## User Guide 24 | 25 | ### Install Dependencies 26 | 27 | First install library dependencies within an Anaconda environment. 28 | 29 | ```bash 30 | # Create a environment with Python 3.8 31 | conda create -n lsq python=3.8 32 | # PyTorch GPU version >= 1.5 33 | conda install pytorch torchvision cudatoolkit=10.1 -c pytorch 34 | # Tensorboard visualization tool 35 | conda install tensorboard 36 | # Miscellaneous 37 | conda install scikit-learn pyyaml munch 38 | ``` 39 | 40 | ### Run Scripts with Your Configurations 41 | 42 | This program use YAML files as inputs. A template as well as the default configuration is providen as `config.yaml`. 43 | 44 | If you want to change the behaviour of this program, please copy it somewhere else. And then run the `main.py` with your modified configuration file. 45 | 46 | ``` 47 | python main.py /path/to/your/config/file.yaml 48 | ``` 49 | 50 | The modified options in your YAML file will overwrite the default settings. For details, please read the comments in `config.yaml`. After every epoch, the program will automatically store the best model parameters as a checkpoint. You can modify the option `resume.path: /path/to/checkpoint.pth.tar` in the YAML file to resume the training process, or evaluate the accuracy of the quantized model. 51 | 52 | You can find some example configuration files in the [example](examples) folder. 53 | 54 | ## Implementation Differences From the Original Paper 55 | 56 | LSQ-Net paper has two versions, [v1](https://arxiv.org/pdf/1902.08153v2.pdf) and [v2](https://arxiv.org/pdf/1902.08153v1.pdf). 57 | To improve accuracy, the authors expanded the quantization space in the v2 version. 58 | Recently they released a new version [v3](https://arxiv.org/pdf/1902.08153v3.pdf), which fixed some typos in the v2 version. 59 | 60 | My implementation generally follows the v2 version, except for the following points. 61 | 62 | ### Initial Values of the Quantization Step Size 63 | 64 | The authors use `Tensor(v.abs().mean() * 2 / sqrt(Qp))` as initial values of the step sizes in both weight and activation quantization layers, where Qp is the upper bound of the quantization space, and v is the initial weight values or the first batch of activations. 65 | 66 | In my implementation, the step sizes in weight quantization layers are initialized in the same way, but in activation quantization layers, the step sizes are initialized as `Tensor(1.0)`. 67 | 68 | ### Supported Models 69 | 70 | Currently, only ResNet is supported. 71 | For the ImageNet dataset, the ResNet-18/34/50/101/152 models are copied from the torchvision model zoo. 72 | For the CIFAR10 dataset, the models are modified based on [Yerlan Idelbayev's contribution](https://github.com/akamaster/pytorch_resnet_cifar10), including ResNet-20/32/44/56/110/1202. 73 | 74 | Thanks to the non-invasive nature of the framework, it is easy to add another new architectures beside ResNet. 75 | All you need is to paste your model code into the `model` folder, and then add a corresponding entry in the `model/model.py`. 76 | The quantization framework will automatically replace layers specified in `quan/func.py` with their quantized versions automatically. 77 | 78 | ## Contributing Guide 79 | 80 | I am not a professional algorithm researcher, and I only have very limited GPU resources. Thus, I may not spend too much time continuing to optimize its accuracy. 81 | 82 | However, if you find any bugs in my code or have any ideas to improve the quantization results, please feel free to open an issue. I will be glad to join the discussion. 83 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | #=============================================================================== 2 | # Default Configuration for LSQ-Net 3 | #=============================================================================== 4 | # Please do NOT modify this file directly. If you want to modify configurations, 5 | # please: 6 | # 1. Create a new YAML file and copy some bellowing options to it. 7 | # 2. Modify these options in your YAML file. 8 | # 3. run main.py with your configuration file in the command line, like this: 9 | # $ python main.py path/to/your/config/file 10 | # The options modified in your configuration file will overwrite those in this 11 | # file. 12 | #============================ Environment ====================================== 13 | 14 | # Experiment name 15 | name: MyProject 16 | 17 | # Name of output directory. Checkpoints and logs will be saved at `pwd`/output_dir 18 | output_dir: out 19 | 20 | # Device to be used 21 | device: 22 | # Use CPU or GPU (choices: cpu, cuda) 23 | type: cuda 24 | # GPU device IDs to be used. Only valid when device.type is 'cuda' 25 | gpu: [0, 1] 26 | 27 | # Dataset loader 28 | dataloader: 29 | # Dataset to train/validate (choices: imagenet, cifar10) 30 | dataset: imagenet 31 | # Number of categories in the specified dataset (choices: 1000, 10) 32 | num_classes: 1000 33 | # Path to dataset directory 34 | path: /localhome/fair/Dataset/imagenet 35 | # Size of mini-batch 36 | batch_size: 64 37 | # Number of data loading workers 38 | workers: 32 39 | # Seeds random generators in a deterministic way (i.e., set all the seeds 0). 40 | # Please keep it true when resuming the experiment from a checkpoint 41 | deterministic: true 42 | # Load the model without DataParallel wrapping it 43 | serialized: false 44 | # Portion of training dataset to set aside for validation (range: [0, 1)) 45 | val_split: 0.05 46 | 47 | resume: 48 | # Path to a checkpoint to be loaded. Leave blank to skip 49 | path: 50 | # Resume model parameters only 51 | lean: false 52 | 53 | log: 54 | # Number of best scores to track and report 55 | num_best_scores: 3 56 | # Print frequency 57 | print_freq: 20 58 | 59 | #============================ Model ============================================ 60 | 61 | # Supported model architecture 62 | # choices: 63 | # ImageNet: 64 | # resnet18, resnet34, resnet50, resnet101, resnet152 65 | # CIFAR10: 66 | # resnet20, resnet32, resnet44, resnet56, resnet110, resnet1202 67 | arch: resnet18 68 | 69 | # Use pre-trained model 70 | pre_trained: true 71 | 72 | #============================ Quantization ===================================== 73 | 74 | quan: 75 | act: # (default for all layers) 76 | # Quantizer type (choices: lsq) 77 | mode: lsq 78 | # Bit width of quantized activation 79 | bit: 3 80 | # Each output channel uses its own scaling factor 81 | per_channel: false 82 | # Whether to use symmetric quantization 83 | symmetric: false 84 | # Quantize all the numbers to non-negative 85 | all_positive: true 86 | weight: # (default for all layers) 87 | # Quantizer type (choices: lsq) 88 | mode: lsq 89 | # Bit width of quantized weight 90 | bit: 3 91 | # Each output channel uses its own scaling factor 92 | per_channel: true 93 | # Whether to use symmetric quantization 94 | symmetric: false 95 | # Whether to quantize all the numbers to non-negative 96 | all_positive: false 97 | excepts: 98 | # Specify quantized bit width for some layers, like this: 99 | conv1: 100 | act: 101 | all_positive: false 102 | weight: 103 | bit: 104 | fc: 105 | act: 106 | bit: 107 | weight: 108 | bit: 109 | 110 | #============================ Training / Evaluation ============================ 111 | 112 | # Evaluate the model without training 113 | # If this field is true, all the bellowing options will be ignored 114 | eval: false 115 | 116 | epochs: 90 117 | 118 | optimizer: 119 | learning_rate: 0.01 120 | momentum: 0.9 121 | weight_decay: 0.0001 122 | 123 | # Learning rate scheduler 124 | lr_scheduler: 125 | # Update learning rate per batch or epoch 126 | update_per_batch: true 127 | 128 | # Uncomment one of bellowing options to activate a learning rate scheduling 129 | 130 | # Fixed learning rate 131 | mode: fixed 132 | 133 | # Step decay 134 | # mode: step 135 | # step_size: 30 136 | # gamma: 0.1 137 | 138 | # Multi-step decay 139 | # mode: multi_step 140 | # milestones: [30, ] 141 | # gamma: 0.1 142 | 143 | # Exponential decay 144 | # mode: exp 145 | # gamma: 0.95 146 | 147 | # Cosine annealing 148 | # mode: cos 149 | # lr_min: 0 150 | # cycle: 0.95 151 | 152 | # Cosine annealing with warm restarts 153 | # mode: cos_warm_restarts 154 | # lr_min: 0 155 | # cycle: 5 156 | # cycle_scale: 2 157 | # amp_scale: 0.5 158 | -------------------------------------------------------------------------------- /examples/lsq/resnet18_a2w3_imagenet.yaml: -------------------------------------------------------------------------------- 1 | # Experiment name 2 | name: ResNet18_ImageNet_a2w3 3 | 4 | # Dataset loader 5 | dataloader: 6 | # Dataset to train/validate (choices: imagenet, cifar10) 7 | dataset: imagenet 8 | # Number of categories in the specified dataset (choices: 1000, 10) 9 | num_classes: 1000 10 | # Path to dataset directory 11 | path: /localhome/fair/Dataset/imagenet 12 | # Size of mini-batch 13 | batch_size: 256 14 | # Portion of training dataset to set aside for validation (range: [0, 1)) 15 | val_split: 0. 16 | 17 | resume: 18 | # Path to a checkpoint to be loaded. Leave blank to skip 19 | path: 20 | # Resume model parameters only 21 | lean: false 22 | 23 | #============================ Model ============================================ 24 | 25 | # Supported model architecture 26 | # choices: 27 | # ImageNet: 28 | # resnet18, resnet34, resnet50, resnet101, resnet152 29 | # CIFAR10: 30 | # resnet20, resnet32, resnet44, resnet56, resnet110, resnet1202 31 | arch: resnet18 32 | 33 | # Use pre-trained model 34 | pre_trained: true 35 | 36 | #============================ Quantization ===================================== 37 | 38 | quan: 39 | act: # (default for all layers) 40 | # Quantizer type (choices: lsq) 41 | mode: lsq 42 | # Bit width of quantized activation 43 | bit: 2 44 | # Each output channel uses its own scaling factor 45 | per_channel: false 46 | # Whether to use symmetric quantization 47 | symmetric: false 48 | # Quantize all the numbers to non-negative 49 | all_positive: true 50 | weight: # (default for all layers) 51 | # Quantizer type (choices: lsq) 52 | mode: lsq 53 | # Bit width of quantized weight 54 | bit: 3 55 | # Each output channel uses its own scaling factor 56 | per_channel: false 57 | # Whether to use symmetric quantization 58 | symmetric: true 59 | # Whether to quantize all the numbers to non-negative 60 | all_positive: false 61 | excepts: 62 | # Specify quantized bit width for some layers, like this: 63 | conv1: 64 | act: 65 | bit: 66 | all_positive: false 67 | weight: 68 | bit: 69 | fc: 70 | act: 71 | bit: 72 | weight: 73 | bit: 74 | 75 | #============================ Training / Evaluation ============================ 76 | 77 | # Evaluate the model without training 78 | # If this field is true, all the bellowing options will be ignored 79 | eval: false 80 | 81 | epochs: 120 82 | 83 | optimizer: 84 | learning_rate: 0.01 85 | momentum: 0.9 86 | weight_decay: 0.0001 87 | 88 | # Learning rate scheduler 89 | lr_scheduler: 90 | # Update learning rate per batch or epoch 91 | update_per_batch: true 92 | 93 | # Uncomment one of bellowing options to activate a learning rate scheduling 94 | 95 | # Fixed learning rate 96 | # mode: fixed 97 | 98 | # Step decay 99 | # mode: step 100 | # step_size: 30 101 | # gamma: 0.1 102 | 103 | # Multi-step decay 104 | mode: multi_step 105 | milestones: [30, 60, 90] 106 | gamma: 0.1 107 | 108 | # Exponential decay 109 | # mode: exp 110 | # gamma: 0.95 111 | 112 | # Cosine annealing 113 | # mode: cos 114 | # lr_min: 0 115 | # cycle: 0.95 116 | 117 | # Cosine annealing with warm restarts 118 | # mode: cos_warm_restarts 119 | # lr_min: 0 120 | # cycle: 5 121 | # cycle_scale: 2 122 | # amp_scale: 0.5 123 | -------------------------------------------------------------------------------- /logging.conf: -------------------------------------------------------------------------------- 1 | [formatters] 2 | keys: simple, time_simple 3 | 4 | [handlers] 5 | keys: console, file 6 | 7 | [loggers] 8 | keys: root 9 | 10 | [formatter_simple] 11 | format: %(levelname)s - %(message)s 12 | 13 | [formatter_time_simple] 14 | format: %(asctime)s - %(levelname)s - %(message)s 15 | 16 | [handler_console] 17 | class: StreamHandler 18 | propagate: 0 19 | args: [] 20 | formatter: simple 21 | 22 | [handler_file] 23 | class: FileHandler 24 | mode: 'w' 25 | args=('%(logfilename)s', 'w') 26 | formatter: time_simple 27 | 28 | [logger_root] 29 | level: INFO 30 | propagate: 1 31 | handlers: console, file 32 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pathlib import Path 3 | 4 | import torch as t 5 | import yaml 6 | 7 | import process 8 | import quan 9 | import util 10 | from model import create_model 11 | 12 | 13 | def main(): 14 | script_dir = Path.cwd() 15 | args = util.get_config(default_file=script_dir / 'config.yaml') 16 | 17 | output_dir = script_dir / args.output_dir 18 | output_dir.mkdir(exist_ok=True) 19 | 20 | log_dir = util.init_logger(args.name, output_dir, script_dir / 'logging.conf') 21 | logger = logging.getLogger() 22 | 23 | with open(log_dir / "args.yaml", "w") as yaml_file: # dump experiment config 24 | yaml.safe_dump(args, yaml_file) 25 | 26 | pymonitor = util.ProgressMonitor(logger) 27 | tbmonitor = util.TensorBoardMonitor(logger, log_dir) 28 | monitors = [pymonitor, tbmonitor] 29 | 30 | if args.device.type == 'cpu' or not t.cuda.is_available() or args.device.gpu == []: 31 | args.device.gpu = [] 32 | else: 33 | available_gpu = t.cuda.device_count() 34 | for dev_id in args.device.gpu: 35 | if dev_id >= available_gpu: 36 | logger.error('GPU device ID {0} requested, but only {1} devices available' 37 | .format(dev_id, available_gpu)) 38 | exit(1) 39 | # Set default device in case the first one on the list 40 | t.cuda.set_device(args.device.gpu[0]) 41 | # Enable the cudnn built-in auto-tuner to accelerating training, but it 42 | # will introduce some fluctuations in a narrow range. 43 | t.backends.cudnn.benchmark = True 44 | t.backends.cudnn.deterministic = False 45 | 46 | # Initialize data loader 47 | train_loader, val_loader, test_loader = util.load_data(args.dataloader) 48 | logger.info('Dataset `%s` size:' % args.dataloader.dataset + 49 | '\n Training Set = %d (%d)' % (len(train_loader.sampler), len(train_loader)) + 50 | '\n Validation Set = %d (%d)' % (len(val_loader.sampler), len(val_loader)) + 51 | '\n Test Set = %d (%d)' % (len(test_loader.sampler), len(test_loader))) 52 | 53 | # Create the model 54 | model = create_model(args) 55 | 56 | modules_to_replace = quan.find_modules_to_quantize(model, args.quan) 57 | model = quan.replace_module_by_names(model, modules_to_replace) 58 | tbmonitor.writer.add_graph(model, input_to_model=train_loader.dataset[0][0].unsqueeze(0)) 59 | logger.info('Inserted quantizers into the original model') 60 | 61 | if args.device.gpu and not args.dataloader.serialized: 62 | model = t.nn.DataParallel(model, device_ids=args.device.gpu) 63 | 64 | model.to(args.device.type) 65 | 66 | start_epoch = 0 67 | if args.resume.path: 68 | model, start_epoch, _ = util.load_checkpoint( 69 | model, args.resume.path, args.device.type, lean=args.resume.lean) 70 | 71 | # Define loss function (criterion) and optimizer 72 | criterion = t.nn.CrossEntropyLoss().to(args.device.type) 73 | 74 | # optimizer = t.optim.Adam(model.parameters(), lr=args.optimizer.learning_rate) 75 | optimizer = t.optim.SGD(model.parameters(), 76 | lr=args.optimizer.learning_rate, 77 | momentum=args.optimizer.momentum, 78 | weight_decay=args.optimizer.weight_decay) 79 | lr_scheduler = util.lr_scheduler(optimizer, 80 | batch_size=train_loader.batch_size, 81 | num_samples=len(train_loader.sampler), 82 | **args.lr_scheduler) 83 | logger.info(('Optimizer: %s' % optimizer).replace('\n', '\n' + ' ' * 11)) 84 | logger.info('LR scheduler: %s\n' % lr_scheduler) 85 | 86 | perf_scoreboard = process.PerformanceScoreboard(args.log.num_best_scores) 87 | 88 | if args.eval: 89 | process.validate(test_loader, model, criterion, -1, monitors, args) 90 | else: # training 91 | if args.resume.path or args.pre_trained: 92 | logger.info('>>>>>>>> Epoch -1 (pre-trained model evaluation)') 93 | top1, top5, _ = process.validate(val_loader, model, criterion, 94 | start_epoch - 1, monitors, args) 95 | perf_scoreboard.update(top1, top5, start_epoch - 1) 96 | for epoch in range(start_epoch, args.epochs): 97 | logger.info('>>>>>>>> Epoch %3d' % epoch) 98 | t_top1, t_top5, t_loss = process.train(train_loader, model, criterion, optimizer, 99 | lr_scheduler, epoch, monitors, args) 100 | v_top1, v_top5, v_loss = process.validate(val_loader, model, criterion, epoch, monitors, args) 101 | 102 | tbmonitor.writer.add_scalars('Train_vs_Validation/Loss', {'train': t_loss, 'val': v_loss}, epoch) 103 | tbmonitor.writer.add_scalars('Train_vs_Validation/Top1', {'train': t_top1, 'val': v_top1}, epoch) 104 | tbmonitor.writer.add_scalars('Train_vs_Validation/Top5', {'train': t_top5, 'val': v_top5}, epoch) 105 | 106 | perf_scoreboard.update(v_top1, v_top5, epoch) 107 | is_best = perf_scoreboard.is_best(epoch) 108 | util.save_checkpoint(epoch, args.arch, model, {'top1': v_top1, 'top5': v_top5}, is_best, args.name, log_dir) 109 | 110 | logger.info('>>>>>>>> Epoch -1 (final model evaluation)') 111 | process.validate(test_loader, model, criterion, -1, monitors, args) 112 | 113 | tbmonitor.writer.close() # close the TensorBoard 114 | logger.info('Program completed successfully ... exiting ...') 115 | logger.info('If you have any questions or suggestions, please visit: github.com/zhutmost/lsq-net') 116 | 117 | 118 | if __name__ == "__main__": 119 | main() 120 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import create_model 2 | -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from .resnet import * 4 | from .resnet_cifar import * 5 | 6 | 7 | def create_model(args): 8 | logger = logging.getLogger() 9 | 10 | model = None 11 | if args.dataloader.dataset == 'imagenet': 12 | if args.arch == 'resnet18': 13 | model = resnet18(pretrained=args.pre_trained) 14 | elif args.arch == 'resnet34': 15 | model = resnet34(pretrained=args.pre_trained) 16 | elif args.arch == 'resnet50': 17 | model = resnet50(pretrained=args.pre_trained) 18 | elif args.arch == 'resnet101': 19 | model = resnet101(pretrained=args.pre_trained) 20 | elif args.arch == 'resnet152': 21 | model = resnet152(pretrained=args.pre_trained) 22 | elif args.dataloader.dataset == 'cifar10': 23 | if args.arch == 'resnet20': 24 | model = resnet20(pretrained=args.pre_trained) 25 | elif args.arch == 'resnet32': 26 | model = resnet32(pretrained=args.pre_trained) 27 | elif args.arch == 'resnet44': 28 | model = resnet44(pretrained=args.pre_trained) 29 | elif args.arch == 'resnet56': 30 | model = resnet56(pretrained=args.pre_trained) 31 | elif args.arch == 'resnet110': 32 | model = resnet152(pretrained=args.pre_trained) 33 | elif args.arch == 'resnet1202': 34 | model = resnet1202(pretrained=args.pre_trained) 35 | 36 | if model is None: 37 | logger.error('Model architecture `%s` for `%s` dataset is not supported' % (args.arch, args.dataloader.dataset)) 38 | exit(-1) 39 | 40 | msg = 'Created `%s` model for `%s` dataset' % (args.arch, args.dataloader.dataset) 41 | msg += '\n Use pre-trained model = %s' % args.pre_trained 42 | logger.info(msg) 43 | 44 | return model 45 | -------------------------------------------------------------------------------- /model/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.hub import load_state_dict_from_url 4 | 5 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 6 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 7 | 'wide_resnet50_2', 'wide_resnet101_2'] 8 | 9 | model_urls = { 10 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 11 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 15 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 16 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 17 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 18 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 19 | } 20 | 21 | 22 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 23 | """3x3 convolution with padding""" 24 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 25 | padding=dilation, groups=groups, bias=False, dilation=dilation) 26 | 27 | 28 | def conv1x1(in_planes, out_planes, stride=1): 29 | """1x1 convolution""" 30 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 31 | 32 | 33 | class BasicBlock(nn.Module): 34 | expansion = 1 35 | 36 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 37 | base_width=64, dilation=1, norm_layer=None): 38 | super(BasicBlock, self).__init__() 39 | if norm_layer is None: 40 | norm_layer = nn.BatchNorm2d 41 | if groups != 1 or base_width != 64: 42 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 43 | if dilation > 1: 44 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 45 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 46 | self.conv1 = conv3x3(inplanes, planes, stride) 47 | self.bn1 = norm_layer(planes) 48 | self.relu = nn.ReLU(inplace=True) 49 | self.conv2 = conv3x3(planes, planes) 50 | self.bn2 = norm_layer(planes) 51 | self.downsample = downsample 52 | self.stride = stride 53 | 54 | def forward(self, x): 55 | identity = x 56 | 57 | out = self.conv1(x) 58 | out = self.bn1(out) 59 | out = self.relu(out) 60 | 61 | out = self.conv2(out) 62 | out = self.bn2(out) 63 | 64 | if self.downsample is not None: 65 | identity = self.downsample(x) 66 | 67 | out += identity 68 | out = self.relu(out) 69 | 70 | return out 71 | 72 | 73 | class Bottleneck(nn.Module): 74 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 75 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 76 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 77 | # This variant is also known as ResNet V1.5 and improves accuracy according to 78 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 79 | 80 | expansion = 4 81 | 82 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 83 | base_width=64, dilation=1, norm_layer=None): 84 | super(Bottleneck, self).__init__() 85 | if norm_layer is None: 86 | norm_layer = nn.BatchNorm2d 87 | width = int(planes * (base_width / 64.)) * groups 88 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 89 | self.conv1 = conv1x1(inplanes, width) 90 | self.bn1 = norm_layer(width) 91 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 92 | self.bn2 = norm_layer(width) 93 | self.conv3 = conv1x1(width, planes * self.expansion) 94 | self.bn3 = norm_layer(planes * self.expansion) 95 | self.relu = nn.ReLU(inplace=True) 96 | self.downsample = downsample 97 | self.stride = stride 98 | 99 | def forward(self, x): 100 | identity = x 101 | 102 | out = self.conv1(x) 103 | out = self.bn1(out) 104 | out = self.relu(out) 105 | 106 | out = self.conv2(out) 107 | out = self.bn2(out) 108 | out = self.relu(out) 109 | 110 | out = self.conv3(out) 111 | out = self.bn3(out) 112 | 113 | if self.downsample is not None: 114 | identity = self.downsample(x) 115 | 116 | out += identity 117 | out = self.relu(out) 118 | 119 | return out 120 | 121 | 122 | class ResNet(nn.Module): 123 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 124 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 125 | norm_layer=None): 126 | super(ResNet, self).__init__() 127 | if norm_layer is None: 128 | norm_layer = nn.BatchNorm2d 129 | self._norm_layer = norm_layer 130 | 131 | self.inplanes = 64 132 | self.dilation = 1 133 | if replace_stride_with_dilation is None: 134 | # each element in the tuple indicates if we should replace 135 | # the 2x2 stride with a dilated convolution instead 136 | replace_stride_with_dilation = [False, False, False] 137 | if len(replace_stride_with_dilation) != 3: 138 | raise ValueError("replace_stride_with_dilation should be None " 139 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 140 | self.groups = groups 141 | self.base_width = width_per_group 142 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 143 | bias=False) 144 | self.bn1 = norm_layer(self.inplanes) 145 | self.relu = nn.ReLU(inplace=True) 146 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 147 | self.layer1 = self._make_layer(block, 64, layers[0]) 148 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 149 | dilate=replace_stride_with_dilation[0]) 150 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 151 | dilate=replace_stride_with_dilation[1]) 152 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 153 | dilate=replace_stride_with_dilation[2]) 154 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 155 | self.fc = nn.Linear(512 * block.expansion, num_classes) 156 | 157 | for m in self.modules(): 158 | if isinstance(m, nn.Conv2d): 159 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 160 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 161 | nn.init.constant_(m.weight, 1) 162 | nn.init.constant_(m.bias, 0) 163 | 164 | # Zero-initialize the last BN in each residual branch, 165 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 166 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 167 | if zero_init_residual: 168 | for m in self.modules(): 169 | if isinstance(m, Bottleneck): 170 | nn.init.constant_(m.bn3.weight, 0) 171 | elif isinstance(m, BasicBlock): 172 | nn.init.constant_(m.bn2.weight, 0) 173 | 174 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 175 | norm_layer = self._norm_layer 176 | downsample = None 177 | previous_dilation = self.dilation 178 | if dilate: 179 | self.dilation *= stride 180 | stride = 1 181 | if stride != 1 or self.inplanes != planes * block.expansion: 182 | downsample = nn.Sequential( 183 | conv1x1(self.inplanes, planes * block.expansion, stride), 184 | norm_layer(planes * block.expansion), 185 | ) 186 | 187 | layers = [] 188 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 189 | self.base_width, previous_dilation, norm_layer)) 190 | self.inplanes = planes * block.expansion 191 | for _ in range(1, blocks): 192 | layers.append(block(self.inplanes, planes, groups=self.groups, 193 | base_width=self.base_width, dilation=self.dilation, 194 | norm_layer=norm_layer)) 195 | 196 | return nn.Sequential(*layers) 197 | 198 | def _forward_impl(self, x): 199 | # See note [TorchScript super()] 200 | x = self.conv1(x) 201 | x = self.bn1(x) 202 | x = self.relu(x) 203 | x = self.maxpool(x) 204 | 205 | x = self.layer1(x) 206 | x = self.layer2(x) 207 | x = self.layer3(x) 208 | x = self.layer4(x) 209 | 210 | x = self.avgpool(x) 211 | x = torch.flatten(x, 1) 212 | x = self.fc(x) 213 | 214 | return x 215 | 216 | def forward(self, x): 217 | return self._forward_impl(x) 218 | 219 | 220 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 221 | model = ResNet(block, layers, **kwargs) 222 | 223 | if pretrained: 224 | state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) 225 | model.load_state_dict(state_dict) 226 | return model 227 | 228 | 229 | def resnet18(pretrained=False, progress=True, **kwargs): 230 | r"""ResNet-18 model from 231 | `"Deep Residual Learning for Image Recognition" `_ 232 | 233 | Args: 234 | pretrained (bool): If True, returns a model pre-trained on ImageNet 235 | progress (bool): If True, displays a progress bar of the download to stderr 236 | """ 237 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 238 | **kwargs) 239 | 240 | 241 | def resnet34(pretrained=False, progress=True, **kwargs): 242 | r"""ResNet-34 model from 243 | `"Deep Residual Learning for Image Recognition" `_ 244 | 245 | Args: 246 | pretrained (bool): If True, returns a model pre-trained on ImageNet 247 | progress (bool): If True, displays a progress bar of the download to stderr 248 | """ 249 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 250 | **kwargs) 251 | 252 | 253 | def resnet50(pretrained=False, progress=True, **kwargs): 254 | r"""ResNet-50 model from 255 | `"Deep Residual Learning for Image Recognition" `_ 256 | 257 | Args: 258 | pretrained (bool): If True, returns a model pre-trained on ImageNet 259 | progress (bool): If True, displays a progress bar of the download to stderr 260 | """ 261 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 262 | **kwargs) 263 | 264 | 265 | def resnet101(pretrained=False, progress=True, **kwargs): 266 | r"""ResNet-101 model from 267 | `"Deep Residual Learning for Image Recognition" `_ 268 | 269 | Args: 270 | pretrained (bool): If True, returns a model pre-trained on ImageNet 271 | progress (bool): If True, displays a progress bar of the download to stderr 272 | """ 273 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 274 | **kwargs) 275 | 276 | 277 | def resnet152(pretrained=False, progress=True, **kwargs): 278 | r"""ResNet-152 model from 279 | `"Deep Residual Learning for Image Recognition" `_ 280 | 281 | Args: 282 | pretrained (bool): If True, returns a model pre-trained on ImageNet 283 | progress (bool): If True, displays a progress bar of the download to stderr 284 | """ 285 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 286 | **kwargs) 287 | 288 | 289 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 290 | r"""ResNeXt-50 32x4d model from 291 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 292 | 293 | Args: 294 | pretrained (bool): If True, returns a model pre-trained on ImageNet 295 | progress (bool): If True, displays a progress bar of the download to stderr 296 | """ 297 | kwargs['groups'] = 32 298 | kwargs['width_per_group'] = 4 299 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 300 | pretrained, progress, **kwargs) 301 | 302 | 303 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 304 | r"""ResNeXt-101 32x8d model from 305 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 306 | 307 | Args: 308 | pretrained (bool): If True, returns a model pre-trained on ImageNet 309 | progress (bool): If True, displays a progress bar of the download to stderr 310 | """ 311 | kwargs['groups'] = 32 312 | kwargs['width_per_group'] = 8 313 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 314 | pretrained, progress, **kwargs) 315 | 316 | 317 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 318 | r"""Wide ResNet-50-2 model from 319 | `"Wide Residual Networks" `_ 320 | 321 | The model is the same as ResNet except for the bottleneck number of channels 322 | which is twice larger in every block. The number of channels in outer 1x1 323 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 324 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 325 | 326 | Args: 327 | pretrained (bool): If True, returns a model pre-trained on ImageNet 328 | progress (bool): If True, displays a progress bar of the download to stderr 329 | """ 330 | kwargs['width_per_group'] = 64 * 2 331 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 332 | pretrained, progress, **kwargs) 333 | 334 | 335 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 336 | r"""Wide ResNet-101-2 model from 337 | `"Wide Residual Networks" `_ 338 | 339 | The model is the same as ResNet except for the bottleneck number of channels 340 | which is twice larger in every block. The number of channels in outer 1x1 341 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 342 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 343 | 344 | Args: 345 | pretrained (bool): If True, returns a model pre-trained on ImageNet 346 | progress (bool): If True, displays a progress bar of the download to stderr 347 | """ 348 | kwargs['width_per_group'] = 64 * 2 349 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 350 | pretrained, progress, **kwargs) 351 | -------------------------------------------------------------------------------- /model/resnet_cifar.py: -------------------------------------------------------------------------------- 1 | """ 2 | Properly implemented ResNet-s for CIFAR10 as described in paper [1]. 3 | 4 | The implementation and structure of this file is hugely influenced by [2] 5 | which is implemented for ImageNet and doesn't have option A for identity. 6 | Moreover, most of the implementations on the web is copy-paste from 7 | torchvision's resnet and has wrong number of params. 8 | 9 | Proper ResNet-s for CIFAR10 (for fair comparision and etc.) has following 10 | number of layers and parameters: 11 | 12 | name | layers | params 13 | ResNet20 | 20 | 0.27M 14 | ResNet32 | 32 | 0.46M 15 | ResNet44 | 44 | 0.66M 16 | ResNet56 | 56 | 0.85M 17 | ResNet110 | 110 | 1.7M 18 | ResNet1202| 1202 | 19.4m 19 | 20 | which this implementation indeed has. 21 | 22 | Reference: 23 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 24 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 25 | [2] https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 26 | 27 | If you use this implementation in you work, please don't forget to mention the 28 | author, Yerlan Idelbayev. 29 | """ 30 | from collections import OrderedDict 31 | 32 | import torch.nn as nn 33 | import torch.nn.functional as F 34 | import torch.nn.init as init 35 | from torch.hub import load_state_dict_from_url 36 | 37 | __all__ = ['resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202'] 38 | 39 | model_urls = { 40 | 'resnet20': 'https://raw.githubusercontent.com/akamaster/pytorch_resnet_cifar10/master/pretrained_models/resnet20-12fca82f.th', 41 | 'resnet32': 'https://raw.githubusercontent.com/akamaster/pytorch_resnet_cifar10/master/pretrained_models/resnet32-d509ac18.th', 42 | 'resnet44': 'https://raw.githubusercontent.com/akamaster/pytorch_resnet_cifar10/master/pretrained_models/resnet44-014dd654.th', 43 | 'resnet56': 'https://raw.githubusercontent.com/akamaster/pytorch_resnet_cifar10/master/pretrained_models/resnet56-4bfd9763.th', 44 | 'resnet110': 'https://raw.githubusercontent.com/akamaster/pytorch_resnet_cifar10/master/pretrained_models/resnet110-1d1ed7c2.th.th', 45 | 'resnet1202': 'https://raw.githubusercontent.com/akamaster/pytorch_resnet_cifar10/master/pretrained_models/resnet1202-f3b1deed.th', 46 | } 47 | 48 | 49 | def _weights_init(m): 50 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 51 | init.kaiming_normal_(m.weight) 52 | 53 | 54 | class LambdaLayer(nn.Module): 55 | def __init__(self, lambd): 56 | super().__init__() 57 | self.lambd = lambd 58 | 59 | def forward(self, x): 60 | return self.lambd(x) 61 | 62 | 63 | class BasicBlock(nn.Module): 64 | expansion = 1 65 | 66 | def __init__(self, in_planes, planes, stride=1, option='A'): 67 | super(BasicBlock, self).__init__() 68 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 69 | self.bn1 = nn.BatchNorm2d(planes) 70 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 71 | self.bn2 = nn.BatchNorm2d(planes) 72 | 73 | self.shortcut = nn.Sequential() 74 | if stride != 1 or in_planes != planes: 75 | if option == 'A': 76 | """ 77 | For CIFAR10 ResNet paper uses option A. 78 | """ 79 | self.shortcut = LambdaLayer(lambda x: 80 | F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes // 4, planes // 4), "constant", 81 | 0)) 82 | elif option == 'B': 83 | self.shortcut = nn.Sequential( 84 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 85 | nn.BatchNorm2d(self.expansion * planes) 86 | ) 87 | 88 | def forward(self, x): 89 | out = F.relu(self.bn1(self.conv1(x))) 90 | out = self.bn2(self.conv2(out)) 91 | out += self.shortcut(x) 92 | out = F.relu(out) 93 | return out 94 | 95 | 96 | class ResNet(nn.Module): 97 | def __init__(self, block, num_blocks, num_classes=10): 98 | super(ResNet, self).__init__() 99 | self.in_planes = 16 100 | 101 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 102 | self.bn1 = nn.BatchNorm2d(16) 103 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 104 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 105 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) 106 | self.linear = nn.Linear(64, num_classes) 107 | 108 | self.apply(_weights_init) 109 | 110 | def _make_layer(self, block, planes, num_blocks, stride): 111 | strides = [stride] + [1] * (num_blocks - 1) 112 | layers = [] 113 | for stride in strides: 114 | layers.append(block(self.in_planes, planes, stride)) 115 | self.in_planes = planes * block.expansion 116 | 117 | return nn.Sequential(*layers) 118 | 119 | def forward(self, x): 120 | out = F.relu(self.bn1(self.conv1(x))) 121 | out = self.layer1(out) 122 | out = self.layer2(out) 123 | out = self.layer3(out) 124 | out = F.avg_pool2d(out, out.size()[3]) 125 | out = out.view(out.size(0), -1) 126 | out = self.linear(out) 127 | return out 128 | 129 | 130 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 131 | model = ResNet(block, layers, **kwargs) 132 | 133 | if pretrained: 134 | s = load_state_dict_from_url(model_urls[arch], progress=progress) 135 | state_dict = OrderedDict() 136 | for k, v in s['state_dict'].items(): 137 | if k.startswith('module.'): 138 | state_dict[k[7:]] = v 139 | model.load_state_dict(state_dict) 140 | return model 141 | 142 | 143 | def resnet20(pretrained=False, progress=True): 144 | return _resnet('resnet20', BasicBlock, [3, 3, 3], pretrained, progress) 145 | 146 | 147 | def resnet32(pretrained=False, progress=True): 148 | return _resnet('resnet32', BasicBlock, [5, 5, 5], pretrained, progress) 149 | 150 | 151 | def resnet44(pretrained=False, progress=True): 152 | return _resnet('resnet44', BasicBlock, [7, 7, 7], pretrained, progress) 153 | 154 | 155 | def resnet56(pretrained=False, progress=True): 156 | return _resnet('resnet56', BasicBlock, [9, 9, 9], pretrained, progress) 157 | 158 | 159 | def resnet110(pretrained=False, progress=True): 160 | return _resnet('resnet110', BasicBlock, [18, 18, 18], pretrained, progress) 161 | 162 | 163 | def resnet1202(pretrained=False, progress=True): 164 | return _resnet('resnet1202', BasicBlock, [200, 200, 200], pretrained, progress) 165 | -------------------------------------------------------------------------------- /process.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import operator 4 | import time 5 | 6 | import torch as t 7 | 8 | from util import AverageMeter 9 | 10 | __all__ = ['train', 'validate', 'PerformanceScoreboard'] 11 | 12 | logger = logging.getLogger() 13 | 14 | 15 | def accuracy(output, target, topk=(1,)): 16 | """Computes the accuracy over the k top predictions for the specified values of k""" 17 | with t.no_grad(): 18 | maxk = max(topk) 19 | batch_size = target.size(0) 20 | 21 | _, pred = output.topk(maxk, 1, True, True) 22 | pred = pred.t() 23 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 24 | 25 | res = [] 26 | for k in topk: 27 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 28 | res.append(correct_k.mul_(100.0 / batch_size)) 29 | return res 30 | 31 | 32 | def train(train_loader, model, criterion, optimizer, lr_scheduler, epoch, monitors, args): 33 | losses = AverageMeter() 34 | top1 = AverageMeter() 35 | top5 = AverageMeter() 36 | batch_time = AverageMeter() 37 | 38 | total_sample = len(train_loader.sampler) 39 | batch_size = train_loader.batch_size 40 | steps_per_epoch = math.ceil(total_sample / batch_size) 41 | logger.info('Training: %d samples (%d per mini-batch)', total_sample, batch_size) 42 | 43 | model.train() 44 | end_time = time.time() 45 | for batch_idx, (inputs, targets) in enumerate(train_loader): 46 | inputs = inputs.to(args.device.type) 47 | targets = targets.to(args.device.type) 48 | 49 | outputs = model(inputs) 50 | loss = criterion(outputs, targets) 51 | 52 | acc1, acc5 = accuracy(outputs.data, targets.data, topk=(1, 5)) 53 | losses.update(loss.item(), inputs.size(0)) 54 | top1.update(acc1.item(), inputs.size(0)) 55 | top5.update(acc5.item(), inputs.size(0)) 56 | 57 | if lr_scheduler is not None: 58 | lr_scheduler.step(epoch=epoch, batch=batch_idx) 59 | 60 | optimizer.zero_grad() 61 | loss.backward() 62 | optimizer.step() 63 | 64 | batch_time.update(time.time() - end_time) 65 | end_time = time.time() 66 | 67 | if (batch_idx + 1) % args.log.print_freq == 0: 68 | for m in monitors: 69 | m.update(epoch, batch_idx + 1, steps_per_epoch, 'Training', { 70 | 'Loss': losses, 71 | 'Top1': top1, 72 | 'Top5': top5, 73 | 'BatchTime': batch_time, 74 | 'LR': optimizer.param_groups[0]['lr'] 75 | }) 76 | 77 | logger.info('==> Top1: %.3f Top5: %.3f Loss: %.3f\n', 78 | top1.avg, top5.avg, losses.avg) 79 | return top1.avg, top5.avg, losses.avg 80 | 81 | 82 | def validate(data_loader, model, criterion, epoch, monitors, args): 83 | losses = AverageMeter() 84 | top1 = AverageMeter() 85 | top5 = AverageMeter() 86 | batch_time = AverageMeter() 87 | 88 | total_sample = len(data_loader.sampler) 89 | batch_size = data_loader.batch_size 90 | steps_per_epoch = math.ceil(total_sample / batch_size) 91 | 92 | logger.info('Validation: %d samples (%d per mini-batch)', total_sample, batch_size) 93 | 94 | model.eval() 95 | end_time = time.time() 96 | for batch_idx, (inputs, targets) in enumerate(data_loader): 97 | with t.no_grad(): 98 | inputs = inputs.to(args.device.type) 99 | targets = targets.to(args.device.type) 100 | 101 | outputs = model(inputs) 102 | loss = criterion(outputs, targets) 103 | 104 | acc1, acc5 = accuracy(outputs.data, targets.data, topk=(1, 5)) 105 | losses.update(loss.item(), inputs.size(0)) 106 | top1.update(acc1.item(), inputs.size(0)) 107 | top5.update(acc5.item(), inputs.size(0)) 108 | batch_time.update(time.time() - end_time) 109 | end_time = time.time() 110 | 111 | if (batch_idx + 1) % args.log.print_freq == 0: 112 | for m in monitors: 113 | m.update(epoch, batch_idx + 1, steps_per_epoch, 'Validation', { 114 | 'Loss': losses, 115 | 'Top1': top1, 116 | 'Top5': top5, 117 | 'BatchTime': batch_time 118 | }) 119 | 120 | logger.info('==> Top1: %.3f Top5: %.3f Loss: %.3f\n', top1.avg, top5.avg, losses.avg) 121 | return top1.avg, top5.avg, losses.avg 122 | 123 | 124 | class PerformanceScoreboard: 125 | def __init__(self, num_best_scores): 126 | self.board = list() 127 | self.num_best_scores = num_best_scores 128 | 129 | def update(self, top1, top5, epoch): 130 | """ Update the list of top training scores achieved so far, and log the best scores so far""" 131 | self.board.append({'top1': top1, 'top5': top5, 'epoch': epoch}) 132 | 133 | # Keep scoreboard sorted from best to worst, and sort by top1, top5 and epoch 134 | curr_len = min(self.num_best_scores, len(self.board)) 135 | self.board = sorted(self.board, 136 | key=operator.itemgetter('top1', 'top5', 'epoch'), 137 | reverse=True)[0:curr_len] 138 | for idx in range(curr_len): 139 | score = self.board[idx] 140 | logger.info('Scoreboard best %d ==> Epoch [%d][Top1: %.3f Top5: %.3f]', 141 | idx + 1, score['epoch'], score['top1'], score['top5']) 142 | 143 | def is_best(self, epoch): 144 | return self.board[0]['epoch'] == epoch 145 | -------------------------------------------------------------------------------- /quan/__init__.py: -------------------------------------------------------------------------------- 1 | from .func import * 2 | from .utils import * 3 | -------------------------------------------------------------------------------- /quan/func.py: -------------------------------------------------------------------------------- 1 | import torch as t 2 | 3 | 4 | class QuanConv2d(t.nn.Conv2d): 5 | def __init__(self, m: t.nn.Conv2d, quan_w_fn=None, quan_a_fn=None): 6 | assert type(m) == t.nn.Conv2d 7 | super().__init__(m.in_channels, m.out_channels, m.kernel_size, 8 | stride=m.stride, 9 | padding=m.padding, 10 | dilation=m.dilation, 11 | groups=m.groups, 12 | bias=True if m.bias is not None else False, 13 | padding_mode=m.padding_mode) 14 | self.quan_w_fn = quan_w_fn 15 | self.quan_a_fn = quan_a_fn 16 | 17 | self.weight = t.nn.Parameter(m.weight.detach()) 18 | self.quan_w_fn.init_from(m.weight) 19 | if m.bias is not None: 20 | self.bias = t.nn.Parameter(m.bias.detach()) 21 | 22 | def forward(self, x): 23 | quantized_weight = self.quan_w_fn(self.weight) 24 | quantized_act = self.quan_a_fn(x) 25 | return self._conv_forward(quantized_act, quantized_weight) 26 | 27 | 28 | class QuanLinear(t.nn.Linear): 29 | def __init__(self, m: t.nn.Linear, quan_w_fn=None, quan_a_fn=None): 30 | assert type(m) == t.nn.Linear 31 | super().__init__(m.in_features, m.out_features, 32 | bias=True if m.bias is not None else False) 33 | self.quan_w_fn = quan_w_fn 34 | self.quan_a_fn = quan_a_fn 35 | 36 | self.weight = t.nn.Parameter(m.weight.detach()) 37 | self.quan_w_fn.init_from(m.weight) 38 | if m.bias is not None: 39 | self.bias = t.nn.Parameter(m.bias.detach()) 40 | 41 | def forward(self, x): 42 | quantized_weight = self.quan_w_fn(self.weight) 43 | quantized_act = self.quan_a_fn(x) 44 | return t.nn.functional.linear(quantized_act, quantized_weight, self.bias) 45 | 46 | 47 | QuanModuleMapping = { 48 | t.nn.Conv2d: QuanConv2d, 49 | t.nn.Linear: QuanLinear 50 | } 51 | -------------------------------------------------------------------------------- /quan/quantizer/__init__.py: -------------------------------------------------------------------------------- 1 | from .lsq import LsqQuan 2 | from .quantizer import IdentityQuan 3 | -------------------------------------------------------------------------------- /quan/quantizer/lsq.py: -------------------------------------------------------------------------------- 1 | import torch as t 2 | 3 | from .quantizer import Quantizer 4 | 5 | 6 | def grad_scale(x, scale): 7 | y = x 8 | y_grad = x * scale 9 | return (y - y_grad).detach() + y_grad 10 | 11 | 12 | def round_pass(x): 13 | y = x.round() 14 | y_grad = x 15 | return (y - y_grad).detach() + y_grad 16 | 17 | 18 | class LsqQuan(Quantizer): 19 | def __init__(self, bit, all_positive=False, symmetric=False, per_channel=True): 20 | super().__init__(bit) 21 | 22 | if all_positive: 23 | assert not symmetric, "Positive quantization cannot be symmetric" 24 | # unsigned activation is quantized to [0, 2^b-1] 25 | self.thd_neg = 0 26 | self.thd_pos = 2 ** bit - 1 27 | else: 28 | if symmetric: 29 | # signed weight/activation is quantized to [-2^(b-1)+1, 2^(b-1)-1] 30 | self.thd_neg = - 2 ** (bit - 1) + 1 31 | self.thd_pos = 2 ** (bit - 1) - 1 32 | else: 33 | # signed weight/activation is quantized to [-2^(b-1), 2^(b-1)-1] 34 | self.thd_neg = - 2 ** (bit - 1) 35 | self.thd_pos = 2 ** (bit - 1) - 1 36 | 37 | self.per_channel = per_channel 38 | self.s = t.nn.Parameter(t.ones(1)) 39 | 40 | def init_from(self, x, *args, **kwargs): 41 | if self.per_channel: 42 | self.s = t.nn.Parameter( 43 | x.detach().abs().mean(dim=list(range(1, x.dim())), keepdim=True) * 2 / (self.thd_pos ** 0.5)) 44 | else: 45 | self.s = t.nn.Parameter(x.detach().abs().mean() * 2 / (self.thd_pos ** 0.5)) 46 | 47 | def forward(self, x): 48 | if self.per_channel: 49 | s_grad_scale = 1.0 / ((self.thd_pos * x.numel()) ** 0.5) 50 | else: 51 | s_grad_scale = 1.0 / ((self.thd_pos * x.numel()) ** 0.5) 52 | s_scale = grad_scale(self.s, s_grad_scale) 53 | 54 | x = x / s_scale 55 | x = t.clamp(x, self.thd_neg, self.thd_pos) 56 | x = round_pass(x) 57 | x = x * s_scale 58 | return x 59 | -------------------------------------------------------------------------------- /quan/quantizer/quantizer.py: -------------------------------------------------------------------------------- 1 | import torch as t 2 | 3 | 4 | class Quantizer(t.nn.Module): 5 | def __init__(self, bit): 6 | super().__init__() 7 | 8 | def init_from(self, x, *args, **kwargs): 9 | pass 10 | 11 | def forward(self, x): 12 | raise NotImplementedError 13 | 14 | 15 | class IdentityQuan(Quantizer): 16 | def __init__(self, bit=None, *args, **kwargs): 17 | super().__init__(bit) 18 | assert bit is None, 'The bit-width of identity quantizer must be None' 19 | 20 | def forward(self, x): 21 | return x 22 | -------------------------------------------------------------------------------- /quan/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from .func import * 4 | from .quantizer import * 5 | 6 | 7 | def quantizer(default_cfg, this_cfg=None): 8 | target_cfg = dict(default_cfg) 9 | if this_cfg is not None: 10 | for k, v in this_cfg.items(): 11 | target_cfg[k] = v 12 | 13 | if target_cfg['bit'] is None: 14 | q = IdentityQuan 15 | elif target_cfg['mode'] == 'lsq': 16 | q = LsqQuan 17 | else: 18 | raise ValueError('Cannot find quantizer `%s`', target_cfg['mode']) 19 | 20 | target_cfg.pop('mode') 21 | return q(**target_cfg) 22 | 23 | 24 | def find_modules_to_quantize(model, quan_scheduler): 25 | replaced_modules = dict() 26 | for name, module in model.named_modules(): 27 | if type(module) in QuanModuleMapping.keys(): 28 | if name in quan_scheduler.excepts: 29 | replaced_modules[name] = QuanModuleMapping[type(module)]( 30 | module, 31 | quan_w_fn=quantizer(quan_scheduler.weight, 32 | quan_scheduler.excepts[name].weight), 33 | quan_a_fn=quantizer(quan_scheduler.act, 34 | quan_scheduler.excepts[name].act) 35 | ) 36 | else: 37 | replaced_modules[name] = QuanModuleMapping[type(module)]( 38 | module, 39 | quan_w_fn=quantizer(quan_scheduler.weight), 40 | quan_a_fn=quantizer(quan_scheduler.act) 41 | ) 42 | elif name in quan_scheduler.excepts: 43 | logging.warning('Cannot find module %s in the model, skip it' % name) 44 | 45 | return replaced_modules 46 | 47 | 48 | def replace_module_by_names(model, modules_to_replace): 49 | def helper(child: t.nn.Module): 50 | for n, c in child.named_children(): 51 | if type(c) in QuanModuleMapping.keys(): 52 | for full_name, m in model.named_modules(): 53 | if c is m: 54 | child.add_module(n, modules_to_replace.pop(full_name)) 55 | break 56 | else: 57 | helper(c) 58 | 59 | helper(model) 60 | return model 61 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | from .checkpoint import load_checkpoint, save_checkpoint 2 | from .config import init_logger, get_config 3 | from .data_loader import load_data 4 | from .lr_scheduler import lr_scheduler 5 | from .monitor import ProgressMonitor, TensorBoardMonitor, AverageMeter 6 | -------------------------------------------------------------------------------- /util/checkpoint.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import torch as t 5 | 6 | logger = logging.getLogger() 7 | 8 | 9 | def save_checkpoint(epoch, arch, model, extras=None, is_best=None, name=None, output_dir='.'): 10 | """Save a pyTorch training checkpoint 11 | Args: 12 | epoch: current epoch number 13 | arch: name of the network architecture/topology 14 | model: a pyTorch model 15 | extras: optional dict with additional user-defined data to be saved in the checkpoint. 16 | Will be saved under the key 'extras' 17 | is_best: If true, will save a copy of the checkpoint with the suffix 'best' 18 | name: the name of the checkpoint file 19 | output_dir: directory in which to save the checkpoint 20 | """ 21 | if not os.path.isdir(output_dir): 22 | raise IOError('Checkpoint directory does not exist at', os.path.abspath(dir)) 23 | 24 | if extras is None: 25 | extras = {} 26 | if not isinstance(extras, dict): 27 | raise TypeError('extras must be either a dict or None') 28 | 29 | filename = 'checkpoint.pth.tar' if name is None else name + '_checkpoint.pth.tar' 30 | filepath = os.path.join(output_dir, filename) 31 | filename_best = 'best.pth.tar' if name is None else name + '_best.pth.tar' 32 | filepath_best = os.path.join(output_dir, filename_best) 33 | 34 | checkpoint = { 35 | 'epoch': epoch, 36 | 'state_dict': model.state_dict(), 37 | 'arch': arch, 38 | 'extras': extras, 39 | } 40 | 41 | msg = 'Saving checkpoint to:\n' 42 | msg += ' Current: %s\n' % filepath 43 | t.save(checkpoint, filepath) 44 | if is_best: 45 | msg += ' Best: %s\n' % filepath_best 46 | t.save(checkpoint, filepath_best) 47 | logger.info(msg) 48 | 49 | 50 | def load_checkpoint(model, chkp_file, model_device=None, strict=False, lean=False): 51 | """Load a pyTorch training checkpoint. 52 | Args: 53 | model: the pyTorch model to which we will load the parameters. You can 54 | specify model=None if the checkpoint contains enough metadata to infer 55 | the model. The order of the arguments is misleading and clunky, and is 56 | kept this way for backward compatibility. 57 | chkp_file: the checkpoint file 58 | lean: if set, read into model only 'state_dict' field 59 | model_device [str]: if set, call model.to($model_device) 60 | This should be set to either 'cpu' or 'cuda'. 61 | :returns: updated model, optimizer, start_epoch 62 | """ 63 | if not os.path.isfile(chkp_file): 64 | raise IOError('Cannot find a checkpoint at', chkp_file) 65 | 66 | checkpoint = t.load(chkp_file, map_location=lambda storage, loc: storage) 67 | 68 | if 'state_dict' not in checkpoint: 69 | raise ValueError('Checkpoint must contain model parameters') 70 | 71 | extras = checkpoint.get('extras', None) 72 | 73 | arch = checkpoint.get('arch', '_nameless_') 74 | 75 | checkpoint_epoch = checkpoint.get('epoch', None) 76 | start_epoch = checkpoint_epoch + 1 if checkpoint_epoch is not None else 0 77 | 78 | anomalous_keys = model.load_state_dict(checkpoint['state_dict'], strict) 79 | if anomalous_keys: 80 | missing_keys, unexpected_keys = anomalous_keys 81 | if unexpected_keys: 82 | logger.warning("The loaded checkpoint (%s) contains %d unexpected state keys" % 83 | (chkp_file, len(unexpected_keys))) 84 | if missing_keys: 85 | raise ValueError("The loaded checkpoint (%s) is missing %d state keys" % 86 | (chkp_file, len(missing_keys))) 87 | 88 | if model_device is not None: 89 | model.to(model_device) 90 | 91 | if lean: 92 | logger.info("Loaded checkpoint %s model (next epoch %d) from %s", arch, 0, chkp_file) 93 | return model, 0, None 94 | else: 95 | logger.info("Loaded checkpoint %s model (next epoch %d) from %s", arch, start_epoch, chkp_file) 96 | return model, start_epoch, extras 97 | -------------------------------------------------------------------------------- /util/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import logging.config 4 | import os 5 | import time 6 | 7 | import munch 8 | import yaml 9 | 10 | 11 | def merge_nested_dict(d, other): 12 | new = dict(d) 13 | for k, v in other.items(): 14 | if d.get(k, None) is not None and type(v) is dict: 15 | new[k] = merge_nested_dict(d[k], v) 16 | else: 17 | new[k] = v 18 | return new 19 | 20 | 21 | def get_config(default_file): 22 | p = argparse.ArgumentParser(description='Learned Step Size Quantization') 23 | p.add_argument('config_file', metavar='PATH', nargs='+', 24 | help='path to a configuration file') 25 | arg = p.parse_args() 26 | 27 | with open(default_file) as yaml_file: 28 | cfg = yaml.safe_load(yaml_file) 29 | 30 | for f in arg.config_file: 31 | if not os.path.isfile(f): 32 | raise FileNotFoundError('Cannot find a configuration file at', f) 33 | with open(f) as yaml_file: 34 | c = yaml.safe_load(yaml_file) 35 | cfg = merge_nested_dict(cfg, c) 36 | 37 | return munch.munchify(cfg) 38 | 39 | 40 | def init_logger(experiment_name, output_dir, cfg_file=None): 41 | time_str = time.strftime("%Y%m%d-%H%M%S") 42 | exp_full_name = time_str if experiment_name is None else experiment_name + '_' + time_str 43 | log_dir = output_dir / exp_full_name 44 | log_dir.mkdir(exist_ok=True) 45 | log_file = log_dir / (exp_full_name + '.log') 46 | logging.config.fileConfig(cfg_file, defaults={'logfilename': log_file}) 47 | logger = logging.getLogger() 48 | logger.info('Log file for this run: ' + str(log_file)) 49 | return log_dir 50 | -------------------------------------------------------------------------------- /util/data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch as t 5 | import torch.utils.data 6 | import torchvision as tv 7 | from sklearn.model_selection import train_test_split 8 | 9 | 10 | def __balance_val_split(dataset, val_split=0.): 11 | targets = np.array(dataset.targets) 12 | train_indices, val_indices = train_test_split( 13 | np.arange(targets.shape[0]), 14 | test_size=val_split, 15 | stratify=targets 16 | ) 17 | train_dataset = t.utils.data.Subset(dataset, indices=train_indices) 18 | val_dataset = t.utils.data.Subset(dataset, indices=val_indices) 19 | return train_dataset, val_dataset 20 | 21 | 22 | def __deterministic_worker_init_fn(worker_id, seed=0): 23 | import random 24 | random.seed(seed) 25 | np.random.seed(seed) 26 | t.manual_seed(seed) 27 | 28 | 29 | def load_data(cfg): 30 | if cfg.val_split < 0 or cfg.val_split >= 1: 31 | raise ValueError('val_split should be in the range of [0, 1) but got %.3f' % cfg.val_split) 32 | 33 | tv_normalize = tv.transforms.Normalize(mean=[0.485, 0.456, 0.406], 34 | std=[0.229, 0.224, 0.225]) 35 | if cfg.dataset == 'imagenet': 36 | train_transform = tv.transforms.Compose([ 37 | tv.transforms.RandomResizedCrop(224), 38 | tv.transforms.RandomHorizontalFlip(), 39 | tv.transforms.ToTensor(), 40 | tv_normalize 41 | ]) 42 | val_transform = tv.transforms.Compose([ 43 | tv.transforms.Resize(256), 44 | tv.transforms.CenterCrop(224), 45 | tv.transforms.ToTensor(), 46 | tv_normalize 47 | ]) 48 | 49 | train_set = tv.datasets.ImageFolder( 50 | root=os.path.join(cfg.path, 'train'), transform=train_transform) 51 | test_set = tv.datasets.ImageFolder( 52 | root=os.path.join(cfg.path, 'val'), transform=val_transform) 53 | 54 | elif cfg.dataset == 'cifar10': 55 | train_transform = tv.transforms.Compose([ 56 | tv.transforms.RandomHorizontalFlip(), 57 | tv.transforms.RandomCrop(32, 4), 58 | tv.transforms.ToTensor(), 59 | tv_normalize 60 | ]) 61 | val_transform = tv.transforms.Compose([ 62 | tv.transforms.ToTensor(), 63 | tv_normalize 64 | ]) 65 | 66 | train_set = tv.datasets.CIFAR10(cfg.path, train=True, transform=train_transform, download=True) 67 | test_set = tv.datasets.CIFAR10(cfg.path, train=False, transform=val_transform, download=True) 68 | 69 | else: 70 | raise ValueError('load_data does not support dataset %s' % cfg.dataset) 71 | 72 | if cfg.val_split != 0: 73 | train_set, val_set = __balance_val_split(train_set, cfg.val_split) 74 | else: 75 | # In this case, use the test set for validation 76 | val_set = test_set 77 | 78 | worker_init_fn = None 79 | if cfg.deterministic: 80 | worker_init_fn = __deterministic_worker_init_fn 81 | 82 | train_loader = t.utils.data.DataLoader( 83 | train_set, cfg.batch_size, shuffle=True, num_workers=cfg.workers, pin_memory=True, worker_init_fn=worker_init_fn) 84 | val_loader = t.utils.data.DataLoader( 85 | val_set, cfg.batch_size, num_workers=cfg.workers, pin_memory=True, worker_init_fn=worker_init_fn) 86 | test_loader = t.utils.data.DataLoader( 87 | test_set, cfg.batch_size, num_workers=cfg.workers, pin_memory=True, worker_init_fn=worker_init_fn) 88 | 89 | return train_loader, val_loader, test_loader 90 | -------------------------------------------------------------------------------- /util/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | def lr_scheduler(optimizer, mode, batch_size=None, num_samples=None, update_per_batch=False, **kwargs): 5 | # variables batch_size & num_samples are only used when the learning rate updated every epoch 6 | if update_per_batch: 7 | assert isinstance(batch_size, int) and isinstance(num_samples, int) 8 | 9 | if mode == 'fixed': 10 | scheduler = FixedLr 11 | elif mode == 'step': 12 | scheduler = StepLr 13 | elif mode == 'multi_step': 14 | scheduler = MultiStepLr 15 | elif mode == 'exp': 16 | scheduler = ExponentialLr 17 | elif mode == 'cos': 18 | scheduler = CosineLr 19 | elif mode == 'cos_warm_restarts': 20 | scheduler = CosineWarmRestartsLr 21 | else: 22 | raise ValueError('LR scheduler `%s` is not supported', mode) 23 | 24 | return scheduler(optimizer=optimizer, batch_size=batch_size, num_samples=num_samples, 25 | update_per_batch=update_per_batch, **kwargs) 26 | 27 | 28 | class LrScheduler: 29 | def __init__(self, optimizer, batch_size, num_samples, update_per_batch): 30 | self.optimizer = optimizer 31 | self.current_lr = self.get_lr() 32 | self.base_lr = self.get_lr() 33 | self.num_groups = len(self.base_lr) 34 | 35 | self.batch_size = batch_size 36 | self.num_samples = num_samples 37 | self.update_per_batch = update_per_batch 38 | 39 | def get_lr(self): 40 | return [g['lr'] for g in self.optimizer.param_groups] 41 | 42 | def set_lr(self, lr): 43 | for i in range(self.num_groups): 44 | self.current_lr[i] = lr[i] 45 | self.optimizer.param_groups[i]['lr'] = lr[i] 46 | 47 | def step(self, epoch, batch): 48 | raise NotImplementedError 49 | 50 | def __str__(self): 51 | s = '`%s`' % self.__class__.__name__ 52 | s += '\n Update per batch: %s' % self.update_per_batch 53 | for i in range(self.num_groups): 54 | s += '\n Group %d: %g' % (i, self.current_lr[i]) 55 | return s 56 | 57 | 58 | class FixedLr(LrScheduler): 59 | def step(self, epoch, batch): 60 | pass 61 | 62 | 63 | class LambdaLr(LrScheduler): 64 | def __init__(self, lr_lambda, **kwargs): 65 | super(LambdaLr, self).__init__(**kwargs) 66 | if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple): 67 | self.lr_lambdas = [lr_lambda] * self.num_groups 68 | else: 69 | if len(lr_lambda) != self.num_groups: 70 | raise ValueError("Expected {} lr_lambdas, but got {}".format( 71 | self.num_groups, len(lr_lambda))) 72 | self.lr_lambdas = list(lr_lambda) 73 | 74 | def step(self, epoch, batch): 75 | if self.update_per_batch: 76 | epoch = epoch + batch * self.batch_size / self.num_samples 77 | for i in range(self.num_groups): 78 | func = self.lr_lambdas[i] 79 | self.current_lr[i] = func(epoch) * self.base_lr[i] 80 | self.set_lr(self.current_lr) 81 | 82 | 83 | class StepLr(LrScheduler): 84 | def __init__(self, step_size=30, gamma=0.1, **kwargs): 85 | super(StepLr, self).__init__(**kwargs) 86 | self.step_size = step_size 87 | self.gamma = gamma 88 | 89 | def step(self, epoch, batch): 90 | for i in range(self.num_groups): 91 | self.current_lr[i] = self.base_lr[i] * (self.gamma ** (epoch // self.step_size)) 92 | self.set_lr(self.current_lr) 93 | 94 | 95 | class MultiStepLr(LrScheduler): 96 | def __init__(self, milestones=[30, ], gamma=0.1, **kwargs): 97 | super(MultiStepLr, self).__init__(**kwargs) 98 | self.milestones = milestones 99 | self.gamma = gamma 100 | 101 | def step(self, epoch, batch): 102 | n = sum([1 for m in self.milestones if m <= epoch]) 103 | scale = self.gamma ** n 104 | for i in range(self.num_groups): 105 | self.current_lr[i] = self.base_lr[i] * scale 106 | self.set_lr(self.current_lr) 107 | 108 | 109 | class ExponentialLr(LrScheduler): 110 | def __init__(self, gamma=0.95, **kwargs): 111 | super(ExponentialLr, self).__init__(**kwargs) 112 | self.gamma = gamma 113 | 114 | def step(self, epoch, batch): 115 | if self.update_per_batch: 116 | epoch = epoch + batch * self.batch_size / self.num_samples 117 | for i in range(self.num_groups): 118 | self.current_lr[i] = self.base_lr[i] * (self.gamma ** epoch) 119 | 120 | 121 | class CosineLr(LrScheduler): 122 | def __init__(self, lr_min=0., cycle=90, **kwargs): 123 | super(CosineLr, self).__init__(**kwargs) 124 | self.min_lr = lr_min 125 | self.cycle = cycle 126 | 127 | def step(self, epoch, batch): 128 | if self.update_per_batch: 129 | epoch = epoch + batch * self.batch_size / self.num_samples 130 | if epoch > self.cycle: 131 | epoch = self.cycle 132 | for i in range(self.num_groups): 133 | self.current_lr[i] = self.min_lr + 0.5 * (self.base_lr[i] - self.min_lr) \ 134 | * (1 + math.cos(math.pi * epoch / self.cycle)) 135 | self.set_lr(self.current_lr) 136 | 137 | 138 | class CosineWarmRestartsLr(LrScheduler): 139 | def __init__(self, lr_min=0., cycle=5, cycle_scale=2., amp_scale=0.5, **kwargs): 140 | super(CosineWarmRestartsLr, self).__init__(**kwargs) 141 | self.min_lr = lr_min 142 | self.cycle = cycle 143 | self.cycle_scale = cycle_scale 144 | self.amp_scale = amp_scale 145 | 146 | def step(self, epoch, batch): 147 | if self.update_per_batch: 148 | epoch = epoch + batch * self.batch_size / self.num_samples 149 | 150 | curr_cycle = self.cycle 151 | curr_amp = 1. 152 | while epoch >= curr_cycle: 153 | epoch = epoch - curr_cycle 154 | curr_cycle *= self.cycle_scale 155 | curr_amp *= self.amp_scale 156 | 157 | for i in range(self.num_groups): 158 | self.current_lr[i] = self.min_lr + 0.5 * curr_amp * (self.base_lr[i] - self.min_lr) \ 159 | * (1 + math.cos(math.pi * epoch / curr_cycle)) 160 | self.set_lr(self.current_lr) 161 | -------------------------------------------------------------------------------- /util/monitor.py: -------------------------------------------------------------------------------- 1 | from torch.utils.tensorboard import SummaryWriter 2 | 3 | __all__ = ['ProgressMonitor', 'TensorBoardMonitor', 'AverageMeter'] 4 | 5 | 6 | class AverageMeter: 7 | """Computes and stores the average and current value""" 8 | 9 | def __init__(self, fmt='%.6f'): 10 | self.fmt = fmt 11 | self.val = self.avg = self.sum = self.count = 0 12 | 13 | def reset(self): 14 | self.val = self.avg = self.sum = self.count = 0 15 | 16 | def update(self, val, n=1): 17 | self.val = val 18 | self.sum += val * n 19 | self.count += n 20 | self.avg = self.sum / self.count 21 | 22 | def __str__(self): 23 | s = self.fmt % self.avg 24 | return s 25 | 26 | 27 | class Monitor: 28 | """This is an abstract interface for data loggers 29 | 30 | Train monitors log the progress of the training process to some backend. 31 | This backend can be a file, a web service, or some other means to collect and/or 32 | display the training 33 | """ 34 | 35 | def __init__(self): 36 | pass 37 | 38 | def update(self, epoch, step_idx, step_num, prefix, meter_dict): 39 | raise NotImplementedError 40 | 41 | 42 | class ProgressMonitor(Monitor): 43 | def __init__(self, logger): 44 | super().__init__() 45 | self.logger = logger 46 | 47 | def update(self, epoch, step_idx, step_num, prefix, meter_dict): 48 | msg = prefix 49 | if epoch > -1: 50 | msg += ' [%d][%5d/%5d] ' % (epoch, step_idx, int(step_num)) 51 | else: 52 | msg += ' [%5d/%5d] ' % (step_idx, int(step_num)) 53 | for k, v in meter_dict.items(): 54 | msg += k + ' ' 55 | if isinstance(v, AverageMeter): 56 | msg += str(v) 57 | else: 58 | msg += '%.6f' % v 59 | msg += ' ' 60 | self.logger.info(msg) 61 | 62 | 63 | class TensorBoardMonitor(Monitor): 64 | def __init__(self, logger, log_dir): 65 | super().__init__() 66 | self.writer = SummaryWriter(log_dir / 'tb_runs') 67 | logger.info('TensorBoard data directory: %s/tb_runs' % log_dir) 68 | 69 | def update(self, epoch, step_idx, step_num, prefix, meter_dict): 70 | current_step = epoch * step_num + step_idx 71 | for k, v in meter_dict.items(): 72 | val = v.val if isinstance(v, AverageMeter) else v 73 | self.writer.add_scalar(prefix + '/' + k, val, current_step) 74 | --------------------------------------------------------------------------------