├── .github └── workflows │ └── pythonpublish.yml ├── .gitignore ├── LICENSE.md ├── README.md ├── aum ├── __init__.py ├── aum.py ├── dataset.py └── version.py ├── examples ├── cifar100 │ ├── README.md │ └── train.py └── paper_replication │ ├── README.md │ ├── large_dataset_aum.sh │ ├── large_dataset_baseline.sh │ ├── losses.py │ ├── models │ ├── __init__.py │ ├── conv4.py │ ├── densenet.py │ ├── lenet.py │ ├── resnet.py │ ├── vgg.py │ └── wide_resnet.py │ ├── runner.py │ ├── runner_testing.py │ ├── small_dataset_aum.sh │ ├── small_dataset_baseline.sh │ └── util.py ├── requirements.txt ├── setup.cfg ├── setup.py └── test ├── requirements-dev.txt └── test_aum.py /.github/workflows/pythonpublish.yml: -------------------------------------------------------------------------------- 1 | name: Upload Python Package 2 | 3 | on: 4 | release: 5 | types: [created] 6 | 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v1 12 | - name: Set up Python 13 | uses: actions/setup-python@v1 14 | with: 15 | python-version: '3.7.x' 16 | - name: Install dependencies 17 | run: | 18 | python -m pip install --upgrade pip 19 | pip install setuptools wheel twine 20 | - name: Build and publish 21 | env: 22 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 23 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 24 | run: | 25 | python setup.py sdist bdist_wheel 26 | twine upload dist/* 27 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # python version 107 | .python-version 108 | 109 | # vscode 110 | .vscode 111 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 ASAPP Research 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AUM 2 | 3 | Pytorch Library for Area Under the Margin (AUM) Ranking, as proposed in the paper: 4 | [Identifying Mislabeled Data using the Area Under the Margin Ranking](https://arxiv.org/pdf/2001.10528.pdf) 5 | 6 | ## Install 7 | 8 | `pip install -U aum` 9 | 10 | ## Usage 11 | 12 | Instantiate an AUMCalculator object: 13 | 14 | ```python 15 | from aum import AUMCalculator 16 | 17 | save_dir = '~/Desktop' 18 | aum_calculator = AUMCalculator(save_dir, compressed=True) 19 | ``` 20 | Note: you can set `compressed` to `False` if you want to store the AUM metrics at every call to the update method. This will require considerably more space, however. 21 | 22 | You can then update aum rankings on batches of data during training with: 23 | 24 | ```python 25 | model.train() 26 | for batch in loader: 27 | inputs, targets, sample_ids = batch 28 | 29 | logits = model(inputs) 30 | 31 | records = aum_calculator.update(logits, targets, sample_ids) 32 | 33 | ... 34 | ``` 35 | 36 | `records` is a dictionary mapping a sample_id to an `AUMRecord` containing the information below, including the AUM for the sample at this point in time. 37 | 38 | ```python 39 | @dataclass 40 | class AUMRecord: 41 | """ 42 | Class for holding info around an aum update for a single sample 43 | """ 44 | sample_id: Optional[int, str] 45 | num_measurements: int 46 | target_logit: int 47 | target_val: float 48 | other_logit: int 49 | other_val: float 50 | margin: float 51 | aum: float 52 | ``` 53 | 54 | And once you are done training, you can generate a csv of ranked samples with their aum scores with: 55 | 56 | ```python 57 | aum_calculator.finalize() 58 | ``` 59 | 60 | If you have a dataset that does not return sample_ids, you can wrap it in `DatasetWithIndex`. The last element of the tuple returned for a given sample will be its sample_id. 61 | ```python 62 | from aum import DatasetWithIndex 63 | from torch.utils.data import Dataset 64 | 65 | my_dataset = Dataset(...) 66 | my_dataset_with_index = DatasetWithIndex(my_dataset) 67 | ``` 68 | 69 | ## Example Outputs 70 | Calling `finalize()` on an AUMCalculator will result in the creation of 1 or 2 csv files, depending if `compressed` was set to True or False. 71 | 72 | If AUMCalculator was instantiated with `compressed = True`, you will find a csv file titled `aum_values.csv` in the following format: 73 | 74 | | sample_id | aum | 75 | |-----------|--------| 76 | | sample_1 | 1.205 | 77 | | sample_3 | 1.145 | 78 | | sample_2 | -3.785 | 79 | 80 | If AUMCalculator was instantiated with `compressed = False`, you will find a csv file titled `full_aum_records.csv` in addition to the `aum_values.csv`. `full_aum_records.csv` is in the following format: 81 | 82 | | sample_id | num_measurements | target_logit | target_val | other_logit | other_val | margin | aum | 83 | |-----------|------------------|--------------|------------|-------------|-----------|--------|--------| 84 | | sample_1 | 1 | 0 | 3.74 | 10 | 2.48 | 1.26 | 1.26 | 85 | | sample_1 | 2 | 0 | 4.59 | 10 | 3.44 | 1.15 | 1.205 | 86 | | sample_2 | 1 | 1 | -0.09 | 0 | 3.11 | -3.20 | -3.02 | 87 | | sample_2 | 2 | 1 | -1.12 | 0 | 3.25 | -4.37 | -3.785 | 88 | | sample_3 | 1 | 6 | 3.39 | 10 | 1.62 | 1.77 | 1.77 | 89 | | sample_3 | 2 | 6 | 2.63 | 2 | 2.11 | 0.52 | 1.145 | 90 | 91 | 92 | ## Replicate results from the paper 93 | To replicate results, please refer to the [examples/paper_replication](examples/paper_replication) section. 94 | 95 | ## Example usage 96 | For a more basic example of using the `AUMCalculator` and `DatasetWithIndex` in a training script, please refer to the [examples/cifar100](examples/cifar100) section. 97 | 98 | ## Cite 99 | ```sh 100 | @article{pleiss2020identifying, 101 | title={Identifying Mislabeled Data using the Area Under the Margin Ranking}, 102 | author={Geoff Pleiss and Tianyi Zhang and Ethan R. Elenberg and Kilian Q. Weinberger}, 103 | journal={arXiv preprint arXiv:2001.10528}, 104 | year={2020} 105 | } 106 | ``` 107 | -------------------------------------------------------------------------------- /aum/__init__.py: -------------------------------------------------------------------------------- 1 | from .aum import AUMCalculator, AUMRecord 2 | from .dataset import DatasetWithIndex 3 | 4 | __all__ = ['AUMCalculator', 'AUMRecord', 'DatasetWithIndex'] 5 | -------------------------------------------------------------------------------- /aum/aum.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | from collections import defaultdict, namedtuple 4 | from dataclasses import asdict, dataclass 5 | from pathlib import Path 6 | from typing import Dict, List, Optional, Union 7 | 8 | import pandas as pd 9 | import torch 10 | 11 | sample_identifier = Union[int, str] 12 | 13 | 14 | @dataclass 15 | class AUMRecord: 16 | """ 17 | Class for holding info around an aum update for a single sample 18 | """ 19 | sample_id: sample_identifier 20 | num_measurements: int 21 | target_logit: int 22 | target_val: float 23 | other_logit: int 24 | other_val: float 25 | margin: float 26 | aum: float 27 | 28 | 29 | class AUMCalculator(): 30 | def __init__(self, save_dir: str, compressed: bool = True): 31 | """ 32 | Intantiates the AUM object 33 | 34 | :param save_dir (str): Directory location of where to save out the final csv file(s) 35 | when calling `finalize` 36 | :param compressed (bool): Dictates how much information to store. If True, the object 37 | will only keep track of enough information to return the final AUM for each sample 38 | when `finalize` is called. 39 | If False, the object will keep track of the AUM value at each update call, 40 | storing an AUMRecord per sample per update call. This will also result in a 41 | `full_aum_records.csv` being saved out when calling `finalize`. 42 | Defaults to True. 43 | """ 44 | self.save_dir = save_dir 45 | self.counts = defaultdict(int) 46 | self.sums = defaultdict(float) 47 | 48 | self.compressed = compressed 49 | if not compressed: 50 | self.records = [] 51 | 52 | def update(self, logits: torch.Tensor, targets: torch.Tensor, 53 | sample_ids: List[sample_identifier]) -> Dict[sample_identifier, AUMRecord]: 54 | """ 55 | Updates the running totals and calculates the AUM values for the given samples 56 | 57 | :param logits (torch.Tensor): A 2 dimensional tensor where each row contains the logits 58 | for a given sample. 59 | :param targets (torch.Tensor): A 1 dimensional tensor containing the index of the target 60 | logit for a given sample. 61 | :param sample_ids (List[sample_identifier]): A list mapping each row of the logits & targets 62 | tensors to a sample id. This can be a list of ints or strings. 63 | 64 | :return (Dict[sample_identifier, AUMRecord]): A dictionary mapping each sample identifier 65 | to an AUMRecord. The AUMRecord contains the current AUM data for the given sample after 66 | this update step has been called. 67 | """ 68 | 69 | target_values = logits.gather(1, targets.view(-1, 1)).squeeze() 70 | 71 | # mask out target values 72 | masked_logits = torch.scatter(logits, 1, targets.view(-1, 1), float('-inf')) 73 | other_logit_values, other_logit_index = masked_logits.max(1) 74 | other_logit_values = other_logit_values.squeeze() 75 | other_logit_index = other_logit_index.squeeze() 76 | 77 | margin_values = (target_values - other_logit_values).tolist() 78 | 79 | updated_aums = {} 80 | for i, (sample_id, margin) in enumerate(zip(sample_ids, margin_values)): 81 | self.counts[sample_id] += 1 82 | self.sums[sample_id] += margin 83 | 84 | record = AUMRecord(sample_id=sample_id, 85 | num_measurements=self.counts[sample_id], 86 | target_logit=targets[i].item(), 87 | target_val=target_values[i].item(), 88 | other_logit=other_logit_index[i].item(), 89 | other_val=other_logit_values[i].item(), 90 | margin=margin, 91 | aum=self.sums[sample_id] / self.counts[sample_id]) 92 | 93 | updated_aums[sample_id] = record 94 | if not self.compressed: 95 | self.records.append(record) 96 | 97 | return updated_aums 98 | 99 | def finalize(self, save_dir: Optional[str] = None) -> None: 100 | """ 101 | Calculates AUM for each sample given the data gathered on each update call. 102 | Outputs a `aum_values.csv` file containing the final AUM values for each sample. 103 | If `self.compressed` set to False, this will also output a `full_aum_records.csv` file 104 | containing AUM values for each sample at each update call. 105 | 106 | :param save_dir (Optional[str]): Allows the ability to overwrite the original save 107 | directory that was set on instantiation of the AUM object. When set to None, the 108 | directory set on instantiation will be used. Defaults to None. 109 | """ 110 | save_dir = save_dir or self.save_dir 111 | Path(save_dir).mkdir(parents=True, exist_ok=True) 112 | 113 | results = [{ 114 | 'sample_id': sample_id, 115 | 'aum': self.sums[sample_id] / self.counts[sample_id] 116 | } for sample_id in self.counts.keys()] 117 | 118 | result_df = pd.DataFrame(results).sort_values(by='aum', ascending=False) 119 | 120 | save_path = os.path.join(save_dir, 'aum_values.csv') 121 | result_df.to_csv(save_path, index=False) 122 | 123 | if not self.compressed: 124 | records_df = AUMCalculator.records_to_df(self.records) 125 | save_path = os.path.join(save_dir, 'full_aum_records.csv') 126 | records_df.to_csv(save_path, index=False) 127 | 128 | @staticmethod 129 | def records_to_df(records: List[AUMRecord]) -> pd.DataFrame: 130 | """ 131 | Converts a list of AUMRecords to a dataframe, sorted by sample_id & num_measurements 132 | 133 | :param records (List[AUMRecord]): A list of AUMRecords 134 | 135 | :return (pd.DataFrame): a dataframe, sorted by sample_id & num_measurements 136 | """ 137 | df = pd.DataFrame([asdict(record) for record in records]) 138 | df.sort_values(by=['sample_id', 'num_measurements'], inplace=True) 139 | return df 140 | -------------------------------------------------------------------------------- /aum/dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | 3 | 4 | class DatasetWithIndex(Dataset): 5 | """ 6 | A thin wrapper over a pytorch dataset that includes the sample index as the last element 7 | of the tuple returned. 8 | """ 9 | def __init__(self, base_dataset: Dataset): 10 | self.base_dataset = base_dataset 11 | 12 | def __len__(self): 13 | return len(self.base_dataset) 14 | 15 | def __getitem__(self, index): 16 | return (*self.base_dataset[index], index) 17 | -------------------------------------------------------------------------------- /aum/version.py: -------------------------------------------------------------------------------- 1 | MAJOR = "1" 2 | MINOR = "0" 3 | PATCH = "2" 4 | 5 | VERSION = f'{MAJOR}.{MINOR}.{PATCH}' 6 | -------------------------------------------------------------------------------- /examples/cifar100/README.md: -------------------------------------------------------------------------------- 1 | # CIFAR-100 Example 2 | This is a simple example showing how to use the `AUMCalculator` and `DatasetWithIndex` in a training script. This script trains Resnet-34 on the CIFAR-100 dataset. At training completion, the aum artifacts will be located in the output directory. The samples with the lowest aum values are most likely mislabeled. 3 | 4 | ## Requirements 5 | - pytorch >= 1.3 6 | - torchvision >= 0.4 7 | - numpy 8 | - pandas 9 | - aum 10 | - tensorboard 11 | 12 | ## Usage 13 | You can call the script as follows: 14 | 15 | ```sh 16 | # for the compressed version of the AUMCalculator 17 | python train.py 18 | 19 | # For the uncompressed version of the AUMCalculator 20 | python train.py --detailed-aum 21 | ``` 22 | 23 | The script will run without any specified arguments as all have defaults, but to see all available arguments: 24 | ```sh 25 | python train.py --help 26 | ``` 27 | -------------------------------------------------------------------------------- /examples/cifar100/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | import random 5 | import time 6 | from pathlib import Path 7 | from pprint import pprint 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn.functional as F 12 | from torch.utils.data import DataLoader 13 | from torch.utils.tensorboard import SummaryWriter 14 | from torchvision import datasets, transforms 15 | from torchvision.models.resnet import resnet34 16 | 17 | from aum import AUMCalculator, DatasetWithIndex 18 | 19 | 20 | class AverageMeter(object): 21 | """ 22 | Computes and stores the average and current value 23 | Copied from: https://github.com/pytorch/examples/blob/master/imagenet/main.py 24 | """ 25 | def __init__(self): 26 | self.reset() 27 | 28 | def reset(self): 29 | self.val = 0 30 | self.avg = 0 31 | self.sum = 0 32 | self.count = 0 33 | 34 | def update(self, val, n=1): 35 | self.val = val 36 | self.sum += val * n 37 | self.count += n 38 | self.avg = self.sum / self.count 39 | 40 | 41 | def set_seed(seed: int): 42 | """ 43 | Sets random, numpy, torch, and torch.cuda seeds 44 | """ 45 | 46 | random.seed(seed) 47 | np.random.seed(seed) 48 | torch.manual_seed(seed) 49 | torch.cuda.manual_seed_all(seed) 50 | 51 | 52 | def train_step(args, summary_writer, metrics, aum_calculator, log_interval, batch_step, num_batches, 53 | batch, epoch, num_epochs, global_step, model, optimizer, device): 54 | start = time.time() 55 | model.train() 56 | with torch.enable_grad(): 57 | optimizer.zero_grad() 58 | 59 | input, target, sample_ids = batch 60 | input = input.to(device) 61 | target = target.to(device) 62 | 63 | # Compute output 64 | output = model(input) 65 | loss = F.cross_entropy(output, target) 66 | 67 | # Compute gradient and optimize 68 | loss.backward() 69 | optimizer.step() 70 | 71 | # Measure accuracy & record loss 72 | end = time.time() 73 | batch_size = target.size(0) 74 | _, pred = output.data.cpu().topk(1, dim=1) 75 | error = torch.ne(pred.squeeze(), target.cpu()).float().sum().item() / batch_size 76 | 77 | metrics['error'].update(error, batch_size) 78 | metrics['loss'].update(loss.item(), batch_size) 79 | metrics['batch_time'].update(end - start) 80 | 81 | # Update AUM 82 | aum_calculator.update(output, target, sample_ids.tolist()) 83 | 84 | # log to tensorboard 85 | summary_writer.add_scalar('train/error', metrics['error'].val, global_step) 86 | summary_writer.add_scalar('train/loss', metrics['loss'].val, global_step) 87 | summary_writer.add_scalar('train/batch_time', metrics['batch_time'].val, global_step) 88 | 89 | # log to console 90 | if (batch_step + 1) % log_interval == 0: 91 | results = '\t'.join([ 92 | 'TRAIN', 93 | f'Epoch: [{epoch}/{num_epochs}]', 94 | f'Batch: [{batch_step}/{num_batches}]', 95 | f'Time: {metrics["batch_time"].val:.3f} ({metrics["batch_time"].avg:.3f})', 96 | f'Loss: {metrics["loss"].val:.3f} ({metrics["loss"].avg:.3f})', 97 | f'Error: {metrics["error"].val:.3f} ({metrics["error"].avg:.3f})', 98 | ]) 99 | print(results) 100 | 101 | 102 | def eval_step(args, regime, metrics, log_interval, batch_step, num_batches, batch, epoch, 103 | num_epochs, model, device): 104 | start = time.time() 105 | model.eval() 106 | with torch.no_grad(): 107 | input, target, sample_ids = batch 108 | input = input.to(device) 109 | target = target.to(device) 110 | 111 | # Compute output 112 | output = model(input) 113 | loss = F.cross_entropy(output, target) 114 | 115 | # Measure accuracy & record loss 116 | end = time.time() 117 | batch_size = target.size(0) 118 | _, pred = output.data.cpu().topk(1, dim=1) 119 | error = torch.ne(pred.squeeze(), target.cpu()).float().sum().item() / batch_size 120 | 121 | metrics['error'].update(error, batch_size) 122 | metrics['loss'].update(loss.item(), batch_size) 123 | metrics['batch_time'].update(end - start) 124 | 125 | # log to console 126 | if (batch_step + 1) % log_interval == 0: 127 | results = '\t'.join([ 128 | regime, 129 | f'Epoch: [{epoch}/{num_epochs}]', 130 | f'Batch: [{batch_step}/{num_batches}]', 131 | f'Time: {metrics["batch_time"].val:.3f} ({metrics["batch_time"].avg:.3f})', 132 | f'Loss: {metrics["loss"].val:.3f} ({metrics["loss"].avg:.3f})', 133 | f'Error: {metrics["error"].val:.3f} ({metrics["error"].avg:.3f})', 134 | ]) 135 | print(results) 136 | 137 | 138 | def parse_args(): 139 | parser = argparse.ArgumentParser() 140 | 141 | # Dataset 142 | parser.add_argument('--data-dir', type=str, default='./', help='where to download dataset') 143 | parser.add_argument('--valid-size', 144 | type=int, 145 | default=5000, 146 | help='num samples in validation set') 147 | 148 | # Output/logging file 149 | parser.add_argument('--log-interval', 150 | type=int, 151 | default=10, 152 | help='how many steps between logging to the console') 153 | parser.add_argument('--output-dir', 154 | type=str, 155 | default='./output', 156 | help='where to save out the model, must be an existing directory.') 157 | 158 | parser.add_argument('--detailed-aum', 159 | action='store_true', 160 | help='if set, the AUM calculations will be done in non-compressed mode') 161 | 162 | # Optimizer params 163 | parser.add_argument('--learning-rate', type=float, default=0.1, help='optimizer learning rate') 164 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum for optimizer') 165 | 166 | # Training Regime params 167 | parser.add_argument('--num-epochs', 168 | type=int, 169 | default=150, 170 | help='number of epochs to train over') 171 | parser.add_argument('--train-batch-size', type=int, default=64, help='size of training batch') 172 | 173 | # Validation Regime params 174 | parser.add_argument('--val-batch-size', type=int, default=64, help='size of val batch') 175 | 176 | args = parser.parse_args() 177 | return args 178 | 179 | 180 | def main(args): 181 | pprint(vars(args)) 182 | 183 | # Setup experiment folder structure 184 | # Create output folder if it doesn't exist 185 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 186 | 187 | # save out args 188 | with open(os.path.join(args.output_dir, 'args.txt'), 'w+') as f: 189 | pprint(vars(args), f) 190 | 191 | # Setup summary writer 192 | summary_writer = SummaryWriter(log_dir=os.path.join(args.output_dir, 'tb_logs')) 193 | 194 | # Set seeds 195 | set_seed(42) 196 | 197 | # Load dataset 198 | # Data transforms 199 | mean = [0.5071, 0.4867, 0.4408] 200 | stdv = [0.2675, 0.2565, 0.2761] 201 | train_transforms = transforms.Compose([ 202 | transforms.RandomCrop(32, padding=4), 203 | transforms.RandomHorizontalFlip(), 204 | transforms.ToTensor(), 205 | transforms.Normalize(mean=mean, std=stdv), 206 | ]) 207 | test_transforms = transforms.Compose([ 208 | transforms.ToTensor(), 209 | transforms.Normalize(mean=mean, std=stdv), 210 | ]) 211 | 212 | # Datasets 213 | train_set = datasets.CIFAR100(args.data_dir, 214 | train=True, 215 | transform=train_transforms, 216 | download=True) 217 | val_set = datasets.CIFAR100(args.data_dir, train=True, transform=test_transforms) 218 | test_set = datasets.CIFAR100(args.data_dir, train=False, transform=test_transforms) 219 | 220 | indices = torch.randperm(len(train_set)) 221 | train_indices = indices[:len(indices) - args.valid_size] 222 | valid_indices = indices[len(indices) - args.valid_size:] 223 | train_set = torch.utils.data.Subset(train_set, train_indices) 224 | val_set = torch.utils.data.Subset(val_set, valid_indices) 225 | 226 | train_set = DatasetWithIndex(train_set) 227 | val_set = DatasetWithIndex(val_set) 228 | test_set = DatasetWithIndex(test_set) 229 | 230 | val_loader = DataLoader(val_set, 231 | batch_size=args.val_batch_size, 232 | shuffle=False, 233 | pin_memory=(torch.cuda.is_available())) 234 | test_loader = DataLoader(test_set, 235 | batch_size=args.val_batch_size, 236 | shuffle=False, 237 | pin_memory=(torch.cuda.is_available())) 238 | 239 | # Load Model 240 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 241 | model = resnet34(num_classes=100) 242 | model = model.to(device) 243 | num_params = sum(x.numel() for x in model.parameters() if x.requires_grad) 244 | print(model) 245 | f'Number of parameters: {num_params}' 246 | 247 | # Create optimizer & lr scheduler 248 | parameters = [p for p in model.parameters() if p.requires_grad] 249 | optimizer = torch.optim.SGD(parameters, 250 | lr=args.learning_rate, 251 | momentum=args.momentum, 252 | nesterov=True) 253 | milestones = [0.5 * args.num_epochs, 0.75 * args.num_epochs] 254 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.1) 255 | 256 | # Keep track of AUM 257 | aum_calculator = AUMCalculator(args.output_dir, compressed=(not args.detailed_aum)) 258 | 259 | # Keep track of things 260 | global_step = 0 261 | best_error = math.inf 262 | 263 | print('Beginning training') 264 | for epoch in range(args.num_epochs): 265 | 266 | train_loader = DataLoader(train_set, 267 | batch_size=args.train_batch_size, 268 | shuffle=True, 269 | pin_memory=(torch.cuda.is_available()), 270 | num_workers=0) 271 | 272 | train_metrics = { 273 | 'loss': AverageMeter(), 274 | 'error': AverageMeter(), 275 | 'batch_time': AverageMeter() 276 | } 277 | num_batches = len(train_loader) 278 | for batch_step, batch in enumerate(train_loader): 279 | train_step(args, summary_writer, train_metrics, aum_calculator, args.log_interval, 280 | batch_step, num_batches, batch, epoch, args.num_epochs, global_step, model, 281 | optimizer, device) 282 | 283 | global_step += 1 284 | 285 | scheduler.step() 286 | 287 | val_metrics = { 288 | 'loss': AverageMeter(), 289 | 'error': AverageMeter(), 290 | 'batch_time': AverageMeter() 291 | } 292 | num_batches = len(val_loader) 293 | for batch_step, batch in enumerate(val_loader): 294 | eval_step(args, 'VAL', val_metrics, args.log_interval, batch_step, num_batches, batch, 295 | epoch, args.num_epochs, model, device) 296 | 297 | # log eval metrics to tensorboard 298 | summary_writer.add_scalar('val/error', val_metrics['error'].avg, global_step) 299 | summary_writer.add_scalar('val/loss', val_metrics['loss'].avg, global_step) 300 | summary_writer.add_scalar('val/batch_time', val_metrics['batch_time'].avg, global_step) 301 | 302 | # Save best model 303 | if val_metrics['error'].avg < best_error: 304 | best_error = val_metrics['error'].avg 305 | torch.save(model.state_dict(), os.path.join(args.output_dir, 'best.pt')) 306 | 307 | # Finalize aum calculator 308 | aum_calculator.finalize() 309 | 310 | # Eval best model on on test set 311 | model.load_state_dict(torch.load(os.path.join(args.output_dir, 'best.pt'))) 312 | test_metrics = {'loss': AverageMeter(), 'error': AverageMeter(), 'batch_time': AverageMeter()} 313 | num_batches = len(test_loader) 314 | for batch_step, batch in enumerate(test_loader): 315 | eval_step(args, 'TEST', test_metrics, args.log_interval, batch_step, num_batches, batch, -1, 316 | -1, model, device) 317 | 318 | # log eval metrics to tensorboard 319 | summary_writer.add_scalar('test/error', test_metrics['error'].avg, global_step) 320 | summary_writer.add_scalar('test/loss', test_metrics['loss'].avg, global_step) 321 | summary_writer.add_scalar('test/batch_time', test_metrics['batch_time'].avg, global_step) 322 | 323 | # log test metrics to console 324 | results = '\t'.join([ 325 | 'FINAL TEST RESULTS', 326 | f'Loss: {test_metrics["loss"].avg:.3f}', 327 | f'Error: {test_metrics["error"].avg:.3f}', 328 | ]) 329 | print(results) 330 | 331 | """ 332 | A demo to show how to calculate AUM while training a ResNet on CIFAR100. 333 | """ 334 | if __name__ == '__main__': 335 | args = parse_args() 336 | main(args) 337 | -------------------------------------------------------------------------------- /examples/paper_replication/README.md: -------------------------------------------------------------------------------- 1 | # Paper Replication 2 | 3 | ## Requirements 4 | - pytorch >=1.3 5 | - torchvision >= 0.4 6 | - numpy 7 | - pandas 8 | - tqdm 9 | - aum 10 | - fire (pip install fire) 11 | 12 | ## Datasets 13 | 14 | We run experiments on 3 **small** datasets... 15 | - cifar10 16 | - cifar100 17 | - tiny_imagenet 18 | 19 | ... and 3 **large** datasets 20 | - webvision50 21 | - clothing100k 22 | 23 | Download and untar the file here for all 5 datasets: 24 | https://drive.google.com/file/d/1rr2nvnnBMsbo1qcU3i3urJsDw86PJ9tR/view?usp=sharing 25 | 26 | Alternatively, if you just want to run CIFAR10 and CIFAR100 you don't need to download anything 27 | 28 | ## Run the baseline models 29 | These scripts produce the "Baseline" result in our tables. 30 | 31 | ```sh 32 | # For `dataset=cifar10`, `dataset=cifar100`, or `dataset=tiny_imagenet` 33 | ./small_dataset_baseline.sh 34 | # For `dataset=webvision50`, `dataset=clothing100k` 35 | ./large_dataset_baseline.sh 36 | ``` 37 | 38 | The arguments: 39 | - `` - set to something like `1` 40 | - `` - percentage of synthetic label noise to add (e.g. `0.2`). Set to `0` for no synthetic mislabeled data. 41 | - `` - either `uniform` or `flip`. 42 | 43 | Note that `./large_dataset_baseline` does not take an `` or `` argument. 44 | 45 | ## Run the AUM models 46 | These scripts produce the "AUM" result in our tables. 47 | 48 | ```sh 49 | # For `dataset=cifar10`, `dataset=cifar100`, or `dataset=tiny_imagenet` 50 | ./small_dataset_aum.sh 51 | # For `dataset=webvision50`, `dataset=clothing100k` 52 | ./large_dataset_aum.sh 53 | ``` 54 | -------------------------------------------------------------------------------- /examples/paper_replication/large_dataset_aum.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | if [ "$#" -ne 3 ]; then 4 | echo "You must enter exactly 3 command line arguments" 5 | fi 6 | 7 | datadir=$1 8 | dataset=$2 9 | seed=$3 10 | NETTYPE="resnet50" 11 | 12 | # General arguments for threshold sample trains 13 | args="--data ${datadir}/${dataset} --dataset ${dataset} --net_type ${NETTYPE}" 14 | args="${args} --seed ${seed} --num_valid 0 --use_threshold_samples" 15 | train_args="--num_epochs 60 --lr 0.1 --wd 1e-4 --batch_size 256" 16 | train_args="${train_args} --num_workers 0" 17 | 18 | # First threshold sample run 19 | savedir1="results/${dataset}_${NETTYPE}" 20 | savedir1="${savedir1}_threshold1_seed${seed}" 21 | cmd="python runner.py ${args} --save ${savedir1} --threshold_samples_set_idx 1 - train_for_aum_computation ${train_args} - done" 22 | echo $cmd 23 | if [ -z "${TESTRUN}" ]; then 24 | mkdir -p $savedir1 25 | echo $cmd > $savedir1/cmd.txt 26 | eval $cmd 27 | fi 28 | 29 | # Second threshold sample run 30 | savedir2="results/${dataset}_${NETTYPE}" 31 | savedir2="${savedir2}_threshold2_seed${seed}" 32 | cmd="python runner.py ${args} --save ${savedir2} --threshold_samples_set_idx 2 - train_for_aum_computation ${train_args} - done" 33 | echo $cmd 34 | if [ -z "${TESTRUN}" ]; then 35 | mkdir -p $savedir2 36 | echo $cmd > $savedir2/cmd.txt 37 | eval $cmd 38 | fi 39 | 40 | # Compute AUMs for first threshold sample run 41 | cmd="python runner.py ${args} --save ${savedir1} --threshold_samples_set_idx 1 - generate_aum_details - done" 42 | echo $cmd 43 | if [ -z "${TESTRUN}" ]; then 44 | mkdir -p ${savedir1} 45 | eval $cmd 46 | fi 47 | 48 | # Compute AUMs for the second threshold sample run 49 | cmd="python runner.py ${args} --save ${savedir2} --threshold_samples_set_idx 2 - generate_aum_details - done" 50 | echo $cmd 51 | if [ -z "${TESTRUN}" ]; then 52 | mkdir -p ${savedir2} 53 | eval $cmd 54 | fi 55 | 56 | # Remove the identified mislabeled saples and retrain 57 | savedir="results/${dataset}_${NETTYPE}" 58 | savedir="${savedir}_aumwtr_seed${seed}" 59 | args="--data ${datadir}/${dataset} --save ${savedir} --dataset ${dataset} --net_type ${NETTYPE}" 60 | args="${args} --seed ${seed} --num_valid 0" 61 | train_args="--num_epochs 180 --lr_drops 0.33,0.67 --lr 0.1 --wd 1e-4 --batch_size 256" 62 | train_args="${train_args} --num_workers 0 --aum_wtr ${savedir1},${savedir2}" 63 | cmd="python runner.py ${args} - train ${train_args} - done" 64 | echo $cmd 65 | if [ -z "${TESTRUN}" ]; then 66 | mkdir -p $savedir 67 | echo $cmd > $savedir/cmd.txt 68 | eval $cmd 69 | fi 70 | -------------------------------------------------------------------------------- /examples/paper_replication/large_dataset_baseline.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | if [ "$#" -ne 3 ]; then 4 | echo "You must enter exactly 3 command line arguments" 5 | fi 6 | 7 | datadir=$1 8 | DATASET=$2 9 | seed=$3 10 | NETTYPE="resnet50" 11 | 12 | savedir="results/${DATASET}_${NETTYPE}" 13 | savedir="${savedir}_baseline_seed${seed}" 14 | 15 | args="--data /home/ubuntu/${DATASET} --save ${savedir} --dataset ${DATASET} --net_type ${NETTYPE}" 16 | args="${args} --seed ${seed} --num_valid 0" 17 | 18 | train_args="--num_epochs 180 --lr_drops 0.33,0.67 --lr 0.1 --wd 1e-4 --batch_size 256" 19 | train_args="${train_args} --num_workers 0" 20 | 21 | cmd="python runner.py ${args} - train ${train_args} - done" 22 | echo $cmd 23 | if [ -z "${TESTRUN}" ]; then 24 | mkdir -p $savedir 25 | echo $cmd > $savedir/cmd.txt 26 | eval $cmd 27 | fi 28 | -------------------------------------------------------------------------------- /examples/paper_replication/losses.py: -------------------------------------------------------------------------------- 1 | import torch.nn 2 | from torch.nn.functional import cross_entropy 3 | 4 | 5 | def reed_soft(logits, targets, beta=0.95, reduction='none'): 6 | """ 7 | Soft version of Reed et al 2014. 8 | Equivalent to entropy regularization. 9 | """ 10 | assert reduction == 'none' # stupid quick hack 11 | cross_entropy_loss = cross_entropy(logits, targets, reduction=reduction) 12 | probs = torch.softmax(logits, dim=-1) 13 | entropy = -(probs.log() * probs).sum(dim=-1) 14 | loss = cross_entropy_loss * beta - entropy * (1 - beta) 15 | return loss 16 | 17 | 18 | def reed_hard(logits, targets, beta=0.8, reduction='none'): 19 | """ 20 | Soft version of Reed et al 2014. 21 | Equivalent to entropy regularization. 22 | """ 23 | cross_entropy_loss = cross_entropy(logits, targets, reduction=reduction) 24 | most_confident_probs = torch.softmax(logits, dim=-1).max(dim=-1)[0] 25 | loss = cross_entropy_loss * beta - (1 - beta) * most_confident_probs.log() 26 | return loss 27 | 28 | 29 | losses = { 30 | "cross-entropy": cross_entropy, 31 | "reed-soft": reed_soft, 32 | "reed-hard": reed_hard, 33 | } 34 | 35 | __all__ = ["losses"] 36 | -------------------------------------------------------------------------------- /examples/paper_replication/models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.conv4 import Conv4 2 | from models.densenet import DenseNet 3 | from models.lenet import LeNet, LeNetMNIST 4 | from models.resnet import ResNet 5 | from models.vgg import VGG 6 | from models.wide_resnet import WideResNet 7 | 8 | models = { 9 | "densenet": DenseNet, 10 | "resnet": ResNet, 11 | "wide_resnet": WideResNet, 12 | "vgg": VGG, 13 | "lenet": LeNet, 14 | "lenet_mnist": LeNetMNIST, 15 | "conv4": Conv4, 16 | } 17 | 18 | __all__ = [ 19 | "DenseNet", 20 | "WideResNet", 21 | "LeNet" 22 | "models", 23 | ] 24 | -------------------------------------------------------------------------------- /examples/paper_replication/models/conv4.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from torch import nn 3 | 4 | 5 | class Conv4(nn.Module): 6 | def __init__(self, num_classes, net_dataset="mnist"): 7 | super().__init__() 8 | final_size = 7 if net_dataset == "mnist" else 8 9 | input_channel = 1 if net_dataset == "mnist" else 3 10 | self.conv1 = nn.Conv2d(input_channel, 32, kernel_size=3, padding=1) 11 | self.conv2 = nn.Conv2d(32, 32, kernel_size=3, padding=1) 12 | self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1) 13 | self.conv4 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 14 | self.fc1 = nn.Linear(64 * final_size * final_size, 512) 15 | self.fc2 = nn.Linear(512, num_classes) 16 | 17 | def forward(self, x): 18 | x = F.relu(self.conv1(x)) 19 | x = F.relu(self.conv2(x)) 20 | x = F.max_pool2d(x, 2) 21 | x = F.relu(self.conv3(x)) 22 | x = F.relu(self.conv4(x)) 23 | x = F.max_pool2d(x, 2) 24 | x = x.view(x.size(0), -1) 25 | x = F.relu(self.fc1(x)) 26 | x = self.fc2(x) 27 | return x 28 | -------------------------------------------------------------------------------- /examples/paper_replication/models/densenet.py: -------------------------------------------------------------------------------- 1 | # This implementation is based on the DenseNet-BC implementation in torchvision 2 | # https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py 3 | 4 | import math 5 | from collections import OrderedDict 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.utils.checkpoint as cp 11 | 12 | 13 | def _bn_function_factory(norm, relu, conv): 14 | def bn_function(*inputs): 15 | concated_features = torch.cat(inputs, 1) 16 | bottleneck_output = conv(relu(norm(concated_features))) 17 | return bottleneck_output 18 | 19 | return bn_function 20 | 21 | 22 | class _DenseLayer(nn.Module): 23 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, efficient=False): 24 | super(_DenseLayer, self).__init__() 25 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)), 26 | self.add_module('relu1', nn.ReLU(inplace=True)), 27 | self.add_module( 28 | 'conv1', 29 | nn.Conv2d(num_input_features, 30 | bn_size * growth_rate, 31 | kernel_size=1, 32 | stride=1, 33 | bias=False)), 34 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), 35 | self.add_module('relu2', nn.ReLU(inplace=True)), 36 | self.add_module( 37 | 'conv2', 38 | nn.Conv2d(bn_size * growth_rate, 39 | growth_rate, 40 | kernel_size=3, 41 | stride=1, 42 | padding=1, 43 | bias=False)), 44 | self.drop_rate = drop_rate 45 | self.efficient = efficient 46 | 47 | def forward(self, *prev_features): 48 | bn_function = _bn_function_factory(self.norm1, self.relu1, self.conv1) 49 | if self.efficient and any(prev_feature.requires_grad for prev_feature in prev_features): 50 | bottleneck_output = cp.checkpoint(bn_function, *prev_features) 51 | else: 52 | bottleneck_output = bn_function(*prev_features) 53 | new_features = self.conv2(self.relu2(self.norm2(bottleneck_output))) 54 | if self.drop_rate > 0: 55 | new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) 56 | return new_features 57 | 58 | 59 | class _Transition(nn.Sequential): 60 | def __init__(self, num_input_features, num_output_features): 61 | super(_Transition, self).__init__() 62 | self.add_module('norm', nn.BatchNorm2d(num_input_features)) 63 | self.add_module('relu', nn.ReLU(inplace=True)) 64 | self.add_module( 65 | 'conv', 66 | nn.Conv2d(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False)) 67 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) 68 | 69 | 70 | class _DenseBlock(nn.Module): 71 | def __init__(self, 72 | num_layers, 73 | num_input_features, 74 | bn_size, 75 | growth_rate, 76 | drop_rate, 77 | efficient=False): 78 | super(_DenseBlock, self).__init__() 79 | for i in range(num_layers): 80 | layer = _DenseLayer( 81 | num_input_features + i * growth_rate, 82 | growth_rate=growth_rate, 83 | bn_size=bn_size, 84 | drop_rate=drop_rate, 85 | efficient=efficient, 86 | ) 87 | self.add_module('denselayer%d' % (i + 1), layer) 88 | 89 | def forward(self, init_features): 90 | features = [init_features] 91 | for name, layer in self.named_children(): 92 | new_features = layer(*features) 93 | features.append(new_features) 94 | return torch.cat(features, 1) 95 | 96 | 97 | class DenseNet(nn.Module): 98 | r"""Densenet-BC model class, based on 99 | `"Densely Connected Convolutional Networks" ` 100 | Args: 101 | growth_rate (int) - how many filters to add each layer (`k` in paper) 102 | block_config (list of 3 or 4 ints) - how many layers in each pooling block 103 | num_init_features (int) - the number of filters to learn in the first convolution layer 104 | bn_size (int) - multiplicative factor for number of bottle neck layers 105 | (i.e. bn_size * k features in the bottleneck layer) 106 | drop_rate (float) - dropout rate after each dense layer 107 | num_classes (int) - number of classification classes 108 | small_inputs (bool) - set to True if images are 32x32. Otherwise assumes images are larger. 109 | efficient (bool) - set to True to use checkpointing. Much more memory efficient, but slower. 110 | """ 111 | def __init__(self, 112 | growth_rate=12, 113 | block_config=(16, 16, 16), 114 | compression=0.5, 115 | num_init_features=24, 116 | bn_size=4, 117 | drop_rate=0, 118 | initial_stride=1, 119 | num_classes=10, 120 | small_inputs=True, 121 | efficient=False): 122 | 123 | super(DenseNet, self).__init__() 124 | assert 0 < compression <= 1, 'compression of densenet should be between 0 and 1' 125 | self.avgpool_size = 8 if small_inputs else 7 126 | 127 | # First convolution 128 | if small_inputs: 129 | self.features = nn.Sequential( 130 | OrderedDict([('conv0', 131 | nn.Conv2d(3, 132 | num_init_features, 133 | kernel_size=3, 134 | stride=initial_stride, 135 | padding=1, 136 | bias=False))])) 137 | else: 138 | self.features = nn.Sequential( 139 | OrderedDict([ 140 | ('conv0', 141 | nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, 142 | bias=False)), 143 | ])) 144 | self.features.add_module('norm0', nn.BatchNorm2d(num_init_features)) 145 | self.features.add_module('relu0', nn.ReLU(inplace=True)) 146 | self.features.add_module( 147 | 'pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False)) 148 | 149 | # Each denseblock 150 | num_features = num_init_features 151 | for i, num_layers in enumerate(block_config): 152 | block = _DenseBlock( 153 | num_layers=num_layers, 154 | num_input_features=num_features, 155 | bn_size=bn_size, 156 | growth_rate=growth_rate, 157 | drop_rate=drop_rate, 158 | efficient=efficient, 159 | ) 160 | self.features.add_module('denseblock%d' % (i + 1), block) 161 | num_features = num_features + num_layers * growth_rate 162 | if i != len(block_config) - 1: 163 | trans = _Transition(num_input_features=num_features, 164 | num_output_features=int(num_features * compression)) 165 | self.features.add_module('transition%d' % (i + 1), trans) 166 | num_features = int(num_features * compression) 167 | 168 | # Final batch norm 169 | self.features.add_module('norm_final', nn.BatchNorm2d(num_features)) 170 | 171 | # Linear layer 172 | self.classifier = nn.Linear(num_features, num_classes) 173 | 174 | # Initialization 175 | for name, param in self.named_parameters(): 176 | if 'conv' in name and 'weight' in name: 177 | n = param.size(0) * param.size(2) * param.size(3) 178 | param.data.normal_().mul_(math.sqrt(2. / n)) 179 | elif 'norm' in name and 'weight' in name: 180 | param.data.fill_(1) 181 | elif 'norm' in name and 'bias' in name: 182 | param.data.fill_(0) 183 | elif 'classifier' in name and 'bias' in name: 184 | param.data.fill_(0) 185 | 186 | @property 187 | def num_classes(self): 188 | return self.classifier.weight.size(-2) 189 | 190 | @property 191 | def num_features(self): 192 | return self.classifier.weight.size(-1) 193 | 194 | def extract_features(self, x): 195 | features = self.features(x) 196 | out = F.relu(features, inplace=True) 197 | out = F.avg_pool2d(out, kernel_size=self.avgpool_size).view(features.size(0), -1) 198 | return out 199 | 200 | def forward(self, x): 201 | return self.classifier(self.extract_features(x)) 202 | -------------------------------------------------------------------------------- /examples/paper_replication/models/lenet.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from torch import nn 3 | 4 | 5 | class LeNet(nn.Module): 6 | def __init__(self, num_classes): 7 | super(LeNet, self).__init__() 8 | self.conv1 = nn.Conv2d(3, 6, kernel_size=5) 9 | self.conv2 = nn.Conv2d(6, 16, kernel_size=5) 10 | self.fc1 = nn.Linear(16 * 5 * 5, 120) 11 | self.fc2 = nn.Linear(120, 84) 12 | self.fc3 = nn.Linear(84, num_classes) 13 | 14 | def forward(self, x): 15 | x = F.relu(self.conv1(x)) 16 | x = F.max_pool2d(x, 2) 17 | x = F.relu(self.conv2(x)) 18 | x = F.max_pool2d(x, 2) 19 | x = x.view(x.size(0), -1) 20 | x = F.relu(self.fc1(x)) 21 | x = F.relu(self.fc2(x)) 22 | x = self.fc3(x) 23 | return x 24 | 25 | 26 | class LeNetMNIST(nn.Module): 27 | def __init__(self, num_classes): 28 | super(LeNetMNIST, self).__init__() 29 | self.conv1 = nn.Conv2d(1, 20, 5, 1) 30 | self.conv2 = nn.Conv2d(20, 50, 5, 1) 31 | self.fc1 = nn.Linear(4 * 4 * 50, 500) 32 | self.fc2 = nn.Linear(500, 10) 33 | 34 | def forward(self, x): 35 | x = F.relu(self.conv1(x)) 36 | x = F.max_pool2d(x, 2, 2) 37 | x = F.relu(self.conv2(x)) 38 | x = F.max_pool2d(x, 2, 2) 39 | x = x.view(-1, 4 * 4 * 50) 40 | x = F.relu(self.fc1(x)) 41 | x = self.fc2(x) 42 | return x 43 | -------------------------------------------------------------------------------- /examples/paper_replication/models/resnet.py: -------------------------------------------------------------------------------- 1 | # This implementation is based on the DenseNet implementation in torchvision 2 | # https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 3 | 4 | import math 5 | 6 | import torch 7 | from torch import nn 8 | 9 | from torchvision.models.resnet import conv3x3 10 | 11 | 12 | class BasicBlock(nn.Module): 13 | expansion = 1 14 | 15 | def __init__(self, inplanes, planes, stride=1, downsample=None): 16 | super(BasicBlock, self).__init__() 17 | self.conv1 = conv3x3(inplanes, planes, stride) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.relu1 = nn.ReLU(inplace=True) 20 | self.conv2 = conv3x3(planes, planes) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | self.relu2 = nn.ReLU(inplace=True) 23 | self.downsample = downsample 24 | self.stride = stride 25 | 26 | def forward(self, x): 27 | residual = x 28 | if self.downsample is not None: 29 | x = self.downsample(x) 30 | # TODO: fix the bug of original Stochatic depth 31 | residual = self.conv1(residual) 32 | residual = self.bn1(residual) 33 | residual = self.relu1(residual) 34 | residual = self.conv2(residual) 35 | residual = self.bn2(residual) 36 | x = x + residual 37 | x = self.relu2(x) 38 | 39 | return x 40 | 41 | 42 | class DownsampleB(nn.Module): 43 | def __init__(self, nIn, nOut, stride): 44 | super(DownsampleB, self).__init__() 45 | self.avg = nn.AvgPool2d(stride) 46 | self.expand_ratio = nOut // nIn 47 | 48 | def forward(self, x): 49 | x = self.avg(x) 50 | return torch.cat([x] + [x.mul(0)] * (self.expand_ratio - 1), 1) 51 | 52 | 53 | class ResNet(nn.Module): 54 | '''Small ResNet for CIFAR & SVHN ''' 55 | def __init__(self, depth=32, block=BasicBlock, initial_stride=1, num_classes=10): 56 | assert (depth - 2) % 6 == 0, 'depth should be one of 6N+2' 57 | super(ResNet, self).__init__() 58 | n = (depth - 2) // 6 59 | self.inplanes = 16 60 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=initial_stride, padding=1, bias=False) 61 | self.bn1 = nn.BatchNorm2d(16) 62 | self.relu = nn.ReLU(inplace=True) 63 | self.layer1 = self._make_layer(block, 16, n) 64 | self.layer2 = self._make_layer(block, 32, n, stride=2) 65 | self.layer3 = self._make_layer(block, 64, n, stride=2) 66 | self.avgpool = nn.AvgPool2d(8) 67 | self.fc = nn.Linear(64 * block.expansion, num_classes) 68 | 69 | for m in self.modules(): 70 | if isinstance(m, nn.Conv2d): 71 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 72 | m.weight.data.normal_(0, math.sqrt(2. / n)) 73 | elif isinstance(m, nn.BatchNorm2d): 74 | m.weight.data.fill_(1) 75 | m.bias.data.zero_() 76 | 77 | def _make_layer(self, block, planes, num_blocks, stride=1): 78 | downsample = None 79 | if stride != 1 or self.inplanes != planes * block.expansion: 80 | downsample = DownsampleB(self.inplanes, planes * block.expansion, stride) 81 | 82 | layers = [block(self.inplanes, planes, stride, downsample=downsample)] 83 | self.inplanes = planes * block.expansion 84 | for _ in range(1, num_blocks): 85 | layers.append(block(self.inplanes, planes)) 86 | 87 | return nn.Sequential(*layers) 88 | 89 | @property 90 | def classifier(self): 91 | return self.fc 92 | 93 | @property 94 | def num_classes(self): 95 | return self.fc.weight.size(-2) 96 | 97 | @property 98 | def num_features(self): 99 | return self.fc.weight.size(-1) 100 | 101 | def extract_features(self, x): 102 | x = self.conv1(x) 103 | x = self.bn1(x) 104 | x = self.relu(x) 105 | 106 | x = self.layer1(x) 107 | x = self.layer2(x) 108 | x = self.layer3(x) 109 | 110 | x = self.avgpool(x) 111 | x = x.view(x.size(0), -1) 112 | return x 113 | 114 | def forward(self, x): 115 | return self.fc(self.extract_features(x)) 116 | -------------------------------------------------------------------------------- /examples/paper_replication/models/vgg.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Modified from https://github.com/pytorch/vision.git 3 | ''' 4 | import math 5 | 6 | import torch.nn as nn 7 | 8 | 9 | class VGG(nn.Module): 10 | ''' 11 | VGG model 12 | ''' 13 | def __init__(self, num_classes=10, depth=11): 14 | super(VGG, self).__init__() 15 | self.features = make_layers(depth) 16 | self.classifier = nn.Sequential( 17 | nn.Dropout(), 18 | nn.Linear(512, 512), 19 | nn.ReLU(True), 20 | nn.Dropout(), 21 | nn.Linear(512, 512), 22 | nn.ReLU(True), 23 | nn.Linear(512, num_classes), 24 | ) 25 | # Initialize weights 26 | for m in self.modules(): 27 | if isinstance(m, nn.Conv2d): 28 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 29 | m.weight.data.normal_(0, math.sqrt(2. / n)) 30 | m.bias.data.zero_() 31 | 32 | def forward(self, x): 33 | x = self.features(x) 34 | x = x.view(x.size(0), -1) 35 | x = self.classifier(x) 36 | return x 37 | 38 | 39 | def make_layers(depth, batch_norm=False): 40 | layers = [] 41 | in_channels = 3 42 | for v in cfg[depth]: 43 | if v == 'M': 44 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 45 | else: 46 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 47 | if batch_norm: 48 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 49 | else: 50 | layers += [conv2d, nn.ReLU(inplace=True)] 51 | in_channels = v 52 | return nn.Sequential(*layers) 53 | 54 | 55 | cfg = { 56 | 11: [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 57 | 13: [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 58 | 16: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 59 | 19: [ 60 | 64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 61 | 512, 'M' 62 | ], 63 | } 64 | 65 | __all__ = ["VGG", "make_layers", "cfg"] 66 | -------------------------------------------------------------------------------- /examples/paper_replication/models/wide_resnet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.nn.init as init 5 | 6 | 7 | def conv3x3(in_planes, out_planes, stride=1): 8 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True) 9 | 10 | 11 | def conv_init(m): 12 | classname = m.__class__.__name__ 13 | if classname.find("Conv") != -1: 14 | init.xavier_uniform(m.weight, gain=np.sqrt(2)) 15 | init.constant(m.bias, 0) 16 | elif classname.find("BatchNorm") != -1: 17 | init.constant(m.weight, 1) 18 | init.constant(m.bias, 0) 19 | 20 | 21 | class WideBasic(nn.Module): 22 | def __init__(self, in_planes, planes, dropout_rate, stride=1): 23 | super().__init__() 24 | self.bn1 = nn.BatchNorm2d(in_planes) 25 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True) 26 | self.dropout = nn.Dropout(p=dropout_rate) 27 | self.bn2 = nn.BatchNorm2d(planes) 28 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True) 29 | 30 | self.shortcut = nn.Sequential() 31 | if stride != 1 or in_planes != planes: 32 | self.shortcut = nn.Sequential( 33 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True), ) 34 | 35 | def forward(self, x): 36 | out = self.dropout(self.conv1(F.relu(self.bn1(x)))) 37 | out = self.conv2(F.relu(self.bn2(out))) 38 | out += self.shortcut(x) 39 | return out 40 | 41 | 42 | class WideResNet(nn.Module): 43 | def __init__(self, num_classes, depth=28, widen_factor=10, dropout_rate=0.3): 44 | super(WideResNet, self).__init__() 45 | self.in_planes = 16 46 | 47 | assert ((depth - 4) % 6 == 0), "Wide-resnet depth should be 6n+4" 48 | n = (depth - 4) // 6 49 | k = widen_factor 50 | 51 | nStages = [16, 16 * k, 32 * k, 64 * k] 52 | 53 | self.conv1 = conv3x3(3, nStages[0]) 54 | self.layer1 = self._wide_layer(WideBasic, nStages[1], n, dropout_rate, stride=1) 55 | self.layer2 = self._wide_layer(WideBasic, nStages[2], n, dropout_rate, stride=2) 56 | self.layer3 = self._wide_layer(WideBasic, nStages[3], n, dropout_rate, stride=2) 57 | self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9) 58 | self.linear = nn.Linear(nStages[3], num_classes) 59 | 60 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride): 61 | strides = [stride] + [1] * (num_blocks - 1) 62 | layers = [] 63 | 64 | for stride in strides: 65 | layers.append(block(self.in_planes, planes, dropout_rate, stride)) 66 | self.in_planes = planes 67 | 68 | return nn.Sequential(*layers) 69 | 70 | @property 71 | def classifier(self): 72 | return self.linear 73 | 74 | @property 75 | def num_classes(self): 76 | return self.linear.weight.size(-2) 77 | 78 | @property 79 | def num_features(self): 80 | return self.linear.weight.size(-1) 81 | 82 | def extract_features(self, x): 83 | out = self.conv1(x) 84 | out = self.layer1(out) 85 | out = self.layer2(out) 86 | out = self.layer3(out) 87 | out = F.relu(self.bn1(out)) 88 | out = F.avg_pool2d(out, 8) 89 | out = out.view(out.size(0), -1) 90 | return out 91 | 92 | def forward(self, x): 93 | return self.linear(self.extract_features(x)) 94 | -------------------------------------------------------------------------------- /examples/paper_replication/runner.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import os 4 | import random 5 | import shutil 6 | import sys 7 | from collections import OrderedDict 8 | 9 | import numpy as np 10 | import pandas as pd 11 | import torch 12 | 13 | import fire 14 | import tqdm 15 | import util 16 | from aum import AUMCalculator 17 | from losses import losses 18 | from models import models 19 | from torchvision import datasets 20 | from torchvision import models as tvmodels 21 | from torchvision import transforms 22 | 23 | 24 | class _Dataset(torch.utils.data.Dataset): 25 | """ 26 | A wrapper around existing torch datasets to add purposefully mislabeled samplesa and threshold samples. 27 | 28 | :param :obj:`torch.utils.data.Dataset` base_dataset: Dataset to wrap 29 | :param :obj:`torch.LongTensor` indices: List of indices of base_dataset to include (used to create valid. sets) 30 | :param dict flip_dict: (optional) List mapping sample indices to their (incorrect) assigned label 31 | :param bool use_threshold_samples: (default False) Whether or not to add threshold samples to this datasets 32 | :param bool threshold_samples_set_idx: (default 1) Which set of threshold samples to use. 33 | """ 34 | def __init__(self, 35 | base_dataset, 36 | indices=None, 37 | flip_dict=None, 38 | use_threshold_samples=False, 39 | threshold_samples_set_idx=1): 40 | super().__init__() 41 | self.dataset = base_dataset 42 | self.flip_dict = flip_dict or {} 43 | self.indices = torch.arange(len(self.dataset)) if indices is None else indices 44 | 45 | # Create optional extra class (for threshold samples) 46 | self.use_threshold_samples = use_threshold_samples 47 | if use_threshold_samples: 48 | num_threshold_samples = len(self.indices) // (self.targets.max().item() + 1) 49 | start_index = (threshold_samples_set_idx - 1) * num_threshold_samples 50 | end_index = (threshold_samples_set_idx) * num_threshold_samples 51 | self.threshold_sample_indices = torch.randperm(len(self.indices))[start_index:end_index] 52 | 53 | @property 54 | def targets(self): 55 | """ 56 | (Hidden) ground-truth labels 57 | """ 58 | if not hasattr(self, "_target_memo"): 59 | try: 60 | self.__target_memo = torch.tensor(self.dataset.targets)[self.indices] 61 | except Exception: 62 | self.__target_memo = torch.tensor([target 63 | for _, target in self.dataset])[self.indices] 64 | if torch.is_tensor(self.__target_memo): 65 | return self.__target_memo 66 | else: 67 | return torch.tensor(self.__target_memo) 68 | 69 | @property 70 | def assigned_targets(self): 71 | """ 72 | (Potentially incorrect) assigned labels 73 | """ 74 | if not hasattr(self, "_assigned_target_memo"): 75 | self._assigned_target_memo = self.targets.clone() 76 | 77 | # Change labels of mislabeled samples 78 | if self.flip_dict is not None: 79 | for i, idx in enumerate(self.indices.tolist()): 80 | if idx in self.flip_dict.keys(): 81 | self._assigned_target_memo[i] = self.flip_dict[idx] 82 | 83 | # Change labels of threshold samples 84 | if self.use_threshold_samples: 85 | extra_class = (self.targets.max().item() + 1) 86 | self._assigned_target_memo[self.threshold_sample_indices] = extra_class 87 | return self._assigned_target_memo 88 | 89 | def __len__(self): 90 | return len(self.indices) 91 | 92 | def __getitem__(self, index): 93 | input, _ = self.dataset[self.indices[index].item()] 94 | target = self.assigned_targets[index].item() 95 | res = input, target, index 96 | return res 97 | 98 | 99 | class Runner(object): 100 | """ 101 | Main module for running experiments. Can call `load`, `save`, `train`, `test`, etc. 102 | 103 | :param str data: Directory to load data from 104 | :param str save: Directory to save model/results 105 | :param str dataset: (cifar10, cifar100, tiny_imagenet, webvision50, clothing100k) 106 | 107 | :param int num_valid: (default 5000) What size validation set to use (comes from train set, indices determined by seed) 108 | :param int seed: (default 0) Random seed 109 | :param int split_seed: (default 0) Which random seed to use for creating trian/val split and for flipping random labels. 110 | If this arg is not supplied, the split_seed will come from the `seed` arg. 111 | 112 | :param float perc_mislabeled: (default 0.) How many samples will be intentionally mislabeled. 113 | Default is 0. - i.e. regular training without flipping any labels. 114 | :param str noise_type: (uniform, flip) Mislabeling noise model to use. 115 | 116 | :param bool use_threshold_samples: (default False) Whether to add indictaor samples 117 | :param bool threshold_samples_set_idx: (default 1) Which set of threshold samples to use (based on index) 118 | 119 | :param str loss_type: (default cross-entropy) Loss type 120 | :param bool oracle_training: (default False) If true, the network will be trained only on clean data 121 | (i.e. all training points with flipped labels will be discarded). 122 | 123 | :param str net_type: (resnet, densenet, wide_resnet) Which network to use. 124 | :param **model_args: Additional argumets to pass to the model 125 | """ 126 | def __init__(self, 127 | data, 128 | save, 129 | dataset="cifar10", 130 | num_valid=5000, 131 | seed=0, 132 | split_seed=None, 133 | noise_type="uniform", 134 | perc_mislabeled=0., 135 | use_threshold_samples=False, 136 | threshold_samples_set_idx=1, 137 | loss_type="cross-entropy", 138 | oracle_training=False, 139 | net_type="resnet", 140 | pretrained=False, 141 | **model_args): 142 | if not os.path.exists(save): 143 | os.makedirs(save) 144 | if not os.path.isdir(save): 145 | raise Exception('%s is not a dir' % save) 146 | self.data = data 147 | self.savedir = save 148 | self.perc_mislabeled = perc_mislabeled 149 | self.noise_type = noise_type 150 | self.dataset = dataset 151 | self.net_type = net_type 152 | self.num_valid = num_valid 153 | self.use_threshold_samples = use_threshold_samples 154 | self.threshold_samples_set_idx = threshold_samples_set_idx 155 | self.split_seed = split_seed if split_seed is not None else seed 156 | self.seed = seed 157 | self.loss_func = losses[loss_type] 158 | self.oracle_training = oracle_training 159 | self.pretrained = pretrained 160 | 161 | # Seed 162 | torch.manual_seed(0) 163 | torch.cuda.manual_seed_all(0) 164 | random.seed(0) 165 | 166 | # Logging 167 | self.timestring = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S') 168 | logging.basicConfig( 169 | format='%(message)s', 170 | handlers=[ 171 | logging.StreamHandler(sys.stdout), 172 | logging.FileHandler(os.path.join(self.savedir, 'log-%s.log' % self.timestring)), 173 | ], 174 | level=logging.INFO, 175 | ) 176 | logging.info('Data dir:\t%s' % data) 177 | logging.info('Save dir:\t%s\n' % save) 178 | 179 | # Make model 180 | self.num_classes = self.test_set.targets.max().item() + 1 181 | if use_threshold_samples: 182 | self.num_classes += 1 183 | self.num_data = len(self.train_set) 184 | logging.info(f"\nDataset: {self.dataset}") 185 | logging.info(f"Num train: {self.num_data}") 186 | logging.info(f"Num valid: {self.num_valid}") 187 | logging.info(f"Extra class: {self.use_threshold_samples}") 188 | logging.info(f"Num classes: {self.num_classes}") 189 | if self.perc_mislabeled: 190 | logging.info(f"Noise type: {self.noise_type}") 191 | logging.info(f"Flip perc: {self.perc_mislabeled}\n") 192 | if self.oracle_training: 193 | logging.info(f"Training with Oracle Only") 194 | 195 | # Model 196 | if self.dataset == "imagenet" or "webvision" in self.dataset or "clothing" in self.dataset: 197 | big_models = dict((key, val) for key, val in tvmodels.__dict__.items()) 198 | self.model = big_models[self.net_type](pretrained=False, num_classes=self.num_classes) 199 | if self.pretrained: 200 | try: 201 | self.model.load_state_dict( 202 | big_models[self.net_type](pretrained=True).state_dict(), strict=False) 203 | except RuntimeError: 204 | pass 205 | # Fix pooling issues 206 | if "inception" in self.net_type: 207 | self.avgpool_1a = torch.nn.AdaptiveAvgPool2d((1, 1)) 208 | else: 209 | self.model = models[self.net_type]( 210 | num_classes=self.num_classes, 211 | initial_stride=(2 if "tiny" in self.dataset.lower() else 1), 212 | **model_args) 213 | logging.info(f"Model type: {self.net_type}") 214 | logging.info(f"Model args:") 215 | for key, val in model_args.items(): 216 | logging.info(f" - {key}: {val}") 217 | logging.info(f"Loss type: {loss_type}") 218 | logging.info("") 219 | 220 | def _make_datasets(self): 221 | try: 222 | dataset_cls = getattr(datasets, self.dataset.upper()) 223 | self.big_model = False 224 | except Exception: 225 | dataset_cls = datasets.ImageFolder 226 | if "tiny" in self.dataset.lower(): 227 | self.big_model = False 228 | else: 229 | self.big_model = True 230 | 231 | # Get constants 232 | if dataset_cls == datasets.ImageFolder: 233 | tmp_set = dataset_cls(root=os.path.join(self.data, "train")) 234 | else: 235 | tmp_set = dataset_cls(root=self.data, train=True, download=True) 236 | if self.dataset.upper() == 'CIFAR10': 237 | tmp_set.targets = tmp_set.train_labels 238 | num_train = len(tmp_set) - self.num_valid 239 | num_valid = self.num_valid 240 | num_classes = int(max(tmp_set.targets)) + 1 241 | 242 | # Create train/valid split 243 | torch.manual_seed(self.split_seed) 244 | torch.cuda.manual_seed_all(self.split_seed) 245 | random.seed(self.split_seed) 246 | train_indices, valid_indices = torch.randperm(num_train + num_valid).split( 247 | [num_train, num_valid]) 248 | 249 | # dataset indices flip 250 | flip_dict = {} 251 | if self.perc_mislabeled: 252 | # Generate noisy labels from random transitions 253 | transition_matrix = torch.eye(num_classes) 254 | if self.noise_type == "uniform": 255 | transition_matrix.mul_(1 - self.perc_mislabeled * (num_classes / (num_classes - 1))) 256 | transition_matrix.add_(self.perc_mislabeled / (num_classes - 1)) 257 | elif self.noise_type == "flip": 258 | source_classes = torch.arange(num_classes) 259 | target_classes = (source_classes + 1).fmod(num_classes) 260 | transition_matrix.mul_(1 - self.perc_mislabeled) 261 | transition_matrix[source_classes, target_classes] = self.perc_mislabeled 262 | else: 263 | raise ValueError(f"Unknonwn noise type {self.noise}") 264 | true_targets = (torch.tensor(tmp_set.targets) if hasattr(tmp_set, "targets") else 265 | torch.tensor([target for _, target in self])) 266 | transition_targets = torch.distributions.Categorical( 267 | probs=transition_matrix[true_targets, :]).sample() 268 | # Create a dictionary of transitions 269 | if not self.oracle_training: 270 | flip_indices = torch.nonzero(transition_targets != true_targets).squeeze(-1) 271 | flip_targets = transition_targets[flip_indices] 272 | for index, target in zip(flip_indices, flip_targets): 273 | flip_dict[index.item()] = target.item() 274 | else: 275 | # In the oracle setting, don't add transitions 276 | oracle_indices = torch.nonzero(transition_targets == true_targets).squeeze(-1) 277 | train_indices = torch.from_numpy( 278 | np.intersect1d(oracle_indices.numpy(), train_indices.numpy())).long() 279 | 280 | # Reset the seed for dataset/initializations 281 | torch.manual_seed(self.split_seed) 282 | torch.cuda.manual_seed_all(self.split_seed) 283 | random.seed(self.split_seed) 284 | 285 | # Define trainsforms 286 | if self.big_model: 287 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 288 | test_transforms = transforms.Compose([ 289 | transforms.Resize(256), 290 | transforms.CenterCrop(227 if "inception" in self.net_type else 224), 291 | transforms.ToTensor(), 292 | normalize, 293 | ]) 294 | train_transforms = transforms.Compose([ 295 | transforms.RandomResizedCrop(227 if "inception" in self.net_type else 224), 296 | transforms.RandomHorizontalFlip(), 297 | transforms.ToTensor(), 298 | normalize, 299 | ]) 300 | elif self.dataset == "tiny_imagenet": 301 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 302 | test_transforms = transforms.Compose([ 303 | transforms.ToTensor(), 304 | normalize, 305 | ]) 306 | train_transforms = transforms.Compose([ 307 | transforms.RandomCrop(64, padding=8), 308 | transforms.RandomHorizontalFlip(), 309 | test_transforms, 310 | ]) 311 | elif self.dataset == "cifar10": 312 | normalize = transforms.Normalize(mean=[0.4914, 0.4824, 0.4467], 313 | std=[0.2471, 0.2435, 0.2616]) 314 | test_transforms = transforms.Compose([ 315 | transforms.ToTensor(), 316 | normalize, 317 | ]) 318 | train_transforms = transforms.Compose([ 319 | transforms.RandomCrop(32, padding=4), 320 | transforms.RandomHorizontalFlip(), 321 | test_transforms, 322 | ]) 323 | elif self.dataset == "cifar100": 324 | normalize = transforms.Normalize(mean=[0.4914, 0.4824, 0.4467], 325 | std=[0.2471, 0.2435, 0.2616]) 326 | test_transforms = transforms.Compose([ 327 | transforms.ToTensor(), 328 | normalize, 329 | ]) 330 | train_transforms = transforms.Compose([ 331 | transforms.RandomCrop(32, padding=4), 332 | transforms.RandomHorizontalFlip(), 333 | test_transforms, 334 | ]) 335 | elif self.dataset == "mnist": 336 | normalize = transforms.Normalize(mean=(0.1307, ), std=(0.3081, )) 337 | test_transforms = transforms.Compose([ 338 | transforms.ToTensor(), 339 | normalize, 340 | ]) 341 | train_transforms = test_transforms 342 | else: 343 | raise ValueError(f"Unknown dataset {self.dataset}") 344 | 345 | # Get train set 346 | if dataset_cls == datasets.ImageFolder: 347 | self._train_set_memo = _Dataset( 348 | dataset_cls( 349 | root=os.path.join(self.data, "train"), 350 | transform=train_transforms, 351 | ), 352 | flip_dict=flip_dict, 353 | indices=train_indices, 354 | use_threshold_samples=self.use_threshold_samples, 355 | threshold_samples_set_idx=self.threshold_samples_set_idx, 356 | ) 357 | if os.path.exists(os.path.join(self.data, "test")): 358 | self._valid_set_memo = _Dataset( 359 | dataset_cls(root=os.path.join(self.data, "val"), transform=test_transforms)) 360 | self._test_set_memo = _Dataset( 361 | dataset_cls(root=os.path.join(self.data, "test"), transform=test_transforms)) 362 | else: 363 | self._valid_set_memo = _Dataset( 364 | dataset_cls(root=os.path.join(self.data, "train"), transform=test_transforms), 365 | indices=valid_indices, 366 | ) if len(valid_indices) else None 367 | self._test_set_memo = _Dataset( 368 | dataset_cls(root=os.path.join(self.data, "val"), transform=test_transforms)) 369 | else: 370 | self._train_set_memo = _Dataset( 371 | dataset_cls(root=self.data, train=True, transform=train_transforms), 372 | flip_dict=flip_dict, 373 | indices=train_indices, 374 | use_threshold_samples=self.use_threshold_samples, 375 | threshold_samples_set_idx=self.threshold_samples_set_idx, 376 | ) 377 | self._valid_set_memo = _Dataset(dataset_cls( 378 | root=self.data, train=True, transform=test_transforms), 379 | indices=valid_indices) if len(valid_indices) else None 380 | self._test_set_memo = _Dataset( 381 | dataset_cls(root=self.data, train=False, transform=test_transforms)) 382 | 383 | @property 384 | def test_set(self): 385 | if not hasattr(self, "_test_set_memo"): 386 | self._make_datasets() 387 | return self._test_set_memo 388 | 389 | @property 390 | def train_set(self): 391 | if not hasattr(self, "_train_set_memo"): 392 | self._make_datasets() 393 | return self._train_set_memo 394 | 395 | @property 396 | def valid_set(self): 397 | if not hasattr(self, "_valid_set_memo"): 398 | self._make_datasets() 399 | return self._valid_set_memo 400 | 401 | def generate_aum_details(self, load=None): 402 | """ 403 | Script for accumulating both aum values and other sample details at the end of training. 404 | It makes a dataframe that contains AUMs Clean for all samples 405 | The results are saved to the file `aum_details.csv` in the model folder. 406 | 407 | :param str load: (optional) If set to some value - it will assemble aum info from the model stored in the `load` folder. 408 | Otherwise - it will comptue aums from the runner's model. 409 | 410 | :return: self 411 | """ 412 | 413 | load = load or self.savedir 414 | train_data = torch.load(os.path.join(load, "train_data.pth")) 415 | aum_data = pd.read_csv(os.path.join(load, "aum_values.csv")) 416 | 417 | # HACK: fix for old version of the code 418 | if "assigned_targets" not in train_data: 419 | train_data["assigned_targets"] = train_data["observed_targets"] 420 | 421 | true_targets = train_data["true_targets"] 422 | assigned_targets = train_data["assigned_targets"] 423 | is_threshold_sample = assigned_targets.gt(true_targets.max()) 424 | label_flipped = torch.ne(true_targets, assigned_targets) 425 | 426 | # Where to store result 427 | result = {} 428 | 429 | # Add index of samples 430 | result["Index"] = torch.arange(train_data["assigned_targets"].size(-1)) 431 | 432 | # Add label flipped info 433 | result["True Target"] = true_targets 434 | result["Observed Target"] = assigned_targets 435 | result["Label Flipped"] = label_flipped 436 | result["Is Threshold Sample"] = is_threshold_sample 437 | 438 | # Add AUM 439 | aum_data = aum_data.set_index('sample_id') 440 | aum_data = aum_data.reindex(list(range(train_data["assigned_targets"].size(-1)))) 441 | aum_list = aum_data['aum'].to_list() 442 | result["AUM"] = torch.tensor(aum_list) 443 | 444 | # Add AUM "worse than random" (AUM_WTR) score 445 | # i.e. - is the AUM worse than 99% of threshold samples? 446 | if is_threshold_sample.sum().item(): 447 | aum_wtr = torch.lt( 448 | result["AUM"].view(-1, 1), 449 | result["AUM"][is_threshold_sample].view(1, -1), 450 | ).float().mean(dim=-1).gt(0.01).float() 451 | result["AUM_WTR"] = aum_wtr 452 | else: 453 | result["AUM_WTR"] = torch.ones_like(result["AUM"]) 454 | 455 | df = pd.DataFrame(result) 456 | df.set_index( 457 | ["Index", "True Target", "Observed Target", "Label Flipped", "Is Threshold Sample"], 458 | inplace=True) 459 | df.to_csv(os.path.join(load, "aum_details.csv")) 460 | return self 461 | 462 | def done(self): 463 | "Break out of the runner" 464 | return None 465 | 466 | def load(self, save=None, suffix=""): 467 | """ 468 | Load a previously saved model state dict. 469 | 470 | :param str save: (optional) Which folder to load the saved model from. 471 | Will default to the current runner's save dir. 472 | :param str suffix: (optional) Which model file to load (e.g. "model.pth.last"). 473 | By default will load "model.pth" which contains the early-stopped model. 474 | """ 475 | save = save or self.savedir 476 | state_dict = torch.load(os.path.join(save, f"model.pth{suffix}"), 477 | map_location=torch.device('cpu')) 478 | self.model.load_state_dict(state_dict, strict=False) 479 | return self 480 | 481 | def save(self, save=None, suffix=""): 482 | """ 483 | Save the current state dict 484 | 485 | :param str save: (optional) Which folder to save the model to. 486 | Will default to the current runner's save dir. 487 | :param str suffix: (optional) A suffix to append to the save name. 488 | """ 489 | save = save or self.savedir 490 | torch.save(self.model.state_dict(), os.path.join(save, f"model.pth{suffix}")) 491 | return self 492 | 493 | def subset(self, perc, aum_files=None): 494 | """ 495 | Use only a subset of the training set 496 | If aum files are supplied, then drop samples with the lowest aum. 497 | Otherwise, drop samples at random. 498 | 499 | :param float perc: What percentage of the set to use 500 | :param str aum_files: 501 | """ 502 | if aum_files is None: 503 | torch.manual_seed(self.seed) 504 | torch.cuda.manual_seed_all(self.seed) 505 | random.seed(self.seed) 506 | order = torch.randperm(len(self.train_set)) 507 | else: 508 | counts = torch.zeros(len(self.train_set)) 509 | aums = torch.zeros(len(self.train_set)) 510 | if isinstance(aum_files, str): 511 | aum_files = aum_files.split(",") 512 | for sub_aum_file in aum_files: 513 | aums_path = os.path.join(sub_aum_file, "aum_details.csv") 514 | if not os.path.exists(aums_path): 515 | self.compute_aums(load=sub_aum_file) 516 | aums_data = pd.read_csv(aums_path).drop( 517 | ["True Target", "Observed Target", "Label Flipped"], axis=1) 518 | counts += torch.tensor(~aums_data["Is Threshold Sample"].values).float() 519 | aums += torch.tensor(aums_data["AUM"].values * 520 | ~aums_data["Is Threshold Sample"].values).float() 521 | counts.clamp_min_(1) 522 | aums = aums.div_(counts) 523 | order = aums.argsort(descending=True) 524 | 525 | num_samples = int(len(self.train_set) * perc) 526 | self.train_set.indices = self.train_set.indices[order[:num_samples]] 527 | logging.info(f"Reducing training set from {len(order)} to {len(self.train_set)}") 528 | if aum_files is not None: 529 | logging.info( 530 | f"Average AUM: {aums[order[:num_samples]].mean().item()} (from {aums.mean().item()}" 531 | ) 532 | return self 533 | 534 | def test(self, 535 | model=None, 536 | split="test", 537 | batch_size=512, 538 | dataset=None, 539 | epoch=None, 540 | num_workers=0): 541 | """ 542 | Testing script 543 | """ 544 | stats = ['error', 'top5_error', 'loss'] 545 | meters = [util.AverageMeter() for _ in stats] 546 | result_class = util.result_class(stats) 547 | 548 | # Get model 549 | if model is None: 550 | model = self.model 551 | # Model on cuda 552 | if torch.cuda.is_available(): 553 | model = model.cuda() 554 | if torch.cuda.is_available() and torch.cuda.device_count() > 1: 555 | model = torch.nn.DataParallel(model).cuda() 556 | 557 | # Get dataset/loader 558 | if dataset is None: 559 | try: 560 | dataset = getattr(self, f"{split}_set") 561 | except Exception: 562 | raise ValueError(f"Invalid split '{split}'") 563 | loader = tqdm.tqdm(torch.utils.data.DataLoader(dataset, 564 | batch_size=batch_size, 565 | shuffle=False, 566 | num_workers=num_workers), 567 | desc=split.title()) 568 | 569 | # For storing results 570 | all_losses = [] 571 | all_confs = [] 572 | all_preds = [] 573 | all_targets = [] 574 | 575 | # Model on train mode 576 | model.eval() 577 | with torch.no_grad(): 578 | for inputs, targets, indices in loader: 579 | # Get types right 580 | if torch.cuda.is_available(): 581 | inputs = inputs.cuda() 582 | targets = targets.cuda() 583 | 584 | # Calculate loss 585 | outputs = model(inputs) 586 | losses = self.loss_func(outputs, targets, reduction="none") 587 | confs, preds = outputs.topk(5, dim=-1, largest=True, sorted=True) 588 | is_correct = preds.eq(targets.unsqueeze(-1)).float() 589 | loss = losses.mean() 590 | error = 1 - is_correct[:, 0].mean() 591 | top5_error = 1 - is_correct.sum(dim=-1).mean() 592 | 593 | # measure and record stats 594 | batch_size = inputs.size(0) 595 | stat_vals = [error.item(), top5_error.item(), loss.item()] 596 | for stat_val, meter in zip(stat_vals, meters): 597 | meter.update(stat_val, batch_size) 598 | 599 | # Record losses 600 | all_losses.append(losses.cpu()) 601 | all_confs.append(confs[:, 0].cpu()) 602 | all_preds.append(preds[:, 0].cpu()) 603 | all_targets.append(targets.cpu()) 604 | 605 | # log stats 606 | res = dict((name, f"{meter.val:.3f} ({meter.avg:.3f})") 607 | for name, meter in zip(stats, meters)) 608 | loader.set_postfix(**res) 609 | 610 | # Save the outputs 611 | pd.DataFrame({ 612 | "Loss": torch.cat(all_losses).numpy(), 613 | "Prediction": torch.cat(all_preds).numpy(), 614 | "Confidence": torch.cat(all_confs).numpy(), 615 | "Label": torch.cat(all_targets).numpy(), 616 | }).to_csv(os.path.join(self.savedir, f"results_{split}.csv"), index_label="index") 617 | 618 | # Return summary statistics and outputs 619 | return result_class(*[meter.avg for meter in meters]) 620 | 621 | def train_for_aum_computation(self, 622 | num_epochs=150, 623 | batch_size=64, 624 | lr=0.1, 625 | wd=1e-4, 626 | momentum=0.9, 627 | **kwargs): 628 | """ 629 | Helper training script - this trains models that will be specifically used for AUL computations 630 | 631 | :param int num_epochs: (default 150) (This corresponds roughly to how 632 | many epochs a normal model is trained for before the lr drop.) 633 | :param int batch_size: (default 64) (The batch size is intentionally 634 | lower - this makes the network less likely to memorize.) 635 | :param float lr: Learning rate 636 | :param float wd: Weight decay 637 | :param float momentum: Momentum 638 | """ 639 | return self.train(num_epochs=num_epochs, 640 | batch_size=batch_size, 641 | test_at_end=False, 642 | lr=lr, 643 | wd=wd, 644 | momentum=momentum, 645 | lr_drops=[], 646 | **kwargs) 647 | 648 | def train(self, 649 | num_epochs=300, 650 | batch_size=256, 651 | test_at_end=True, 652 | lr=0.1, 653 | wd=1e-4, 654 | momentum=0.9, 655 | lr_drops=[0.5, 0.75], 656 | aum_wtr=False, 657 | rand_weight=False, 658 | **kwargs): 659 | """ 660 | Training script 661 | 662 | :param int num_epochs: (default 300) 663 | :param int batch_size: (default 256) 664 | :param float lr: Learning rate 665 | :param float wd: Weight decay 666 | :param float momentum: Momentum 667 | :param list lr_drops: When to drop the learning rate (by a factor of 10) as a percentage of total training time. 668 | 669 | :param str aum_wtr: (optional) The path of the model/results directory to load AUM_WTR weights from. 670 | :param bool rand_weight (optional, default false): uses rectified normal random weighting if True. 671 | """ 672 | # Model 673 | model = self.model 674 | if torch.cuda.is_available(): 675 | model = model.cuda() 676 | if torch.cuda.device_count() > 1: 677 | model = torch.nn.DataParallel(model).cuda() 678 | 679 | # Optimizer 680 | optimizer = torch.optim.SGD(model.parameters(), 681 | lr=lr, 682 | weight_decay=wd, 683 | momentum=momentum, 684 | nesterov=True) 685 | milestones = [int(lr_drop * num_epochs) for lr_drop in (lr_drops or [])] 686 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, 687 | milestones=milestones, 688 | gamma=0.1) 689 | logging.info(f"\nOPTIMIZER:\n{optimizer}") 690 | logging.info(f"SCHEDULER:\n{scheduler.milestones}") 691 | 692 | # Initialize AUM caluclator object 693 | aum_calculator = AUMCalculator(save_dir=self.savedir, compressed=False) 694 | 695 | train_data = OrderedDict() 696 | train_data["train_indices"] = self.train_set.indices 697 | train_data["valid_indices"] = (self.valid_set.indices if self.valid_set is not None else 698 | torch.tensor([], dtype=torch.long)) 699 | train_data["true_targets"] = self.train_set.targets 700 | train_data["assigned_targets"] = self.train_set.assigned_targets 701 | 702 | # Storage to log results 703 | results = [] 704 | 705 | # Train model 706 | best_error = 1 707 | for epoch in range(num_epochs): 708 | train_results = self.train_epoch(model=model, 709 | optimizer=optimizer, 710 | epoch=epoch, 711 | num_epochs=num_epochs, 712 | batch_size=batch_size, 713 | aum_calculator=aum_calculator, 714 | aum_wtr=aum_wtr, 715 | rand_weight=rand_weight, 716 | **kwargs) 717 | if self.valid_set is not None: 718 | valid_results = self.test(model=model, 719 | split="valid", 720 | batch_size=batch_size, 721 | epoch=epoch, 722 | **kwargs) 723 | else: 724 | valid_results = self.test(model, 725 | split="test", 726 | batch_size=batch_size, 727 | epoch=epoch, 728 | **kwargs) 729 | scheduler.step() 730 | 731 | # Determine if model is the best 732 | if self.valid_set is not None: 733 | self.save() 734 | elif best_error > valid_results.error: 735 | best_error = valid_results.error 736 | logging.info('New best error: %.4f' % valid_results.error) 737 | self.save() 738 | 739 | # Log results 740 | logging.info(f"\nTraining {repr(train_results)}") 741 | logging.info(f"\nValidation {repr(valid_results)}") 742 | logging.info('') 743 | results.append( 744 | OrderedDict([("epoch", f"{epoch + 1:03d}"), 745 | *[(f"train_{field}", val) for field, val in train_results.items()], 746 | *[(f"valid_{field}", val) for field, val in valid_results.items()]])) 747 | pd.DataFrame(results).set_index("epoch").to_csv( 748 | os.path.join(self.savedir, "train_log.csv")) 749 | 750 | # Save metadata around train set (like which labels were flipped) 751 | torch.save(train_data, os.path.join(self.savedir, "train_data.pth")) 752 | 753 | # Once we're finished training calculate aum 754 | aum_calculator.finalize() 755 | 756 | # Maybe test (last epoch) 757 | if test_at_end and self.valid_set is not None: 758 | test_results = self.test(model=model, **kwargs) 759 | logging.info(f"\nTest (no early stopping) {repr(test_results)}") 760 | shutil.copyfile(os.path.join(self.savedir, "results_test.csv"), 761 | os.path.join(self.savedir, "results_test_noearlystop.csv")) 762 | results.append( 763 | OrderedDict([(f"test_{field}", val) for field, val in test_results.items()])) 764 | pd.DataFrame(results).set_index("epoch").to_csv( 765 | os.path.join(self.savedir, "train_log.csv")) 766 | 767 | # Load best model 768 | self.save(suffix=".last") 769 | self.load() 770 | 771 | # Maybe test (best epoch) 772 | if test_at_end and self.valid_set is not None: 773 | test_results = self.test(model=model, **kwargs) 774 | logging.info(f"\nEarly Stopped Model Test {repr(test_results)}") 775 | results.append( 776 | OrderedDict([(f"test_best_{field}", val) for field, val in test_results.items()])) 777 | pd.DataFrame(results).set_index("epoch").to_csv(os.path.join(self.savedir, "train_log.csv")) 778 | 779 | return self 780 | 781 | def train_epoch(self, 782 | model, 783 | optimizer, 784 | epoch, 785 | num_epochs, 786 | batch_size=256, 787 | num_workers=0, 788 | aum_calculator=None, 789 | aum_wtr=False, 790 | rand_weight=False): 791 | stats = ["error", "loss"] 792 | meters = [util.AverageMeter() for _ in stats] 793 | result_class = util.result_class(stats) 794 | 795 | # Weighting - set up from GMM 796 | # NOTE: This is only used when removing threshold samples 797 | # TODO: some of this probably needs to be changed? 798 | if aum_wtr: 799 | counts = torch.zeros(len(self.train_set)) 800 | bad_probs = torch.zeros(len(self.train_set)) 801 | if isinstance(aum_wtr, str): 802 | aum_wtr = aum_wtr.split(",") 803 | for sub_aum_wtr in aum_wtr: 804 | aums_path = os.path.join(sub_aum_wtr, "aum_details.csv") 805 | if not os.path.exists(aums_path): 806 | self.generate_aum_details(load=sub_aum_wtr) 807 | aums_data = pd.read_csv(aums_path).drop( 808 | ["True Target", "Observed Target", "Label Flipped"], axis=1) 809 | counts += torch.tensor(~aums_data["Is Threshold Sample"].values).float() 810 | bad_probs += torch.tensor(aums_data["AUM_WTR"].values * 811 | ~aums_data["Is Threshold Sample"].values).float() 812 | counts.clamp_min_(1) 813 | good_probs = (1 - bad_probs / counts).to(next(model.parameters()).dtype).ceil() 814 | if torch.cuda.is_available(): 815 | good_probs = good_probs.cuda() 816 | logging.info(f"AUM WTR Score") 817 | logging.info(f"(Num samples removed: {good_probs.ne(1.).sum().item()})") 818 | elif rand_weight: 819 | logging.info("Rectified Normal Random Weighting") 820 | else: 821 | logging.info("Standard weighting") 822 | 823 | # Setup loader 824 | train_set = self.train_set 825 | loader = tqdm.tqdm(torch.utils.data.DataLoader(train_set, 826 | batch_size=batch_size, 827 | shuffle=True, 828 | num_workers=num_workers), 829 | desc=f"Train (Epoch {epoch + 1}/{num_epochs})") 830 | 831 | # Model on train mode 832 | model.train() 833 | for inputs, targets, indices in loader: 834 | optimizer.zero_grad() 835 | 836 | # Get types right 837 | if torch.cuda.is_available(): 838 | inputs = inputs.cuda() 839 | targets = targets.cuda() 840 | 841 | # Compute output and losses 842 | outputs = model(inputs) 843 | losses = self.loss_func(outputs, targets, reduction="none") 844 | preds = outputs.argmax(dim=-1) 845 | 846 | # Compute loss weights 847 | if aum_wtr: 848 | weights = good_probs[indices.to(good_probs.device)] 849 | weights = weights.div(weights.sum()) 850 | elif rand_weight: 851 | weights = torch.randn(targets.size(), dtype=outputs.dtype, 852 | device=outputs.device).clamp_min_(0) 853 | weights = weights.div(weights.sum().clamp_min_(1e-10)) 854 | else: 855 | weights = torch.ones(targets.size(), dtype=outputs.dtype, 856 | device=outputs.device).div_(targets.numel()) 857 | 858 | # Backward through model 859 | loss = torch.dot(weights, losses) 860 | error = torch.ne(targets, preds).float().mean() 861 | loss.backward() 862 | 863 | # Update the model 864 | optimizer.step() 865 | 866 | # Update AUM values (after the first epoch due to variability of random initialization) 867 | if aum_calculator and epoch > 0: 868 | aum_calculator.update(logits=outputs.detach().cpu().half().float(), 869 | targets=targets.detach().cpu(), 870 | sample_ids=indices.tolist()) 871 | 872 | # measure and record stats 873 | batch_size = outputs.size(0) 874 | stat_vals = [error.item(), loss.item()] 875 | for stat_val, meter in zip(stat_vals, meters): 876 | meter.update(stat_val, batch_size) 877 | 878 | # log stats 879 | res = dict( 880 | (name, f"{meter.val:.3f} ({meter.avg:.3f})") for name, meter in zip(stats, meters)) 881 | loader.set_postfix(**res) 882 | 883 | # Return summary statistics 884 | return result_class(*[meter.avg for meter in meters]) 885 | 886 | 887 | if __name__ == "__main__": 888 | fire.Fire(Runner) 889 | -------------------------------------------------------------------------------- /examples/paper_replication/runner_testing.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import os 4 | import shutil 5 | import sys 6 | from collections import OrderedDict 7 | 8 | import numpy as np 9 | import pandas as pd 10 | import torch 11 | 12 | import fire 13 | import tqdm 14 | import util 15 | from aum import AUMCalculator 16 | from losses import losses 17 | from models import models 18 | from torchvision import datasets 19 | from torchvision import models as tvmodels 20 | from torchvision import transforms 21 | 22 | 23 | class _Dataset(torch.utils.data.Dataset): 24 | """ 25 | A wrapper around existing torch datasets to add purposefully mislabeled samplesa and threshold samples. 26 | 27 | :param :obj:`torch.utils.data.Dataset` base_dataset: Dataset to wrap 28 | :param :obj:`torch.LongTensor` indices: List of indices of base_dataset to include (used to create valid. sets) 29 | :param dict flip_dict: (optional) List mapping sample indices to their (incorrect) assigned label 30 | :param bool use_threshold_samples: (default False) Whether or not to add threshold samples to this datasets 31 | :param bool threshold_samples_set_idx: (default 1) Which set of threshold samples to use. 32 | """ 33 | def __init__(self, 34 | base_dataset, 35 | indices=None, 36 | flip_dict=None, 37 | use_threshold_samples=False, 38 | threshold_samples_set_idx=1): 39 | super().__init__() 40 | self.dataset = base_dataset 41 | self.flip_dict = flip_dict or {} 42 | self.indices = torch.arange(len(self.dataset)) if indices is None else indices 43 | 44 | # Create optional extra class (for threshold samples) 45 | self.use_threshold_samples = use_threshold_samples 46 | if use_threshold_samples: 47 | num_threshold_samples = len(self.indices) // (self.targets.max().item() + 1) 48 | start_index = (threshold_samples_set_idx - 1) * num_threshold_samples 49 | end_index = (threshold_samples_set_idx) * num_threshold_samples 50 | self.threshold_sample_indices = torch.randperm(len(self.indices))[start_index:end_index] 51 | 52 | @property 53 | def targets(self): 54 | """ 55 | (Hidden) ground-truth labels 56 | """ 57 | if not hasattr(self, "_target_memo"): 58 | try: 59 | self.__target_memo = torch.tensor(self.dataset.targets)[self.indices] 60 | except Exception: 61 | self.__target_memo = torch.tensor([target 62 | for _, target in self.dataset])[self.indices] 63 | if torch.is_tensor(self.__target_memo): 64 | return self.__target_memo 65 | else: 66 | return torch.tensor(self.__target_memo) 67 | 68 | @property 69 | def assigned_targets(self): 70 | """ 71 | (Potentially incorrect) assigned labels 72 | """ 73 | if not hasattr(self, "_assigned_target_memo"): 74 | self._assigned_target_memo = self.targets.clone() 75 | 76 | # Change labels of mislabeled samples 77 | if self.flip_dict is not None: 78 | for i, idx in enumerate(self.indices.tolist()): 79 | if idx in self.flip_dict.keys(): 80 | self._assigned_target_memo[i] = self.flip_dict[idx] 81 | 82 | # Change labels of threshold samples 83 | if self.use_threshold_samples: 84 | extra_class = (self.targets.max().item() + 1) 85 | self._assigned_target_memo[self.threshold_sample_indices] = extra_class 86 | return self._assigned_target_memo 87 | 88 | def __len__(self): 89 | return len(self.indices) 90 | 91 | def __getitem__(self, index): 92 | input, _ = self.dataset[self.indices[index].item()] 93 | target = self.assigned_targets[index].item() 94 | res = input, target, index 95 | return res 96 | 97 | 98 | class Runner(object): 99 | """ 100 | Main module for running experiments. Can call `load`, `save`, `train`, `test`, etc. 101 | 102 | :param str data: Directory to load data from 103 | :param str save: Directory to save model/results 104 | :param str dataset: (cifar10, cifar100, tiny_imagenet, webvision50, clothing100k) 105 | 106 | :param int num_valid: (default 5000) What size validation set to use (comes from train set, indices determined by seed) 107 | :param int seed: (default 0) Random seed 108 | :param int split_seed: (default 0) Which random seed to use for creating trian/val split and for flipping random labels. 109 | If this arg is not supplied, the split_seed will come from the `seed` arg. 110 | 111 | :param float perc_mislabeled: (default 0.) How many samples will be intentionally mislabeled. 112 | Default is 0. - i.e. regular training without flipping any labels. 113 | :param str noise_type: (uniform, flip) Mislabeling noise model to use. 114 | 115 | :param bool use_threshold_samples: (default False) Whether to add indictaor samples 116 | :param bool threshold_samples_set_idx: (default 1) Which set of threshold samples to use (based on index) 117 | 118 | :param str loss_type: (default cross-entropy) Loss type 119 | :param bool oracle_training: (default False) If true, the network will be trained only on clean data 120 | (i.e. all training points with flipped labels will be discarded). 121 | 122 | :param str net_type: (resnet, densenet, wide_resnet) Which network to use. 123 | :param **model_args: Additional argumets to pass to the model 124 | """ 125 | def __init__(self, 126 | data, 127 | save, 128 | dataset="cifar10", 129 | num_valid=5000, 130 | seed=0, 131 | split_seed=None, 132 | noise_type="uniform", 133 | perc_mislabeled=0., 134 | use_threshold_samples=False, 135 | threshold_samples_set_idx=1, 136 | loss_type="cross-entropy", 137 | oracle_training=False, 138 | net_type="resnet", 139 | pretrained=False, 140 | **model_args): 141 | if not os.path.exists(save): 142 | os.makedirs(save) 143 | if not os.path.isdir(save): 144 | raise Exception('%s is not a dir' % save) 145 | self.data = data 146 | self.savedir = save 147 | self.perc_mislabeled = perc_mislabeled 148 | self.noise_type = noise_type 149 | self.dataset = dataset 150 | self.net_type = net_type 151 | self.num_valid = num_valid 152 | self.use_threshold_samples = use_threshold_samples 153 | self.threshold_samples_set_idx = threshold_samples_set_idx 154 | self.split_seed = split_seed if split_seed is not None else seed 155 | self.seed = seed 156 | self.loss_func = losses[loss_type] 157 | self.oracle_training = oracle_training 158 | self.pretrained = pretrained 159 | 160 | # Seed 161 | torch.manual_seed(0) 162 | 163 | # Logging 164 | self.timestring = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S') 165 | logging.basicConfig( 166 | format='%(message)s', 167 | handlers=[ 168 | logging.StreamHandler(sys.stdout), 169 | logging.FileHandler(os.path.join(self.savedir, 'log-%s.log' % self.timestring)), 170 | ], 171 | level=logging.INFO, 172 | ) 173 | logging.info('Data dir:\t%s' % data) 174 | logging.info('Save dir:\t%s\n' % save) 175 | 176 | # Make model 177 | self.num_classes = self.test_set.targets.max().item() + 1 178 | if use_threshold_samples: 179 | self.num_classes += 1 180 | self.num_data = len(self.train_set) 181 | logging.info(f"\nDataset: {self.dataset}") 182 | logging.info(f"Num train: {self.num_data}") 183 | logging.info(f"Num valid: {self.num_valid}") 184 | logging.info(f"Extra class: {self.use_threshold_samples}") 185 | logging.info(f"Num classes: {self.num_classes}") 186 | if self.perc_mislabeled: 187 | logging.info(f"Noise type: {self.noise_type}") 188 | logging.info(f"Flip perc: {self.perc_mislabeled}\n") 189 | if self.oracle_training: 190 | logging.info(f"Training with Oracle Only") 191 | 192 | # Model 193 | if self.dataset == "imagenet" or "webvision" in self.dataset or "clothing" in self.dataset: 194 | big_models = dict((key, val) for key, val in tvmodels.__dict__.items()) 195 | self.model = big_models[self.net_type](pretrained=False, num_classes=self.num_classes) 196 | if self.pretrained: 197 | try: 198 | self.model.load_state_dict( 199 | big_models[self.net_type](pretrained=True).state_dict(), strict=False) 200 | except RuntimeError: 201 | pass 202 | # Fix pooling issues 203 | if "inception" in self.net_type: 204 | self.avgpool_1a = torch.nn.AdaptiveAvgPool2d((1, 1)) 205 | else: 206 | self.model = models[self.net_type]( 207 | num_classes=self.num_classes, 208 | initial_stride=(2 if "tiny" in self.dataset.lower() else 1), 209 | **model_args) 210 | logging.info(f"Model type: {self.net_type}") 211 | logging.info(f"Model args:") 212 | for key, val in model_args.items(): 213 | logging.info(f" - {key}: {val}") 214 | logging.info(f"Loss type: {loss_type}") 215 | logging.info("") 216 | 217 | def _make_datasets(self): 218 | try: 219 | dataset_cls = getattr(datasets, self.dataset.upper()) 220 | self.big_model = False 221 | except Exception: 222 | dataset_cls = datasets.ImageFolder 223 | if "tiny" in self.dataset.lower(): 224 | self.big_model = False 225 | else: 226 | self.big_model = True 227 | 228 | # Get constants 229 | if dataset_cls == datasets.ImageFolder: 230 | tmp_set = dataset_cls(root=os.path.join(self.data, "train")) 231 | else: 232 | tmp_set = dataset_cls(root=self.data, train=True, download=True) 233 | if self.dataset.upper() == 'CIFAR10': 234 | tmp_set.targets = tmp_set.train_labels 235 | num_train = len(tmp_set) - self.num_valid 236 | num_valid = self.num_valid 237 | num_classes = int(max(tmp_set.targets)) + 1 238 | 239 | # Create train/valid split 240 | torch.manual_seed(self.split_seed) 241 | train_indices, valid_indices = torch.randperm(num_train + num_valid).split( 242 | [num_train, num_valid]) 243 | 244 | # dataset indices flip 245 | flip_dict = {} 246 | if self.perc_mislabeled: 247 | # Generate noisy labels from random transitions 248 | transition_matrix = torch.eye(num_classes) 249 | if self.noise_type == "uniform": 250 | transition_matrix.mul_(1 - self.perc_mislabeled * (num_classes / (num_classes - 1))) 251 | transition_matrix.add_(self.perc_mislabeled / (num_classes - 1)) 252 | elif self.noise_type == "flip": 253 | source_classes = torch.arange(num_classes) 254 | target_classes = (source_classes + 1).fmod(num_classes) 255 | transition_matrix.mul_(1 - self.perc_mislabeled) 256 | transition_matrix[source_classes, target_classes] = self.perc_mislabeled 257 | else: 258 | raise ValueError(f"Unknonwn noise type {self.noise}") 259 | true_targets = (torch.tensor(tmp_set.targets) if hasattr(tmp_set, "targets") else 260 | torch.tensor([target for _, target in self])) 261 | transition_targets = torch.distributions.Categorical( 262 | probs=transition_matrix[true_targets, :]).sample() 263 | # Create a dictionary of transitions 264 | if not self.oracle_training: 265 | flip_indices = torch.nonzero(transition_targets != true_targets).squeeze(-1) 266 | flip_targets = transition_targets[flip_indices] 267 | for index, target in zip(flip_indices, flip_targets): 268 | flip_dict[index.item()] = target.item() 269 | else: 270 | # In the oracle setting, don't add transitions 271 | oracle_indices = torch.nonzero(transition_targets == true_targets).squeeze(-1) 272 | train_indices = torch.from_numpy( 273 | np.intersect1d(oracle_indices.numpy(), train_indices.numpy())).long() 274 | 275 | # Reset the seed for dataset/initializations 276 | torch.manual_seed(self.seed) 277 | 278 | # Define trainsforms 279 | if self.big_model: 280 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 281 | test_transforms = transforms.Compose([ 282 | transforms.Resize(256), 283 | transforms.CenterCrop(227 if "inception" in self.net_type else 224), 284 | transforms.ToTensor(), 285 | normalize, 286 | ]) 287 | train_transforms = transforms.Compose([ 288 | transforms.RandomResizedCrop(227 if "inception" in self.net_type else 224), 289 | transforms.RandomHorizontalFlip(), 290 | transforms.ToTensor(), 291 | normalize, 292 | ]) 293 | elif self.dataset == "tiny_imagenet": 294 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 295 | test_transforms = transforms.Compose([ 296 | transforms.ToTensor(), 297 | normalize, 298 | ]) 299 | train_transforms = transforms.Compose([ 300 | transforms.RandomCrop(64, padding=8), 301 | transforms.RandomHorizontalFlip(), 302 | test_transforms, 303 | ]) 304 | elif self.dataset == "cifar10": 305 | normalize = transforms.Normalize(mean=[0.4914, 0.4824, 0.4467], 306 | std=[0.2471, 0.2435, 0.2616]) 307 | test_transforms = transforms.Compose([ 308 | transforms.ToTensor(), 309 | normalize, 310 | ]) 311 | train_transforms = transforms.Compose([ 312 | transforms.RandomCrop(32, padding=4), 313 | transforms.RandomHorizontalFlip(), 314 | test_transforms, 315 | ]) 316 | elif self.dataset == "cifar100": 317 | normalize = transforms.Normalize(mean=[0.4914, 0.4824, 0.4467], 318 | std=[0.2471, 0.2435, 0.2616]) 319 | test_transforms = transforms.Compose([ 320 | transforms.ToTensor(), 321 | normalize, 322 | ]) 323 | train_transforms = transforms.Compose([ 324 | transforms.RandomCrop(32, padding=4), 325 | transforms.RandomHorizontalFlip(), 326 | test_transforms, 327 | ]) 328 | elif self.dataset == "mnist": 329 | normalize = transforms.Normalize(mean=(0.1307, ), std=(0.3081, )) 330 | test_transforms = transforms.Compose([ 331 | transforms.ToTensor(), 332 | normalize, 333 | ]) 334 | train_transforms = test_transforms 335 | else: 336 | raise ValueError(f"Unknown dataset {self.dataset}") 337 | 338 | # Get train set 339 | if dataset_cls == datasets.ImageFolder: 340 | self._train_set_memo = _Dataset( 341 | dataset_cls( 342 | root=os.path.join(self.data, "train"), 343 | transform=train_transforms, 344 | ), 345 | flip_dict=flip_dict, 346 | indices=train_indices, 347 | use_threshold_samples=self.use_threshold_samples, 348 | threshold_samples_set_idx=self.threshold_samples_set_idx, 349 | ) 350 | if os.path.exists(os.path.join(self.data, "test")): 351 | self._valid_set_memo = _Dataset( 352 | dataset_cls(root=os.path.join(self.data, "val"), transform=test_transforms)) 353 | self._test_set_memo = _Dataset( 354 | dataset_cls(root=os.path.join(self.data, "test"), transform=test_transforms)) 355 | else: 356 | self._valid_set_memo = _Dataset( 357 | dataset_cls(root=os.path.join(self.data, "train"), transform=test_transforms), 358 | indices=valid_indices, 359 | ) if len(valid_indices) else None 360 | self._test_set_memo = _Dataset( 361 | dataset_cls(root=os.path.join(self.data, "val"), transform=test_transforms)) 362 | else: 363 | self._train_set_memo = _Dataset( 364 | dataset_cls(root=self.data, train=True, transform=train_transforms), 365 | flip_dict=flip_dict, 366 | indices=train_indices, 367 | use_threshold_samples=self.use_threshold_samples, 368 | threshold_samples_set_idx=self.threshold_samples_set_idx, 369 | ) 370 | self._valid_set_memo = _Dataset(dataset_cls( 371 | root=self.data, train=True, transform=test_transforms), 372 | indices=valid_indices) if len(valid_indices) else None 373 | self._test_set_memo = _Dataset( 374 | dataset_cls(root=self.data, train=False, transform=test_transforms)) 375 | 376 | @property 377 | def test_set(self): 378 | if not hasattr(self, "_test_set_memo"): 379 | self._make_datasets() 380 | return self._test_set_memo 381 | 382 | @property 383 | def train_set(self): 384 | if not hasattr(self, "_train_set_memo"): 385 | self._make_datasets() 386 | return self._train_set_memo 387 | 388 | @property 389 | def valid_set(self): 390 | if not hasattr(self, "_valid_set_memo"): 391 | self._make_datasets() 392 | return self._valid_set_memo 393 | 394 | def generate_aum_details(self, load=None): 395 | """ 396 | Script for accumulating both aum values and other sample details at the end of training. 397 | It makes a dataframe that contains AUMs Clean for all samples 398 | The results are saved to the file `aum_details.csv` in the model folder. 399 | 400 | :param str load: (optional) If set to some value - it will assemble aum info from the model stored in the `load` folder. 401 | Otherwise - it will comptue aums from the runner's model. 402 | 403 | :return: self 404 | """ 405 | 406 | load = load or self.savedir 407 | train_data = torch.load(os.path.join(load, "train_data.pth")) 408 | aum_data = pd.read_csv(os.path.join(load, "aum_values.csv")) 409 | 410 | # HACK: fix for old version of the code 411 | if "assigned_targets" not in train_data: 412 | train_data["assigned_targets"] = train_data["observed_targets"] 413 | 414 | true_targets = train_data["true_targets"] 415 | assigned_targets = train_data["assigned_targets"] 416 | is_threshold_sample = assigned_targets.gt(true_targets.max()) 417 | label_flipped = torch.ne(true_targets, assigned_targets) 418 | 419 | # Where to store result 420 | result = {} 421 | 422 | # Add index of samples 423 | result["Index"] = torch.arange(train_data["assigned_targets"].size(-1)) 424 | 425 | # Add label flipped info 426 | result["True Target"] = true_targets 427 | result["Observed Target"] = assigned_targets 428 | result["Label Flipped"] = label_flipped 429 | result["Is Threshold Sample"] = is_threshold_sample 430 | 431 | # Add AUM 432 | aum_data = aum_data.set_index('sample_id') 433 | aum_data = aum_data.reindex(list(range(train_data["assigned_targets"].size(-1)))) 434 | aum_list = aum_data['aum'].to_list() 435 | result["AUM_LIB"] = torch.tensor(aum_list) 436 | 437 | ####################################### 438 | ## OLD WAY OF DOING THINGS ############ 439 | ####################################### 440 | start_epoch = 1 441 | 442 | # Get a mask for epochs that we recorded data for 443 | # This is for if we stopped training early 444 | mask = train_data["correct_confs"].float().sum(-1).gt(0) 445 | mask[:start_epoch] = False 446 | 447 | # Get the losses 448 | losses = train_data["correct_confs"][mask].float().clamp_min(1e-10).log().mul(-1) 449 | 450 | # Add Loss 451 | result["Loss"] = losses[-1] 452 | 453 | # Add confidence 454 | result["Correct Logit"] = train_data["correct_logits"][mask][-1] 455 | result["Incorrect Logit"] = train_data["incorrect_logits"][mask][-1] 456 | result["Margin"] = (result["Correct Logit"].float() - 457 | result["Incorrect Logit"].float()).half() 458 | result["AUM"] = torch.sub(train_data["correct_logits"][mask].float(), 459 | train_data["incorrect_logits"][mask].float()).mean(0) 460 | 461 | ####################################### 462 | ## OLD WAY OF DOING THINGS ############ 463 | ####################################### 464 | 465 | 466 | # Add AUM "worse than random" (AUM_WTR) score 467 | # i.e. - is the AUM worse than 99% of threshold samples? 468 | if is_threshold_sample.sum().item(): 469 | aum_wtr = torch.lt( 470 | result["AUM"].view(-1, 1), 471 | result["AUM"][is_threshold_sample].view(1, -1), 472 | ).float().mean(dim=-1).gt(0.01).float() 473 | result["AUM_WTR"] = aum_wtr 474 | else: 475 | result["AUM_WTR"] = torch.ones_like(result["AUM"]) 476 | 477 | df = pd.DataFrame(result) 478 | df.set_index( 479 | ["Index", "True Target", "Observed Target", "Label Flipped", "Is Threshold Sample"], 480 | inplace=True) 481 | df.to_csv(os.path.join(load, "aum_details.csv")) 482 | return self 483 | 484 | def done(self): 485 | "Break out of the runner" 486 | return None 487 | 488 | def load(self, save=None, suffix=""): 489 | """ 490 | Load a previously saved model state dict. 491 | 492 | :param str save: (optional) Which folder to load the saved model from. 493 | Will default to the current runner's save dir. 494 | :param str suffix: (optional) Which model file to load (e.g. "model.pth.last"). 495 | By default will load "model.pth" which contains the early-stopped model. 496 | """ 497 | save = save or self.savedir 498 | state_dict = torch.load(os.path.join(save, f"model.pth{suffix}"), 499 | map_location=torch.device('cpu')) 500 | self.model.load_state_dict(state_dict, strict=False) 501 | return self 502 | 503 | def save(self, save=None, suffix=""): 504 | """ 505 | Save the current state dict 506 | 507 | :param str save: (optional) Which folder to save the model to. 508 | Will default to the current runner's save dir. 509 | :param str suffix: (optional) A suffix to append to the save name. 510 | """ 511 | save = save or self.savedir 512 | torch.save(self.model.state_dict(), os.path.join(save, f"model.pth{suffix}")) 513 | return self 514 | 515 | def subset(self, perc, aum_files=None): 516 | """ 517 | Use only a subset of the training set 518 | If aum files are supplied, then drop samples with the lowest aum. 519 | Otherwise, drop samples at random. 520 | 521 | :param float perc: What percentage of the set to use 522 | :param str aum_files: 523 | """ 524 | if aum_files is None: 525 | torch.manual_seed(self.seed) 526 | order = torch.randperm(len(self.train_set)) 527 | else: 528 | counts = torch.zeros(len(self.train_set)) 529 | aums = torch.zeros(len(self.train_set)) 530 | if isinstance(aum_files, str): 531 | aum_files = aum_files.split(",") 532 | for sub_aum_file in aum_files: 533 | aums_path = os.path.join(sub_aum_file, "aum_details.csv") 534 | if not os.path.exists(aums_path): 535 | self.compute_aums(load=sub_aum_file) 536 | aums_data = pd.read_csv(aums_path).drop( 537 | ["True Target", "Observed Target", "Label Flipped"], axis=1) 538 | counts += torch.tensor(~aums_data["Is Threshold Sample"].values).float() 539 | aums += torch.tensor(aums_data["AUM"].values * 540 | ~aums_data["Is Threshold Sample"].values).float() 541 | counts.clamp_min_(1) 542 | aums = aums.div_(counts) 543 | order = aums.argsort(descending=True) 544 | 545 | num_samples = int(len(self.train_set) * perc) 546 | self.train_set.indices = self.train_set.indices[order[:num_samples]] 547 | logging.info(f"Reducing training set from {len(order)} to {len(self.train_set)}") 548 | if aum_files is not None: 549 | logging.info( 550 | f"Average AUM: {aums[order[:num_samples]].mean().item()} (from {aums.mean().item()}" 551 | ) 552 | return self 553 | 554 | def test(self, 555 | model=None, 556 | split="test", 557 | batch_size=512, 558 | dataset=None, 559 | epoch=None, 560 | num_workers=0): 561 | """ 562 | Testing script 563 | """ 564 | stats = ['error', 'top5_error', 'loss'] 565 | meters = [util.AverageMeter() for _ in stats] 566 | result_class = util.result_class(stats) 567 | 568 | # Get model 569 | if model is None: 570 | model = self.model 571 | # Model on cuda 572 | if torch.cuda.is_available(): 573 | model = model.cuda() 574 | if torch.cuda.is_available() and torch.cuda.device_count() > 1: 575 | model = torch.nn.DataParallel(model).cuda() 576 | 577 | # Get dataset/loader 578 | if dataset is None: 579 | try: 580 | dataset = getattr(self, f"{split}_set") 581 | except Exception: 582 | raise ValueError(f"Invalid split '{split}'") 583 | loader = tqdm.tqdm(torch.utils.data.DataLoader(dataset, 584 | batch_size=batch_size, 585 | shuffle=False, 586 | num_workers=num_workers), 587 | desc=split.title()) 588 | 589 | # For storing results 590 | all_losses = [] 591 | all_confs = [] 592 | all_preds = [] 593 | all_targets = [] 594 | 595 | # Model on train mode 596 | model.eval() 597 | with torch.no_grad(): 598 | for inputs, targets, indices in loader: 599 | # Get types right 600 | if torch.cuda.is_available(): 601 | inputs = inputs.cuda() 602 | targets = targets.cuda() 603 | 604 | # Calculate loss 605 | outputs = model(inputs) 606 | losses = self.loss_func(outputs, targets, reduction="none") 607 | confs, preds = outputs.topk(5, dim=-1, largest=True, sorted=True) 608 | is_correct = preds.eq(targets.unsqueeze(-1)).float() 609 | loss = losses.mean() 610 | error = 1 - is_correct[:, 0].mean() 611 | top5_error = 1 - is_correct.sum(dim=-1).mean() 612 | 613 | # measure and record stats 614 | batch_size = inputs.size(0) 615 | stat_vals = [error.item(), top5_error.item(), loss.item()] 616 | for stat_val, meter in zip(stat_vals, meters): 617 | meter.update(stat_val, batch_size) 618 | 619 | # Record losses 620 | all_losses.append(losses.cpu()) 621 | all_confs.append(confs[:, 0].cpu()) 622 | all_preds.append(preds[:, 0].cpu()) 623 | all_targets.append(targets.cpu()) 624 | 625 | # print stats 626 | res = dict((name, f"{meter.val:.3f} ({meter.avg:.3f})") 627 | for name, meter in zip(stats, meters)) 628 | loader.set_postfix(**res) 629 | 630 | # Save the outputs 631 | pd.DataFrame({ 632 | "Loss": torch.cat(all_losses).numpy(), 633 | "Prediction": torch.cat(all_preds).numpy(), 634 | "Confidence": torch.cat(all_confs).numpy(), 635 | "Label": torch.cat(all_targets).numpy(), 636 | }).to_csv(os.path.join(self.savedir, f"results_{split}.csv"), index_label="index") 637 | 638 | # Return summary statistics and outputs 639 | return result_class(*[meter.avg for meter in meters]) 640 | 641 | def train_for_aum_computation(self, 642 | num_epochs=150, 643 | batch_size=64, 644 | lr=0.1, 645 | wd=1e-4, 646 | momentum=0.9, 647 | **kwargs): 648 | """ 649 | Helper training script - this trains models that will be specifically used for AUL computations 650 | 651 | :param int num_epochs: (default 150) (This corresponds roughly to how 652 | many epochs a normal model is trained for before the lr drop.) 653 | :param int batch_size: (default 64) (The batch size is intentionally 654 | lower - this makes the network less likely to memorize.) 655 | :param float lr: Learning rate 656 | :param float wd: Weight decay 657 | :param float momentum: Momentum 658 | """ 659 | return self.train(num_epochs=num_epochs, 660 | batch_size=batch_size, 661 | test_at_end=False, 662 | lr=lr, 663 | wd=wd, 664 | momentum=momentum, 665 | lr_drops=[], 666 | **kwargs) 667 | 668 | def train(self, 669 | num_epochs=300, 670 | batch_size=256, 671 | test_at_end=True, 672 | lr=0.1, 673 | wd=1e-4, 674 | momentum=0.9, 675 | lr_drops=[0.5, 0.75], 676 | aum_wtr=False, 677 | rand_weight=False, 678 | **kwargs): 679 | """ 680 | Training script 681 | 682 | :param int num_epochs: (default 300) 683 | :param int batch_size: (default 256) 684 | :param float lr: Learning rate 685 | :param float wd: Weight decay 686 | :param float momentum: Momentum 687 | :param list lr_drops: When to drop the learning rate (by a factor of 10) as a percentage of total training time. 688 | 689 | :param str aum_wtr: (optional) The path of the model/results directory to load AUM_WTR weights from. 690 | :param bool rand_weight (optional, default false): uses rectified normal random weighting if True. 691 | """ 692 | # Model 693 | model = self.model 694 | if torch.cuda.is_available(): 695 | model = model.cuda() 696 | if torch.cuda.device_count() > 1: 697 | model = torch.nn.DataParallel(model).cuda() 698 | 699 | # Optimizer 700 | optimizer = torch.optim.SGD(model.parameters(), 701 | lr=lr, 702 | weight_decay=wd, 703 | momentum=momentum, 704 | nesterov=True) 705 | milestones = [int(lr_drop * num_epochs) for lr_drop in (lr_drops or [])] 706 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, 707 | milestones=milestones, 708 | gamma=0.1) 709 | logging.info(f"\nOPTIMIZER:\n{optimizer}") 710 | logging.info(f"SCHEDULER:\n{scheduler.milestones}") 711 | 712 | # Initialize AUM caluclator object 713 | aum_calculator = AUMCalculator(save_dir=self.savedir) 714 | 715 | train_data = OrderedDict() 716 | train_data["train_indices"] = self.train_set.indices 717 | train_data["valid_indices"] = (self.valid_set.indices if self.valid_set is not None else 718 | torch.tensor([], dtype=torch.long)) 719 | train_data["true_targets"] = self.train_set.targets 720 | train_data["assigned_targets"] = self.train_set.assigned_targets 721 | 722 | ####################################### 723 | ## OLD WAY OF DOING THINGS ############ 724 | ####################################### 725 | train_data["preds"] = torch.zeros(num_epochs, len(self.train_set), dtype=torch.long) 726 | train_data["correct_confs"] = torch.zeros(num_epochs, len(self.train_set), dtype=torch.half) 727 | train_data["correct_logits"] = torch.zeros(num_epochs, 728 | len(self.train_set), 729 | dtype=torch.half) 730 | train_data["incorrect_confs"] = torch.zeros(num_epochs, 731 | len(self.train_set), 732 | dtype=torch.half) 733 | train_data["incorrect_logits"] = torch.zeros(num_epochs, 734 | len(self.train_set), 735 | dtype=torch.half) 736 | train_data["top_incorrect"] = torch.zeros(num_epochs, len(self.train_set), dtype=torch.long) 737 | 738 | ####################################### 739 | ## OLD WAY OF DOING THINGS ############ 740 | ####################################### 741 | 742 | 743 | # Storage to log results 744 | results = [] 745 | 746 | # Train model 747 | best_error = 1 748 | for epoch in range(num_epochs): 749 | train_results = self.train_epoch( 750 | model=model, 751 | optimizer=optimizer, 752 | epoch=epoch, 753 | num_epochs=num_epochs, 754 | batch_size=batch_size, 755 | aum_calculator=aum_calculator, 756 | correct_confs_storage=train_data["correct_confs"], 757 | correct_logits_storage=train_data["correct_logits"], 758 | incorrect_confs_storage=train_data["incorrect_confs"], 759 | incorrect_logits_storage=train_data["incorrect_logits"], 760 | preds_storage=train_data["preds"], 761 | top_incorrect_storage=train_data["top_incorrect"], 762 | aum_wtr=aum_wtr, 763 | rand_weight=rand_weight, 764 | **kwargs) 765 | if self.valid_set is not None: 766 | valid_results = self.test(model=model, 767 | split="valid", 768 | batch_size=batch_size, 769 | epoch=epoch, 770 | **kwargs) 771 | else: 772 | valid_results = self.test(model, 773 | split="test", 774 | batch_size=batch_size, 775 | epoch=epoch, 776 | **kwargs) 777 | scheduler.step() 778 | 779 | # Determine if model is the best 780 | if self.valid_set is not None: 781 | self.save() 782 | elif best_error > valid_results.error: 783 | best_error = valid_results.error 784 | logging.info('New best error: %.4f' % valid_results.error) 785 | self.save() 786 | 787 | # Log results 788 | logging.info(f"\nTraining {repr(train_results)}") 789 | logging.info(f"\nValidation {repr(valid_results)}") 790 | logging.info('') 791 | results.append( 792 | OrderedDict([("epoch", f"{epoch + 1:03d}"), 793 | *[(f"train_{field}", val) for field, val in train_results.items()], 794 | *[(f"valid_{field}", val) for field, val in valid_results.items()]])) 795 | pd.DataFrame(results).set_index("epoch").to_csv( 796 | os.path.join(self.savedir, "train_log.csv")) 797 | 798 | # Save metadata around train set (like which labels were flipped) 799 | torch.save(train_data, os.path.join(self.savedir, "train_data.pth")) 800 | 801 | # Once we're finished training calculate aum 802 | aum_calculator.finalize() 803 | 804 | # Maybe test (last epoch) 805 | if test_at_end and self.valid_set is not None: 806 | test_results = self.test(model=model, **kwargs) 807 | logging.info(f"\nTest (no early stopping) {repr(test_results)}") 808 | shutil.copyfile(os.path.join(self.savedir, "results_test.csv"), 809 | os.path.join(self.savedir, "results_test_noearlystop.csv")) 810 | results.append( 811 | OrderedDict([(f"test_{field}", val) for field, val in test_results.items()])) 812 | pd.DataFrame(results).set_index("epoch").to_csv( 813 | os.path.join(self.savedir, "train_log.csv")) 814 | 815 | # Load best model 816 | self.save(suffix=".last") 817 | self.load() 818 | 819 | # Maybe test (best epoch) 820 | if test_at_end and self.valid_set is not None: 821 | test_results = self.test(model=model, **kwargs) 822 | logging.info(f"\nEarly Stopped Model Test {repr(test_results)}") 823 | results.append( 824 | OrderedDict([(f"test_best_{field}", val) for field, val in test_results.items()])) 825 | pd.DataFrame(results).set_index("epoch").to_csv(os.path.join(self.savedir, "train_log.csv")) 826 | 827 | return self 828 | 829 | def train_epoch(self, 830 | model, 831 | optimizer, 832 | epoch, 833 | num_epochs, 834 | batch_size=256, 835 | num_workers=0, 836 | aum_calculator=None, 837 | correct_confs_storage=None, 838 | correct_logits_storage=None, 839 | incorrect_confs_storage=None, 840 | incorrect_logits_storage=None, 841 | top_incorrect_storage=None, 842 | preds_storage=None, 843 | aum_wtr=False, 844 | rand_weight=False): 845 | stats = ["error", "loss"] 846 | meters = [util.AverageMeter() for _ in stats] 847 | result_class = util.result_class(stats) 848 | 849 | # Weighting - set up from GMM 850 | # NOTE: This is only used when removing threshold samples 851 | # TODO: some of this probably needs to be changed? 852 | if aum_wtr: 853 | counts = torch.zeros(len(self.train_set)) 854 | bad_probs = torch.zeros(len(self.train_set)) 855 | if isinstance(aum_wtr, str): 856 | aum_wtr = aum_wtr.split(",") 857 | for sub_aum_wtr in aum_wtr: 858 | aums_path = os.path.join(sub_aum_wtr, "aum_details.csv") 859 | if not os.path.exists(aums_path): 860 | self.generate_aum_details(load=sub_aum_wtr) 861 | aums_data = pd.read_csv(aums_path).drop( 862 | ["True Target", "Observed Target", "Label Flipped"], axis=1) 863 | counts += torch.tensor(~aums_data["Is Threshold Sample"].values).float() 864 | bad_probs += torch.tensor(aums_data["AUM_WTR"].values * 865 | ~aums_data["Is Threshold Sample"].values).float() 866 | counts.clamp_min_(1) 867 | good_probs = (1 - bad_probs / counts).to(next(model.parameters()).dtype).ceil() 868 | if torch.cuda.is_available(): 869 | good_probs = good_probs.cuda() 870 | logging.info(f"AUM WTR Score") 871 | logging.info(f"(Num samples removed: {good_probs.ne(1.).sum().item()})") 872 | elif rand_weight: 873 | logging.info("Rectified Normal Random Weighting") 874 | else: 875 | logging.info("Standard weighting") 876 | 877 | # Setup loader 878 | train_set = self.train_set 879 | loader = tqdm.tqdm(torch.utils.data.DataLoader(train_set, 880 | batch_size=batch_size, 881 | shuffle=True, 882 | num_workers=num_workers), 883 | desc=f"Train (Epoch {epoch + 1}/{num_epochs})") 884 | 885 | # Model on train mode 886 | model.train() 887 | for inputs, targets, indices in loader: 888 | optimizer.zero_grad() 889 | 890 | # Get types right 891 | if torch.cuda.is_available(): 892 | inputs = inputs.cuda() 893 | targets = targets.cuda() 894 | 895 | # Compute output and losses 896 | outputs = model(inputs) 897 | probs = torch.softmax(outputs, dim=-1) # OLD WAY OF DOING THINGS 898 | losses = self.loss_func(outputs, targets, reduction="none") 899 | preds = outputs.argmax(dim=-1) 900 | 901 | ####################################### 902 | ## OLD WAY OF DOING THINGS ############ 903 | ####################################### 904 | 905 | # Get the correct logit/conf 906 | index = torch.arange(targets.size(0)) 907 | correct_logits = outputs[index, targets] 908 | correct_confs = probs[index, targets] 909 | 910 | # Get the incorrect logit/conf 911 | one_hot = torch.nn.functional.one_hot(targets, 912 | num_classes=self.num_classes).type_as(probs) 913 | top_incorrect = (probs - one_hot).argmax(dim=-1) 914 | incorrect_logits = outputs[index, top_incorrect] 915 | incorrect_confs = probs[index, top_incorrect] 916 | 917 | ####################################### 918 | ## OLD WAY OF DOING THINGS ############ 919 | ####################################### 920 | 921 | 922 | # Compute loss weights 923 | if aum_wtr: 924 | weights = good_probs[indices.to(good_probs.device)] 925 | weights = weights.div(weights.sum()) 926 | elif rand_weight: 927 | weights = torch.randn(targets.size(), dtype=outputs.dtype, 928 | device=outputs.device).clamp_min_(0) 929 | weights = weights.div(weights.sum().clamp_min_(1e-10)) 930 | else: 931 | weights = torch.ones(targets.size(), dtype=outputs.dtype, 932 | device=outputs.device).div_(targets.numel()) 933 | 934 | # Backward through model 935 | loss = torch.dot(weights, losses) 936 | error = torch.ne(targets, preds).float().mean() 937 | loss.backward() 938 | 939 | # Update the model 940 | optimizer.step() 941 | 942 | # Update AUM values (after the first epoch due to variability of random initialization) 943 | if aum_calculator and epoch > 0: 944 | aum_calculator.update(logits=outputs.detach().cpu().half().float(), 945 | targets=targets.detach().cpu(), 946 | sample_ids=indices.tolist()) 947 | 948 | # measure and record stats 949 | batch_size = outputs.size(0) 950 | stat_vals = [error.item(), loss.item()] 951 | for stat_val, meter in zip(stat_vals, meters): 952 | meter.update(stat_val, batch_size) 953 | 954 | ####################################### 955 | ## OLD WAY OF DOING THINGS ############ 956 | ####################################### 957 | 958 | # Record losses 959 | # TODO: Rip this out and replace with AUM 960 | if correct_confs_storage is not None: 961 | correct_confs_storage[epoch, indices] = correct_confs.detach().cpu().half() 962 | if correct_logits_storage is not None: 963 | correct_logits_storage[epoch, indices] = correct_logits.detach().cpu().half() 964 | if incorrect_confs_storage is not None: 965 | incorrect_confs_storage[epoch, indices] = incorrect_confs.detach().cpu().half() 966 | if incorrect_logits_storage is not None: 967 | incorrect_logits_storage[epoch, indices] = incorrect_logits.detach().cpu().half() 968 | if top_incorrect_storage is not None: 969 | top_incorrect_storage[epoch, indices] = top_incorrect.detach().cpu() 970 | if preds_storage is not None: 971 | preds_storage[epoch, indices] = preds.detach().cpu() 972 | 973 | ####################################### 974 | ## OLD WAY OF DOING THINGS ############ 975 | ####################################### 976 | 977 | # print stats 978 | res = dict( 979 | (name, f"{meter.val:.3f} ({meter.avg:.3f})") for name, meter in zip(stats, meters)) 980 | loader.set_postfix(**res) 981 | 982 | # Return summary statistics 983 | return result_class(*[meter.avg for meter in meters]) 984 | 985 | 986 | if __name__ == "__main__": 987 | fire.Fire(Runner) 988 | -------------------------------------------------------------------------------- /examples/paper_replication/small_dataset_aum.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | if [ "$#" -ne 5 ]; then 4 | echo "You must enter exactly 5 command line arguments" 5 | fi 6 | 7 | datadir=$1 8 | dataset=$2 9 | seed=$3 10 | perc_mislabeled=$4 11 | noise_type=$5 12 | result_dir=$6 13 | NETTYPE="resnet" 14 | depth=32 15 | 16 | # General arguments for threshold sample trains 17 | args="--data ${datadir}/${dataset} --dataset ${dataset} --net_type ${NETTYPE} --depth ${depth}" 18 | args="${args} --perc_mislabeled ${perc_mislabeled} --noise_type ${noise_type} --seed ${seed} --use_threshold_samples" 19 | train_args="--num_epochs 150 --lr 0.1 --wd 1e-4 --batch_size 64 --num_workers 0" 20 | train_args="${train_args}" 21 | 22 | # First threshold sample run 23 | savedir1="${result_dir}/results/${dataset}_${NETTYPE}${depth}" 24 | savedir1="${savedir1}_percmislabeled${perc_mislabeled}_${noise_type}_threshold1_seed${seed}" 25 | cmd="python runner.py ${args} --save ${savedir1} --threshold_samples_set_idx 1 - train_for_aum_computation ${train_args} - done" 26 | echo $cmd 27 | if [ -z "${TESTRUN}" ]; then 28 | mkdir -p $savedir1 29 | echo $cmd > $savedir1/cmd.txt 30 | eval $cmd 31 | fi 32 | 33 | # Second threshold sample run 34 | savedir2="${result_dir}/results/${dataset}_${NETTYPE}${depth}" 35 | savedir2="${savedir2}_percmislabeled${perc_mislabeled}_${noise_type}_threshold2_seed${seed}" 36 | cmd="python runner.py ${args} --save ${savedir2} --threshold_samples_set_idx 2 - train_for_aum_computation ${train_args} - done" 37 | echo $cmd 38 | if [ -z "${TESTRUN}" ]; then 39 | mkdir -p $savedir2 40 | echo $cmd > $savedir2/cmd.txt 41 | eval $cmd 42 | fi 43 | 44 | # Compute AUMs for first threshold sample run 45 | cmd="python runner.py ${args} --save ${savedir1} --threshold_samples_set_idx 1 - generate_aum_details - done" 46 | echo $cmd 47 | if [ -z "${TESTRUN}" ]; then 48 | mkdir -p ${savedir1} 49 | eval $cmd 50 | fi 51 | 52 | # Compute AUMs for the second threshold sample run 53 | cmd="python runner.py ${args} --save ${savedir2} --threshold_samples_set_idx 2 - generate_aum_details - done" 54 | echo $cmd 55 | if [ -z "${TESTRUN}" ]; then 56 | mkdir -p ${savedir2} 57 | eval $cmd 58 | fi 59 | 60 | # Remove the identified mislabeled saples and retrain 61 | savedir="${result_dir}/results/${dataset}_${NETTYPE}${depth}" 62 | savedir="${savedir}_percmislabeled${perc_mislabeled}_${noise_type}_aumwtr_seed${seed}" 63 | args="--data ${datadir}/${dataset} --save ${savedir} --dataset ${dataset} --net_type ${NETTYPE} --depth ${depth}" 64 | args="${args} --perc_mislabeled ${perc_mislabeled} --noise_type ${noise_type} --seed ${seed}" 65 | train_args="--num_epochs 300 --lr 0.1 --wd 1e-4 --batch_size 256" 66 | train_args="${train_args} --aum_wtr ${savedir1},${savedir2}" 67 | cmd="python runner.py ${args} - train ${train_args} - done" 68 | echo $cmd 69 | if [ -z "${TESTRUN}" ]; then 70 | mkdir -p $savedir 71 | echo $cmd > $savedir/cmd.txt 72 | eval $cmd 73 | fi 74 | -------------------------------------------------------------------------------- /examples/paper_replication/small_dataset_baseline.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | if [ "$#" -ne 5 ]; then 4 | echo "You must enter exactly 5 command line arguments" 5 | fi 6 | 7 | datadir=$1 8 | dataset=$2 9 | seed=$3 10 | perc_mislabeled=$4 11 | noise_type=$5 12 | NETTYPE="resnet" 13 | depth=32 14 | 15 | savedir="results/${dataset}_${NETTYPE}${depth}" 16 | savedir="${savedir}_percmislabeled${perc_mislabeled}_${noise_type}_baseline_seed${seed}" 17 | 18 | args="--data ${datadir}/${dataset} --save ${savedir} --dataset ${dataset} --net_type ${NETTYPE} --depth ${depth}" 19 | args="${args} --perc_mislabeled ${perc_mislabeled} --noise_type ${noise_type} --seed ${seed}" 20 | 21 | train_args="--num_epochs 300 --lr 0.1 --wd 1e-4 --batch_size 256" 22 | train_args="${train_args}" 23 | 24 | cmd="python runner.py ${args} - train ${train_args} - done" 25 | echo $cmd 26 | if [ -z "${TESTRUN}" ]; then 27 | mkdir -p $savedir 28 | echo $cmd > $savedir/cmd.txt 29 | eval $cmd 30 | fi 31 | -------------------------------------------------------------------------------- /examples/paper_replication/util.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | import torch 4 | 5 | 6 | class AverageMeter(object): 7 | """ 8 | Computes and stores the average and current value 9 | Copied from: https://github.com/pytorch/examples/blob/master/imagenet/main.py 10 | """ 11 | def __init__(self): 12 | self.reset() 13 | 14 | def reset(self): 15 | self.val = 0 16 | self.avg = 0 17 | self.sum = 0 18 | self.count = 0 19 | 20 | def update(self, val, n=1): 21 | self.val = val 22 | self.sum += val * n 23 | self.count += n 24 | self.avg = self.sum / self.count 25 | 26 | 27 | class Welford(object): 28 | """ 29 | Computes and stores a running average and variance 30 | """ 31 | def __init__(self): 32 | self.reset() 33 | 34 | def reset(self): 35 | self._count = 0 36 | self._mean = None 37 | self._sum_sq = None 38 | 39 | # for a new value newValue, compute the new count, new mean, the new M2. 40 | # mean accumulates the mean of the entire dataset 41 | # M2 aggregates the squared distance from the mean 42 | # count aggregates the number of samples seen so far 43 | def update(self, new_value, batch=True): 44 | if isinstance(new_value, torch.autograd.Variable): 45 | new_value = new_value.data 46 | if not batch: 47 | new_value = new_value.unsqueeze(0) 48 | 49 | self._mean = new_value.new( 50 | *list(new_value.size())[1:]).zero_() if self._mean is None else self._mean 51 | self._sum_sq = new_value.new( 52 | *list(new_value.size())[1:]).zero_() if self._sum_sq is None else self._sum_sq 53 | 54 | for item in new_value: 55 | self._count += 1 56 | delta = item - self._mean 57 | self._mean += (item - self._mean) / float(self._count) 58 | self._sum_sq += delta * (item - self._mean) 59 | 60 | @property 61 | def mean(self): 62 | return self._mean 63 | 64 | @property 65 | def var(self): 66 | return self._sum_sq / (self._count - 1) 67 | 68 | @property 69 | def std(self): 70 | return self.var.sqrt() 71 | 72 | 73 | def result_class(fields): 74 | class Result(namedtuple('Result', fields)): 75 | def items(self): 76 | for field in self._fields: 77 | yield (field, getattr(self, field)) 78 | 79 | def to_str(self): 80 | return ",".join(str(item) for item in self) 81 | 82 | def __repr__(self): 83 | res = 'Results:\n' 84 | fieldstrs = [] 85 | for key in self._fields: 86 | fieldstrs.append(' - %s: %s' % (key, repr(getattr(self, key)))) 87 | res = res + '\n'.join(fieldstrs) 88 | return res 89 | 90 | return Result 91 | 92 | 93 | def output_class(fields): 94 | class Output(namedtuple('Output', fields)): 95 | def __repr__(self): 96 | res = 'Outputs:\n' 97 | fieldstrs = [] 98 | for key in self._fields: 99 | fieldstrs.append(' - %s: %s' % (key, repr(getattr(self, key).size()))) 100 | res = res + '\n'.join(fieldstrs) 101 | return res 102 | 103 | return Output 104 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.3.1 2 | pandas>=0.25.3 3 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 100 3 | 4 | [yapf] 5 | based_on_style = pep8 6 | COLUMN_LIMIT = 100 7 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | setup.py 3 | """ 4 | 5 | import os 6 | from typing import Dict 7 | 8 | from setuptools import find_packages, setup 9 | 10 | NAME = "aum" 11 | AUTHOR = "ASAPP Inc." 12 | EMAIL = "jshapiro@asapp.com" 13 | DESCRIPTION = "Library for calculating area under the margin ranking." 14 | 15 | 16 | def readme(): 17 | with open('README.md', encoding='utf-8') as f: 18 | return f.read() 19 | 20 | 21 | def required(): 22 | with open('requirements.txt') as f: 23 | return f.read().splitlines() 24 | 25 | 26 | # So that we don't import flambe. 27 | VERSION: Dict[str, str] = {} 28 | with open("aum/version.py", "r") as version_file: 29 | exec(version_file.read(), VERSION) 30 | 31 | setup( 32 | name=NAME, 33 | version=os.environ.get("TAG_VERSION", VERSION['VERSION']), 34 | description=DESCRIPTION, 35 | 36 | # Author information 37 | author=AUTHOR, 38 | author_email=EMAIL, 39 | 40 | # What is packaged here. 41 | packages=find_packages(), 42 | install_requires=required(), 43 | include_package_data=True, 44 | python_requires='>=3.7', 45 | zip_safe=True) 46 | -------------------------------------------------------------------------------- /test/requirements-dev.txt: -------------------------------------------------------------------------------- 1 | pytest~=5.3.5 2 | yapf~=0.29.0 -------------------------------------------------------------------------------- /test/test_aum.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import asdict 3 | 4 | import pandas as pd 5 | import pytest 6 | import torch 7 | 8 | from aum import AUMCalculator, AUMRecord 9 | 10 | 11 | @pytest.fixture(scope='module') 12 | def aum_data(): 13 | inputs = [] 14 | outputs = [] 15 | 16 | logits_1 = torch.tensor([[1., 2., 3.], [6., 5., 4.]]) 17 | targets_1 = torch.tensor([1, 0]) 18 | sample_ids_1 = ['a', 'b'] 19 | inputs.append({'logits': logits_1, 'targets': targets_1, 'sample_ids': sample_ids_1}) 20 | 21 | logits_2 = torch.tensor([[7., 8., 9.], [12., 11., 10.]]) 22 | targets_2 = torch.tensor([2, 1]) 23 | sample_ids_2 = ['b', 'c'] 24 | inputs.append({'logits': logits_2, 'targets': targets_2, 'sample_ids': sample_ids_2}) 25 | 26 | outputs.append({ 27 | 'a': 28 | AUMRecord(sample_id='a', 29 | num_measurements=1, 30 | target_logit=1, 31 | target_val=2., 32 | other_logit=2, 33 | other_val=3, 34 | margin=-1., 35 | aum=-1.), 36 | 'b': 37 | AUMRecord(sample_id='b', 38 | num_measurements=1, 39 | target_logit=0, 40 | target_val=6., 41 | other_logit=1, 42 | other_val=5., 43 | margin=1., 44 | aum=1.) 45 | }) 46 | 47 | outputs.append({ 48 | 'b': 49 | AUMRecord(sample_id='b', 50 | num_measurements=2, 51 | target_logit=2, 52 | target_val=9., 53 | other_logit=1, 54 | other_val=8., 55 | margin=1., 56 | aum=1.), 57 | 'c': 58 | AUMRecord(sample_id='c', 59 | num_measurements=1, 60 | target_logit=1, 61 | target_val=11., 62 | other_logit=0, 63 | other_val=12., 64 | margin=-1., 65 | aum=-1.) 66 | }) 67 | 68 | return (inputs, outputs) 69 | 70 | 71 | def test_aum_update(aum_data): 72 | inputs, outputs = aum_data 73 | aum_calculator = AUMCalculator(save_dir=None) 74 | 75 | expected_results = aum_calculator.update(inputs[0]['logits'], inputs[0]['targets'], 76 | inputs[0]['sample_ids']) 77 | assert expected_results == outputs[0] 78 | 79 | expected_results = aum_calculator.update(inputs[1]['logits'], inputs[1]['targets'], 80 | inputs[1]['sample_ids']) 81 | assert expected_results == outputs[1] 82 | 83 | 84 | def test_aum_finalize(tmp_path, aum_data): 85 | inputs, outputs = aum_data 86 | save_dir = tmp_path.as_posix() 87 | aum_calculator = AUMCalculator(save_dir=save_dir, compressed=False) 88 | 89 | for data in inputs: 90 | aum_calculator.update(data['logits'], data['targets'], data['sample_ids']) 91 | 92 | aum_calculator.finalize() 93 | final_vals = pd.read_csv(os.path.join(save_dir, 'aum_values.csv')) 94 | detailed_vals = pd.read_csv(os.path.join(save_dir, 'full_aum_records.csv')) 95 | 96 | # Lets first verify detailed vals 97 | records = [] 98 | for output in outputs: 99 | records.extend(output.values()) 100 | 101 | expected_detailed_vals = pd.DataFrame([ 102 | asdict(record) for record in records 103 | ]).sort_values(by=['sample_id', 'num_measurements']).reset_index(drop=True) 104 | assert detailed_vals.equals(expected_detailed_vals) 105 | 106 | # Now lets verfiy the final vals 107 | final_dict = {record.sample_id: record.aum for record in records} 108 | expected_final_vals = [] 109 | for key, val in final_dict.items(): 110 | expected_final_vals.append({'sample_id': key, 'aum': val}) 111 | expected_final_vals = pd.DataFrame(expected_final_vals).sort_values( 112 | by='aum', ascending=False).reset_index(drop=True) 113 | 114 | assert final_vals.equals(expected_final_vals) 115 | --------------------------------------------------------------------------------