├── .github ├── dependabot.yml └── workflows │ ├── release.yaml │ └── test.yaml ├── .gitignore ├── LICENSE ├── README.md ├── examples ├── 1_cifar_100_train_loop_exposed.py ├── 2_cifar_100_trainer.py ├── 3_cifar_100N_Fine_train_loop_exposed.py └── CIFAR-100_human.pt ├── pyproject.toml ├── src └── gradient_agreement_filtering │ ├── __init__.py │ └── gaf.py └── test ├── run_sweeps.sh └── test_gaf.py /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates 2 | version: 2 3 | updates: 4 | - package-ecosystem: "github-actions" 5 | directory: "/" # Location of package manifests 6 | schedule: 7 | interval: "weekly" -------------------------------------------------------------------------------- /.github/workflows/release.yaml: -------------------------------------------------------------------------------- 1 | name: Publish to PyPI 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | build-and-publish: 9 | runs-on: ubuntu-latest 10 | 11 | steps: 12 | - name: Checkout code 13 | uses: actions/checkout@v3 14 | 15 | - name: Set up Python 16 | uses: actions/setup-python@v4 17 | with: 18 | python-version: '3.9' # Adjust the Python version as needed 19 | 20 | - name: Install dependencies 21 | run: | 22 | python -m pip install --upgrade pip 23 | pip install build 24 | 25 | - name: Build package 26 | run: python -m build 27 | 28 | - name: Publish package to PyPI 29 | uses: pypa/gh-action-pypi-publish@v1.5.0 30 | with: 31 | password: ${{ secrets.PYPI_API_TOKEN }} 32 | -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: Unit 2 | 3 | # Run on pushes to main branch 4 | on: 5 | push: 6 | branches: 7 | - main 8 | 9 | 10 | jobs: 11 | test: 12 | # Test Python 3.11 and above 13 | runs-on: ubuntu-latest 14 | strategy: 15 | matrix: 16 | version: [3.11, 3.12] 17 | 18 | steps: 19 | - name: Checkout code 20 | uses: actions/checkout@v4 21 | 22 | - name: Set up Python 23 | uses: actions/setup-python@v5 24 | with: 25 | python-version: ${{ matrix.version }} 26 | 27 | - name: Install dependencies 28 | run: | 29 | pip install -U pip 30 | pip install -e ".[dev]" 31 | pip install coverage coveralls # Install coverage and coveralls libraries 32 | 33 | - name: Run tests with coverage 34 | run: | 35 | coverage run -m pytest # Adjust this command if you're using a different test runner 36 | 37 | # Uncomment the following steps if you want to generate a coverage report and upload it to Coveralls 38 | 39 | # - name: Generate coverage report 40 | # run: coverage report 41 | 42 | # - name: Upload coverage to Coveralls 43 | # env: 44 | # COVERALLS_REPO_TOKEN: ${{ secrets.COVERALLS_REPO_TOKEN }} # Add your Coveralls token to repository secrets 45 | # run: coveralls -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # IDE files 2 | .idea/ 3 | .vscode/ 4 | .python-version 5 | 6 | # Wandb files 7 | wandb/* 8 | 9 | # OS Files 10 | *.DS_Store 11 | 12 | # Model files 13 | data/* 14 | checkpoints/* 15 | 16 | # Byte-compiled / optimized / DLL files 17 | __pycache__/ 18 | *.py[cod] 19 | *$py.class 20 | 21 | # C extensions 22 | *.so 23 | 24 | # Distribution / packaging 25 | .Python 26 | build/ 27 | develop-eggs/ 28 | dist/ 29 | downloads/ 30 | eggs/ 31 | .eggs/ 32 | lib/ 33 | lib64/ 34 | parts/ 35 | sdist/ 36 | var/ 37 | wheels/ 38 | share/python-wheels/ 39 | *.egg-info/ 40 | .installed.cfg 41 | *.egg 42 | MANIFEST 43 | 44 | # PyInstaller 45 | # Usually these files are written by a python script from a template 46 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 47 | *.manifest 48 | *.spec 49 | 50 | # Installer logs 51 | pip-log.txt 52 | pip-delete-this-directory.txt 53 | 54 | # Unit test / coverage reports 55 | htmlcov/ 56 | .tox/ 57 | .nox/ 58 | .coverage 59 | .coverage.* 60 | .cache 61 | nosetests.xml 62 | coverage.xml 63 | *.cover 64 | *.py,cover 65 | .hypothesis/ 66 | .pytest_cache/ 67 | cover/ 68 | 69 | # Translations 70 | *.mo 71 | *.pot 72 | 73 | # Django stuff: 74 | *.log 75 | local_settings.py 76 | db.sqlite3 77 | db.sqlite3-journal 78 | 79 | # Flask stuff: 80 | instance/ 81 | .webassets-cache 82 | 83 | # Scrapy stuff: 84 | .scrapy 85 | 86 | # Sphinx documentation 87 | docs/_build/ 88 | 89 | # PyBuilder 90 | .pybuilder/ 91 | target/ 92 | 93 | # Jupyter Notebook 94 | .ipynb_checkpoints 95 | 96 | # IPython 97 | profile_default/ 98 | ipython_config.py 99 | 100 | # pyenv 101 | # For a library or package, you might want to ignore these files since the code is 102 | # intended to run in multiple environments; otherwise, check them in: 103 | # .python-version 104 | 105 | # pipenv 106 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 107 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 108 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 109 | # install all needed dependencies. 110 | #Pipfile.lock 111 | 112 | # poetry 113 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 114 | # This is especially recommended for binary packages to ensure reproducibility, and is more 115 | # commonly ignored for libraries. 116 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 117 | #poetry.lock 118 | 119 | # pdm 120 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 121 | #pdm.lock 122 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 123 | # in version control. 124 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 125 | .pdm.toml 126 | .pdm-python 127 | .pdm-build/ 128 | 129 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 130 | __pypackages__/ 131 | 132 | # Celery stuff 133 | celerybeat-schedule 134 | celerybeat.pid 135 | 136 | # SageMath parsed files 137 | *.sage.py 138 | 139 | # Environments 140 | .env 141 | .venv 142 | env/ 143 | venv/ 144 | ENV/ 145 | env.bak/ 146 | venv.bak/ 147 | 148 | # Spyder project settings 149 | .spyderproject 150 | .spyproject 151 | 152 | # Rope project settings 153 | .ropeproject 154 | 155 | # mkdocs documentation 156 | /site 157 | 158 | # mypy 159 | .mypy_cache/ 160 | .dmypy.json 161 | dmypy.json 162 | 163 | # Pyre type checker 164 | .pyre/ 165 | 166 | # pytype static type analyzer 167 | .pytype/ 168 | 169 | # Cython debug symbols 170 | cython_debug/ 171 | 172 | # PyCharm 173 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 174 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 175 | # and can be added to the global gitignore or merged into this file. For a more nuclear 176 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 177 | #.idea/ 178 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Fchaubard 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Gradient Agreement Filtering (GAF) 2 | 3 | This package implements the Gradient Agreement Filtering (GAF) optimization algorithm. 4 | 5 | GAF is a novel optimization algorithm that improves gradient-based optimization by filtering out gradients of data batches that do not agree with each other and nearly eliminates the need for a validation set without risk of overfitting (even with noisy labels). It bolts on top of existing optimization procedures such as SGD, SGD with Nesterov momentum, Adam, AdamW, RMSProp, etc and outperforms in all cases. Full paper here: 6 | ``` 7 | https://arxiv.org/pdf/2412.18052 8 | ``` 9 | 10 | ## Features 11 | 12 | The package provides a number of features for the example implementation of GAF. Specifically, it: 13 | 14 | - Implements Gradient Agreement Filtering based on cosine distance. 15 | - Supports multiple optimizers: SGD, SGD with Nesterov momentum, Adam, AdamW, RMSProp. 16 | - Provides examples on application of GAF in training image classifiers on CIFAR-100 and CIFAR-100N-Fine datasets. 17 | - Allows for label noise injection by flipping a percentage of labels. 18 | - Customizable hyperparameters via command-line arguments. 19 | - Logging and tracking with [Weights & Biases (wandb)](https://wandb.ai/). 20 | 21 | ## Installation 22 | 23 | There are a few ways to install the package and run the examples. 24 | 25 | **PyPI Installation** 26 | 27 | If you only with to take advantage of the package and not the examples, you can install the package via PyPI: 28 | 29 | ```bash 30 | pip install gradient-agreement-filtering 31 | ``` 32 | 33 | **Local Installation** 34 | 35 | To install the package and run the examples locally you can execute the following commands: 36 | 37 | You can install via: 38 | ```bash 39 | git clone https://github.com/Fchaubard/gradient_agreement_filtering.git 40 | cd gradient_agreement_filtering 41 | pip install . 42 | ``` 43 | 44 | If you wish to install with wandb, for the last step instead use 45 | 46 | ```bash 47 | pip install ".[all]" 48 | ``` 49 | 50 | **Local Installation (Development)** If you wish to install the package with added 51 | development dependencies, you can execute the following commands: 52 | 53 | ```bash 54 | git clone https://github.com/Fchaubard/gradient_agreement_filtering.git 55 | cd gradient_agreement_filtering 56 | pip install ".[dev]" 57 | ``` 58 | 59 | 60 | 61 | ## Usage 62 | 63 | We provide two ways to easily incorporate GAF into your existing training. 64 | 1. `step_GAF()`: 65 | If you want to use GAF inside your existing train loop, you can just replace your typical: 66 | 67 | ``` 68 | ... 69 | optimizer.zero_grad() 70 | outputs = model(batch) 71 | loss = criterion(outputs, labels) 72 | loss.backward() 73 | optimizer.step() 74 | ... 75 | ``` 76 | 77 | with one call to step_GAF() as per below: 78 | 79 | ``` 80 | from gradient_agreement_filtering import step_GAF 81 | ... 82 | results = step_GAF(model, 83 | optimizer, 84 | criterion, 85 | list_of_microbatches, 86 | wandb=True, 87 | verbose=True, 88 | cos_distance_thresh=0.97, 89 | device=gpu_device) 90 | ... 91 | ``` 92 | 93 | 2. `train_GAF()`: 94 | 95 | If you want to use GAF as the train loop, you can just replace your typical hugging face / keras style interface: 96 | 97 | ``` 98 | trainer.Train() 99 | ``` 100 | 101 | with one call to train_GAF() as per below: 102 | 103 | ``` 104 | from gradient_agreement_filtering import train_GAF 105 | ... 106 | train_GAF(model, 107 | args, 108 | train_dataset, 109 | val_dataset, 110 | optimizer, 111 | criterion, 112 | wandb=True, 113 | verbose=True, 114 | cos_distance_thresh=0.97, 115 | device=gpu_device) 116 | ... 117 | ``` 118 | 119 | ### NOTE: Running with wandb 120 | 121 | If you want to run with wandb, you will need to set your WANDB_API_KEY. You can do this in a few ways: 122 | 123 | 1. You can login on the system first then run the .py via: 124 | ```bash 125 | wandb login 126 | ``` 127 | 128 | 2. Add the following line to the top of your .py file: 129 | 130 | ```python 131 | os.environ["WANDB_API_KEY"] = "" 132 | ``` 133 | 134 | 3. Or you can prepend any of the calls below with: 135 | 136 | ```bash 137 | WANDB_API_KEY= python *.py 138 | ``` 139 | 140 | ## Examples 141 | 142 | We provide three examples to demonstrate the use of GAF in training image classifiers on CIFAR-100 and CIFAR-100N-Fine datasets. 143 | 144 | ### 1_cifar_100_train_loop_exposed.py 145 | 146 | This file uses `step_GAF()` to train a ResNet18 model on the CIFAR-100 dataset using PyTorch with the ability to add noise to the labels to observe how GAF performs under noisy conditions. The code supports various optimizers and configurations, allowing you to experiment with different settings to understand the impact of GAF on model training. 147 | 148 | Example call: 149 | ```bash 150 | python examples/1_cifar_100_train_loop_exposed.py --GAF True --optimizer "SGD+Nesterov+val_plateau" --learning_rate 0.01 --momentum 0.9 --nesterov True --wandb True --verbose True --num_samples_per_class_per_batch 1 --num_batches_to_force_agreement 2 --label_error_percentage 0.15 --cos_distance_thresh 0.97 151 | ``` 152 | 153 | ### 2_cifar_100_trainer.py 154 | This file uses `train_GAF()` to train a ResNet18 model on the CIFAR-100 dataset using PyTorch just to show how it works. 155 | 156 | Example call: 157 | ``` 158 | python examples/2_cifar_100_trainer.py 159 | ``` 160 | 161 | ### 3_cifar_100N_train_loop_exposed.py 162 | 163 | This file uses `step_GAF()` to train a ResNet34 model on the CIFAR-100N-Fine dataset using PyTorch to observe how GAF performs under typical labeling noise. The code supports various optimizers and configurations, allowing you to experiment with different settings to understand the impact of GAF on model training. 164 | 165 | Example call: 166 | ```bash 167 | python examples/3_cifar_100N_Fine_train_loop_exposed.py --GAF True --optimizer "SGD+Nesterov+val_plateau" --cifarn True --learning_rate 0.01 --momentum 0.9 --nesterov True --wandb True --verbose True --num_samples_per_class_per_batch 2 --num_batches_to_force_agreement 2 --cos_distance_thresh 0.97 168 | ``` 169 | 170 | ## Running sweeps 171 | 172 | We also provide a shell script to run a sweep for convenience. The script spawns screen sessions and will randomly allocate the runs to GPUs. It can be run multiple times without . You should make sure you update the script with your own WANDB_API_KEY before running. Here is how you run it: 173 | 174 | ### test/run_sweeps.sh 175 | 176 | Example call: 177 | ```bash 178 | cd test 179 | chmod +x run_sweeps.sh 180 | ./run_sweeps.sh 181 | ``` 182 | 183 | ## BibTex 184 | 185 | If GAF is useful in your own research, please use the following BibTeX entry: 186 | 187 | ``` 188 | @misc{chaubard2024gradientaveragingparalleloptimization, 189 | title={Beyond Gradient Averaging in Parallel Optimization: Improved Robustness through Gradient Agreement Filtering}, 190 | author={Francois Chaubard and Duncan Eddy and Mykel J. Kochenderfer}, 191 | year={2024}, 192 | eprint={2412.18052}, 193 | archivePrefix={arXiv}, 194 | primaryClass={cs.LG}, 195 | url={https://arxiv.org/abs/2412.18052}, 196 | } 197 | ``` 198 | 199 | ## Acknowledgements 200 | 201 | We would like to acknowledge and thank Alex Tzikas, Harrison Delecki, and Francois Chollet who provided invaluable help through discussions and feedback. 202 | 203 | 204 | ## License 205 | 206 | This package is licensed under the MIT license. See [LICENSE](LICENSE) for details. 207 | 208 | -------------------------------------------------------------------------------- /examples/1_cifar_100_train_loop_exposed.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to train ResNet18 on CIFAR-100 with Gradient Agreement Filtering (GAF) and various optimizers. 3 | 4 | This script allows for experimentation with different optimizers and GAF settings. It supports label noise injection and logs metrics using Weights & Biases (wandb). 5 | 6 | We expose the train loop to allow maximum flexibility. 7 | 8 | Usage: 9 | python examples/1_cifar_100_train_loop_exposed.py [OPTIONS] 10 | 11 | Example: 12 | python examples/1_cifar_100_train_loop_exposed.py --GAF True --optimizer "SGD+Nesterov+val_plateau" --learning_rate 0.01 --momentum 0.9 --nesterov True --wandb True --verbose True --num_samples_per_class_per_batch 2 --num_batches_to_force_agreement 2 --label_error_percentage 0.15 --cos_distance_thresh 0.97 13 | 14 | Author: 15 | Francois Chaubard 16 | 17 | Date: 18 | 2024-12-03 19 | """ 20 | import torch 21 | import torch.nn as nn 22 | from torchvision import datasets, transforms, models 23 | import torch.optim as optim 24 | from torch.utils.data import DataLoader, Subset 25 | from collections import defaultdict 26 | import numpy as np 27 | import random 28 | import os 29 | import argparse 30 | import time 31 | 32 | # Try to import wandb 33 | try: 34 | import wandb 35 | except ImportError: 36 | wandb = None 37 | 38 | from gradient_agreement_filtering import step_GAF 39 | 40 | 41 | #################################################################################################### 42 | 43 | # (OPTIONAL) WANDB SETUP 44 | 45 | # Ensure to set your WandB API key as an environment variable or directly in the code 46 | # os.environ["WANDB_API_KEY"] = "" 47 | 48 | #################################################################################################### 49 | 50 | # DEFINITIONS 51 | 52 | def str2bool(v): 53 | """Parse boolean values from the command line.""" 54 | if isinstance(v, bool): 55 | return v 56 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 57 | return True 58 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 59 | return False 60 | else: 61 | raise argparse.ArgumentTypeError('Boolean value expected.') 62 | 63 | 64 | def sample_iid_mbs(full_dataset, class_indices, batch_size): 65 | """ 66 | Samples an IID minibatch for standard training. 67 | 68 | Args: 69 | full_dataset (Dataset): The full training dataset. 70 | class_indices (dict): A mapping from class labels to data indices. 71 | batch_size (int): The size of the batch to sample. 72 | 73 | Returns: 74 | Subset: A subset of the dataset representing the minibatch. 75 | """ 76 | num_classes = len(class_indices) 77 | samples_per_class = batch_size // num_classes 78 | batch_indices = [] 79 | for cls in class_indices: 80 | indices = random.sample(class_indices[cls], samples_per_class) 81 | batch_indices.extend(indices) 82 | # If batch_size is not divisible by num_classes, fill the rest randomly 83 | remaining = batch_size - len(batch_indices) 84 | if remaining > 0: 85 | all_indices = [idx for idx in range(len(full_dataset))] 86 | batch_indices.extend(random.sample(all_indices, remaining)) 87 | # Create a Subset 88 | batch = Subset(full_dataset, batch_indices) 89 | return batch 90 | 91 | def flip_labels(train_dataset, label_error_percentage=0.1, num_classes=100): 92 | """ 93 | Flips a percentage of labels in the training dataset to simulate label noise. 94 | 95 | Args: 96 | train_dataset (Dataset): The training dataset. 97 | label_error_percentage (float): The percentage of labels to flip (between 0 and 1). 98 | num_classes (int): The total number of classes. 99 | 100 | Returns: 101 | Dataset: The training dataset with labels flipped. 102 | """ 103 | num_samples = len(train_dataset.targets) 104 | num_to_flip = int(label_error_percentage * num_samples) 105 | all_indices = list(range(num_samples)) 106 | flip_indices = random.sample(all_indices, num_to_flip) 107 | 108 | for idx in flip_indices: 109 | original_label = train_dataset.targets[idx] 110 | # Exclude the original label to ensure the label is actually changed 111 | wrong_labels = list(range(num_classes)) 112 | wrong_labels.remove(original_label) 113 | new_label = random.choice(wrong_labels) 114 | train_dataset.targets[idx] = new_label 115 | 116 | return train_dataset 117 | 118 | def sample_iid_mbs_for_GAF(full_dataset, class_indices, n, m): 119 | """ 120 | Samples 'n' independent minibatches, each containing an equal number of samples from each class, m. 121 | 122 | Args: 123 | full_dataset (Dataset): The full training dataset. 124 | class_indices (dict): A mapping from class labels to data indices. 125 | n (int): The number of microbatches to sample. 126 | m (int): The number of images per class to sample per microbatch. 127 | 128 | Returns: 129 | list: A list of Subsets representing the minibatches. 130 | """ 131 | # Initialize a list to hold indices for each batch 132 | batch_indices_list = [[] for _ in range(n)] 133 | for clazz in class_indices: 134 | num_samples_per_class = m # Adjust if you want more samples per class per batch 135 | total_samples_needed = num_samples_per_class * n 136 | available_indices = class_indices[clazz] 137 | # Ensure there are enough indices 138 | if len(available_indices) < total_samples_needed: 139 | multiples = (total_samples_needed // len(available_indices)) + 1 140 | extended_indices = (available_indices * multiples)[:total_samples_needed] 141 | else: 142 | extended_indices = random.sample(available_indices, total_samples_needed) 143 | for i in range(n): 144 | start_idx = i * num_samples_per_class 145 | end_idx = start_idx + num_samples_per_class 146 | batch_indices_list[i].extend(extended_indices[start_idx:end_idx]) 147 | # Create Subsets for each batch 148 | batches = [Subset(full_dataset, batch_indices) for batch_indices in batch_indices_list] 149 | return batches 150 | 151 | 152 | def evaluate(model, dataloader, device): 153 | """ 154 | Evaluates the model on the validation or test dataset. 155 | 156 | Args: 157 | model (nn.Module): The model to evaluate. 158 | dataloader (DataLoader): The DataLoader for the dataset. 159 | device (torch.device): The device to use. 160 | 161 | Returns: 162 | tuple: Average loss and top-1 accuracy. 163 | """ 164 | model.eval() 165 | total_loss = 0.0 166 | correct_top1 = 0 167 | total = 0 168 | with torch.no_grad(): 169 | for data in dataloader: 170 | images, labels = data[0].to(device), data[1].to(device) 171 | outputs = model(images) 172 | loss = criterion(outputs, labels) 173 | total_loss += loss.item() * images.size(0) 174 | _, predicted = torch.max(outputs.data, 1) 175 | total += labels.size(0) 176 | correct_top1 += (predicted == labels).sum().item() 177 | avg_loss = total_loss / total 178 | accuracy_top1 = correct_top1 / total 179 | return avg_loss, accuracy_top1 180 | 181 | #################################################################################################### 182 | 183 | # MAIN TRAINING LOOP 184 | 185 | if __name__ == "__main__": 186 | 187 | # Define the list of available optimizer types 188 | optimizer_types = ["SGD", "SGD+Nesterov", "SGD+Nesterov+val_plateau", "Adam", "AdamW", "RMSProp"] 189 | 190 | parser = argparse.ArgumentParser(description='Train ResNet18 on CIFAR-100 with various optimizers and GAF.') 191 | 192 | # General training parameters 193 | parser.add_argument('--GAF', type=str2bool, default=True, help='Enable Gradient Agreement Filtering (True or False)') 194 | parser.add_argument('--learning_rate', type=float, default=0.01, help='Learning rate for the optimizer') 195 | parser.add_argument('--weight_decay', type=float, default=1e-2, help='Weight decay factor') 196 | parser.add_argument('--batch_size', type=int, default=128, help='Batch size for training') 197 | parser.add_argument('--num_val_epochs', type=int, default=2, help='Number of epochs between validation checks') 198 | parser.add_argument('--optimizer', type=str, default='SGD', choices=optimizer_types, help='Optimizer type to use') 199 | parser.add_argument('--num_batches_to_force_agreement', type=int, default=2, help='Number of batches to compute gradients for agreement (must be > 1)') 200 | parser.add_argument('--epochs', type=int, default=10000, help='Total number of training epochs') 201 | parser.add_argument('--num_samples_per_class_per_batch', type=int, default=1, help='Number of samples per class per batch when using GAF') 202 | parser.add_argument('--label_error_percentage', type=float, default=0, help='Percentage of labels to flip in the training dataset to simulate label noise (between 0 and 1)') 203 | parser.add_argument('--cos_distance_thresh', type=float, default=1, help='Threshold for cosine distance in gradient agreement filtering. Tau in the paper.') 204 | 205 | # Optimizer-specific parameters 206 | parser.add_argument('--momentum', type=float, default=0.0, help='Momentum factor for SGD and RMSProp optimizers') 207 | parser.add_argument('--nesterov', type=str2bool, default=False, help='Use Nesterov momentum (True or False)') 208 | parser.add_argument('--betas', type=float, nargs=2, default=(0.9, 0.999), help='Betas for Adam and AdamW optimizers') 209 | parser.add_argument('--eps', type=float, default=1e-8, help='Epsilon value for optimizers') 210 | parser.add_argument('--alpha', type=float, default=0.99, help='Alpha value for RMSProp optimizer') 211 | parser.add_argument('--centered', type=str2bool, default=False, help='Centered RMSProp (True or False)') 212 | parser.add_argument('--scheduler_patience', type=int, default=100, help='Patience for ReduceLROnPlateau scheduler') 213 | parser.add_argument('--scheduler_factor', type=int, default=0.1, help='Discount factor for ReduceLROnPlateau scheduler') 214 | 215 | # logging 216 | parser.add_argument('--wandb', type=str2bool, default=False, help='Log to wandb (True or False)') 217 | parser.add_argument('--verbose', type=str2bool, default=False, help='Print out logs (True or False)') 218 | 219 | # Parse arguments 220 | args = parser.parse_args() 221 | config = vars(args) 222 | 223 | # Set unused optimizer-specific configs to 'NA' 224 | optimizer = config['optimizer'] 225 | all_params = ['momentum', 'nesterov', 'betas', 'eps', 'alpha', 'centered', 'scheduler_patience', 'scheduler_factor'] 226 | 227 | # Define which parameters are used by each optimizer 228 | optimizer_params = { 229 | 'SGD': ['momentum', 'nesterov'], 230 | 'SGD+Nesterov': ['momentum', 'nesterov'], 231 | 'SGD+Nesterov+val_plateau': ['momentum', 'nesterov', 'scheduler_patience', 'scheduler_factor'], 232 | 'Adam': ['betas', 'eps'], 233 | 'AdamW': ['betas', 'eps'], 234 | 'RMSProp': ['momentum', 'eps', 'alpha', 'centered'], 235 | } 236 | 237 | # Get the list of parameters used by the selected optimizer 238 | used_params = optimizer_params.get(optimizer, []) 239 | 240 | # Set unused parameters to 'NA' 241 | for param in all_params: 242 | if param not in used_params: 243 | config[param] = 'NA' 244 | 245 | # Check for available device (GPU or CPU) 246 | if torch.cuda.is_available(): 247 | num_gpus = torch.cuda.device_count() 248 | device_index = random.randint(0, num_gpus - 1) # Pick a random device index 249 | device = torch.device(f"cuda:{device_index}") 250 | elif torch.mps.is_available(): 251 | device = torch.device("mps") 252 | else: 253 | device = torch.device("cpu") 254 | 255 | if config['verbose']: 256 | print(f"Using device: {device}") 257 | 258 | # Set random seeds for reproducibility 259 | torch.manual_seed(0) 260 | np.random.seed(0) 261 | random.seed(0) 262 | 263 | # Data transformations for training and testing 264 | transform_train = transforms.Compose([ 265 | transforms.RandomCrop(32, padding=4), 266 | transforms.RandomHorizontalFlip(), 267 | transforms.ToTensor(), 268 | transforms.Normalize((0.5071, 0.4867, 0.4408), 269 | (0.2675, 0.2565, 0.2761)), 270 | ]) 271 | 272 | transform_test = transforms.Compose([ 273 | transforms.ToTensor(), 274 | transforms.Normalize((0.5071, 0.4867, 0.4408), 275 | (0.2675, 0.2565, 0.2761)), 276 | ]) 277 | 278 | # Load CIFAR-100 dataset 279 | train_dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train) 280 | test_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test) 281 | 282 | # Create a mapping from class labels to indices for sampling 283 | class_indices = defaultdict(list) 284 | for idx, (_, label) in enumerate(train_dataset): 285 | class_indices[label].append(idx) 286 | 287 | # Initialize the model (ResNet18) and move it to the device 288 | model = models.resnet18(num_classes=100) 289 | model = model.to(device) 290 | 291 | # Define the loss function (CrossEntropyLoss) 292 | criterion = nn.CrossEntropyLoss() 293 | 294 | # Initialize the optimizer based on the selected type and parameters 295 | if config['optimizer'] == 'SGD': 296 | optimizer = optim.SGD(model.parameters(), lr=config['learning_rate'], 297 | momentum=config['momentum'], 298 | weight_decay=config['weight_decay'], 299 | nesterov=config['nesterov']) 300 | elif config['optimizer'] == 'SGD+Nesterov': 301 | optimizer = optim.SGD(model.parameters(), lr=config['learning_rate'], 302 | momentum=config['momentum'], 303 | weight_decay=config['weight_decay'], 304 | nesterov=True) 305 | elif config['optimizer'] == 'SGD+Nesterov+val_plateau': 306 | optimizer = optim.SGD(model.parameters(), lr=config['learning_rate'], 307 | momentum=config['momentum'], 308 | weight_decay=config['weight_decay'], 309 | nesterov=True) 310 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=config['scheduler_patience'], factor=config['scheduler_factor']) 311 | elif config['optimizer'] == 'Adam': 312 | optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'], 313 | betas=tuple(config['betas']), 314 | eps=config['eps'], 315 | weight_decay=config['weight_decay']) 316 | elif config['optimizer'] == 'AdamW': 317 | optimizer = optim.AdamW(model.parameters(), lr=config['learning_rate'], 318 | betas=tuple(config['betas']), 319 | eps=config['eps'], 320 | weight_decay=config['weight_decay']) 321 | elif config['optimizer'] == 'RMSProp': 322 | optimizer = optim.RMSprop(model.parameters(), lr=config['learning_rate'], 323 | alpha=config['alpha'], 324 | eps=config['eps'], 325 | weight_decay=config['weight_decay'], 326 | momentum=config['momentum'], 327 | centered=config['centered']) 328 | else: 329 | raise ValueError(f"Unsupported optimizer type: {config['optimizer']}") 330 | 331 | # Test DataLoader 332 | test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=2) 333 | 334 | if config['wandb']: 335 | 336 | # Raise an error if wandb is not installed 337 | if wandb is None: 338 | raise ImportError("wandb is not installed. Please install it using 'pip install wandb'.") 339 | 340 | # Set up WandB project and run names 341 | model_name = 'ResNet18' 342 | dataset_name = 'CIFAR100' 343 | project_name = f"{model_name}_{dataset_name}" 344 | name_prefix = 'GAF' if config['GAF'] else 'NO_GAF' 345 | run_name = f"{name_prefix}_opt_{config['optimizer']}_lr_{config['learning_rate']}_bs_{config['batch_size']}" 346 | 347 | # Initialize WandB 348 | wandb.init(project=project_name, name=run_name, config=config) 349 | config = wandb.config 350 | 351 | # Create checkpoints directory 352 | checkpoint_dir = './checkpoints/' 353 | os.makedirs(checkpoint_dir, exist_ok=True) 354 | 355 | start_timestamp = time.strftime("%Y%m%d_%H%M%S") 356 | 357 | # If label error percentage is specified, flip labels to introduce label noise 358 | if config['label_error_percentage'] and config['label_error_percentage'] > 0: 359 | if 0 < config['label_error_percentage'] < 1: 360 | train_dataset = flip_labels(train_dataset, label_error_percentage=config['label_error_percentage'], num_classes=len(class_indices)) 361 | else: 362 | raise ValueError(f"label_error_percentage needs to be between 0 and 1. Given label_error_percentage={config['label_error_percentage']}") 363 | 364 | iteration = 0 365 | # Training loop 366 | try: 367 | for epoch in range(config['epochs']): 368 | model.train() 369 | running_loss = 0.0 370 | correct_top1 = 0 371 | total = 0 372 | i=0 373 | 374 | # Calculate total iterations per epoch 375 | if config['GAF']: 376 | iterations_per_epoch = len(train_dataset) // (len(class_indices) * config['num_samples_per_class_per_batch']) 377 | else: 378 | iterations_per_epoch = len(train_dataset) // config['batch_size'] 379 | 380 | while i < iterations_per_epoch: 381 | i+=1 382 | if config['GAF']: 383 | # Sample microbatches for GAF 384 | mbs = sample_iid_mbs_for_GAF(train_dataset, class_indices, config['num_batches_to_force_agreement'], config['num_samples_per_class_per_batch']) 385 | # Run GAF to update the model 386 | result = step_GAF(model, optimizer, criterion, mbs, verbose=config['verbose'], device=device) 387 | 388 | 389 | # # Update metrics 390 | running_loss += result['train_loss'] / (len(class_indices)*config['num_batches_to_force_agreement']) 391 | total += 1 392 | correct_top1 += result['train_accuracy'] 393 | 394 | # Log metrics 395 | if config['verbose']: 396 | print(f'Epoch: {epoch:7d}, Iteration: {iteration:7d}, Train Loss: {result["train_loss"]:.9f}, Train Acc: {result["train_accuracy"]:.4f}, Costine Distance: {result["cosine_distance"]:.4f}, Agreement Count: {result["agreed_count"]:d}') 397 | 398 | # Log to wandb 399 | if config['wandb']: 400 | wandb.log(result) 401 | 402 | else: 403 | # Sample a minibatch for standard training 404 | batch = sample_iid_mbs(train_dataset, class_indices, config['batch_size']) 405 | loader = DataLoader(batch, batch_size=len(batch), shuffle=False) 406 | data = next(iter(loader)) 407 | images, labels = data[0].to(device), data[1].to(device) 408 | # Forward and backward passes 409 | optimizer.zero_grad() 410 | outputs = model(images) 411 | loss = criterion(outputs, labels) 412 | loss.backward() 413 | optimizer.step() 414 | 415 | # Update metrics 416 | running_loss += loss.item() * images.size(0) 417 | _, predicted = torch.max(outputs.data, 1) 418 | total += labels.size(0) 419 | correct_top1 += (predicted == labels).sum().item() 420 | 421 | # print for baseline 422 | message = {'train_loss': loss.item(), 423 | 'train_accuracy': (predicted == labels).sum().item() / labels.size(0), 424 | 'iteration': iteration} 425 | 426 | # Log metrics to wandb 427 | if config['wandb']: 428 | try: 429 | wandb.log(message) 430 | except Exception as e: 431 | print(f"Failed to log to wandb: {e}") 432 | if config['verbose']: 433 | print(f'Epoch: {epoch:7d}, Iteration: {iteration:7d}, Train Loss: {message["train_loss"]:.9f}, Train Acc: {message["train_accuracy"]:.4f}') 434 | 435 | iteration += 1 436 | 437 | # Perform validation every num_val_epochs iterations 438 | if epoch % config['num_val_epochs'] == 0 and total > 0: 439 | # Compute training metrics 440 | train_loss = running_loss / total 441 | train_accuracy = correct_top1 / total 442 | 443 | # Evaluate on the validation/test set 444 | val_loss, val_accuracy = evaluate(model, test_loader, device) 445 | message = {'train_loss': train_loss, 446 | 'train_accuracy': train_accuracy, 447 | 'val_loss': val_loss, 448 | 'val_accuracy': val_accuracy, 449 | 'epoch': epoch, 450 | 'iteration': iteration} 451 | 452 | # Log metrics to wandb 453 | if config['wandb']: 454 | try: 455 | wandb.log(message) 456 | except Exception as e: 457 | print(f"Failed to log to wandb: {e}") 458 | 459 | if config['verbose']: 460 | print(f'Epoch: {epoch:7d}, Train Loss: {train_loss:.9f}, Train Acc: {train_accuracy:.4f}, Val Loss: {val_loss:.9f}, Val Acc: {val_accuracy:.4f}') 461 | 462 | # Reset running metrics 463 | running_loss = 0.0 464 | correct_top1 = 0 465 | total = 0 466 | 467 | # Save the latest checkpoint 468 | timestamp = time.strftime("%Y%m%d_%H%M%S") 469 | checkpoint_name = f"cifar_100_{start_timestamp}_checkpoint_{timestamp}.pt" 470 | 471 | checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name) 472 | try: 473 | torch.save(model.state_dict(), checkpoint_path) 474 | except Exception as e: 475 | print(f"Failed to save checkpoint: {e}") 476 | 477 | # Adjust learning rate if scheduler is used 478 | if config['optimizer'] == 'SGD+Nesterov+val_plateau': 479 | scheduler.step(val_loss) 480 | 481 | except KeyboardInterrupt: 482 | print("Training interrupted. Saving checkpoint...") 483 | timestamp = time.strftime("%Y%m%d_%H%M%S") 484 | checkpoint_name = f"cifar_100_{start_timestamp}_interrupt_{timestamp}.pt" 485 | checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name) 486 | try: 487 | torch.save(model.state_dict(), checkpoint_path) 488 | print(f"Checkpoint saved at {checkpoint_path}") 489 | except Exception as e: 490 | print(f"Failed to save checkpoint: {e}") 491 | exit(0) -------------------------------------------------------------------------------- /examples/2_cifar_100_trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to train ResNet18 on CIFAR-100 with Gradient Agreement Filtering (GAF) with test_GAF(). 3 | 4 | Usage: 5 | python examples/2_cifar_100_trainer.py 6 | 7 | Author: 8 | Francois Chaubard 9 | 10 | Date: 11 | 2024-12-03 12 | """ 13 | import torch 14 | import torch.nn as nn 15 | from torchvision import datasets, transforms, models 16 | import torch.optim as optim 17 | from torch.utils.data import DataLoader, Subset 18 | import numpy as np 19 | import random 20 | import os 21 | import time 22 | from argparse import Namespace 23 | 24 | # Try to import wandb 25 | try: 26 | import wandb 27 | except ImportError: 28 | wandb = None 29 | 30 | from gradient_agreement_filtering import train_GAF 31 | 32 | 33 | # Ensure to set your WandB API key as an environment variable or directly in the code 34 | # os.environ["WANDB_API_KEY"] = "your_wandb_api_key_here" 35 | 36 | # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 37 | 38 | # Check for available device (GPU or CPU) 39 | if torch.cuda.is_available(): 40 | num_gpus = torch.cuda.device_count() 41 | device_index = random.randint(0, num_gpus - 1) # Pick a random device index 42 | device = torch.device(f"cuda:{device_index}") 43 | elif torch.mps.is_available(): 44 | device = torch.device("mps") 45 | else: 46 | device = torch.device("cpu") 47 | 48 | print(f"Using device: {device}") 49 | 50 | 51 | # Data setup 52 | transform = transforms.Compose([ 53 | transforms.ToTensor(), 54 | transforms.Normalize((0.5071, 0.4867, 0.4408), 55 | (0.2675, 0.2565, 0.2761)), 56 | ]) 57 | train_dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform) 58 | val_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform) 59 | 60 | # Model 61 | model = models.resnet18(num_classes=100) 62 | 63 | # Define args 64 | args = Namespace( 65 | epochs=100, 66 | batch_size=128, 67 | num_batches_to_force_agreement=2, 68 | learning_rate=0.01, 69 | momentum = 0.9, 70 | weight_decay = 1e-2, 71 | cos_distance_thresh = 0.98, 72 | wandb = True, 73 | verbose = True 74 | ) 75 | 76 | optimizer = optim.SGD(model.parameters(), 77 | lr=args.learning_rate, 78 | momentum=args.momentum, 79 | weight_decay=args.weight_decay, 80 | nesterov=True) 81 | 82 | criterion = nn.CrossEntropyLoss() 83 | 84 | if wandb: 85 | # Raise an error if wandb is not installed 86 | if wandb is None: 87 | raise ImportError("wandb is not installed. Please install it using 'pip install wandb'.") 88 | 89 | # Set up WandB project and run names 90 | model_name = 'ResNet18' 91 | dataset_name = 'CIFAR100' 92 | project_name = f"{model_name}_{dataset_name}" 93 | run_name = f"example_run" 94 | wandb.init(project=project_name, name=run_name, config=args) 95 | 96 | 97 | try: 98 | train_GAF(model, 99 | args, 100 | train_dataset, 101 | val_dataset, 102 | optimizer, 103 | criterion, 104 | use_wandb=args.wandb, 105 | verbose=args.verbose, 106 | cos_distance_thresh=args.cos_distance_thresh, 107 | device=device) 108 | except KeyboardInterrupt: 109 | # Save the model if the training is interrupted 110 | timestamp = time.strftime("%Y%m%d_%H%M%S") 111 | 112 | checkpoint_dir = './checkpoints/' 113 | os.makedirs(checkpoint_dir, exist_ok=True) 114 | 115 | checkpoint_name = f"cifar_100_{timestamp}.pt" 116 | 117 | checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name) 118 | 119 | try: 120 | torch.save(model.state_dict(), checkpoint_path) 121 | print(f"Checkpoint saved at {checkpoint_path}") 122 | except Exception as e: 123 | print(f"Failed to save checkpoint: {e}") 124 | if wandb: 125 | wandb.finish() 126 | print('WandB run finished') 127 | print('Training interrupted') 128 | -------------------------------------------------------------------------------- /examples/3_cifar_100N_Fine_train_loop_exposed.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to train ResNet34-PreAct (as per the standard benchmark) on CIFAR-100N-Fine with Gradient Agreement Filtering (GAF) and various optimizers. 3 | 4 | This script allows for experimentation with different optimizers and GAF settings. It supports label noise injection and logs metrics using Weights & Biases (wandb). 5 | 6 | We expose the train loop to allow maximum flexibility. 7 | 8 | Usage: 9 | python examples/3_cifar_100N_Fine_train_loop_exposed.py [OPTIONS] 10 | 11 | Example: 12 | python examples/3_cifar_100N_Fine_train_loop_exposed.py --GAF True --optimizer "SGD+Nesterov+val_plateau" --cifarn True --learning_rate 0.01 --momentum 0.9 --nesterov True --wandb True --verbose True --num_samples_per_class_per_batch 2 --num_batches_to_force_agreement 2 --cos_distance_thresh 0.97 13 | 14 | Author: 15 | Francois Chaubard 16 | 17 | Date: 18 | 2024-12-03 19 | """ 20 | import torch 21 | import torch.nn as nn 22 | import torch.nn.functional as F 23 | import torch.optim as optim 24 | from torchvision import datasets, transforms 25 | from torch.utils.data import DataLoader, Subset 26 | from collections import defaultdict 27 | import numpy as np 28 | import random 29 | import os 30 | import argparse 31 | import time 32 | 33 | # Try to import wandb 34 | try: 35 | import wandb 36 | except ImportError: 37 | wandb = None 38 | 39 | from gradient_agreement_filtering import step_GAF 40 | 41 | #################################################################################################### 42 | 43 | # (OPTIONAL) WANDB SETUP 44 | 45 | # Ensure to set your WandB API key as an environment variable or directly in the code 46 | # os.environ["WANDB_API_KEY"] = "" 47 | 48 | #################################################################################################### 49 | 50 | # DEFINITIONS 51 | 52 | def str2bool(v): 53 | """Parse boolean values from the command line.""" 54 | if isinstance(v, bool): 55 | return v 56 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 57 | return True 58 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 59 | return False 60 | else: 61 | raise argparse.ArgumentTypeError('Boolean value expected.') 62 | 63 | 64 | def sample_iid_mbs(full_dataset, class_indices, batch_size): 65 | """ 66 | Samples an IID minibatch for standard training. 67 | 68 | Args: 69 | full_dataset (Dataset): The full training dataset. 70 | class_indices (dict): A mapping from class labels to data indices. 71 | batch_size (int): The size of the batch to sample. 72 | 73 | Returns: 74 | Subset: A subset of the dataset representing the minibatch. 75 | """ 76 | num_classes = len(class_indices) 77 | samples_per_class = batch_size // num_classes 78 | batch_indices = [] 79 | for cls in class_indices: 80 | indices = random.sample(class_indices[cls], samples_per_class) 81 | batch_indices.extend(indices) 82 | # If batch_size is not divisible by num_classes, fill the rest randomly 83 | remaining = batch_size - len(batch_indices) 84 | if remaining > 0: 85 | all_indices = [idx for idx in range(len(full_dataset))] 86 | batch_indices.extend(random.sample(all_indices, remaining)) 87 | # Create a Subset 88 | batch = Subset(full_dataset, batch_indices) 89 | return batch 90 | 91 | 92 | def sample_iid_mbs_for_GAF(full_dataset, class_indices, n, m): 93 | """ 94 | Samples 'n' independent minibatches, each containing an equal number of samples from each class, m. 95 | 96 | Args: 97 | full_dataset (Dataset): The full training dataset. 98 | class_indices (dict): A mapping from class labels to data indices. 99 | n (int): The number of microbatches to sample. 100 | m (int): The number of images per class to sample per microbatch. 101 | 102 | Returns: 103 | list: A list of Subsets representing the minibatches. 104 | """ 105 | # Initialize a list to hold indices for each batch 106 | batch_indices_list = [[] for _ in range(n)] 107 | for clazz in class_indices: 108 | num_samples_per_class = m # Adjust if you want more samples per class per batch 109 | total_samples_needed = num_samples_per_class * n 110 | available_indices = class_indices[clazz] 111 | # Ensure there are enough indices 112 | if len(available_indices) < total_samples_needed: 113 | multiples = (total_samples_needed // len(available_indices)) + 1 114 | extended_indices = (available_indices * multiples)[:total_samples_needed] 115 | else: 116 | extended_indices = random.sample(available_indices, total_samples_needed) 117 | for i in range(n): 118 | start_idx = i * num_samples_per_class 119 | end_idx = start_idx + num_samples_per_class 120 | batch_indices_list[i].extend(extended_indices[start_idx:end_idx]) 121 | # Create Subsets for each batch 122 | batches = [Subset(full_dataset, batch_indices) for batch_indices in batch_indices_list] 123 | return batches 124 | 125 | 126 | def evaluate(model, dataloader, device): 127 | """ 128 | Evaluates the model on the validation or test dataset. 129 | 130 | Args: 131 | model (nn.Module): The model to evaluate. 132 | dataloader (DataLoader): The DataLoader for the dataset. 133 | device (torch.device): The device to use. 134 | 135 | Returns: 136 | tuple: Average loss and top-1 accuracy. 137 | """ 138 | model.eval() 139 | total_loss = 0.0 140 | correct_top1 = 0 141 | total = 0 142 | with torch.no_grad(): 143 | for data in dataloader: 144 | images, labels = data[0].to(device), data[1].to(device) 145 | outputs = model(images) 146 | loss = criterion(outputs, labels) 147 | total_loss += loss.item() * images.size(0) 148 | _, predicted = torch.max(outputs.data, 1) 149 | total += labels.size(0) 150 | correct_top1 += (predicted == labels).sum().item() 151 | avg_loss = total_loss / total 152 | accuracy_top1 = correct_top1 / total 153 | return avg_loss, accuracy_top1 154 | 155 | 156 | '''ResNet in PyTorch. 157 | 158 | For Pre-activation ResNet, see 'preact_resnet.py'. 159 | 160 | Reference: 161 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 162 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 163 | ''' 164 | 165 | def conv3x3(in_planes, out_planes, stride=1): 166 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 167 | 168 | 169 | class BasicBlock(nn.Module): 170 | expansion = 1 171 | 172 | def __init__(self, in_planes, planes, stride=1): 173 | super(BasicBlock, self).__init__() 174 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 175 | self.bn1 = nn.BatchNorm2d(planes) 176 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 177 | self.bn2 = nn.BatchNorm2d(planes) 178 | 179 | self.shortcut = nn.Sequential() 180 | if stride != 1 or in_planes != self.expansion*planes: 181 | self.shortcut = nn.Sequential( 182 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 183 | nn.BatchNorm2d(self.expansion*planes) 184 | ) 185 | 186 | def forward(self, x): 187 | out = F.relu(self.bn1(self.conv1(x))) 188 | out = self.bn2(self.conv2(out)) 189 | out += self.shortcut(x) 190 | out = F.relu(out) 191 | return out 192 | 193 | 194 | class PreActBlock(nn.Module): 195 | '''Pre-activation version of the BasicBlock.''' 196 | expansion = 1 197 | 198 | def __init__(self, in_planes, planes, stride=1): 199 | super(PreActBlock, self).__init__() 200 | self.bn1 = nn.BatchNorm2d(in_planes) 201 | self.conv1 = conv3x3(in_planes, planes, stride) 202 | self.bn2 = nn.BatchNorm2d(planes) 203 | self.conv2 = conv3x3(planes, planes) 204 | 205 | self.shortcut = nn.Sequential() 206 | if stride != 1 or in_planes != self.expansion*planes: 207 | self.shortcut = nn.Sequential( 208 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 209 | ) 210 | 211 | def forward(self, x): 212 | out = F.relu(self.bn1(x)) 213 | shortcut = self.shortcut(out) 214 | out = self.conv1(out) 215 | out = self.conv2(F.relu(self.bn2(out))) 216 | out += shortcut 217 | return out 218 | 219 | 220 | class Bottleneck(nn.Module): 221 | expansion = 4 222 | 223 | def __init__(self, in_planes, planes, stride=1): 224 | super(Bottleneck, self).__init__() 225 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 226 | self.bn1 = nn.BatchNorm2d(planes) 227 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 228 | self.bn2 = nn.BatchNorm2d(planes) 229 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 230 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 231 | 232 | self.shortcut = nn.Sequential() 233 | if stride != 1 or in_planes != self.expansion*planes: 234 | self.shortcut = nn.Sequential( 235 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 236 | nn.BatchNorm2d(self.expansion*planes) 237 | ) 238 | 239 | def forward(self, x): 240 | out = F.relu(self.bn1(self.conv1(x))) 241 | out = F.relu(self.bn2(self.conv2(out))) 242 | out = self.bn3(self.conv3(out)) 243 | out += self.shortcut(x) 244 | out = F.relu(out) 245 | return out 246 | 247 | 248 | class PreActBottleneck(nn.Module): 249 | '''Pre-activation version of the original Bottleneck module.''' 250 | expansion = 4 251 | 252 | def __init__(self, in_planes, planes, stride=1): 253 | super(PreActBottleneck, self).__init__() 254 | self.bn1 = nn.BatchNorm2d(in_planes) 255 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 256 | self.bn2 = nn.BatchNorm2d(planes) 257 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 258 | self.bn3 = nn.BatchNorm2d(planes) 259 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 260 | 261 | self.shortcut = nn.Sequential() 262 | if stride != 1 or in_planes != self.expansion*planes: 263 | self.shortcut = nn.Sequential( 264 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 265 | ) 266 | 267 | def forward(self, x): 268 | out = F.relu(self.bn1(x)) 269 | shortcut = self.shortcut(out) 270 | out = self.conv1(out) 271 | out = self.conv2(F.relu(self.bn2(out))) 272 | out = self.conv3(F.relu(self.bn3(out))) 273 | out += shortcut 274 | return out 275 | 276 | class ResNet(nn.Module): 277 | def __init__(self, block, num_blocks, num_classes=10): 278 | super(ResNet, self).__init__() 279 | self.in_planes = 64 280 | 281 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 282 | self.bn1 = nn.BatchNorm2d(64) 283 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 284 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 285 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 286 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 287 | self.linear = nn.Linear(512*block.expansion, num_classes) 288 | 289 | def _make_layer(self, block, planes, num_blocks, stride): 290 | strides = [stride] + [1]*(num_blocks-1) 291 | layers = [] 292 | for stride in strides: 293 | layers.append(block(self.in_planes, planes, stride)) 294 | self.in_planes = planes * block.expansion 295 | return nn.Sequential(*layers) 296 | 297 | def forward(self, x): 298 | out = F.relu(self.bn1(self.conv1(x))) 299 | out = self.layer1(out) 300 | out = self.layer2(out) 301 | out = self.layer3(out) 302 | out = self.layer4(out) 303 | out = F.avg_pool2d(out, 4) 304 | out = out.view(out.size(0), -1) 305 | out = self.linear(out) 306 | return out 307 | 308 | def PreResNet18(num_classes): 309 | return ResNet(PreActBlock, [2,2,2,2],num_classes=num_classes) 310 | 311 | def ResNet18(num_classes): 312 | return ResNet(BasicBlock, [2,2,2,2],num_classes=num_classes) 313 | 314 | def ResNet34(num_classes): 315 | return ResNet(BasicBlock, [3,4,6,3],num_classes=num_classes) 316 | 317 | def ResNet50(num_classes): 318 | return ResNet(Bottleneck, [3,4,6,3],num_classes=num_classes) 319 | 320 | def ResNet101(num_classes): 321 | return ResNet(Bottleneck, [3,4,23,3],num_classes=num_classes) 322 | 323 | def ResNet152(num_classes): 324 | return ResNet(Bottleneck, [3,8,36,3],num_classes=num_classes) 325 | 326 | #################################################################################################### 327 | 328 | # MAIN TRAINING LOOP 329 | 330 | if __name__ == '__main__': 331 | 332 | # Define the list of available optimizer types 333 | optimizer_types = ["SGD", "SGD+Nesterov", "SGD+Nesterov+val_plateau", "Adam", "AdamW", "RMSProp"] 334 | 335 | parser = argparse.ArgumentParser(description='Train ResNet18 on CIFAR-100 with various optimizers and GAF.') 336 | 337 | # General training parameters 338 | parser.add_argument('--GAF', type=str2bool, default=True, help='Enable Gradient Agreement Filtering (True or False)') 339 | parser.add_argument('--learning_rate', type=float, default=0.01, help='Learning rate for the optimizer') 340 | parser.add_argument('--weight_decay', type=float, default=1e-2, help='Weight decay factor') 341 | parser.add_argument('--batch_size', type=int, default=128, help='Batch size for training') 342 | parser.add_argument('--num_val_epochs', type=int, default=2, help='Number of epochs between validation checks') 343 | parser.add_argument('--optimizer', type=str, default='SGD', choices=optimizer_types, help='Optimizer type to use') 344 | parser.add_argument('--num_batches_to_force_agreement', type=int, default=2, help='Number of batches to compute gradients for agreement (must be > 1)') 345 | parser.add_argument('--epochs', type=int, default=10000, help='Total number of training epochs') 346 | parser.add_argument('--num_samples_per_class_per_batch', type=int, default=1, help='Number of samples per class per batch when using GAF') 347 | parser.add_argument('--cos_distance_thresh', type=float, default=1, help='Threshold for cosine distance in gradient agreement filtering. Tau in the paper.') 348 | parser.add_argument('--dummy', type=bool, default=False, help='if we should use dummy data or not') 349 | parser.add_argument('--cifarn', type=bool, default=True, help='if we should use CIFARN labels or not') 350 | parser.add_argument('--cifarn_noisy_data_file_path', type=str, default="./examples/CIFAR-100_human.pt", help='the path to the noisy labeling file') 351 | 352 | # Optimizer-specific parameters 353 | parser.add_argument('--momentum', type=float, default=0.0, help='Momentum factor for SGD and RMSProp optimizers') 354 | parser.add_argument('--nesterov', type=str2bool, default=False, help='Use Nesterov momentum (True or False)') 355 | parser.add_argument('--betas', type=float, nargs=2, default=(0.9, 0.999), help='Betas for Adam and AdamW optimizers') 356 | parser.add_argument('--eps', type=float, default=1e-8, help='Epsilon value for optimizers') 357 | parser.add_argument('--alpha', type=float, default=0.99, help='Alpha value for RMSProp optimizer') 358 | parser.add_argument('--centered', type=str2bool, default=False, help='Centered RMSProp (True or False)') 359 | parser.add_argument('--scheduler_patience', type=int, default=100, help='Patience for ReduceLROnPlateau scheduler') 360 | parser.add_argument('--scheduler_factor', type=int, default=0.1, help='Discount factor for ReduceLROnPlateau scheduler') 361 | 362 | # logging 363 | parser.add_argument('--wandb', type=str2bool, default=False, help='Log to wandb (True or False)') 364 | parser.add_argument('--verbose', type=str2bool, default=False, help='Print out logs (True or False)') 365 | 366 | # Parse arguments 367 | args = parser.parse_args() 368 | config = vars(args) 369 | 370 | # Set unused optimizer-specific configs to 'NA' 371 | optimizer = config['optimizer'] 372 | all_params = ['momentum', 'nesterov', 'betas', 'eps', 'alpha', 'centered', 'scheduler_patience', 'scheduler_factor'] 373 | 374 | # Define which parameters are used by each optimizer 375 | optimizer_params = { 376 | 'SGD': ['momentum', 'nesterov'], 377 | 'SGD+Nesterov': ['momentum', 'nesterov'], 378 | 'SGD+Nesterov+val_plateau': ['momentum', 'nesterov', 'scheduler_patience', 'scheduler_factor'], 379 | 'Adam': ['betas', 'eps'], 380 | 'AdamW': ['betas', 'eps'], 381 | 'RMSProp': ['momentum', 'eps', 'alpha', 'centered'], 382 | } 383 | 384 | # Get the list of parameters used by the selected optimizer 385 | used_params = optimizer_params.get(optimizer, []) 386 | 387 | # Set unused parameters to 'NA' 388 | for param in all_params: 389 | if param not in used_params: 390 | config[param] = 'NA' 391 | 392 | # Check for available device (GPU or CPU) 393 | if torch.cuda.is_available(): 394 | num_gpus = torch.cuda.device_count() 395 | device_index = random.randint(0, num_gpus - 1) # Pick a random device index 396 | device = torch.device(f"cuda:{device_index}") 397 | elif torch.mps.is_available(): 398 | device = torch.device("mps") 399 | else: 400 | device = torch.device("cpu") 401 | 402 | if config['verbose']: 403 | print(f"Using device: {device}") 404 | 405 | # Set random seeds for reproducibility 406 | torch.manual_seed(0) 407 | np.random.seed(0) 408 | random.seed(0) 409 | 410 | #### 411 | 412 | 413 | # Data transformations 414 | transform_train = transforms.Compose([ 415 | transforms.RandomCrop(32, padding=4), 416 | transforms.RandomHorizontalFlip(), 417 | transforms.ToTensor(), 418 | transforms.Normalize((0.5071, 0.4867, 0.4408), 419 | (0.2675, 0.2565, 0.2761)), 420 | ]) 421 | 422 | transform_test = transforms.Compose([ 423 | transforms.ToTensor(), 424 | transforms.Normalize((0.5071, 0.4867, 0.4408), 425 | (0.2675, 0.2565, 0.2761)), 426 | ]) 427 | 428 | 429 | if config['dummy']==True: 430 | print('using dummy data') 431 | train_dataset = datasets.FakeData( 432 | size=50000, # Match CIFAR-100 training set size 433 | image_size=(3, 32, 32), # Match CIFAR-100 image size 434 | num_classes=100, 435 | transform=transform_train, # Use the same training transforms 436 | random_offset=random.randint(0, 1000000) 437 | ) 438 | # Create fake test data 439 | test_dataset = datasets.FakeData( 440 | size=10000, # Match CIFAR-100 test set size 441 | image_size=(3, 32, 32), 442 | num_classes=100, 443 | transform=transform_test, # Use the same test transforms 444 | random_offset=random.randint(0, 1000000) 445 | ) 446 | 447 | 448 | elif config['cifarn']==True: 449 | print('using CIFAR100N data') 450 | train_dataset = datasets.CIFAR100( 451 | root='./data', 452 | train=True, 453 | download=True, 454 | transform=transform_train 455 | ) 456 | 457 | # Load the noisy labels 458 | dd = torch.load(config['cifarn_noisy_data_file_path']) # this should be part of the repo 459 | noisy_label = dd['noisy_label'] 460 | # Replace the training labels with noisy labels 461 | train_dataset.targets = noisy_label 462 | # Load CIFAR-100 test dataset (clean labels) 463 | test_dataset = datasets.CIFAR100( 464 | root='./data', 465 | train=False, 466 | download=True, 467 | transform=transform_test 468 | ) 469 | 470 | else: 471 | print('using CIFAR100 data') 472 | train_dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train) 473 | test_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test) 474 | 475 | 476 | # Create a mapping from class labels to indices for sampling 477 | class_indices = defaultdict(list) 478 | for idx, (_, label) in enumerate(train_dataset): 479 | class_indices[label].append(idx) 480 | 481 | # Initialize the model (ResNet34) and move it to the device 482 | model = ResNet34(num_classes=100) 483 | model = model.to(device) 484 | 485 | # Define the loss function (CrossEntropyLoss) 486 | criterion = nn.CrossEntropyLoss() 487 | 488 | # Initialize the optimizer based on the selected type and parameters 489 | if config['optimizer'] == 'SGD': 490 | optimizer = optim.SGD(model.parameters(), lr=config['learning_rate'], 491 | momentum=config['momentum'], 492 | weight_decay=config['weight_decay'], 493 | nesterov=config['nesterov']) 494 | elif config['optimizer'] == 'SGD+Nesterov': 495 | optimizer = optim.SGD(model.parameters(), lr=config['learning_rate'], 496 | momentum=config['momentum'], 497 | weight_decay=config['weight_decay'], 498 | nesterov=True) 499 | elif config['optimizer'] == 'SGD+Nesterov+val_plateau': 500 | optimizer = optim.SGD(model.parameters(), lr=config['learning_rate'], 501 | momentum=config['momentum'], 502 | weight_decay=config['weight_decay'], 503 | nesterov=True) 504 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=config['scheduler_patience'], factor=config['scheduler_factor']) 505 | elif config['optimizer'] == 'Adam': 506 | optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'], 507 | betas=tuple(config['betas']), 508 | eps=config['eps'], 509 | weight_decay=config['weight_decay']) 510 | elif config['optimizer'] == 'AdamW': 511 | optimizer = optim.AdamW(model.parameters(), lr=config['learning_rate'], 512 | betas=tuple(config['betas']), 513 | eps=config['eps'], 514 | weight_decay=config['weight_decay']) 515 | elif config['optimizer'] == 'RMSProp': 516 | optimizer = optim.RMSprop(model.parameters(), lr=config['learning_rate'], 517 | alpha=config['alpha'], 518 | eps=config['eps'], 519 | weight_decay=config['weight_decay'], 520 | momentum=config['momentum'], 521 | centered=config['centered']) 522 | else: 523 | raise ValueError(f"Unsupported optimizer type: {config['optimizer']}") 524 | 525 | # Test DataLoader 526 | test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=2) 527 | 528 | if config['wandb']: 529 | # Raise an error if wandb is not installed 530 | if wandb is None: 531 | raise ImportError("wandb is not installed. Please install it using 'pip install wandb'.") 532 | 533 | # Set up WandB project and run names 534 | model_name = 'ResNet34' 535 | dataset_name = 'CIFAR100N-Fine' 536 | project_name = f"{model_name}_{dataset_name}" 537 | 538 | # Initialize WandB 539 | wandb.init(project=project_name, config=config) 540 | config = wandb.config 541 | 542 | # Create checkpoints directory 543 | checkpoint_dir = './checkpoints/' 544 | os.makedirs(checkpoint_dir, exist_ok=True) 545 | 546 | start_timestamp = time.strftime("%Y%m%d_%H%M%S") 547 | 548 | iteration = 0 549 | try: 550 | # Training loop 551 | for epoch in range(config['epochs']): 552 | model.train() 553 | running_loss = 0.0 554 | correct_top1 = 0 555 | total = 0 556 | i = 0 557 | 558 | # Calculate total iterations per epoch 559 | if config['GAF']: 560 | iterations_per_epoch = len(train_dataset) // (len(class_indices) * config['num_samples_per_class_per_batch']) 561 | else: 562 | iterations_per_epoch = len(train_dataset) // config['batch_size'] 563 | 564 | while i < iterations_per_epoch: 565 | i += 1 566 | if config['GAF']: 567 | # Sample microbatches for GAF 568 | mbs = sample_iid_mbs_for_GAF(train_dataset, class_indices, config['num_batches_to_force_agreement'], config['num_samples_per_class_per_batch']) 569 | # Run GAF to update the model 570 | result = step_GAF(model, optimizer, criterion, mbs, use_wandb=config['wandb'], verbose=config['verbose'], device=device) 571 | 572 | # Update metrics 573 | running_loss += result['train_loss'] / (len(class_indices) * config['num_batches_to_force_agreement']) 574 | total += 1 575 | correct_top1 += result['train_accuracy'] 576 | 577 | if config['verbose']: 578 | print(f'Epoch: {epoch:7d}, Iteration: {iteration:7d}, Train Loss: {result["train_loss"]:.9f}, Train Acc: {result["train_accuracy"]:.4f}, Costine Distance: {result["cosine_distance"]:.4f}, Agreement Count: {result["agreed_count"]:d}') 579 | 580 | # Log to wandb 581 | if config['wandb']: 582 | wandb.log(result) 583 | 584 | else: 585 | # Sample a minibatch for standard training 586 | batch = sample_iid_mbs(train_dataset, class_indices, config['batch_size']) 587 | loader = DataLoader(batch, batch_size=len(batch), shuffle=False) 588 | data = next(iter(loader)) 589 | images, labels = data[0].to(device), data[1].to(device) 590 | # Forward and backward passes 591 | optimizer.zero_grad() 592 | outputs = model(images) 593 | loss = criterion(outputs, labels) 594 | loss.backward() 595 | optimizer.step() 596 | 597 | # Update metrics 598 | running_loss += loss.item() * images.size(0) 599 | _, predicted = torch.max(outputs.data, 1) 600 | total += labels.size(0) 601 | correct_top1 += (predicted == labels).sum().item() 602 | 603 | # print for baseline 604 | message = {'train_loss': loss.item(), 605 | 'train_accuracy': (predicted == labels).sum().item() / labels.size(0), 606 | 'iteration': iteration} 607 | 608 | # Log metrics to wandb 609 | if config['wandb']: 610 | try: 611 | wandb.log(message) 612 | except Exception as e: 613 | print(f"Failed to log to wandb: {e}") 614 | if config['verbose']: 615 | # Print formatted metrics 616 | print(f'Epoch: {epoch:7d}, Iteration: {iteration:7d}, Train Loss: {message["train_loss"]:.9f}, Train Acc: {message["train_accuracy"]:.4f}') 617 | 618 | iteration += 1 619 | 620 | # Perform validation every num_val_epochs iterations 621 | if epoch % config['num_val_epochs'] == 0 and total > 0: 622 | # Compute training metrics 623 | train_loss = running_loss / total 624 | train_accuracy = correct_top1 / total 625 | 626 | # Evaluate on the validation/test set 627 | val_loss, val_accuracy = evaluate(model, test_loader, device) 628 | message = {'train_loss': train_loss, 629 | 'train_accuracy': train_accuracy, 630 | 'val_loss': val_loss, 631 | 'val_accuracy': val_accuracy, 632 | 'epoch': epoch, 633 | 'iteration': iteration} 634 | # Log metrics to wandb 635 | if config['wandb']: 636 | try: 637 | wandb.log(message) 638 | except Exception as e: 639 | print(f"Failed to log to wandb: {e}") 640 | 641 | # Print formatted metrics 642 | print(f'Epoch: {epoch:7d}, Train Loss: {train_loss:.9f}, Train Acc: {train_accuracy:.4f}, Val Loss: {val_loss:.9f}, Val Acc: {val_accuracy:.4f}') 643 | 644 | # Reset running metrics 645 | running_loss = 0.0 646 | correct_top1 = 0 647 | total = 0 648 | # Save the latest checkpoint 649 | timestamp = time.strftime("%Y%m%d_%H%M%S") 650 | checkpoint_name = f"cifar_100n_{start_timestamp}_checkpoint_{timestamp}.pt" 651 | 652 | checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name) 653 | try: 654 | torch.save(model.state_dict(), checkpoint_path) 655 | except Exception as e: 656 | print(f"Failed to save checkpoint: {e}") 657 | 658 | # Adjust learning rate if scheduler is used 659 | if config['optimizer'] == 'SGD+Nesterov+val_plateau': 660 | scheduler.step(val_loss) 661 | 662 | except KeyboardInterrupt: 663 | print("Training interrupted. Saving checkpoint...") 664 | timestamp = time.strftime("%Y%m%d_%H%M%S") 665 | checkpoint_name = f"cifar_100n_{start_timestamp}_interrupt_{timestamp}.pt" 666 | checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name) 667 | try: 668 | torch.save(model.state_dict(), checkpoint_path) 669 | print(f"Checkpoint saved at {checkpoint_path}") 670 | except Exception as e: 671 | print(f"Failed to save checkpoint: {e}") 672 | exit(0) 673 | -------------------------------------------------------------------------------- /examples/CIFAR-100_human.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fchaubard/gradient_agreement_filtering/de0fd3d0604ead1471edf855fad8344caae3e0c0/examples/CIFAR-100_human.pt -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "gradient_agreement_filtering" 3 | version = "0.1.3" 4 | description = "Gradient Agreement Filtering (GAF) Package" 5 | authors = [ 6 | { name="Francois Chaubard", email="fchaubard@gmail.com" } 7 | ] 8 | license = { file = "LICENSE" } 9 | readme = "README.md" 10 | requires-python = ">=3.9" 11 | dependencies = [ 12 | "torch>=2.0.0", 13 | "torchvision>=0.16.0", 14 | "numpy>=1.2.0", 15 | ] 16 | 17 | # Add dev dependencies 18 | [project.optional-dependencies] 19 | dev = [ 20 | "pytest>=8.3.4", 21 | "pytest-cov>=6.0.0", 22 | "wandb>=0.19.0", 23 | ] 24 | 25 | all = [ 26 | "wandb>=0.19.0", 27 | ] 28 | 29 | [tool.setuptools] 30 | py-modules = ['gradient_agreement_filtering'] 31 | 32 | [tool.setuptools.packages.find] 33 | where = ["src"] 34 | 35 | [tool.setuptools.package-data] 36 | "*" = ["*"] 37 | 38 | [project.urls] 39 | Homepage = "https://github.com/Fchaubard/gradient_agreement_filtering" 40 | -------------------------------------------------------------------------------- /src/gradient_agreement_filtering/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | # Import major function directly in init so that they can be loaded directly from the package 3 | from .gaf import ( 4 | step_GAF, 5 | train_GAF, 6 | ) -------------------------------------------------------------------------------- /src/gradient_agreement_filtering/gaf.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of GAF in two methods. 3 | 4 | Method 1: step_GAF() 5 | Usage: 6 | To be used inside a train loop in lieu of optimizer.step() 7 | 8 | Method 2: train_GAF() 9 | Usage: 10 | To be used inside as the train loop in lieu of huggingface's trainer.train() 11 | 12 | For more information, see documentation below. 13 | 14 | Author: 15 | Francois Chaubard 16 | 17 | Date: 18 | 2024-12-03 19 | """ 20 | 21 | import torch 22 | from torch.utils.data import DataLoader, Subset 23 | import random 24 | 25 | def _filter_gradients_cosine_sim(G1, G2, cos_distance_thresh): 26 | """ 27 | Filters gradients based on cosine similarity. 28 | 29 | Args: 30 | G1 (list[torch.Tensor]): List of parameter gradients from the first microbatch. 31 | G2 (list[torch.Tensor]): List of parameter gradients from the second microbatch. 32 | cos_distance_thresh (float): Threshold for cosine distance. 33 | 34 | Returns: 35 | tuple: (filtered_grad, cos_distance) 36 | filtered_grad (list[torch.Tensor] or None): The averaged gradients if cosine distance 37 | is below threshold, otherwise None. 38 | cos_distance (float): The computed cosine distance between G1 and G2. 39 | """ 40 | # Flatten G1 and G2 into vectors 41 | G1_flat = torch.cat([g1.view(-1) for g1 in G1]) 42 | G2_flat = torch.cat([g2.view(-1) for g2 in G2]) 43 | 44 | # Compute cosine similarity 45 | cos_sim = torch.nn.functional.cosine_similarity(G1_flat, G2_flat, dim=0) 46 | 47 | # Compute cosine distance 48 | cos_distance = 1 - cos_sim 49 | 50 | if cos_distance > cos_distance_thresh: 51 | filtered_grad = None 52 | else: 53 | filtered_grad = [(g1 + g2) / 2 for g1, g2 in zip(G1, G2)] 54 | 55 | return filtered_grad, cos_distance.item() 56 | 57 | def _compute_gradients(b, optimizer, model, criterion, device): 58 | """ 59 | Computes gradients for a given microbatch. 60 | 61 | Args: 62 | b (Subset): The microbatch dataset. 63 | optimizer (torch.optim.Optimizer): The optimizer. 64 | model (torch.nn.Module): The model. 65 | criterion (torch.nn.Module): The loss function. 66 | device (torch.device): The device to use. 67 | 68 | Returns: 69 | tuple: (G, loss, labels, outputs) 70 | G (list[torch.Tensor]): List of gradients for each model parameter. 71 | loss (torch.Tensor): Computed loss for the batch. 72 | labels (torch.Tensor): Labels for the batch. 73 | outputs (torch.Tensor): Model outputs for the batch. 74 | """ 75 | loader = DataLoader(b, batch_size=len(b), shuffle=False) 76 | data = next(iter(loader)) 77 | images, labels = data[0].to(device), data[1].to(device) 78 | 79 | optimizer.zero_grad() 80 | outputs = model(images) 81 | loss = criterion(outputs, labels) 82 | loss.backward() 83 | 84 | G = [p.grad.clone() for p in model.parameters()] 85 | optimizer.zero_grad() 86 | return G, loss, labels, outputs 87 | 88 | def step_GAF(model, 89 | optimizer, 90 | criterion, 91 | list_of_microbatches, 92 | verbose=True, 93 | cos_distance_thresh=1.0, 94 | device=torch.device('cpu')): 95 | """ 96 | Performs one Gradient Agreement Filtering (GAF) step given a list of microbatches. 97 | 98 | Args: 99 | model (torch.nn.Module): The model to train. 100 | optimizer (torch.optim.Optimizer): The optimizer. 101 | criterion (torch.nn.Module): The loss function. 102 | list_of_microbatches (list[Subset]): A list of data subsets (microbatches) for GAF. 103 | verbose (bool): Whether to print debug information. 104 | cos_distance_thresh (float): Cosine distance threshold for filtering. This is \tau in the paper. Must be between 0 and 2. We recommend 0.9 to 1 for an HPP sweep. 105 | device (torch.device): Device on which to perform computation. TODO: You may want to distribute this across GPUs which we may implement later to be in parellel. 106 | 107 | Returns: 108 | dict: A dictionary containing loss, cosine_distance, and agreed_count. 109 | """ 110 | model.train() 111 | total_loss = 0 112 | total_correct = 0 113 | 114 | # Compute gradients on the first microbatch 115 | G_current, loss, labels, outputs = _compute_gradients(list_of_microbatches[0], optimizer, model, criterion, device) 116 | 117 | # update total_loss 118 | total_loss += loss * labels.size(0) 119 | 120 | # update @1 accuracy 121 | _, predicted = torch.max(outputs.data, 1) 122 | correct_top1 = (predicted == labels).sum().item() 123 | total_correct += correct_top1 124 | 125 | # Fuse gradients with subsequent microbatches 126 | agreed_count = 0 127 | 128 | for i, mb in enumerate(list_of_microbatches[1:]): 129 | # compute microgradients and filter them based on cosine distance: 130 | G, loss_i, labels_i, outputs_i = _compute_gradients(mb, optimizer, model, criterion, device) 131 | G_current_temp, cosine_distance = _filter_gradients_cosine_sim(G_current, G, cos_distance_thresh) 132 | 133 | # update total_loss 134 | total_loss += loss_i * labels_i.size(0) 135 | 136 | # update @1 accuracy 137 | _, predicted = torch.max(outputs_i.data, 1) 138 | correct_top1 = (predicted == labels_i).sum().item() 139 | total_correct += correct_top1 140 | 141 | if verbose: 142 | print(f"Gradient fusion iteration {i+1}/{len(list_of_microbatches)-1}, Cosine Distance: {cosine_distance:.4f}") 143 | 144 | if G_current_temp is not None: 145 | G_current = G_current_temp 146 | agreed_count += 1 147 | 148 | # If at least one agreed, update params 149 | if agreed_count > 0: 150 | with torch.no_grad(): 151 | for param, grad in zip(model.parameters(), G_current): 152 | param.grad = grad 153 | optimizer.step() 154 | 155 | # Compute metrics 156 | total = labels.size(0) * len(list_of_microbatches) 157 | result = {'train_loss': total_loss.item(), 'cosine_distance': cosine_distance, 'agreed_count':agreed_count, 'train_accuracy':total_correct/total } 158 | 159 | return result 160 | 161 | def train_GAF(model, 162 | args, 163 | train_dataset, 164 | val_dataset, 165 | optimizer, 166 | criterion, 167 | use_wandb=True, 168 | verbose=True, 169 | cos_distance_thresh=1.0, 170 | device=torch.device('cpu')): 171 | """ 172 | Trains the model using Gradient Agreement Filtering (GAF) across multiple epochs. 173 | This mimics the HuggingFace Trainer interface. 174 | 175 | Args: 176 | model (torch.nn.Module): The model to train. 177 | args (object): A simple namespace or dictionary with training arguments. 178 | Expected fields: 179 | - args.epochs (int) 180 | - args.batch_size (int) 181 | - args.num_batches_to_force_agreement (int) 182 | train_dataset (torch.utils.data.Dataset): Training dataset. 183 | val_dataset (torch.utils.data.Dataset): Validation dataset. 184 | optimizer (torch.optim.Optimizer): Optimizer to be used. 185 | criterion (torch.nn): Loss function to be used. 186 | use_wandb (bool): Whether to log metrics to Weights & Biases. 187 | verbose (bool): Whether to print progress. 188 | cos_distance_thresh (float): Cosine distance threshold for GAF filtering. This is \tau in the paper. Must be between 0 and 2. We recommend 0.9 to 1 for an HPP sweep. 189 | device (torch.device): Device on which to train on. TODO: Make this an array in future iters size at most num_batches_to_force_agreement to run them in parallel. 190 | 191 | Returns: 192 | None 193 | """ 194 | model.to(device) 195 | 196 | # Check if wandb is imported 197 | if use_wandb: 198 | try: 199 | import wandb 200 | except ImportError: 201 | raise ImportError("Please install wandb to use use_wandb option. You can install it using 'pip install wandb'.") 202 | 203 | # Create dataloaders 204 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) 205 | val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False) 206 | 207 | # A simple function to sample multiple microbatches for GAF 208 | def sample_iid_mbs_for_GAF(full_dataset, batch_size, n): 209 | # For simplicity: just take n random subsets of size batch_size 210 | indices = list(range(len(full_dataset))) 211 | random.shuffle(indices) 212 | # If not enough data for n batches, just cycle 213 | if len(indices) < batch_size * n: 214 | multiples = (batch_size * n) // len(indices) + 1 215 | indices = (indices * multiples)[:batch_size * n] 216 | batches = [] 217 | for i in range(n): 218 | subset_indices = indices[i*batch_size:(i+1)*batch_size] 219 | batches.append(Subset(full_dataset, subset_indices)) 220 | return batches 221 | 222 | # Training loop 223 | for epoch in range(args.epochs): 224 | model.train() 225 | running_loss = 0.0 226 | count = 0 227 | 228 | steps_per_epoch = len(train_dataset) // (args.batch_size * args.num_batches_to_force_agreement) 229 | 230 | for step in range(steps_per_epoch): 231 | # Sample microbatches 232 | mbs = sample_iid_mbs_for_GAF(train_dataset, args.batch_size, args.num_batches_to_force_agreement) 233 | result = step_GAF(model, 234 | optimizer, 235 | criterion, 236 | list_of_microbatches=mbs, 237 | verbose=verbose, 238 | cos_distance_thresh=cos_distance_thresh, 239 | device=device) 240 | running_loss += result['train_loss'] 241 | count += 1 242 | 243 | if verbose: 244 | print(f'Epoch: {epoch:7d}, Iteration: {epoch:7d}, Train Loss: {result["train_loss"]:.9f}, Train Acc: {result["train_accuracy"]:.4f}, Costine Distance: {result["cosine_distance"]:.4f}, Agreement Count: {result["agreed_count"]:d}') 245 | 246 | # Log to wandb 247 | if wandb: 248 | wandb.log(result) 249 | 250 | # Validation step 251 | model.eval() 252 | val_loss = 0.0 253 | val_count = 0 254 | all_preds = [] 255 | all_labels = [] 256 | with torch.no_grad(): 257 | for batch in val_loader: 258 | images, labels = batch[0].to(device), batch[1].to(device) 259 | outputs = model(images) 260 | loss = criterion(outputs, labels) 261 | val_loss += loss.item() 262 | val_count += 1 263 | _, preds = torch.max(outputs, dim=1) 264 | all_preds.extend(preds.cpu()) 265 | all_labels.extend(labels.cpu()) 266 | 267 | val_loss /= max(val_count, 1) 268 | 269 | val_accuracy = sum([pred == label for pred, label in zip(all_preds, all_labels)]) / len(all_preds) 270 | message = {'epoch': epoch+1, 'train_loss': running_loss/max(count,1), 'val_loss': val_loss, 'train_accuracy': result['train_accuracy'], 'val_accuracy': val_accuracy.item()} 271 | 272 | 273 | if verbose: 274 | print(f'Epoch: {epoch:7d}, Train Loss: {message["train_loss"]:.9f}, Train Acc: {message["train_accuracy"]:.4f}, Val Loss: {message["val_loss"]:.9f}, Val Acc: {message["val_accuracy"]:.4f}') 275 | # log to wandb 276 | if use_wandb: 277 | try: 278 | wandb.log(message) 279 | except Exception as e: 280 | print(f"Failed to log to wandb: {e}") 281 | 282 | return None 283 | -------------------------------------------------------------------------------- /test/run_sweeps.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | WANDB_API_KEY= 4 | 5 | # Example run: 6 | #WANDB_API_KEY= python ../examples/1_cifar_100_train_loop_exposed.py --GAF True --optimizer "SGD+Nesterov+val_plateau" --learning_rate 0.01 --momentum 0.9 --nesterov True --wandb True --verbose True --num_samples_per_class_per_batch 2 --num_batches_to_force_agreement 2 --label_error_percentage 0.15 --cos_distance_thresh 0.97 7 | 8 | # Kill existing screen sessions that start with "gaf_" 9 | screen -ls | grep '\.gaf_' | awk '{print $1}' | xargs -I{} screen -S {} -X quit 10 | 11 | # HPP SWEEP 12 | num_samples_per_class_per_batch=(1 2 4) 13 | num_batches_to_force_agreement=(2 4 10) 14 | label_error_percentage=(0 .1 .5) 15 | cos_distance_thresh=(.95 .97 .99 1.0 1.01 1.05 2.0) 16 | 17 | 18 | # Non-GAF runs 19 | for a in "${label_error_percentage[@]}"; do 20 | sleep 30 # this allows the GPU mem to fill up and to download any files needed 21 | screen_name="gaf_baseline_${a}" 22 | screen -dmS "$screen_name" bash -c "export WANDB_API_KEY=${WANDB_API_KEY}; \ 23 | python ../examples/1_cifar_100_train_loop_exposed.py --GAF False \ 24 | --optimizer \"SGD+Nesterov+val_plateau\" \ 25 | --learning_rate 0.01 --momentum 0.9 --nesterov True \ 26 | --wandb True --verbose True \ 27 | --label_error_percentage ${a}" 28 | done 29 | 30 | # GAF runs 31 | for a in "${num_samples_per_class_per_batch[@]}"; do 32 | for b in "${num_batches_to_force_agreement[@]}"; do 33 | for c in "${label_error_percentage[@]}"; do 34 | for d in "${cos_distance_thresh[@]}"; do 35 | sleep 0.1 36 | screen_name="gaf_${a}_${b}_${c}_${d}" 37 | screen -dmS "$screen_name" bash -c "export WANDB_API_KEY=${WANDB_API_KEY}; \ 38 | python ../examples/1_cifar_100_train_loop_exposed.py \ 39 | --GAF True --optimizer \"SGD+Nesterov+val_plateau\" \ 40 | --learning_rate 0.01 --momentum 0.9 --nesterov True \ 41 | --wandb True --verbose True \ 42 | --num_samples_per_class_per_batch ${a} \ 43 | --num_batches_to_force_agreement ${b} \ 44 | --label_error_percentage ${c} \ 45 | --cos_distance_thresh ${d}"\ 46 | 47 | done 48 | done 49 | done 50 | done 51 | 52 | -------------------------------------------------------------------------------- /test/test_gaf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, optim 3 | from torch.utils.data import TensorDataset, DataLoader, Subset 4 | import pytest 5 | from gradient_agreement_filtering.gaf import _filter_gradients_cosine_sim, _compute_gradients 6 | 7 | @pytest.fixture 8 | def setup(): 9 | # Simple model 10 | model = nn.Linear(10, 2) 11 | criterion = nn.CrossEntropyLoss() 12 | optimizer = optim.SGD(model.parameters(), lr=0.1) 13 | device = torch.device('cpu') 14 | # Simple dataset 15 | X = torch.randn(20, 10) 16 | y = torch.randint(0, 2, (20,)) 17 | dataset = TensorDataset(X, y) 18 | return model, criterion, optimizer, device, dataset 19 | 20 | def test_filter_gradients_cosine_sim(setup): 21 | model, _, _, _, _ = setup 22 | # Create dummy gradients 23 | G1 = [torch.randn(p.shape) for p in model.parameters()] 24 | G2 = [torch.randn(p.shape) for p in model.parameters()] 25 | 26 | filtered_grad, cos_dist = _filter_gradients_cosine_sim(G1, G2, cos_distance_thresh=2.0) 27 | # With a large threshold, we expect some averaging to happen 28 | assert filtered_grad is not None 29 | assert isinstance(cos_dist, float) 30 | 31 | # With a very small threshold, likely no agreement 32 | filtered_grad_none, _ = _filter_gradients_cosine_sim(G1, G2, cos_distance_thresh=0.0) 33 | # Most likely none since random gradients are unlikely to match exactly 34 | assert filtered_grad_none is None 35 | 36 | def test_compute_gradients(setup): 37 | model, criterion, optimizer, device, dataset = setup 38 | subset_indices = list(range(4)) 39 | b = Subset(dataset, subset_indices) 40 | G, loss, labels, outputs = _compute_gradients(b, optimizer, model, criterion, device) 41 | assert isinstance(G, list) 42 | assert isinstance(loss, torch.Tensor) 43 | assert isinstance(labels, torch.Tensor) 44 | assert isinstance(outputs, torch.Tensor) 45 | assert len(G) == len(list(model.parameters())) 46 | --------------------------------------------------------------------------------