├── .gitignore ├── LICENSE ├── README.md ├── extensions ├── __init__.py ├── data_parallel.py ├── model_refinery_wrapper.py └── refinery_loss.py ├── imagenet.py ├── models ├── __init__.py ├── alexnet.py ├── blocks.py ├── model_factory.py └── resnet50.py ├── opts.py ├── requirements.txt ├── test.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | __pycache__ 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The following license governs your NON-COMMERCIAL use of the Software. Commercial use is strictly prohibited. 2 | 3 | This Software License Agreement (“License”) is a legal agreement between you and the Xnor.ai Inc. (“Xnor” or “we”) for the software made available in this GitHub repository, including all source code, and any related materials, documentation, files, media, and any updates or revisions we may provide (“Software”). 4 | 5 | By using, copying, reproducing, adapting, modifying, distributing, or otherwise utilizing this Software, you agree to be bound by the terms of this License. If you do not agree, do not use the Software. 6 | 7 | 1. PERMISSIONS AND RESTRICTIONS: 8 | (a) You may copy, reproduce, adapt, modify, distribute, or otherwise utilize this Software solely for any NON-COMMERCIAL purpose. Examples of acceptable, non-commercial, use include research and education, and the inclusion of the Software into materials used for those purposes. 9 | (b) Any and all commercial use of this Software, or any adapted, modified, or derivative works thereof, is strictly prohibited. Prohibited commercial use includes, but is not limited to: selling, leasing, or licensing the Software for monetary or other commercial gain, using the Software in connection with business functions or operations, using Software to develop a commercial product, or embedding or installing the Software into products for your own commercial gain or for the commercial gain of third parties. 10 | (c) If you are uncertain as to whether your contemplated use of the Software is permissible, please do not use this Software and instead contact Xnor for further information. 11 | 12 | 2. CONDITIONS FOR USE: 13 | We require you also comply with the following conditions of use for the Software: 14 | (a) Do not remove any copyright or other notices from the Software. 15 | (b) Provide notice that this Software is used under this License, and that all use is for non-commercial purposes. 16 | (c) Any distribution of the Software or any derivative works of the Software must be under the same terms and conditions as in this License. Furthermore, the disclaimer of warranties in Section 3(a) of this License must be included in any distribution of the Software to third parties. 17 | If you adapt, modify, or create derivative versions of the Software, indicate that you’ve modified the original Software and provide the date(s) of such modifications. 18 | (d) Grant back to Xnor, without any restrictions or limitations, a non-exclusive, perpetual, irrevocable, royalty-free, assignable and sub-licensable license, to reproduce, publicly perform or display, install, use, modify, distribute, make and have made, and transfer your modifications to and/or derivative works of the Software source code or data, for any purpose. 19 | Acknowledge that any feedback about the Software provided by you to us is voluntarily given, and Xnor shall be free to use the feedback as it sees fit without obligation or restriction of any kind. 20 | 21 | 4. ADDITIONAL TERMS: 22 | (a) THE SOFTWARE COMES “AS IS”, WITH NO WARRANTIES. THIS MEANS NO EXPRESS, IMPLIED OR STATUTORY WARRANTY, INCLUDING WITHOUT LIMITATION, WARRANTIES OF MERCHANTABILITY OR FITNESS FOR A PARTICULAR PURPOSE, ANY WARRANTY AGAINST INTERFERENCE WITH YOUR ENJOYMENT OF THE SOFTWARE OR ANY WARRANTY OF TITLE OR NON-INFRINGEMENT. THERE IS NO WARRANTY THAT THIS SOFTWARE WILL FULFILL ANY OF YOUR PARTICULAR PURPOSES OR NEEDS. 23 | (b) NEITHER XNOR NOR ANY CONTRIBUTOR TO THE SOFTWARE WILL BE LIABLE FOR ANY DAMAGES RELATED TO THE SOFTWARE OR THIS LICENSE, INCLUDING DIRECT, INDIRECT, SPECIAL, CONSEQUENTIAL OR INCIDENTAL DAMAGES, TO THE MAXIMUM EXTENT THE LAW PERMITS, NO MATTER WHAT LEGAL THEORY IT IS BASED ON. ALSO, YOU MUST PASS THIS LIMITATION OF LIABILITY ON WHENEVER YOU DISTRIBUTE THE SOFTWARE OR DERIVATIVE WORKS. 24 | (c) We have no duty of reasonable care or lack of negligence, and we are not obligated to (and will not) provide technical support for the Software. 25 | (d) If you breach this License or if you sue anyone over any intellectual property regarding the Software, or another party’s use thereof, this License and your rights herein shall terminate automatically. We may also terminate this license for convenience upon notice to you. Upon termination of this License for any reason, you will immediately cease all use and distribution of the Software and destroy any copies or portions of the Software in your possession. Section 4 of this License shall survive any termination of this License. 26 | (e) As between the parties, you acknowledge that Xnor owns all copyright and other intellectual property rights in the Software, including Software or portions thereof incorporated in derivative works made under this License. . 27 | (f) The Software may be subject to U.S. export jurisdiction at the time it is licensed to you, and it may be subject to additional export or import laws in other places. You agree to comply with all such laws and regulations that may apply to the Software after delivery of the software to you. 28 | (g) All rights not expressly granted to you in this License are reserved. This License does not convey an ownership of any right to you or any third party. There are no implied licenses under this Licence. 29 | (h) This License shall be construed and controlled by the laws of the State of Washington, USA, without regard to conflicts of law. If any provision of this License shall be deemed unenforceable or contrary to law, the rest of this License shall remain in full effect and interpreted in an enforceable manner that most nearly captures the intent of the original language. 30 | 31 | 32 | By downloading this software you acknowledge that you read and agreed all the terms in this license. 33 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Label Refinery: *Improving ImageNet Classification through Label Progression* 2 | By [Hessam Bagherinezhad](http://homes.cs.washington.edu/~hessam/), 3 | [Maxwell Horton](http://homes.cs.washington.edu/~mchorton/), 4 | [Mohammad Rastegari](http://www.umiacs.umd.edu/~mrastega/), 5 | and [Ali Farhadi](http://homes.cs.washington.edu/~ali/). 6 | 7 | ### Introduction 8 | 9 | This is a pytorch training script that can be used to train image classifier on 10 | ImageNet. The purpose of this repository is to back the experimental results 11 | presented in the Label Refinery paper. The Label Refinery paper is 12 | [published on arxiv](https://arxiv.org/abs/1805.02641). 13 | 14 | Label Refinery is a training mechanism that can be used to train any 15 | classification model. Label Refinery improves the quality of the labels, and 16 | therefore the quality of the models trained with those labels. Using Label 17 | Refinery improves the state-of-the-art accuracy of a variety of network 18 | architectures: 19 | 20 | Model | Paper Number :: Top-1 | Our Impl. :: Top-1 | Label Refinery :: Top-1 21 | -------------- |:---------------------:|:-------------------:|:-----------------------: 22 | AlexNet | 59.3 | 57.93 | **66.28** 23 | MobileNet | 70.6 | 68.53 | **73.39** 24 | MobileNet0.75 | 68.4 | 65.93 | **70.92** 25 | MobileNet0.5 | 63.7 | 63.03 | **66.66** 26 | MobileNet0.25 | 50.6 | 50.65 | **54.62** 27 | ResNet-50 | N/A | 75.7 | **76.5** 28 | ResNet-34 | N/A | 73.39 | **75.06** 29 | ResNet-18 | N/A | 69.7 | **72.52** 30 | ResNetXnor-50 | N/A | 63.1 | **70.34** 31 | VGG-16 | 73 | 70.1 | **75** 32 | VGG-19 | 72.7 | 71.39 | **75.46** 33 | Darknet19 | 72.9 | 70.6 | **74.47** 34 | 35 | For complete list of results and some analysis, please refer to 36 | [our paper](https://arxiv.org/abs/1805.02641). 37 | 38 | ### Usage 39 | #### Prerequisite 40 | To use this source code you need Python3.5+, a copy of ImageNet 2012 dataset, 41 | and a few python3 packages. A full set of python dependencies is listed in 42 | [requirements.txt](requirements.txt) for cuda 8 users. If you're not using cuda, 43 | or using a different version of cuda, change `torch==0.4.0` line to your desired 44 | pytorch 0.4 wheel url. You can install them all through pip3: 45 | ``` 46 | pip3 install -r requirements.txt 47 | ``` 48 | 49 | #### Train models 50 | 51 | You can train models either with the standard labels, or with the refined 52 | labels. To train `AlexNet` with the standard labels: 53 | ``` 54 | ./train.py --model AlexNet --imagenet /path/to/imagenet2012 55 | ``` 56 | To train `AlexNet` with refined labels generated by a trained `AlexNet` Label 57 | Refinery: 58 | ``` 59 | ./train.py --model AlexNet --imagenet /path/to/imagenet2012 --label-refinery-model AlexNet --label-refinery-state-file /path/to/trained/alexnet.pytar 60 | ``` 61 | 62 | #### Test models 63 | 64 | To test a trained AlexNet model: 65 | ``` 66 | ./test.py --model AlexNet --model-state-file /path/to/alexnet.pytar --imagenet /path/to/imagenet2012 67 | ``` 68 | 69 | 70 | #### Pre-trained weights 71 | 72 | Model | Description | Top-1 | Link 73 | -------------------- |:-----------------------------------------------------:|:------:|:------: 74 | `AlexNet^1` | AlexNet trained with standard labels. | 57.93 | [get](https://storage.googleapis.com/xnorai-public/downloads/label-refinery/alexnet%5E1.pytar) 75 | `AlexNet^2` | AlexNet trained with labels refined by `AlexNet^1`. | 59.97 | [get](https://storage.googleapis.com/xnorai-public/downloads/label-refinery/alexnet%5E2.pytar) 76 | `AlexNet^3` | AlexNet trained with labels refined by `AlexNet^2`. | 60.87 | [get](https://storage.googleapis.com/xnorai-public/downloads/label-refinery/alexnet%5E3.pytar) 77 | `AlexNet^4` | AlexNet trained with labels refined by `AlexNet^3`. | 61.22 | [get](https://storage.googleapis.com/xnorai-public/downloads/label-refinery/alexnet%5E4.pytar) 78 | `AlexNet^5` | AlexNet trained with labels refined by `AlexNet^4`. | 61.37 | [get](https://storage.googleapis.com/xnorai-public/downloads/label-refinery/alexnet%5E5.pytar) 79 | `AlexNet By ResNet-50` | AlexNet trained with labels refined by `ResNet-50`. | 66.28 | [get](https://storage.googleapis.com/xnorai-public/downloads/label-refinery/alexnet-from-resnet50.pytar) 80 | `ResNet-50` | ResNet-50 trained with standard labels. | 75.7 | [get](https://storage.googleapis.com/xnorai-public/downloads/label-refinery/resnet50.pytar) 81 | 82 | #### License 83 | By downloading this software you acknowledge that you read and agreed all the 84 | terms in the `LICENSE` file. 85 | -------------------------------------------------------------------------------- /extensions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hessamb/label-refinery/e64f194df362d6c6b9a3250948620a2d8b003894/extensions/__init__.py -------------------------------------------------------------------------------- /extensions/data_parallel.py: -------------------------------------------------------------------------------- 1 | __author__ = "Hessam Bagherinezhad " 2 | 3 | from torch import nn 4 | from torch.nn.modules import loss 5 | 6 | 7 | class DataParallel(nn.DataParallel): 8 | """An extension of nn.DataParallel. 9 | 10 | The only extensions are: 11 | 1) If an attribute is missing in an object of this class, it will look 12 | for it in the wrapped module. This is useful for getting `LR_REGIME` 13 | of the wrapped module for example. 14 | 2) state_dict() of this class calls the wrapped module's state_dict(), 15 | hence the weights can be transferred from a data parallel wrapped 16 | module to a single gpu module. 17 | """ 18 | 19 | 20 | def __getattr__(self, name): 21 | # If attribute doesn't exist in the DataParallel object this method will 22 | # be called. Here we first ask the super class to get the attribute, if 23 | # couldn't find it, we ask the underlying module that is wrapped by this 24 | # DataParallel to get the attribute. 25 | try: 26 | return super().__getattr__(name) 27 | except AttributeError: 28 | underlying_module = super().__getattr__('module') 29 | return getattr(underlying_module, name) 30 | 31 | def state_dict(self, *args, **kwargs): 32 | return self.module.state_dict(*args, **kwargs) 33 | -------------------------------------------------------------------------------- /extensions/model_refinery_wrapper.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn import functional as F 3 | 4 | 5 | class ModelRefineryWrapper(nn.Module): 6 | """Convenient wrapper class to train a model with a label refinery.""" 7 | 8 | def __init__(self, model, label_refinery): 9 | super().__init__() 10 | self.model = model 11 | self.label_refinery = label_refinery 12 | 13 | # Since we don't want to back-prop through the label_refinery network, 14 | # make the parameters of the teacher network not require gradients. This 15 | # saves some GPU memory. 16 | for param in self.label_refinery.parameters(): 17 | param.requires_grad = False 18 | 19 | @property 20 | def LR_REGIME(self): 21 | # Training with label refinery does not change learing rate regime. 22 | # Return's wrapped model lr regime. 23 | return self.model.LR_REGIME 24 | 25 | def state_dict(self): 26 | return self.model.state_dict() 27 | 28 | def forward(self, input): 29 | if self.training: 30 | refined_labels = self.label_refinery(input) 31 | refined_labels = F.softmax(refined_labels, dim=1) 32 | model_output = self.model(input) 33 | return (model_output, refined_labels) 34 | else: 35 | return self.model(input) 36 | -------------------------------------------------------------------------------- /extensions/refinery_loss.py: -------------------------------------------------------------------------------- 1 | __author__ = "Hessam Bagherinezhad " 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | from torch.nn.modules import loss 6 | 7 | 8 | class RefineryLoss(loss._Loss): 9 | """The KL-Divergence loss for the model and refined labels output. 10 | 11 | output must be a pair of (model_output, refined_labels), both NxC tensors. 12 | The rows of refined_labels must all add up to one (probability scores); 13 | however, model_output must be the pre-softmax output of the network.""" 14 | 15 | def forward(self, output, target): 16 | if not self.training: 17 | # Loss is normal cross entropy loss between the model output and the 18 | # target. 19 | return F.cross_entropy(output, target, 20 | size_average=self.size_average) 21 | 22 | assert type(output) == tuple and len(output) == 2 and output[0].size() == \ 23 | output[1].size(), "output must a pair of tensors of same size." 24 | 25 | # Target is ignored at training time. Loss is defined as KL divergence 26 | # between the model output and the refined labels. 27 | model_output, refined_labels = output 28 | if refined_labels.requires_grad: 29 | raise ValueError("Refined labels should not require gradients.") 30 | 31 | model_output_log_prob = F.log_softmax(model_output, dim=1) 32 | del model_output 33 | 34 | # Loss is -dot(model_output_log_prob, refined_labels). Prepare tensors 35 | # for batch matrix multiplicatio 36 | refined_labels = refined_labels.unsqueeze(1) 37 | model_output_log_prob = model_output_log_prob.unsqueeze(2) 38 | 39 | # Compute the loss, and average/sum for the batch. 40 | cross_entropy_loss = -torch.bmm(refined_labels, model_output_log_prob) 41 | if self.size_average: 42 | cross_entropy_loss = cross_entropy_loss.mean() 43 | else: 44 | cross_entropy_loss = cross_entropy_loss.sum() 45 | # Return a pair of (loss_output, model_output). Model output will be 46 | # used for top-1 and top-5 evaluation. 47 | model_output_log_prob = model_output_log_prob.squeeze(2) 48 | return (cross_entropy_loss, model_output_log_prob) 49 | -------------------------------------------------------------------------------- /imagenet.py: -------------------------------------------------------------------------------- 1 | """Dataset class for loading imagenet data.""" 2 | 3 | import os 4 | 5 | from torch.utils import data as data_utils 6 | from torchvision import datasets as torch_datasets 7 | from torchvision import transforms 8 | 9 | 10 | def get_train_loader(imagenet_path, batch_size, num_workers): 11 | train_dataset = ImageNet(imagenet_path, is_train=True) 12 | return data_utils.DataLoader( 13 | train_dataset, shuffle=True, batch_size=batch_size, pin_memory=True, 14 | num_workers=num_workers) 15 | 16 | 17 | def get_val_loader(imagenet_path, batch_size, num_workers): 18 | val_dataset = ImageNet(imagenet_path, is_train=False) 19 | return data_utils.DataLoader( 20 | val_dataset, shuffle=False, batch_size=batch_size, pin_memory=True, 21 | num_workers=num_workers) 22 | 23 | 24 | class ImageNet(torch_datasets.ImageFolder): 25 | """Dataset class for ImageNet dataset. 26 | 27 | Arguments: 28 | root_dir (str): Path to the dataset root directory, which must contain 29 | train/ and val/ directories. 30 | is_train (bool): Whether to read training or validation images. 31 | """ 32 | MEAN = [0.485, 0.456, 0.406] 33 | STD = [0.229, 0.224, 0.225] 34 | 35 | def __init__(self, root_dir, is_train): 36 | if is_train: 37 | root_dir = os.path.join(root_dir, 'train') 38 | transform = transforms.Compose([ 39 | transforms.RandomResizedCrop(224), 40 | transforms.RandomHorizontalFlip(), 41 | transforms.ToTensor(), 42 | transforms.Normalize(ImageNet.MEAN, ImageNet.STD), 43 | ]) 44 | else: 45 | root_dir = os.path.join(root_dir, 'val') 46 | transform = transforms.Compose([ 47 | transforms.Resize(256), 48 | transforms.CenterCrop(224), 49 | transforms.ToTensor(), 50 | transforms.Normalize(ImageNet.MEAN, ImageNet.STD), 51 | ]) 52 | super().__init__(root_dir, transform=transform) 53 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hessamb/label-refinery/e64f194df362d6c6b9a3250948620a2d8b003894/models/__init__.py -------------------------------------------------------------------------------- /models/alexnet.py: -------------------------------------------------------------------------------- 1 | """AlexNet architecture pytorch model.""" 2 | 3 | from torch import nn 4 | 5 | from models import blocks 6 | 7 | 8 | class AlexNet(nn.Module): 9 | """This is the original AlexNet architecture, and not the version introduced 10 | in the "one weird trick" paper.""" 11 | LR_REGIME = [1, 140, 0.01, 141, 170, 0.001, 171, 200, 0.0001] 12 | 13 | def __init__(self): 14 | super().__init__() 15 | self.conv1 = blocks.Conv2dBnRelu(3, 96, 11, 4, 2, 16 | pooling=nn.MaxPool2d(2)) 17 | self.conv2 = blocks.Conv2dBnRelu(96, 256, 5, 1, 2, 18 | pooling=nn.MaxPool2d(2)) 19 | self.conv3 = blocks.Conv2dBnRelu(256, 384, 3, 1, 1) 20 | self.conv4 = blocks.Conv2dBnRelu(384, 384, 3, 1, 1) 21 | self.conv5 = blocks.Conv2dBnRelu(384, 256, 3, 1, 1, 22 | pooling=nn.MaxPool2d(2)) 23 | 24 | self.fc6 = blocks.LinearBnRelu(256 * 6 * 6, 4096) 25 | self.fc7 = blocks.LinearBnRelu(4096, 4096) 26 | self.fc8 = nn.Linear(4096, 1000, bias=False) 27 | 28 | def convolutions(self, x): 29 | return nn.Sequential(self.conv1, self.conv2, self.conv3, self.conv4, 30 | self.conv5)(x) 31 | 32 | def fully_connecteds(self, x): 33 | return nn.Sequential(self.fc6, self.fc7, self.fc8)(x) 34 | 35 | def forward(self, x): 36 | x = self.convolutions(x) 37 | x = x.view(x.size(0), -1) 38 | x = self.fully_connecteds(x) 39 | return x 40 | -------------------------------------------------------------------------------- /models/blocks.py: -------------------------------------------------------------------------------- 1 | """A list of commonly used building blocks.""" 2 | 3 | from torch import nn 4 | 5 | 6 | class Conv2dBnRelu(nn.Module): 7 | """A commonly used building block: Conv -> BN -> ReLU""" 8 | 9 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 10 | padding=0, bias=True, pooling=None, 11 | activation=nn.ReLU(inplace=True)): 12 | super().__init__() 13 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, 14 | padding, bias=bias) 15 | self.bn = nn.BatchNorm2d(out_channels) 16 | self.pooling = pooling 17 | self.activation = activation 18 | 19 | def forward(self, x): 20 | x = self.bn(self.conv(x)) 21 | if self.pooling is not None: 22 | x = self.pooling(x) 23 | return self.activation(x) 24 | 25 | 26 | class LinearBnRelu(nn.Module): 27 | """A commonly used building block: FC -> BN -> ReLU""" 28 | 29 | def __init__(self, in_features, out_features, bias=True, 30 | activation=nn.ReLU(inplace=True)): 31 | super().__init__() 32 | self.linear = nn.Linear(in_features, out_features, bias=bias) 33 | self.bn = nn.BatchNorm1d(out_features) 34 | self.activation = activation 35 | 36 | def forward(self, x): 37 | return self.activation(self.bn(self.linear(x))) 38 | -------------------------------------------------------------------------------- /models/model_factory.py: -------------------------------------------------------------------------------- 1 | """Utility functions to construct a model.""" 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from extensions import data_parallel 7 | from extensions import model_refinery_wrapper 8 | from extensions import refinery_loss 9 | from models import alexnet 10 | from models import resnet50 11 | 12 | 13 | MODEL_NAME_MAP = { 14 | 'AlexNet': alexnet.AlexNet, 15 | 'ResNet50': resnet50.ResNet50, 16 | } 17 | 18 | 19 | def _create_single_cpu_model(model_name, state_file=None): 20 | if model_name not in MODEL_NAME_MAP: 21 | raise ValueError("Model {} is invalid. Pick from {}.".format( 22 | model_name, sorted(MODEL_NAME_MAP.keys()))) 23 | model_class = MODEL_NAME_MAP[model_name] 24 | model = model_class() 25 | if state_file is not None: 26 | model.load_state_dict(torch.load(state_file)) 27 | return model 28 | 29 | 30 | def create_model(model_name, model_state_file=None, gpus=[], label_refinery=None, 31 | label_refinery_state_file=None): 32 | model = _create_single_cpu_model(model_name, model_state_file) 33 | if label_refinery is not None: 34 | assert label_refinery_state_file is not None, "Refinery state is None." 35 | label_refinery = _create_single_cpu_model( 36 | label_refinery, label_refinery_state_file) 37 | model = model_refinery_wrapper.ModelRefineryWrapper(model, label_refinery) 38 | loss = refinery_loss.RefineryLoss() 39 | else: 40 | loss = nn.CrossEntropyLoss() 41 | 42 | if len(gpus) > 0: 43 | model = model.cuda() 44 | loss = loss.cuda() 45 | if len(gpus) > 1: 46 | model = data_parallel.DataParallel(model, device_ids=gpus) 47 | return model, loss 48 | -------------------------------------------------------------------------------- /models/resnet50.py: -------------------------------------------------------------------------------- 1 | """ResNet-50 architecture pytorch model.""" 2 | 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | 7 | class ResNet50(nn.Module): 8 | LR_REGIME = [1, 140, 0.1, 141, 170, 0.01, 171, 200, 0.001] 9 | 10 | def __init__(self): 11 | super().__init__() 12 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 13 | bias=False) 14 | self._current_planes = 64 15 | self.bn1 = nn.BatchNorm2d(64) 16 | self.layer1 = self._make_layer(planes=64, num_layers=3) 17 | self.layer2 = self._make_layer(planes=128, num_layers=4, stride=2) 18 | self.layer3 = self._make_layer(planes=256, num_layers=6, stride=2) 19 | self.layer4 = self._make_layer(planes=512, num_layers=3, stride=2) 20 | self.fc = nn.Linear(512 * 4, 1000) 21 | 22 | def _make_layer(self, planes, num_layers, stride=1): 23 | layers = [] 24 | # Add blocks one by one 25 | for i in range(num_layers): 26 | # Apply the stride on the first block of the series 27 | block_stride = stride if i == 0 else 1 28 | # If input size is changing, do convolution downsampling (residual) 29 | if block_stride != 1 or self._current_planes != planes * 4: 30 | downsample = nn.Sequential( 31 | nn.Conv2d(self._current_planes, planes * 4, kernel_size=1, 32 | stride=stride, bias=False), 33 | nn.BatchNorm2d(planes * 4)) 34 | else: 35 | downsample = None 36 | layers.append(Bottleneck( 37 | self._current_planes, planes, stride=block_stride, 38 | downsample=downsample)) 39 | self._current_planes = planes * 4 40 | # Make a sequential of all blocks 41 | return nn.Sequential(*layers) 42 | 43 | def classifier(self, x): 44 | return self.fc(x) 45 | 46 | def feats(self, x): 47 | x = self.conv1(x) 48 | x = self.bn1(x) 49 | x = F.relu(x, inplace=True) 50 | x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1) 51 | 52 | x = self.layer1(x) 53 | x = self.layer2(x) 54 | x = self.layer3(x) 55 | x = self.layer4(x) 56 | 57 | x = F.avg_pool2d(x, kernel_size=7) 58 | x = x.view(x.size(0), -1) 59 | return x 60 | 61 | def forward(self, x): 62 | x = self.feats(x) 63 | x = self.classifier(x) 64 | return x 65 | 66 | 67 | class Bottleneck(nn.Module): 68 | 69 | def __init__(self, inplanes, planes, stride=1, downsample=None): 70 | super().__init__() 71 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 72 | self.bn1 = nn.BatchNorm2d(planes) 73 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 74 | padding=1, bias=False) 75 | self.bn2 = nn.BatchNorm2d(planes) 76 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 77 | self.bn3 = nn.BatchNorm2d(planes * 4) 78 | self.downsample = downsample 79 | 80 | def forward(self, x): 81 | if self.downsample is not None: 82 | # Compute residual at the beginning so we can free the memory of x, 83 | # if not needed. 84 | residual = self.downsample(x) 85 | else: 86 | residual = x 87 | 88 | x = self.conv1(x) 89 | x = self.bn1(x) 90 | x = F.relu(x, inplace=True) 91 | 92 | x = self.conv2(x) 93 | x = self.bn2(x) 94 | x = F.relu(x, inplace=True) 95 | 96 | x = self.conv3(x) 97 | x = self.bn3(x) 98 | 99 | x += residual 100 | x = F.relu(x, inplace=True) 101 | return x 102 | -------------------------------------------------------------------------------- /opts.py: -------------------------------------------------------------------------------- 1 | from torch.utils import data as data_utils 2 | 3 | from models import model_factory 4 | 5 | 6 | def add_general_flags(parser): 7 | parser.add_argument('--save', default='checkpoints', 8 | help="Path to the directory to save logs and " 9 | "checkpoints.") 10 | parser.add_argument('--gpus', '--gpu', nargs='+', default=[0], type=int, 11 | help="The GPU(s) on which the model should run. The " 12 | "first GPU will be the main one.") 13 | parser.add_argument('--cpu', action='store_const', const=[], 14 | dest='gpus', help="If set, no gpus will be used.") 15 | 16 | 17 | def add_dataset_flags(parser): 18 | parser.add_argument('--imagenet', required=True, help="Path to ImageNet's " 19 | "root directory holding 'train/' and 'val/' " 20 | "directories.") 21 | parser.add_argument('--batch-size', default=256, help="Batch size to use " 22 | "distributed over all GPUs.", type=int) 23 | parser.add_argument('--num-workers', '-j', default=8, help="Number of " 24 | "data loading processes to use for loading data and " 25 | "transforming.", type=int) 26 | 27 | 28 | def add_model_flags(parser): 29 | parser.add_argument('--model', required=True, help="The model architecture " 30 | "name.", choices=sorted(model_factory.MODEL_NAME_MAP.keys())) 31 | parser.add_argument('--model-state-file', default=None, help="Path to model" 32 | " state file to initialize the model.") 33 | 34 | 35 | def add_label_refinery_flags(parser): 36 | parser.add_argument('--label-refinery-model', default=None, help="The " 37 | "model that will generate refined labels per crop.", 38 | choices=sorted(model_factory.MODEL_NAME_MAP.keys())) 39 | parser.add_argument('--label-refinery-state-file', default=None, 40 | help="Path to label refinery model state file.") 41 | 42 | 43 | def add_training_flags(parser): 44 | parser.add_argument('--lr-regime', default=None, nargs='+', type=float, 45 | help="If set, it will override the default learning " 46 | "rate regime of the model. Learning rate passed must " 47 | "be as list of [start, end, lr, ...].") 48 | parser.add_argument('--momentum', default=0.9, type=float, 49 | help="The momentum of the optimization.") 50 | parser.add_argument('--weight-decay', default=0, type=float, 51 | help="The weight decay of the optimization.") 52 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.14.3 2 | Pillow==5.1.0 3 | pkg-resources==0.0.0 4 | six==1.11.0 5 | # Change the next line to your desired pytorch 0.4 wheel url, if you're not 6 | # using cuda 8. 7 | torch==0.4.0 8 | torchvision==0.2.1 9 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Script to test a pytorch model on ImageNet's validation set.""" 3 | 4 | import argparse 5 | import logging 6 | import pprint 7 | import sys 8 | import time 9 | 10 | import torch 11 | from torch import nn 12 | 13 | import imagenet 14 | from models import model_factory 15 | import opts 16 | import utils 17 | 18 | 19 | def parse_args(argv): 20 | """Parse arguments @argv and return the flags needed for training.""" 21 | parser = argparse.ArgumentParser(description=__doc__, allow_abbrev=False) 22 | 23 | group = parser.add_argument_group('General Options') 24 | opts.add_general_flags(group) 25 | 26 | group = parser.add_argument_group('Dataset Options') 27 | opts.add_dataset_flags(group) 28 | 29 | group = parser.add_argument_group('Model Options') 30 | opts.add_model_flags(group) 31 | 32 | args = parser.parse_args(argv) 33 | 34 | if args.model_state_file is None: 35 | parser.error("You should set --model-state-file to reload a model " 36 | "state.") 37 | 38 | return args 39 | 40 | 41 | def test_for_one_epoch(model, loss, test_loader, epoch_number): 42 | model.eval() 43 | loss.eval() 44 | 45 | data_time_meter = utils.AverageMeter() 46 | batch_time_meter = utils.AverageMeter() 47 | loss_meter = utils.AverageMeter(recent=100) 48 | top1_meter = utils.AverageMeter(recent=100) 49 | top5_meter = utils.AverageMeter(recent=100) 50 | 51 | timestamp = time.time() 52 | for i, (images, labels) in enumerate(test_loader): 53 | batch_size = images.size(0) 54 | 55 | if utils.is_model_cuda(model): 56 | images = images.cuda(async=True) 57 | labels = labels.cuda(async=True) 58 | 59 | # Record data time 60 | data_time_meter.update(time.time() - timestamp) 61 | 62 | # Forward pass without computing gradients. 63 | with torch.no_grad(): 64 | outputs = model(images) 65 | loss_output = loss(outputs, labels) 66 | 67 | # Sometimes loss function returns a modified version of the output, 68 | # which must be used to compute the model accuracy. 69 | if isinstance(loss_output, tuple): 70 | loss_value, outputs = loss_output 71 | else: 72 | loss_value = loss_output 73 | 74 | # Record loss and model accuracy. 75 | loss_meter.update(loss_value.item(), batch_size) 76 | top1, top5 = utils.topk_accuracy(outputs, labels, recalls=(1, 5)) 77 | top1_meter.update(top1, batch_size) 78 | top5_meter.update(top5, batch_size) 79 | 80 | # Record batch time 81 | batch_time_meter.update(time.time() - timestamp) 82 | timestamp = time.time() 83 | 84 | logging.info( 85 | 'Epoch: [{epoch}][{batch}/{epoch_size}]\t' 86 | 'Time {batch_time.value:.2f} ({batch_time.average:.2f}) ' 87 | 'Data {data_time.value:.2f} ({data_time.average:.2f}) ' 88 | 'Loss {loss.value:.3f} {{{loss.average:.3f}, {loss.average_recent:.3f}}} ' 89 | 'Top-1 {top1.value:.2f} {{{top1.average:.2f}, {top1.average_recent:.2f}}} ' 90 | 'Top-5 {top5.value:.2f} {{{top5.average:.2f}, {top5.average_recent:.2f}}} '.format( 91 | epoch=epoch_number, batch=i + 1, epoch_size=len(test_loader), 92 | batch_time=batch_time_meter, data_time=data_time_meter, 93 | loss=loss_meter, top1=top1_meter, top5=top5_meter)) 94 | # Log the overall test stats 95 | logging.info( 96 | 'Epoch: [{epoch}] -- TESTING SUMMARY\t' 97 | 'Time {batch_time.sum:.2f} ' 98 | 'Data {data_time.sum:.2f} ' 99 | 'Loss {loss.average:.3f} ' 100 | 'Top-1 {top1.average:.2f} ' 101 | 'Top-5 {top5.average:.2f} '.format( 102 | epoch=epoch_number, batch_time=batch_time_meter, data_time=data_time_meter, 103 | loss=loss_meter, top1=top1_meter, top5=top5_meter)) 104 | 105 | 106 | def main(argv): 107 | """Run the test script with command line arguments @argv.""" 108 | args = parse_args(argv) 109 | utils.general_setup(args.save, args.gpus) 110 | 111 | logging.info("Arguments parsed.\n{}".format(pprint.pformat(vars(args)))) 112 | 113 | # Create the validation data loaders. 114 | val_loader = imagenet.get_val_loader(args.imagenet, args.batch_size, 115 | args.num_workers) 116 | # Create model and the loss. 117 | model, loss = model_factory.create_model( 118 | args.model, args.model_state_file, args.gpus) 119 | logging.info("Model:\n{}".format(model)) 120 | 121 | # Test for one epoch. 122 | test_for_one_epoch(model, loss, val_loader, epoch_number=1) 123 | 124 | 125 | if __name__ == '__main__': 126 | main(sys.argv[1:]) 127 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Script to train a model through refined labels on ImageNet's train set.""" 3 | 4 | import argparse 5 | import logging 6 | import pprint 7 | import os 8 | import sys 9 | import time 10 | 11 | import torch 12 | from torch import nn 13 | 14 | import imagenet 15 | from models import model_factory 16 | import opts 17 | import test 18 | import utils 19 | 20 | 21 | def parse_args(argv): 22 | """Parse arguments @argv and return the flags needed for training.""" 23 | parser = argparse.ArgumentParser(description=__doc__, allow_abbrev=False) 24 | 25 | group = parser.add_argument_group('General Options') 26 | opts.add_general_flags(group) 27 | 28 | group = parser.add_argument_group('Dataset Options') 29 | opts.add_dataset_flags(group) 30 | 31 | group = parser.add_argument_group('Model Options') 32 | opts.add_model_flags(group) 33 | 34 | group = parser.add_argument_group('Label Refinery Options') 35 | opts.add_label_refinery_flags(group) 36 | 37 | group = parser.add_argument_group('Training Options') 38 | opts.add_training_flags(group) 39 | 40 | args = parser.parse_args(argv) 41 | 42 | if args.label_refinery_model is not None and args.label_refinery_state_file is None: 43 | parser.error("You should set --label-refinery-state-file if " 44 | "--label-refinery-model is set.") 45 | 46 | return args 47 | 48 | 49 | class LearningRateRegime: 50 | """Encapsulates the learning rate regime for training a model. 51 | 52 | Args: 53 | @intervals (list): A list of triples (start, end, lr). The intervals 54 | are inclusive (for start <= epoch <= end, lr will be used). The 55 | start of each interval must be right after the end of its previous 56 | interval. 57 | """ 58 | 59 | def __init__(self, regime): 60 | if len(regime) % 3 != 0: 61 | raise ValueError("Regime length should be devisible by 3.") 62 | intervals = list(zip(regime[0::3], regime[1::3], regime[2::3])) 63 | self._validate_intervals(intervals) 64 | self.intervals = intervals 65 | self.num_epochs = intervals[-1][1] 66 | 67 | @classmethod 68 | def _validate_intervals(cls, intervals): 69 | if type(intervals) is not list: 70 | raise TypeError("Intervals must be a list of triples.") 71 | elif len(intervals) == 0: 72 | raise ValueError("Intervals must be a non empty list.") 73 | elif intervals[0][0] != 1: 74 | raise ValueError("Intervals must start from 1: {}".format(intervals)) 75 | elif any(end < start for (start, end, lr) in intervals): 76 | raise ValueError("End of intervals must be greater or equal than their" 77 | " start: {}".format(intervals)) 78 | elif any(intervals[i][1] + 1 != intervals[i + 1][0] 79 | for i in range(len(intervals) - 1)): 80 | raise ValueError("Start of each each interval must be the end of its " 81 | "previous interval plus one: {}".format(intervals)) 82 | 83 | def get_lr(self, epoch): 84 | for (start, end, lr) in self.intervals: 85 | if start <= epoch <= end: 86 | return lr 87 | raise ValueError("Invalid epoch {} for regime {!r}".format( 88 | epoch, self.intervals)) 89 | 90 | 91 | def _set_learning_rate(optimizer, lr): 92 | for param_group in optimizer.param_groups: 93 | param_group['lr'] = lr 94 | 95 | 96 | def _get_learning_rate(optimizer): 97 | return max(param_group['lr'] for param_group in optimizer.param_groups) 98 | 99 | 100 | def train_for_one_epoch(model, loss, train_loader, optimizer, epoch_number): 101 | model.train() 102 | loss.train() 103 | 104 | data_time_meter = utils.AverageMeter() 105 | batch_time_meter = utils.AverageMeter() 106 | loss_meter = utils.AverageMeter(recent=100) 107 | top1_meter = utils.AverageMeter(recent=100) 108 | top5_meter = utils.AverageMeter(recent=100) 109 | 110 | timestamp = time.time() 111 | for i, (images, labels) in enumerate(train_loader): 112 | batch_size = images.size(0) 113 | 114 | if utils.is_model_cuda(model): 115 | images = images.cuda(async=True) 116 | labels = labels.cuda(async=True) 117 | 118 | # Record data time 119 | data_time_meter.update(time.time() - timestamp) 120 | 121 | # Forward pass, backward pass, and update parameters. 122 | outputs = model(images) 123 | loss_output = loss(outputs, labels) 124 | 125 | # Sometimes loss function returns a modified version of the output, 126 | # which must be used to compute the model accuracy. 127 | if isinstance(loss_output, tuple): 128 | loss_value, outputs = loss_output 129 | else: 130 | loss_value = loss_output 131 | loss_value.backward() 132 | 133 | # Update parameters and reset gradients. 134 | optimizer.step() 135 | optimizer.zero_grad() 136 | 137 | # Record loss and model accuracy. 138 | loss_meter.update(loss_value.item(), batch_size) 139 | top1, top5 = utils.topk_accuracy(outputs, labels, recalls=(1, 5)) 140 | top1_meter.update(top1, batch_size) 141 | top5_meter.update(top5, batch_size) 142 | 143 | # Record batch time 144 | batch_time_meter.update(time.time() - timestamp) 145 | timestamp = time.time() 146 | 147 | logging.info( 148 | 'Epoch: [{epoch}][{batch}/{epoch_size}]\t' 149 | 'Time {batch_time.value:.2f} ({batch_time.average:.2f}) ' 150 | 'Data {data_time.value:.2f} ({data_time.average:.2f}) ' 151 | 'Loss {loss.value:.3f} {{{loss.average:.3f}, {loss.average_recent:.3f}}} ' 152 | 'Top-1 {top1.value:.2f} {{{top1.average:.2f}, {top1.average_recent:.2f}}} ' 153 | 'Top-5 {top5.value:.2f} {{{top5.average:.2f}, {top5.average_recent:.2f}}} ' 154 | 'LR {lr:.5f}'.format( 155 | epoch=epoch_number, batch=i + 1, epoch_size=len(train_loader), 156 | batch_time=batch_time_meter, data_time=data_time_meter, 157 | loss=loss_meter, top1=top1_meter, top5=top5_meter, 158 | lr=_get_learning_rate(optimizer))) 159 | # Log the overall train stats 160 | logging.info( 161 | 'Epoch: [{epoch}] -- TRAINING SUMMARY\t' 162 | 'Time {batch_time.sum:.2f} ' 163 | 'Data {data_time.sum:.2f} ' 164 | 'Loss {loss.average:.3f} ' 165 | 'Top-1 {top1.average:.2f} ' 166 | 'Top-5 {top5.average:.2f} '.format( 167 | epoch=epoch_number, batch_time=batch_time_meter, data_time=data_time_meter, 168 | loss=loss_meter, top1=top1_meter, top5=top5_meter)) 169 | 170 | 171 | def save_checkpoint(checkpoints_dir, model, optimizer, epoch): 172 | model_state_file = os.path.join(checkpoints_dir, 'model_state_{:02}.pytar'.format(epoch)) 173 | optim_state_file = os.path.join(checkpoints_dir, 'optim_state_{:02}.pytar'.format(epoch)) 174 | torch.save(model.state_dict(), model_state_file) 175 | torch.save(optimizer.state_dict(), optim_state_file) 176 | 177 | 178 | def create_optimizer(model, momentum=0.9, weight_decay=0): 179 | # Get model parameters that require a gradient. 180 | model_trainable_parameters = filter(lambda x: x.requires_grad, model.parameters()) 181 | optimizer = torch.optim.SGD(model_trainable_parameters, lr=0, 182 | momentum=momentum, weight_decay=weight_decay) 183 | return optimizer 184 | 185 | 186 | def main(argv): 187 | """Run the training script with command line arguments @argv.""" 188 | args = parse_args(argv) 189 | utils.general_setup(args.save, args.gpus) 190 | 191 | logging.info("Arguments parsed.\n{}".format(pprint.pformat(vars(args)))) 192 | 193 | # Create the train and the validation data loaders. 194 | train_loader = imagenet.get_train_loader(args.imagenet, args.batch_size, 195 | args.num_workers) 196 | val_loader = imagenet.get_val_loader(args.imagenet, args.batch_size, 197 | args.num_workers) 198 | # Create model with optional label refinery. 199 | model, loss = model_factory.create_model( 200 | args.model, args.model_state_file, args.gpus, args.label_refinery_model, 201 | args.label_refinery_state_file) 202 | logging.info("Model:\n{}".format(model)) 203 | 204 | if args.lr_regime is None: 205 | lr_regime = model.LR_REGIME 206 | else: 207 | lr_regime = args.lr_regime 208 | regime = LearningRateRegime(lr_regime) 209 | # Train and test for needed number of epochs. 210 | optimizer = create_optimizer(model, args.momentum, args.weight_decay) 211 | for epoch in range(1, regime.num_epochs + 1): 212 | lr = regime.get_lr(epoch) 213 | _set_learning_rate(optimizer, lr) 214 | train_for_one_epoch(model, loss, train_loader, optimizer, epoch) 215 | test.test_for_one_epoch(model, loss, val_loader, epoch) 216 | save_checkpoint(args.save, model, optimizer, epoch) 217 | 218 | 219 | if __name__ == '__main__': 220 | main(sys.argv[1:]) 221 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import logging 3 | import os 4 | import sys 5 | 6 | import torch 7 | 8 | 9 | def general_setup(checkpoints_dir=None, gpus=[]): 10 | if checkpoints_dir is not None: 11 | os.makedirs(checkpoints_dir, exist_ok=True) 12 | if len(gpus) > 0: 13 | torch.cuda.set_device(gpus[0]) 14 | # Setup python's logging module. 15 | log_formatter = logging.Formatter( 16 | '%(levelname)s %(asctime)-20s:\t %(message)s') 17 | root_logger = logging.getLogger() 18 | root_logger.setLevel(logging.INFO) 19 | # Add a console handler to write to stdout. 20 | console_handler = logging.StreamHandler(sys.stdout) 21 | console_handler.setFormatter(log_formatter) 22 | root_logger.addHandler(console_handler) 23 | # Add a file handler to write to log.txt. 24 | log_filepath = os.path.join(checkpoints_dir, 'log.txt') 25 | file_handler = logging.FileHandler(log_filepath) 26 | file_handler.setFormatter(log_formatter) 27 | root_logger.addHandler(file_handler) 28 | 29 | 30 | def is_model_cuda(model): 31 | # Check if the first parameter is on cuda. 32 | return next(model.parameters()).is_cuda 33 | 34 | 35 | def topk_accuracy(outputs, labels, recalls=(1, 5)): 36 | """Return @recall accuracies for the given recalls.""" 37 | 38 | _, num_classes = outputs.size() 39 | maxk = min(max(recalls), num_classes) 40 | 41 | _, pred = outputs.topk(maxk, dim=1, largest=True, sorted=True) 42 | correct = (pred == labels[:,None].expand_as(pred)).float() 43 | 44 | topk_accuracy = [] 45 | for recall in recalls: 46 | topk_accuracy.append(100 * correct[:, :recall].sum(1).mean()) 47 | return topk_accuracy 48 | 49 | 50 | class AverageMeter: 51 | """Helper class to track the running average (and optionally the recent k 52 | items average of a sequence).""" 53 | 54 | def __init__(self, recent=None): 55 | self._recent = recent 56 | if recent is not None: 57 | self._q = collections.deque() 58 | self.reset() 59 | 60 | def reset(self): 61 | self.value = 0 62 | self.sum = 0 63 | self.count = 0 64 | if self._recent is not None: 65 | self.sum_recent = 0 66 | self.count_recent = 0 67 | self._q.clear() 68 | 69 | def update(self, value, n=1): 70 | self.value = value 71 | self.sum += value * n 72 | self.count += n 73 | 74 | if self._recent is not None: 75 | self.sum_recent += value * n 76 | self.count_recent += n 77 | self._q.append((n, value)) 78 | while len(self._q) > self._recent: 79 | (n, value) = self._q.popleft() 80 | self.sum_recent -= value * n 81 | self.count_recent -= n 82 | 83 | @property 84 | def average(self): 85 | if self.count > 0: 86 | return self.sum / self.count 87 | else: 88 | return 0 89 | 90 | @property 91 | def average_recent(self): 92 | if self.count_recent > 0: 93 | return self.sum_recent / self.count_recent 94 | else: 95 | return 0 96 | --------------------------------------------------------------------------------