├── README.md ├── cifar_train_eval.py ├── cifar_train_eval_dist.py ├── imgnet_train_eval.py ├── imgnet_train_eval_dist.py ├── mnist_train_eval.py ├── nets ├── cifar_resnet.py ├── cifar_vgg.py ├── imgnet_alexnet.py ├── imgnet_mobilenet_v1.py ├── imgnet_mobilenet_v2.py ├── imgnet_resnet.py ├── imgnet_vgg.py └── mnist_lenet.py └── utils ├── preprocessing.py ├── summary.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch simple classification baselines 2 | 3 | This repository contains simple pytorch version of LeNet-5(MNIST), ResNet(CIFAR, ImageNet), AlexNet(ImageNet), VGG-16(CIFAR, ImageNet) baselines. 4 | There are both **nn.DataParallel** and **nn.parallel.DistributedDataParallel** version for multi GPU training, I highly recommand using nn.parallel.DistributedDataParallel since it's considerably faster than using nn.DataParallel. 5 | 6 | ## Requirements: 7 | - python>=3.5 8 | - pytorch>=0.4.1(>=1.1.0 for DistributedDataParallel version) 9 | - tensorboardX(optional) 10 | 11 | ## Train 12 | 13 | ### single GPU or multi GPU using nn.DataParallel 14 | * ```python mnist_train_eval.py ``` 15 | * ```python cifar_train_eval.py ``` 16 | * ```python imgnet_train_eval.py ``` 17 | 18 | ### multi GPU using nn.parallel.DistributedDataParallel 19 | * ```python -m torch.distributed.launch --nproc_per_node 2 cifar_train_eval.py --dist --gpus 0,1``` 20 | * ```python -m torch.distributed.launch --nproc_per_node 2 imgnet_train_eval.py --dist --gpus 0,1``` 21 | 22 | 23 | ## Results: 24 | 25 | ### MNIST: 26 | Model|Accuracy 27 | :---:|:---:| 28 | LeNet-5|99.26% 29 | 30 | ### CIFAR-10: 31 | Model|Accuracy 32 | :---:|:---: 33 | ResNet-20|92.09% 34 | ResNet-56|93.68% 35 | VGG-16|93.99% 36 | 37 | ### ImageNet2012: 38 | Model|Top-1 Accuracy|Top-5 Accuracy 39 | :---:|:---:|:---: 40 | ResNet-18|69.67%|89.29% 41 | 42 | -------------------------------------------------------------------------------- /cifar_train_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | from datetime import datetime 5 | from contextlib import ExitStack 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | import torch.distributed as dist 11 | 12 | import torchvision 13 | 14 | from nets.cifar_vgg import vgg16 15 | from nets.cifar_resnet import resnet20, resnet56 16 | 17 | from utils.utils import DisablePrint 18 | from utils.summary import SummaryWriter 19 | from utils.preprocessing import cifar_transform 20 | 21 | # Training settings 22 | parser = argparse.ArgumentParser(description='classification_baselines') 23 | 24 | parser.add_argument('--dist', action='store_true') 25 | parser.add_argument('--local_rank', type=int, default=0) 26 | 27 | parser.add_argument('--root_dir', type=str, default='./') 28 | parser.add_argument('--data_dir', type=str, default='./data') 29 | parser.add_argument('--log_name', type=str, default='vgg16_baseline') 30 | parser.add_argument('--pretrain', action='store_true', default=False) 31 | parser.add_argument('--pretrain_dir', type=str, default='') 32 | 33 | parser.add_argument('--lr', type=float, default=0.1) 34 | parser.add_argument('--wd', type=float, default=5e-4) 35 | 36 | parser.add_argument('--train_batch_size', type=int, default=256) 37 | parser.add_argument('--test_batch_size', type=int, default=200) 38 | parser.add_argument('--max_epochs', type=int, default=200) 39 | 40 | parser.add_argument('--log_interval', type=int, default=10) 41 | parser.add_argument('--gpus', type=str, default='0') 42 | parser.add_argument('--num_workers', type=int, default=0) 43 | 44 | cfg = parser.parse_args() 45 | 46 | cfg.log_dir = os.path.join(cfg.root_dir, 'logs', cfg.log_name) 47 | cfg.ckpt_dir = os.path.join(cfg.root_dir, 'ckpt', cfg.log_name) 48 | 49 | os.makedirs(cfg.log_dir, exist_ok=True) 50 | os.makedirs(cfg.ckpt_dir, exist_ok=True) 51 | 52 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # see issue #152 53 | os.environ["CUDA_VISIBLE_DEVICES"] = cfg.gpus 54 | 55 | 56 | def main(): 57 | num_gpus = torch.cuda.device_count() 58 | if cfg.dist: 59 | device = torch.device('cuda:%d' % cfg.local_rank) 60 | torch.cuda.set_device(cfg.local_rank) 61 | dist.init_process_group(backend='nccl', init_method='env://', 62 | world_size=num_gpus, rank=cfg.local_rank) 63 | else: 64 | device = torch.device('cuda') 65 | 66 | # Data 67 | print('==> Preparing data ...') 68 | dataset = torchvision.datasets.CIFAR10 69 | trainset = dataset(root=cfg.data_dir, train=True, download=True, 70 | transform=cifar_transform(is_training=True)) 71 | train_sampler = torch.utils.data.distributed.DistributedSampler(trainset, 72 | num_replicas=num_gpus, 73 | rank=cfg.local_rank) 74 | train_loader = torch.utils.data.DataLoader(trainset, 75 | batch_size=cfg.train_batch_size // num_gpus 76 | if cfg.dist else cfg.train_batch_size, 77 | shuffle=not cfg.dist, 78 | num_workers=cfg.num_workers, 79 | sampler=train_sampler if cfg.dist else None) 80 | 81 | testset = dataset(root=cfg.data_dir, train=False, 82 | transform=cifar_transform(is_training=False)) 83 | test_loader = torch.utils.data.DataLoader(testset, 84 | batch_size=cfg.test_batch_size, 85 | shuffle=False, 86 | num_workers=cfg.num_workers) 87 | 88 | print('==> Building model ...') 89 | model = vgg16() 90 | model = model.to(device) 91 | if cfg.dist: 92 | model = nn.parallel.DistributedDataParallel(model, 93 | device_ids=[cfg.local_rank, ], 94 | output_device=cfg.local_rank) 95 | else: 96 | model = nn.DataParallel(model).to(device) 97 | 98 | optimizer = torch.optim.SGD(model.parameters(), lr=cfg.lr, momentum=0.9, weight_decay=cfg.wd) 99 | lr_schedulr = optim.lr_scheduler.StepLR(optimizer, 60, 0.1) 100 | criterion = torch.nn.CrossEntropyLoss() 101 | 102 | summary_writer = SummaryWriter(cfg.log_dir) 103 | 104 | # Training 105 | def train(epoch): 106 | print('\nEpoch: %d' % epoch) 107 | model.train() 108 | 109 | start_time = time.time() 110 | for batch_idx, (inputs, targets) in enumerate(train_loader): 111 | inputs, targets = inputs.to(device), targets.to(device) 112 | 113 | outputs = model(inputs) 114 | loss = criterion(outputs, targets) 115 | 116 | optimizer.zero_grad() 117 | loss.backward() 118 | optimizer.step() 119 | 120 | if cfg.local_rank == 0 and batch_idx % cfg.log_interval == 0: 121 | step = len(train_loader) * epoch + batch_idx 122 | duration = time.time() - start_time 123 | 124 | print('%s epoch: %d step: %d cls_loss= %.5f (%d samples/sec)' % 125 | (datetime.now(), epoch, batch_idx, loss.item(), 126 | cfg.train_batch_size * cfg.log_interval / duration)) 127 | 128 | start_time = time.time() 129 | summary_writer.add_scalar('cls_loss', loss.item(), step) 130 | summary_writer.add_scalar('learning rate', optimizer.param_groups[0]['lr'], step) 131 | 132 | def test(epoch): 133 | model.eval() 134 | correct = 0 135 | with torch.no_grad(): 136 | for batch_idx, (inputs, targets) in enumerate(test_loader): 137 | inputs, targets = inputs.to(device), targets.to(device) 138 | 139 | outputs = model(inputs) 140 | _, predicted = torch.max(outputs.data, 1) 141 | correct += predicted.eq(targets.data).cpu().sum().item() 142 | 143 | acc = 100. * correct / len(test_loader.dataset) 144 | if cfg.local_rank == 0: 145 | print('%s Precision@1 ==> %.2f%% \n' % (datetime.now(), acc)) 146 | summary_writer.add_scalar('Precision@1', acc, global_step=epoch) 147 | return 148 | 149 | for epoch in range(cfg.max_epochs): 150 | train_sampler.set_epoch(epoch) 151 | train(epoch) 152 | test(epoch) 153 | lr_schedulr.step(epoch) 154 | if cfg.local_rank == 0: 155 | torch.save(model.state_dict(), os.path.join(cfg.ckpt_dir, 'checkpoint.t7')) 156 | print('checkpoint saved to %s !' % os.path.join(cfg.ckpt_dir, 'checkpoint.t7')) 157 | 158 | summary_writer.close() 159 | 160 | 161 | if __name__ == '__main__': 162 | with ExitStack() as stack: 163 | if cfg.local_rank != 0: 164 | stack.enter_context(DisablePrint()) 165 | main() 166 | -------------------------------------------------------------------------------- /cifar_train_eval_dist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | from datetime import datetime 5 | 6 | import torch 7 | import torch.optim as optim 8 | import torch.distributed as dist 9 | 10 | import torchvision 11 | 12 | from nets.cifar_vgg import * 13 | from nets.cifar_resnet import * 14 | from utils.preprocessing import * 15 | from utils.summary import SummaryWriter 16 | 17 | # Training settings 18 | parser = argparse.ArgumentParser(description='classification_baselines') 19 | 20 | parser.add_argument('--local_rank', type=int, default=0) 21 | 22 | parser.add_argument('--root_dir', type=str, default='./') 23 | parser.add_argument('--data_dir', type=str, default='./data') 24 | parser.add_argument('--log_name', type=str, default='vgg16_baseline_np3') 25 | parser.add_argument('--pretrain', action='store_true', default=False) 26 | parser.add_argument('--pretrain_dir', type=str, default='') 27 | 28 | parser.add_argument('--lr', type=float, default=0.1) 29 | parser.add_argument('--wd', type=float, default=5e-4) 30 | 31 | parser.add_argument('--train_batch_size', type=int, default=128) 32 | parser.add_argument('--test_batch_size', type=int, default=200) 33 | parser.add_argument('--max_epochs', type=int, default=200) 34 | 35 | parser.add_argument('--log_interval', type=int, default=10) 36 | parser.add_argument('--gpus', type=str, default='2,3') 37 | parser.add_argument('--num_workers', type=int, default=5) 38 | 39 | cfg = parser.parse_args() 40 | 41 | cfg.log_dir = os.path.join(cfg.root_dir, 'logs', cfg.log_name) 42 | cfg.ckpt_dir = os.path.join(cfg.root_dir, 'ckpt', cfg.log_name) 43 | 44 | os.makedirs(cfg.log_dir, exist_ok=True) 45 | os.makedirs(cfg.ckpt_dir, exist_ok=True) 46 | 47 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # see issue #152 48 | os.environ["CUDA_VISIBLE_DEVICES"] = cfg.gpus 49 | 50 | 51 | def main(): 52 | device = torch.device('cuda:%d' % cfg.local_rank) 53 | num_gpus = torch.cuda.device_count() 54 | 55 | torch.cuda.set_device(cfg.local_rank) 56 | dist.init_process_group(backend='nccl', init_method='env://', 57 | world_size=num_gpus, rank=cfg.local_rank) 58 | 59 | dataset = torchvision.datasets.CIFAR10 60 | 61 | # Data 62 | print('==> Preparing data..') 63 | trainset = dataset(root=cfg.data_dir, train=True, download=True, 64 | transform=cifar_transform(is_training=True)) 65 | train_sampler = torch.utils.data.distributed.DistributedSampler(trainset, 66 | num_replicas=num_gpus, 67 | rank=cfg.local_rank) 68 | train_loader = torch.utils.data.DataLoader(trainset, 69 | batch_size=cfg.train_batch_size // num_gpus, 70 | shuffle=False, 71 | num_workers=cfg.num_workers, 72 | sampler=train_sampler) 73 | 74 | testset = dataset(root=cfg.data_dir, train=False, 75 | transform=cifar_transform(is_training=False)) 76 | test_loader = torch.utils.data.DataLoader(testset, 77 | batch_size=cfg.test_batch_size, 78 | shuffle=False, 79 | num_workers=cfg.num_workers) 80 | 81 | print('==> Building model..') 82 | # model = resnet20() 83 | model = vgg16() 84 | model = model.to(device) 85 | model = nn.parallel.DistributedDataParallel(model, 86 | device_ids=[cfg.local_rank, ], 87 | output_device=cfg.local_rank) 88 | 89 | optimizer = torch.optim.SGD(model.parameters(), lr=cfg.lr, momentum=0.9, weight_decay=cfg.wd) 90 | lr_schedulr = optim.lr_scheduler.MultiStepLR(optimizer, [60, 120, 180], 0.1) 91 | criterion = torch.nn.CrossEntropyLoss() 92 | 93 | summary_writer = SummaryWriter(cfg.log_dir) 94 | 95 | # Training 96 | def train(epoch): 97 | print('\nEpoch: %d' % epoch) 98 | model.train() 99 | 100 | start_time = time.time() 101 | for batch_idx, (inputs, targets) in enumerate(train_loader): 102 | inputs, targets = inputs.to(device), targets.to(device) 103 | 104 | outputs = model(inputs) 105 | loss = criterion(outputs, targets) 106 | 107 | optimizer.zero_grad() 108 | loss.backward() 109 | optimizer.step() 110 | 111 | if cfg.local_rank == 0 and batch_idx % cfg.log_interval == 0: 112 | step = len(train_loader) * epoch + batch_idx 113 | duration = time.time() - start_time 114 | 115 | print('%s epoch: %d step: %d cls_loss= %.5f (%d samples/sec)' % 116 | (datetime.now(), epoch, batch_idx, loss.item(), 117 | cfg.train_batch_size * cfg.log_interval / duration)) 118 | 119 | start_time = time.time() 120 | summary_writer.add_scalar('cls_loss', loss.item(), step) 121 | summary_writer.add_scalar('learning rate', optimizer.param_groups[0]['lr'], step) 122 | 123 | def test(epoch): 124 | model.eval() 125 | correct = 0 126 | with torch.no_grad(): 127 | for batch_idx, (inputs, targets) in enumerate(test_loader): 128 | inputs, targets = inputs.to(device), targets.to(device) 129 | 130 | outputs = model(inputs) 131 | _, predicted = torch.max(outputs.data, 1) 132 | correct += predicted.eq(targets.data).cpu().sum().item() 133 | 134 | acc = 100. * correct / len(test_loader.dataset) 135 | print('%s Precision@1 ==> %.2f%% \n' % (datetime.now(), acc)) 136 | summary_writer.add_scalar('Precision@1', acc, global_step=epoch) 137 | return 138 | 139 | for epoch in range(cfg.max_epochs): 140 | train_sampler.set_epoch(epoch) 141 | train(epoch) 142 | test(epoch) 143 | lr_schedulr.step(epoch) 144 | if cfg.local_rank == 0: 145 | torch.save(model.state_dict(), os.path.join(cfg.ckpt_dir, 'checkpoint.t7')) 146 | print('checkpoint saved to %s !' % os.path.join(cfg.ckpt_dir, 'checkpoint.t7')) 147 | 148 | summary_writer.close() 149 | 150 | 151 | if __name__ == '__main__': 152 | main() 153 | -------------------------------------------------------------------------------- /imgnet_train_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | from tqdm import tqdm 5 | from PIL import ImageFile 6 | from datetime import datetime 7 | from contextlib import ExitStack 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.optim as optim 12 | import torch.distributed as dist 13 | import torchvision.datasets as datasets 14 | 15 | from nets.imgnet_vgg import vgg16 16 | from nets.imgnet_alexnet import alexnet 17 | from nets.imgnet_resnet import resnet18, resnet34, resnet50 18 | 19 | from utils.utils import DisablePrint 20 | from utils.summary import SummaryWriter 21 | from utils.preprocessing import imgnet_transform 22 | 23 | ImageFile.LOAD_TRUNCATED_IMAGES = True 24 | torch.backends.cudnn.benchmark = True 25 | 26 | # Training settings 27 | parser = argparse.ArgumentParser(description='classification_baselines') 28 | 29 | parser.add_argument('--dist', action='store_true') 30 | parser.add_argument('--local_rank', type=int, default=0) 31 | 32 | parser.add_argument('--root_dir', type=str, default='./') 33 | parser.add_argument('--data_dir', type=str, default='./data') 34 | parser.add_argument('--log_name', type=str, default='alexnet_baseline') 35 | parser.add_argument('--pretrain', action='store_true', default=False) 36 | parser.add_argument('--pretrain_dir', type=str, default='./ckpt/') 37 | 38 | parser.add_argument('--lr', type=float, default=0.1) 39 | parser.add_argument('--wd', type=float, default=5e-4) 40 | 41 | parser.add_argument('--train_batch_size', type=int, default=256) 42 | parser.add_argument('--test_batch_size', type=int, default=200) 43 | parser.add_argument('--max_epochs', type=int, default=100) 44 | 45 | parser.add_argument('--log_interval', type=int, default=10) 46 | parser.add_argument('--gpus', type=str, default='0') 47 | parser.add_argument('--num_workers', type=int, default=20) 48 | 49 | cfg = parser.parse_args() 50 | 51 | cfg.log_dir = os.path.join(cfg.root_dir, 'logs', cfg.log_name) 52 | cfg.ckpt_dir = os.path.join(cfg.root_dir, 'ckpt', cfg.log_name) 53 | 54 | os.makedirs(cfg.log_dir, exist_ok=True) 55 | os.makedirs(cfg.ckpt_dir, exist_ok=True) 56 | 57 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # see issue #152 58 | os.environ["CUDA_VISIBLE_DEVICES"] = cfg.gpus 59 | 60 | 61 | def main(): 62 | num_gpus = torch.cuda.device_count() 63 | if cfg.dist: 64 | device = torch.device('cuda:%d' % cfg.local_rank) 65 | torch.cuda.set_device(cfg.local_rank) 66 | dist.init_process_group(backend='nccl', init_method='env://', 67 | world_size=num_gpus, rank=cfg.local_rank) 68 | else: 69 | device = torch.device('cuda') 70 | 71 | print('==> Preparing data ...') 72 | traindir = os.path.join(cfg.data_dir, 'train') 73 | train_dataset = datasets.ImageFolder(traindir, imgnet_transform(is_training=True)) 74 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, 75 | num_replicas=num_gpus, 76 | rank=cfg.local_rank) 77 | train_loader = torch.utils.data.DataLoader(train_dataset, 78 | batch_size=cfg.train_batch_size // num_gpus 79 | if cfg.dist else cfg.train_batch_size, 80 | shuffle=not cfg.dist, 81 | num_workers=cfg.num_workers, 82 | sampler=train_sampler if cfg.dist else None, 83 | pin_memory=True) 84 | 85 | evaldir = os.path.join(cfg.data_dir, 'val') 86 | val_dataset = datasets.ImageFolder(evaldir, imgnet_transform(is_training=False)) 87 | val_loader = torch.utils.data.DataLoader(val_dataset, 88 | batch_size=cfg.test_batch_size, 89 | shuffle=False, 90 | num_workers=cfg.num_workers, 91 | pin_memory=True) 92 | 93 | # create model 94 | print('==> Building model ...') 95 | model = resnet50() 96 | model = model.to(device) 97 | if cfg.dist: 98 | model = nn.parallel.DistributedDataParallel(model, 99 | device_ids=[cfg.local_rank, ], 100 | output_device=cfg.local_rank) 101 | else: 102 | model = torch.nn.DataParallel(model) 103 | 104 | optimizer = torch.optim.SGD(model.parameters(), cfg.lr, momentum=0.9, weight_decay=cfg.wd) 105 | lr_schedulr = optim.lr_scheduler.MultiStepLR(optimizer, [30, 60, 90], 0.1) 106 | criterion = torch.nn.CrossEntropyLoss() 107 | 108 | summary_writer = SummaryWriter(cfg.log_dir) 109 | 110 | def train(epoch): 111 | # switch to train mode 112 | model.train() 113 | 114 | start_time = time.time() 115 | for batch_idx, (inputs, targets) in enumerate(train_loader): 116 | inputs, targets = inputs.to(device), targets.to(device) 117 | 118 | # compute output 119 | outputs = model(inputs) 120 | loss = criterion(outputs, targets) 121 | 122 | # compute gradient and do SGD step 123 | optimizer.zero_grad() 124 | loss.backward() 125 | optimizer.step() 126 | 127 | if cfg.local_rank == 0 and batch_idx % cfg.log_interval == 0: 128 | step = len(train_loader) * epoch + batch_idx 129 | duration = time.time() - start_time 130 | 131 | print('%s epoch: %d step: %d cls_loss= %.5f (%d samples/sec)' % 132 | (datetime.now(), epoch, batch_idx, loss.item(), 133 | cfg.train_batch_size * cfg.log_interval / duration)) 134 | 135 | start_time = time.time() 136 | summary_writer.add_scalar('cls_loss', loss.item(), step) 137 | summary_writer.add_scalar('learning rate', optimizer.param_groups[0]['lr'], step) 138 | 139 | def validate(epoch): 140 | # switch to evaluate mode 141 | model.eval() 142 | top1 = 0 143 | top5 = 0 144 | with torch.no_grad(): 145 | for i, (inputs, targets) in tqdm(enumerate(val_loader)): 146 | inputs, targets = inputs.to(device), targets.to(device) 147 | 148 | # compute output 149 | output = model(inputs) 150 | 151 | # measure accuracy and record loss 152 | _, pred = output.data.topk(5, dim=1, largest=True, sorted=True) 153 | pred = pred.t() 154 | correct = pred.eq(targets.view(1, -1).expand_as(pred)) 155 | 156 | top1 += correct[:1].view(-1).float().sum(0, keepdim=True).item() 157 | top5 += correct[:5].view(-1).float().sum(0, keepdim=True).item() 158 | 159 | top1 *= 100 / len(val_dataset) 160 | top5 *= 100 / len(val_dataset) 161 | print('%s Precision@1 ==> %.2f%% Precision@1: %.2f%%\n' % (datetime.now(), top1, top5)) 162 | 163 | summary_writer.add_scalar('Precision@1', top1, epoch) 164 | summary_writer.add_scalar('Precision@5', top5, epoch) 165 | return 166 | 167 | for epoch in range(cfg.max_epochs): 168 | train_sampler.set_epoch(epoch) 169 | train(epoch) 170 | validate(epoch) 171 | lr_schedulr.step(epoch) 172 | if cfg.local_rank == 0: 173 | torch.save(model.state_dict(), os.path.join(cfg.ckpt_dir, 'checkpoint.t7')) 174 | print('checkpoint saved to %s !' % os.path.join(cfg.ckpt_dir, 'checkpoint.t7')) 175 | 176 | summary_writer.close() 177 | 178 | 179 | if __name__ == '__main__': 180 | with ExitStack() as stack: 181 | if cfg.local_rank != 0: 182 | stack.enter_context(DisablePrint()) 183 | main() 184 | -------------------------------------------------------------------------------- /imgnet_train_eval_dist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | from tqdm import tqdm 5 | from datetime import datetime 6 | from PIL import ImageFile 7 | 8 | import torch 9 | import torch.optim as optim 10 | import torch.distributed as dist 11 | import torchvision.datasets as datasets 12 | 13 | from nets.imgnet_alexnet import * 14 | from nets.imgnet_resnet import * 15 | from nets.imgnet_vgg import * 16 | 17 | from utils.preprocessing import * 18 | from utils.summary import SummaryWriter 19 | 20 | ImageFile.LOAD_TRUNCATED_IMAGES = True 21 | torch.backends.cudnn.benchmark = True 22 | 23 | # Training settings 24 | parser = argparse.ArgumentParser(description='classification_baselines') 25 | 26 | parser.add_argument('--local_rank', type=int, default=0) 27 | 28 | parser.add_argument('--root_dir', type=str, default='./') 29 | parser.add_argument('--data_dir', type=str, default='./data') 30 | parser.add_argument('--log_name', type=str, default='alexnet_baseline') 31 | parser.add_argument('--pretrain', action='store_true', default=False) 32 | parser.add_argument('--pretrain_dir', type=str, default='./ckpt/') 33 | 34 | parser.add_argument('--lr', type=float, default=0.1) 35 | parser.add_argument('--wd', type=float, default=5e-4) 36 | 37 | parser.add_argument('--train_batch_size', type=int, default=256) 38 | parser.add_argument('--test_batch_size', type=int, default=200) 39 | parser.add_argument('--max_epochs', type=int, default=100) 40 | 41 | parser.add_argument('--log_interval', type=int, default=10) 42 | parser.add_argument('--gpus', type=str, default='0') 43 | parser.add_argument('--num_workers', type=int, default=20) 44 | 45 | cfg = parser.parse_args() 46 | 47 | cfg.log_dir = os.path.join(cfg.root_dir, 'logs', cfg.log_name) 48 | cfg.ckpt_dir = os.path.join(cfg.root_dir, 'ckpt', cfg.log_name) 49 | 50 | os.makedirs(cfg.log_dir, exist_ok=True) 51 | os.makedirs(cfg.ckpt_dir, exist_ok=True) 52 | 53 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # see issue #152 54 | os.environ["CUDA_VISIBLE_DEVICES"] = cfg.gpus 55 | 56 | 57 | def main(): 58 | device = torch.device('cuda:%d' % cfg.local_rank) 59 | num_gpus = torch.cuda.device_count() 60 | 61 | torch.cuda.set_device(cfg.local_rank) 62 | dist.init_process_group(backend='nccl', init_method='env://', 63 | world_size=num_gpus, rank=cfg.local_rank) 64 | 65 | print('Prepare dataset ...') 66 | traindir = os.path.join(cfg.data_dir, 'train') 67 | train_dataset = datasets.ImageFolder(traindir, imagenet_transform(is_training=True)) 68 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, 69 | num_replicas=num_gpus, 70 | rank=cfg.local_rank) 71 | train_loader = torch.utils.data.DataLoader(train_dataset, 72 | batch_size=cfg.train_batch_size // num_gpus, 73 | shuffle=False, 74 | num_workers=cfg.num_workers, 75 | pin_memory=True, 76 | sampler=train_sampler) 77 | 78 | evaldir = os.path.join(cfg.data_dir, 'val') 79 | val_dataset = datasets.ImageFolder(evaldir, imagenet_transform(is_training=False)) 80 | val_loader = torch.utils.data.DataLoader(val_dataset, 81 | batch_size=cfg.test_batch_size, 82 | shuffle=False, 83 | num_workers=cfg.num_workers, 84 | pin_memory=True) 85 | 86 | # create model 87 | print("=> creating model alexnet") 88 | model = resnet18() 89 | model = model.to(device) 90 | model = nn.parallel.DistributedDataParallel(model, 91 | device_ids=[cfg.local_rank, ], 92 | output_device=cfg.local_rank) 93 | 94 | optimizer = torch.optim.SGD(model.parameters(), cfg.lr, momentum=0.9, weight_decay=cfg.wd) 95 | lr_schedulr = optim.lr_scheduler.MultiStepLR(optimizer, [30, 60, 90], 0.1) 96 | criterion = torch.nn.CrossEntropyLoss() 97 | 98 | summary_writer = SummaryWriter(cfg.log_dir) 99 | 100 | def train(epoch): 101 | # switch to train mode 102 | model.train() 103 | 104 | start_time = time.time() 105 | for batch_idx, (inputs, targets) in enumerate(train_loader): 106 | inputs, targets = inputs.to(device), targets.to(device) 107 | 108 | # compute output 109 | outputs = model(inputs) 110 | loss = criterion(outputs, targets) 111 | 112 | # compute gradient and do SGD step 113 | optimizer.zero_grad() 114 | loss.backward() 115 | optimizer.step() 116 | 117 | if cfg.local_rank == 0 and batch_idx % cfg.log_interval == 0: 118 | step = len(train_loader) * epoch + batch_idx 119 | duration = time.time() - start_time 120 | 121 | print('%s epoch: %d step: %d cls_loss= %.5f (%d samples/sec)' % 122 | (datetime.now(), epoch, batch_idx, loss.item(), 123 | cfg.train_batch_size * cfg.log_interval / duration)) 124 | 125 | start_time = time.time() 126 | summary_writer.add_scalar('cls_loss', loss.item(), step) 127 | summary_writer.add_scalar('learning rate', optimizer.param_groups[0]['lr'], step) 128 | 129 | def validate(epoch): 130 | # switch to evaluate mode 131 | model.eval() 132 | top1 = 0 133 | top5 = 0 134 | with torch.no_grad(): 135 | for i, (inputs, targets) in tqdm(enumerate(val_loader)): 136 | inputs, targets = inputs.to(device), targets.to(device) 137 | 138 | # compute output 139 | output = model(inputs) 140 | 141 | # measure accuracy and record loss 142 | _, pred = output.data.topk(5, dim=1, largest=True, sorted=True) 143 | pred = pred.t() 144 | correct = pred.eq(targets.view(1, -1).expand_as(pred)) 145 | 146 | top1 += correct[:1].view(-1).float().sum(0, keepdim=True).item() 147 | top5 += correct[:5].view(-1).float().sum(0, keepdim=True).item() 148 | 149 | top1 *= 100 / len(val_dataset) 150 | top5 *= 100 / len(val_dataset) 151 | print('%s Precision@1 ==> %.2f%% Precision@1: %.2f%%\n' % (datetime.now(), top1, top5)) 152 | 153 | summary_writer.add_scalar('Precision@1', top1, epoch) 154 | summary_writer.add_scalar('Precision@5', top5, epoch) 155 | return 156 | 157 | for epoch in range(cfg.max_epochs): 158 | train_sampler.set_epoch(epoch) 159 | train(epoch) 160 | validate(epoch) 161 | lr_schedulr.step(epoch) 162 | if cfg.local_rank == 0: 163 | torch.save(model.state_dict(), os.path.join(cfg.ckpt_dir, 'checkpoint.t7')) 164 | print('checkpoint saved to %s !' % os.path.join(cfg.ckpt_dir, 'checkpoint.t7')) 165 | 166 | summary_writer.close() 167 | 168 | 169 | if __name__ == '__main__': 170 | main() 171 | -------------------------------------------------------------------------------- /mnist_train_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | from datetime import datetime 5 | 6 | import torch.optim as optim 7 | from torchvision import datasets 8 | 9 | from nets.mnist_lenet import * 10 | from utils.preprocessing import * 11 | from utils.summary import SummaryWriter 12 | 13 | # Training settings 14 | parser = argparse.ArgumentParser(description='classification_baselines') 15 | 16 | parser.add_argument('--root_dir', type=str, default='./') 17 | parser.add_argument('--data_dir', type=str, default='./data') 18 | parser.add_argument('--log_name', type=str, default='lenet_baseline') 19 | parser.add_argument('--pretrain', action='store_true', default=False) 20 | parser.add_argument('--pretrain_dir', type=str, default='./ckpt/') 21 | 22 | parser.add_argument('--lr', type=float, default=0.1) 23 | parser.add_argument('--wd', type=float, default=1e-5) 24 | 25 | parser.add_argument('--train_batch_size', type=int, default=128) 26 | parser.add_argument('--test_batch_size', type=int, default=500) 27 | parser.add_argument('--max_epochs', type=int, default=30) 28 | 29 | parser.add_argument('--log_interval', type=int, default=10) 30 | parser.add_argument('--use_gpu', type=str, default='0') 31 | parser.add_argument('--num_workers', type=int, default=0) 32 | 33 | cfg = parser.parse_args() 34 | 35 | cfg.log_dir = os.path.join(cfg.root_dir, 'logs', cfg.log_name) 36 | cfg.ckpt_dir = os.path.join(cfg.root_dir, 'ckpt', cfg.log_name) 37 | 38 | os.makedirs(cfg.log_dir, exist_ok=True) 39 | os.makedirs(cfg.ckpt_dir, exist_ok=True) 40 | 41 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # see issue #152 42 | os.environ["CUDA_VISIBLE_DEVICES"] = cfg.use_gpu 43 | 44 | 45 | def main(): 46 | train_dataset = datasets.MNIST(cfg.data_dir, train=True, download=True, 47 | transform=minst_transform(is_training=True)) 48 | train_loader = torch.utils.data.DataLoader(train_dataset, 49 | batch_size=cfg.train_batch_size, 50 | shuffle=True, 51 | num_workers=cfg.num_workers, pin_memory=True) 52 | 53 | test_dataset = datasets.MNIST(cfg.data_dir, train=False, download=True, 54 | transform=minst_transform(is_training=False)) 55 | test_loader = torch.utils.data.DataLoader(test_dataset, 56 | batch_size=cfg.test_batch_size, 57 | shuffle=True, 58 | num_workers=cfg.num_workers, pin_memory=True) 59 | 60 | model = LeNet().cuda() 61 | 62 | optimizer = optim.SGD(model.parameters(), lr=cfg.lr, momentum=0.8, weight_decay=cfg.wd) 63 | lr_schedulr = optim.lr_scheduler.StepLR(optimizer, 10, 0.1) 64 | criterion = torch.nn.CrossEntropyLoss() 65 | 66 | summary_writer = SummaryWriter(cfg.log_dir) 67 | 68 | def train(epoch): 69 | model.train() 70 | 71 | start_time = time.time() 72 | for batch_idx, (inputs, target) in enumerate(train_loader): 73 | output = model(inputs.cuda()) 74 | loss = criterion(output, target.cuda()) 75 | 76 | optimizer.zero_grad() 77 | loss.backward() 78 | optimizer.step() 79 | 80 | if batch_idx % cfg.log_interval == 0: 81 | step = len(train_loader) * epoch + batch_idx 82 | duration = time.time() - start_time 83 | 84 | print('%s epoch: %d step: %d cls_loss= %.5f (%d samples/sec)' % 85 | (datetime.now(), epoch, batch_idx, loss.item(), 86 | cfg.train_batch_size * cfg.log_interval / duration)) 87 | 88 | start_time = time.time() 89 | summary_writer.add_scalar('cls_loss', loss.item(), step) 90 | summary_writer.add_scalar('learning rate', optimizer.param_groups[0]['lr'], step) 91 | 92 | def test(epoch): 93 | model.eval() 94 | correct = 0 95 | with torch.no_grad(): 96 | for inputs, target in test_loader: 97 | output = model(inputs.cuda()) 98 | pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability 99 | correct += pred.eq(target.cuda().data.view_as(pred)).cpu().sum().item() 100 | 101 | acc = 100. * correct / len(test_loader.dataset) 102 | print('%s Precision@1 ==> %.2f%% \n' % (datetime.now(), acc)) 103 | summary_writer.add_scalar('Precision@1', acc, global_step=epoch) 104 | return 105 | 106 | for epoch in range(cfg.max_epochs): 107 | lr_schedulr.step(epoch) 108 | train(epoch) 109 | test(epoch) 110 | torch.save(model.state_dict(), os.path.join(cfg.ckpt_dir, 'checkpoint.t7')) 111 | print('checkpoint saved to %s !' % os.path.join(cfg.ckpt_dir, 'checkpoint.t7')) 112 | 113 | summary_writer.close() 114 | 115 | 116 | if __name__ == '__main__': 117 | main() 118 | -------------------------------------------------------------------------------- /nets/cifar_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class PreActBlock_conv(nn.Module): 7 | 8 | def __init__(self, in_planes, out_planes, stride=1): 9 | super(PreActBlock_conv, self).__init__() 10 | self.bn0 = nn.BatchNorm2d(in_planes) 11 | self.conv0 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 12 | self.bn1 = nn.BatchNorm2d(out_planes) 13 | self.conv1 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=False) 14 | 15 | self.skip_conv = None 16 | if stride != 1: 17 | self.skip_conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False) 18 | self.skip_bn = nn.BatchNorm2d(out_planes) 19 | 20 | def forward(self, x): 21 | out = F.relu(self.bn0(x)) 22 | 23 | if self.skip_conv is not None: 24 | shortcut = self.skip_conv(out) 25 | shortcut = self.skip_bn(shortcut) 26 | else: 27 | shortcut = x 28 | 29 | out = self.conv0(out) 30 | out = F.relu(self.bn1(out)) 31 | out = self.conv1(out) 32 | out += shortcut 33 | return out 34 | 35 | 36 | class PreActResNet(nn.Module): 37 | def __init__(self, block, num_units, num_classes): 38 | super(PreActResNet, self).__init__() 39 | 40 | self.in_planes = 16 41 | 42 | self.conv0 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 43 | last_c = 16 44 | 45 | strides = [1] * num_units[0] + \ 46 | [2] + [1] * (num_units[1] - 1) + \ 47 | [2] + [1] * (num_units[2] - 1) 48 | channels = [16] * num_units[0] + \ 49 | [32] * num_units[1] + \ 50 | [64] * num_units[2] 51 | 52 | self.blocks = nn.ModuleList() 53 | for channel, stride in zip(channels, strides): 54 | self.blocks.append(block(last_c, channel, stride)) 55 | last_c = channel 56 | 57 | self.bn = nn.BatchNorm2d(64) 58 | self.logit = nn.Linear(64, num_classes) 59 | 60 | def forward(self, x): 61 | out = self.conv0(x) 62 | for block in self.blocks: 63 | out = block(out) 64 | out = self.bn(out) 65 | out = out.mean(3).mean(2) 66 | out = self.logit(out) 67 | return out 68 | 69 | 70 | def resnet20(num_classes=10): 71 | return PreActResNet(PreActBlock_conv, [3, 3, 3], num_classes=num_classes) 72 | 73 | 74 | def resnet56(num_classes=10): 75 | return PreActResNet(PreActBlock_conv, [9, 9, 9], num_classes=num_classes) 76 | 77 | 78 | if __name__ == '__main__': 79 | 80 | def hook(self, input, output): 81 | print(output.data.cpu().numpy().shape) 82 | 83 | 84 | net = resnet56() 85 | for m in net.modules(): 86 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 87 | m.register_forward_hook(hook) 88 | 89 | y = net(torch.randn(1, 3, 32, 32)) 90 | print(y.size()) 91 | -------------------------------------------------------------------------------- /nets/cifar_vgg.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class standard_block(nn.Module): 8 | def __init__(self, in_channels, out_channels): 9 | super(standard_block, self).__init__() 10 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False) 11 | self.bn = nn.BatchNorm2d(out_channels) 12 | self.relu = nn.ReLU(inplace=True) 13 | 14 | def forward(self, x): 15 | out = self.conv2d(x) 16 | out = self.bn(out) 17 | out = self.relu(out) 18 | return out 19 | 20 | 21 | class VGG(nn.Module): 22 | def __init__(self, conv_config, fc_config, num_classes): 23 | super(VGG, self).__init__() 24 | layers = [] 25 | in_channels = 3 26 | 27 | for v in conv_config: 28 | if v == 'M': 29 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 30 | else: 31 | layers += [standard_block(in_channels, v)] 32 | in_channels = v 33 | 34 | self.conv = nn.Sequential(*layers) 35 | 36 | self.fc = nn.Sequential(nn.Dropout(), 37 | nn.Linear(in_channels, fc_config[0]), 38 | nn.ReLU(True), 39 | nn.Dropout(), 40 | nn.Linear(fc_config[0], fc_config[1]), 41 | nn.ReLU(True), 42 | nn.Linear(fc_config[1], num_classes), ) 43 | 44 | # Initialize weights 45 | for m in self.modules(): 46 | if isinstance(m, nn.Conv2d): 47 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 48 | m.weight.data.normal_(0, math.sqrt(2. / n)) 49 | # m.bias.data.zero_() 50 | 51 | def forward(self, x): 52 | x = self.conv(x) 53 | x = x.view(x.size(0), -1) 54 | x = self.fc(x) 55 | return x 56 | 57 | 58 | def vgg16(num_classes=10): 59 | """VGG 16-layer model (configuration "D")""" 60 | return VGG([64, 64, 'M', 61 | 128, 128, 'M', 62 | 256, 256, 256, 'M', 63 | 512, 512, 512, 'M', 64 | 512, 512, 512, 'M'], 65 | [512, 512], num_classes) 66 | 67 | 68 | if __name__ == '__main__': 69 | def hook(self, input, output): 70 | print(output.data.cpu().numpy().shape) 71 | 72 | net = vgg16() 73 | 74 | torch.save(net.state_dict(),'../ckpt/vgg16_baseline_p/init.t7') 75 | 76 | for m in net.modules(): 77 | if isinstance(m, nn.Conv2d): 78 | m.register_forward_hook(hook) 79 | 80 | y = net(torch.randn(1, 3, 32, 32)) 81 | print(y.size()) 82 | -------------------------------------------------------------------------------- /nets/imgnet_alexnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class AlexNet(nn.Module): 7 | def __init__(self, num_classes=1000): 8 | super(AlexNet, self).__init__() 9 | 10 | self.conv1 = nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=2, bias=False) 11 | self.bn1 = nn.BatchNorm2d(96) 12 | 13 | self.conv2 = nn.Conv2d(96, 256, kernel_size=5, padding=2, bias=False) 14 | self.bn2 = nn.BatchNorm2d(256) 15 | 16 | self.conv3 = nn.Conv2d(256, 384, kernel_size=3, padding=1) 17 | self.conv4 = nn.Conv2d(384, 384, kernel_size=3, padding=1) 18 | self.conv5 = nn.Conv2d(384, 256, kernel_size=3, padding=1) 19 | 20 | self.fc6 = nn.Linear(256 * 6 * 6, 4096) 21 | self.fc7 = nn.Linear(4096, 4096) 22 | self.logit = nn.Linear(4096, num_classes) 23 | 24 | def forward(self, x): 25 | x = self.conv1(x) 26 | x = F.relu(self.bn1(x)) 27 | x = F.max_pool2d(x, kernel_size=3, stride=2) 28 | 29 | x = self.conv2(x) 30 | x = F.relu(self.bn2(x)) 31 | x = F.max_pool2d(x, kernel_size=3, stride=2) 32 | 33 | x = F.relu(self.conv3(x)) 34 | x = F.relu(self.conv4(x)) 35 | x = F.relu(self.conv5(x)) 36 | x = F.max_pool2d(x, kernel_size=3, stride=2) 37 | 38 | x = x.view(x.size(0), -1) 39 | x = F.dropout(x) 40 | x = F.relu(self.fc6(x)) 41 | x = F.dropout(x) 42 | x = F.relu(self.fc7(x)) 43 | x = self.logit(x) 44 | 45 | return x 46 | 47 | 48 | def alexnet(): 49 | return AlexNet() 50 | 51 | 52 | if __name__ == '__main__': 53 | def hook(self, input, output): 54 | print(output.data.cpu().numpy().shape) 55 | 56 | net = alexnet() 57 | 58 | for m in net.modules(): 59 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 60 | m.register_forward_hook(hook) 61 | 62 | y = net(torch.randn(1, 3, 224, 224)) 63 | print(y.size()) 64 | -------------------------------------------------------------------------------- /nets/imgnet_mobilenet_v1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Block(nn.Module): 7 | def __init__(self, in_planes, out_planes, stride=1): 8 | super(Block, self).__init__() 9 | self.conv1 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=in_planes, bias=False) 10 | self.bn1 = nn.BatchNorm2d(in_planes) 11 | self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 12 | self.bn2 = nn.BatchNorm2d(out_planes) 13 | 14 | def forward(self, x): 15 | out = F.relu(self.bn1(self.conv1(x))) 16 | out = F.relu(self.bn2(self.conv2(out))) 17 | return out 18 | 19 | 20 | class MobileNet(nn.Module): 21 | def __init__(self, conv_cfg, num_classes=10): 22 | super(MobileNet, self).__init__() 23 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False) 24 | self.bn1 = nn.BatchNorm2d(32) 25 | 26 | layers = [] 27 | in_planes = 32 28 | for x in conv_cfg: 29 | out_planes = x if isinstance(x, int) else x[0] 30 | stride = 1 if isinstance(x, int) else x[1] 31 | layers.append(Block(in_planes, out_planes, stride)) 32 | in_planes = out_planes 33 | 34 | self.conv = nn.Sequential(*layers) 35 | self.fc = nn.Linear(1024, num_classes) 36 | 37 | def forward(self, x): 38 | out = F.relu(self.bn1(self.conv1(x))) 39 | out = self.conv(out) 40 | out = out.mean(2).mean(2) 41 | out = out.view(out.size(0), -1) 42 | out = self.fc(out) 43 | return out 44 | 45 | 46 | def mobilenet_v1(): 47 | # (128,2) means conv planes=128, conv stride=2, by default conv stride=1 48 | return MobileNet(conv_cfg=[64, 49 | (128, 2), 128, 50 | (256, 2), 256, 51 | (512, 2), 512, 512, 512, 512, 512, 52 | (1024, 2), 1024]) 53 | 54 | 55 | if __name__ == '__main__': 56 | def hook(self, input, output): 57 | print(output.data.cpu().numpy().shape) 58 | 59 | 60 | net = mobilenet_v1() 61 | for m in net.modules(): 62 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 63 | m.register_forward_hook(hook) 64 | 65 | y = net(torch.randn(1, 3, 224, 224)) 66 | print(y.size()) 67 | -------------------------------------------------------------------------------- /nets/imgnet_mobilenet_v2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | 7 | class InvertedResidual(nn.Module): 8 | def __init__(self, inplanes, planes, stride, expand_ratio): 9 | super(InvertedResidual, self).__init__() 10 | self.skip = stride == 1 and inplanes == planes 11 | hidden_dim = round(inplanes * expand_ratio) 12 | 13 | # pw 14 | self.conv1 = \ 15 | nn.Sequential(nn.Conv2d(inplanes, hidden_dim, 1, 1, 0, bias=False), 16 | nn.BatchNorm2d(hidden_dim), 17 | nn.ReLU6(inplace=True)) \ 18 | if expand_ratio > 1 else nn.Sequential() 19 | # dw 20 | self.conv2 = \ 21 | nn.Sequential(nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 22 | nn.BatchNorm2d(hidden_dim), 23 | nn.ReLU6(inplace=True)) 24 | # pw-linear 25 | self.conv3 = \ 26 | nn.Sequential(nn.Conv2d(hidden_dim, planes, 1, 1, 0, bias=False), 27 | nn.BatchNorm2d(planes)) 28 | 29 | def forward(self, x): 30 | out = self.conv1(x) 31 | out = self.conv2(out) 32 | out = self.conv3(out) 33 | 34 | if self.skip: 35 | out = x + out 36 | 37 | return out 38 | 39 | 40 | class MobileNetV2(nn.Module): 41 | def __init__(self, block=InvertedResidual, n_class=1000, input_size=224, width_mult=1.): 42 | super(MobileNetV2, self).__init__() 43 | input_channel = int(32 * width_mult) 44 | last_channel = int(1280 * width_mult) 45 | # [expand_ratio, channel, num_blocks, stride] 46 | interverted_residual_setting = [[1, 16, 1, 1], 47 | [6, 24, 2, 2], 48 | [6, 32, 3, 2], 49 | [6, 64, 4, 2], 50 | [6, 96, 3, 1], 51 | [6, 160, 3, 2], 52 | [6, 320, 1, 1]] 53 | 54 | # building first layer 55 | self.conv_first = nn.Conv2d(3, input_channel, 3, 2, 1, bias=False) 56 | self.bn_first = nn.BatchNorm2d(input_channel) 57 | 58 | # building inverted residual blocks 59 | self.features = nn.ModuleList() 60 | for t, c, n, s in interverted_residual_setting: 61 | output_channel = int(c * width_mult) 62 | for i in range(n): 63 | self.features.append(block(input_channel, output_channel, s if i == 0 else 1, expand_ratio=t)) 64 | input_channel = output_channel 65 | 66 | # building last several layers 67 | self.conv_last = nn.Conv2d(input_channel, last_channel, 1, 1, 0, bias=False) 68 | self.bn_last = nn.BatchNorm2d(last_channel) 69 | 70 | # building classifier 71 | self.classifier = nn.Sequential(nn.Dropout(0.2), 72 | nn.Linear(last_channel, n_class)) 73 | 74 | self._initialize_weights() 75 | 76 | def forward(self, x): 77 | x = self.conv_first(x) 78 | x = F.relu6(self.bn_first(x), inplace=True) 79 | 80 | for block in self.features: 81 | x = block(x) 82 | 83 | x = self.conv_last(x) 84 | x = F.relu6(self.bn_last(x), inplace=True) 85 | 86 | x = x.mean(3).mean(2) 87 | x = self.classifier(x) 88 | return x 89 | 90 | def _initialize_weights(self): 91 | for m in self.modules(): 92 | if isinstance(m, nn.Conv2d): 93 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 94 | m.weight.data.normal_(0, math.sqrt(2. / n)) 95 | if m.bias is not None: 96 | m.bias.data.zero_() 97 | elif isinstance(m, nn.BatchNorm2d): 98 | m.weight.data.fill_(1) 99 | m.bias.data.zero_() 100 | elif isinstance(m, nn.Linear): 101 | n = m.weight.size(1) 102 | m.weight.data.normal_(0, 0.01) 103 | m.bias.data.zero_() 104 | 105 | 106 | if __name__ == '__main__': 107 | def hook(self, input, output): 108 | print(output.data.cpu().numpy().shape) 109 | 110 | 111 | net = MobileNetV2() 112 | for m in net.modules(): 113 | m.register_forward_hook(hook) 114 | 115 | y = net(torch.randn(1, 3, 224, 224)) 116 | print(y.size()) 117 | -------------------------------------------------------------------------------- /nets/imgnet_resnet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class BasicBlock(nn.Module): 9 | expansion = 1 10 | 11 | def __init__(self, inplanes, planes, stride=1): 12 | super(BasicBlock, self).__init__() 13 | 14 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 15 | self.bn1 = nn.BatchNorm2d(planes) 16 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 17 | self.bn2 = nn.BatchNorm2d(planes) 18 | 19 | self.skip_conv = None 20 | if stride != 1 or inplanes != planes: 21 | self.skip_conv = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False) 22 | self.skip_bn = nn.BatchNorm2d(planes) 23 | 24 | def forward(self, x): 25 | residual = x 26 | 27 | out = self.conv1(x) 28 | out = self.bn1(out) 29 | out = F.relu(out, inplace=True) 30 | 31 | out = self.conv2(out) 32 | out = self.bn2(out) 33 | 34 | if self.skip_conv is not None: 35 | residual = self.skip_conv(x) 36 | residual = self.skip_bn(residual) 37 | 38 | out += residual 39 | out = F.relu(out, inplace=True) 40 | 41 | return out 42 | 43 | 44 | class Bottleneck(nn.Module): 45 | expansion = 4 46 | 47 | def __init__(self, inplanes, planes, stride=1): 48 | super(Bottleneck, self).__init__() 49 | 50 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 51 | self.bn1 = nn.BatchNorm2d(planes) 52 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 53 | self.bn2 = nn.BatchNorm2d(planes) 54 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 55 | self.bn3 = nn.BatchNorm2d(planes * 4) 56 | 57 | self.skip_conv = None 58 | if stride != 1 or inplanes != planes * 4: 59 | self.skip_conv = nn.Conv2d(inplanes, planes * 4, kernel_size=1, stride=stride, bias=False) 60 | self.skip_bn = nn.BatchNorm2d(planes * 4) 61 | 62 | def forward(self, x): 63 | residual = x 64 | 65 | out = self.conv1(x) 66 | out = self.bn1(out) 67 | out = F.relu(out, inplace=True) 68 | 69 | out = self.conv2(out) 70 | out = self.bn2(out) 71 | out = F.relu(out, inplace=True) 72 | 73 | out = self.conv2(out) 74 | out = self.bn2(out) 75 | 76 | if self.skip_conv is not None: 77 | out = self.skip_conv(x) 78 | residual = self.skip_bn(out) 79 | 80 | out += residual 81 | out = F.relu(out, inplace=True) 82 | return out 83 | 84 | 85 | class ResNet(nn.Module): 86 | def __init__(self, block, layers, num_classes=1000): 87 | super(ResNet, self).__init__() 88 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 89 | self.bn1 = nn.BatchNorm2d(64) 90 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 91 | 92 | strides = [1] * layers[0] + \ 93 | [2] + [1] * (layers[1] - 1) + \ 94 | [2] + [1] * (layers[2] - 1) + \ 95 | [2] + [1] * (layers[3] - 1) 96 | out_channels = [64] * layers[0] + \ 97 | [128] * layers[1] + \ 98 | [256] * layers[2] + \ 99 | [512] * layers[3] 100 | 101 | self.layers = nn.ModuleList() 102 | last_c = 64 103 | for channel, stride in zip(out_channels, strides): 104 | self.layers.append(block(last_c, channel, stride)) 105 | last_c = channel * block.expansion 106 | 107 | self.fc = nn.Linear(512 * block.expansion, num_classes) 108 | 109 | for m in self.modules(): 110 | if isinstance(m, nn.Conv2d): 111 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 112 | m.weight.data.normal_(0, np.sqrt(2. / n)) 113 | elif isinstance(m, nn.BatchNorm2d): 114 | m.weight.data.fill_(1) 115 | m.bias.data.zero_() 116 | 117 | def forward(self, x): 118 | x = self.conv1(x) 119 | x = F.relu(self.bn1(x), inplace=True) 120 | x = self.maxpool(x) 121 | 122 | for layer in self.layers: 123 | x = layer(x) 124 | 125 | x = x.mean(3).mean(2) 126 | x = self.fc(x) 127 | 128 | return x 129 | 130 | 131 | def resnet18(num_classes=1000): 132 | return ResNet(BasicBlock, [2, 2, 2, 2], num_classes) 133 | 134 | 135 | def resnet34(num_classes=1000): 136 | return ResNet(BasicBlock, [3, 4, 6, 3], num_classes) 137 | 138 | 139 | def resnet50(num_classes=1000): 140 | return ResNet(Bottleneck, [3, 4, 6, 3], num_classes) 141 | 142 | 143 | if __name__ == '__main__': 144 | def hook(self, input, output): 145 | print(output.data.cpu().numpy().shape) 146 | 147 | 148 | net = resnet18() 149 | 150 | print('total num of parameters: %.5f' % 151 | (sum(p[1].data.nelement() for p in net.named_parameters()) / 1024 / 1024)) 152 | 153 | for m in net.modules(): 154 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 155 | m.register_forward_hook(hook) 156 | 157 | y = net(torch.randn(1, 3, 224, 224)) 158 | print(y.size()) 159 | -------------------------------------------------------------------------------- /nets/imgnet_vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | 6 | class standard_block(nn.Module): 7 | def __init__(self, in_channels, out_channels): 8 | super(standard_block, self).__init__() 9 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False) 10 | self.bn = nn.BatchNorm2d(out_channels) 11 | self.relu = nn.ReLU(inplace=True) 12 | 13 | def forward(self, x): 14 | out = self.conv2d(x) 15 | out = self.bn(out) 16 | out = self.relu(out) 17 | return out 18 | 19 | 20 | class VGG(nn.Module): 21 | def __init__(self, conv_config, fc_config, num_classes=1000): 22 | super(VGG, self).__init__() 23 | layers = [] 24 | in_channels = 3 25 | 26 | for v in conv_config: 27 | if v == 'M': 28 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 29 | else: 30 | layers += [standard_block(in_channels, v)] 31 | in_channels = v 32 | 33 | self.conv = nn.Sequential(*layers) 34 | 35 | self.fc = nn.Sequential(nn.Linear(in_channels * 7 * 7, fc_config[0]), 36 | nn.ReLU(True), 37 | nn.Dropout(), 38 | nn.Linear(fc_config[0], fc_config[1]), 39 | nn.ReLU(True), 40 | nn.Dropout(), 41 | nn.Linear(fc_config[1], num_classes)) 42 | 43 | def forward(self, x): 44 | x = self.conv(x) 45 | x = x.view(x.size(0), -1) 46 | x = self.fc(x) 47 | return x 48 | 49 | 50 | def vgg16(): 51 | """VGG 16-layer model (configuration "D") with batch normalization""" 52 | return VGG([64, 64, 'M', 53 | 128, 128, 'M', 54 | 256, 256, 256, 'M', 55 | 512, 512, 512, 'M', 56 | 512, 512, 512, 'M'], 57 | [4096, 4096]) 58 | 59 | 60 | if __name__ == '__main__': 61 | def hook(self, input, output): 62 | print(output.data.cpu().numpy().shape) 63 | 64 | net = vgg16() 65 | for m in net.modules(): 66 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 67 | m.register_forward_hook(hook) 68 | 69 | y = net(torch.randn(1, 3, 224, 224)) 70 | print(y.size()) 71 | -------------------------------------------------------------------------------- /nets/mnist_lenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class LeNet(nn.Module): 7 | def __init__(self): 8 | super(LeNet, self).__init__() 9 | self.conv1 = nn.Conv2d(1, 20, kernel_size=5, bias=False) 10 | self.conv2 = nn.Conv2d(20, 50, kernel_size=5, bias=False) 11 | self.fc1 = nn.Linear(800, 500) 12 | self.fc2 = nn.Linear(500, 10) 13 | 14 | def forward(self, x): 15 | out = self.conv1(x) 16 | out = F.relu(F.max_pool2d(out, 2)) 17 | out = self.conv2(out) 18 | out = F.relu(F.max_pool2d(out, 2)) 19 | out = out.view(-1, 800) 20 | out = self.fc1(out) 21 | out = F.relu(out) 22 | out = self.fc2(out) 23 | return out 24 | 25 | 26 | if __name__ == '__main__': 27 | def hook(self, input, output): 28 | print(output.data.cpu().numpy().shape) 29 | 30 | 31 | net = LeNet() 32 | for m in net.modules(): 33 | m.register_forward_hook(hook) 34 | 35 | y = net(torch.randn(1, 1, 28, 28)) 36 | print(y.size()) 37 | -------------------------------------------------------------------------------- /utils/preprocessing.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as transforms 2 | 3 | 4 | def minst_transform(is_training=True): 5 | if is_training: 6 | transform_list = transforms.Compose([transforms.ToTensor(), 7 | transforms.Normalize((0.1307,), (0.3081,))]) 8 | else: 9 | transform_list = transforms.Compose([transforms.ToTensor(), 10 | transforms.Normalize((0.1307,), (0.3081,))]) 11 | return transform_list 12 | 13 | 14 | def cifar_transform(is_training=True): 15 | # Data 16 | if is_training: 17 | transform_list = transforms.Compose([transforms.RandomHorizontalFlip(), 18 | transforms.Pad(4, padding_mode='reflect'), 19 | transforms.RandomCrop(32, padding=0), 20 | transforms.ToTensor(), 21 | transforms.Normalize((0.4914, 0.4822, 0.4465), 22 | (0.2023, 0.1994, 0.2010))]) 23 | 24 | else: 25 | transform_list = transforms.Compose([transforms.ToTensor(), 26 | transforms.Normalize((0.4914, 0.4822, 0.4465), 27 | (0.2023, 0.1994, 0.2010))]) 28 | 29 | return transform_list 30 | 31 | 32 | def imgnet_transform(is_training=True): 33 | if is_training: 34 | transform_list = transforms.Compose([transforms.RandomResizedCrop(224), 35 | transforms.RandomHorizontalFlip(), 36 | transforms.ToTensor(), 37 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 38 | std=[0.229, 0.224, 0.225])]) 39 | else: 40 | transform_list = transforms.Compose([transforms.Resize(256), 41 | transforms.CenterCrop(224), 42 | transforms.ToTensor(), 43 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 44 | std=[0.229, 0.224, 0.225])]) 45 | return transform_list 46 | -------------------------------------------------------------------------------- /utils/summary.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from datetime import datetime 4 | 5 | # return a fake summarywriter if tensorbaordX is not installed 6 | 7 | try: 8 | from tensorboardX import SummaryWriter 9 | except ImportError: 10 | class SummaryWriter: 11 | def __init__(self, log_dir=None, comment='', **kwargs): 12 | print('\nunable to import tensorboardX, log will be recorded in pickle format!\n') 13 | self.log_dir = log_dir if log_dir is not None else './logs' 14 | os.makedirs('./logs', exist_ok=True) 15 | self.logs = {'comment': comment} 16 | return 17 | 18 | def add_scalar(self, tag, scalar_value, global_step=None, walltime=None): 19 | if tag in self.logs: 20 | self.logs[tag].append((scalar_value, global_step, walltime)) 21 | else: 22 | self.logs[tag] = [(scalar_value, global_step, walltime)] 23 | return 24 | 25 | def close(self): 26 | timestamp = str(datetime.now()).replace(' ', '_').replace(':', '_') 27 | with open(os.path.join(self.log_dir, 'log_%s.pickle' % timestamp), 'wb') as handle: 28 | pickle.dump(self.logs, handle, protocol=pickle.HIGHEST_PROTOCOL) 29 | return 30 | 31 | if __name__ == '__main__': 32 | sw = SummaryWriter() 33 | sw.close() 34 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | class DisablePrint: 2 | def __enter__(self): 3 | self._original_stdout = sys.stdout 4 | sys.stdout = open(os.devnull, 'w') 5 | 6 | def __exit__(self, exc_type, exc_val, exc_tb): 7 | sys.stdout.close() 8 | sys.stdout = self._original_stdout 9 | --------------------------------------------------------------------------------