├── .gitignore ├── LICENSE ├── README.md ├── args.py ├── data ├── __init__.py ├── dataset_1.py └── dataset_2.py ├── main.py ├── model.py ├── models ├── __init__.py ├── resnet.py └── vgg.py ├── pretrained_models └── download.sh ├── test.py ├── tools ├── __init__.py └── visualize.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Shi Husen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Driver Posture Classification 2 | 3 | This is a PyTorch code for **Driver Posture Classification** task. We use the [AUC Distracted Driver Dataset](https://devyhia.github.io/projects/auc-distracted-driver-dataset). The dataset was captured to develop the state-of-the-art in detection of distracted drivers. Here are some samples from the dataset: 4 |

5 | 6 |

7 | 8 | The task is to classify an image to one of these pre-defined categories, namely "Drive Safe", "Talk Passenger", "Text Right", "Drink", and etc. We use a pretrained resnet34 model to achieve comparable performance with the orignal paper [Real-time Distracted Driver Posture Classification](https://arxiv.org/abs/1706.09498). The classification accuracy is about 95%. 9 | 10 | 11 | ## Usage 12 | 13 | ### Requirements 14 | 15 | * python 3.5+ 16 | * pytorch 0.4 17 | * visdom (optional) 18 | 19 | 20 | ### Steps 21 | 22 | 0. Download the dataset and its training and testing splits (train.csv and test.csv). Put them in a directory together. 23 | 1. Clone the repository 24 | 25 | `git clone https://github.com/husencd/DriverPostureClassification.git` 26 | 27 | `cd DriverPostureClassification` 28 | 29 | 2. Download the resnet model pretrained on ImageNet from [pytorch official model urls](https://download.pytorch.org/models/). 30 | 31 | `cd pretrained_models` 32 | 33 | `sh download.sh` 34 | 35 | 3. Now you can train/fine-tune the model 36 | 37 | `cd ..` 38 | 39 | `python main.py [--model resnet] [--model_depth 34]` 40 | 41 | If you want to monitor the training process, use visdom 42 | 43 | `python -m visdom.server` 44 | 45 | 46 | ## Reference 47 | 48 | * Our code is partially based on https://github.com/chenyuntc/pytorch-best-practice. 49 | -------------------------------------------------------------------------------- /args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def parse_args(): 5 | parser = argparse.ArgumentParser(description='PyTorch Driver Posture Classification') 6 | 7 | # path 8 | parser.add_argument('--data_path', default='/home/husencd/Downloads/dataset/driver', type=str, 9 | help='Driver data directory path') 10 | parser.add_argument('--root_path', default='/home/husencd/husen/pytorch/learn/DriverPostureClassification', type=str, 11 | help='Project root directory path') 12 | parser.add_argument('--result_path', default='results', type=str, 13 | help='Result directory path') 14 | parser.add_argument('--checkpoint_path', default='checkpoints', type=str, 15 | help='Checkpoint directory path (snapshot)') 16 | parser.add_argument('--resume_path', default='', type=str, 17 | help='Saved model (checkpoint) path of previous training') 18 | 19 | # I/O 20 | parser.add_argument('--input_size', default=224, type=int, 21 | help='Input size of image') 22 | parser.add_argument('--n_classes', default=1000, type=int, 23 | help='Number of classes (ImageNet: 1000,)') 24 | parser.add_argument('--n_finetune_classes', default=10, type=int, 25 | help='Number of classes for fine-tuning, n_classes is set to the number when pre-training') 26 | 27 | # batch size and epoch 28 | parser.add_argument('--batch_size', default=64, type=int, 29 | help='Batch Size') 30 | parser.add_argument('--test_batch_size', default=64, type=int, 31 | help='Test batch Size') 32 | parser.add_argument('--epochs', default=50, type=int, 33 | help='Number of total epochs to run') 34 | parser.add_argument('--begin_epoch', default=1, type=int, 35 | help='Training begins at this epoch. Previous trained model indicated by resume_path is loaded.') 36 | 37 | # about model configuration 38 | parser.add_argument('--model', default='resnet', type=str, 39 | help='(vgg | resnet | resnext | densenet)') 40 | parser.add_argument('--model_depth', default=34, type=int, 41 | help='Depth of resnet (10 | 18 | 34 | 50 | 101 | 152)') 42 | 43 | # about optimizer 44 | parser.add_argument('--lr', default=0.001, type=float, 45 | help='Initial learning rate (divided by 10 while training by lr scheduler)') 46 | parser.add_argument('--lr_mult1', default=0.1, type=float, 47 | help='Multiplication factor of learning rate in those pre-trained layers') 48 | parser.add_argument('--lr_mult2', default=1, type=float, 49 | help='Multiplication factor of learning rate in those newly-created layers') 50 | parser.add_argument('--lr_patience', default=10, type=int, 51 | help='Patience of LR scheduler. See documentation of ReduceLROnPlateau.') 52 | parser.add_argument('--momentum', default=0.9, type=float, 53 | help='Momentum') 54 | parser.add_argument('--weight_decay', default=5e-4, type=float, 55 | help='Weight decay') 56 | 57 | # train, val, test, fine-tune 58 | parser.add_argument('--train', action='store_true', default=True, 59 | help='If true, training is performed.') 60 | parser.add_argument('--val', action='store_true', default=True, 61 | help='If true, validation is performed.') 62 | parser.add_argument('--test', action='store_true', default=True, 63 | help='If true, test is performed.') 64 | parser.add_argument('--finetune', action='store_true', default=True, 65 | help='If True, fine-tune on a model that has been pre-trained on ImageNet') 66 | parser.add_argument('--ft_begin_index', default=0, type=int, 67 | help='Begin block index of fine-tuning') 68 | 69 | # training log and checkpoint 70 | parser.add_argument('--log_interval', default=10, type=int, 71 | help='How many batches to wait before logging training status') 72 | parser.add_argument('--checkpoint_interval', default=20, type=int, 73 | help='Trained model is saved at every this epochs.') 74 | 75 | # about device 76 | parser.add_argument('--use_cuda', action='store_true', default=True, 77 | help='If False, cuda is not used.') 78 | parser.add_argument('--num_workers', default=4, type=int, 79 | help='Number of threads for multi-thread loading') 80 | 81 | # random number seed 82 | parser.add_argument('--manual_seed', default=1, type=int, 83 | help='Manually set random seed') 84 | 85 | # visdom 86 | parser.add_argument('--env', default='default', type=str, 87 | help='Visdom enviroment') 88 | 89 | args = parser.parse_args() 90 | 91 | return args 92 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset_1 import Driver # Divide the dataset into 2 parts only, i.e. train set and test set. 2 | # from .dataset_2 import Driver # Divide the dataset into 3 parts, i.e. train set, val set and test set. 3 | -------------------------------------------------------------------------------- /data/dataset_1.py: -------------------------------------------------------------------------------- 1 | """ 2 | Divide the dataset into 2 parts only, i.e. train set and test set. 3 | """ 4 | 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | import os 8 | from PIL import Image 9 | import numpy as np 10 | 11 | 12 | class Driver(Dataset): 13 | def __init__(self, root, transform=None, target_transform=None, train=True, test=False): 14 | self.root = root 15 | self.transform = transform 16 | self.target_transform = target_transform 17 | self.train = train 18 | self.test = test 19 | 20 | if self.test: 21 | with open(os.path.join(self.root, 'test.csv'), 'r') as f: 22 | lines = f.readlines()[1:] 23 | dataset = [] 24 | for line in lines: 25 | dataset.append(line.strip().split(',')) 26 | else: 27 | with open(os.path.join(self.root, 'train.csv'), 'r') as f: 28 | lines = f.readlines()[1:] 29 | dataset = [] 30 | for line in lines: 31 | dataset.append(line.strip().split(',')) 32 | 33 | dataset = np.array(dataset) 34 | self.imgs = list(map(lambda x: os.path.join(self.root, x), dataset[:, 0])) 35 | self.target = list(map(int, dataset[:, 1])) 36 | 37 | if transform is None: 38 | normalize = transforms.Normalize( 39 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 40 | 41 | if self.test: 42 | self.transform = transforms.Compose([ 43 | transforms.Resize(256), 44 | transforms.CenterCrop(224), 45 | transforms.ToTensor(), normalize 46 | ]) 47 | else: 48 | self.transform = transforms.Compose([ 49 | transforms.Resize(256), 50 | transforms.RandomResizedCrop(224, scale=(0.25, 1)), 51 | transforms.RandomHorizontalFlip(), 52 | transforms.ToTensor(), normalize 53 | ]) 54 | 55 | def __getitem__(self, index): 56 | img_path = self.imgs[index] 57 | target = self.target[index] 58 | img = Image.open(img_path) 59 | 60 | if self.transform is not None: 61 | img = self.transform(img) 62 | 63 | if self.target_transform is not None: 64 | target = self.target_transform(target) 65 | 66 | return img, target 67 | 68 | def __len__(self): 69 | return len(self.imgs) 70 | 71 | 72 | if __name__ == '__main__': 73 | driver = Driver('/home/husencd/Downloads/dataset/driver', train=True) 74 | print(driver.__getitem__(1)) 75 | print(driver.__len__()) # 12977 76 | driver = Driver('/home/husencd/Downloads/dataset/driver', train=False, test=True) 77 | print(driver.__len__()) # 4331 78 | -------------------------------------------------------------------------------- /data/dataset_2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Divide the dataset into 3 parts, i.e. train set, val set and test set. 3 | """ 4 | 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | import os 8 | from PIL import Image 9 | import numpy as np 10 | 11 | 12 | class Driver(Dataset): 13 | def __init__(self, root, transform=None, target_transform=None, train=True, test=False): 14 | self.root = root 15 | self.transform = transform 16 | self.target_transform = target_transform 17 | self.train = train 18 | self.test = test 19 | 20 | if self.test: 21 | with open(os.path.join(self.root, 'test.csv'), 'r') as f: 22 | lines = f.readlines()[1:] 23 | dataset = [] 24 | for line in lines: 25 | dataset.append(line.strip().split(',')) 26 | else: 27 | with open(os.path.join(self.root, 'train.csv'), 'r') as f: 28 | lines = f.readlines()[1:] 29 | dataset = [] 30 | for line in lines: 31 | dataset.append(line.strip().split(',')) 32 | 33 | num_train = int(0.7 * len(dataset)) 34 | import random 35 | random.seed(1) 36 | for _ in range(10): 37 | dataset = random.sample(dataset, len(dataset)) 38 | if self.train: 39 | dataset = dataset[:num_train] 40 | else: 41 | dataset = dataset[num_train:] 42 | 43 | dataset = np.array(dataset) 44 | self.imgs = list(map(lambda x: os.path.join(self.root, x), dataset[:, 0])) 45 | self.target = list(map(int, dataset[:, 1])) 46 | 47 | if transform is None: 48 | normalize = transforms.Normalize( 49 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 50 | 51 | if self.test or (not self.train): 52 | self.transform = transforms.Compose([ 53 | transforms.Resize(256), 54 | transforms.CenterCrop(224), 55 | transforms.ToTensor(), normalize 56 | ]) 57 | else: 58 | self.transform = transforms.Compose([ 59 | transforms.Resize(256), 60 | transforms.RandomResizedCrop(224, scale=(0.25, 1)), 61 | transforms.RandomHorizontalFlip(), 62 | transforms.ToTensor(), normalize 63 | ]) 64 | 65 | def __getitem__(self, index): 66 | img_path = self.imgs[index] 67 | target = self.target[index] 68 | img = Image.open(img_path) 69 | 70 | if self.transform is not None: 71 | img = self.transform(img) 72 | 73 | if self.target_transform is not None: 74 | target = self.target_transform(target) 75 | 76 | return img, target 77 | 78 | def __len__(self): 79 | return len(self.imgs) 80 | 81 | 82 | if __name__ == '__main__': 83 | driver = Driver('/home/husencd/Downloads/dataset/driver', train=True) 84 | print(driver.__getitem__(1)) 85 | print(driver.__len__()) # 9083 86 | driver = Driver('/home/husencd/Downloads/dataset/driver', train=False, test=False) 87 | print(driver.__len__()) # 3894 88 | driver = Driver('/home/husencd/Downloads/dataset/driver', train=False, test=True) 89 | print(driver.__len__()) # 4331 90 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from torch.utils.data import DataLoader 5 | 6 | import os 7 | import json 8 | 9 | from args import parse_args 10 | from model import get_model_param 11 | from data import Driver 12 | from utils import Logger 13 | from tools import Visualizer 14 | from train import train_epoch, val_epoch 15 | import test 16 | 17 | best_prec1 = 0 18 | best_epoch = 1 19 | 20 | 21 | def main(): 22 | global args, best_prec1, best_epoch 23 | args = parse_args() 24 | 25 | if args.root_path != '': 26 | args.result_path = os.path.join(args.root_path, args.result_path) 27 | args.checkpoint_path = os.path.join(args.root_path, args.checkpoint_path) 28 | if not os.path.exists(args.result_path): 29 | os.mkdir(args.result_path) 30 | if not os.path.exists(args.checkpoint_path): 31 | os.mkdir(args.checkpoint_path) 32 | if args.resume_path: 33 | args.resume_path = os.path.join(args.checkpoint_path, args.resume_path) 34 | 35 | args.arch = '{}{}'.format(args.model, args.model_depth) 36 | 37 | torch.manual_seed(args.manual_seed) 38 | 39 | args.use_cuda = args.use_cuda and torch.cuda.is_available() 40 | 41 | device = torch.device("cuda" if args.use_cuda else "cpu") 42 | 43 | # create model 44 | model, parameters = get_model_param(args) 45 | print(model) 46 | model = model.to(device) 47 | 48 | with open(os.path.join(args.result_path, 'args.json'), 'w') as args_file: 49 | json.dump(vars(args), args_file) 50 | 51 | # define loss function (criterion) and optimizer 52 | criterion = nn.CrossEntropyLoss().to(device) 53 | optimizer = optim.SGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 54 | # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, threshold=0.001, patience=args.lr_patience) 55 | 56 | lr_mult = [] 57 | for param_group in optimizer.param_groups: 58 | lr_mult.append(param_group['lr']) 59 | 60 | # optionally resume from a checkpoint 61 | if args.resume_path: 62 | if os.path.isfile(args.resume_path): 63 | print("=> loading checkpoint '{}'...".format(args.resume_path)) 64 | checkpoint = torch.load(args.resume_path) 65 | args.begin_epoch = checkpoint['epoch'] + 1 66 | model.load_state_dict(checkpoint['model']) 67 | optimizer.load_state_dict(checkpoint['optimizer']) 68 | else: 69 | print("=> no checkpoint found at '{}'".format(args.resume_path)) 70 | 71 | if args.train: 72 | train_dataset = Driver(root=args.data_path, train=True, test=False) 73 | train_loader = DataLoader( 74 | dataset=train_dataset, 75 | batch_size=args.batch_size, 76 | shuffle=True, 77 | num_workers=args.num_workers) 78 | train_logger = Logger( 79 | os.path.join(args.result_path, 'train.log'), 80 | ['epoch', 'loss', 'top1', 'top3', 'lr']) 81 | train_batch_logger = Logger( 82 | os.path.join(args.result_path, 'train_batch.log'), 83 | ['epoch', 'batch', 'iter', 'loss', 'top1', 'top3', 'lr']) 84 | 85 | if args.val: 86 | val_dataset = Driver(root=args.data_path, train=False, test=True) 87 | val_loader = DataLoader( 88 | dataset=val_dataset, 89 | batch_size=args.test_batch_size, 90 | shuffle=False, 91 | num_workers=args.num_workers) 92 | val_logger = Logger( 93 | os.path.join(args.result_path, 'val.log'), 94 | ['epoch', 'loss', 'top1', 'top3']) 95 | 96 | print('=> Start running...') 97 | vis = Visualizer(env=args.env) 98 | for epoch in range(args.begin_epoch, args.epochs + 1): 99 | if args.train: 100 | adjust_learning_rate(optimizer, epoch, lr_mult, args) 101 | train_epoch(epoch, train_loader, model, criterion, optimizer, args, device, train_logger, train_batch_logger, vis) 102 | print('\n') 103 | 104 | if args.val: 105 | val_loss, val_prec1 = val_epoch(epoch, val_loader, model, criterion, args, device, val_logger, vis) 106 | print('\n') 107 | # remember best prec@1 and save checkpoint 108 | if val_prec1 > best_prec1: 109 | best_prec1 = val_prec1 110 | best_epoch = epoch 111 | print('=> Saving current best model...\n') 112 | save_file_path = os.path.join(args.result_path, 'save_best_{}_{}.pth'.format(args.arch, epoch)) 113 | checkpoint = { 114 | 'arch': args.arch, 115 | 'epoch': best_epoch, 116 | 'best_prec1': best_prec1, 117 | 'model': model.state_dict(), 118 | 'optimizer': optimizer.state_dict() 119 | } 120 | torch.save(checkpoint, save_file_path) 121 | 122 | # if args.train and args.val: 123 | # scheduler.step(val_loss) 124 | 125 | if args.test: 126 | test_dataset = Driver(root=args.data_path, train=False, test=True) 127 | test_loader = DataLoader( 128 | dataset=test_dataset, 129 | batch_size=args.test_batch_size, 130 | shuffle=False, 131 | num_workers=args.num_workers) 132 | # # if you only test the model, you need to set the "best_epoch" manually 133 | # best_epoch = 10 # set manually 134 | saved_model_path = os.path.join(args.result_path, 'save_best_{}_{}.pth'.format(args.arch, best_epoch)) 135 | print("Using '{}' for test...".format(saved_model_path)) 136 | checkpoint = torch.load(saved_model_path) 137 | model.load_state_dict(checkpoint['model']) 138 | test.test(test_loader, model, args, device) 139 | 140 | 141 | def adjust_learning_rate(optimizer, epoch, lr_mult, args): 142 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 143 | lr = args.lr * (0.1**((epoch - 1) // 20)) 144 | for i, param_group in enumerate(optimizer.param_groups): 145 | if args.finetune and args.ft_begin_index: 146 | param_group['lr'] = lr * lr_mult[i] 147 | else: 148 | param_group['lr'] = lr 149 | 150 | 151 | if __name__ == '__main__': 152 | main() 153 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import os 5 | from models import resnet 6 | 7 | model_path = { 8 | 'resnet18': 'resnet18-5c106cde.pth', 9 | 'resnet34': 'resnet34-333f7ec4.pth', 10 | 'resnet50': 'resnet50-19c8e357.pth', 11 | 'resnet101': 'resnet101-5d3b4d8f.pth', 12 | 'resnet152': 'resnet152-b121ed2d.pth', 13 | } 14 | 15 | 16 | def get_model_param(args): 17 | # assert args.model in ['resnet', 'vgg'] 18 | 19 | if args.model == 'resnet': 20 | assert args.model_depth in [18, 34, 50, 101, 152] 21 | 22 | from models.resnet import get_fine_tuning_parameters 23 | 24 | if args.model_depth == 18: 25 | model = resnet.resnet18(pretrained=False, input_size=args.input_size, num_classes=args.n_classes) 26 | elif args.model_depth == 34: 27 | model = resnet.resnet34(pretrained=False, input_size=args.input_size, num_classes=args.n_classes) 28 | elif args.model_depth == 50: 29 | model = resnet.resnet50(pretrained=False, input_size=args.input_size, num_classes=args.n_classes) 30 | elif args.model_depth == 101: 31 | model = resnet.resnet101(pretrained=False, input_size=args.input_size, num_classes=args.n_classes) 32 | elif args.model_depth == 152: 33 | model = resnet.resnet152(pretrained=False, input_size=args.input_size, num_classes=args.n_classes) 34 | 35 | # elif args.model == 'vgg': 36 | # pass 37 | 38 | # Load pretrained model here 39 | if args.finetune: 40 | pretrained_model = model_path[args.arch] 41 | args.pretrain_path = os.path.join(args.root_path, 'pretrained_models', pretrained_model) 42 | print("=> loading pretrained model '{}'...".format(pretrained_model)) 43 | 44 | model.load_state_dict(torch.load(args.pretrain_path)) 45 | 46 | # Only modify the last layer 47 | if args.model == 'resnet': 48 | model.fc = nn.Linear(model.fc.in_features, args.n_finetune_classes) 49 | # elif args.model == 'vgg': 50 | # pass 51 | 52 | parameters = get_fine_tuning_parameters(model, args.ft_begin_index, args.lr_mult1, args.lr_mult2) 53 | return model, parameters 54 | 55 | return model, model.parameters() 56 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # from .resnet import * 2 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Almost the same with the offical resnet. 3 | Except that: 4 | We allow different input size, which is 224 by default. 5 | And we can get fine-tuning parameters based on ft_begin_index. 6 | ''' 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.utils.model_zoo as model_zoo 11 | import math 12 | 13 | 14 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 15 | 'resnet152'] 16 | 17 | 18 | model_urls = { 19 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 20 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 21 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 22 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 23 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 24 | } 25 | 26 | 27 | def conv3x3(in_planes, out_planes, stride=1): 28 | """3x3 convolution with padding""" 29 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 30 | padding=1, bias=False) 31 | 32 | 33 | class BasicBlock(nn.Module): 34 | expansion = 1 35 | 36 | def __init__(self, in_planes, planes, stride=1, downsample=None): 37 | super(BasicBlock, self).__init__() 38 | self.conv1 = conv3x3(in_planes, planes, stride) 39 | self.bn1 = nn.BatchNorm2d(planes) 40 | self.relu = nn.ReLU(inplace=True) 41 | self.conv2 = conv3x3(planes, planes) 42 | self.bn2 = nn.BatchNorm2d(planes) 43 | self.downsample = downsample 44 | self.stride = stride 45 | 46 | def forward(self, x): 47 | residual = x 48 | 49 | out = self.conv1(x) 50 | out = self.bn1(out) 51 | out = self.relu(out) 52 | 53 | out = self.conv2(out) 54 | out = self.bn2(out) 55 | 56 | if self.downsample is not None: 57 | residual = self.downsample(x) 58 | 59 | out += residual 60 | out = self.relu(out) 61 | 62 | return out 63 | 64 | 65 | class Bottleneck(nn.Module): 66 | expansion = 4 67 | 68 | def __init__(self, in_planes, planes, stride=1, downsample=None): 69 | super(Bottleneck, self).__init__() 70 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 71 | self.bn1 = nn.BatchNorm2d(planes) 72 | self.conv2 = conv3x3(planes, planes, stride=stride) 73 | self.bn2 = nn.BatchNorm2d(planes) 74 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 75 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 76 | self.relu = nn.ReLU(inplace=True) 77 | self.downsample = downsample 78 | self.stride = stride 79 | 80 | def forward(self, x): 81 | residual = x 82 | 83 | out = self.conv1(x) 84 | out = self.bn1(out) 85 | out = self.relu(out) 86 | 87 | out = self.conv2(out) 88 | out = self.bn2(out) 89 | out = self.relu(out) 90 | 91 | out = self.conv3(out) 92 | out = self.bn3(out) 93 | 94 | if self.downsample is not None: 95 | residual = self.downsample(x) 96 | 97 | out += residual 98 | out = self.relu(out) 99 | 100 | return out 101 | 102 | 103 | class ResNet(nn.Module): 104 | 105 | def __init__(self, block, num_blocks, input_size=224, num_classes=1000): 106 | super(ResNet, self).__init__() 107 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 108 | bias=False) 109 | self.bn1 = nn.BatchNorm2d(64) 110 | self.relu = nn.ReLU(inplace=True) 111 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 112 | self.in_planes = 64 # initial value 113 | self.layer1 = self._make_layer(block, 64, num_blocks[0]) 114 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 115 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 116 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 117 | self.avgpool = nn.AvgPool2d(int(math.ceil(input_size / 32)), stride=1) 118 | self.fc = nn.Linear(512 * block.expansion, num_classes) 119 | 120 | # for m in self.modules(): 121 | # if isinstance(m, nn.Conv2d): 122 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 123 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 124 | # elif isinstance(m, nn.BatchNorm2d): 125 | # m.weight.data.fill_(1) 126 | # m.bias.data.zero_() 127 | 128 | for m in self.modules(): 129 | if isinstance(m, nn.Conv2d): 130 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 131 | if m.bias is not None: 132 | nn.init.constant_(m.bias, 0) 133 | elif isinstance(m, nn.BatchNorm2d): 134 | nn.init.constant_(m.weight, 1) 135 | nn.init.constant_(m.bias, 0) 136 | elif isinstance(m, nn.Linear): 137 | nn.init.normal_(m.weight, 0, 0.01) 138 | nn.init.constant_(m.bias, 0) 139 | 140 | def _make_layer(self, block, planes, num_blocks, stride=1): 141 | downsample = None 142 | if stride != 1 or self.in_planes != planes * block.expansion: 143 | downsample = nn.Sequential( 144 | nn.Conv2d(self.in_planes, planes * block.expansion, 145 | kernel_size=1, stride=stride, bias=False), 146 | nn.BatchNorm2d(planes * block.expansion), 147 | ) 148 | 149 | layers = [] 150 | layers.append(block(self.in_planes, planes, stride, downsample)) 151 | self.in_planes = planes * block.expansion # update 152 | for _ in range(1, num_blocks): 153 | layers.append(block(self.in_planes, planes)) 154 | 155 | return nn.Sequential(*layers) 156 | 157 | def forward(self, x): 158 | x = self.conv1(x) 159 | x = self.bn1(x) 160 | x = self.relu(x) 161 | x = self.maxpool(x) 162 | 163 | x = self.layer1(x) 164 | x = self.layer2(x) 165 | x = self.layer3(x) 166 | x = self.layer4(x) 167 | 168 | x = self.avgpool(x) 169 | x = x.view(x.size(0), -1) 170 | x = self.fc(x) 171 | 172 | return x 173 | 174 | 175 | def resnet18(pretrained=False, **kwargs): 176 | """Constructs a ResNet-18 model. 177 | 178 | Args: 179 | pretrained (bool): If True, returns a model pre-trained on ImageNet 180 | """ 181 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 182 | if pretrained: 183 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 184 | return model 185 | 186 | 187 | def resnet34(pretrained=False, **kwargs): 188 | """Constructs a ResNet-34 model. 189 | 190 | Args: 191 | pretrained (bool): If True, returns a model pre-trained on ImageNet 192 | """ 193 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 194 | if pretrained: 195 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 196 | return model 197 | 198 | 199 | def resnet50(pretrained=False, **kwargs): 200 | """Constructs a ResNet-50 model. 201 | 202 | Args: 203 | pretrained (bool): If True, returns a model pre-trained on ImageNet 204 | """ 205 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 206 | if pretrained: 207 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 208 | return model 209 | 210 | 211 | def resnet101(pretrained=False, **kwargs): 212 | """Constructs a ResNet-101 model. 213 | 214 | Args: 215 | pretrained (bool): If True, returns a model pre-trained on ImageNet 216 | """ 217 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 218 | if pretrained: 219 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 220 | return model 221 | 222 | 223 | def resnet152(pretrained=False, **kwargs): 224 | """Constructs a ResNet-152 model. 225 | 226 | Args: 227 | pretrained (bool): If True, returns a model pre-trained on ImageNet 228 | """ 229 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 230 | if pretrained: 231 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 232 | return model 233 | 234 | 235 | def get_fine_tuning_parameters(model, ft_begin_index=0, lr_mult1=0.1, lr_mult2=1): 236 | if ft_begin_index == 0: 237 | return model.parameters() 238 | 239 | ft_module_names = [] 240 | for i in range(ft_begin_index, 5): 241 | ft_module_names.append('layer{}'.format(i)) 242 | ft_module_names.append('fc') 243 | 244 | parameters = [] 245 | for name, params in model.named_parameters(): 246 | flag = False 247 | for ft_module in ft_module_names: 248 | if ft_module in name: 249 | flag = True 250 | parameters.append({'params': params, 'lr': lr_mult2}) 251 | break 252 | if not flag: 253 | parameters.append({'params': params, 'lr': lr_mult1}) 254 | 255 | return parameters 256 | 257 | 258 | if __name__ == '__main__': 259 | model = resnet18(input_size=224, num_classes=10) 260 | x = torch.rand(1, 3, 224, 224) 261 | device = torch.device("cuda" if torch.cuda.is_available else "cpu") 262 | x = x.to(device) 263 | model = model.to(device) 264 | y = model(x) 265 | print(torch.nn.functional.softmax(y, dim=1)) 266 | -------------------------------------------------------------------------------- /models/vgg.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Almost the same with the offical vgg. 3 | Except that: 4 | We allow different input size, which is 224 by default. 5 | ''' 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.utils.model_zoo as model_zoo 10 | import math 11 | 12 | 13 | __all__ = [ 14 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 15 | 'vgg19_bn', 'vgg19', 16 | ] 17 | 18 | 19 | model_urls = { 20 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 21 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 22 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 23 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 24 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', 25 | 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', 26 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', 27 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', 28 | } 29 | 30 | 31 | class VGG(nn.Module): 32 | 33 | def __init__(self, features, input_size=224, num_classes=1000, init_weights=True): 34 | super(VGG, self).__init__() 35 | self.features = features 36 | last_size = int(math.ceil(input_size / 32)) 37 | self.classifier = nn.Sequential( 38 | nn.Linear(512 * last_size * last_size, 4096), 39 | nn.ReLU(True), 40 | nn.Dropout(), 41 | nn.Linear(4096, 4096), 42 | nn.ReLU(True), 43 | nn.Dropout(), 44 | nn.Linear(4096, num_classes), 45 | ) 46 | if init_weights: 47 | self._initialize_weights() 48 | 49 | def forward(self, x): 50 | x = self.features(x) 51 | x = x.view(x.size(0), -1) 52 | x = self.classifier(x) 53 | return x 54 | 55 | def _initialize_weights(self): 56 | for m in self.modules(): 57 | if isinstance(m, nn.Conv2d): 58 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 59 | if m.bias is not None: 60 | nn.init.constant_(m.bias, 0) 61 | elif isinstance(m, nn.BatchNorm2d): 62 | nn.init.constant_(m.weight, 1) 63 | nn.init.constant_(m.bias, 0) 64 | elif isinstance(m, nn.Linear): 65 | nn.init.normal_(m.weight, 0, 0.01) 66 | nn.init.constant_(m.bias, 0) 67 | 68 | 69 | def make_layers(cfg, batch_norm=False): 70 | layers = [] 71 | in_channels = 3 72 | for v in cfg: 73 | if v == 'M': 74 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 75 | else: 76 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 77 | if batch_norm: 78 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 79 | else: 80 | layers += [conv2d, nn.ReLU(inplace=True)] 81 | in_channels = v 82 | return nn.Sequential(*layers) 83 | 84 | 85 | cfg = { 86 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 87 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 88 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 89 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 90 | } 91 | 92 | 93 | def vgg11(pretrained=False, **kwargs): 94 | """VGG 11-layer model (configuration "A") 95 | 96 | Args: 97 | pretrained (bool): If True, returns a model pre-trained on ImageNet 98 | """ 99 | if pretrained: 100 | kwargs['init_weights'] = False 101 | model = VGG(make_layers(cfg['A']), **kwargs) 102 | if pretrained: 103 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11'])) 104 | return model 105 | 106 | 107 | def vgg11_bn(pretrained=False, **kwargs): 108 | """VGG 11-layer model (configuration "A") with batch normalization 109 | 110 | Args: 111 | pretrained (bool): If True, returns a model pre-trained on ImageNet 112 | """ 113 | if pretrained: 114 | kwargs['init_weights'] = False 115 | model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs) 116 | if pretrained: 117 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11_bn'])) 118 | return model 119 | 120 | 121 | def vgg13(pretrained=False, **kwargs): 122 | """VGG 13-layer model (configuration "B") 123 | 124 | Args: 125 | pretrained (bool): If True, returns a model pre-trained on ImageNet 126 | """ 127 | if pretrained: 128 | kwargs['init_weights'] = False 129 | model = VGG(make_layers(cfg['B']), **kwargs) 130 | if pretrained: 131 | model.load_state_dict(model_zoo.load_url(model_urls['vgg13'])) 132 | return model 133 | 134 | 135 | def vgg13_bn(pretrained=False, **kwargs): 136 | """VGG 13-layer model (configuration "B") with batch normalization 137 | 138 | Args: 139 | pretrained (bool): If True, returns a model pre-trained on ImageNet 140 | """ 141 | if pretrained: 142 | kwargs['init_weights'] = False 143 | model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs) 144 | if pretrained: 145 | model.load_state_dict(model_zoo.load_url(model_urls['vgg13_bn'])) 146 | return model 147 | 148 | 149 | def vgg16(pretrained=False, **kwargs): 150 | """VGG 16-layer model (configuration "D") 151 | 152 | Args: 153 | pretrained (bool): If True, returns a model pre-trained on ImageNet 154 | """ 155 | if pretrained: 156 | kwargs['init_weights'] = False 157 | model = VGG(make_layers(cfg['D']), **kwargs) 158 | if pretrained: 159 | model.load_state_dict(model_zoo.load_url(model_urls['vgg16'])) 160 | return model 161 | 162 | 163 | def vgg16_bn(pretrained=False, **kwargs): 164 | """VGG 16-layer model (configuration "D") with batch normalization 165 | 166 | Args: 167 | pretrained (bool): If True, returns a model pre-trained on ImageNet 168 | """ 169 | if pretrained: 170 | kwargs['init_weights'] = False 171 | model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs) 172 | if pretrained: 173 | model.load_state_dict(model_zoo.load_url(model_urls['vgg16_bn'])) 174 | return model 175 | 176 | 177 | def vgg19(pretrained=False, **kwargs): 178 | """VGG 19-layer model (configuration "E") 179 | 180 | Args: 181 | pretrained (bool): If True, returns a model pre-trained on ImageNet 182 | """ 183 | if pretrained: 184 | kwargs['init_weights'] = False 185 | model = VGG(make_layers(cfg['E']), **kwargs) 186 | if pretrained: 187 | model.load_state_dict(model_zoo.load_url(model_urls['vgg19'])) 188 | return model 189 | 190 | 191 | def vgg19_bn(pretrained=False, **kwargs): 192 | """VGG 19-layer model (configuration 'E') with batch normalization 193 | 194 | Args: 195 | pretrained (bool): If True, returns a model pre-trained on ImageNet 196 | """ 197 | if pretrained: 198 | kwargs['init_weights'] = False 199 | model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs) 200 | if pretrained: 201 | model.load_state_dict(model_zoo.load_url(model_urls['vgg19_bn'])) 202 | return model 203 | 204 | 205 | if __name__ == '__main__': 206 | model = vgg16_bn(input_size=224, num_classes=10) 207 | x = torch.rand(1, 3, 224, 224) 208 | device = torch.device("cuda" if torch.cuda.is_available else "cpu") 209 | x = x.to(device) 210 | model = model.to(device) 211 | y = model(x) 212 | print(torch.nn.functional.softmax(y, dim=1)) 213 | -------------------------------------------------------------------------------- /pretrained_models/download.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | echo "Download resnet models pretrained on ImageNet..." 4 | 5 | wget -N https://download.pytorch.org/models/resnet18-5c106cde.pth 6 | wget -N https://download.pytorch.org/models/resnet34-333f7ec4.pth 7 | wget -N https://download.pytorch.org/models/resnet50-19c8e357.pth 8 | wget -N https://download.pytorch.org/models/resnet101-5d3b4d8f.pth 9 | wget -N https://download.pytorch.org/models/resnet152-b121ed2d.pth -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from utils import AverageMeter, calculate_accuracy 4 | 5 | 6 | def test(data_loader, model, args, device): 7 | batch_time = AverageMeter() 8 | data_time = AverageMeter() 9 | top1 = AverageMeter() 10 | top3 = AverageMeter() 11 | 12 | # switch to evaluate mode 13 | model.eval() 14 | 15 | end_time = time.time() 16 | for i, (input, target) in enumerate(data_loader): 17 | # measure data loading time 18 | data_time.update(time.time() - end_time) 19 | 20 | input = input.to(device) 21 | target = target.to(device) 22 | 23 | # compute output and loss 24 | output = model(input) 25 | 26 | # measure accuracy and record loss 27 | prec1, prec3 = calculate_accuracy(output, target, topk=(1, 3)) 28 | # prec1[0]: convert torch.Size([1]) to torch.Size([]) 29 | top1.update(prec1[0].item(), input.size(0)) 30 | top3.update(prec3[0].item(), input.size(0)) 31 | 32 | # measure elapsed time 33 | batch_time.update(time.time() - end_time) 34 | end_time = time.time() 35 | 36 | if (i + 1) % args.log_interval == 0: 37 | print('Test Iter [{0}/{1}]\t' 38 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 39 | 'Prec@3 {top3.val:.3f} ({top3.avg:.3f})\t' 40 | 'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 41 | 'Data Time {data_time.val:.3f} ({data_time.avg:.3f})'.format( 42 | i + 1, 43 | len(data_loader), 44 | top1=top1, 45 | top3=top3, 46 | batch_time=batch_time, 47 | data_time=data_time)) 48 | 49 | print(' * Prec@1 {top1.avg:.2f}% | Prec@3 {top3.avg:.2f}%'.format( 50 | top1=top1, top3=top3)) 51 | 52 | return top1.avg 53 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | from .visualize import Visualizer 2 | -------------------------------------------------------------------------------- /tools/visualize.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import visdom 4 | import time 5 | import numpy as np 6 | 7 | 8 | class Visualizer(object): 9 | 10 | def __init__(self, env='default', **kwargs): 11 | self.vis = visdom.Visdom(env=env, **kwargs) 12 | self.index = {} 13 | self.log_text = '' 14 | 15 | def reinit(self, env='default', **kwargs): 16 | self.vis = visdom.Visdom(env=env, **kwargs) 17 | return self 18 | 19 | def plot_many(self, d): 20 | ''' 21 | @params d: dict (name, value) i.e. ('loss', 0.11) 22 | ''' 23 | for k, v in d.items(): 24 | self.plot(k, v) 25 | 26 | def img_many(self, d): 27 | for k, v in d.items(): 28 | self.img(k, v) 29 | 30 | def plot(self, name, y, **kwargs): 31 | ''' 32 | self.plot('loss', 1.00) 33 | ''' 34 | x = self.index.get(name, 0) 35 | self.vis.line(Y=np.array([y]), 36 | X=np.array([x]), 37 | win=name, 38 | opts=dict(title=name), 39 | update=None if x == 0 else 'append', 40 | **kwargs) 41 | self.index[name] = x + 1 42 | 43 | def img(self, name, img_, **kwargs): 44 | ''' 45 | self.img('input_img', torch.Tensor(64, 64)) 46 | self.img('input_imgs', torch.Tensor(3, 64, 64)) 47 | self.img('input_imgs', torch.Tensor(100, 1, 64, 64)) 48 | self.img('input_imgs', torch.Tensor(100, 3, 64, 64), nrows=10) 49 | ''' 50 | self.vis.images(img_.cpu().numpy(), 51 | win=name, 52 | opts=dict(title=name), 53 | **kwargs) 54 | 55 | def log(self, info, win='log_text'): 56 | ''' 57 | self.log({'loss': 1, 'lr': 0.0001}) 58 | ''' 59 | self.log_text += ('[{time}] {info}
'.format( 60 | time=time.strftime('%m%d_%H%M%S'), 61 | info=info)) 62 | self.vis.text(self.log_text, win) 63 | 64 | def __getattr__(self, name): 65 | return getattr(self.vis, name) 66 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | import os 4 | 5 | from utils import AverageMeter, calculate_accuracy 6 | 7 | 8 | def train_epoch(epoch, data_loader, model, criterion, optimizer, args, device, 9 | epoch_logger, batch_logger, vis): 10 | batch_time = AverageMeter() 11 | data_time = AverageMeter() 12 | losses = AverageMeter() 13 | top1 = AverageMeter() 14 | top3 = AverageMeter() 15 | 16 | # switch to train mode 17 | model.train() 18 | 19 | end_time = time.time() 20 | for i, (input, target) in enumerate(data_loader): 21 | # measure data loading time 22 | data_time.update(time.time() - end_time) 23 | 24 | input = input.to(device) 25 | target = target.to(device) 26 | 27 | # compute output and loss 28 | output = model(input) 29 | loss = criterion(output, target) 30 | 31 | # measure accuracy and record loss 32 | prec1, prec3 = calculate_accuracy(output, target, topk=(1, 3)) 33 | losses.update(loss.item(), input.size(0)) 34 | # prec1[0]: convert torch.Size([1]) to torch.Size([]) 35 | top1.update(prec1[0].item(), input.size(0)) 36 | top3.update(prec3[0].item(), input.size(0)) 37 | """ 38 | a = np.array([1, 2, 3]) 39 | b = torch.from_numpy(a) # tensor([ 1, 2, 3]) 40 | c = b.sum() # tensor(6) 41 | d = b.sum(0) # tensor(6) 42 | e = b.sum(0, keepdim=True) # tensor([ 6]), torch.Size([1]) 43 | e[0] # tensor(6), torch.Size([]) 44 | e.item() # 6 45 | """ 46 | 47 | # compute gradient and do SGD step 48 | optimizer.zero_grad() 49 | loss.backward() 50 | optimizer.step() 51 | 52 | # measure elapsed time 53 | batch_time.update(time.time() - end_time) 54 | end_time = time.time() 55 | 56 | if (i + 1) % args.log_interval == 0: 57 | print('Train Epoch [{0}/{1}]([{2}/{3}])\t' 58 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 59 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 60 | 'Prec@3 {top3.val:.3f} ({top3.avg:.3f})\t' 61 | 'LR {lr:f}\t' 62 | 'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 63 | 'Data Time {data_time.val:.3f} ({data_time.avg:.3f})'.format( 64 | epoch, 65 | args.epochs, 66 | i + 1, 67 | len(data_loader), 68 | loss=losses, 69 | top1=top1, 70 | top3=top3, 71 | lr=optimizer.param_groups[0]['lr'], 72 | batch_time=batch_time, 73 | data_time=data_time)) 74 | 75 | batch_logger.log({ 76 | 'epoch': epoch, 77 | 'batch': i + 1, 78 | 'iter': (epoch - 1) * len(data_loader) + (i + 1), 79 | 'loss': losses.val, 80 | 'top1': top1.val, 81 | 'top3': top3.val, 82 | 'lr': optimizer.param_groups[0]['lr'] 83 | }) 84 | 85 | epoch_logger.log({ 86 | 'epoch': epoch, 87 | 'loss': losses.avg, 88 | 'top1': top1.avg, 89 | 'top3': top3.avg, 90 | 'lr': optimizer.param_groups[0]['lr'] 91 | }) 92 | 93 | if epoch % args.checkpoint_interval == 0: 94 | save_file_path = os.path.join(args.checkpoint_path, 'save_{}_{}.pth'.format(args.arch, epoch)) 95 | checkpoint = { 96 | 'epoch': epoch, 97 | 'arch': args.arch, 98 | 'model': model.state_dict(), 99 | 'optimizer': optimizer.state_dict(), 100 | } 101 | torch.save(checkpoint, save_file_path) 102 | 103 | vis.plot('Train loss', losses.avg) 104 | vis.plot('Train accu', top1.avg) 105 | vis.log("epoch:{epoch}, lr:{lr}, loss:{loss}, accu:{accu}".format( 106 | epoch=epoch, 107 | lr=optimizer.param_groups[0]['lr'], 108 | loss=losses.avg, 109 | accu=top1.avg)) 110 | 111 | 112 | def val_epoch(epoch, data_loader, model, criterion, args, device, epoch_logger, vis): 113 | batch_time = AverageMeter() 114 | data_time = AverageMeter() 115 | losses = AverageMeter() 116 | top1 = AverageMeter() 117 | top3 = AverageMeter() 118 | 119 | # switch to evaluate mode 120 | model.eval() 121 | 122 | end_time = time.time() 123 | for i, (input, target) in enumerate(data_loader): 124 | # measure data loading time 125 | data_time.update(time.time() - end_time) 126 | 127 | input = input.to(device) 128 | target = target.to(device) 129 | 130 | # compute output and loss 131 | output = model(input) 132 | loss = criterion(output, target) 133 | 134 | # measure accuracy and record loss 135 | prec1, prec3 = calculate_accuracy(output, target, topk=(1, 3)) 136 | losses.update(loss.item(), input.size(0)) 137 | top1.update(prec1[0].item(), input.size(0)) 138 | top3.update(prec3[0].item(), input.size(0)) 139 | 140 | # measure elapsed time 141 | batch_time.update(time.time() - end_time) 142 | end_time = time.time() 143 | 144 | if (i + 1) % args.log_interval == 0: 145 | print('Valid Epoch [{0}/{1}]([{2}/{3}])\t' 146 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 147 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 148 | 'Prec@3 {top3.val:.3f} ({top3.avg:.3f})\t' 149 | 'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 150 | 'Data Time {data_time.val:.3f} ({data_time.avg:.3f})'.format( 151 | epoch, 152 | args.epochs, 153 | i + 1, 154 | len(data_loader), 155 | loss=losses, 156 | top1=top1, 157 | top3=top3, 158 | batch_time=batch_time, 159 | data_time=data_time)) 160 | 161 | print(' * Prec@1 {top1.avg:.2f}% | Prec@3 {top3.avg:.2f}%'.format( 162 | top1=top1, top3=top3)) 163 | 164 | epoch_logger.log({ 165 | 'epoch': epoch, 166 | 'loss': losses.avg, 167 | 'top1': top1.avg, 168 | 'top3': top3.avg 169 | }) 170 | 171 | vis.plot('Val loss', losses.avg) 172 | vis.plot('Val accu', top1.avg) 173 | 174 | return losses.avg, top1.avg 175 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import torch 3 | 4 | 5 | class AverageMeter(object): 6 | """Computes and stores the average and current value""" 7 | def __init__(self): 8 | self.reset() 9 | 10 | def reset(self): 11 | self.val = 0 12 | self.avg = 0 13 | self.sum = 0 14 | self.count = 0 15 | 16 | def update(self, val, n=1): 17 | self.val = val 18 | self.sum += val * n 19 | self.count += n 20 | self.avg = self.sum / self.count 21 | 22 | 23 | class Logger(object): 24 | """Outputs log files""" 25 | def __init__(self, path, header): 26 | self.log_file = open(path, 'w') 27 | self.logger = csv.writer(self.log_file, delimiter='\t') 28 | self.logger.writerow(header) 29 | self.header = header 30 | 31 | def __del(self): 32 | self.log_file.close() 33 | 34 | def log(self, values): 35 | write_values = [] 36 | for col in self.header: 37 | assert col in values 38 | write_values.append(values[col]) 39 | 40 | self.logger.writerow(write_values) 41 | self.log_file.flush() 42 | 43 | 44 | def calculate_accuracy(output, target, topk=(1,)): 45 | """Computes the precision@k for the specified values of k""" 46 | with torch.no_grad(): 47 | maxk = max(topk) 48 | batch_size = target.size(0) 49 | 50 | _, pred = output.topk(maxk, dim=1, largest=True, sorted=True) # batch_size x maxk 51 | pred = pred.t() # transpose, maxk x batch_size 52 | # target.view(1, -1): convert (batch_size,) to 1 x batch_size 53 | # expand_as: convert 1 x batch_size to maxk x batch_size 54 | correct = pred.eq(target.view(1, -1).expand_as(pred)) # maxk x batch_size 55 | 56 | res = [] 57 | for k in topk: 58 | # correct[:k] converts "maxk x batch_size" to "k x batch_size" 59 | # view(-1) converts "k x batch_size" to "(k x batch_size,)" 60 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 61 | res.append(correct_k.mul_(100.0 / batch_size)) 62 | return res 63 | --------------------------------------------------------------------------------