├── ACKNOWLEDGEMENTS ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── dataset ├── cifar_dataset.py └── imagenet_dataset.py ├── main_cifar.py ├── main_imagenet.py ├── media ├── data_parametres_neurips19_poster.pdf ├── histogram_class_temperature_over_iterations.png └── method_overview.png ├── models └── wide_resnet.py ├── optimizer └── sparse_sgd.py └── utils.py /ACKNOWLEDGEMENTS: -------------------------------------------------------------------------------- 1 | Acknowledgements 2 | Portions of this Software have utilized the following copyrighted material, the use of which is hereby acknowledged. 3 | 4 | _____________________ 5 | Bumsoo Kim (https://github.com/meliketoy/wide-resnet.pytorch) 6 | We have used WideResNet implementation from this repository. 7 | 8 | MIT License 9 | Copyright (c) 2018 Bumsoo Kim 10 | 11 | Permission is hereby granted, free of charge, to any person obtaining a copy 12 | of this software and associated documentation files (the "Software"), to deal 13 | in the Software without restriction, including without limitation the rights 14 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 15 | copies of the Software, and to permit persons to whom the Software is 16 | furnished to do so, subject to the following conditions: 17 | 18 | The above copyright notice and this permission notice shall be included in all 19 | copies or substantial portions of the Software. 20 | 21 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 22 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 23 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 24 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 25 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 26 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 27 | SOFTWARE. 28 | 29 | _____________________ 30 | Konstantin Lopuhin (https://github.com/TeamHG-Memex/tensorboard_logger) 31 | We have used tensorboard_logger to log training information. 32 | 33 | MIT License 34 | Copyright (c) 2016, Konstantin Lopuhin 35 | 36 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), 37 | to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, 38 | sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 39 | 40 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 41 | 42 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 43 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 44 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 45 | DEALINGS IN THE SOFTWARE. 46 | 47 | _____________________ 48 | PyTorch (https://pytorch.org) 49 | We use PyTorch as the training framework for training our models with data parameters. 50 | We have derived SparseSGD optimizer from SGD optimizer so as to support training with sparse gradient updates. 51 | 52 | From PyTorch: 53 | 54 | Copyright (c) 2016- Facebook, Inc (Adam Paszke) 55 | Copyright (c) 2014- Facebook, Inc (Soumith Chintala) 56 | Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) 57 | Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) 58 | Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) 59 | Copyright (c) 2011-2013 NYU (Clement Farabet) 60 | Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) 61 | Copyright (c) 2006 Idiap Research Institute (Samy Bengio) 62 | Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) 63 | 64 | From Caffe2: 65 | 66 | Copyright (c) 2016-present, Facebook Inc. All rights reserved. 67 | 68 | All contributions by Facebook: 69 | Copyright (c) 2016 Facebook Inc. 70 | 71 | All contributions by Google: 72 | Copyright (c) 2015 Google Inc. 73 | All rights reserved. 74 | 75 | All contributions by Yangqing Jia: 76 | Copyright (c) 2015 Yangqing Jia 77 | All rights reserved. 78 | 79 | All contributions from Caffe: 80 | Copyright(c) 2013, 2014, 2015, the respective contributors 81 | All rights reserved. 82 | 83 | All other contributions: 84 | Copyright(c) 2015, 2016 the respective contributors 85 | All rights reserved. 86 | 87 | Caffe2 uses a copyright model similar to Caffe: each contributor holds 88 | copyright over their contributions to Caffe2. The project versioning records 89 | all such contribution and copyright details. If a contributor wants to further 90 | mark their specific copyright on a particular contribution, they should 91 | indicate their copyright solely in the commit message of the change when it is 92 | committed. 93 | 94 | All rights reserved. 95 | 96 | Redistribution and use in source and binary forms, with or without 97 | modification, are permitted provided that the following conditions are met: 98 | 99 | 1. Redistributions of source code must retain the above copyright 100 | notice, this list of conditions and the following disclaimer. 101 | 102 | 2. Redistributions in binary form must reproduce the above copyright 103 | notice, this list of conditions and the following disclaimer in the 104 | documentation and/or other materials provided with the distribution. 105 | 106 | 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America 107 | and IDIAP Research Institute nor the names of its contributors may be 108 | used to endorse or promote products derived from this software without 109 | specific prior written permission. 110 | 111 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 112 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 113 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 114 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 115 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 116 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 117 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 118 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 119 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 120 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 121 | POSSIBILITY OF SUCH DAMAGE. 122 | 123 | _____________________ 124 | Soumith Chintala (TorchVision: https://github.com/pytorch/vision/) 125 | TorchVision is used for handling data loading and IO utilities, which is distributed under BSD 3-Clause License. 126 | We have derived ImageFolderWithIdx dataset class from datasets.ImageFolder class. 127 | 128 | Copyright (c) Soumith Chintala 2016, 129 | All rights reserved. 130 | 131 | Redistribution and use in source and binary forms, with or without 132 | modification, are permitted provided that the following conditions are met: 133 | 134 | * Redistributions of source code must retain the above copyright notice, this 135 | list of conditions and the following disclaimer. 136 | 137 | * Redistributions in binary form must reproduce the above copyright notice, 138 | this list of conditions and the following disclaimer in the documentation 139 | and/or other materials provided with the distribution. 140 | 141 | * Neither the name of the copyright holder nor the names of its 142 | contributors may be used to endorse or promote products derived from 143 | this software without specific prior written permission. 144 | 145 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 146 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 147 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 148 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 149 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 150 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 151 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 152 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 153 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 154 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 155 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the open source team at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 1.4, 71 | available at [https://www.contributor-covenant.org/version/1/4/code-of-conduct.html](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html) -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contribution Guide 2 | 3 | Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducability, and beyond its publication there are limited plans for future development of the repository. 4 | 5 | ## Before you get started 6 | 7 | We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md). -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (C) 2019 Apple Inc. All Rights Reserved. 2 | 3 | IMPORTANT: This Apple software is supplied to you by Apple 4 | Inc. ("Apple") in consideration of your agreement to the following 5 | terms, and your use, installation, modification or redistribution of 6 | this Apple software constitutes acceptance of these terms. If you do 7 | not agree with these terms, please do not use, install, modify or 8 | redistribute this Apple software. 9 | 10 | In consideration of your agreement to abide by the following terms, and 11 | subject to these terms, Apple grants you a personal, non-exclusive 12 | license, under Apple's copyrights in this original Apple software (the 13 | "Apple Software"), to use, reproduce, modify and redistribute the Apple 14 | Software, with or without modifications, in source and/or binary forms; 15 | provided that if you redistribute the Apple Software in its entirety and 16 | without modifications, you must retain this notice and the following 17 | text and disclaimers in all such redistributions of the Apple Software. 18 | Neither the name, trademarks, service marks or logos of Apple Inc. may 19 | be used to endorse or promote products derived from the Apple Software 20 | without specific prior written permission from Apple. Except as 21 | expressly stated in this notice, no other rights or licenses, express or 22 | implied, are granted by Apple herein, including but not limited to any 23 | patent rights that may be infringed by your derivative works or by other 24 | works in which the Apple Software may be incorporated. 25 | 26 | The Apple Software is provided by Apple on an "AS IS" basis. APPLE 27 | MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION 28 | THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS 29 | FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND 30 | OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS. 31 | 32 | IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL 33 | OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 34 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 35 | INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION, 36 | MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED 37 | AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE), 38 | STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE 39 | POSSIBILITY OF SUCH DAMAGE. 40 | 41 | ------------------------------------------------------------------------------- 42 | SOFTWARE DISTRIBUTED IN THIS REPOSITORY: 43 | 44 | This software includes a number of subcomponents with separate 45 | copyright notices and license terms - please see the file ACKNOWLEDGEMENTS. 46 | ------------------------------------------------------------------------------- 47 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Data Parameters: A New Family of Parameters for Learning a Differentiable Curriculum 2 | This repository accompanies the research paper, 3 | [Data Parameters: A New Family of Parameters for Learning a Differentiable Curriculum]( 4 | https://papers.nips.cc/paper/9289-data-parameters-a-new-family-of-parameters-for-learning-a-differentiable-curriculum) 5 | (accepted at NeurIPS 2019). The online copy of the poster is available 6 | [here](./media/data_parametres_neurips19_poster.pdf). 7 | 8 | ## Citation 9 | If you find this code useful in your research then please cite: 10 | ``` 11 | @article{saxena2019data, 12 | title={Data Parameters: A New Family of Parameters for Learning a Differentiable Curriculum}, 13 | author={Saxena, Shreyas and Tuzel, Oncel and DeCoste, Dennis}, 14 | booktitle={NeurIPS}, 15 | year={2019} 16 | } 17 | ``` 18 | ## Data Parameters 19 | In the paper cited above, we have introduced a new family of parameters termed "data parameters". 20 | Specifically, we equip each class and training data point with a learnable parameter (data parameters), which governs 21 | their importance during different stages of training. Along with the model parameters, the data parameters are also 22 | learnt with gradient descent, thereby yielding a curriculum which evolves during the course of training. 23 | More importantly, post training, during inference, data parameters are not used, and hence do not alter the model's 24 | complexity or run-time at inference. 25 | 26 | ![Training Overview](media/method_overview.png) 27 | 28 | 29 | ## Setup and Requirements 30 | This code was developed and tested on Nvidia V100 in the following environment. 31 | 32 | - Ubuntu 18.04 33 | - Python 3.6.9 34 | - Torch 1.2.0 35 | - Torchvision 0.4 36 | 37 | ## Getting Started 38 | Apart from the system requirements, you would also need to download ImageNet dataset locally. 39 | This path needs to be provided to main_imagenet.py with --data argument. 40 | 41 | 42 | ### Training model with data parameters 43 | With very little modification, an existing DNN training pipeline can be modified to use data parameters. 44 | - The first modification is change in data loader. In contrast to standard data loaders which return (x_i, y_i) as 45 | a tuple. We need to return the index of the sample. This is a one line change in \__getitem\__ function (see 46 | [cifar_dataset.py](dataset/cifar_dataset.py) or [imagenet_dataset.py](dataset/imagenet_dataset.py)). Note, this 47 | change is required to implement instance level curriculum. Class level curriculum can be implemented without this 48 | modification. 49 | 50 | - This brings us to the second change (crucial), change in optimizer. Standard optimizers in a deep learning framework like 51 | PyTorch are built for model parameters. They assume that at each iteration, all parameters are involved in the 52 | computational graph and receive a gradient. Therefore, at each iteration, all parameters undergo a weight decay 53 | penalty, along with an update to their corresponding momentum buffer. These assumptions and updates are valid for model 54 | parameters, but not for data parameters. At each iteration, only a subset of data parameters are part of the 55 | computational graph (corresponding to classes and instances in the minibatch). Using standard optimizers 56 | from PyTorch for data parameters will apply a weight decay penalty on all data parameters at each iteration, and 57 | will therefore nullify the learnt curriculum. To circumvent this issue, we apply weight decay explicitly on the 58 | subset of data parameters participating in the minibatch. Also, we have implemented a SparseSGD optimizer which performs 59 | a sparse update of momentum buffer, updating buffer only for data parameters present in the computational graph. 60 | More information can be found in file [sparse_sgd.py](optimizer/sparse_sgd.py). 61 | 62 | - Apart from these changes, the only change required is instantiation of data parameters and rescaling of logits with data 63 | parameters in the forward pass. Since data parameters interact with model at the last layer, in practice, there is 64 | negligible overhead in training time. 65 | 66 | - The three things which can be tuned for data parameters is their: initialization, learning rate, and weight decay. 67 | In practice, we have set initialization of data parameters to 1.0 (initializes training to use standard softmax loss). 68 | This leaves us with two hyper-parameters whose value can be set by grid-search. In our experiments, we found data 69 | parameters to be robust to variations in these hyper-parameters. 70 | 71 | Below we provide example commands along with the hyper-parameters to reproduce results on ImageNet and 72 | CIFAR100 noisy dataset from the paper. 73 | 74 | ### ImageNet 75 | #### Baseline 76 | ``` 77 | python main_imagenet.py \ 78 | --arch 'resnet18' \ 79 | --gpu 0 \ 80 | --data 'path/to/imagenet' \ 81 | ``` 82 | This command will train ResNet18 architecture on ImageNet dataset without data parameters. 83 | This experiment can be used to obtain baseline performance without data parameters. 84 | Running this script, you should obtain 70.2% accuracy on validation @ 100 epoch. 85 | 86 | 87 | #### Train with class level parameters 88 | To train ResNet18 with class-level parameters you can use this command: 89 | ``` 90 | python main_imagenet.py \ 91 | --arch 'resnet18' \ 92 | --data 'path/to/imagenet' \ 93 | --init_class_param 1.0 \ 94 | --lr_class_param 0.1 \ 95 | --wd_class_param 1e-4 \ 96 | --learn_class_paramters \ 97 | ``` 98 | Note, the learning rate, weight decay and initial value of class parameters can be specified 99 | using --lr_class_param, --wd_class_param and --init_class_param respectively. 100 | Running this script with the hyper-parameters specified above, you should obtain 70.5% accuracy on 101 | validation @ 100 epoch. You can run this script with different values of lr_class_param and wd_class_param 102 | to obtain more intuition about data parameters. 103 | 104 | To facilitate introspection, the training script dumps the histogram, mean, highest and the lowest value of 105 | data parameters in the tensorboard for visualization. For example, in the figure below, we can visualize the histogram 106 | of class-level parameters (x-axis) over the course of training (y-axis). The parameters of each class vary in the start 107 | of training, but towards the end of training, they all converge to similar value (indicating all classes were given 108 | close to equal importance at convergence). 109 | ![Histogram of class level parameters](./media/histogram_class_temperature_over_iterations.png) 110 | 111 | #### Joint training with class and instance level parameters 112 | As mentioned in the paper, it is possible to train with class and instance level parameters in a joint manner. 113 | To train ResNet18 with both parameters you can use this command: 114 | ``` 115 | python main_imagenet.py \ 116 | --arch 'resnet18' \ 117 | --data 'path/to/imagenet' \ 118 | --init_class_param 1.0 \ 119 | --lr_class_param 0.1 \ 120 | --wd_class_param 1e-4 \ 121 | --init_inst_param 0.001 \ 122 | --lr_inst_param 0.8 \ 123 | --wd_inst_param 1e-8 \ 124 | --learn_class_paramters \ 125 | --learn_inst_parameters 126 | ``` 127 | For joint training setup, class and instance level parameters are initialized to 1.0 and 0.001. This is to ensure that 128 | their initial sum is close to 1.0. We did not experiment with other initialization schemes, but data parameters can 129 | be initialized with any arbitrary value (greater than 0). Running this script with the hyper-parameters specified 130 | above, you should obtain 70.8% accuracy on validation @ 100 epoch. 131 | 132 | 133 | #### CIFAR100 Noisy Data 134 | [cifar_dataset.py](dataset/cifar_dataset.py) extends the standard CIFAR100 dataset from torchvision to allow corruption 135 | of a subset of data with uniform label swap. 136 | 137 | #### Baseline 138 | ``` 139 | python main_cifar.py \ 140 | --rand_fraction 0.4 \ 141 | ``` 142 | This command trains WideResNet28_10 architecture on CIFAR100 dataset (corruption rate=40%) without data parameters. 143 | This experiment can be used to obtain baseline performance without data parameters. 144 | Running this script, you should obtain 50.0% accuracy at convergence (see Table 2 in paper). 145 | 146 | #### Train with instance level parameters 147 | Since the noise present in the dataset is at instance level, we can train the DNN model with instance level parameters 148 | to learn instance specific curriculum. The curriculum should learn to ignore learning from corrupt samples in the 149 | dataset. 150 | ``` 151 | python main_cifar.py \ 152 | --rand_fraction 0.4 \ 153 | --init_inst_param 1.0 \ 154 | --lr_inst_param 0.2 \ 155 | --wd_inst_param 0.0 \ 156 | --learn_inst_parameters 157 | ``` 158 | Running this script, using instance level parameters, you should obtain 71% accuracy @ 84th epoch. 159 | For results on noisy datasets, we always perform early stopping at 84th epoch (set by cross-validation). 160 | Running with the same hyper-parameters for instance parameters, for 20% and 80% corruption rate, you should obtain 75% 161 | and 35% accuracy respectively. 162 | 163 | 164 | 165 | ## License 166 | This code is released under the [LICENSE](LICENSE) terms. 167 | -------------------------------------------------------------------------------- /dataset/cifar_dataset.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2019 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from torchvision.datasets import CIFAR100 7 | import numpy as np 8 | 9 | 10 | class CIFAR100WithIdx(CIFAR100): 11 | """ 12 | Extends CIFAR100 dataset to yield index of element in addition to image and target label. 13 | """ 14 | 15 | def __init__(self, 16 | root, 17 | train=True, 18 | transform=None, 19 | target_transform=None, 20 | download=False, 21 | rand_fraction=0.0): 22 | super(CIFAR100WithIdx, self).__init__(root=root, 23 | train=train, 24 | transform=transform, 25 | target_transform=target_transform, 26 | download=download) 27 | 28 | assert (rand_fraction <= 1.0) and (rand_fraction >= 0.0) 29 | self.rand_fraction = rand_fraction 30 | 31 | if self.rand_fraction > 0.0: 32 | self.data = self.corrupt_fraction_of_data() 33 | 34 | def corrupt_fraction_of_data(self): 35 | """Corrupts fraction of train data by permuting image-label pairs.""" 36 | 37 | # Check if we are not corrupting test data 38 | assert self.train is True, 'We should not corrupt test data.' 39 | 40 | nr_points = len(self.data) 41 | nr_corrupt_instances = int(np.floor(nr_points * self.rand_fraction)) 42 | print('Randomizing {} fraction of data == {} / {}'.format(self.rand_fraction, 43 | nr_corrupt_instances, 44 | nr_points)) 45 | # We will corrupt the top fraction data points 46 | corrupt_data = self.data[:nr_corrupt_instances, :, :, :] 47 | clean_data = self.data[nr_corrupt_instances:, :, :, :] 48 | 49 | # Corrupting data 50 | rand_idx = np.random.permutation(np.arange(len(corrupt_data))) 51 | corrupt_data = corrupt_data[rand_idx, :, :, :] 52 | 53 | # Adding corrupt and clean data back together 54 | return np.vstack((corrupt_data, clean_data)) 55 | 56 | def __getitem__(self, index): 57 | """ 58 | Args: 59 | index (int): index of element to be fetched 60 | 61 | Returns: 62 | tuple: (sample, target, index) where index is the index of this sample in dataset. 63 | """ 64 | img, target = super().__getitem__(index) 65 | return img, target, index 66 | 67 | 68 | 69 | -------------------------------------------------------------------------------- /dataset/imagenet_dataset.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2019 Apple Inc. All Rights Reserved. 4 | # 5 | import torchvision.datasets as datasets 6 | from PIL import Image 7 | 8 | 9 | def pil_loader(path): 10 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 11 | with open(path, 'rb') as f: 12 | img = Image.open(f) 13 | return img.convert('RGB') 14 | 15 | 16 | def accimage_loader(path): 17 | import accimage 18 | try: 19 | return accimage.Image(path) 20 | except IOError: 21 | # fall back to PIL Image 22 | return pil_loader(path) 23 | 24 | 25 | def default_loader(path): 26 | from torchvision import get_image_backend 27 | if get_image_backend() == 'accimage': 28 | return accimage_loader(path) 29 | else: 30 | return pil_loader(path) 31 | 32 | 33 | class ImageFolderWithIdx(datasets.ImageFolder): 34 | """ 35 | Extends ImageFolder dataset to yield index of element in dataset in addition to image and target label. 36 | 37 | Args: 38 | root (string): Root directory path. 39 | transform (callable, optional): A function/transform that takes in an PIL image 40 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 41 | target_transform (callable, optional): A function/transform that takes in the 42 | target and transforms it. 43 | loader (callable, optional): A function to load an image given its path. 44 | 45 | Attributes: 46 | classes (list): List of the class names. 47 | class_to_idx (dict): Dict with items (class_name, class_index). 48 | imgs (list): List of (image path, class_index) tuples 49 | """ 50 | def __init__(self, 51 | root, 52 | transform=None, 53 | target_transform=None, 54 | loader=default_loader): 55 | super(ImageFolderWithIdx, self).__init__(root=root, 56 | transform=transform, 57 | target_transform=target_transform, 58 | loader=loader) 59 | 60 | def __getitem__(self, index): 61 | """ 62 | Args: 63 | index (int): Index 64 | 65 | Returns: 66 | tuple: (sample, target, index) where index is the index of this sample in dataset. 67 | """ 68 | sample, target = super().__getitem__(index) 69 | return sample, target, index 70 | -------------------------------------------------------------------------------- /main_cifar.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2019 Apple Inc. All Rights Reserved. 4 | # 5 | import time 6 | import argparse 7 | 8 | import numpy as np 9 | import torch 10 | import torch.optim 11 | import torch.nn as nn 12 | import torch.utils.data 13 | import torch.nn.parallel 14 | import torch.utils.data.distributed 15 | import torchvision.transforms as transforms 16 | from tensorboard_logger import log_value 17 | 18 | import utils 19 | from dataset.cifar_dataset import CIFAR100WithIdx 20 | from models.wide_resnet import WideResNet28_10 21 | 22 | 23 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 24 | parser.add_argument('-j', '--workers', default=32, type=int, metavar='N', 25 | help='number of data loading workers (default: 4)') 26 | parser.add_argument('--epochs', default=120, type=int, metavar='N', 27 | help='number of total epochs to run') 28 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 29 | help='manual epoch number (useful on restarts)') 30 | parser.add_argument('--restart', default=False, const=True, action='store_const', 31 | help='Erase log and saved checkpoints and restart training') 32 | parser.add_argument('-b', '--batch-size', default=128, type=int, 33 | metavar='N', help='mini-batch size (default: 128)') 34 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 35 | metavar='LR', help='initial learning rate for model parameters', dest='lr') 36 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') 37 | parser.add_argument('--wd', '--weight-decay', default=5e-4, type=float, 38 | metavar='W', help='weight decay (default: 5e-4)', 39 | dest='weight_decay') 40 | parser.add_argument('--rand_fraction', default=0.0, type=float, help='Fraction of data we will corrupt') 41 | parser.add_argument('-p', '--print-freq', default=10, type=int, 42 | metavar='N', help='print frequency') 43 | parser.add_argument('--seed', default=1, type=int, 44 | help='seed for initializing training. ') 45 | parser.add_argument('--learn_class_parameters', default=False, const=True, action='store_const', 46 | help='Learn temperature per class') 47 | parser.add_argument('--learn_inst_parameters', default=False, const=True, action='store_const', 48 | help='Learn temperature per instance') 49 | parser.add_argument('--skip_clamp_data_param', default=False, const=True, action='store_const', 50 | help='Do not clamp data parameters during optimization') 51 | parser.add_argument('--lr_class_param', default=0.1, type=float, help='Learning rate for class parameters') 52 | parser.add_argument('--lr_inst_param', default=0.1, type=float, help='Learning rate for instance parameters') 53 | parser.add_argument('--wd_class_param', default=0.0, type=float, help='Weight decay for class parameters') 54 | parser.add_argument('--wd_inst_param', default=0.0, type=float, help='Weight decay for instance parameters') 55 | parser.add_argument('--init_class_param', default=1.0, type=float, help='Initial value for class parameters') 56 | parser.add_argument('--init_inst_param', default=1.0, type=float, help='Initial value for instance parameters') 57 | 58 | 59 | def adjust_learning_rate(model_initial_lr, optimizer, gamma, step): 60 | """Sets the learning rate to the initial learning rate decayed by 10 every few epochs. 61 | 62 | Args: 63 | model_initial_lr (int) : initial learning rate for model parameters 64 | optimizer (class derived under torch.optim): torch optimizer. 65 | gamma (float): fraction by which we are going to decay the learning rate of model parameters 66 | step (int) : number of steps in staircase learning rate decay schedule 67 | """ 68 | lr = model_initial_lr * (gamma ** step) 69 | for param_group in optimizer.param_groups: 70 | param_group['lr'] = lr 71 | 72 | 73 | def get_train_and_val_loader(args): 74 | """"Constructs data loaders for train and val on CIFAR100 75 | 76 | Args: 77 | args (argparse.Namespace): 78 | 79 | Returns: 80 | train_loader (torch.utils.data.DataLoader): data loader for CIFAR100 train data. 81 | val_loader (torch.utils.data.DataLoader): data loader for CIFAR100 val data. 82 | """ 83 | print('==> Preparing data for CIFAR100..') 84 | 85 | transform_train = transforms.Compose([ 86 | transforms.RandomCrop(32, padding=4), 87 | transforms.RandomHorizontalFlip(), 88 | transforms.ToTensor(), 89 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 90 | ]) 91 | transform_val = transforms.Compose([ 92 | transforms.ToTensor(), 93 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 94 | ]) 95 | 96 | trainset = CIFAR100WithIdx(root='/tmp/data', 97 | train=True, 98 | download=True, 99 | transform=transform_train, 100 | rand_fraction=args.rand_fraction) 101 | valset = CIFAR100WithIdx(root='/tmp/data', 102 | train=False, 103 | download=True, 104 | transform=transform_val) 105 | train_loader = torch.utils.data.DataLoader(trainset, 106 | batch_size=args.batch_size, 107 | shuffle=True, 108 | num_workers=args.workers) 109 | val_loader = torch.utils.data.DataLoader(valset, 110 | batch_size=100, 111 | shuffle=False, 112 | num_workers=args.workers) 113 | 114 | return train_loader, val_loader 115 | 116 | 117 | def get_model_and_loss_criterion(args): 118 | """Initializes DNN model and loss function. 119 | 120 | Args: 121 | args (argparse.Namespace): 122 | 123 | Returns: 124 | model (torch.nn.Module): DNN model. 125 | criterion (torch.nn.modules.loss): cross entropy loss 126 | """ 127 | print('Building WideResNet28_10') 128 | args.arch = 'WideResNet28_10' 129 | model = WideResNet28_10(num_classes=args.nr_classes) 130 | if args.device == 'cuda': 131 | model = model.cuda() 132 | criterion = nn.CrossEntropyLoss().cuda() 133 | else: 134 | criterion = nn.CrossEntropyLoss() 135 | 136 | return model, criterion 137 | 138 | 139 | def validate(args, val_loader, model, criterion, epoch): 140 | """Evaluates model on validation set and logs score on tensorboard. 141 | 142 | Args: 143 | args (argparse.Namespace): 144 | val_loader (torch.utils.data.dataloader): dataloader for validation set. 145 | model (torch.nn.Module): DNN model. 146 | criterion (torch.nn.modules.loss): cross entropy loss 147 | epoch (int): current epoch 148 | """ 149 | losses = utils.AverageMeter('Loss', ':.4e') 150 | top1 = utils.AverageMeter('Acc@1', ':6.2f') 151 | # switch to evaluate mode 152 | model.eval() 153 | 154 | with torch.no_grad(): 155 | for i, (inputs, target, _) in enumerate(val_loader): 156 | if args.device == 'cuda': 157 | inputs = inputs.cuda() 158 | target = target.cuda() 159 | 160 | # compute output 161 | logits = model(inputs) 162 | loss = criterion(logits, target) 163 | 164 | # measure accuracy and record loss 165 | acc1 = utils.compute_topk_accuracy(logits, target, topk=(1, )) 166 | losses.update(loss.item(), inputs.size(0)) 167 | top1.update(acc1[0].item(), inputs.size(0)) 168 | 169 | print('Test-Epoch-{}: Acc:{}, Loss:{}'.format(epoch, top1.avg, losses.avg)) 170 | 171 | # Logging results on tensorboard 172 | log_value('val/accuracy', top1.avg, step=epoch) 173 | log_value('val/loss', losses.avg, step=epoch) 174 | 175 | 176 | def train_for_one_epoch(args, 177 | train_loader, 178 | model, 179 | criterion, 180 | optimizer, 181 | epoch, 182 | global_iter, 183 | optimizer_data_parameters, 184 | data_parameters, 185 | config): 186 | """Train model for one epoch on the train set. 187 | 188 | Args: 189 | args (argparse.Namespace): 190 | train_loader (torch.utils.data.dataloader): dataloader for train set. 191 | model (torch.nn.Module): DNN model. 192 | criterion (torch.nn.modules.loss): cross entropy loss. 193 | optimizer (torch.optim.SGD): optimizer for model parameters. 194 | epoch (int): current epoch. 195 | global_iter (int): current iteration count. 196 | optimizer_data_parameters (tuple SparseSGD): SparseSGD optimizer for class and instance data parameters. 197 | data_parameters (tuple of torch.Tensor): class and instance level data parameters. 198 | config (dict): config file for the experiment. 199 | 200 | Returns: 201 | global iter (int): updated iteration count after 1 epoch. 202 | """ 203 | 204 | # Initialize counters 205 | losses = utils.AverageMeter('Loss', ':.4e') 206 | top1 = utils.AverageMeter('Acc@1', ':6.2f') 207 | 208 | # Unpack data parameters 209 | optimizer_class_param, optimizer_inst_param = optimizer_data_parameters 210 | class_parameters, inst_parameters = data_parameters 211 | 212 | # Switch to train mode 213 | model.train() 214 | start_epoch_time = time.time() 215 | for i, (inputs, target, index_dataset) in enumerate(train_loader): 216 | global_iter = global_iter + 1 217 | inputs, target = inputs.to(args.device), target.to(args.device) 218 | 219 | # Flush the gradient buffer for model and data-parameters 220 | optimizer.zero_grad() 221 | if args.learn_class_parameters: 222 | optimizer_class_param.zero_grad() 223 | if args.learn_inst_parameters: 224 | optimizer_inst_param.zero_grad() 225 | 226 | # Compute logits 227 | logits = model(inputs) 228 | 229 | if args.learn_class_parameters or args.learn_inst_parameters: 230 | # Compute data parameters for instances in the minibatch 231 | class_parameter_minibatch = class_parameters[target] 232 | inst_parameter_minibatch = inst_parameters[index_dataset] 233 | data_parameter_minibatch = utils.get_data_param_for_minibatch( 234 | args, 235 | class_param_minibatch=class_parameter_minibatch, 236 | inst_param_minibatch=inst_parameter_minibatch) 237 | 238 | # Compute logits scaled by data parameters 239 | logits = logits / data_parameter_minibatch 240 | 241 | loss = criterion(logits, target) 242 | 243 | # Apply weight decay on data parameters 244 | if args.learn_class_parameters or args.learn_inst_parameters: 245 | loss = utils.apply_weight_decay_data_parameters(args, loss, 246 | class_parameter_minibatch=class_parameter_minibatch, 247 | inst_parameter_minibatch=inst_parameter_minibatch) 248 | 249 | # Compute gradient and do SGD step 250 | loss.backward() 251 | optimizer.step() 252 | if args.learn_class_parameters: 253 | optimizer_class_param.step() 254 | if args.learn_inst_parameters: 255 | optimizer_inst_param.step() 256 | 257 | # Clamp class and instance level parameters within certain bounds 258 | if args.learn_class_parameters or args.learn_inst_parameters: 259 | utils.clamp_data_parameters(args, class_parameters, config, inst_parameters) 260 | 261 | # Measure accuracy and record loss 262 | acc1 = utils.compute_topk_accuracy(logits, target, topk=(1, )) 263 | losses.update(loss.item(), inputs.size(0)) 264 | top1.update(acc1[0].item(), inputs.size(0)) 265 | 266 | # Log stats for data parameters and loss every few iterations 267 | if i % args.print_freq == 0: 268 | utils.log_intermediate_iteration_stats(args, class_parameters, epoch, 269 | global_iter, inst_parameters, 270 | losses, top1) 271 | 272 | # Print and log stats for the epoch 273 | print('Time for epoch: {}'.format(time.time() - start_epoch_time)) 274 | print('Train-Epoch-{}: Acc:{}, Loss:{}'.format(epoch, top1.avg, losses.avg)) 275 | log_value('train/accuracy', top1.avg, step=epoch) 276 | log_value('train/loss', losses.avg, step=epoch) 277 | 278 | return global_iter 279 | 280 | 281 | def main_worker(args, config): 282 | """Trains model on ImageNet using data parameters 283 | 284 | Args: 285 | args (argparse.Namespace): 286 | config (dict): config file for the experiment. 287 | """ 288 | global_iter = 0 289 | learning_rate_schedule = np.array([80, 100, 160]) 290 | 291 | # Create model 292 | model, loss_criterion = get_model_and_loss_criterion(args) 293 | 294 | # Define optimizer 295 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 296 | momentum=args.momentum, 297 | weight_decay=args.weight_decay) 298 | 299 | # Get train and validation dataset loader 300 | train_loader, val_loader = get_train_and_val_loader(args) 301 | 302 | # Initialize class and instance based temperature 303 | (class_parameters, inst_parameters, 304 | optimizer_class_param, optimizer_inst_param) = utils.get_class_inst_data_params_n_optimizer( 305 | args=args, 306 | nr_classes=args.nr_classes, 307 | nr_instances=len(train_loader.dataset), 308 | device='cuda' 309 | ) 310 | 311 | # Training loop 312 | for epoch in range(args.start_epoch, args.epochs): 313 | 314 | # Adjust learning rate for model parameters 315 | if epoch in learning_rate_schedule: 316 | adjust_learning_rate(model_initial_lr=args.lr, 317 | optimizer=optimizer, 318 | gamma=0.1, 319 | step=np.sum(epoch >= learning_rate_schedule)) 320 | 321 | # Train for one epoch 322 | global_iter = train_for_one_epoch( 323 | args=args, 324 | train_loader=train_loader, 325 | model=model, 326 | criterion=loss_criterion, 327 | optimizer=optimizer, 328 | epoch=epoch, 329 | global_iter=global_iter, 330 | optimizer_data_parameters=(optimizer_class_param, optimizer_inst_param), 331 | data_parameters=(class_parameters, inst_parameters), 332 | config=config) 333 | 334 | # Evaluate on validation set 335 | validate(args, val_loader, model, loss_criterion, epoch) 336 | 337 | # Save artifacts 338 | utils.save_artifacts(args, epoch, model, class_parameters, inst_parameters) 339 | 340 | # Log temperature stats over epochs 341 | if args.learn_class_parameters: 342 | utils.log_stats(data=torch.exp(class_parameters), 343 | name='epochs_stats_class_parameter', 344 | step=epoch) 345 | if args.learn_inst_parameters: 346 | utils.log_stats(data=torch.exp(inst_parameters), 347 | name='epoch_stats_inst_parameter', 348 | step=epoch) 349 | 350 | if args.rand_fraction > 0.0: 351 | # We have corrupted labels in the train data; plot instance parameter stats for clean and corrupt data 352 | nr_corrupt_instances = int(np.floor(len(train_loader.dataset) * args.rand_fraction)) 353 | # Corrupt data is in the top-fraction of dataset 354 | utils.log_stats(data=torch.exp(inst_parameters[:nr_corrupt_instances]), 355 | name='epoch_stats_corrupt_inst_parameter', 356 | step=epoch) 357 | utils.log_stats(data=torch.exp(inst_parameters[nr_corrupt_instances:]), 358 | name='epoch_stats_clean_inst_parameter', 359 | step=epoch) 360 | 361 | 362 | def main(): 363 | args = parser.parse_args() 364 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu' 365 | args.log_dir = './logs_CL_CIFAR' 366 | args.save_dir = './weights_CL_CIFAR' 367 | args.nr_classes = 100 # Number classes in CIFAR100 368 | utils.generate_log_dir(args) 369 | utils.generate_save_dir(args) 370 | 371 | config = {} 372 | config['clamp_inst_sigma'] = {} 373 | config['clamp_inst_sigma']['min'] = np.log(1/20) 374 | config['clamp_inst_sigma']['max'] = np.log(20) 375 | config['clamp_cls_sigma'] = {} 376 | config['clamp_cls_sigma']['min'] = np.log(1/20) 377 | config['clamp_cls_sigma']['max'] = np.log(20) 378 | utils.save_config(args.save_dir, config) 379 | 380 | # Set seed for reproducibility 381 | utils.set_seed(args) 382 | 383 | # Simply call main_worker function 384 | main_worker(args, config) 385 | 386 | 387 | if __name__ == '__main__': 388 | main() 389 | -------------------------------------------------------------------------------- /main_imagenet.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2019 Apple Inc. All Rights Reserved. 4 | # 5 | import os 6 | import time 7 | import argparse 8 | 9 | import numpy as np 10 | import torch 11 | import torch.optim 12 | import torch.nn as nn 13 | import torch.utils.data 14 | import torch.nn.parallel 15 | import torchvision.models as models 16 | import torch.utils.data.distributed 17 | import torchvision.transforms as transforms 18 | from tensorboard_logger import log_value 19 | 20 | import utils 21 | from dataset.imagenet_dataset import ImageFolderWithIdx 22 | model_names = sorted(name for name in models.__dict__ 23 | if name.islower() and not name.startswith("__") 24 | and callable(models.__dict__[name])) 25 | 26 | 27 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training With Data Parameters') 28 | parser.add_argument('--data', metavar='DIR', 29 | help='path to dataset') 30 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', 31 | choices=model_names, 32 | help='model architecture: ' + 33 | ' | '.join(model_names) + 34 | ' (default: resnet18)') 35 | parser.add_argument('-j', '--workers', default=32, type=int, metavar='N', 36 | help='number of data loading workers (default: 4)') 37 | parser.add_argument('--job_name', default='temp', help='Job name used to create save directories for ' 38 | 'checkpoints and logs') 39 | parser.add_argument('--epochs', default=120, type=int, metavar='N', 40 | help='number of total epochs to run') 41 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 42 | help='manual epoch number (useful on restarts)') 43 | parser.add_argument('--restart', default=False, const=True, action='store_const', 44 | help='Erase log and saved checkpoints and restart training') 45 | parser.add_argument('-b', '--batch-size', default=256, type=int, 46 | metavar='N', 47 | help='mini-batch size (default: 256), this is the total ' 48 | 'batch size of all GPUs on the current node when ' 49 | 'using Data Parallel or Distributed Data Parallel') 50 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 51 | metavar='LR', help='initial learning rate for model parameters', dest='lr') 52 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 53 | help='momentum') 54 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 55 | metavar='W', help='weight decay (default: 1e-4)', 56 | dest='weight_decay') 57 | parser.add_argument('-p', '--print-freq', default=10, type=int, 58 | metavar='N', help='print frequency') 59 | parser.add_argument('--seed', default=1, type=int, 60 | help='seed for initializing training. ') 61 | parser.add_argument('--gpu', default=None, type=int, 62 | help='GPU id to use.') 63 | parser.add_argument('--learn_class_parameters', default=False, const=True, action='store_const', 64 | help='Learn temperature per class') 65 | parser.add_argument('--learn_inst_parameters', default=False, const=True, action='store_const', 66 | help='Learn temperature per instance') 67 | parser.add_argument('--skip_clamp_data_param', default=False, const=True, action='store_const', 68 | help='Do not clamp data parameters during optimization') 69 | parser.add_argument('--lr_class_param', default=0.1, type=float, help='Learning rate for class parameters') 70 | parser.add_argument('--lr_inst_param', default=0.1, type=float, help='Learning rate for instance parameters') 71 | parser.add_argument('--wd_class_param', default=0.0, type=float, help='Weight decay for class parameters') 72 | parser.add_argument('--wd_inst_param', default=0.0, type=float, help='Weight decay for instance parameters') 73 | parser.add_argument('--init_class_param', default=1.0, type=float, help='Initial value for class parameters') 74 | parser.add_argument('--init_inst_param', default=1.0, type=float, help='Initial value for instance parameters') 75 | parser.add_argument('--lr_drop_epoch_step', default=30, type=int, help='Nr epochs upon which model parameters ' 76 | 'lr should drop by 0.1') 77 | 78 | 79 | def adjust_learning_rate(optimizer, epoch, args): 80 | """Sets the learning rate to the initial learning rate decayed by 10 every few epochs. 81 | 82 | Args: 83 | optimizer (class derived under torch.optim): torch optimizer. 84 | epoch (int): current epoch count. 85 | args (argparse.Namespace): 86 | """ 87 | lr = args.lr * (0.1 ** (epoch // args.lr_drop_epoch_step)) 88 | for param_group in optimizer.param_groups: 89 | param_group['lr'] = lr 90 | 91 | 92 | def get_train_and_val_loader(args): 93 | """"Constructs data loaders for train and validation on ImageNet. 94 | 95 | Args: 96 | args (argparse.Namespace): 97 | 98 | Returns: 99 | train_loader (torch.utils.data.DataLoader): data loader for train data. 100 | val_loader (torch.utils.data.DataLoader): data loader for validation data. 101 | """ 102 | traindir = os.path.join(args.data, 'training') 103 | valdir = os.path.join(args.data, 'validation') 104 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 105 | std=[0.229, 0.224, 0.225]) 106 | 107 | # Instead of ImageFolder dataset class, we use ImageFolderIdx which is derived 108 | # from the former and returns the index of items in the minibatch. 109 | train_dataset = ImageFolderWithIdx( 110 | traindir, 111 | transforms.Compose([ 112 | transforms.RandomResizedCrop(224), 113 | transforms.RandomHorizontalFlip(), 114 | transforms.ToTensor(), 115 | normalize, 116 | ])) 117 | train_loader = torch.utils.data.DataLoader( 118 | train_dataset, 119 | batch_size=args.batch_size, 120 | shuffle=True, 121 | num_workers=args.workers, 122 | pin_memory=True) 123 | val_loader = torch.utils.data.DataLoader( 124 | ImageFolderWithIdx(valdir, transforms.Compose([ 125 | transforms.Resize(256), 126 | transforms.CenterCrop(224), 127 | transforms.ToTensor(), 128 | normalize, 129 | ])), 130 | batch_size=args.batch_size, 131 | shuffle=False, 132 | num_workers=args.workers, 133 | pin_memory=True) 134 | return train_loader, val_loader 135 | 136 | 137 | def get_model_and_loss_criterion(args): 138 | """Initializes DNN model and loss function on a single GPU or multiple GPU's for data parallelism. 139 | 140 | Args: 141 | args (argparse.Namespace): 142 | 143 | Returns: 144 | model (torch.nn.Module): DNN model. 145 | criterion (torch.nn.modules.loss): cross entropy loss 146 | """ 147 | print("=> creating model '{}'".format(args.arch)) 148 | model = models.__dict__[args.arch]() 149 | if args.gpu is not None: 150 | print("Using GPU: {} for training".format(args.gpu)) 151 | torch.cuda.set_device(args.gpu) 152 | model = model.cuda(args.gpu) 153 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 154 | else: 155 | print('Splitting model across all GPUs with data parallelism') 156 | criterion = nn.CrossEntropyLoss().cuda() 157 | # DataParallel will divide and allocate batch_size to all available GPUs 158 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 159 | model.features = torch.nn.DataParallel(model.features) 160 | model.cuda() 161 | else: 162 | model = torch.nn.DataParallel(model).cuda() 163 | return model, criterion 164 | 165 | 166 | def validate(args, val_loader, model, criterion, epoch): 167 | """Evaluates model on validation set and logs score on tensorboard. 168 | 169 | Args: 170 | args (argparse.Namespace): 171 | val_loader (torch.utils.data.dataloader): dataloader for validation set. 172 | model (torch.nn.Module): DNN model. 173 | criterion (torch.nn.modules.loss): cross entropy loss 174 | epoch (int): current epoch 175 | """ 176 | losses = utils.AverageMeter('Loss', ':.4e') 177 | top1 = utils.AverageMeter('Acc@1', ':6.2f') 178 | top5 = utils.AverageMeter('Acc@5', ':6.2f') 179 | # switch to evaluate mode 180 | model.eval() 181 | 182 | with torch.no_grad(): 183 | for i, (inputs, target, _) in enumerate(val_loader): 184 | if args.gpu is not None: 185 | inputs = inputs.cuda(args.gpu, non_blocking=True) 186 | target = target.cuda(args.gpu, non_blocking=True) 187 | else: 188 | inputs = inputs.cuda() 189 | target = target.cuda() 190 | 191 | # compute output 192 | logits = model(inputs) 193 | loss = criterion(logits, target) 194 | 195 | # measure accuracy and record loss 196 | acc1, acc5 = utils.compute_topk_accuracy(logits, target, topk=(1, 5)) 197 | losses.update(loss.item(), inputs.size(0)) 198 | top1.update(acc1[0], inputs.size(0)) 199 | top5.update(acc5[0], inputs.size(0)) 200 | 201 | print(' * Validation Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(top1=top1, top5=top5)) 202 | 203 | # Logging results on tensorboard 204 | log_value('val/accuracy_top1', top1.avg, step=epoch) 205 | log_value('val/accuracy_top5', top5.avg, step=epoch) 206 | log_value('val/loss', losses.avg, step=epoch) 207 | 208 | 209 | def train_for_one_epoch(args, 210 | train_loader, 211 | model, 212 | criterion, 213 | optimizer, 214 | epoch, 215 | global_iter, 216 | optimizer_data_parameters, 217 | data_parameters, 218 | config): 219 | """Train model for one epoch on the train set. 220 | 221 | Args: 222 | args (argparse.Namespace): 223 | train_loader (torch.utils.data.dataloader): dataloader for train set. 224 | model (torch.nn.Module): DNN model. 225 | criterion (torch.nn.modules.loss): cross entropy loss. 226 | optimizer (torch.optim.SGD): optimizer for model parameters. 227 | epoch (int): current epoch. 228 | global_iter (int): current iteration count. 229 | optimizer_data_parameters (tuple SparseSGD): SparseSGD optimizer for class and instance data parameters. 230 | data_parameters (tuple of torch.Tensor): class and instance level data parameters. 231 | config (dict): config file for the experiment. 232 | 233 | Returns: 234 | global iter (int): updated iteration count after 1 epoch. 235 | """ 236 | 237 | # Initialize counters 238 | losses = utils.AverageMeter('Loss', ':.4e') 239 | top1 = utils.AverageMeter('Acc@1', ':6.2f') 240 | top5 = utils.AverageMeter('Acc@5', ':6.2f') 241 | 242 | # Unpack data parameters 243 | optimizer_class_param, optimizer_inst_param = optimizer_data_parameters 244 | class_parameters, inst_parameters = data_parameters 245 | 246 | # Switch to train mode 247 | model.train() 248 | start_epoch_time = time.time() 249 | 250 | for i, (inputs, target, index_dataset) in enumerate(train_loader): 251 | global_iter = global_iter + 1 252 | 253 | if args.gpu is not None: 254 | inputs = inputs.cuda(args.gpu, non_blocking=True) 255 | target = target.cuda(args.gpu, non_blocking=True) 256 | else: 257 | inputs = inputs.cuda() 258 | target = target.cuda() 259 | 260 | # Flush the gradient buffer for model and data-parameters 261 | optimizer.zero_grad() 262 | if args.learn_class_parameters: 263 | optimizer_class_param.zero_grad() 264 | if args.learn_inst_parameters: 265 | optimizer_inst_param.zero_grad() 266 | 267 | logits = model(inputs) 268 | 269 | if args.learn_class_parameters or args.learn_inst_parameters: 270 | # Compute data parameters for instances in the minibatch 271 | class_parameter_minibatch = class_parameters[target] 272 | inst_parameter_minibatch = inst_parameters[index_dataset] 273 | data_parameter_minibatch = utils.get_data_param_for_minibatch( 274 | args, 275 | class_param_minibatch=class_parameter_minibatch, 276 | inst_param_minibatch=inst_parameter_minibatch) 277 | 278 | # Compute logits scaled by data parameters 279 | logits = logits / data_parameter_minibatch 280 | 281 | loss = criterion(logits, target) 282 | 283 | # Apply weight decay on data parameters 284 | if args.learn_class_parameters or args.learn_inst_parameters: 285 | loss = utils.apply_weight_decay_data_parameters(args, loss, 286 | class_parameter_minibatch=class_parameter_minibatch, 287 | inst_parameter_minibatch=inst_parameter_minibatch) 288 | 289 | # Compute gradient and do SGD step 290 | loss.backward() 291 | optimizer.step() 292 | if args.learn_class_parameters: 293 | optimizer_class_param.step() 294 | if args.learn_inst_parameters: 295 | optimizer_inst_param.step() 296 | 297 | # Clamp class and instance level parameters within certain bounds 298 | if args.learn_class_parameters or args.learn_inst_parameters: 299 | utils.clamp_data_parameters(args, class_parameters, config, inst_parameters) 300 | 301 | # Measure accuracy and record loss 302 | acc1, acc5 = utils.compute_topk_accuracy(logits, target, topk=(1, 5)) 303 | losses.update(loss.item(), inputs.size(0)) 304 | top1.update(acc1[0], inputs.size(0)) 305 | top5.update(acc5[0], inputs.size(0)) 306 | 307 | # Log stats for data parameters and loss every few iterations 308 | if i % args.print_freq == 0: 309 | utils.log_intermediate_iteration_stats(args, class_parameters, epoch, 310 | global_iter, inst_parameters, 311 | losses, top1, top5) 312 | # Print and log stats for the epoch 313 | print('Train-Epoch-{}: Acc-5:{}, Acc-1:{}, Loss:{}'.format(epoch, top5.avg, top1.avg, losses.avg)) 314 | print('Time for 1 epoch: {}'.format(time.time() - start_epoch_time)) 315 | log_value('train/accuracy_top5', top5.avg, step=epoch) 316 | log_value('train/accuracy_top1', top1.avg, step=epoch) 317 | log_value('train/loss', losses.avg, step=epoch) 318 | 319 | return global_iter 320 | 321 | 322 | def main_worker(args, config): 323 | """Trains model on ImageNet using data parameters 324 | 325 | Args: 326 | args (argparse.Namespace): 327 | config (dict): config file for the experiment. 328 | """ 329 | global_iter = 0 330 | 331 | # Create model 332 | model, loss_criterion = get_model_and_loss_criterion(args) 333 | 334 | # Define optimizer 335 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 336 | momentum=args.momentum, 337 | weight_decay=args.weight_decay) 338 | 339 | # Get train and validation dataset loader 340 | train_loader, val_loader = get_train_and_val_loader(args) 341 | 342 | # Initialize class and instance based temperature 343 | (class_parameters, inst_parameters, 344 | optimizer_class_param, optimizer_inst_param) = utils.get_class_inst_data_params_n_optimizer( 345 | args=args, 346 | nr_classes=1000, 347 | nr_instances=len(train_loader.dataset), 348 | device='cuda' 349 | ) 350 | # Training loop 351 | for epoch in range(args.start_epoch, args.epochs): 352 | adjust_learning_rate(optimizer, epoch, args) 353 | 354 | # Train for one epoch 355 | global_iter = train_for_one_epoch( 356 | args=args, 357 | train_loader=train_loader, 358 | model=model, 359 | criterion=loss_criterion, 360 | optimizer=optimizer, 361 | epoch=epoch, 362 | global_iter=global_iter, 363 | optimizer_data_parameters=(optimizer_class_param, optimizer_inst_param), 364 | data_parameters=(class_parameters, inst_parameters), 365 | config=config) 366 | 367 | # Evaluate on validation set 368 | validate(args, val_loader, model, loss_criterion, epoch) 369 | 370 | # Save artifacts 371 | utils.save_artifacts(args, epoch, model, class_parameters, inst_parameters) 372 | 373 | # Log temperature stats over epochs 374 | if args.learn_class_parameters: 375 | utils.log_stats(data=torch.exp(class_parameters), 376 | name='epochs_stats_class_parameter', 377 | step=epoch) 378 | if args.learn_inst_parameters: 379 | utils.log_stats(data=torch.exp(inst_parameters), 380 | name='epochs_stats_inst_parameter', 381 | step=epoch) 382 | 383 | 384 | def main(): 385 | args = parser.parse_args() 386 | args.log_dir = './logs_CL_imagenet' 387 | args.save_dir = './weights_CL_imagenet' 388 | utils.generate_log_dir(args) 389 | utils.generate_save_dir(args) 390 | 391 | config = {} 392 | config['clamp_inst_sigma'] = {} 393 | config['clamp_inst_sigma']['min'] = np.log(1/20) 394 | config['clamp_inst_sigma']['max'] = np.log(20) 395 | config['clamp_cls_sigma'] = {} 396 | config['clamp_cls_sigma']['min'] = np.log(1/20) 397 | config['clamp_cls_sigma']['max'] = np.log(20) 398 | utils.save_config(args.save_dir, config) 399 | 400 | # Set seed for reproducibility 401 | utils.set_seed(args) 402 | 403 | # Simply call main_worker function 404 | main_worker(args, config) 405 | 406 | 407 | if __name__ == '__main__': 408 | main() 409 | -------------------------------------------------------------------------------- /media/data_parametres_neurips19_poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-data-parameters/c2931c12d2c6100105c4dcc14df2797e39772098/media/data_parametres_neurips19_poster.pdf -------------------------------------------------------------------------------- /media/histogram_class_temperature_over_iterations.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-data-parameters/c2931c12d2c6100105c4dcc14df2797e39772098/media/histogram_class_temperature_over_iterations.png -------------------------------------------------------------------------------- /media/method_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-data-parameters/c2931c12d2c6100105c4dcc14df2797e39772098/media/method_overview.png -------------------------------------------------------------------------------- /models/wide_resnet.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2019 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | # File imported/derived from: https://github.com/meliketoy/wide-resnet.pytorch 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.init as init 11 | import torch.nn.functional as F 12 | from torch.autograd import Variable 13 | import numpy as np 14 | 15 | 16 | def conv3x3(in_planes, out_planes, stride=1): 17 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True) 18 | 19 | 20 | def conv_init(m): 21 | classname = m.__class__.__name__ 22 | if classname.find('Conv') != -1: 23 | init.xavier_uniform(m.weight, gain=np.sqrt(2)) 24 | init.constant(m.bias, 0) 25 | elif classname.find('BatchNorm') != -1: 26 | init.constant(m.weight, 1) 27 | init.constant(m.bias, 0) 28 | 29 | 30 | class WideBasic(nn.Module): 31 | def __init__(self, in_planes, planes, dropout_rate, stride=1): 32 | super(WideBasic, self).__init__() 33 | self.bn1 = nn.BatchNorm2d(in_planes) 34 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True) 35 | self.dropout = nn.Dropout(p=dropout_rate) 36 | self.bn2 = nn.BatchNorm2d(planes) 37 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True) 38 | 39 | self.shortcut = nn.Sequential() 40 | if stride != 1 or in_planes != planes: 41 | self.shortcut = nn.Sequential( 42 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True), 43 | ) 44 | 45 | def forward(self, x): 46 | out = self.dropout(self.conv1(F.relu(self.bn1(x)))) 47 | out = self.conv2(F.relu(self.bn2(out))) 48 | out += self.shortcut(x) 49 | 50 | return out 51 | 52 | class Wide_ResNet(nn.Module): 53 | def __init__(self, depth, widen_factor, dropout_rate, num_classes): 54 | super(Wide_ResNet, self).__init__() 55 | self.in_planes = 16 56 | 57 | assert ((depth-4)%6 ==0), 'Wide-resnet depth should be 6n+4' 58 | n = int((depth-4)/6) 59 | k = widen_factor 60 | 61 | print('| Wide-Resnet %dx%d' %(depth, k)) 62 | nStages = [16, 16*k, 32*k, 64*k] 63 | 64 | self.conv1 = conv3x3(3,nStages[0]) 65 | self.layer1 = self._wide_layer(WideBasic, nStages[1], n, dropout_rate, stride=1) 66 | self.layer2 = self._wide_layer(WideBasic, nStages[2], n, dropout_rate, stride=2) 67 | self.layer3 = self._wide_layer(WideBasic, nStages[3], n, dropout_rate, stride=2) 68 | self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9) 69 | self.linear = nn.Linear(nStages[3], num_classes) 70 | 71 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride): 72 | strides = [stride] + [1]*(num_blocks-1) 73 | layers = [] 74 | 75 | for stride in strides: 76 | layers.append(block(self.in_planes, planes, dropout_rate, stride)) 77 | self.in_planes = planes 78 | 79 | return nn.Sequential(*layers) 80 | 81 | def forward(self, x): 82 | out = self.conv1(x) 83 | out = self.layer1(out) 84 | out = self.layer2(out) 85 | out = self.layer3(out) 86 | out = F.relu(self.bn1(out)) 87 | out = F.avg_pool2d(out, 8) 88 | out = out.view(out.size(0), -1) 89 | out = self.linear(out) 90 | 91 | return out 92 | 93 | def WideResNet28_10(num_classes=10, dropout_rate=0.0): 94 | assert dropout_rate == 0.0, 'We have decided not to use dropout' 95 | if dropout_rate > 0.0: 96 | print('We are going to instantiate WideResNet with dropout rate {}'.format(dropout_rate)) 97 | return Wide_ResNet(depth=28, widen_factor=10, dropout_rate=dropout_rate, num_classes=num_classes) 98 | 99 | 100 | if __name__ == '__main__': 101 | net=Wide_ResNet(28, 10, 0.3, 10) 102 | y = net(Variable(torch.randn(1,3,32,32))) 103 | 104 | print(y.size()) 105 | -------------------------------------------------------------------------------- /optimizer/sparse_sgd.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2019 Apple Inc. All Rights Reserved. 4 | # 5 | import torch 6 | 7 | 8 | class SparseSGD(torch.optim.SGD): 9 | """ 10 | This class implements SGD for optimizing parameters where at each iteration only few parameters obtain a gradient. 11 | More specifically, we zero out the update to state and momentum buffer for parameters with zero gradient. 12 | 13 | Args: 14 | params (iterable): iterable of parameters to optimize or dicts defining parameter groups 15 | lr (float): learning rate 16 | momentum (float, optional): momentum factor (default: 0) 17 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 18 | dampening (float, optional): dampening for momentum (default: 0) 19 | nesterov (bool, optional): enables Nesterov momentum (default: False) 20 | skip_update_zero_grad (bool, optional): if True, we will zero out the update to state and momentum buffer 21 | for parameters which are not in computation graph (eq. to zero gradient). 22 | """ 23 | 24 | def __init__(self, params, lr=0, momentum=0, dampening=0, 25 | weight_decay=0, nesterov=False, skip_update_zero_grad=False): 26 | super(SparseSGD, self).__init__(params, 27 | lr=lr, 28 | momentum=momentum, 29 | dampening=dampening, 30 | weight_decay=weight_decay, 31 | nesterov=nesterov) 32 | 33 | self.skip_update_zero_grad = skip_update_zero_grad 34 | str_disp = ' ' if self.skip_update_zero_grad else ' "not" ' 35 | print('Warning: skip_update_zero_grad set to {}. ' 36 | 'We will{}zero out update to state and momentum buffer ' 37 | 'for parameters with zero gradient. '.format(self.skip_update_zero_grad, 38 | str_disp)) 39 | assert weight_decay == 0, 'Weight decay for optimizer should be set to 0. ' \ 40 | 'For data parameters, we explicitly invoke weight decay on ' \ 41 | 'subset of data parameters in the computation graph.' 42 | 43 | def step(self, closure=None): 44 | """Performs a single optimization step. 45 | 46 | Arguments: 47 | closure (callable, optional): A closure that reevaluates the model 48 | and returns the loss. 49 | """ 50 | loss = None 51 | if closure is not None: 52 | loss = closure() 53 | 54 | for group in self.param_groups: 55 | weight_decay = group['weight_decay'] 56 | momentum = group['momentum'] 57 | dampening = group['dampening'] 58 | nesterov = group['nesterov'] 59 | 60 | for p in group['params']: 61 | if p.grad is None: 62 | continue 63 | d_p = p.grad.data 64 | 65 | # Generating pointers to old-state 66 | p_before_update = p.data.clone() 67 | 68 | if weight_decay != 0: 69 | d_p.add_(weight_decay, p.data) 70 | if momentum != 0: 71 | param_state = self.state[p] 72 | if 'momentum_buffer' not in param_state: 73 | # Initializes momentum buffer 74 | buf = param_state['momentum_buffer'] = torch.zeros_like(p.data) 75 | buf.mul_(momentum).add_(d_p) 76 | buf_before_update = None 77 | else: 78 | buf = param_state['momentum_buffer'] 79 | buf_before_update = buf.data.clone() 80 | buf.mul_(momentum).add_(1 - dampening, d_p) 81 | if nesterov: 82 | d_p = d_p.add(momentum, buf) 83 | else: 84 | d_p = buf 85 | 86 | p.data.add_(-group['lr'], d_p) 87 | 88 | # We need to revert back the state of parameter and momentum buffer for entries with zero-grad 89 | if self.skip_update_zero_grad: 90 | indices_without_grad = torch.abs(p.grad) == 0.0 91 | 92 | # Old Momentum buffer has updated parameters without gradient, reverting to old value 93 | p.data[indices_without_grad] = p_before_update.data[indices_without_grad] 94 | 95 | # Resetting momentum buffer parameters without gradient 96 | if (buf_before_update is not None) and (momentum != 0): 97 | param_state['momentum_buffer'].data[indices_without_grad] = \ 98 | buf_before_update.data[indices_without_grad] 99 | return loss 100 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2019 Apple Inc. All Rights Reserved. 4 | # 5 | """ Utility functions for training DNNs with data parameters""" 6 | import os 7 | import json 8 | import shutil 9 | import random 10 | 11 | import torch 12 | import numpy as np 13 | from tensorboard_logger import configure, log_value, log_histogram 14 | 15 | from optimizer.sparse_sgd import SparseSGD 16 | 17 | 18 | class AverageMeter(object): 19 | """Computes and stores the average and current value.""" 20 | def __init__(self, name, fmt=':f'): 21 | self.name = name 22 | self.fmt = fmt 23 | self.reset() 24 | 25 | def reset(self): 26 | self.val = 0 27 | self.avg = 0 28 | self.sum = 0 29 | self.count = 0 30 | 31 | def update(self, val, n=1): 32 | self.val = val 33 | self.sum += val * n 34 | self.count += n 35 | self.avg = self.sum / self.count 36 | 37 | def __str__(self): 38 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 39 | return fmtstr.format(**self.__dict__) 40 | 41 | 42 | def compute_topk_accuracy(prediction, target, topk=(1,)): 43 | """Computes the accuracy over the k top predictions for the specified values of k. 44 | 45 | Args: 46 | prediction (torch.Tensor): N*C tensor, contains logits for N samples over C classes. 47 | target (torch.Tensor): labels for each row in prediction. 48 | topk (tuple of int): different values of k for which top-k accuracy should be computed. 49 | 50 | Returns: 51 | result (tuple of float): accuracy at different top-k. 52 | """ 53 | with torch.no_grad(): 54 | maxk = max(topk) 55 | batch_size = target.size(0) 56 | 57 | _, pred = prediction.topk(maxk, 1, True, True) 58 | pred = pred.t() 59 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 60 | 61 | result = [] 62 | for k in topk: 63 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 64 | result.append(correct_k.mul_(100.0 / batch_size)) 65 | return result 66 | 67 | 68 | def save_artifacts(args, epoch, model, class_parameters, inst_parameters): 69 | """Saves model and data parameters. 70 | 71 | Args: 72 | args (argparse.Namespace): 73 | epoch (int): current epoch 74 | model (torch.nn.Module): DNN model. 75 | class_parameters (torch.Tensor): class level data parameters. 76 | inst_parameters (torch.Tensor): instance level data parameters. 77 | """ 78 | artifacts = { 79 | 'epoch': epoch + 1, 80 | 'arch': args.arch, 81 | 'state_dict': model.state_dict(), 82 | 'class_parameters': class_parameters.cpu().detach().numpy(), 83 | 'inst_parameters': inst_parameters.cpu().detach().numpy() 84 | } 85 | 86 | file_path = args.save_dir + '/epoch_{}.pth.tar'.format(epoch) 87 | torch.save(obj=artifacts, f=file_path) 88 | 89 | 90 | def save_config(save_dir, cfg): 91 | """Save config file to disk at save_dir. 92 | 93 | Args: 94 | save_dir (str): path to directory. 95 | cfg (dict): config file. 96 | """ 97 | save_path = save_dir + '/config.json' 98 | if os.path.isfile(save_path): 99 | raise Exception("Expected an empty folder but found an existing config file.., aborting") 100 | with open(save_path, 'w') as outfile: 101 | json.dump(cfg, outfile) 102 | 103 | 104 | def generate_save_dir(args): 105 | """Generate directory to save artifacts and tensorboard log files.""" 106 | 107 | print('\nModel artifacts (checkpoints and config) are going to be saved in: {}'.format(args.save_dir)) 108 | if os.path.exists(args.save_dir): 109 | if args.restart: 110 | print('Deleting old model artifacts found in: {}'.format(args.save_dir)) 111 | shutil.rmtree(args.save_dir) 112 | os.makedirs(args.save_dir) 113 | else: 114 | error='Old artifacts found; pass --restart flag to erase'.format(args.save_dir) 115 | raise Exception(error) 116 | else: 117 | os.makedirs(args.save_dir) 118 | 119 | 120 | def generate_log_dir(args): 121 | """Generate directory to save artifacts and tensorboard log files.""" 122 | 123 | print('\nLog is going to be saved in: {}'.format(args.log_dir)) 124 | 125 | if os.path.exists(args.log_dir): 126 | if args.restart: 127 | print('Deleting old log found in: {}'.format(args.log_dir)) 128 | shutil.rmtree(args.log_dir) 129 | configure(args.log_dir, flush_secs=10) 130 | else: 131 | error='Old log found; pass --restart flag to erase'.format(args.log_dir) 132 | raise Exception(error) 133 | else: 134 | configure(args.log_dir, flush_secs=10) 135 | 136 | 137 | def set_seed(args): 138 | """Set seed to ensure deterministic runs. 139 | 140 | Note: Setting torch to be deterministic can lead to slow down in training. 141 | """ 142 | random.seed(args.seed) 143 | torch.manual_seed(args.seed) 144 | torch.cuda.manual_seed(args.seed) 145 | np.random.seed(args.seed) 146 | torch.backends.cudnn.deterministic = True 147 | torch.backends.cudnn.benchmark = False 148 | 149 | 150 | def get_class_inst_data_params_n_optimizer(args, 151 | nr_classes, 152 | nr_instances, 153 | device): 154 | """Returns class and instance level data parameters and their corresponding optimizers. 155 | 156 | Args: 157 | args (argparse.Namespace): 158 | nr_classes (int): number of classes in dataset. 159 | nr_instances (int): number of instances in dataset. 160 | device (str): device on which data parameters should be placed. 161 | 162 | Returns: 163 | class_parameters (torch.Tensor): class level data parameters. 164 | inst_parameters (torch.Tensor): instance level data parameters 165 | optimizer_class_param (SparseSGD): Sparse SGD optimizer for class parameters 166 | optimizer_inst_param (SparseSGD): Sparse SGD optimizer for instance parameters 167 | """ 168 | # class-parameter 169 | class_parameters = torch.tensor(np.ones(nr_classes) * np.log(args.init_class_param), 170 | dtype=torch.float32, 171 | requires_grad=args.learn_class_parameters, 172 | device=device) 173 | optimizer_class_param = SparseSGD([class_parameters], 174 | lr=args.lr_class_param, 175 | momentum=0.9, 176 | skip_update_zero_grad=True) 177 | if args.learn_class_parameters: 178 | print('Initialized class_parameters with: {}'.format(args.init_class_param)) 179 | print('optimizer_class_param:') 180 | print(optimizer_class_param) 181 | 182 | # instance-parameter 183 | inst_parameters = torch.tensor(np.ones(nr_instances) * np.log(args.init_inst_param), 184 | dtype=torch.float32, 185 | requires_grad=args.learn_inst_parameters, 186 | device=device) 187 | optimizer_inst_param = SparseSGD([inst_parameters], 188 | lr=args.lr_inst_param, 189 | momentum=0.9, 190 | skip_update_zero_grad=True) 191 | if args.learn_inst_parameters: 192 | print('Initialized inst_parameters with: {}'.format(args.init_inst_param)) 193 | print('optimizer_inst_param:') 194 | print(optimizer_inst_param) 195 | 196 | return class_parameters, inst_parameters, optimizer_class_param, optimizer_inst_param 197 | 198 | 199 | def get_data_param_for_minibatch(args, 200 | class_param_minibatch, 201 | inst_param_minibatch): 202 | """Returns the effective data parameter for instances in a minibatch as per the specified curriculum. 203 | 204 | Args: 205 | args (argparse.Namespace): 206 | class_param_minibatch (torch.Tensor): class level parameters for samples in minibatch. 207 | inst_param_minibatch (torch.Tensor): instance level parameters for samples in minibatch. 208 | 209 | Returns: 210 | effective_data_param_minibatch (torch.Tensor): data parameter for samples in the minibatch. 211 | """ 212 | sigma_class_minibatch = torch.exp(class_param_minibatch).view(-1, 1) 213 | sigma_inst_minibatch = torch.exp(inst_param_minibatch).view(-1, 1) 214 | 215 | if args.learn_class_parameters and args.learn_inst_parameters: 216 | # Joint curriculum 217 | effective_data_param_minibatch = sigma_class_minibatch + sigma_inst_minibatch 218 | elif args.learn_class_parameters: 219 | # Class level curriculum 220 | effective_data_param_minibatch = sigma_class_minibatch 221 | elif args.learn_inst_parameters: 222 | # Instance level curriculum 223 | effective_data_param_minibatch = sigma_inst_minibatch 224 | else: 225 | # This corresponds to the baseline case without data parameters 226 | effective_data_param_minibatch = 1.0 227 | 228 | return effective_data_param_minibatch 229 | 230 | 231 | def apply_weight_decay_data_parameters(args, loss, class_parameter_minibatch, inst_parameter_minibatch): 232 | """Applies weight decay on class and instance level data parameters. 233 | 234 | We apply weight decay on only those data parameters which participate in a mini-batch. 235 | To apply weight-decay on a subset of data parameters, we explicitly include l2 penalty on these data 236 | parameters in the computational graph. Note, l2 penalty is applied in log domain. This encourages 237 | data parameters to stay close to value 1, and prevents data parameters from obtaining very high or 238 | low values. 239 | 240 | Args: 241 | args (argparse.Namespace): 242 | loss (torch.Tensor): loss of DNN model during forward. 243 | class_parameter_minibatch (torch.Tensor): class level parameters for samples in minibatch. 244 | inst_parameter_minibatch (torch.Tensor): instance level parameters for samples in minibatch. 245 | 246 | Returns: 247 | loss (torch.Tensor): loss augmented with l2 penalty on data parameters. 248 | """ 249 | 250 | # Loss due to weight decay on instance-parameters 251 | if args.learn_inst_parameters and args.wd_inst_param > 0.0: 252 | loss = loss + 0.5 * args.wd_inst_param * (inst_parameter_minibatch ** 2).sum() 253 | 254 | # Loss due to weight decay on class-parameters 255 | if args.learn_class_parameters and args.wd_class_param > 0.0: 256 | # (We apply weight-decay to only those classes which are present in the mini-batch) 257 | loss = loss + 0.5 * args.wd_class_param * (class_parameter_minibatch ** 2).sum() 258 | 259 | return loss 260 | 261 | 262 | def clamp_data_parameters(args, class_parameters, config, inst_parameters): 263 | """Clamps class and instance level parameters within specified range. 264 | 265 | Args: 266 | args (argparse.Namespace): 267 | class_parameters (torch.Tensor): class level parameters. 268 | inst_parameters (torch.Tensor): instance level parameters. 269 | config (dict): config file for the experiment. 270 | """ 271 | if args.skip_clamp_data_param is False: 272 | if args.learn_inst_parameters: 273 | # Project the sigma's to be within certain range 274 | inst_parameters.data = inst_parameters.data.clamp_( 275 | min=config['clamp_inst_sigma']['min'], 276 | max=config['clamp_inst_sigma']['max']) 277 | if args.learn_class_parameters: 278 | # Project the sigma's to be within certain range 279 | class_parameters.data = class_parameters.data.clamp_( 280 | min=config['clamp_cls_sigma']['min'], 281 | max=config['clamp_cls_sigma']['max']) 282 | 283 | 284 | def log_stats(data, name, step): 285 | """Logs statistics on tensorboard for data tensor. 286 | 287 | Args: 288 | data (torch.Tensor): torch tensor. 289 | name (str): name under which stats for the tensor should be logged. 290 | step (int): step used for logging 291 | """ 292 | log_value('{}/highest'.format(name), torch.max(data).item(), step=step) 293 | log_value('{}/lowest'.format(name), torch.min(data).item(), step=step) 294 | log_value('{}/mean'.format(name), torch.mean(data).item(), step=step) 295 | log_value('{}/std'.format(name), torch.std(data).item(), step=step) 296 | log_histogram('{}'.format(name), data.data.cpu().numpy(), step=step) 297 | 298 | 299 | def log_intermediate_iteration_stats(args, class_parameters, epoch, global_iter, 300 | inst_parameters, losses, top1=None, top5=None): 301 | """Log stats for data parameters and loss on tensorboard.""" 302 | if top5 is not None: 303 | log_value('train_iteration_stats/accuracy_top5', top5.avg, step=global_iter) 304 | if top1 is not None: 305 | log_value('train_iteration_stats/accuracy_top1', top1.avg, step=global_iter) 306 | log_value('train_iteration_stats/loss', losses.avg, step=global_iter) 307 | log_value('train_iteration_stats/epoch', epoch, step=global_iter) 308 | 309 | # Log temperature stats 310 | if args.learn_class_parameters: 311 | log_stats(data=torch.exp(class_parameters), 312 | name='iter_stats_class_parameter', 313 | step=global_iter) 314 | if args.learn_inst_parameters: 315 | log_stats(data=torch.exp(inst_parameters), 316 | name='iter_stats_inst_parameter', 317 | step=global_iter) 318 | 319 | 320 | --------------------------------------------------------------------------------