├── .gitignore
├── LICENSE
├── README.md
├── args.py
├── data
├── __init__.py
├── dataset_1.py
└── dataset_2.py
├── main.py
├── model.py
├── models
├── __init__.py
├── resnet.py
└── vgg.py
├── pretrained_models
└── download.sh
├── test.py
├── tools
├── __init__.py
└── visualize.py
├── train.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Shi Husen
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Driver Posture Classification
2 |
3 | This is a PyTorch code for **Driver Posture Classification** task. We use the [AUC Distracted Driver Dataset](https://devyhia.github.io/projects/auc-distracted-driver-dataset). The dataset was captured to develop the state-of-the-art in detection of distracted drivers. Here are some samples from the dataset:
4 |
5 |
6 |
7 |
8 | The task is to classify an image to one of these pre-defined categories, namely "Drive Safe", "Talk Passenger", "Text Right", "Drink", and etc. We use a pretrained resnet34 model to achieve comparable performance with the orignal paper [Real-time Distracted Driver Posture Classification](https://arxiv.org/abs/1706.09498). The classification accuracy is about 95%.
9 |
10 |
11 | ## Usage
12 |
13 | ### Requirements
14 |
15 | * python 3.5+
16 | * pytorch 0.4
17 | * visdom (optional)
18 |
19 |
20 | ### Steps
21 |
22 | 0. Download the dataset and its training and testing splits (train.csv and test.csv). Put them in a directory together.
23 | 1. Clone the repository
24 |
25 | `git clone https://github.com/husencd/DriverPostureClassification.git`
26 |
27 | `cd DriverPostureClassification`
28 |
29 | 2. Download the resnet model pretrained on ImageNet from [pytorch official model urls](https://download.pytorch.org/models/).
30 |
31 | `cd pretrained_models`
32 |
33 | `sh download.sh`
34 |
35 | 3. Now you can train/fine-tune the model
36 |
37 | `cd ..`
38 |
39 | `python main.py [--model resnet] [--model_depth 34]`
40 |
41 | If you want to monitor the training process, use visdom
42 |
43 | `python -m visdom.server`
44 |
45 |
46 | ## Reference
47 |
48 | * Our code is partially based on https://github.com/chenyuntc/pytorch-best-practice.
49 |
--------------------------------------------------------------------------------
/args.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | def parse_args():
5 | parser = argparse.ArgumentParser(description='PyTorch Driver Posture Classification')
6 |
7 | # path
8 | parser.add_argument('--data_path', default='/home/husencd/Downloads/dataset/driver', type=str,
9 | help='Driver data directory path')
10 | parser.add_argument('--root_path', default='/home/husencd/husen/pytorch/learn/DriverPostureClassification', type=str,
11 | help='Project root directory path')
12 | parser.add_argument('--result_path', default='results', type=str,
13 | help='Result directory path')
14 | parser.add_argument('--checkpoint_path', default='checkpoints', type=str,
15 | help='Checkpoint directory path (snapshot)')
16 | parser.add_argument('--resume_path', default='', type=str,
17 | help='Saved model (checkpoint) path of previous training')
18 |
19 | # I/O
20 | parser.add_argument('--input_size', default=224, type=int,
21 | help='Input size of image')
22 | parser.add_argument('--n_classes', default=1000, type=int,
23 | help='Number of classes (ImageNet: 1000,)')
24 | parser.add_argument('--n_finetune_classes', default=10, type=int,
25 | help='Number of classes for fine-tuning, n_classes is set to the number when pre-training')
26 |
27 | # batch size and epoch
28 | parser.add_argument('--batch_size', default=64, type=int,
29 | help='Batch Size')
30 | parser.add_argument('--test_batch_size', default=64, type=int,
31 | help='Test batch Size')
32 | parser.add_argument('--epochs', default=50, type=int,
33 | help='Number of total epochs to run')
34 | parser.add_argument('--begin_epoch', default=1, type=int,
35 | help='Training begins at this epoch. Previous trained model indicated by resume_path is loaded.')
36 |
37 | # about model configuration
38 | parser.add_argument('--model', default='resnet', type=str,
39 | help='(vgg | resnet | resnext | densenet)')
40 | parser.add_argument('--model_depth', default=34, type=int,
41 | help='Depth of resnet (10 | 18 | 34 | 50 | 101 | 152)')
42 |
43 | # about optimizer
44 | parser.add_argument('--lr', default=0.001, type=float,
45 | help='Initial learning rate (divided by 10 while training by lr scheduler)')
46 | parser.add_argument('--lr_mult1', default=0.1, type=float,
47 | help='Multiplication factor of learning rate in those pre-trained layers')
48 | parser.add_argument('--lr_mult2', default=1, type=float,
49 | help='Multiplication factor of learning rate in those newly-created layers')
50 | parser.add_argument('--lr_patience', default=10, type=int,
51 | help='Patience of LR scheduler. See documentation of ReduceLROnPlateau.')
52 | parser.add_argument('--momentum', default=0.9, type=float,
53 | help='Momentum')
54 | parser.add_argument('--weight_decay', default=5e-4, type=float,
55 | help='Weight decay')
56 |
57 | # train, val, test, fine-tune
58 | parser.add_argument('--train', action='store_true', default=True,
59 | help='If true, training is performed.')
60 | parser.add_argument('--val', action='store_true', default=True,
61 | help='If true, validation is performed.')
62 | parser.add_argument('--test', action='store_true', default=True,
63 | help='If true, test is performed.')
64 | parser.add_argument('--finetune', action='store_true', default=True,
65 | help='If True, fine-tune on a model that has been pre-trained on ImageNet')
66 | parser.add_argument('--ft_begin_index', default=0, type=int,
67 | help='Begin block index of fine-tuning')
68 |
69 | # training log and checkpoint
70 | parser.add_argument('--log_interval', default=10, type=int,
71 | help='How many batches to wait before logging training status')
72 | parser.add_argument('--checkpoint_interval', default=20, type=int,
73 | help='Trained model is saved at every this epochs.')
74 |
75 | # about device
76 | parser.add_argument('--use_cuda', action='store_true', default=True,
77 | help='If False, cuda is not used.')
78 | parser.add_argument('--num_workers', default=4, type=int,
79 | help='Number of threads for multi-thread loading')
80 |
81 | # random number seed
82 | parser.add_argument('--manual_seed', default=1, type=int,
83 | help='Manually set random seed')
84 |
85 | # visdom
86 | parser.add_argument('--env', default='default', type=str,
87 | help='Visdom enviroment')
88 |
89 | args = parser.parse_args()
90 |
91 | return args
92 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .dataset_1 import Driver # Divide the dataset into 2 parts only, i.e. train set and test set.
2 | # from .dataset_2 import Driver # Divide the dataset into 3 parts, i.e. train set, val set and test set.
3 |
--------------------------------------------------------------------------------
/data/dataset_1.py:
--------------------------------------------------------------------------------
1 | """
2 | Divide the dataset into 2 parts only, i.e. train set and test set.
3 | """
4 |
5 | from torch.utils.data import Dataset
6 | from torchvision import transforms
7 | import os
8 | from PIL import Image
9 | import numpy as np
10 |
11 |
12 | class Driver(Dataset):
13 | def __init__(self, root, transform=None, target_transform=None, train=True, test=False):
14 | self.root = root
15 | self.transform = transform
16 | self.target_transform = target_transform
17 | self.train = train
18 | self.test = test
19 |
20 | if self.test:
21 | with open(os.path.join(self.root, 'test.csv'), 'r') as f:
22 | lines = f.readlines()[1:]
23 | dataset = []
24 | for line in lines:
25 | dataset.append(line.strip().split(','))
26 | else:
27 | with open(os.path.join(self.root, 'train.csv'), 'r') as f:
28 | lines = f.readlines()[1:]
29 | dataset = []
30 | for line in lines:
31 | dataset.append(line.strip().split(','))
32 |
33 | dataset = np.array(dataset)
34 | self.imgs = list(map(lambda x: os.path.join(self.root, x), dataset[:, 0]))
35 | self.target = list(map(int, dataset[:, 1]))
36 |
37 | if transform is None:
38 | normalize = transforms.Normalize(
39 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
40 |
41 | if self.test:
42 | self.transform = transforms.Compose([
43 | transforms.Resize(256),
44 | transforms.CenterCrop(224),
45 | transforms.ToTensor(), normalize
46 | ])
47 | else:
48 | self.transform = transforms.Compose([
49 | transforms.Resize(256),
50 | transforms.RandomResizedCrop(224, scale=(0.25, 1)),
51 | transforms.RandomHorizontalFlip(),
52 | transforms.ToTensor(), normalize
53 | ])
54 |
55 | def __getitem__(self, index):
56 | img_path = self.imgs[index]
57 | target = self.target[index]
58 | img = Image.open(img_path)
59 |
60 | if self.transform is not None:
61 | img = self.transform(img)
62 |
63 | if self.target_transform is not None:
64 | target = self.target_transform(target)
65 |
66 | return img, target
67 |
68 | def __len__(self):
69 | return len(self.imgs)
70 |
71 |
72 | if __name__ == '__main__':
73 | driver = Driver('/home/husencd/Downloads/dataset/driver', train=True)
74 | print(driver.__getitem__(1))
75 | print(driver.__len__()) # 12977
76 | driver = Driver('/home/husencd/Downloads/dataset/driver', train=False, test=True)
77 | print(driver.__len__()) # 4331
78 |
--------------------------------------------------------------------------------
/data/dataset_2.py:
--------------------------------------------------------------------------------
1 | """
2 | Divide the dataset into 3 parts, i.e. train set, val set and test set.
3 | """
4 |
5 | from torch.utils.data import Dataset
6 | from torchvision import transforms
7 | import os
8 | from PIL import Image
9 | import numpy as np
10 |
11 |
12 | class Driver(Dataset):
13 | def __init__(self, root, transform=None, target_transform=None, train=True, test=False):
14 | self.root = root
15 | self.transform = transform
16 | self.target_transform = target_transform
17 | self.train = train
18 | self.test = test
19 |
20 | if self.test:
21 | with open(os.path.join(self.root, 'test.csv'), 'r') as f:
22 | lines = f.readlines()[1:]
23 | dataset = []
24 | for line in lines:
25 | dataset.append(line.strip().split(','))
26 | else:
27 | with open(os.path.join(self.root, 'train.csv'), 'r') as f:
28 | lines = f.readlines()[1:]
29 | dataset = []
30 | for line in lines:
31 | dataset.append(line.strip().split(','))
32 |
33 | num_train = int(0.7 * len(dataset))
34 | import random
35 | random.seed(1)
36 | for _ in range(10):
37 | dataset = random.sample(dataset, len(dataset))
38 | if self.train:
39 | dataset = dataset[:num_train]
40 | else:
41 | dataset = dataset[num_train:]
42 |
43 | dataset = np.array(dataset)
44 | self.imgs = list(map(lambda x: os.path.join(self.root, x), dataset[:, 0]))
45 | self.target = list(map(int, dataset[:, 1]))
46 |
47 | if transform is None:
48 | normalize = transforms.Normalize(
49 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
50 |
51 | if self.test or (not self.train):
52 | self.transform = transforms.Compose([
53 | transforms.Resize(256),
54 | transforms.CenterCrop(224),
55 | transforms.ToTensor(), normalize
56 | ])
57 | else:
58 | self.transform = transforms.Compose([
59 | transforms.Resize(256),
60 | transforms.RandomResizedCrop(224, scale=(0.25, 1)),
61 | transforms.RandomHorizontalFlip(),
62 | transforms.ToTensor(), normalize
63 | ])
64 |
65 | def __getitem__(self, index):
66 | img_path = self.imgs[index]
67 | target = self.target[index]
68 | img = Image.open(img_path)
69 |
70 | if self.transform is not None:
71 | img = self.transform(img)
72 |
73 | if self.target_transform is not None:
74 | target = self.target_transform(target)
75 |
76 | return img, target
77 |
78 | def __len__(self):
79 | return len(self.imgs)
80 |
81 |
82 | if __name__ == '__main__':
83 | driver = Driver('/home/husencd/Downloads/dataset/driver', train=True)
84 | print(driver.__getitem__(1))
85 | print(driver.__len__()) # 9083
86 | driver = Driver('/home/husencd/Downloads/dataset/driver', train=False, test=False)
87 | print(driver.__len__()) # 3894
88 | driver = Driver('/home/husencd/Downloads/dataset/driver', train=False, test=True)
89 | print(driver.__len__()) # 4331
90 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | from torch.utils.data import DataLoader
5 |
6 | import os
7 | import json
8 |
9 | from args import parse_args
10 | from model import get_model_param
11 | from data import Driver
12 | from utils import Logger
13 | from tools import Visualizer
14 | from train import train_epoch, val_epoch
15 | import test
16 |
17 | best_prec1 = 0
18 | best_epoch = 1
19 |
20 |
21 | def main():
22 | global args, best_prec1, best_epoch
23 | args = parse_args()
24 |
25 | if args.root_path != '':
26 | args.result_path = os.path.join(args.root_path, args.result_path)
27 | args.checkpoint_path = os.path.join(args.root_path, args.checkpoint_path)
28 | if not os.path.exists(args.result_path):
29 | os.mkdir(args.result_path)
30 | if not os.path.exists(args.checkpoint_path):
31 | os.mkdir(args.checkpoint_path)
32 | if args.resume_path:
33 | args.resume_path = os.path.join(args.checkpoint_path, args.resume_path)
34 |
35 | args.arch = '{}{}'.format(args.model, args.model_depth)
36 |
37 | torch.manual_seed(args.manual_seed)
38 |
39 | args.use_cuda = args.use_cuda and torch.cuda.is_available()
40 |
41 | device = torch.device("cuda" if args.use_cuda else "cpu")
42 |
43 | # create model
44 | model, parameters = get_model_param(args)
45 | print(model)
46 | model = model.to(device)
47 |
48 | with open(os.path.join(args.result_path, 'args.json'), 'w') as args_file:
49 | json.dump(vars(args), args_file)
50 |
51 | # define loss function (criterion) and optimizer
52 | criterion = nn.CrossEntropyLoss().to(device)
53 | optimizer = optim.SGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
54 | # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, threshold=0.001, patience=args.lr_patience)
55 |
56 | lr_mult = []
57 | for param_group in optimizer.param_groups:
58 | lr_mult.append(param_group['lr'])
59 |
60 | # optionally resume from a checkpoint
61 | if args.resume_path:
62 | if os.path.isfile(args.resume_path):
63 | print("=> loading checkpoint '{}'...".format(args.resume_path))
64 | checkpoint = torch.load(args.resume_path)
65 | args.begin_epoch = checkpoint['epoch'] + 1
66 | model.load_state_dict(checkpoint['model'])
67 | optimizer.load_state_dict(checkpoint['optimizer'])
68 | else:
69 | print("=> no checkpoint found at '{}'".format(args.resume_path))
70 |
71 | if args.train:
72 | train_dataset = Driver(root=args.data_path, train=True, test=False)
73 | train_loader = DataLoader(
74 | dataset=train_dataset,
75 | batch_size=args.batch_size,
76 | shuffle=True,
77 | num_workers=args.num_workers)
78 | train_logger = Logger(
79 | os.path.join(args.result_path, 'train.log'),
80 | ['epoch', 'loss', 'top1', 'top3', 'lr'])
81 | train_batch_logger = Logger(
82 | os.path.join(args.result_path, 'train_batch.log'),
83 | ['epoch', 'batch', 'iter', 'loss', 'top1', 'top3', 'lr'])
84 |
85 | if args.val:
86 | val_dataset = Driver(root=args.data_path, train=False, test=True)
87 | val_loader = DataLoader(
88 | dataset=val_dataset,
89 | batch_size=args.test_batch_size,
90 | shuffle=False,
91 | num_workers=args.num_workers)
92 | val_logger = Logger(
93 | os.path.join(args.result_path, 'val.log'),
94 | ['epoch', 'loss', 'top1', 'top3'])
95 |
96 | print('=> Start running...')
97 | vis = Visualizer(env=args.env)
98 | for epoch in range(args.begin_epoch, args.epochs + 1):
99 | if args.train:
100 | adjust_learning_rate(optimizer, epoch, lr_mult, args)
101 | train_epoch(epoch, train_loader, model, criterion, optimizer, args, device, train_logger, train_batch_logger, vis)
102 | print('\n')
103 |
104 | if args.val:
105 | val_loss, val_prec1 = val_epoch(epoch, val_loader, model, criterion, args, device, val_logger, vis)
106 | print('\n')
107 | # remember best prec@1 and save checkpoint
108 | if val_prec1 > best_prec1:
109 | best_prec1 = val_prec1
110 | best_epoch = epoch
111 | print('=> Saving current best model...\n')
112 | save_file_path = os.path.join(args.result_path, 'save_best_{}_{}.pth'.format(args.arch, epoch))
113 | checkpoint = {
114 | 'arch': args.arch,
115 | 'epoch': best_epoch,
116 | 'best_prec1': best_prec1,
117 | 'model': model.state_dict(),
118 | 'optimizer': optimizer.state_dict()
119 | }
120 | torch.save(checkpoint, save_file_path)
121 |
122 | # if args.train and args.val:
123 | # scheduler.step(val_loss)
124 |
125 | if args.test:
126 | test_dataset = Driver(root=args.data_path, train=False, test=True)
127 | test_loader = DataLoader(
128 | dataset=test_dataset,
129 | batch_size=args.test_batch_size,
130 | shuffle=False,
131 | num_workers=args.num_workers)
132 | # # if you only test the model, you need to set the "best_epoch" manually
133 | # best_epoch = 10 # set manually
134 | saved_model_path = os.path.join(args.result_path, 'save_best_{}_{}.pth'.format(args.arch, best_epoch))
135 | print("Using '{}' for test...".format(saved_model_path))
136 | checkpoint = torch.load(saved_model_path)
137 | model.load_state_dict(checkpoint['model'])
138 | test.test(test_loader, model, args, device)
139 |
140 |
141 | def adjust_learning_rate(optimizer, epoch, lr_mult, args):
142 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
143 | lr = args.lr * (0.1**((epoch - 1) // 20))
144 | for i, param_group in enumerate(optimizer.param_groups):
145 | if args.finetune and args.ft_begin_index:
146 | param_group['lr'] = lr * lr_mult[i]
147 | else:
148 | param_group['lr'] = lr
149 |
150 |
151 | if __name__ == '__main__':
152 | main()
153 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | import os
5 | from models import resnet
6 |
7 | model_path = {
8 | 'resnet18': 'resnet18-5c106cde.pth',
9 | 'resnet34': 'resnet34-333f7ec4.pth',
10 | 'resnet50': 'resnet50-19c8e357.pth',
11 | 'resnet101': 'resnet101-5d3b4d8f.pth',
12 | 'resnet152': 'resnet152-b121ed2d.pth',
13 | }
14 |
15 |
16 | def get_model_param(args):
17 | # assert args.model in ['resnet', 'vgg']
18 |
19 | if args.model == 'resnet':
20 | assert args.model_depth in [18, 34, 50, 101, 152]
21 |
22 | from models.resnet import get_fine_tuning_parameters
23 |
24 | if args.model_depth == 18:
25 | model = resnet.resnet18(pretrained=False, input_size=args.input_size, num_classes=args.n_classes)
26 | elif args.model_depth == 34:
27 | model = resnet.resnet34(pretrained=False, input_size=args.input_size, num_classes=args.n_classes)
28 | elif args.model_depth == 50:
29 | model = resnet.resnet50(pretrained=False, input_size=args.input_size, num_classes=args.n_classes)
30 | elif args.model_depth == 101:
31 | model = resnet.resnet101(pretrained=False, input_size=args.input_size, num_classes=args.n_classes)
32 | elif args.model_depth == 152:
33 | model = resnet.resnet152(pretrained=False, input_size=args.input_size, num_classes=args.n_classes)
34 |
35 | # elif args.model == 'vgg':
36 | # pass
37 |
38 | # Load pretrained model here
39 | if args.finetune:
40 | pretrained_model = model_path[args.arch]
41 | args.pretrain_path = os.path.join(args.root_path, 'pretrained_models', pretrained_model)
42 | print("=> loading pretrained model '{}'...".format(pretrained_model))
43 |
44 | model.load_state_dict(torch.load(args.pretrain_path))
45 |
46 | # Only modify the last layer
47 | if args.model == 'resnet':
48 | model.fc = nn.Linear(model.fc.in_features, args.n_finetune_classes)
49 | # elif args.model == 'vgg':
50 | # pass
51 |
52 | parameters = get_fine_tuning_parameters(model, args.ft_begin_index, args.lr_mult1, args.lr_mult2)
53 | return model, parameters
54 |
55 | return model, model.parameters()
56 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | # from .resnet import *
2 |
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | '''
2 | Almost the same with the offical resnet.
3 | Except that:
4 | We allow different input size, which is 224 by default.
5 | And we can get fine-tuning parameters based on ft_begin_index.
6 | '''
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.utils.model_zoo as model_zoo
11 | import math
12 |
13 |
14 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
15 | 'resnet152']
16 |
17 |
18 | model_urls = {
19 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
20 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
21 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
22 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
23 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
24 | }
25 |
26 |
27 | def conv3x3(in_planes, out_planes, stride=1):
28 | """3x3 convolution with padding"""
29 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
30 | padding=1, bias=False)
31 |
32 |
33 | class BasicBlock(nn.Module):
34 | expansion = 1
35 |
36 | def __init__(self, in_planes, planes, stride=1, downsample=None):
37 | super(BasicBlock, self).__init__()
38 | self.conv1 = conv3x3(in_planes, planes, stride)
39 | self.bn1 = nn.BatchNorm2d(planes)
40 | self.relu = nn.ReLU(inplace=True)
41 | self.conv2 = conv3x3(planes, planes)
42 | self.bn2 = nn.BatchNorm2d(planes)
43 | self.downsample = downsample
44 | self.stride = stride
45 |
46 | def forward(self, x):
47 | residual = x
48 |
49 | out = self.conv1(x)
50 | out = self.bn1(out)
51 | out = self.relu(out)
52 |
53 | out = self.conv2(out)
54 | out = self.bn2(out)
55 |
56 | if self.downsample is not None:
57 | residual = self.downsample(x)
58 |
59 | out += residual
60 | out = self.relu(out)
61 |
62 | return out
63 |
64 |
65 | class Bottleneck(nn.Module):
66 | expansion = 4
67 |
68 | def __init__(self, in_planes, planes, stride=1, downsample=None):
69 | super(Bottleneck, self).__init__()
70 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
71 | self.bn1 = nn.BatchNorm2d(planes)
72 | self.conv2 = conv3x3(planes, planes, stride=stride)
73 | self.bn2 = nn.BatchNorm2d(planes)
74 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
75 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
76 | self.relu = nn.ReLU(inplace=True)
77 | self.downsample = downsample
78 | self.stride = stride
79 |
80 | def forward(self, x):
81 | residual = x
82 |
83 | out = self.conv1(x)
84 | out = self.bn1(out)
85 | out = self.relu(out)
86 |
87 | out = self.conv2(out)
88 | out = self.bn2(out)
89 | out = self.relu(out)
90 |
91 | out = self.conv3(out)
92 | out = self.bn3(out)
93 |
94 | if self.downsample is not None:
95 | residual = self.downsample(x)
96 |
97 | out += residual
98 | out = self.relu(out)
99 |
100 | return out
101 |
102 |
103 | class ResNet(nn.Module):
104 |
105 | def __init__(self, block, num_blocks, input_size=224, num_classes=1000):
106 | super(ResNet, self).__init__()
107 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
108 | bias=False)
109 | self.bn1 = nn.BatchNorm2d(64)
110 | self.relu = nn.ReLU(inplace=True)
111 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
112 | self.in_planes = 64 # initial value
113 | self.layer1 = self._make_layer(block, 64, num_blocks[0])
114 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
115 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
116 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
117 | self.avgpool = nn.AvgPool2d(int(math.ceil(input_size / 32)), stride=1)
118 | self.fc = nn.Linear(512 * block.expansion, num_classes)
119 |
120 | # for m in self.modules():
121 | # if isinstance(m, nn.Conv2d):
122 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
123 | # m.weight.data.normal_(0, math.sqrt(2. / n))
124 | # elif isinstance(m, nn.BatchNorm2d):
125 | # m.weight.data.fill_(1)
126 | # m.bias.data.zero_()
127 |
128 | for m in self.modules():
129 | if isinstance(m, nn.Conv2d):
130 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
131 | if m.bias is not None:
132 | nn.init.constant_(m.bias, 0)
133 | elif isinstance(m, nn.BatchNorm2d):
134 | nn.init.constant_(m.weight, 1)
135 | nn.init.constant_(m.bias, 0)
136 | elif isinstance(m, nn.Linear):
137 | nn.init.normal_(m.weight, 0, 0.01)
138 | nn.init.constant_(m.bias, 0)
139 |
140 | def _make_layer(self, block, planes, num_blocks, stride=1):
141 | downsample = None
142 | if stride != 1 or self.in_planes != planes * block.expansion:
143 | downsample = nn.Sequential(
144 | nn.Conv2d(self.in_planes, planes * block.expansion,
145 | kernel_size=1, stride=stride, bias=False),
146 | nn.BatchNorm2d(planes * block.expansion),
147 | )
148 |
149 | layers = []
150 | layers.append(block(self.in_planes, planes, stride, downsample))
151 | self.in_planes = planes * block.expansion # update
152 | for _ in range(1, num_blocks):
153 | layers.append(block(self.in_planes, planes))
154 |
155 | return nn.Sequential(*layers)
156 |
157 | def forward(self, x):
158 | x = self.conv1(x)
159 | x = self.bn1(x)
160 | x = self.relu(x)
161 | x = self.maxpool(x)
162 |
163 | x = self.layer1(x)
164 | x = self.layer2(x)
165 | x = self.layer3(x)
166 | x = self.layer4(x)
167 |
168 | x = self.avgpool(x)
169 | x = x.view(x.size(0), -1)
170 | x = self.fc(x)
171 |
172 | return x
173 |
174 |
175 | def resnet18(pretrained=False, **kwargs):
176 | """Constructs a ResNet-18 model.
177 |
178 | Args:
179 | pretrained (bool): If True, returns a model pre-trained on ImageNet
180 | """
181 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
182 | if pretrained:
183 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
184 | return model
185 |
186 |
187 | def resnet34(pretrained=False, **kwargs):
188 | """Constructs a ResNet-34 model.
189 |
190 | Args:
191 | pretrained (bool): If True, returns a model pre-trained on ImageNet
192 | """
193 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
194 | if pretrained:
195 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
196 | return model
197 |
198 |
199 | def resnet50(pretrained=False, **kwargs):
200 | """Constructs a ResNet-50 model.
201 |
202 | Args:
203 | pretrained (bool): If True, returns a model pre-trained on ImageNet
204 | """
205 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
206 | if pretrained:
207 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
208 | return model
209 |
210 |
211 | def resnet101(pretrained=False, **kwargs):
212 | """Constructs a ResNet-101 model.
213 |
214 | Args:
215 | pretrained (bool): If True, returns a model pre-trained on ImageNet
216 | """
217 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
218 | if pretrained:
219 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
220 | return model
221 |
222 |
223 | def resnet152(pretrained=False, **kwargs):
224 | """Constructs a ResNet-152 model.
225 |
226 | Args:
227 | pretrained (bool): If True, returns a model pre-trained on ImageNet
228 | """
229 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
230 | if pretrained:
231 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
232 | return model
233 |
234 |
235 | def get_fine_tuning_parameters(model, ft_begin_index=0, lr_mult1=0.1, lr_mult2=1):
236 | if ft_begin_index == 0:
237 | return model.parameters()
238 |
239 | ft_module_names = []
240 | for i in range(ft_begin_index, 5):
241 | ft_module_names.append('layer{}'.format(i))
242 | ft_module_names.append('fc')
243 |
244 | parameters = []
245 | for name, params in model.named_parameters():
246 | flag = False
247 | for ft_module in ft_module_names:
248 | if ft_module in name:
249 | flag = True
250 | parameters.append({'params': params, 'lr': lr_mult2})
251 | break
252 | if not flag:
253 | parameters.append({'params': params, 'lr': lr_mult1})
254 |
255 | return parameters
256 |
257 |
258 | if __name__ == '__main__':
259 | model = resnet18(input_size=224, num_classes=10)
260 | x = torch.rand(1, 3, 224, 224)
261 | device = torch.device("cuda" if torch.cuda.is_available else "cpu")
262 | x = x.to(device)
263 | model = model.to(device)
264 | y = model(x)
265 | print(torch.nn.functional.softmax(y, dim=1))
266 |
--------------------------------------------------------------------------------
/models/vgg.py:
--------------------------------------------------------------------------------
1 | '''
2 | Almost the same with the offical vgg.
3 | Except that:
4 | We allow different input size, which is 224 by default.
5 | '''
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.utils.model_zoo as model_zoo
10 | import math
11 |
12 |
13 | __all__ = [
14 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
15 | 'vgg19_bn', 'vgg19',
16 | ]
17 |
18 |
19 | model_urls = {
20 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
21 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
22 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
23 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
24 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
25 | 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
26 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
27 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
28 | }
29 |
30 |
31 | class VGG(nn.Module):
32 |
33 | def __init__(self, features, input_size=224, num_classes=1000, init_weights=True):
34 | super(VGG, self).__init__()
35 | self.features = features
36 | last_size = int(math.ceil(input_size / 32))
37 | self.classifier = nn.Sequential(
38 | nn.Linear(512 * last_size * last_size, 4096),
39 | nn.ReLU(True),
40 | nn.Dropout(),
41 | nn.Linear(4096, 4096),
42 | nn.ReLU(True),
43 | nn.Dropout(),
44 | nn.Linear(4096, num_classes),
45 | )
46 | if init_weights:
47 | self._initialize_weights()
48 |
49 | def forward(self, x):
50 | x = self.features(x)
51 | x = x.view(x.size(0), -1)
52 | x = self.classifier(x)
53 | return x
54 |
55 | def _initialize_weights(self):
56 | for m in self.modules():
57 | if isinstance(m, nn.Conv2d):
58 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
59 | if m.bias is not None:
60 | nn.init.constant_(m.bias, 0)
61 | elif isinstance(m, nn.BatchNorm2d):
62 | nn.init.constant_(m.weight, 1)
63 | nn.init.constant_(m.bias, 0)
64 | elif isinstance(m, nn.Linear):
65 | nn.init.normal_(m.weight, 0, 0.01)
66 | nn.init.constant_(m.bias, 0)
67 |
68 |
69 | def make_layers(cfg, batch_norm=False):
70 | layers = []
71 | in_channels = 3
72 | for v in cfg:
73 | if v == 'M':
74 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
75 | else:
76 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
77 | if batch_norm:
78 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
79 | else:
80 | layers += [conv2d, nn.ReLU(inplace=True)]
81 | in_channels = v
82 | return nn.Sequential(*layers)
83 |
84 |
85 | cfg = {
86 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
87 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
88 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
89 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
90 | }
91 |
92 |
93 | def vgg11(pretrained=False, **kwargs):
94 | """VGG 11-layer model (configuration "A")
95 |
96 | Args:
97 | pretrained (bool): If True, returns a model pre-trained on ImageNet
98 | """
99 | if pretrained:
100 | kwargs['init_weights'] = False
101 | model = VGG(make_layers(cfg['A']), **kwargs)
102 | if pretrained:
103 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11']))
104 | return model
105 |
106 |
107 | def vgg11_bn(pretrained=False, **kwargs):
108 | """VGG 11-layer model (configuration "A") with batch normalization
109 |
110 | Args:
111 | pretrained (bool): If True, returns a model pre-trained on ImageNet
112 | """
113 | if pretrained:
114 | kwargs['init_weights'] = False
115 | model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs)
116 | if pretrained:
117 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11_bn']))
118 | return model
119 |
120 |
121 | def vgg13(pretrained=False, **kwargs):
122 | """VGG 13-layer model (configuration "B")
123 |
124 | Args:
125 | pretrained (bool): If True, returns a model pre-trained on ImageNet
126 | """
127 | if pretrained:
128 | kwargs['init_weights'] = False
129 | model = VGG(make_layers(cfg['B']), **kwargs)
130 | if pretrained:
131 | model.load_state_dict(model_zoo.load_url(model_urls['vgg13']))
132 | return model
133 |
134 |
135 | def vgg13_bn(pretrained=False, **kwargs):
136 | """VGG 13-layer model (configuration "B") with batch normalization
137 |
138 | Args:
139 | pretrained (bool): If True, returns a model pre-trained on ImageNet
140 | """
141 | if pretrained:
142 | kwargs['init_weights'] = False
143 | model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs)
144 | if pretrained:
145 | model.load_state_dict(model_zoo.load_url(model_urls['vgg13_bn']))
146 | return model
147 |
148 |
149 | def vgg16(pretrained=False, **kwargs):
150 | """VGG 16-layer model (configuration "D")
151 |
152 | Args:
153 | pretrained (bool): If True, returns a model pre-trained on ImageNet
154 | """
155 | if pretrained:
156 | kwargs['init_weights'] = False
157 | model = VGG(make_layers(cfg['D']), **kwargs)
158 | if pretrained:
159 | model.load_state_dict(model_zoo.load_url(model_urls['vgg16']))
160 | return model
161 |
162 |
163 | def vgg16_bn(pretrained=False, **kwargs):
164 | """VGG 16-layer model (configuration "D") with batch normalization
165 |
166 | Args:
167 | pretrained (bool): If True, returns a model pre-trained on ImageNet
168 | """
169 | if pretrained:
170 | kwargs['init_weights'] = False
171 | model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs)
172 | if pretrained:
173 | model.load_state_dict(model_zoo.load_url(model_urls['vgg16_bn']))
174 | return model
175 |
176 |
177 | def vgg19(pretrained=False, **kwargs):
178 | """VGG 19-layer model (configuration "E")
179 |
180 | Args:
181 | pretrained (bool): If True, returns a model pre-trained on ImageNet
182 | """
183 | if pretrained:
184 | kwargs['init_weights'] = False
185 | model = VGG(make_layers(cfg['E']), **kwargs)
186 | if pretrained:
187 | model.load_state_dict(model_zoo.load_url(model_urls['vgg19']))
188 | return model
189 |
190 |
191 | def vgg19_bn(pretrained=False, **kwargs):
192 | """VGG 19-layer model (configuration 'E') with batch normalization
193 |
194 | Args:
195 | pretrained (bool): If True, returns a model pre-trained on ImageNet
196 | """
197 | if pretrained:
198 | kwargs['init_weights'] = False
199 | model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs)
200 | if pretrained:
201 | model.load_state_dict(model_zoo.load_url(model_urls['vgg19_bn']))
202 | return model
203 |
204 |
205 | if __name__ == '__main__':
206 | model = vgg16_bn(input_size=224, num_classes=10)
207 | x = torch.rand(1, 3, 224, 224)
208 | device = torch.device("cuda" if torch.cuda.is_available else "cpu")
209 | x = x.to(device)
210 | model = model.to(device)
211 | y = model(x)
212 | print(torch.nn.functional.softmax(y, dim=1))
213 |
--------------------------------------------------------------------------------
/pretrained_models/download.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | echo "Download resnet models pretrained on ImageNet..."
4 |
5 | wget -N https://download.pytorch.org/models/resnet18-5c106cde.pth
6 | wget -N https://download.pytorch.org/models/resnet34-333f7ec4.pth
7 | wget -N https://download.pytorch.org/models/resnet50-19c8e357.pth
8 | wget -N https://download.pytorch.org/models/resnet101-5d3b4d8f.pth
9 | wget -N https://download.pytorch.org/models/resnet152-b121ed2d.pth
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | from utils import AverageMeter, calculate_accuracy
4 |
5 |
6 | def test(data_loader, model, args, device):
7 | batch_time = AverageMeter()
8 | data_time = AverageMeter()
9 | top1 = AverageMeter()
10 | top3 = AverageMeter()
11 |
12 | # switch to evaluate mode
13 | model.eval()
14 |
15 | end_time = time.time()
16 | for i, (input, target) in enumerate(data_loader):
17 | # measure data loading time
18 | data_time.update(time.time() - end_time)
19 |
20 | input = input.to(device)
21 | target = target.to(device)
22 |
23 | # compute output and loss
24 | output = model(input)
25 |
26 | # measure accuracy and record loss
27 | prec1, prec3 = calculate_accuracy(output, target, topk=(1, 3))
28 | # prec1[0]: convert torch.Size([1]) to torch.Size([])
29 | top1.update(prec1[0].item(), input.size(0))
30 | top3.update(prec3[0].item(), input.size(0))
31 |
32 | # measure elapsed time
33 | batch_time.update(time.time() - end_time)
34 | end_time = time.time()
35 |
36 | if (i + 1) % args.log_interval == 0:
37 | print('Test Iter [{0}/{1}]\t'
38 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
39 | 'Prec@3 {top3.val:.3f} ({top3.avg:.3f})\t'
40 | 'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
41 | 'Data Time {data_time.val:.3f} ({data_time.avg:.3f})'.format(
42 | i + 1,
43 | len(data_loader),
44 | top1=top1,
45 | top3=top3,
46 | batch_time=batch_time,
47 | data_time=data_time))
48 |
49 | print(' * Prec@1 {top1.avg:.2f}% | Prec@3 {top3.avg:.2f}%'.format(
50 | top1=top1, top3=top3))
51 |
52 | return top1.avg
53 |
--------------------------------------------------------------------------------
/tools/__init__.py:
--------------------------------------------------------------------------------
1 | from .visualize import Visualizer
2 |
--------------------------------------------------------------------------------
/tools/visualize.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import visdom
4 | import time
5 | import numpy as np
6 |
7 |
8 | class Visualizer(object):
9 |
10 | def __init__(self, env='default', **kwargs):
11 | self.vis = visdom.Visdom(env=env, **kwargs)
12 | self.index = {}
13 | self.log_text = ''
14 |
15 | def reinit(self, env='default', **kwargs):
16 | self.vis = visdom.Visdom(env=env, **kwargs)
17 | return self
18 |
19 | def plot_many(self, d):
20 | '''
21 | @params d: dict (name, value) i.e. ('loss', 0.11)
22 | '''
23 | for k, v in d.items():
24 | self.plot(k, v)
25 |
26 | def img_many(self, d):
27 | for k, v in d.items():
28 | self.img(k, v)
29 |
30 | def plot(self, name, y, **kwargs):
31 | '''
32 | self.plot('loss', 1.00)
33 | '''
34 | x = self.index.get(name, 0)
35 | self.vis.line(Y=np.array([y]),
36 | X=np.array([x]),
37 | win=name,
38 | opts=dict(title=name),
39 | update=None if x == 0 else 'append',
40 | **kwargs)
41 | self.index[name] = x + 1
42 |
43 | def img(self, name, img_, **kwargs):
44 | '''
45 | self.img('input_img', torch.Tensor(64, 64))
46 | self.img('input_imgs', torch.Tensor(3, 64, 64))
47 | self.img('input_imgs', torch.Tensor(100, 1, 64, 64))
48 | self.img('input_imgs', torch.Tensor(100, 3, 64, 64), nrows=10)
49 | '''
50 | self.vis.images(img_.cpu().numpy(),
51 | win=name,
52 | opts=dict(title=name),
53 | **kwargs)
54 |
55 | def log(self, info, win='log_text'):
56 | '''
57 | self.log({'loss': 1, 'lr': 0.0001})
58 | '''
59 | self.log_text += ('[{time}] {info}
'.format(
60 | time=time.strftime('%m%d_%H%M%S'),
61 | info=info))
62 | self.vis.text(self.log_text, win)
63 |
64 | def __getattr__(self, name):
65 | return getattr(self.vis, name)
66 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import time
3 | import os
4 |
5 | from utils import AverageMeter, calculate_accuracy
6 |
7 |
8 | def train_epoch(epoch, data_loader, model, criterion, optimizer, args, device,
9 | epoch_logger, batch_logger, vis):
10 | batch_time = AverageMeter()
11 | data_time = AverageMeter()
12 | losses = AverageMeter()
13 | top1 = AverageMeter()
14 | top3 = AverageMeter()
15 |
16 | # switch to train mode
17 | model.train()
18 |
19 | end_time = time.time()
20 | for i, (input, target) in enumerate(data_loader):
21 | # measure data loading time
22 | data_time.update(time.time() - end_time)
23 |
24 | input = input.to(device)
25 | target = target.to(device)
26 |
27 | # compute output and loss
28 | output = model(input)
29 | loss = criterion(output, target)
30 |
31 | # measure accuracy and record loss
32 | prec1, prec3 = calculate_accuracy(output, target, topk=(1, 3))
33 | losses.update(loss.item(), input.size(0))
34 | # prec1[0]: convert torch.Size([1]) to torch.Size([])
35 | top1.update(prec1[0].item(), input.size(0))
36 | top3.update(prec3[0].item(), input.size(0))
37 | """
38 | a = np.array([1, 2, 3])
39 | b = torch.from_numpy(a) # tensor([ 1, 2, 3])
40 | c = b.sum() # tensor(6)
41 | d = b.sum(0) # tensor(6)
42 | e = b.sum(0, keepdim=True) # tensor([ 6]), torch.Size([1])
43 | e[0] # tensor(6), torch.Size([])
44 | e.item() # 6
45 | """
46 |
47 | # compute gradient and do SGD step
48 | optimizer.zero_grad()
49 | loss.backward()
50 | optimizer.step()
51 |
52 | # measure elapsed time
53 | batch_time.update(time.time() - end_time)
54 | end_time = time.time()
55 |
56 | if (i + 1) % args.log_interval == 0:
57 | print('Train Epoch [{0}/{1}]([{2}/{3}])\t'
58 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
59 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
60 | 'Prec@3 {top3.val:.3f} ({top3.avg:.3f})\t'
61 | 'LR {lr:f}\t'
62 | 'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
63 | 'Data Time {data_time.val:.3f} ({data_time.avg:.3f})'.format(
64 | epoch,
65 | args.epochs,
66 | i + 1,
67 | len(data_loader),
68 | loss=losses,
69 | top1=top1,
70 | top3=top3,
71 | lr=optimizer.param_groups[0]['lr'],
72 | batch_time=batch_time,
73 | data_time=data_time))
74 |
75 | batch_logger.log({
76 | 'epoch': epoch,
77 | 'batch': i + 1,
78 | 'iter': (epoch - 1) * len(data_loader) + (i + 1),
79 | 'loss': losses.val,
80 | 'top1': top1.val,
81 | 'top3': top3.val,
82 | 'lr': optimizer.param_groups[0]['lr']
83 | })
84 |
85 | epoch_logger.log({
86 | 'epoch': epoch,
87 | 'loss': losses.avg,
88 | 'top1': top1.avg,
89 | 'top3': top3.avg,
90 | 'lr': optimizer.param_groups[0]['lr']
91 | })
92 |
93 | if epoch % args.checkpoint_interval == 0:
94 | save_file_path = os.path.join(args.checkpoint_path, 'save_{}_{}.pth'.format(args.arch, epoch))
95 | checkpoint = {
96 | 'epoch': epoch,
97 | 'arch': args.arch,
98 | 'model': model.state_dict(),
99 | 'optimizer': optimizer.state_dict(),
100 | }
101 | torch.save(checkpoint, save_file_path)
102 |
103 | vis.plot('Train loss', losses.avg)
104 | vis.plot('Train accu', top1.avg)
105 | vis.log("epoch:{epoch}, lr:{lr}, loss:{loss}, accu:{accu}".format(
106 | epoch=epoch,
107 | lr=optimizer.param_groups[0]['lr'],
108 | loss=losses.avg,
109 | accu=top1.avg))
110 |
111 |
112 | def val_epoch(epoch, data_loader, model, criterion, args, device, epoch_logger, vis):
113 | batch_time = AverageMeter()
114 | data_time = AverageMeter()
115 | losses = AverageMeter()
116 | top1 = AverageMeter()
117 | top3 = AverageMeter()
118 |
119 | # switch to evaluate mode
120 | model.eval()
121 |
122 | end_time = time.time()
123 | for i, (input, target) in enumerate(data_loader):
124 | # measure data loading time
125 | data_time.update(time.time() - end_time)
126 |
127 | input = input.to(device)
128 | target = target.to(device)
129 |
130 | # compute output and loss
131 | output = model(input)
132 | loss = criterion(output, target)
133 |
134 | # measure accuracy and record loss
135 | prec1, prec3 = calculate_accuracy(output, target, topk=(1, 3))
136 | losses.update(loss.item(), input.size(0))
137 | top1.update(prec1[0].item(), input.size(0))
138 | top3.update(prec3[0].item(), input.size(0))
139 |
140 | # measure elapsed time
141 | batch_time.update(time.time() - end_time)
142 | end_time = time.time()
143 |
144 | if (i + 1) % args.log_interval == 0:
145 | print('Valid Epoch [{0}/{1}]([{2}/{3}])\t'
146 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
147 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
148 | 'Prec@3 {top3.val:.3f} ({top3.avg:.3f})\t'
149 | 'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
150 | 'Data Time {data_time.val:.3f} ({data_time.avg:.3f})'.format(
151 | epoch,
152 | args.epochs,
153 | i + 1,
154 | len(data_loader),
155 | loss=losses,
156 | top1=top1,
157 | top3=top3,
158 | batch_time=batch_time,
159 | data_time=data_time))
160 |
161 | print(' * Prec@1 {top1.avg:.2f}% | Prec@3 {top3.avg:.2f}%'.format(
162 | top1=top1, top3=top3))
163 |
164 | epoch_logger.log({
165 | 'epoch': epoch,
166 | 'loss': losses.avg,
167 | 'top1': top1.avg,
168 | 'top3': top3.avg
169 | })
170 |
171 | vis.plot('Val loss', losses.avg)
172 | vis.plot('Val accu', top1.avg)
173 |
174 | return losses.avg, top1.avg
175 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import torch
3 |
4 |
5 | class AverageMeter(object):
6 | """Computes and stores the average and current value"""
7 | def __init__(self):
8 | self.reset()
9 |
10 | def reset(self):
11 | self.val = 0
12 | self.avg = 0
13 | self.sum = 0
14 | self.count = 0
15 |
16 | def update(self, val, n=1):
17 | self.val = val
18 | self.sum += val * n
19 | self.count += n
20 | self.avg = self.sum / self.count
21 |
22 |
23 | class Logger(object):
24 | """Outputs log files"""
25 | def __init__(self, path, header):
26 | self.log_file = open(path, 'w')
27 | self.logger = csv.writer(self.log_file, delimiter='\t')
28 | self.logger.writerow(header)
29 | self.header = header
30 |
31 | def __del(self):
32 | self.log_file.close()
33 |
34 | def log(self, values):
35 | write_values = []
36 | for col in self.header:
37 | assert col in values
38 | write_values.append(values[col])
39 |
40 | self.logger.writerow(write_values)
41 | self.log_file.flush()
42 |
43 |
44 | def calculate_accuracy(output, target, topk=(1,)):
45 | """Computes the precision@k for the specified values of k"""
46 | with torch.no_grad():
47 | maxk = max(topk)
48 | batch_size = target.size(0)
49 |
50 | _, pred = output.topk(maxk, dim=1, largest=True, sorted=True) # batch_size x maxk
51 | pred = pred.t() # transpose, maxk x batch_size
52 | # target.view(1, -1): convert (batch_size,) to 1 x batch_size
53 | # expand_as: convert 1 x batch_size to maxk x batch_size
54 | correct = pred.eq(target.view(1, -1).expand_as(pred)) # maxk x batch_size
55 |
56 | res = []
57 | for k in topk:
58 | # correct[:k] converts "maxk x batch_size" to "k x batch_size"
59 | # view(-1) converts "k x batch_size" to "(k x batch_size,)"
60 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
61 | res.append(correct_k.mul_(100.0 / batch_size))
62 | return res
63 |
--------------------------------------------------------------------------------