├── .gitignore ├── README.md ├── data └── .gitkeep ├── models └── .gitkeep ├── nuswide.py ├── onnx_export.py ├── onnx_validate.py ├── requirements.txt ├── reshape.py ├── train.py └── voc.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.pth.tar 3 | *.onnx 4 | *.engine 5 | *.jpg 6 | *.png 7 | cat_dog* 8 | plants* 9 | data/ 10 | models/ 11 | logs/ 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ImageNet training in PyTorch 2 | 3 | This implements training of popular model architectures, such as ResNet, AlexNet, and VGG on the ImageNet dataset. 4 | 5 | ## Requirements 6 | 7 | - Install PyTorch ([pytorch.org](http://pytorch.org)) 8 | - `pip install -r requirements.txt` 9 | - Download the ImageNet dataset and move validation images to labeled subfolders 10 | - To do this, you can use the following script: https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh 11 | 12 | ## Training 13 | 14 | To train a model, run `main.py` with the desired model architecture and the path to the ImageNet dataset: 15 | 16 | ```bash 17 | python main.py -a resnet18 [imagenet-folder with train and val folders] 18 | ``` 19 | 20 | The default learning rate schedule starts at 0.1 and decays by a factor of 10 every 30 epochs. This is appropriate for ResNet and models with batch normalization, but too high for AlexNet and VGG. Use 0.01 as the initial learning rate for AlexNet or VGG: 21 | 22 | ```bash 23 | python main.py -a alexnet --lr 0.01 [imagenet-folder with train and val folders] 24 | ``` 25 | 26 | ## Multi-processing Distributed Data Parallel Training 27 | 28 | You should always use the NCCL backend for multi-processing distributed training since it currently provides the best distributed training performance. 29 | 30 | ### Single node, multiple GPUs: 31 | 32 | ```bash 33 | python main.py -a resnet50 --dist-url 'tcp://127.0.0.1:FREEPORT' --dist-backend 'nccl' --multiprocessing-distributed --world-size 1 --rank 0 [imagenet-folder with train and val folders] 34 | ``` 35 | 36 | ### Multiple nodes: 37 | 38 | Node 0: 39 | ```bash 40 | python main.py -a resnet50 --dist-url 'tcp://IP_OF_NODE0:FREEPORT' --dist-backend 'nccl' --multiprocessing-distributed --world-size 2 --rank 0 [imagenet-folder with train and val folders] 41 | ``` 42 | 43 | Node 1: 44 | ```bash 45 | python main.py -a resnet50 --dist-url 'tcp://IP_OF_NODE0:FREEPORT' --dist-backend 'nccl' --multiprocessing-distributed --world-size 2 --rank 1 [imagenet-folder with train and val folders] 46 | ``` 47 | 48 | ## Usage 49 | 50 | ``` 51 | usage: main.py [-h] [--arch ARCH] [-j N] [--epochs N] [--start-epoch N] [-b N] 52 | [--lr LR] [--momentum M] [--weight-decay W] [--print-freq N] 53 | [--resume PATH] [-e] [--pretrained] [--world-size WORLD_SIZE] 54 | [--rank RANK] [--dist-url DIST_URL] 55 | [--dist-backend DIST_BACKEND] [--seed SEED] [--gpu GPU] 56 | [--multiprocessing-distributed] 57 | DIR 58 | 59 | PyTorch ImageNet Training 60 | 61 | positional arguments: 62 | DIR path to dataset 63 | 64 | optional arguments: 65 | -h, --help show this help message and exit 66 | --arch ARCH, -a ARCH model architecture: alexnet | densenet121 | 67 | densenet161 | densenet169 | densenet201 | 68 | resnet101 | resnet152 | resnet18 | resnet34 | 69 | resnet50 | squeezenet1_0 | squeezenet1_1 | vgg11 | 70 | vgg11_bn | vgg13 | vgg13_bn | vgg16 | vgg16_bn | vgg19 71 | | vgg19_bn (default: resnet18) 72 | -j N, --workers N number of data loading workers (default: 4) 73 | --epochs N number of total epochs to run 74 | --start-epoch N manual epoch number (useful on restarts) 75 | -b N, --batch-size N mini-batch size (default: 256), this is the total 76 | batch size of all GPUs on the current node when using 77 | Data Parallel or Distributed Data Parallel 78 | --lr LR, --learning-rate LR 79 | initial learning rate 80 | --momentum M momentum 81 | --weight-decay W, --wd W 82 | weight decay (default: 1e-4) 83 | --print-freq N, -p N print frequency (default: 10) 84 | --resume PATH path to latest checkpoint (default: none) 85 | -e, --evaluate evaluate model on validation set 86 | --pretrained use pre-trained model 87 | --world-size WORLD_SIZE 88 | number of nodes for distributed training 89 | --rank RANK node rank for distributed training 90 | --dist-url DIST_URL url used to set up distributed training 91 | --dist-backend DIST_BACKEND 92 | distributed backend 93 | --seed SEED seed for initializing training. 94 | --gpu GPU GPU id to use. 95 | --multiprocessing-distributed 96 | Use multi-processing distributed training to launch N 97 | processes per node, which has N GPUs. This is the 98 | fastest way to use PyTorch for either single node or 99 | multi node data parallel training 100 | ``` 101 | -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dusty-nv/pytorch-classification/3e9cf8c4003311009539a6c101d156c919fe2250/data/.gitkeep -------------------------------------------------------------------------------- /models/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dusty-nv/pytorch-classification/3e9cf8c4003311009539a6c101d156c919fe2250/models/.gitkeep -------------------------------------------------------------------------------- /nuswide.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | import csv 4 | import glob 5 | 6 | import torch 7 | import numpy as np 8 | 9 | from PIL import Image 10 | 11 | 12 | class NUSWideDataset(torch.utils.data.Dataset): 13 | """ 14 | Dataloader for NUS-WIDE multi-label classification dataset 15 | https://lms.comp.nus.edu.sg/wp-content/uploads/2019/research/nuswide/NUS-WIDE.html 16 | 17 | TODO: support custom labels and class culling 18 | """ 19 | def __init__(self, root, set, transform=None, target_transform=None): 20 | """ 21 | Load either the 'trainval' or 'test' set 22 | """ 23 | self.root = root 24 | self.path_images = os.path.join(root, 'images') 25 | self.set = set 26 | self.transform = transform 27 | self.target_transform = target_transform 28 | 29 | # load the class labels 30 | #with open(os.path.join(root, 'labels.txt'), 'r') as file: 31 | # self.classes = file.read().splitlines() 32 | 33 | # load the available images 34 | self.fn_map = {} 35 | 36 | for fn in glob.glob(os.path.join(root, 'images/*.jpg')): 37 | tmp = os.path.basename(fn).split('_')[1] 38 | self.fn_map[tmp] = fn 39 | 40 | # load class labels from CSV 41 | self.images = self.read_labels() 42 | 43 | print(f"=> NUS-WIDE classification set={set} classes={len(self.classes)} images={len(self.images)}") 44 | 45 | def __getitem__(self, index): 46 | path, target = self.images[index] 47 | img = Image.open(path).convert('RGB') 48 | 49 | if self.transform is not None: 50 | img = self.transform(img) 51 | 52 | if self.target_transform is not None: 53 | target = self.target_transform(target) 54 | 55 | return img, target 56 | 57 | def __len__(self): 58 | return len(self.images) 59 | 60 | def read_image_list(self): 61 | imagelist = {} 62 | hash2ids = {} 63 | 64 | if self.set == "trainval": 65 | path = os.path.join(self.root, "ImageList", "TrainImagelist.txt") 66 | elif self.set == "test": 67 | path = os.path.join(self.root, "ImageList", "TestImagelist.txt") 68 | else: 69 | raise ValueError(f"invalid set '{self.set}' (should be either 'trainval' or 'test')") 70 | 71 | with open(path, 'r') as f: 72 | for i, line in enumerate(f): 73 | line = line.split('\\')[-1] 74 | start = line.index('_') 75 | end = line.index('.') 76 | imagelist[i] = line[start+1:end] 77 | hash2ids[line[start+1:end]] = i 78 | 79 | return imagelist 80 | 81 | def read_labels(self, header=True): 82 | images = [] 83 | num_categories = 0 84 | imagelist = self.read_image_list() 85 | 86 | file = os.path.join(self.root, 'classification_labels', 'classification_' + self.set + '.csv') 87 | print(f"=> loading {file}") 88 | 89 | with open(file, 'r') as f: 90 | reader = csv.reader(f) 91 | rownum = 0 92 | for row in reader: 93 | if header and rownum == 0: 94 | header = row 95 | self.classes = header[1:] 96 | else: 97 | if num_categories == 0: 98 | num_categories = len(row) - 1 99 | name = int(row[0]) 100 | labels = (np.asarray(row[1:num_categories + 1])).astype(np.float32) # BCELoss requires float 101 | labels = torch.from_numpy(labels) 102 | labels[labels==-1] = 0 # TODO should these remain as -1? 103 | name2 = self.fn_map[imagelist[name]] 104 | item = (name2, labels) 105 | images.append(item) 106 | rownum += 1 107 | return images 108 | 109 | def get_class_distribution(self): 110 | distribution = [0] * len(self.classes) 111 | 112 | for _, labels in self.images: 113 | for n, label in enumerate(labels): 114 | if label > 0: 115 | distribution[n] += 1 116 | 117 | return distribution 118 | 119 | 120 | if __name__ == "__main__": 121 | import argparse 122 | 123 | parser = argparse.ArgumentParser() 124 | 125 | parser.add_argument('--data', type=str, default='.') 126 | parser.add_argument('--set', type=str, default='trainval') 127 | parser.add_argument('--load-data', action='store_true') 128 | parser.add_argument('--distribution', action='store_true') 129 | 130 | args = parser.parse_args() 131 | print(args) 132 | 133 | # load the dataset 134 | dataset = NUSWideDataset(args.data, args.set) 135 | 136 | # verify that all images load 137 | if args.load_data: 138 | for idx, (img, target) in enumerate(dataset): 139 | print(f"loaded image {idx} dims={img.size} classes={[dataset.classes[n] for n in target.nonzero(as_tuple=True)[0]]}") 140 | #print(f"labels: {target}") 141 | 142 | # get the class distributions 143 | if args.distribution: 144 | print("=> computing class distributions:") 145 | 146 | distribution = dataset.get_class_distribution() 147 | total_labels = 0 148 | 149 | for n, count in enumerate(distribution): 150 | print(f" class {n} {dataset.classes[n]} - {count}") 151 | total_labels += count 152 | 153 | print(f"=> NUS-WIDE classification set={dataset.set} classes={len(dataset.classes)} images={len(dataset.images)} labels={total_labels}") 154 | -------------------------------------------------------------------------------- /onnx_export.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # converts a saved PyTorch model to ONNX format 3 | import os 4 | import argparse 5 | 6 | import torch 7 | import torchvision.models as models 8 | 9 | from reshape import reshape_model 10 | 11 | model_names = sorted(name for name in models.__dict__ 12 | if name.islower() and not name.startswith("__") 13 | and callable(models.__dict__[name])) 14 | 15 | # parse command line 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--input', type=str, default='model_best.pth.tar', help="path to input PyTorch model (default: model_best.pth.tar)") 18 | parser.add_argument('--output', type=str, default='', help="desired path of converted ONNX model (default: .onnx)") 19 | parser.add_argument('--model-dir', type=str, default='', help="directory to look for the input PyTorch model in, and export the converted ONNX model to (if --output doesn't specify a directory)") 20 | parser.add_argument('--no-activation', type=bool, default=False, help="disable adding Softmax or Sigmoid layer to model (default is to add it)") 21 | 22 | args = parser.parse_args() 23 | print(args) 24 | 25 | # format input model path 26 | if args.model_dir: 27 | args.model_dir = os.path.expanduser(args.model_dir) 28 | args.input = os.path.join(args.model_dir, args.input) 29 | 30 | # set the device 31 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 32 | print('=> running on device ' + str(device)) 33 | 34 | # load the model checkpoint 35 | print('=> loading checkpoint: ' + args.input) 36 | checkpoint = torch.load(args.input) 37 | arch = checkpoint['arch'] 38 | 39 | # create the model architecture 40 | print('=> using model: ' + arch) 41 | model = models.__dict__[arch](pretrained=True) 42 | 43 | # reshape the model's output 44 | model = reshape_model(model, arch, checkpoint['num_classes']) 45 | 46 | # load the model weights 47 | model.load_state_dict(checkpoint['state_dict']) 48 | 49 | # add softmax layer 50 | if not args.no_activation: 51 | if checkpoint.get('multi_label', False): 52 | print('=> adding nn.Sigmoid layer to multi-label model') 53 | model = torch.nn.Sequential(model, torch.nn.Sigmoid()) 54 | else: 55 | print('=> adding nn.Softmax layer to model') 56 | model = torch.nn.Sequential(model, torch.nn.Softmax(1)) 57 | 58 | model.to(device) 59 | model.eval() 60 | 61 | print(model) 62 | 63 | # create example image data 64 | resolution = checkpoint['resolution'] 65 | input = torch.ones((1, 3, resolution, resolution)).cuda() 66 | print('=> input size: {:d}x{:d}'.format(resolution, resolution)) 67 | 68 | # format output model path 69 | if not args.output: 70 | args.output = arch + '.onnx' 71 | 72 | if args.model_dir and args.output.find('/') == -1 and args.output.find('\\') == -1: 73 | args.output = os.path.join(args.model_dir, args.output) 74 | 75 | # export the model 76 | input_names = [ "input_0" ] 77 | output_names = [ "output_0" ] 78 | 79 | print('=> exporting model to ONNX...') 80 | torch.onnx.export(model, input, args.output, verbose=True, input_names=input_names, output_names=output_names) 81 | print('=> model exported to: {:s}'.format(args.output)) 82 | 83 | 84 | -------------------------------------------------------------------------------- /onnx_validate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # Check that an ONNX model is valid and well-formed. 4 | # 5 | # Before running this script, install the following: 6 | # 7 | # $ sudo apt-get install protobuf-compiler libprotoc-dev 8 | # $ pip install onnx 9 | # 10 | import onnx 11 | import argparse 12 | import sys 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('model', type=str, default='resnet18.onnx', help='path to ONNX model to validate') 16 | args = parser.parse_args(sys.argv[1:]) 17 | 18 | # Load the ONNX model 19 | model = onnx.load(args.model) 20 | 21 | # Print a human readable representation of the graph 22 | print('Network Graph:') 23 | print(onnx.helper.printable_graph(model.graph)) 24 | print('') 25 | 26 | # Print model metadata 27 | print('ONNX version: ' + onnx.__version__) 28 | print('IR version: {:d}'.format(model.ir_version)) 29 | print('Producer name: ' + model.producer_name) 30 | print('Producer version: ' + model.producer_version) 31 | print('Model version: {:d}'.format(model.model_version)) 32 | print('') 33 | 34 | # Check that the IR is well formed 35 | print('Checking model IR...') 36 | onnx.checker.check_model(model) 37 | print('The model was checked successfully!') 38 | 39 | 40 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | tensorboard -------------------------------------------------------------------------------- /reshape.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn 4 | 5 | # 6 | # reshape the model for N classes 7 | # 8 | def reshape_model(model, arch, num_classes): 9 | """Reshape a model's output layers for the given number of classes""" 10 | 11 | # reshape output layers for the dataset 12 | if arch.startswith("resnet"): 13 | model.fc = torch.nn.Linear(model.fc.in_features, num_classes) 14 | print("=> reshaped ResNet fully-connected layer with: " + str(model.fc)) 15 | 16 | elif arch.startswith("alexnet"): 17 | model.classifier[6] = torch.nn.Linear(model.classifier[6].in_features, num_classes) 18 | print("=> reshaped AlexNet classifier layer with: " + str(model.classifier[6])) 19 | 20 | elif arch.startswith("vgg"): 21 | model.classifier[6] = torch.nn.Linear(model.classifier[6].in_features, num_classes) 22 | print("=> reshaped VGG classifier layer with: " + str(model.classifier[6])) 23 | 24 | elif arch.startswith("squeezenet"): 25 | model.classifier[1] = torch.nn.Conv2d(512, num_classes, kernel_size=(1,1), stride=(1,1)) 26 | model.num_classes = num_classes 27 | print("=> reshaped SqueezeNet classifier layer with: " + str(model.classifier[1])) 28 | 29 | elif arch.startswith("densenet"): 30 | model.classifier = torch.nn.Linear(model.classifier.in_features, num_classes) 31 | print("=> reshaped DenseNet classifier layer with: " + str(model.classifier)) 32 | 33 | elif arch.startswith("efficientnet"): 34 | model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, num_classes) 35 | print(f"=> reshaped {arch} classifier layer with: " + str(model.classifier[1])) 36 | 37 | elif arch.startswith("mobilenet"): 38 | model.classifier[-1] = torch.nn.Linear(model.classifier[-1].in_features, num_classes) 39 | print(f"=> reshaped {arch} classifier layer with: " + str(model.classifier[-1])) 40 | 41 | elif arch.startswith("inception"): 42 | model.AuxLogits.fc = torch.nn.Linear(model.AuxLogits.fc.in_features, num_classes) 43 | model.fc = torch.nn.Linear(model.fc.in_features, num_classes) 44 | 45 | print("=> reshaped Inception aux-logits layer with: " + str(model.AuxLogits.fc)) 46 | print("=> reshaped Inception fully-connected layer with: " + str(model.fc)) 47 | 48 | elif arch.startswith("googlenet"): 49 | if model.aux_logits: 50 | from torchvision.models.googlenet import InceptionAux 51 | 52 | model.aux1 = InceptionAux(512, num_classes) 53 | model.aux2 = InceptionAux(528, num_classes) 54 | 55 | print("=> reshaped GoogleNet aux-logits layers with: ") 56 | print(" " + str(model.aux1)) 57 | print(" " + str(model.aux2)) 58 | 59 | model.fc = torch.nn.Linear(model.fc.in_features, num_classes) 60 | print("=> reshaped GoogleNet fully-connected layer with: " + str(model.fc)) 61 | 62 | else: 63 | raise ValueError(f"classifier reshaping not supported for {arch}") 64 | 65 | return model 66 | 67 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # Note -- this training script is tweaked from the original at: 4 | # https://github.com/pytorch/examples/tree/master/imagenet 5 | # 6 | # For a step-by-step guide to transfer learning with PyTorch, see: 7 | # https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html 8 | # 9 | import argparse 10 | import os 11 | import random 12 | 13 | import time 14 | import shutil 15 | import warnings 16 | import datetime 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.parallel 21 | import torch.nn.functional as F 22 | import torch.backends.cudnn as cudnn 23 | import torch.optim 24 | import torch.utils.data 25 | import torchvision.transforms as transforms 26 | import torchvision.datasets as datasets 27 | import torchvision.models as models 28 | 29 | from torch.utils.tensorboard import SummaryWriter 30 | 31 | from voc import VOCDataset 32 | from nuswide import NUSWideDataset 33 | from reshape import reshape_model 34 | 35 | 36 | # get the available network architectures 37 | model_names = sorted(name for name in models.__dict__ 38 | if name.islower() and not name.startswith("__") 39 | and callable(models.__dict__[name])) 40 | 41 | 42 | # parse command-line arguments 43 | parser = argparse.ArgumentParser(description='PyTorch Image Classifier Training') 44 | 45 | parser.add_argument('data', metavar='DIR', 46 | help='path to dataset') 47 | parser.add_argument('--dataset-type', type=str, default='folder', 48 | choices=['folder', 'nuswide', 'voc'], 49 | help='specify the dataset type (default: folder)') 50 | parser.add_argument('--multi-label', action='store_true', 51 | help='multi-label model (aka image tagging)') 52 | parser.add_argument('--multi-label-threshold', type=float, default=0.5, 53 | help='confidence threshold for counting a prediction as correct') 54 | parser.add_argument('--model-dir', type=str, default='models', 55 | help='path to desired output directory for saving model ' 56 | 'checkpoints (default: models/)') 57 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', 58 | choices=model_names, 59 | help='model architecture: ' + ' | '.join(model_names) + ' (default: resnet18)') 60 | parser.add_argument('--resolution', default=224, type=int, metavar='N', 61 | help='input NxN image resolution of model (default: 224x224) ' 62 | 'note than Inception models should use 299x299') 63 | parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', 64 | help='number of data loading workers (default: 2)') 65 | parser.add_argument('--epochs', default=35, type=int, metavar='N', 66 | help='number of total epochs to run') 67 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 68 | help='manual epoch number (useful on restarts)') 69 | parser.add_argument('-b', '--batch-size', default=8, type=int, metavar='N', 70 | help='mini-batch size (default: 8)') 71 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 72 | metavar='LR', help='initial learning rate', dest='lr') 73 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 74 | help='momentum') 75 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 76 | metavar='W', help='weight decay (default: 1e-4)', 77 | dest='weight_decay') 78 | parser.add_argument('-p', '--print-freq', default=10, type=int, 79 | metavar='N', help='print frequency (default: 10)') 80 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 81 | help='path to latest checkpoint (default: none)') 82 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 83 | help='evaluate model on validation set') 84 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', default=True, 85 | help='use pre-trained model') 86 | parser.add_argument('--seed', default=None, type=int, 87 | help='seed for initializing training') 88 | parser.add_argument('--gpu', default=0, type=int, 89 | help='GPU ID to use (default: 0)') 90 | 91 | args = parser.parse_args() 92 | 93 | 94 | # open tensorboard logger (to model_dir/tensorboard) 95 | tensorboard = SummaryWriter(log_dir=os.path.join(args.model_dir, "tensorboard", f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}")) 96 | print(f"To start tensorboard run: tensorboard --log-dir={os.path.join(args.model_dir, 'tensorboard')}") 97 | 98 | # variable for storing the best model accuracy so far 99 | best_accuracy = 0 100 | 101 | 102 | def main(args): 103 | """ 104 | Load dataset, setup model, and train for N epochs 105 | """ 106 | global best_accuracy 107 | 108 | if args.seed is not None: 109 | random.seed(args.seed) 110 | torch.manual_seed(args.seed) 111 | cudnn.deterministic = True 112 | warnings.warn('You have chosen to seed training. ' 113 | 'This will turn on the CUDNN deterministic setting, ' 114 | 'which can slow down your training considerably! ' 115 | 'You may see unexpected behavior when restarting ' 116 | 'from checkpoints.') 117 | 118 | if args.gpu is not None: 119 | print(f"=> using GPU {args.gpu} ({torch.cuda.get_device_name(args.gpu)})") 120 | 121 | # setup data transformations 122 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 123 | std=[0.229, 0.224, 0.225]) 124 | 125 | train_transforms = transforms.Compose([ 126 | transforms.RandomResizedCrop(args.resolution), 127 | transforms.RandomHorizontalFlip(), 128 | transforms.ToTensor(), 129 | normalize, 130 | ]) 131 | 132 | val_transforms = transforms.Compose([ 133 | transforms.Resize(args.resolution), 134 | transforms.CenterCrop(args.resolution), 135 | transforms.ToTensor(), 136 | normalize, 137 | ]) 138 | 139 | # load the dataset 140 | if args.dataset_type == 'folder': 141 | train_dataset = datasets.ImageFolder(os.path.join(args.data, 'train'), train_transforms) 142 | val_dataset = datasets.ImageFolder(os.path.join(args.data, 'val'), val_transforms) 143 | elif args.dataset_type == 'nuswide': 144 | train_dataset = NUSWideDataset(args.data, 'trainval', train_transforms) 145 | val_dataset = NUSWideDataset(args.data, 'test', val_transforms) 146 | elif args.dataset_type == 'voc': 147 | train_dataset = VOCDataset(args.data, 'trainval', train_transforms) 148 | val_dataset = VOCDataset(args.data, 'val', val_transforms) 149 | 150 | if (args.dataset_type == 'nuswide' or args.dataset_type == 'voc') and (not args.multi_label): 151 | raise ValueError("nuswide or voc datasets should be run with --multi-label") 152 | 153 | print(f"=> dataset classes: {len(train_dataset.classes)} {train_dataset.classes}") 154 | 155 | train_loader = torch.utils.data.DataLoader( 156 | train_dataset, batch_size=args.batch_size, shuffle=True, 157 | num_workers=args.workers, pin_memory=True) 158 | 159 | val_loader = torch.utils.data.DataLoader( 160 | val_dataset, batch_size=args.batch_size, shuffle=False, 161 | num_workers=args.workers, pin_memory=True) 162 | 163 | # create or load the model if using pre-trained (the default) 164 | if args.pretrained: 165 | print(f"=> using pre-trained model '{args.arch}'") 166 | model = models.__dict__[args.arch](pretrained=True) 167 | else: 168 | print(f"=> creating model '{args.arch}'") 169 | model = models.__dict__[args.arch]() 170 | 171 | # reshape the model for the number of classes in the dataset 172 | model = reshape_model(model, args.arch, len(train_dataset.classes)) 173 | 174 | # define loss function (criterion) and optimizer 175 | if args.multi_label: 176 | criterion = nn.BCEWithLogitsLoss() 177 | else: 178 | criterion = nn.CrossEntropyLoss() 179 | 180 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 181 | momentum=args.momentum, 182 | weight_decay=args.weight_decay) 183 | 184 | # transfer the model to the GPU that it should be run on 185 | if args.gpu is not None: 186 | torch.cuda.set_device(args.gpu) 187 | model = model.cuda(args.gpu) 188 | criterion = criterion.cuda(args.gpu) 189 | 190 | # optionally resume from a checkpoint 191 | if args.resume: 192 | if os.path.isfile(args.resume): 193 | print(f"=> loading checkpoint '{args.resume}'") 194 | checkpoint = torch.load(args.resume) 195 | args.start_epoch = checkpoint['epoch'] + 1 196 | #best_accuracy = checkpoint['best_accuracy'] 197 | #if args.gpu is not None: 198 | # best_accuracy = best_accuracy.to(args.gpu) # best_accuracy may be from a checkpoint from a different GPU 199 | model.load_state_dict(checkpoint['state_dict']) 200 | optimizer.load_state_dict(checkpoint['optimizer']) 201 | print(f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']})") 202 | else: 203 | print(f"=> no checkpoint found at '{args.resume}'") 204 | 205 | cudnn.benchmark = True 206 | 207 | # if in evaluation mode, only run validation 208 | if args.evaluate: 209 | validate(val_loader, model, criterion, 0) 210 | return 211 | 212 | # train for the specified number of epochs 213 | for epoch in range(args.start_epoch, args.epochs): 214 | # decay the learning rate 215 | adjust_learning_rate(optimizer, epoch) 216 | 217 | # train for one epoch 218 | train_loss, train_acc = train(train_loader, model, criterion, optimizer, epoch) 219 | 220 | # evaluate on validation set 221 | val_loss, val_acc = validate(val_loader, model, criterion, epoch) 222 | 223 | # remember best acc@1 and save checkpoint 224 | is_best = val_acc > best_accuracy 225 | best_accuracy = max(val_acc, best_accuracy) 226 | 227 | print(f"=> Epoch {epoch}") 228 | print(f" * Train Loss {train_loss:.4e}") 229 | print(f" * Train Accuracy {train_acc:.4f}") 230 | print(f" * Val Loss {val_loss:.4e}") 231 | print(f" * Val Accuracy {val_acc:.4f}{'*' if is_best else ''}") 232 | 233 | save_checkpoint({ 234 | 'epoch': epoch, 235 | 'arch': args.arch, 236 | 'resolution': args.resolution, 237 | 'classes': train_dataset.classes, 238 | 'num_classes': len(train_dataset.classes), 239 | 'multi_label': args.multi_label, 240 | 'state_dict': model.state_dict(), 241 | 'accuracy': {'train': train_acc, 'val': val_acc}, 242 | 'loss' : {'train': train_loss, 'val': val_loss}, 243 | 'optimizer' : optimizer.state_dict(), 244 | }, is_best) 245 | 246 | 247 | def train(train_loader, model, criterion, optimizer, epoch): 248 | """ 249 | Train one epoch over the dataset 250 | """ 251 | batch_time = AverageMeter('Time', ':6.3f') 252 | data_time = AverageMeter('Data', ':6.3f') 253 | losses = AverageMeter('Loss', ':.4e') 254 | acc = AverageMeter('Accuracy', ':7.3f') 255 | 256 | progress = ProgressMeter( 257 | len(train_loader), 258 | [batch_time, data_time, losses, acc], 259 | prefix=f"Epoch: [{epoch}]") 260 | 261 | # switch to train mode 262 | model.train() 263 | 264 | # get the start time 265 | epoch_start = time.time() 266 | end = epoch_start 267 | 268 | # train over each image batch from the dataset 269 | for i, (images, target) in enumerate(train_loader): 270 | # measure data loading time 271 | data_time.update(time.time() - end) 272 | 273 | if args.gpu is not None: 274 | images = images.cuda(args.gpu, non_blocking=True) 275 | target = target.cuda(args.gpu, non_blocking=True) 276 | 277 | # compute output 278 | output = model(images) 279 | loss = criterion(output, target) 280 | 281 | # record loss and measure accuracy 282 | losses.update(loss.item(), images.size(0)) 283 | acc.update(accuracy(output, target), images.size(0)) 284 | 285 | # compute gradient and do SGD step 286 | optimizer.zero_grad() 287 | loss.backward() 288 | optimizer.step() 289 | 290 | # measure elapsed time 291 | batch_time.update(time.time() - end) 292 | end = time.time() 293 | 294 | if i % args.print_freq == 0 or i == len(train_loader)-1: 295 | progress.display(i) 296 | 297 | print(f"Epoch: [{epoch}] completed, elapsed time {time.time() - epoch_start:6.3f} seconds") 298 | 299 | tensorboard.add_scalar('Loss/train', losses.avg, epoch) 300 | tensorboard.add_scalar('Accuracy/train', acc.avg, epoch) 301 | 302 | return losses.avg, acc.avg 303 | 304 | 305 | def validate(val_loader, model, criterion, epoch): 306 | """ 307 | Measure model performance across the val dataset 308 | """ 309 | batch_time = AverageMeter('Time', ':6.3f') 310 | losses = AverageMeter('Loss', ':.4e') 311 | acc = AverageMeter('Accuracy', ':7.3f') 312 | 313 | progress = ProgressMeter( 314 | len(val_loader), 315 | [batch_time, losses, acc], 316 | prefix='Val: ') 317 | 318 | # switch to evaluate mode 319 | model.eval() 320 | 321 | with torch.no_grad(): 322 | end = time.time() 323 | for i, (images, target) in enumerate(val_loader): 324 | if args.gpu is not None: 325 | images = images.cuda(args.gpu, non_blocking=True) 326 | target = target.cuda(args.gpu, non_blocking=True) 327 | 328 | # compute output 329 | output = model(images) 330 | loss = criterion(output, target) 331 | 332 | # record loss and measure accuracy 333 | losses.update(loss.item(), images.size(0)) 334 | acc.update(accuracy(output, target), images.size(0)) 335 | 336 | # measure elapsed time 337 | batch_time.update(time.time() - end) 338 | end = time.time() 339 | 340 | if i % args.print_freq == 0 or i == len(val_loader)-1: 341 | progress.display(i) 342 | 343 | tensorboard.add_scalar('Loss/val', losses.avg, epoch) 344 | tensorboard.add_scalar('Accuracy/val', acc.avg, epoch) 345 | 346 | return losses.avg, acc.avg 347 | 348 | 349 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', best_filename='model_best.pth.tar', labels_filename='labels.txt'): 350 | """ 351 | Save a model checkpoint file, along with the best-performing model if applicable 352 | """ 353 | if args.model_dir: 354 | model_dir = os.path.expanduser(args.model_dir) 355 | 356 | if not os.path.exists(model_dir): 357 | os.mkdir(model_dir) 358 | 359 | filename = os.path.join(model_dir, filename) 360 | best_filename = os.path.join(model_dir, best_filename) 361 | labels_filename = os.path.join(model_dir, labels_filename) 362 | 363 | # save the checkpoint 364 | torch.save(state, filename) 365 | 366 | # earmark the best checkpoint 367 | if is_best: 368 | shutil.copyfile(filename, best_filename) 369 | print(f"saved best model to: {best_filename}") 370 | else: 371 | print(f"saved checkpoint to: {filename}") 372 | 373 | # save labels.txt on the first epoch 374 | if state['epoch'] == 0: 375 | with open(labels_filename, 'w') as file: 376 | for label in state['classes']: 377 | file.write(f"{label}\n") 378 | print(f"saved class labels to: {labels_filename}") 379 | 380 | 381 | def adjust_learning_rate(optimizer, epoch): 382 | """ 383 | Sets the learning rate to the initial LR decayed by 10 every 30 epochs 384 | """ 385 | lr = args.lr * (0.1 ** (epoch // 30)) 386 | for param_group in optimizer.param_groups: 387 | param_group['lr'] = lr 388 | 389 | 390 | def accuracy(output, target): 391 | """ 392 | Computes the accuracy of predictions vs groundtruth 393 | """ 394 | with torch.no_grad(): 395 | if args.multi_label: 396 | output = F.sigmoid(output) 397 | preds = ((output >= args.multi_label_threshold) == target.bool()) # https://medium.com/@yrodriguezmd/tackling-the-accuracy-multi-metric-9e2356f62513 398 | 399 | # https://stackoverflow.com/a/61585551 400 | #output[output >= args.multi_label_threshold] = 1 401 | #output[output < args.multi_label_threshold] = 0 402 | #preds = (output == target) 403 | else: 404 | output = F.softmax(output, dim=-1) 405 | _, preds = torch.max(output, dim=-1) 406 | preds = (preds == target) 407 | 408 | return preds.float().mean().cpu().item() * 100.0 409 | 410 | 411 | class AverageMeter(object): 412 | """ 413 | Computes and stores the average and current value 414 | """ 415 | def __init__(self, name, fmt=':f'): 416 | self.name = name 417 | self.fmt = fmt 418 | self.reset() 419 | 420 | def reset(self): 421 | self.val = 0 422 | self.avg = 0 423 | self.sum = 0 424 | self.count = 0 425 | 426 | def update(self, val, n=1): 427 | self.val = val 428 | self.sum += val * n 429 | self.count += n 430 | self.avg = self.sum / self.count 431 | 432 | def __str__(self): 433 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 434 | return fmtstr.format(**self.__dict__) 435 | 436 | 437 | class ProgressMeter(object): 438 | """ 439 | Progress metering 440 | """ 441 | def __init__(self, num_batches, meters, prefix=""): 442 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 443 | self.meters = meters 444 | self.prefix = prefix 445 | 446 | def display(self, batch): 447 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 448 | entries += [str(meter) for meter in self.meters] 449 | print(' '.join(entries)) 450 | 451 | def _get_batch_fmtstr(self, num_batches): 452 | num_digits = len(str(num_batches // 1)) 453 | fmt = '{:' + str(num_digits) + 'd}' 454 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 455 | 456 | 457 | if __name__ == '__main__': 458 | main(args) 459 | -------------------------------------------------------------------------------- /voc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | import torch 4 | import numpy as np 5 | import xml.dom.minidom 6 | 7 | from PIL import Image 8 | 9 | 10 | class VOCDataset(torch.utils.data.Dataset): 11 | """ 12 | Multi-label classification dataset for Pascal VOC (http://host.robots.ox.ac.uk/pascal/VOC/) 13 | This extracts objects (from the object detection benchmark) and uses them as image tags. 14 | """ 15 | def __init__(self, root, set, transform=None, target_transform=None, use_difficult=False): 16 | """ 17 | Load the dataset (tested on VOC2012) 18 | 19 | Parameters 20 | root (string) -- path to the VOC2007 or VOC2012 dataset, containing the following sub-directories: 21 | Annotations, ImageSets, JPEGImages, SegmentationClass, SegmentationObject 22 | 23 | set (string) -- the data subset, which corresponds to one of the files under ImageSets/Main/ 24 | like 'train', 'trainval', 'test', 'val', ect. 25 | """ 26 | self.root = root 27 | self.set = set 28 | self.transform = transform 29 | self.target_transform = target_transform 30 | 31 | self.classes = [ 32 | 'aeroplane', 'bicycle', 'bird', 'boat', 33 | 'bottle', 'bus', 'car', 'cat', 'chair', 34 | 'cow', 'diningtable', 'dog', 'horse', 35 | 'motorbike', 'person', 'pottedplant', 36 | 'sheep', 'sofa', 'train', 'tvmonitor' 37 | ] 38 | 39 | # load the image ID's 40 | with open(os.path.join(self.root, f"ImageSets/Main/{set}.txt")) as file: 41 | self.id_list = file.read().splitlines() 42 | 43 | # load the labels 44 | self.labels = [] 45 | 46 | for id in self.id_list: 47 | xml_tree = xml.dom.minidom.parse(os.path.join(self.root, 'Annotations', f'{id}.xml')) 48 | xml_root = xml_tree.documentElement 49 | 50 | objects = xml_root.getElementsByTagName('object') 51 | labels = np.zeros(len(self.classes), dtype=np.float32) 52 | 53 | for obj in objects: 54 | if (not use_difficult) and (obj.getElementsByTagName('difficult')[0].firstChild.data == '1'): 55 | continue 56 | 57 | tag = obj.getElementsByTagName('name')[0].firstChild.data.lower() 58 | labels[self.classes.index(tag)] = 1.0 59 | 60 | self.labels.append(torch.from_numpy(labels)) 61 | 62 | def __getitem__(self, index): 63 | image = Image.open(os.path.join(self.root, 'JPEGImages', f'{self.id_list[index]}.jpg')).convert('RGB') 64 | labels = self.labels[index] 65 | 66 | if self.transform: 67 | image = self.transform(image) 68 | 69 | if self.target_transform: 70 | labels = self.target_transform(labels) 71 | 72 | return image, labels 73 | 74 | def __len__(self): 75 | return len(self.id_list) 76 | 77 | def get_class_distribution(self): 78 | with torch.no_grad(): 79 | distribution = torch.zeros(len(self.classes)) 80 | for labels in self.labels: 81 | distribution += labels 82 | 83 | return distribution.to(dtype=torch.int64).tolist() 84 | 85 | 86 | if __name__ == "__main__": 87 | import argparse 88 | 89 | parser = argparse.ArgumentParser() 90 | 91 | parser.add_argument('--data', type=str, default='.') 92 | parser.add_argument('--set', type=str, default='trainval') 93 | parser.add_argument('--load-data', action='store_true') 94 | parser.add_argument('--distribution', action='store_true') 95 | 96 | args = parser.parse_args() 97 | print(args) 98 | 99 | # load the dataset 100 | dataset = VOCDataset(args.data, args.set) 101 | print(f"=> loaded VOC dataset set={dataset.set} classes={len(dataset.classes)} images={len(dataset)}") 102 | 103 | # verify that all images load 104 | if args.load_data: 105 | for idx, (img, target) in enumerate(dataset): 106 | print(f"loaded image {idx} dims={img.size} classes={[dataset.classes[n] for n in target.nonzero(as_tuple=True)[0]]}") 107 | #print(f"labels: {target}") 108 | 109 | # get the class distributions 110 | if args.distribution: 111 | print("=> computing class distributions:") 112 | 113 | distribution = dataset.get_class_distribution() 114 | total_labels = 0 115 | 116 | for n, count in enumerate(distribution): 117 | print(f" class {n} {dataset.classes[n]} - {count}") 118 | total_labels += count 119 | 120 | print(f"=> loaded VOC dataset set={dataset.set} classes={len(dataset.classes)} images={len(dataset)} labels={total_labels}") 121 | --------------------------------------------------------------------------------